diff --git a/MOAFTrain.py b/MOAFTrain.py index 8dde796..9c3d7de 100644 --- a/MOAFTrain.py +++ b/MOAFTrain.py @@ -143,7 +143,7 @@ def main(): fusion_depth_list = [int(ch) for ch in model_type[2:]] model = MOAFWithSE(fusion_depth_list).to(device) elif "mmlp" in model_type: - model = MOAFWithMMLP(fusion_depth_list).to(device) + model = MOAFWithMMLP().to(device) else: model = MOAFNoFusion().to(device) print_with_timestamp("Model Loaded")