import os import time import argparse import tomllib from pathlib import Path import numpy as np from tqdm import tqdm import torch import torch.distributed as dist import torch.multiprocessing as mp from torch import nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from MOAFUtils import print_with_timestamp from MOAFDatasets import MOAFDataset from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE, MOAFWithMMLP def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank): model.train() train_loss = 0.0 data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=180) if rank in {-1, 0} else train_loader for data in data_iter: images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) optimizer.zero_grad() outputs = model(images, params) loss = loss_fn(outputs.squeeze(1), labels) loss.backward() optimizer.step() train_loss += loss.item() return train_loss def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank): model.eval() val_loss = 0.0 data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=180) if rank in {-1, 0} else val_loader with torch.no_grad(): for data in data_iter: images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) outputs = model(images, params) loss = loss_fn(outputs.squeeze(1), labels) val_loss += loss.item() return val_loss def fit(rank, world_size, cfg): # 初始化分布式参数 os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = cfg.get("master_port", "29500") dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) torch.cuda.set_device(rank) device = torch.device(f'cuda:{rank}') if rank == 0: print_with_timestamp(f"Distributed initialized. World size: {world_size}") # 确定超参数 model_type = cfg["model_type"] dataset_type = cfg["dataset_type"] dataset_dir = cfg["dataset_dir"] batch_size = int(cfg["batch_size"]) num_workers = int(cfg["num_workers"]) lr = float(cfg["lr"]) epochs = int(cfg["epochs"]) warmup_epochs = int(cfg["warmup_epochs"]) objective_params_list = cfg["train_objective_params_list"] checkpoint_load = cfg["checkpoint_load"] # 加载数据集 train_set = MOAFDataset(dataset_dir, "train", objective_params_list) val_set = MOAFDataset(dataset_dir, "val", objective_params_list) # 分布式化数据集 train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True) val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False) train_loader = DataLoader( train_set, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler ) val_loader = DataLoader( val_set, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True, persistent_workers=True, sampler=val_sampler ) if rank == 0: print_with_timestamp("Dataset Loaded (Distributed)") # 模型选择 if "film" in model_type: fusion_depth_list = [int(ch) for ch in model_type[4:]] model = MOAFWithFiLM(fusion_depth_list).to(device) elif "cca" in model_type: fusion_depth_list = [int(ch) for ch in model_type[3:]] model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device) elif "se" in model_type: fusion_depth_list = [int(ch) for ch in model_type[2:]] model = MOAFWithSE(fusion_depth_list).to(device) elif "mmlp" in model_type: model = MOAFWithMMLP().to(device) else: model = MOAFNoFusion().to(device) # 形式化预训练参数加载 if checkpoint_load: if Path(f"ckpts/{model_type}_{dataset_type}_best_model.pt").exists(): checkpoint = torch.load(f"ckpts/{model_type}_{dataset_type}_best_model.pt", map_location=device, weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) if rank == 0: print_with_timestamp("Model Checkpoint Params Loaded") else: if rank == 0: print_with_timestamp("Model Checkpoint Params Not Exist") # DDP 化模型 if "mmlp" in model_type: model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True) else: model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False) if rank == 0: print_with_timestamp("Model Wrapped with DDP") # 损失函数、优化器、学习率调度器 loss_fn = nn.HuberLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))) ) # Tensorboard 上显示模型结构 if rank == 0: tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{dataset_type}") dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device) tb_writer.add_graph(model.module, (dummy_input1, dummy_input2)) else: tb_writer = None # 训练 best_val_loss = float('inf') if rank == 0: print_with_timestamp("Start training (DDP)") for epoch in range(epochs): # 必须在每个 epoch 开头设置 epoch 给 sampler,这样 shuffle 能跨 epoch 工作 train_sampler.set_epoch(epoch) val_sampler.set_epoch(epoch) start_time = time.time() avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank) / len(train_loader) avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank) / len(val_loader) current_lr = optimizer.param_groups[0]['lr'] scheduler.step() epoch_time = time.time() - start_time # 只有 rank 0 写 tensorboard 和保存 checkpoint,避免重复 IO if rank == 0: tb_writer.add_scalar('Loss/train', avg_train_loss, epoch) tb_writer.add_scalar('Loss/val', avg_val_loss, epoch) tb_writer.add_scalar('LearningRate', current_lr, epoch) tb_writer.add_scalar('Time/epoch', epoch_time, epoch) print_with_timestamp(f"Epoch {epoch+1}/{epochs} | " f"Train Loss: {avg_train_loss:.6f} | " f"Val Loss: {avg_val_loss:.6f} | " f"LR: {current_lr:.2e} | " f"Time: {epoch_time:.2f}s") if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss save_dict = { "epoch": epoch, # 保存 module.state_dict()(DDP 包裹时用 module) "model_state_dict": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "train_loss": avg_train_loss, "val_loss": avg_val_loss } Path("ckpts").mkdir(exist_ok=True, parents=True) torch.save(save_dict, f"ckpts/{model_type}_{dataset_type}_best_model.pt") print_with_timestamp(f"New best model saved at epoch {epoch+1}") # 清除进程 if tb_writer is not None: tb_writer.close() dist.destroy_process_group() if rank == 0: print_with_timestamp("Training completed and distributed cleaned up") def parse_args_and_cfg(): parser = argparse.ArgumentParser(description="DDP train script that loads config from a TOML file.") parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)") parser.add_argument("--nproc_per_node", type=int, default=None, help="number of processes (GPUs) to use on this node") args = parser.parse_args() with open(args.config, "rb") as f: cfg = tomllib.load(f) return args, cfg def main(): args, cfg = parse_args_and_cfg() # 自动检测可用 GPU 数量,除非用户指定 nproc_per_node available_gpus = torch.cuda.device_count() if available_gpus == 0: raise RuntimeError("No CUDA devices available for DDP. Use a GPU machine.") nproc = args.nproc_per_node if args.nproc_per_node is not None else available_gpus if nproc > available_gpus: raise ValueError(f"Requested {nproc} processes but only {available_gpus} GPUs available") # 使用 spawn 启动 nproc 个子进程,每个进程运行 fit(rank, world_size, cfg) # spawn 会在每个子进程中调用 fit(rank, world_size, cfg) 并传入不同的 rank mp.spawn(fit, args=(nproc, cfg), nprocs=nproc, join=True) if __name__ == "__main__": main()