MOAF/MOAFModels.py

317 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torch import nn
from torchvision import models
# 无融合模型
class MOAFNoFusion(nn.Module):
def __init__(self):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.features = nn.Sequential(
shuff.conv1, shuff.maxpool,
shuff.stage2, shuff.stage3,
shuff.stage4, shuff.conv5
)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def forward(self, image, params):
x = self.features(image)
x = self.avgpool(x)
x = self.regressor(x)
return x
# 参数嵌入模型
class ParamEmbedding(nn.Module):
def __init__(self, in_dim=3, hidden_dim=64, out_dim=128):
super().__init__()
self.embedding = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, out_dim),
nn.ReLU(),
nn.LayerNorm(out_dim)
)
def forward(self, params):
# min-max 归一化参数
normalized_params = torch.stack([
(params[:, 0] - 10.0) / 90.0,
params[:, 1] / 1.25,
(params[:, 2] - 1.0) / 0.5
], dim=1)
return self.embedding(normalized_params)
# FiLM 融合块
class FiLMBlock(nn.Module):
def __init__(self, param_emb_dim=128, feat_channels=128):
super().__init__()
self.gamma_gen = nn.Linear(param_emb_dim, feat_channels)
self.beta_gen = nn.Linear(param_emb_dim, feat_channels)
def forward(self, feature_map, param_emb):
gamma = self.gamma_gen(param_emb).unsqueeze(-1).unsqueeze(-1)
beta = self.beta_gen(param_emb).unsqueeze(-1).unsqueeze(-1)
return feature_map * (1.0 + gamma) + beta
# 通道交叉注意力融合块
class ChannelCrossAttention(nn.Module):
def __init__(self, param_emb_dim=128, feat_channels=128):
super().__init__()
self.param_emb_dim = param_emb_dim
self.feat_channels = feat_channels
self.hidden_dim = max(16, self.feat_channels // 4)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.bottle_neck = nn.Sequential(
nn.Linear(self.param_emb_dim + self.feat_channels, self.hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(self.hidden_dim, self.feat_channels),
nn.Sigmoid()
)
def forward(self, feature_map, param_emb):
b, c, h, w = feature_map.shape
pooled = self.global_pool(feature_map).view(b, c)
cat = torch.cat([pooled, param_emb], dim=1)
weights = self.bottle_neck(cat)
weights4d = weights.view(b, c, 1, 1)
return feature_map * (1.0 + weights4d)
# SE块
class SEBlock(nn.Module):
def __init__(self, feat_channels=128):
super().__init__()
self.feat_channels = feat_channels
self.hidden_dim = max(16, self.feat_channels // 4)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.bottle_neck = nn.Sequential(
nn.Linear(self.feat_channels, self.hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(self.hidden_dim, self.feat_channels),
nn.Sigmoid()
)
def forward(self, feature_map):
b, c, h, w = feature_map.shape
pooled = self.global_pool(feature_map).view(b, c)
weights = self.bottle_neck(pooled)
weights4d = weights.view(b, c, 1, 1)
return feature_map * (1.0 + weights4d)
# 仅返回特征图的恒等变换
class FusionIdentity(nn.Module):
def forward(self, feature_map, param_emb):
return feature_map
# 使用 FiLM 融合的模型
class MOAFWithFiLM(nn.Module):
def __init__(self, fusion_level=None):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
self.stage2 = shuff.stage2
self.stage3 = shuff.stage3
self.stage4 = shuff.stage4
self.cbr2 = shuff.conv5
self.param_embedding = ParamEmbedding()
if fusion_level is None:
fusion_level = [2]
self.film_block0 = FiLMBlock(feat_channels=48) if 0 in fusion_level else FusionIdentity()
self.film_block1 = FiLMBlock(feat_channels=96) if 1 in fusion_level else FusionIdentity()
self.film_block2 = FiLMBlock(feat_channels=192) if 2 in fusion_level else FusionIdentity()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def forward(self, image, params):
x = self.cbrm(image)
x = self.stage2(x)
param_emb = self.param_embedding(params)
x = self.film_block0(x, param_emb)
x = self.stage3(x)
x = self.film_block1(x, param_emb)
x = self.stage4(x)
x = self.film_block2(x, param_emb)
x = self.cbr2(x)
x = self.avgpool(x)
x = self.regressor(x)
return x
# 使用交叉注意力融合的模型
class MOAFWithChannelCrossAttention(nn.Module):
def __init__(self, fusion_level=None):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
self.stage2 = shuff.stage2
self.stage3 = shuff.stage3
self.stage4 = shuff.stage4
self.cbr2 = shuff.conv5
self.param_embedding = ParamEmbedding()
if fusion_level is None:
fusion_level = [2]
self.cca_block0 = ChannelCrossAttention(feat_channels=48) if 0 in fusion_level else FusionIdentity()
self.cca_block1 = ChannelCrossAttention(feat_channels=96) if 1 in fusion_level else FusionIdentity()
self.cca_block2 = ChannelCrossAttention(feat_channels=192) if 2 in fusion_level else FusionIdentity()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def forward(self, image, params):
x = self.cbrm(image)
x = self.stage2(x)
param_emb = self.param_embedding(params)
x = self.cca_block0(x, param_emb)
x = self.stage3(x)
x = self.cca_block1(x, param_emb)
x = self.stage4(x)
x = self.cca_block2(x, param_emb)
x = self.cbr2(x)
x = self.avgpool(x)
x = self.regressor(x)
return x
# 使用 SE 块但无融合的模型
class MOAFWithSE(nn.Module):
def __init__(self, fusion_level=None):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
self.stage2 = shuff.stage2
self.stage3 = shuff.stage3
self.stage4 = shuff.stage4
self.cbr2 = shuff.conv5
if fusion_level is None:
fusion_level = [0]
self.se_block0 = SEBlock(feat_channels=48) if 0 in fusion_level else nn.Identity()
self.se_block1 = SEBlock(feat_channels=96) if 1 in fusion_level else nn.Identity()
self.se_block2 = SEBlock(feat_channels=192) if 2 in fusion_level else nn.Identity()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def forward(self, image, params):
x = self.cbrm(image)
x = self.stage2(x)
x = self.se_block0(x)
x = self.stage3(x)
x = self.se_block1(x)
x = self.stage4(x)
x = self.se_block2(x)
x = self.cbr2(x)
x = self.avgpool(x)
x = self.regressor(x)
return x
# 多回归头模型
class MOAFWithMMLP(nn.Module):
def __init__(self, num_lenses=7):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.features = nn.Sequential(
shuff.conv1, shuff.maxpool,
shuff.stage2, shuff.stage3,
shuff.stage4, shuff.conv5
)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.num_lenses = num_lenses
self.regressors = nn.ModuleList([
self._create_regressor_head() for _ in range(num_lenses)
])
def _create_regressor_head(self):
return nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def _find_lens_id(self, params):
batch_size = params.shape[0]
lens_ids = []
for i in range(batch_size):
mag, na, rix = params[i][0].item(), params[i][1].item(), params[i][2].item()
# 直接的条件匹配类似于switch-case
if mag == 10 and na == 0.25 and rix == 1.0000:
lens_ids.append(0) # obj1
elif mag == 10 and na == 0.30 and rix == 1.0000:
lens_ids.append(1) # obj2
elif mag == 20 and na == 0.70 and rix == 1.0000:
lens_ids.append(2) # obj3
elif mag == 20 and na == 0.80 and rix == 1.0000:
lens_ids.append(3) # obj4
elif mag == 40 and na == 0.65 and rix == 1.0000:
lens_ids.append(4) # obj5
elif mag == 100 and na == 0.80 and rix == 1.0000:
lens_ids.append(5) # obj6
elif mag == 100 and na == 1.25 and rix == 1.4730:
lens_ids.append(6) # obj7
else:
lens_ids.append(0)
return torch.tensor(lens_ids, dtype=torch.long, device=params.device)
def forward(self, image, params):
x = self.features(image)
x = self.avgpool(x)
lens_ids = self._find_lens_id(params)
batch_size = params.size(0)
outputs = []
for i in range(batch_size):
current_lens_id = lens_ids[i]
# 确保lens_id在有效范围内
if current_lens_id < 0 or current_lens_id >= self.num_lenses:
current_lens_id = 0
# 选择对应的回归头
head_output = self.regressors[current_lens_id](x[i].unsqueeze(0))
outputs.append(head_output)
return torch.cat(outputs, dim=0)