fix rank print

This commit is contained in:
kaiza_hikaru 2025-11-01 16:39:04 +08:00
parent 011eae0107
commit 4c3d25af96

View File

@ -119,8 +119,10 @@ def fit(rank, world_size, cfg):
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
if rank == 0:
print_with_timestamp("Model Checkpoint Params Loaded")
else:
if rank == 0:
print_with_timestamp("Model Checkpoint Params Not Exist")
# DDP 化模型