138 lines
4.4 KiB
Python
138 lines
4.4 KiB
Python
import torch
|
|
from torch import nn
|
|
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
|
|
from torchvision.models.convnext import LayerNorm2d
|
|
from torchvision.ops import SqueezeExcitation, Permute, StochasticDepth
|
|
|
|
|
|
# RINet
|
|
class ImportanceClassifier(nn.Module):
|
|
def __init__(self, in_channels=576, hidden_channels=96):
|
|
super().__init__()
|
|
|
|
self.se = SqueezeExcitation(input_channels=576, squeeze_channels=144)
|
|
self.pool = nn.AvgPool2d(kernel_size=7, stride=1)
|
|
self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=2, stride=1, bias=True)
|
|
self.act1 = nn.Hardswish(inplace=True)
|
|
self.conv2 = nn.Conv2d(hidden_channels, 1, kernel_size=1, stride=1, bias=True)
|
|
self.act2 = nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
x = self.se(x)
|
|
x = self.pool(x)
|
|
x = self.conv1(x)
|
|
x = self.act1(x)
|
|
x = self.conv2(x)
|
|
x = self.act2(x)
|
|
return x.squeeze()
|
|
|
|
|
|
class RINet(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
backbone_model = mobilenet_v3_small()
|
|
|
|
self.backbone = backbone_model.features
|
|
self.classifier = ImportanceClassifier(in_channels=576, hidden_channels=96)
|
|
|
|
def forward(self, x):
|
|
x = self.backbone(x)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
# DPNet
|
|
class DFEBlock(nn.Module):
|
|
def __init__(self, dim, layer_scale=1e-6, stochastic_depth_prob=0.0):
|
|
super().__init__()
|
|
|
|
self.block = nn.Sequential(
|
|
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
|
|
nn.BatchNorm2d(dim, eps=1e-4),
|
|
Permute([0, 2, 3, 1]),
|
|
nn.Linear(dim, 4 * dim, bias=True),
|
|
nn.ReLU6(inplace=True),
|
|
nn.Linear(4 * dim, dim, bias=True),
|
|
Permute([0, 3, 1, 2]),
|
|
)
|
|
|
|
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
|
|
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, mode="row")
|
|
|
|
|
|
def forward(self, x):
|
|
result = self.layer_scale * self.block(x)
|
|
result = self.stochastic_depth(result)
|
|
result += x
|
|
return result
|
|
|
|
|
|
class DFEBlockConfig:
|
|
def __init__(self, input_channels, out_channels, num_layers):
|
|
self.input_channels = input_channels
|
|
self.out_channels = out_channels
|
|
self.num_layers = num_layers
|
|
|
|
|
|
class DPNet(nn.Module):
|
|
def __init__(self, stochastic_depth_prob=0.0, layer_scale=1e-6):
|
|
super().__init__()
|
|
|
|
block_setting = [
|
|
DFEBlockConfig(128, 256, 3),
|
|
DFEBlockConfig(256, 512, 3),
|
|
DFEBlockConfig(512, 1024, 9),
|
|
DFEBlockConfig(1024, None, 3),
|
|
]
|
|
|
|
layers = []
|
|
|
|
firstconv_output_channels = block_setting[0].input_channels
|
|
layers.append(
|
|
nn.Conv2d(3, firstconv_output_channels, kernel_size=4, stride=4, padding=0, bias=True)
|
|
)
|
|
|
|
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
|
|
stage_block_id = 0
|
|
|
|
for cnf in block_setting:
|
|
stage = []
|
|
for _ in range(cnf.num_layers):
|
|
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
|
|
stage.append(
|
|
DFEBlock(dim=cnf.input_channels, layer_scale=layer_scale, stochastic_depth_prob=sd_prob)
|
|
)
|
|
stage_block_id += 1
|
|
layers.append(nn.Sequential(*stage))
|
|
if cnf.out_channels is not None:
|
|
layers.append(
|
|
nn.Sequential(
|
|
LayerNorm2d(cnf.input_channels, eps=1e-4),
|
|
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2, bias=True)
|
|
)
|
|
)
|
|
|
|
self.features = nn.Sequential(*layers)
|
|
self.pool = nn.MaxPool2d(kernel_size=7, stride=1)
|
|
|
|
lastblock = block_setting[-1]
|
|
lastconv_output_channels = (
|
|
lastblock.out_channels
|
|
if lastblock.out_channels is not None
|
|
else lastblock.input_channels
|
|
)
|
|
|
|
self.regressor = nn.Sequential(
|
|
nn.Conv2d(lastconv_output_channels, 1280, kernel_size=1, stride=1, padding=0, bias=True),
|
|
nn.Conv2d(1280, 100, kernel_size=1, stride=1, padding=0, bias=True),
|
|
nn.Flatten(1),
|
|
nn.Linear(100, 1, bias=True)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = self.pool(x)
|
|
x = self.regressor(x)
|
|
return x
|