317 lines
10 KiB
Python
317 lines
10 KiB
Python
import torch
|
||
from torch import nn
|
||
from torchvision import models
|
||
|
||
|
||
# 无融合模型
|
||
class MOAFNoFusion(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
|
||
self.features = nn.Sequential(
|
||
shuff.conv1, shuff.maxpool,
|
||
shuff.stage2, shuff.stage3,
|
||
shuff.stage4, shuff.conv5
|
||
)
|
||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||
self.regressor = nn.Sequential(
|
||
nn.Flatten(),
|
||
nn.Linear(1024, 128),
|
||
nn.ReLU(inplace=True),
|
||
nn.Linear(128, 1)
|
||
)
|
||
|
||
def forward(self, image, params):
|
||
x = self.features(image)
|
||
x = self.avgpool(x)
|
||
x = self.regressor(x)
|
||
return x
|
||
|
||
|
||
# 参数嵌入模型
|
||
class ParamEmbedding(nn.Module):
|
||
def __init__(self, in_dim=3, hidden_dim=64, out_dim=128):
|
||
super().__init__()
|
||
self.embedding = nn.Sequential(
|
||
nn.Linear(in_dim, hidden_dim),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_dim, out_dim),
|
||
nn.ReLU(),
|
||
nn.LayerNorm(out_dim)
|
||
)
|
||
|
||
def forward(self, params):
|
||
# min-max 归一化参数
|
||
normalized_params = torch.stack([
|
||
(params[:, 0] - 10.0) / 90.0,
|
||
params[:, 1] / 1.25,
|
||
(params[:, 2] - 1.0) / 0.5
|
||
], dim=1)
|
||
return self.embedding(normalized_params)
|
||
|
||
|
||
# FiLM 融合块
|
||
class FiLMBlock(nn.Module):
|
||
def __init__(self, param_emb_dim=128, feat_channels=128):
|
||
super().__init__()
|
||
self.gamma_gen = nn.Linear(param_emb_dim, feat_channels)
|
||
self.beta_gen = nn.Linear(param_emb_dim, feat_channels)
|
||
|
||
def forward(self, feature_map, param_emb):
|
||
gamma = self.gamma_gen(param_emb).unsqueeze(-1).unsqueeze(-1)
|
||
beta = self.beta_gen(param_emb).unsqueeze(-1).unsqueeze(-1)
|
||
return feature_map * (1.0 + gamma) + beta
|
||
|
||
|
||
# 通道交叉注意力融合块
|
||
class ChannelCrossAttention(nn.Module):
|
||
def __init__(self, param_emb_dim=128, feat_channels=128):
|
||
super().__init__()
|
||
self.param_emb_dim = param_emb_dim
|
||
self.feat_channels = feat_channels
|
||
self.hidden_dim = max(16, self.feat_channels // 4)
|
||
|
||
self.global_pool = nn.AdaptiveAvgPool2d(1)
|
||
|
||
self.bottle_neck = nn.Sequential(
|
||
nn.Linear(self.param_emb_dim + self.feat_channels, self.hidden_dim),
|
||
nn.ReLU(inplace=True),
|
||
nn.Linear(self.hidden_dim, self.feat_channels),
|
||
nn.Sigmoid()
|
||
)
|
||
|
||
def forward(self, feature_map, param_emb):
|
||
b, c, h, w = feature_map.shape
|
||
pooled = self.global_pool(feature_map).view(b, c)
|
||
cat = torch.cat([pooled, param_emb], dim=1)
|
||
weights = self.bottle_neck(cat)
|
||
weights4d = weights.view(b, c, 1, 1)
|
||
return feature_map * (1.0 + weights4d)
|
||
|
||
|
||
# SE块
|
||
class SEBlock(nn.Module):
|
||
def __init__(self, feat_channels=128):
|
||
super().__init__()
|
||
self.feat_channels = feat_channels
|
||
self.hidden_dim = max(16, self.feat_channels // 4)
|
||
|
||
self.global_pool = nn.AdaptiveAvgPool2d(1)
|
||
|
||
self.bottle_neck = nn.Sequential(
|
||
nn.Linear(self.feat_channels, self.hidden_dim),
|
||
nn.ReLU(inplace=True),
|
||
nn.Linear(self.hidden_dim, self.feat_channels),
|
||
nn.Sigmoid()
|
||
)
|
||
|
||
def forward(self, feature_map):
|
||
b, c, h, w = feature_map.shape
|
||
pooled = self.global_pool(feature_map).view(b, c)
|
||
weights = self.bottle_neck(pooled)
|
||
weights4d = weights.view(b, c, 1, 1)
|
||
return feature_map * (1.0 + weights4d)
|
||
|
||
|
||
# 仅返回特征图的恒等变换
|
||
class FusionIdentity(nn.Module):
|
||
def forward(self, feature_map, param_emb):
|
||
return feature_map
|
||
|
||
|
||
# 使用 FiLM 融合的模型
|
||
class MOAFWithFiLM(nn.Module):
|
||
def __init__(self, fusion_level=None):
|
||
super().__init__()
|
||
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
|
||
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
|
||
self.stage2 = shuff.stage2
|
||
self.stage3 = shuff.stage3
|
||
self.stage4 = shuff.stage4
|
||
self.cbr2 = shuff.conv5
|
||
|
||
self.param_embedding = ParamEmbedding()
|
||
|
||
if fusion_level is None:
|
||
fusion_level = [2]
|
||
|
||
self.film_block0 = FiLMBlock(feat_channels=48) if 0 in fusion_level else FusionIdentity()
|
||
self.film_block1 = FiLMBlock(feat_channels=96) if 1 in fusion_level else FusionIdentity()
|
||
self.film_block2 = FiLMBlock(feat_channels=192) if 2 in fusion_level else FusionIdentity()
|
||
|
||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||
self.regressor = nn.Sequential(
|
||
nn.Flatten(),
|
||
nn.Linear(1024, 128),
|
||
nn.ReLU(inplace=True),
|
||
nn.Linear(128, 1)
|
||
)
|
||
|
||
def forward(self, image, params):
|
||
x = self.cbrm(image)
|
||
x = self.stage2(x)
|
||
param_emb = self.param_embedding(params)
|
||
x = self.film_block0(x, param_emb)
|
||
x = self.stage3(x)
|
||
x = self.film_block1(x, param_emb)
|
||
x = self.stage4(x)
|
||
x = self.film_block2(x, param_emb)
|
||
x = self.cbr2(x)
|
||
x = self.avgpool(x)
|
||
x = self.regressor(x)
|
||
return x
|
||
|
||
|
||
# 使用交叉注意力融合的模型
|
||
class MOAFWithChannelCrossAttention(nn.Module):
|
||
def __init__(self, fusion_level=None):
|
||
super().__init__()
|
||
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
|
||
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
|
||
self.stage2 = shuff.stage2
|
||
self.stage3 = shuff.stage3
|
||
self.stage4 = shuff.stage4
|
||
self.cbr2 = shuff.conv5
|
||
|
||
self.param_embedding = ParamEmbedding()
|
||
|
||
if fusion_level is None:
|
||
fusion_level = [2]
|
||
|
||
self.cca_block0 = ChannelCrossAttention(feat_channels=48) if 0 in fusion_level else FusionIdentity()
|
||
self.cca_block1 = ChannelCrossAttention(feat_channels=96) if 1 in fusion_level else FusionIdentity()
|
||
self.cca_block2 = ChannelCrossAttention(feat_channels=192) if 2 in fusion_level else FusionIdentity()
|
||
|
||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||
self.regressor = nn.Sequential(
|
||
nn.Flatten(),
|
||
nn.Linear(1024, 128),
|
||
nn.ReLU(inplace=True),
|
||
nn.Linear(128, 1)
|
||
)
|
||
|
||
def forward(self, image, params):
|
||
x = self.cbrm(image)
|
||
x = self.stage2(x)
|
||
param_emb = self.param_embedding(params)
|
||
x = self.cca_block0(x, param_emb)
|
||
x = self.stage3(x)
|
||
x = self.cca_block1(x, param_emb)
|
||
x = self.stage4(x)
|
||
x = self.cca_block2(x, param_emb)
|
||
x = self.cbr2(x)
|
||
x = self.avgpool(x)
|
||
x = self.regressor(x)
|
||
return x
|
||
|
||
|
||
# 使用 SE 块但无融合的模型
|
||
class MOAFWithSE(nn.Module):
|
||
def __init__(self, fusion_level=None):
|
||
super().__init__()
|
||
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
|
||
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
|
||
self.stage2 = shuff.stage2
|
||
self.stage3 = shuff.stage3
|
||
self.stage4 = shuff.stage4
|
||
self.cbr2 = shuff.conv5
|
||
|
||
if fusion_level is None:
|
||
fusion_level = [0]
|
||
|
||
self.se_block0 = SEBlock(feat_channels=48) if 0 in fusion_level else nn.Identity()
|
||
self.se_block1 = SEBlock(feat_channels=96) if 1 in fusion_level else nn.Identity()
|
||
self.se_block2 = SEBlock(feat_channels=192) if 2 in fusion_level else nn.Identity()
|
||
|
||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||
self.regressor = nn.Sequential(
|
||
nn.Flatten(),
|
||
nn.Linear(1024, 128),
|
||
nn.ReLU(inplace=True),
|
||
nn.Linear(128, 1)
|
||
)
|
||
|
||
def forward(self, image, params):
|
||
x = self.cbrm(image)
|
||
x = self.stage2(x)
|
||
x = self.se_block0(x)
|
||
x = self.stage3(x)
|
||
x = self.se_block1(x)
|
||
x = self.stage4(x)
|
||
x = self.se_block2(x)
|
||
x = self.cbr2(x)
|
||
x = self.avgpool(x)
|
||
x = self.regressor(x)
|
||
return x
|
||
|
||
# 多回归头模型
|
||
class MOAFWithMMLP(nn.Module):
|
||
def __init__(self, num_lenses=7):
|
||
super().__init__()
|
||
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
|
||
self.features = nn.Sequential(
|
||
shuff.conv1, shuff.maxpool,
|
||
shuff.stage2, shuff.stage3,
|
||
shuff.stage4, shuff.conv5
|
||
)
|
||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||
self.num_lenses = num_lenses
|
||
self.regressors = nn.ModuleList([
|
||
self._create_regressor_head() for _ in range(num_lenses)
|
||
])
|
||
|
||
def _create_regressor_head(self):
|
||
return nn.Sequential(
|
||
nn.Flatten(),
|
||
nn.Linear(1024, 128),
|
||
nn.ReLU(inplace=True),
|
||
nn.Linear(128, 1)
|
||
)
|
||
|
||
def _find_lens_id(self, params):
|
||
batch_size = params.shape[0]
|
||
lens_ids = []
|
||
|
||
for i in range(batch_size):
|
||
mag, na, rix = params[i][0].item(), params[i][1].item(), params[i][2].item()
|
||
|
||
# 直接的条件匹配(类似于switch-case)
|
||
if mag == 10 and na == 0.25 and rix == 1.0000:
|
||
lens_ids.append(0) # obj1
|
||
elif mag == 10 and na == 0.30 and rix == 1.0000:
|
||
lens_ids.append(1) # obj2
|
||
elif mag == 20 and na == 0.70 and rix == 1.0000:
|
||
lens_ids.append(2) # obj3
|
||
elif mag == 20 and na == 0.80 and rix == 1.0000:
|
||
lens_ids.append(3) # obj4
|
||
elif mag == 40 and na == 0.65 and rix == 1.0000:
|
||
lens_ids.append(4) # obj5
|
||
elif mag == 100 and na == 0.80 and rix == 1.0000:
|
||
lens_ids.append(5) # obj6
|
||
elif mag == 100 and na == 1.25 and rix == 1.4730:
|
||
lens_ids.append(6) # obj7
|
||
else:
|
||
lens_ids.append(0)
|
||
|
||
return torch.tensor(lens_ids, dtype=torch.long, device=params.device)
|
||
|
||
def forward(self, image, params):
|
||
x = self.features(image)
|
||
x = self.avgpool(x)
|
||
|
||
lens_ids = self._find_lens_id(params)
|
||
batch_size = params.size(0)
|
||
outputs = []
|
||
for i in range(batch_size):
|
||
current_lens_id = lens_ids[i]
|
||
# 确保lens_id在有效范围内
|
||
if current_lens_id < 0 or current_lens_id >= self.num_lenses:
|
||
current_lens_id = 0
|
||
|
||
# 选择对应的回归头
|
||
head_output = self.regressors[current_lens_id](x[i].unsqueeze(0))
|
||
outputs.append(head_output)
|
||
|
||
return torch.cat(outputs, dim=0)
|
||
|