diff --git a/MOAFTrain.py b/MOAFTrain.py index c5d6578..8be46a0 100644 --- a/MOAFTrain.py +++ b/MOAFTrain.py @@ -156,9 +156,12 @@ def main(): # 形式化预训练参数加载 if checkpoint_load: - checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) - model.load_state_dict(checkpoint["model_state_dict"]) - print_with_timestamp("Model Checkpoint Params Loaded") + if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists(): + checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + print_with_timestamp("Model Checkpoint Params Loaded") + else: + print_with_timestamp("Model Checkpoint Params Not Exist") # 损失函数、优化器、学习率调度器 loss_fn = nn.HuberLoss() diff --git a/MOAFTrainDDP.py b/MOAFTrainDDP.py index 8ac7778..da515d5 100644 --- a/MOAFTrainDDP.py +++ b/MOAFTrainDDP.py @@ -62,7 +62,7 @@ def fit(rank, world_size, cfg): world_size: 进程总数(通常等于可用 GPU 数) cfg: 从 toml 读取的配置字典 """ - # -------- init distributed env -------- + # 初始化分布式参数 os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = cfg.get("master_port", "29500") dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) @@ -71,7 +71,7 @@ def fit(rank, world_size, cfg): if rank == 0: print_with_timestamp(f"Distributed initialized. World size: {world_size}") - # -------- parse hyperparams from cfg -------- + # 确定超参数 model_type = cfg["model_type"] output_type = cfg["output_type"] dataset_dir = cfg["dataset_dir"] @@ -84,15 +84,14 @@ def fit(rank, world_size, cfg): objective_params_list = cfg["train_objective_params_list"] checkpoint_load = cfg["checkpoint_load"] - # -------- datasets & distributed sampler -------- + # 加载数据集 train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type) val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type) - # DistributedSampler ensures each process sees a unique subset each epoch + # 分布式化数据集 train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True) val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False) - # When using DistributedSampler, do not shuffle at DataLoader level train_loader = DataLoader( train_set, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler @@ -105,7 +104,7 @@ def fit(rank, world_size, cfg): if rank == 0: print_with_timestamp("Dataset Loaded (Distributed)") - # -------- model creation -------- + # 模型选择 if "film" in model_type: fusion_depth_list = [int(ch) for ch in model_type[4:]] model = MOAFWithFiLM(fusion_depth_list).to(device) @@ -120,16 +119,19 @@ def fit(rank, world_size, cfg): # 形式化预训练参数加载 if checkpoint_load: - checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) - model.load_state_dict(checkpoint["model_state_dict"]) - print_with_timestamp("Model Checkpoint Params Loaded") + if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists(): + checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + print_with_timestamp("Model Checkpoint Params Loaded") + else: + print_with_timestamp("Model Checkpoint Params Not Exist") - # wrap with DDP. device_ids ensures single-GPU per process + # DDP 化模型 model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False) if rank == 0: print_with_timestamp("Model Wrapped with DDP") - # -------- loss / optimizer / scheduler -------- + # 损失函数、优化器、学习率调度器 loss_fn = nn.HuberLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.LambdaLR( @@ -138,7 +140,7 @@ def fit(rank, world_size, cfg): else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))) ) - # -------- TensorBoard & checkpoint only on rank 0 -------- + # Tensorboard 上显示模型结构 if rank == 0: tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") # tensorboard graph: use a small dummy input placed on correct device @@ -148,7 +150,7 @@ def fit(rank, world_size, cfg): else: tb_writer = None - # -------- training loop with early stopping (only rank 0 saves checkpoints) -------- + # 训练 best_val_loss = float('inf') patience_counter = 0 @@ -201,7 +203,7 @@ def fit(rank, world_size, cfg): print_with_timestamp(f"Early stopping at {epoch+1} epochs") break - # cleanup + # 清除进程 if tb_writer is not None: tb_writer.close() dist.destroy_process_group()