diff --git a/MOAFTrainDDP.py b/MOAFTrainDDP.py index 6cf7258..cc18e8e 100644 --- a/MOAFTrainDDP.py +++ b/MOAFTrainDDP.py @@ -119,9 +119,11 @@ def fit(rank, world_size, cfg): if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists(): checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) - print_with_timestamp("Model Checkpoint Params Loaded") + if rank == 0: + print_with_timestamp("Model Checkpoint Params Loaded") else: - print_with_timestamp("Model Checkpoint Params Not Exist") + if rank == 0: + print_with_timestamp("Model Checkpoint Params Not Exist") # DDP 化模型 model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)