MOAF/MOAFModels.py
2025-11-03 09:57:54 +08:00

310 lines
10 KiB
Python

import torch
from torch import nn
from torchvision import models
# 无融合模型
class MOAFNoFusion(nn.Module):
def __init__(self):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.features = nn.Sequential(
shuff.conv1, shuff.maxpool,
shuff.stage2, shuff.stage3,
shuff.stage4, shuff.conv5
)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def forward(self, image, params):
x = self.features(image)
x = self.avgpool(x)
x = self.regressor(x)
return x
# 参数嵌入模型
class ParamEmbedding(nn.Module):
def __init__(self, in_dim=3, hidden_dim=64, out_dim=128):
super().__init__()
self.embedding = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, out_dim),
nn.ReLU(),
nn.LayerNorm(out_dim)
)
def forward(self, params):
# min-max 归一化参数
normalized_params = torch.stack([
(params[:, 0] - 10.0) / 90.0,
params[:, 1] / 1.25,
(params[:, 2] - 1.0) / 0.5
], dim=1)
return self.embedding(normalized_params)
# FiLM 融合块
class FiLMBlock(nn.Module):
def __init__(self, param_emb_dim=128, feat_channels=128):
super().__init__()
self.gamma_gen = nn.Linear(param_emb_dim, feat_channels)
self.beta_gen = nn.Linear(param_emb_dim, feat_channels)
def forward(self, feature_map, param_emb):
gamma = self.gamma_gen(param_emb).unsqueeze(-1).unsqueeze(-1)
beta = self.beta_gen(param_emb).unsqueeze(-1).unsqueeze(-1)
return feature_map * (1.0 + gamma) + beta
# 通道交叉注意力融合块
class ChannelCrossAttention(nn.Module):
def __init__(self, param_emb_dim=128, feat_channels=128):
super().__init__()
self.param_emb_dim = param_emb_dim
self.feat_channels = feat_channels
self.hidden_dim = max(16, self.feat_channels // 4)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.bottle_neck = nn.Sequential(
nn.Linear(self.param_emb_dim + self.feat_channels, self.hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(self.hidden_dim, self.feat_channels),
nn.Sigmoid()
)
def forward(self, feature_map, param_emb):
b, c, h, w = feature_map.shape
pooled = self.global_pool(feature_map).view(b, c)
cat = torch.cat([pooled, param_emb], dim=1)
weights = self.bottle_neck(cat)
weights4d = weights.view(b, c, 1, 1)
return feature_map * (1.0 + weights4d)
# SE块
class SEBlock(nn.Module):
def __init__(self, feat_channels=128):
super().__init__()
self.feat_channels = feat_channels
self.hidden_dim = max(16, self.feat_channels // 4)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.bottle_neck = nn.Sequential(
nn.Linear(self.feat_channels, self.hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(self.hidden_dim, self.feat_channels),
nn.Sigmoid()
)
def forward(self, feature_map):
b, c, h, w = feature_map.shape
pooled = self.global_pool(feature_map).view(b, c)
weights = self.bottle_neck(pooled)
weights4d = weights.view(b, c, 1, 1)
return feature_map * (1.0 + weights4d)
# 仅返回特征图的恒等变换
class FusionIdentity(nn.Module):
def forward(self, feature_map, param_emb):
return feature_map
# 使用 FiLM 融合的模型
class MOAFWithFiLM(nn.Module):
def __init__(self, fusion_level=None):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
self.stage2 = shuff.stage2
self.stage3 = shuff.stage3
self.stage4 = shuff.stage4
self.cbr2 = shuff.conv5
self.param_embedding = ParamEmbedding()
if fusion_level is None:
fusion_level = [2]
self.film_block0 = FiLMBlock(feat_channels=48) if 0 in fusion_level else FusionIdentity()
self.film_block1 = FiLMBlock(feat_channels=96) if 1 in fusion_level else FusionIdentity()
self.film_block2 = FiLMBlock(feat_channels=192) if 2 in fusion_level else FusionIdentity()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def forward(self, image, params):
x = self.cbrm(image)
x = self.stage2(x)
param_emb = self.param_embedding(params)
x = self.film_block0(x, param_emb)
x = self.stage3(x)
x = self.film_block1(x, param_emb)
x = self.stage4(x)
x = self.film_block2(x, param_emb)
x = self.cbr2(x)
x = self.avgpool(x)
x = self.regressor(x)
return x
# 使用交叉注意力融合的模型
class MOAFWithChannelCrossAttention(nn.Module):
def __init__(self, fusion_level=None):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
self.stage2 = shuff.stage2
self.stage3 = shuff.stage3
self.stage4 = shuff.stage4
self.cbr2 = shuff.conv5
self.param_embedding = ParamEmbedding()
if fusion_level is None:
fusion_level = [2]
self.cca_block0 = ChannelCrossAttention(feat_channels=48) if 0 in fusion_level else FusionIdentity()
self.cca_block1 = ChannelCrossAttention(feat_channels=96) if 1 in fusion_level else FusionIdentity()
self.cca_block2 = ChannelCrossAttention(feat_channels=192) if 2 in fusion_level else FusionIdentity()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def forward(self, image, params):
x = self.cbrm(image)
x = self.stage2(x)
param_emb = self.param_embedding(params)
x = self.cca_block0(x, param_emb)
x = self.stage3(x)
x = self.cca_block1(x, param_emb)
x = self.stage4(x)
x = self.cca_block2(x, param_emb)
x = self.cbr2(x)
x = self.avgpool(x)
x = self.regressor(x)
return x
# 使用 SE 块但无融合的模型
class MOAFWithSE(nn.Module):
def __init__(self, fusion_level=None):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
self.stage2 = shuff.stage2
self.stage3 = shuff.stage3
self.stage4 = shuff.stage4
self.cbr2 = shuff.conv5
if fusion_level is None:
fusion_level = [0]
self.se_block0 = SEBlock(feat_channels=48) if 0 in fusion_level else nn.Identity()
self.se_block1 = SEBlock(feat_channels=96) if 1 in fusion_level else nn.Identity()
self.se_block2 = SEBlock(feat_channels=192) if 2 in fusion_level else nn.Identity()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def forward(self, image, params):
x = self.cbrm(image)
x = self.stage2(x)
x = self.se_block0(x)
x = self.stage3(x)
x = self.se_block1(x)
x = self.stage4(x)
x = self.se_block2(x)
x = self.cbr2(x)
x = self.avgpool(x)
x = self.regressor(x)
return x
# 多回归头模型
class MOAFWithMMLP(nn.Module):
def __init__(self, num_lenses=7):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.features = nn.Sequential(
shuff.conv1, shuff.maxpool,
shuff.stage2, shuff.stage3,
shuff.stage4, shuff.conv5
)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.num_lenses = num_lenses
self.regressors = nn.ModuleList([
self._create_regressor_head() for _ in range(num_lenses)
])
# 注册物镜参数基准张量(不可学习)
lens_params_base = torch.tensor([
[10, 0.25, 1.0000], # obj1
[10, 0.30, 1.0000], # obj2
[20, 0.70, 1.0000], # obj3
[20, 0.80, 1.0000], # obj4
[40, 0.65, 1.0000], # obj5
[100, 0.80, 1.0000], # obj6
[100, 1.25, 1.4730] # obj7
], dtype=torch.float32)
self.register_buffer('lens_params_base', lens_params_base)
def _create_regressor_head(self):
return nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def _find_lens_id(self, params):
# 扩展维度以进行广播计算
params_expanded = params.unsqueeze(1) # [batch_size, 1, 3]
base_expanded = self.lens_params_base.unsqueeze(0) # [1, 7, 3]
# 计算欧氏距离(完全向量化)
distances = torch.sqrt(torch.sum((params_expanded - base_expanded) ** 2, dim=2)) # [batch_size, 7]
# 找到最小距离的索引
lens_ids = torch.argmin(distances, dim=1) # [batch_size]
return lens_ids
def forward(self, image, params):
x = self.features(image)
x = self.avgpool(x)
lens_ids = self._find_lens_id(params)
lens_ids = torch.clamp(lens_ids, 0, self.num_lenses - 1)
batch_size = x.size(0)
all_outputs = []
for i in range(batch_size):
head_idx = lens_ids[i]
output = self.regressors[head_idx](x[i].unsqueeze(0))
all_outputs.append(output)
return torch.cat(all_outputs, dim=0)