From 9f3ef716fbcab4a024161525ec0471d5e4d92c03 Mon Sep 17 00:00:00 2001 From: kaiza_hikaru Date: Mon, 3 Nov 2025 09:57:54 +0800 Subject: [PATCH] fix mmlp model --- MOAFModels.py | 61 ++++++++++++++++++++++--------------------------- MOAFTrainDDP.py | 5 +++- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/MOAFModels.py b/MOAFModels.py index b102a32..c426ba1 100644 --- a/MOAFModels.py +++ b/MOAFModels.py @@ -260,6 +260,18 @@ class MOAFWithMMLP(nn.Module): self._create_regressor_head() for _ in range(num_lenses) ]) + # 注册物镜参数基准张量(不可学习) + lens_params_base = torch.tensor([ + [10, 0.25, 1.0000], # obj1 + [10, 0.30, 1.0000], # obj2 + [20, 0.70, 1.0000], # obj3 + [20, 0.80, 1.0000], # obj4 + [40, 0.65, 1.0000], # obj5 + [100, 0.80, 1.0000], # obj6 + [100, 1.25, 1.4730] # obj7 + ], dtype=torch.float32) + self.register_buffer('lens_params_base', lens_params_base) + def _create_regressor_head(self): return nn.Sequential( nn.Flatten(), @@ -269,48 +281,29 @@ class MOAFWithMMLP(nn.Module): ) def _find_lens_id(self, params): - batch_size = params.shape[0] - lens_ids = [] + # 扩展维度以进行广播计算 + params_expanded = params.unsqueeze(1) # [batch_size, 1, 3] + base_expanded = self.lens_params_base.unsqueeze(0) # [1, 7, 3] - for i in range(batch_size): - mag, na, rix = params[i][0].item(), params[i][1].item(), params[i][2].item() - - # 直接的条件匹配(类似于switch-case) - if mag == 10 and na == 0.25 and rix == 1.0000: - lens_ids.append(0) # obj1 - elif mag == 10 and na == 0.30 and rix == 1.0000: - lens_ids.append(1) # obj2 - elif mag == 20 and na == 0.70 and rix == 1.0000: - lens_ids.append(2) # obj3 - elif mag == 20 and na == 0.80 and rix == 1.0000: - lens_ids.append(3) # obj4 - elif mag == 40 and na == 0.65 and rix == 1.0000: - lens_ids.append(4) # obj5 - elif mag == 100 and na == 0.80 and rix == 1.0000: - lens_ids.append(5) # obj6 - elif mag == 100 and na == 1.25 and rix == 1.4730: - lens_ids.append(6) # obj7 - else: - lens_ids.append(0) + # 计算欧氏距离(完全向量化) + distances = torch.sqrt(torch.sum((params_expanded - base_expanded) ** 2, dim=2)) # [batch_size, 7] - return torch.tensor(lens_ids, dtype=torch.long, device=params.device) + # 找到最小距离的索引 + lens_ids = torch.argmin(distances, dim=1) # [batch_size] + return lens_ids def forward(self, image, params): x = self.features(image) x = self.avgpool(x) lens_ids = self._find_lens_id(params) - batch_size = params.size(0) - outputs = [] + lens_ids = torch.clamp(lens_ids, 0, self.num_lenses - 1) + batch_size = x.size(0) + all_outputs = [] for i in range(batch_size): - current_lens_id = lens_ids[i] - # 确保lens_id在有效范围内 - if current_lens_id < 0 or current_lens_id >= self.num_lenses: - current_lens_id = 0 - - # 选择对应的回归头 - head_output = self.regressors[current_lens_id](x[i].unsqueeze(0)) - outputs.append(head_output) + head_idx = lens_ids[i] + output = self.regressors[head_idx](x[i].unsqueeze(0)) + all_outputs.append(output) - return torch.cat(outputs, dim=0) + return torch.cat(all_outputs, dim=0) diff --git a/MOAFTrainDDP.py b/MOAFTrainDDP.py index c1d83ff..e0f2fb2 100644 --- a/MOAFTrainDDP.py +++ b/MOAFTrainDDP.py @@ -128,7 +128,10 @@ def fit(rank, world_size, cfg): print_with_timestamp("Model Checkpoint Params Not Exist") # DDP 化模型 - model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False) + if "mmlp" in model_type: + model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True) + else: + model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False) if rank == 0: print_with_timestamp("Model Wrapped with DDP")