fix mmlp model
This commit is contained in:
parent
955083968e
commit
9f3ef716fb
@ -260,6 +260,18 @@ class MOAFWithMMLP(nn.Module):
|
|||||||
self._create_regressor_head() for _ in range(num_lenses)
|
self._create_regressor_head() for _ in range(num_lenses)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
# 注册物镜参数基准张量(不可学习)
|
||||||
|
lens_params_base = torch.tensor([
|
||||||
|
[10, 0.25, 1.0000], # obj1
|
||||||
|
[10, 0.30, 1.0000], # obj2
|
||||||
|
[20, 0.70, 1.0000], # obj3
|
||||||
|
[20, 0.80, 1.0000], # obj4
|
||||||
|
[40, 0.65, 1.0000], # obj5
|
||||||
|
[100, 0.80, 1.0000], # obj6
|
||||||
|
[100, 1.25, 1.4730] # obj7
|
||||||
|
], dtype=torch.float32)
|
||||||
|
self.register_buffer('lens_params_base', lens_params_base)
|
||||||
|
|
||||||
def _create_regressor_head(self):
|
def _create_regressor_head(self):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Flatten(),
|
nn.Flatten(),
|
||||||
@ -269,48 +281,29 @@ class MOAFWithMMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _find_lens_id(self, params):
|
def _find_lens_id(self, params):
|
||||||
batch_size = params.shape[0]
|
# 扩展维度以进行广播计算
|
||||||
lens_ids = []
|
params_expanded = params.unsqueeze(1) # [batch_size, 1, 3]
|
||||||
|
base_expanded = self.lens_params_base.unsqueeze(0) # [1, 7, 3]
|
||||||
|
|
||||||
for i in range(batch_size):
|
# 计算欧氏距离(完全向量化)
|
||||||
mag, na, rix = params[i][0].item(), params[i][1].item(), params[i][2].item()
|
distances = torch.sqrt(torch.sum((params_expanded - base_expanded) ** 2, dim=2)) # [batch_size, 7]
|
||||||
|
|
||||||
# 直接的条件匹配(类似于switch-case)
|
# 找到最小距离的索引
|
||||||
if mag == 10 and na == 0.25 and rix == 1.0000:
|
lens_ids = torch.argmin(distances, dim=1) # [batch_size]
|
||||||
lens_ids.append(0) # obj1
|
return lens_ids
|
||||||
elif mag == 10 and na == 0.30 and rix == 1.0000:
|
|
||||||
lens_ids.append(1) # obj2
|
|
||||||
elif mag == 20 and na == 0.70 and rix == 1.0000:
|
|
||||||
lens_ids.append(2) # obj3
|
|
||||||
elif mag == 20 and na == 0.80 and rix == 1.0000:
|
|
||||||
lens_ids.append(3) # obj4
|
|
||||||
elif mag == 40 and na == 0.65 and rix == 1.0000:
|
|
||||||
lens_ids.append(4) # obj5
|
|
||||||
elif mag == 100 and na == 0.80 and rix == 1.0000:
|
|
||||||
lens_ids.append(5) # obj6
|
|
||||||
elif mag == 100 and na == 1.25 and rix == 1.4730:
|
|
||||||
lens_ids.append(6) # obj7
|
|
||||||
else:
|
|
||||||
lens_ids.append(0)
|
|
||||||
|
|
||||||
return torch.tensor(lens_ids, dtype=torch.long, device=params.device)
|
|
||||||
|
|
||||||
def forward(self, image, params):
|
def forward(self, image, params):
|
||||||
x = self.features(image)
|
x = self.features(image)
|
||||||
x = self.avgpool(x)
|
x = self.avgpool(x)
|
||||||
|
|
||||||
lens_ids = self._find_lens_id(params)
|
lens_ids = self._find_lens_id(params)
|
||||||
batch_size = params.size(0)
|
lens_ids = torch.clamp(lens_ids, 0, self.num_lenses - 1)
|
||||||
outputs = []
|
batch_size = x.size(0)
|
||||||
|
all_outputs = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
current_lens_id = lens_ids[i]
|
head_idx = lens_ids[i]
|
||||||
# 确保lens_id在有效范围内
|
output = self.regressors[head_idx](x[i].unsqueeze(0))
|
||||||
if current_lens_id < 0 or current_lens_id >= self.num_lenses:
|
all_outputs.append(output)
|
||||||
current_lens_id = 0
|
|
||||||
|
|
||||||
# 选择对应的回归头
|
return torch.cat(all_outputs, dim=0)
|
||||||
head_output = self.regressors[current_lens_id](x[i].unsqueeze(0))
|
|
||||||
outputs.append(head_output)
|
|
||||||
|
|
||||||
return torch.cat(outputs, dim=0)
|
|
||||||
|
|
||||||
|
|||||||
@ -128,6 +128,9 @@ def fit(rank, world_size, cfg):
|
|||||||
print_with_timestamp("Model Checkpoint Params Not Exist")
|
print_with_timestamp("Model Checkpoint Params Not Exist")
|
||||||
|
|
||||||
# DDP 化模型
|
# DDP 化模型
|
||||||
|
if "mmlp" in model_type:
|
||||||
|
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
|
||||||
|
else:
|
||||||
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
|
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print_with_timestamp("Model Wrapped with DDP")
|
print_with_timestamp("Model Wrapped with DDP")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user