import argparse import logging import shutil import time from tqdm import tqdm from datetime import datetime from pathlib import Path from pprint import pformat import torch import torch.nn as nn import torchvision.transforms as transforms from torch.optim import Adam from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter import tomllib from models import DPNet import old_datasets as dataset_F import utils def train_epoch(model, train_loader, criterion, optimizer, scheduler, device, writer, global_step): model.train() running_loss = 0.0 sample_count = 0 for img, label, image_name in tqdm(train_loader, desc="Train", bar_format="{l_bar}{bar:20}{r_bar}"): img = img.to(device, non_blocking=True) label = label.to(device, non_blocking=True).float().view(-1) optimizer.zero_grad(set_to_none=True) output = model(img).view(-1) loss = criterion(output, label) loss.backward() optimizer.step() scheduler.step() global_step += 1 lr = optimizer.param_groups[0]["lr"] writer.add_scalar("lr", lr, global_step) batch_size = img.size(0) running_loss += loss.item() * batch_size sample_count += batch_size epoch_loss = running_loss / sample_count return epoch_loss, global_step @torch.no_grad() def validate_epoch(model, val_loader, criterion, device): model.eval() running_loss = 0.0 sample_count = 0 for img, label, image_name in tqdm(val_loader, desc="Validate", bar_format="{l_bar}{bar:20}{r_bar}"): img = img.to(device, non_blocking=True) label = label.to(device, non_blocking=True).float().view(-1) output = model(img).view(-1) loss = criterion(output, label) batch_size = img.size(0) running_loss += loss.item() * batch_size sample_count += batch_size epoch_loss = running_loss / sample_count return epoch_loss def main(): # ========================= # 1. 读取配置 # ========================= parser = argparse.ArgumentParser(description="Train DPNet") parser.add_argument("--config", type=str, required=True, help="Path to TOML config file") args = parser.parse_args() config_path = Path(args.config) with config_path.open("rb") as f: config = tomllib.load(f) xlsx_files = config["xlsx_files"] batch_size = config["batch_size"] learning_rate = config["learning_rate"] epochs = config["epochs"] num_workers = config["num_workers"] seed = config["seed"] init_weight = config["init_weight"] # ========================= # 2. 创建输出目录 # ========================= run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_dpn") run_dir = Path.cwd() / run_name run_dir.mkdir(parents=True, exist_ok=False) shutil.copy2(config_path, run_dir / config_path.name) # ========================= # 3. 初始化日志与 TensorBoard # ========================= logger = logging.getLogger("dpnet_train") logger.setLevel(logging.INFO) logger.propagate = False logger.handlers.clear() formatter = logging.Formatter( fmt="%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) file_handler = logging.FileHandler(run_dir / "train.log", encoding="utf-8") file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.INFO) stream_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.addHandler(stream_handler) writer = SummaryWriter(log_dir=str(run_dir)) logger.info(f"Run directory: {run_dir}") logger.info(f"Config path: {config_path}") logger.info(f"Loaded config:\n{pformat(config, sort_dicts=False)}") # ========================= # 4. 设置随机种子与设备 # ========================= utils.set_seeds(seed=seed) if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") logger.info(f"Using device: {device}") # ========================= # 5. 准备数据集与 DataLoader # ========================= train_image_path_list, train_defocus_distance_list, val_image_path_list, val_defocus_distance_list = \ dataset_F.get_DPNet_train_data_and_label(root_path_list=xlsx_files) train_transform = transforms.Compose([ transforms.ColorJitter( brightness=(0.9, 1.4), contrast=(0.8, 1.5), saturation=(0.8, 1.5), ), transforms.ToTensor(), ]) val_transform = transforms.Compose([ transforms.ToTensor(), ]) train_dataset = dataset_F.MyDataset( train_image_path_list, train_defocus_distance_list, train_transform, ) val_dataset = dataset_F.MyDataset( val_image_path_list, val_defocus_distance_list, val_transform, ) train_loader = DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device.type == "cuda"), persistent_workers=(num_workers > 0), ) val_loader = DataLoader( dataset=val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device.type == "cuda"), persistent_workers=(num_workers > 0), ) logger.info(f"Train dataset size: {len(train_dataset)}") logger.info(f"Val dataset size: {len(val_dataset)}") logger.info(f"Train steps per epoch: {len(train_loader)}") # ========================= # 6. 准备模型 # ========================= model = DPNet().to(device) if init_weight: state_dict = torch.load(init_weight, map_location="cpu") model.load_state_dict(state_dict, strict=True) logger.info(f"Loaded init weight from: {init_weight}") else: logger.info("Training from scratch") # ========================= # 7. 准备损失函数、优化器、调度器 # ========================= criterion = nn.MSELoss(reduction="mean") optimizer = Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999)) total_steps = epochs * len(train_loader) warmup_steps = min(2000, total_steps - 1) if total_steps > 1 else 0 scheduler = SequentialLR( optimizer, schedulers=[ LinearLR( optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps, ), CosineAnnealingLR( optimizer, T_max=max(1, total_steps - warmup_steps), eta_min=0.0, ), ], milestones=[warmup_steps], ) logger.info("Loss: MSELoss(reduction='mean')") logger.info(f"Optimizer: Adam(lr={learning_rate}, betas=(0.9, 0.999))") logger.info("Scheduler: step-based warmup + cosine annealing") logger.info(f"Total training steps: {total_steps}") logger.info(f"Warmup steps: {warmup_steps}") # ========================= # 8. 开始训练循环 # ========================= best_val_loss = float("inf") global_step = 0 start_time = time.time() for epoch in range(1, epochs + 1): epoch_start_time = time.time() train_loss, global_step = train_epoch( model=model, train_loader=train_loader, criterion=criterion, optimizer=optimizer, scheduler=scheduler, device=device, writer=writer, global_step=global_step, ) val_loss = validate_epoch( model=model, val_loader=val_loader, criterion=criterion, device=device, ) epoch_total_time = time.time() - epoch_start_time writer.add_scalar("train_loss", train_loss, epoch) writer.add_scalar("val_loss", val_loss, epoch) writer.add_scalar("epoch_total_time", epoch_total_time, epoch) logger.info( f"Epoch [{epoch}/{epochs}] | " f"train_loss={train_loss:.8f} | " f"val_loss={val_loss:.8f} | " f"epoch_total_time={epoch_total_time:.2f}s | " f"global_step={global_step}" ) torch.save(model.state_dict(), run_dir / "last_dpn.pth") if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), run_dir / "best_dpn.pth") logger.info(f"Best model updated, best_val_loss={best_val_loss:.8f}") # ========================= # 9. 收尾 # ========================= total_time = time.time() - start_time logger.info("Training finished") logger.info(f"Best validation loss: {best_val_loss:.8f}") logger.info(f"Total time: {total_time:.2f} seconds") writer.close() if __name__ == "__main__": main()