add model params loaded condition
This commit is contained in:
parent
f54d817702
commit
05eb174091
@ -156,9 +156,12 @@ def main():
|
||||
|
||||
# 形式化预训练参数加载
|
||||
if checkpoint_load:
|
||||
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
print_with_timestamp("Model Checkpoint Params Loaded")
|
||||
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
|
||||
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
print_with_timestamp("Model Checkpoint Params Loaded")
|
||||
else:
|
||||
print_with_timestamp("Model Checkpoint Params Not Exist")
|
||||
|
||||
# 损失函数、优化器、学习率调度器
|
||||
loss_fn = nn.HuberLoss()
|
||||
|
||||
@ -62,7 +62,7 @@ def fit(rank, world_size, cfg):
|
||||
world_size: 进程总数(通常等于可用 GPU 数)
|
||||
cfg: 从 toml 读取的配置字典
|
||||
"""
|
||||
# -------- init distributed env --------
|
||||
# 初始化分布式参数
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = cfg.get("master_port", "29500")
|
||||
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
||||
@ -71,7 +71,7 @@ def fit(rank, world_size, cfg):
|
||||
if rank == 0:
|
||||
print_with_timestamp(f"Distributed initialized. World size: {world_size}")
|
||||
|
||||
# -------- parse hyperparams from cfg --------
|
||||
# 确定超参数
|
||||
model_type = cfg["model_type"]
|
||||
output_type = cfg["output_type"]
|
||||
dataset_dir = cfg["dataset_dir"]
|
||||
@ -84,15 +84,14 @@ def fit(rank, world_size, cfg):
|
||||
objective_params_list = cfg["train_objective_params_list"]
|
||||
checkpoint_load = cfg["checkpoint_load"]
|
||||
|
||||
# -------- datasets & distributed sampler --------
|
||||
# 加载数据集
|
||||
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
|
||||
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
|
||||
|
||||
# DistributedSampler ensures each process sees a unique subset each epoch
|
||||
# 分布式化数据集
|
||||
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
|
||||
val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False)
|
||||
|
||||
# When using DistributedSampler, do not shuffle at DataLoader level
|
||||
train_loader = DataLoader(
|
||||
train_set, batch_size=batch_size, num_workers=num_workers,
|
||||
shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler
|
||||
@ -105,7 +104,7 @@ def fit(rank, world_size, cfg):
|
||||
if rank == 0:
|
||||
print_with_timestamp("Dataset Loaded (Distributed)")
|
||||
|
||||
# -------- model creation --------
|
||||
# 模型选择
|
||||
if "film" in model_type:
|
||||
fusion_depth_list = [int(ch) for ch in model_type[4:]]
|
||||
model = MOAFWithFiLM(fusion_depth_list).to(device)
|
||||
@ -120,16 +119,19 @@ def fit(rank, world_size, cfg):
|
||||
|
||||
# 形式化预训练参数加载
|
||||
if checkpoint_load:
|
||||
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
print_with_timestamp("Model Checkpoint Params Loaded")
|
||||
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
|
||||
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
print_with_timestamp("Model Checkpoint Params Loaded")
|
||||
else:
|
||||
print_with_timestamp("Model Checkpoint Params Not Exist")
|
||||
|
||||
# wrap with DDP. device_ids ensures single-GPU per process
|
||||
# DDP 化模型
|
||||
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
|
||||
if rank == 0:
|
||||
print_with_timestamp("Model Wrapped with DDP")
|
||||
|
||||
# -------- loss / optimizer / scheduler --------
|
||||
# 损失函数、优化器、学习率调度器
|
||||
loss_fn = nn.HuberLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
@ -138,7 +140,7 @@ def fit(rank, world_size, cfg):
|
||||
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
|
||||
)
|
||||
|
||||
# -------- TensorBoard & checkpoint only on rank 0 --------
|
||||
# Tensorboard 上显示模型结构
|
||||
if rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}")
|
||||
# tensorboard graph: use a small dummy input placed on correct device
|
||||
@ -148,7 +150,7 @@ def fit(rank, world_size, cfg):
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
# -------- training loop with early stopping (only rank 0 saves checkpoints) --------
|
||||
# 训练
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
|
||||
@ -201,7 +203,7 @@ def fit(rank, world_size, cfg):
|
||||
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
|
||||
break
|
||||
|
||||
# cleanup
|
||||
# 清除进程
|
||||
if tb_writer is not None:
|
||||
tb_writer.close()
|
||||
dist.destroy_process_group()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user