add model params loaded condition

This commit is contained in:
kaiza_hikaru 2025-10-23 16:57:08 +08:00
parent f54d817702
commit 05eb174091
2 changed files with 22 additions and 17 deletions

View File

@ -156,9 +156,12 @@ def main():
# 形式化预训练参数加载
if checkpoint_load:
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
print_with_timestamp("Model Checkpoint Params Loaded")
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
print_with_timestamp("Model Checkpoint Params Loaded")
else:
print_with_timestamp("Model Checkpoint Params Not Exist")
# 损失函数、优化器、学习率调度器
loss_fn = nn.HuberLoss()

View File

@ -62,7 +62,7 @@ def fit(rank, world_size, cfg):
world_size: 进程总数通常等于可用 GPU
cfg: toml 读取的配置字典
"""
# -------- init distributed env --------
# 初始化分布式参数
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = cfg.get("master_port", "29500")
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
@ -71,7 +71,7 @@ def fit(rank, world_size, cfg):
if rank == 0:
print_with_timestamp(f"Distributed initialized. World size: {world_size}")
# -------- parse hyperparams from cfg --------
# 确定超参数
model_type = cfg["model_type"]
output_type = cfg["output_type"]
dataset_dir = cfg["dataset_dir"]
@ -84,15 +84,14 @@ def fit(rank, world_size, cfg):
objective_params_list = cfg["train_objective_params_list"]
checkpoint_load = cfg["checkpoint_load"]
# -------- datasets & distributed sampler --------
# 加载数据集
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
# DistributedSampler ensures each process sees a unique subset each epoch
# 分布式化数据集
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False)
# When using DistributedSampler, do not shuffle at DataLoader level
train_loader = DataLoader(
train_set, batch_size=batch_size, num_workers=num_workers,
shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler
@ -105,7 +104,7 @@ def fit(rank, world_size, cfg):
if rank == 0:
print_with_timestamp("Dataset Loaded (Distributed)")
# -------- model creation --------
# 模型选择
if "film" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[4:]]
model = MOAFWithFiLM(fusion_depth_list).to(device)
@ -120,16 +119,19 @@ def fit(rank, world_size, cfg):
# 形式化预训练参数加载
if checkpoint_load:
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
print_with_timestamp("Model Checkpoint Params Loaded")
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
print_with_timestamp("Model Checkpoint Params Loaded")
else:
print_with_timestamp("Model Checkpoint Params Not Exist")
# wrap with DDP. device_ids ensures single-GPU per process
# DDP 化模型
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
if rank == 0:
print_with_timestamp("Model Wrapped with DDP")
# -------- loss / optimizer / scheduler --------
# 损失函数、优化器、学习率调度器
loss_fn = nn.HuberLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(
@ -138,7 +140,7 @@ def fit(rank, world_size, cfg):
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
)
# -------- TensorBoard & checkpoint only on rank 0 --------
# Tensorboard 上显示模型结构
if rank == 0:
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}")
# tensorboard graph: use a small dummy input placed on correct device
@ -148,7 +150,7 @@ def fit(rank, world_size, cfg):
else:
tb_writer = None
# -------- training loop with early stopping (only rank 0 saves checkpoints) --------
# 训练
best_val_loss = float('inf')
patience_counter = 0
@ -201,7 +203,7 @@ def fit(rank, world_size, cfg):
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
break
# cleanup
# 清除进程
if tb_writer is not None:
tb_writer.close()
dist.destroy_process_group()