SparseFocus/models.py
2026-06-02 13:51:22 +08:00

138 lines
4.4 KiB
Python

import torch
from torch import nn
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
from torchvision.models.convnext import LayerNorm2d
from torchvision.ops import SqueezeExcitation, Permute, StochasticDepth
# RINet
class ImportanceClassifier(nn.Module):
def __init__(self, in_channels=576, hidden_channels=96):
super().__init__()
self.se = SqueezeExcitation(input_channels=576, squeeze_channels=144)
self.pool = nn.AvgPool2d(kernel_size=7, stride=1)
self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=2, stride=1, bias=True)
self.act1 = nn.Hardswish(inplace=True)
self.conv2 = nn.Conv2d(hidden_channels, 1, kernel_size=1, stride=1, bias=True)
self.act2 = nn.Sigmoid()
def forward(self, x):
x = self.se(x)
x = self.pool(x)
x = self.conv1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.act2(x)
return x.squeeze()
class RINet(nn.Module):
def __init__(self):
super().__init__()
backbone_model = mobilenet_v3_small()
self.backbone = backbone_model.features
self.classifier = ImportanceClassifier(in_channels=576, hidden_channels=96)
def forward(self, x):
x = self.backbone(x)
x = self.classifier(x)
return x
# DPNet
class DFEBlock(nn.Module):
def __init__(self, dim, layer_scale=1e-6, stochastic_depth_prob=0.0):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
nn.BatchNorm2d(dim, eps=1e-4),
Permute([0, 2, 3, 1]),
nn.Linear(dim, 4 * dim, bias=True),
nn.ReLU6(inplace=True),
nn.Linear(4 * dim, dim, bias=True),
Permute([0, 3, 1, 2]),
)
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, mode="row")
def forward(self, x):
result = self.layer_scale * self.block(x)
result = self.stochastic_depth(result)
result += x
return result
class DFEBlockConfig:
def __init__(self, input_channels, out_channels, num_layers):
self.input_channels = input_channels
self.out_channels = out_channels
self.num_layers = num_layers
class DPNet(nn.Module):
def __init__(self, stochastic_depth_prob=0.0, layer_scale=1e-6):
super().__init__()
block_setting = [
DFEBlockConfig(128, 256, 3),
DFEBlockConfig(256, 512, 3),
DFEBlockConfig(512, 1024, 9),
DFEBlockConfig(1024, None, 3),
]
layers = []
firstconv_output_channels = block_setting[0].input_channels
layers.append(
nn.Conv2d(3, firstconv_output_channels, kernel_size=4, stride=4, padding=0, bias=True)
)
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
stage_block_id = 0
for cnf in block_setting:
stage = []
for _ in range(cnf.num_layers):
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
stage.append(
DFEBlock(dim=cnf.input_channels, layer_scale=layer_scale, stochastic_depth_prob=sd_prob)
)
stage_block_id += 1
layers.append(nn.Sequential(*stage))
if cnf.out_channels is not None:
layers.append(
nn.Sequential(
LayerNorm2d(cnf.input_channels, eps=1e-4),
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2, bias=True)
)
)
self.features = nn.Sequential(*layers)
self.pool = nn.MaxPool2d(kernel_size=7, stride=1)
lastblock = block_setting[-1]
lastconv_output_channels = (
lastblock.out_channels
if lastblock.out_channels is not None
else lastblock.input_channels
)
self.regressor = nn.Sequential(
nn.Conv2d(lastconv_output_channels, 1280, kernel_size=1, stride=1, padding=0, bias=True),
nn.Conv2d(1280, 100, kernel_size=1, stride=1, padding=0, bias=True),
nn.Flatten(1),
nn.Linear(100, 1, bias=True)
)
def forward(self, x):
x = self.features(x)
x = self.pool(x)
x = self.regressor(x)
return x