import torch from torch import nn from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights from torchvision.models.convnext import LayerNorm2d from torchvision.ops import SqueezeExcitation, Permute, StochasticDepth # RINet class ImportanceClassifier(nn.Module): def __init__(self, in_channels=576, hidden_channels=96): super().__init__() self.se = SqueezeExcitation(input_channels=576, squeeze_channels=144) self.pool = nn.AvgPool2d(kernel_size=7, stride=1) self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=2, stride=1, bias=True) self.act1 = nn.Hardswish(inplace=True) self.conv2 = nn.Conv2d(hidden_channels, 1, kernel_size=1, stride=1, bias=True) self.act2 = nn.Sigmoid() def forward(self, x): x = self.se(x) x = self.pool(x) x = self.conv1(x) x = self.act1(x) x = self.conv2(x) x = self.act2(x) return x.squeeze() class RINet(nn.Module): def __init__(self): super().__init__() backbone_model = mobilenet_v3_small() self.backbone = backbone_model.features self.classifier = ImportanceClassifier(in_channels=576, hidden_channels=96) def forward(self, x): x = self.backbone(x) x = self.classifier(x) return x # DPNet class DFEBlock(nn.Module): def __init__(self, dim, layer_scale=1e-6, stochastic_depth_prob=0.0): super().__init__() self.block = nn.Sequential( nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True), nn.BatchNorm2d(dim, eps=1e-4), Permute([0, 2, 3, 1]), nn.Linear(dim, 4 * dim, bias=True), nn.ReLU6(inplace=True), nn.Linear(4 * dim, dim, bias=True), Permute([0, 3, 1, 2]), ) self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, mode="row") def forward(self, x): result = self.layer_scale * self.block(x) result = self.stochastic_depth(result) result += x return result class DFEBlockConfig: def __init__(self, input_channels, out_channels, num_layers): self.input_channels = input_channels self.out_channels = out_channels self.num_layers = num_layers class DPNet(nn.Module): def __init__(self, stochastic_depth_prob=0.0, layer_scale=1e-6): super().__init__() block_setting = [ DFEBlockConfig(128, 256, 3), DFEBlockConfig(256, 512, 3), DFEBlockConfig(512, 1024, 9), DFEBlockConfig(1024, None, 3), ] layers = [] firstconv_output_channels = block_setting[0].input_channels layers.append( nn.Conv2d(3, firstconv_output_channels, kernel_size=4, stride=4, padding=0, bias=True) ) total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) stage_block_id = 0 for cnf in block_setting: stage = [] for _ in range(cnf.num_layers): sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) stage.append( DFEBlock(dim=cnf.input_channels, layer_scale=layer_scale, stochastic_depth_prob=sd_prob) ) stage_block_id += 1 layers.append(nn.Sequential(*stage)) if cnf.out_channels is not None: layers.append( nn.Sequential( LayerNorm2d(cnf.input_channels, eps=1e-4), nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2, bias=True) ) ) self.features = nn.Sequential(*layers) self.pool = nn.MaxPool2d(kernel_size=7, stride=1) lastblock = block_setting[-1] lastconv_output_channels = ( lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels ) self.regressor = nn.Sequential( nn.Conv2d(lastconv_output_channels, 1280, kernel_size=1, stride=1, padding=0, bias=True), nn.Conv2d(1280, 100, kernel_size=1, stride=1, padding=0, bias=True), nn.Flatten(1), nn.Linear(100, 1, bias=True) ) def forward(self, x): x = self.features(x) x = self.pool(x) x = self.regressor(x) return x