fix rank print
This commit is contained in:
parent
011eae0107
commit
4c3d25af96
@ -119,9 +119,11 @@ def fit(rank, world_size, cfg):
|
||||
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
|
||||
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
print_with_timestamp("Model Checkpoint Params Loaded")
|
||||
if rank == 0:
|
||||
print_with_timestamp("Model Checkpoint Params Loaded")
|
||||
else:
|
||||
print_with_timestamp("Model Checkpoint Params Not Exist")
|
||||
if rank == 0:
|
||||
print_with_timestamp("Model Checkpoint Params Not Exist")
|
||||
|
||||
# DDP 化模型
|
||||
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user