fix mmlp model
This commit is contained in:
parent
955083968e
commit
9f3ef716fb
@ -260,6 +260,18 @@ class MOAFWithMMLP(nn.Module):
|
||||
self._create_regressor_head() for _ in range(num_lenses)
|
||||
])
|
||||
|
||||
# 注册物镜参数基准张量(不可学习)
|
||||
lens_params_base = torch.tensor([
|
||||
[10, 0.25, 1.0000], # obj1
|
||||
[10, 0.30, 1.0000], # obj2
|
||||
[20, 0.70, 1.0000], # obj3
|
||||
[20, 0.80, 1.0000], # obj4
|
||||
[40, 0.65, 1.0000], # obj5
|
||||
[100, 0.80, 1.0000], # obj6
|
||||
[100, 1.25, 1.4730] # obj7
|
||||
], dtype=torch.float32)
|
||||
self.register_buffer('lens_params_base', lens_params_base)
|
||||
|
||||
def _create_regressor_head(self):
|
||||
return nn.Sequential(
|
||||
nn.Flatten(),
|
||||
@ -269,48 +281,29 @@ class MOAFWithMMLP(nn.Module):
|
||||
)
|
||||
|
||||
def _find_lens_id(self, params):
|
||||
batch_size = params.shape[0]
|
||||
lens_ids = []
|
||||
# 扩展维度以进行广播计算
|
||||
params_expanded = params.unsqueeze(1) # [batch_size, 1, 3]
|
||||
base_expanded = self.lens_params_base.unsqueeze(0) # [1, 7, 3]
|
||||
|
||||
for i in range(batch_size):
|
||||
mag, na, rix = params[i][0].item(), params[i][1].item(), params[i][2].item()
|
||||
# 计算欧氏距离(完全向量化)
|
||||
distances = torch.sqrt(torch.sum((params_expanded - base_expanded) ** 2, dim=2)) # [batch_size, 7]
|
||||
|
||||
# 直接的条件匹配(类似于switch-case)
|
||||
if mag == 10 and na == 0.25 and rix == 1.0000:
|
||||
lens_ids.append(0) # obj1
|
||||
elif mag == 10 and na == 0.30 and rix == 1.0000:
|
||||
lens_ids.append(1) # obj2
|
||||
elif mag == 20 and na == 0.70 and rix == 1.0000:
|
||||
lens_ids.append(2) # obj3
|
||||
elif mag == 20 and na == 0.80 and rix == 1.0000:
|
||||
lens_ids.append(3) # obj4
|
||||
elif mag == 40 and na == 0.65 and rix == 1.0000:
|
||||
lens_ids.append(4) # obj5
|
||||
elif mag == 100 and na == 0.80 and rix == 1.0000:
|
||||
lens_ids.append(5) # obj6
|
||||
elif mag == 100 and na == 1.25 and rix == 1.4730:
|
||||
lens_ids.append(6) # obj7
|
||||
else:
|
||||
lens_ids.append(0)
|
||||
|
||||
return torch.tensor(lens_ids, dtype=torch.long, device=params.device)
|
||||
# 找到最小距离的索引
|
||||
lens_ids = torch.argmin(distances, dim=1) # [batch_size]
|
||||
return lens_ids
|
||||
|
||||
def forward(self, image, params):
|
||||
x = self.features(image)
|
||||
x = self.avgpool(x)
|
||||
|
||||
lens_ids = self._find_lens_id(params)
|
||||
batch_size = params.size(0)
|
||||
outputs = []
|
||||
lens_ids = torch.clamp(lens_ids, 0, self.num_lenses - 1)
|
||||
batch_size = x.size(0)
|
||||
all_outputs = []
|
||||
for i in range(batch_size):
|
||||
current_lens_id = lens_ids[i]
|
||||
# 确保lens_id在有效范围内
|
||||
if current_lens_id < 0 or current_lens_id >= self.num_lenses:
|
||||
current_lens_id = 0
|
||||
head_idx = lens_ids[i]
|
||||
output = self.regressors[head_idx](x[i].unsqueeze(0))
|
||||
all_outputs.append(output)
|
||||
|
||||
# 选择对应的回归头
|
||||
head_output = self.regressors[current_lens_id](x[i].unsqueeze(0))
|
||||
outputs.append(head_output)
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
return torch.cat(all_outputs, dim=0)
|
||||
|
||||
|
||||
@ -128,7 +128,10 @@ def fit(rank, world_size, cfg):
|
||||
print_with_timestamp("Model Checkpoint Params Not Exist")
|
||||
|
||||
# DDP 化模型
|
||||
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
|
||||
if "mmlp" in model_type:
|
||||
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
|
||||
else:
|
||||
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
|
||||
if rank == 0:
|
||||
print_with_timestamp("Model Wrapped with DDP")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user