fix mmlp model

This commit is contained in:
kaiza_hikaru 2025-11-03 09:57:54 +08:00
parent 955083968e
commit 9f3ef716fb
2 changed files with 31 additions and 35 deletions

View File

@ -260,6 +260,18 @@ class MOAFWithMMLP(nn.Module):
self._create_regressor_head() for _ in range(num_lenses) self._create_regressor_head() for _ in range(num_lenses)
]) ])
# 注册物镜参数基准张量(不可学习)
lens_params_base = torch.tensor([
[10, 0.25, 1.0000], # obj1
[10, 0.30, 1.0000], # obj2
[20, 0.70, 1.0000], # obj3
[20, 0.80, 1.0000], # obj4
[40, 0.65, 1.0000], # obj5
[100, 0.80, 1.0000], # obj6
[100, 1.25, 1.4730] # obj7
], dtype=torch.float32)
self.register_buffer('lens_params_base', lens_params_base)
def _create_regressor_head(self): def _create_regressor_head(self):
return nn.Sequential( return nn.Sequential(
nn.Flatten(), nn.Flatten(),
@ -269,48 +281,29 @@ class MOAFWithMMLP(nn.Module):
) )
def _find_lens_id(self, params): def _find_lens_id(self, params):
batch_size = params.shape[0] # 扩展维度以进行广播计算
lens_ids = [] params_expanded = params.unsqueeze(1) # [batch_size, 1, 3]
base_expanded = self.lens_params_base.unsqueeze(0) # [1, 7, 3]
for i in range(batch_size): # 计算欧氏距离(完全向量化)
mag, na, rix = params[i][0].item(), params[i][1].item(), params[i][2].item() distances = torch.sqrt(torch.sum((params_expanded - base_expanded) ** 2, dim=2)) # [batch_size, 7]
# 直接的条件匹配类似于switch-case # 找到最小距离的索引
if mag == 10 and na == 0.25 and rix == 1.0000: lens_ids = torch.argmin(distances, dim=1) # [batch_size]
lens_ids.append(0) # obj1 return lens_ids
elif mag == 10 and na == 0.30 and rix == 1.0000:
lens_ids.append(1) # obj2
elif mag == 20 and na == 0.70 and rix == 1.0000:
lens_ids.append(2) # obj3
elif mag == 20 and na == 0.80 and rix == 1.0000:
lens_ids.append(3) # obj4
elif mag == 40 and na == 0.65 and rix == 1.0000:
lens_ids.append(4) # obj5
elif mag == 100 and na == 0.80 and rix == 1.0000:
lens_ids.append(5) # obj6
elif mag == 100 and na == 1.25 and rix == 1.4730:
lens_ids.append(6) # obj7
else:
lens_ids.append(0)
return torch.tensor(lens_ids, dtype=torch.long, device=params.device)
def forward(self, image, params): def forward(self, image, params):
x = self.features(image) x = self.features(image)
x = self.avgpool(x) x = self.avgpool(x)
lens_ids = self._find_lens_id(params) lens_ids = self._find_lens_id(params)
batch_size = params.size(0) lens_ids = torch.clamp(lens_ids, 0, self.num_lenses - 1)
outputs = [] batch_size = x.size(0)
all_outputs = []
for i in range(batch_size): for i in range(batch_size):
current_lens_id = lens_ids[i] head_idx = lens_ids[i]
# 确保lens_id在有效范围内 output = self.regressors[head_idx](x[i].unsqueeze(0))
if current_lens_id < 0 or current_lens_id >= self.num_lenses: all_outputs.append(output)
current_lens_id = 0
# 选择对应的回归头 return torch.cat(all_outputs, dim=0)
head_output = self.regressors[current_lens_id](x[i].unsqueeze(0))
outputs.append(head_output)
return torch.cat(outputs, dim=0)

View File

@ -128,6 +128,9 @@ def fit(rank, world_size, cfg):
print_with_timestamp("Model Checkpoint Params Not Exist") print_with_timestamp("Model Checkpoint Params Not Exist")
# DDP 化模型 # DDP 化模型
if "mmlp" in model_type:
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
else:
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False) model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
if rank == 0: if rank == 0:
print_with_timestamp("Model Wrapped with DDP") print_with_timestamp("Model Wrapped with DDP")