import os import time import argparse import tomllib from pathlib import Path import numpy as np from tqdm import tqdm import torch import torch.distributed as dist import torch.multiprocessing as mp from torch import nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from MOAFUtils import print_with_timestamp from MOAFDatasets import MOAFDataset from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn): model.train() train_loss = 0.0 for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"): images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) optimizer.zero_grad() outputs = model(images, params) loss = loss_fn(outputs.squeeze(1), labels) loss.backward() optimizer.step() train_loss += loss.item() return train_loss def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn): model.eval() val_loss = 0.0 with torch.no_grad(): for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"): images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) outputs = model(images, params) loss = loss_fn(outputs.squeeze(1), labels) val_loss += loss.item() return val_loss def fit(rank, world_size, cfg): """ 每个进程运行的主函数(单卡)。 rank: 该进程的全局 rank(0 ~ world_size-1) world_size: 进程总数(通常等于可用 GPU 数) cfg: 从 toml 读取的配置字典 """ # -------- init distributed env -------- os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = cfg.get("master_port", "29500") dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) torch.cuda.set_device(rank) device = torch.device(f'cuda:{rank}') if rank == 0: print_with_timestamp(f"Distributed initialized. World size: {world_size}") # -------- parse hyperparams from cfg -------- model_type = cfg["model_type"] output_type = cfg["output_type"] dataset_dir = cfg["dataset_dir"] batch_size = int(cfg["batch_size"]) num_workers = int(cfg["num_workers"]) lr = float(cfg["lr"]) patience = int(cfg["patience"]) epochs = int(cfg["epochs"]) warmup_epochs = int(cfg["warmup_epochs"]) objective_params_list = cfg["train_objective_params_list"] checkpoint_load = cfg["checkpoint_load"] # -------- datasets & distributed sampler -------- train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type) val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type) # DistributedSampler ensures each process sees a unique subset each epoch train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True) val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False) # When using DistributedSampler, do not shuffle at DataLoader level train_loader = DataLoader( train_set, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler ) val_loader = DataLoader( val_set, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True, persistent_workers=True, sampler=val_sampler ) if rank == 0: print_with_timestamp("Dataset Loaded (Distributed)") # -------- model creation -------- if "film" in model_type: fusion_depth_list = [int(ch) for ch in model_type[4:]] model = MOAFWithFiLM(fusion_depth_list).to(device) elif "cca" in model_type: fusion_depth_list = [int(ch) for ch in model_type[3:]] model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device) elif "se" in model_type: fusion_depth_list = [int(ch) for ch in model_type[2:]] model = MOAFWithSE(fusion_depth_list).to(device) else: model = MOAFNoFusion().to(device) # 形式化预训练参数加载 if checkpoint_load: checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) print_with_timestamp("Model Checkpoint Params Loaded") # wrap with DDP. device_ids ensures single-GPU per process model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False) if rank == 0: print_with_timestamp("Model Wrapped with DDP") # -------- loss / optimizer / scheduler -------- loss_fn = nn.HuberLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))) ) # -------- TensorBoard & checkpoint only on rank 0 -------- if rank == 0: tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") # tensorboard graph: use a small dummy input placed on correct device dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device) tb_writer.add_graph(model.module, (dummy_input1, dummy_input2)) else: tb_writer = None # -------- training loop with early stopping (only rank 0 saves checkpoints) -------- best_val_loss = float('inf') patience_counter = 0 if rank == 0: print_with_timestamp("Start training (DDP)") for epoch in range(epochs): # 必须在每个 epoch 开头设置 epoch 给 sampler,这样 shuffle 能跨 epoch 工作 train_sampler.set_epoch(epoch) val_sampler.set_epoch(epoch) start_time = time.time() avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader) avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn) / len(val_loader) current_lr = optimizer.param_groups[0]['lr'] scheduler.step() epoch_time = time.time() - start_time # 只有 rank 0 写 tensorboard 和保存 checkpoint,避免重复 IO if rank == 0: tb_writer.add_scalar('Loss/train', avg_train_loss, epoch) tb_writer.add_scalar('Loss/val', avg_val_loss, epoch) tb_writer.add_scalar('LearningRate', current_lr, epoch) tb_writer.add_scalar('Time/epoch', epoch_time, epoch) print_with_timestamp(f"Epoch {epoch+1}/{epochs} | " f"Train Loss: {avg_train_loss:.6f} | " f"Val Loss: {avg_val_loss:.6f} | " f"LR: {current_lr:.2e} | " f"Time: {epoch_time:.2f}s") if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss patience_counter = 0 save_dict = { "epoch": epoch, # 保存 module.state_dict()(DDP 包裹时用 module) "model_state_dict": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "train_loss": avg_train_loss, "val_loss": avg_val_loss } Path("ckpts").mkdir(exist_ok=True, parents=True) torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt") print_with_timestamp(f"New best model saved at epoch {epoch+1}") else: patience_counter += 1 if patience_counter > patience: print_with_timestamp(f"Early stopping at {epoch+1} epochs") break # cleanup if tb_writer is not None: tb_writer.close() dist.destroy_process_group() if rank == 0: print_with_timestamp("Training completed and distributed cleaned up") def parse_args_and_cfg(): parser = argparse.ArgumentParser(description="DDP train script that loads config from a TOML file.") parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)") parser.add_argument("--nproc_per_node", type=int, default=None, help="number of processes (GPUs) to use on this node") args = parser.parse_args() with open(args.config, "rb") as f: cfg = tomllib.load(f) return args, cfg def main(): args, cfg = parse_args_and_cfg() # 自动检测可用 GPU 数量,除非用户指定 nproc_per_node available_gpus = torch.cuda.device_count() if available_gpus == 0: raise RuntimeError("No CUDA devices available for DDP. Use a GPU machine.") nproc = args.nproc_per_node if args.nproc_per_node is not None else available_gpus if nproc > available_gpus: raise ValueError(f"Requested {nproc} processes but only {available_gpus} GPUs available") # 使用 spawn 启动 nproc 个子进程,每个进程运行 fit(rank, world_size, cfg) # spawn 会在每个子进程中调用 fit(rank, world_size, cfg) 并传入不同的 rank mp.spawn(fit, args=(nproc, cfg), nprocs=nproc, join=True) if __name__ == "__main__": main()