add model params loaded condition
This commit is contained in:
parent
f54d817702
commit
05eb174091
@ -156,9 +156,12 @@ def main():
|
|||||||
|
|
||||||
# 形式化预训练参数加载
|
# 形式化预训练参数加载
|
||||||
if checkpoint_load:
|
if checkpoint_load:
|
||||||
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
||||||
print_with_timestamp("Model Checkpoint Params Loaded")
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
print_with_timestamp("Model Checkpoint Params Loaded")
|
||||||
|
else:
|
||||||
|
print_with_timestamp("Model Checkpoint Params Not Exist")
|
||||||
|
|
||||||
# 损失函数、优化器、学习率调度器
|
# 损失函数、优化器、学习率调度器
|
||||||
loss_fn = nn.HuberLoss()
|
loss_fn = nn.HuberLoss()
|
||||||
|
|||||||
@ -62,7 +62,7 @@ def fit(rank, world_size, cfg):
|
|||||||
world_size: 进程总数(通常等于可用 GPU 数)
|
world_size: 进程总数(通常等于可用 GPU 数)
|
||||||
cfg: 从 toml 读取的配置字典
|
cfg: 从 toml 读取的配置字典
|
||||||
"""
|
"""
|
||||||
# -------- init distributed env --------
|
# 初始化分布式参数
|
||||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
os.environ['MASTER_PORT'] = cfg.get("master_port", "29500")
|
os.environ['MASTER_PORT'] = cfg.get("master_port", "29500")
|
||||||
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
||||||
@ -71,7 +71,7 @@ def fit(rank, world_size, cfg):
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
print_with_timestamp(f"Distributed initialized. World size: {world_size}")
|
print_with_timestamp(f"Distributed initialized. World size: {world_size}")
|
||||||
|
|
||||||
# -------- parse hyperparams from cfg --------
|
# 确定超参数
|
||||||
model_type = cfg["model_type"]
|
model_type = cfg["model_type"]
|
||||||
output_type = cfg["output_type"]
|
output_type = cfg["output_type"]
|
||||||
dataset_dir = cfg["dataset_dir"]
|
dataset_dir = cfg["dataset_dir"]
|
||||||
@ -84,15 +84,14 @@ def fit(rank, world_size, cfg):
|
|||||||
objective_params_list = cfg["train_objective_params_list"]
|
objective_params_list = cfg["train_objective_params_list"]
|
||||||
checkpoint_load = cfg["checkpoint_load"]
|
checkpoint_load = cfg["checkpoint_load"]
|
||||||
|
|
||||||
# -------- datasets & distributed sampler --------
|
# 加载数据集
|
||||||
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
|
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
|
||||||
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
|
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
|
||||||
|
|
||||||
# DistributedSampler ensures each process sees a unique subset each epoch
|
# 分布式化数据集
|
||||||
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
|
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
|
||||||
val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False)
|
val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False)
|
||||||
|
|
||||||
# When using DistributedSampler, do not shuffle at DataLoader level
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_set, batch_size=batch_size, num_workers=num_workers,
|
train_set, batch_size=batch_size, num_workers=num_workers,
|
||||||
shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler
|
shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler
|
||||||
@ -105,7 +104,7 @@ def fit(rank, world_size, cfg):
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
print_with_timestamp("Dataset Loaded (Distributed)")
|
print_with_timestamp("Dataset Loaded (Distributed)")
|
||||||
|
|
||||||
# -------- model creation --------
|
# 模型选择
|
||||||
if "film" in model_type:
|
if "film" in model_type:
|
||||||
fusion_depth_list = [int(ch) for ch in model_type[4:]]
|
fusion_depth_list = [int(ch) for ch in model_type[4:]]
|
||||||
model = MOAFWithFiLM(fusion_depth_list).to(device)
|
model = MOAFWithFiLM(fusion_depth_list).to(device)
|
||||||
@ -120,16 +119,19 @@ def fit(rank, world_size, cfg):
|
|||||||
|
|
||||||
# 形式化预训练参数加载
|
# 形式化预训练参数加载
|
||||||
if checkpoint_load:
|
if checkpoint_load:
|
||||||
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
||||||
print_with_timestamp("Model Checkpoint Params Loaded")
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
print_with_timestamp("Model Checkpoint Params Loaded")
|
||||||
|
else:
|
||||||
|
print_with_timestamp("Model Checkpoint Params Not Exist")
|
||||||
|
|
||||||
# wrap with DDP. device_ids ensures single-GPU per process
|
# DDP 化模型
|
||||||
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
|
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print_with_timestamp("Model Wrapped with DDP")
|
print_with_timestamp("Model Wrapped with DDP")
|
||||||
|
|
||||||
# -------- loss / optimizer / scheduler --------
|
# 损失函数、优化器、学习率调度器
|
||||||
loss_fn = nn.HuberLoss()
|
loss_fn = nn.HuberLoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
@ -138,7 +140,7 @@ def fit(rank, world_size, cfg):
|
|||||||
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
|
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
|
||||||
)
|
)
|
||||||
|
|
||||||
# -------- TensorBoard & checkpoint only on rank 0 --------
|
# Tensorboard 上显示模型结构
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}")
|
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}")
|
||||||
# tensorboard graph: use a small dummy input placed on correct device
|
# tensorboard graph: use a small dummy input placed on correct device
|
||||||
@ -148,7 +150,7 @@ def fit(rank, world_size, cfg):
|
|||||||
else:
|
else:
|
||||||
tb_writer = None
|
tb_writer = None
|
||||||
|
|
||||||
# -------- training loop with early stopping (only rank 0 saves checkpoints) --------
|
# 训练
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
|
|
||||||
@ -201,7 +203,7 @@ def fit(rank, world_size, cfg):
|
|||||||
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
|
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
|
||||||
break
|
break
|
||||||
|
|
||||||
# cleanup
|
# 清除进程
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.close()
|
tb_writer.close()
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user