import torch from torch import nn from torchvision import models # 无融合模型 class MOAFNoFusion(nn.Module): def __init__(self): super().__init__() shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") self.features = nn.Sequential( shuff.conv1, shuff.maxpool, shuff.stage2, shuff.stage3, shuff.stage4, shuff.conv5 ) self.avgpool = nn.AdaptiveAvgPool2d(1) self.regressor = nn.Sequential( nn.Flatten(), nn.Linear(1024, 128), nn.ReLU(inplace=True), nn.Linear(128, 1) ) def forward(self, image, params): x = self.features(image) x = self.avgpool(x) x = self.regressor(x) return x # 参数嵌入模型 class ParamEmbedding(nn.Module): def __init__(self, in_dim=3, hidden_dim=64, out_dim=128): super().__init__() self.embedding = nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, out_dim), nn.ReLU(), nn.LayerNorm(out_dim) ) def forward(self, params): # min-max 归一化参数 normalized_params = torch.stack([ (params[:, 0] - 10.0) / 90.0, params[:, 1] / 1.25, (params[:, 2] - 1.0) / 0.5 ], dim=1) return self.embedding(normalized_params) # FiLM 融合块 class FiLMBlock(nn.Module): def __init__(self, param_emb_dim=128, feat_channels=128): super().__init__() self.gamma_gen = nn.Linear(param_emb_dim, feat_channels) self.beta_gen = nn.Linear(param_emb_dim, feat_channels) def forward(self, feature_map, param_emb): gamma = self.gamma_gen(param_emb).unsqueeze(-1).unsqueeze(-1) beta = self.beta_gen(param_emb).unsqueeze(-1).unsqueeze(-1) return feature_map * (1.0 + gamma) + beta # 通道交叉注意力融合块 class ChannelCrossAttention(nn.Module): def __init__(self, param_emb_dim=128, feat_channels=128): super().__init__() self.param_emb_dim = param_emb_dim self.feat_channels = feat_channels self.hidden_dim = max(16, self.feat_channels // 4) self.global_pool = nn.AdaptiveAvgPool2d(1) self.bottle_neck = nn.Sequential( nn.Linear(self.param_emb_dim + self.feat_channels, self.hidden_dim), nn.ReLU(inplace=True), nn.Linear(self.hidden_dim, self.feat_channels), nn.Sigmoid() ) def forward(self, feature_map, param_emb): b, c, h, w = feature_map.shape pooled = self.global_pool(feature_map).view(b, c) cat = torch.cat([pooled, param_emb], dim=1) weights = self.bottle_neck(cat) weights4d = weights.view(b, c, 1, 1) return feature_map * (1.0 + weights4d) # SE块 class SEBlock(nn.Module): def __init__(self, feat_channels=128): super().__init__() self.feat_channels = feat_channels self.hidden_dim = max(16, self.feat_channels // 4) self.global_pool = nn.AdaptiveAvgPool2d(1) self.bottle_neck = nn.Sequential( nn.Linear(self.feat_channels, self.hidden_dim), nn.ReLU(inplace=True), nn.Linear(self.hidden_dim, self.feat_channels), nn.Sigmoid() ) def forward(self, feature_map): b, c, h, w = feature_map.shape pooled = self.global_pool(feature_map).view(b, c) weights = self.bottle_neck(pooled) weights4d = weights.view(b, c, 1, 1) return feature_map * (1.0 + weights4d) # 仅返回特征图的恒等变换 class FusionIdentity(nn.Module): def forward(self, feature_map, param_emb): return feature_map # 使用 FiLM 融合的模型 class MOAFWithFiLM(nn.Module): def __init__(self, fusion_level=None): super().__init__() shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool) self.stage2 = shuff.stage2 self.stage3 = shuff.stage3 self.stage4 = shuff.stage4 self.cbr2 = shuff.conv5 self.param_embedding = ParamEmbedding() if fusion_level is None: fusion_level = [2] self.film_block0 = FiLMBlock(feat_channels=48) if 0 in fusion_level else FusionIdentity() self.film_block1 = FiLMBlock(feat_channels=96) if 1 in fusion_level else FusionIdentity() self.film_block2 = FiLMBlock(feat_channels=192) if 2 in fusion_level else FusionIdentity() self.avgpool = nn.AdaptiveAvgPool2d(1) self.regressor = nn.Sequential( nn.Flatten(), nn.Linear(1024, 128), nn.ReLU(inplace=True), nn.Linear(128, 1) ) def forward(self, image, params): x = self.cbrm(image) x = self.stage2(x) param_emb = self.param_embedding(params) x = self.film_block0(x, param_emb) x = self.stage3(x) x = self.film_block1(x, param_emb) x = self.stage4(x) x = self.film_block2(x, param_emb) x = self.cbr2(x) x = self.avgpool(x) x = self.regressor(x) return x # 使用交叉注意力融合的模型 class MOAFWithChannelCrossAttention(nn.Module): def __init__(self, fusion_level=None): super().__init__() shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool) self.stage2 = shuff.stage2 self.stage3 = shuff.stage3 self.stage4 = shuff.stage4 self.cbr2 = shuff.conv5 self.param_embedding = ParamEmbedding() if fusion_level is None: fusion_level = [2] self.cca_block0 = ChannelCrossAttention(feat_channels=48) if 0 in fusion_level else FusionIdentity() self.cca_block1 = ChannelCrossAttention(feat_channels=96) if 1 in fusion_level else FusionIdentity() self.cca_block2 = ChannelCrossAttention(feat_channels=192) if 2 in fusion_level else FusionIdentity() self.avgpool = nn.AdaptiveAvgPool2d(1) self.regressor = nn.Sequential( nn.Flatten(), nn.Linear(1024, 128), nn.ReLU(inplace=True), nn.Linear(128, 1) ) def forward(self, image, params): x = self.cbrm(image) x = self.stage2(x) param_emb = self.param_embedding(params) x = self.cca_block0(x, param_emb) x = self.stage3(x) x = self.cca_block1(x, param_emb) x = self.stage4(x) x = self.cca_block2(x, param_emb) x = self.cbr2(x) x = self.avgpool(x) x = self.regressor(x) return x # 使用 SE 块但无融合的模型 class MOAFWithSE(nn.Module): def __init__(self, fusion_level=None): super().__init__() shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool) self.stage2 = shuff.stage2 self.stage3 = shuff.stage3 self.stage4 = shuff.stage4 self.cbr2 = shuff.conv5 if fusion_level is None: fusion_level = [0] self.se_block0 = SEBlock(feat_channels=48) if 0 in fusion_level else nn.Identity() self.se_block1 = SEBlock(feat_channels=96) if 1 in fusion_level else nn.Identity() self.se_block2 = SEBlock(feat_channels=192) if 2 in fusion_level else nn.Identity() self.avgpool = nn.AdaptiveAvgPool2d(1) self.regressor = nn.Sequential( nn.Flatten(), nn.Linear(1024, 128), nn.ReLU(inplace=True), nn.Linear(128, 1) ) def forward(self, image, params): x = self.cbrm(image) x = self.stage2(x) x = self.se_block0(x) x = self.stage3(x) x = self.se_block1(x) x = self.stage4(x) x = self.se_block2(x) x = self.cbr2(x) x = self.avgpool(x) x = self.regressor(x) return x # 多回归头模型 class MOAFWithMMLP(nn.Module): def __init__(self, num_lenses=7): super().__init__() shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") self.features = nn.Sequential( shuff.conv1, shuff.maxpool, shuff.stage2, shuff.stage3, shuff.stage4, shuff.conv5 ) self.avgpool = nn.AdaptiveAvgPool2d(1) self.num_lenses = num_lenses self.regressors = nn.ModuleList([ self._create_regressor_head() for _ in range(num_lenses) ]) def _create_regressor_head(self): return nn.Sequential( nn.Flatten(), nn.Linear(1024, 128), nn.ReLU(inplace=True), nn.Linear(128, 1) ) def _find_lens_id(self, params): batch_size = params.shape[0] lens_ids = [] for i in range(batch_size): mag, na, rix = params[i][0].item(), params[i][1].item(), params[i][2].item() # 直接的条件匹配(类似于switch-case) if mag == 10 and na == 0.25 and rix == 1.0000: lens_ids.append(0) # obj1 elif mag == 10 and na == 0.30 and rix == 1.0000: lens_ids.append(1) # obj2 elif mag == 20 and na == 0.70 and rix == 1.0000: lens_ids.append(2) # obj3 elif mag == 20 and na == 0.80 and rix == 1.0000: lens_ids.append(3) # obj4 elif mag == 40 and na == 0.65 and rix == 1.0000: lens_ids.append(4) # obj5 elif mag == 100 and na == 0.80 and rix == 1.0000: lens_ids.append(5) # obj6 elif mag == 100 and na == 1.25 and rix == 1.4730: lens_ids.append(6) # obj7 else: lens_ids.append(0) return torch.tensor(lens_ids, dtype=torch.long, device=params.device) def forward(self, image, params): x = self.features(image) x = self.avgpool(x) lens_ids = self._find_lens_id(params) batch_size = params.size(0) outputs = [] for i in range(batch_size): current_lens_id = lens_ids[i] # 确保lens_id在有效范围内 if current_lens_id < 0 or current_lens_id >= self.num_lenses: current_lens_id = 0 # 选择对应的回归头 head_output = self.regressors[current_lens_id](x[i].unsqueeze(0)) outputs.append(head_output) return torch.cat(outputs, dim=0)