diff --git a/MOAFTrainDDP.py b/MOAFTrainDDP.py index bf95fb8..c1d83ff 100644 --- a/MOAFTrainDDP.py +++ b/MOAFTrainDDP.py @@ -112,7 +112,7 @@ def fit(rank, world_size, cfg): fusion_depth_list = [int(ch) for ch in model_type[2:]] model = MOAFWithSE(fusion_depth_list).to(device) elif "mmlp" in model_type: - model = MOAFWithMMLP(fusion_depth_list).to(device) + model = MOAFWithMMLP().to(device) else: model = MOAFNoFusion().to(device)