first commit

This commit is contained in:
kaiza_hikaru 2026-06-02 13:51:22 +08:00
commit 1b6fcf93d0
11 changed files with 2041 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
__pycache__/
*.py[cod]
.vscode/
20*_RIN*/
20*_RIN_OLD*/

13
config_dpn.toml Normal file
View File

@ -0,0 +1,13 @@
xlsx_files = [
"E:/Datasets/SparseFocusDataset/complete_image/complete_image/excel/02/224_dataset_little.xlsx"
]
batch_size = 64
learning_rate = 1e-4
epochs = 600
num_workers = 4
seed = 3407
# 为空表示从头训练
# 非空表示仅加载模型参数初始化
init_weight = ""

14
config_rin.toml Normal file
View File

@ -0,0 +1,14 @@
xlsx_files = [
"E:/Datasets/SparseFocusDataset/complete_image/complete_image/excel/02/new_classification_dataset.xlsx"
]
batch_size = 64
learning_rate = 1e-5
epochs = 600
num_workers = 12
seed = 3407
# 为空表示从头训练
# 非空表示仅加载模型参数初始化
init_weight = ""

75
datasets.py Normal file
View File

@ -0,0 +1,75 @@
import random
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as F
from torch.utils.data import Dataset
# 同步增强
class RINPairTransform:
def __init__(self, train=True, image_size=512):
self.train = train
self.image_size = image_size
self.color_jitter = transforms.ColorJitter(
brightness=(0.9, 1.4),
contrast=(0.8, 1.5),
saturation=(0.8, 1.5),
)
def __call__(self, image, label):
image = F.resize(image, size=(self.image_size, self.image_size))
label = torch.as_tensor(label, dtype=torch.float32).view(9, 9)
if self.train:
# D4 数据增强
# 随机 90 度旋转
k = random.randint(0, 3)
if k == 1:
image = image.transpose(Image.Transpose.ROTATE_90)
label = torch.rot90(label, k=1, dims=(0, 1))
elif k == 2:
image = image.transpose(Image.Transpose.ROTATE_180)
label = torch.rot90(label, k=2, dims=(0, 1))
elif k == 3:
image = image.transpose(Image.Transpose.ROTATE_270)
label = torch.rot90(label, k=3, dims=(0, 1))
# 随机翻转
if random.random() < 0.5:
image = F.hflip(image)
label = torch.flip(label, dims=(1,))
# 颜色增强
image = self.color_jitter(image)
image = F.to_tensor(image)
return image, label
class RIN_Dataset(Dataset):
def __init__(self, image_path_list, label_list, transform=None):
self.image_path_list = image_path_list
self.label_list = label_list
self.transform = transform
def __len__(self):
return len(self.image_path_list)
def __getitem__(self, index):
image_path = self.image_path_list[index]
label = self.label_list[index]
image = Image.open(image_path).convert("RGB")
if self.transform is not None:
image, label = self.transform(image, label)
else:
image = F.resize(image, size=(512, 512))
image = F.to_tensor(image)
label = torch.as_tensor(label, dtype=torch.float32).view(9, 9)
return image, label

137
models.py Normal file
View File

@ -0,0 +1,137 @@
import torch
from torch import nn
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
from torchvision.models.convnext import LayerNorm2d
from torchvision.ops import SqueezeExcitation, Permute, StochasticDepth
# RINet
class ImportanceClassifier(nn.Module):
def __init__(self, in_channels=576, hidden_channels=96):
super().__init__()
self.se = SqueezeExcitation(input_channels=576, squeeze_channels=144)
self.pool = nn.AvgPool2d(kernel_size=7, stride=1)
self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=2, stride=1, bias=True)
self.act1 = nn.Hardswish(inplace=True)
self.conv2 = nn.Conv2d(hidden_channels, 1, kernel_size=1, stride=1, bias=True)
self.act2 = nn.Sigmoid()
def forward(self, x):
x = self.se(x)
x = self.pool(x)
x = self.conv1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.act2(x)
return x.squeeze()
class RINet(nn.Module):
def __init__(self):
super().__init__()
backbone_model = mobilenet_v3_small()
self.backbone = backbone_model.features
self.classifier = ImportanceClassifier(in_channels=576, hidden_channels=96)
def forward(self, x):
x = self.backbone(x)
x = self.classifier(x)
return x
# DPNet
class DFEBlock(nn.Module):
def __init__(self, dim, layer_scale=1e-6, stochastic_depth_prob=0.0):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
nn.BatchNorm2d(dim, eps=1e-4),
Permute([0, 2, 3, 1]),
nn.Linear(dim, 4 * dim, bias=True),
nn.ReLU6(inplace=True),
nn.Linear(4 * dim, dim, bias=True),
Permute([0, 3, 1, 2]),
)
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, mode="row")
def forward(self, x):
result = self.layer_scale * self.block(x)
result = self.stochastic_depth(result)
result += x
return result
class DFEBlockConfig:
def __init__(self, input_channels, out_channels, num_layers):
self.input_channels = input_channels
self.out_channels = out_channels
self.num_layers = num_layers
class DPNet(nn.Module):
def __init__(self, stochastic_depth_prob=0.0, layer_scale=1e-6):
super().__init__()
block_setting = [
DFEBlockConfig(128, 256, 3),
DFEBlockConfig(256, 512, 3),
DFEBlockConfig(512, 1024, 9),
DFEBlockConfig(1024, None, 3),
]
layers = []
firstconv_output_channels = block_setting[0].input_channels
layers.append(
nn.Conv2d(3, firstconv_output_channels, kernel_size=4, stride=4, padding=0, bias=True)
)
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
stage_block_id = 0
for cnf in block_setting:
stage = []
for _ in range(cnf.num_layers):
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
stage.append(
DFEBlock(dim=cnf.input_channels, layer_scale=layer_scale, stochastic_depth_prob=sd_prob)
)
stage_block_id += 1
layers.append(nn.Sequential(*stage))
if cnf.out_channels is not None:
layers.append(
nn.Sequential(
LayerNorm2d(cnf.input_channels, eps=1e-4),
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2, bias=True)
)
)
self.features = nn.Sequential(*layers)
self.pool = nn.MaxPool2d(kernel_size=7, stride=1)
lastblock = block_setting[-1]
lastconv_output_channels = (
lastblock.out_channels
if lastblock.out_channels is not None
else lastblock.input_channels
)
self.regressor = nn.Sequential(
nn.Conv2d(lastconv_output_channels, 1280, kernel_size=1, stride=1, padding=0, bias=True),
nn.Conv2d(1280, 100, kernel_size=1, stride=1, padding=0, bias=True),
nn.Flatten(1),
nn.Linear(100, 1, bias=True)
)
def forward(self, x):
x = self.features(x)
x = self.pool(x)
x = self.regressor(x)
return x

217
old_datasets.py Normal file
View File

@ -0,0 +1,217 @@
import os
import time
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm
import numpy as np
import random
from torchvision.transforms import ToPILImage
import torch
import glob
import torchvision.transforms as transforms
from openpyxl import load_workbook
import json
import re
# REVIEW:
# 这个文件是项目的数据入口层,负责把 Excel、JSON 和图像文件组织成训练/测试阶段可消费的数据。
# 整体上它解决了“如何把作者本地数据组织接到模型上”的问题,但工程抽象较弱,路径与数据格式都强依赖作者环境。
#
# REVIEW:
# 当前文件同时服务于两个任务:
# 1. DPNet 的 defocus distance 回归;
# 2. RINet 的 patch 有效性二分类。
# 这种复用在小型研究项目里很常见,但随着任务差异增大,后续维护成本会逐渐上升。
# DPNet load datasets
#"""
# param root_path_list:Loading path of xlsx.
# return:The training and validation data both contain the path of the image and the defocus - distance label of the image.
#"""
def get_DPNet_train_data_and_label(root_path_list:list):
# REVIEW:
# 该函数读取 DPNet 的训练/验证数据,返回的是 4 个平行列表:
# 训练图像路径、训练标签、验证图像路径、验证标签。
# 好处是简单直接,缺点是多个列表之间的“同位关系”完全靠调用方自己保证。
print("begin load data")
train_image_path_list = []
train_defocus_distance_list = []
val_image_path_list = []
val_defocus_distance_list = []
start_time = time.time()
for i, root_path in enumerate(root_path_list):
xlsx_list = glob.glob(root_path)
for xlsx in xlsx_list:
wb = load_workbook(xlsx)
# REVIEW:
# 这里使用 openpyxl 的旧接口 get_sheet_by_name说明代码更偏实验性质。
# 运行上通常没问题,但长期兼容性和可维护性一般。
train_sheet = wb.get_sheet_by_name('train')
val_sheet = wb.get_sheet_by_name('val')
# REVIEW:
# train_random_num / val_random_num 被用来做一种“每隔若干样本取 1 个”的下采样。
# 这种写法能快速控量,但采样策略不够直观,也不便于复现实验数据划分。
train_random_num = random.randint(0, 9)
val_random_num = random.randint(0, 9)
for row in train_sheet.iter_rows(values_only=True):
# REVIEW:
# 这里直接把作者本地 Linux 绝对路径写进了数据查找逻辑,是本项目可移植性较差的核心原因之一。
image_path_list = glob.glob(os.path.join('E:/Datasets/SparseFocusDataset/224_image', row[0], '*/*.jpg'))
for (i, image_path) in enumerate(image_path_list):
labels = int(image_path.replace('\\', '/').split('/')[-1].split('.')[0])
# REVIEW:
# 标签直接从文件名解析,说明数据组织约定非常强。如果文件命名规则变化,这里会立即失效。
if labels >= -25000 and labels < 25000:
if train_random_num % 5 == 0:
train_image_path_list.append(image_path)
train_defocus_distance_list.append(labels / 1000)
train_random_num += 1
for row in val_sheet.iter_rows(values_only=True):
image_path_list = glob.glob(
os.path.join('E:/Datasets/SparseFocusDataset/224_image', row[0],
'*/*.jpg'))
for (i, image_path) in enumerate(image_path_list):
labels = int(image_path.replace('\\', '/').split('/')[-1].split('.')[0])
if labels >= -25000 and labels < 25000:
if val_random_num % 5 == 0:
val_image_path_list.append(image_path)
val_defocus_distance_list.append(labels / 1000)
val_random_num += 1
print('need time', time.time() - start_time)
print('train data nums', len(train_image_path_list))
print('val data nums', len(val_image_path_list))
return train_image_path_list, train_defocus_distance_list, val_image_path_list, val_defocus_distance_list
# Load RINet need datasets
# """
# param root_path: Loading path of xlsx.
# param type: Control the type of data: rain or val.
# return: The path of the image and the labels indicating whether different areas on the image contain content.
# """
def get_RINet_data(root_path: str, type: str):
# REVIEW:
# 该函数负责准备 RINet 所需的 patch 分类数据。与 DPNet 不同,这里的标签来自 JSON
# 表示 patch 是否“有效/包含足够信息”。
wb = load_workbook(root_path)
train_image_path_list = []
train_patch_effective_list = []
test_sheet = wb.get_sheet_by_name(type)
for (i, row) in enumerate(test_sheet.iter_rows(values_only=True)):
# REVIEW:
# 此处改为依赖作者本地 Windows 绝对路径,说明仓库内容很可能来自多环境拼接。
field_path = os.path.join('E:/Datasets/SparseFocusDataset/complete_image/complete_image/90_patch_dataset_lap_hsv',row[0])
json_data_path = os.path.join(field_path, '224_patch_effective.json')
with open(json_data_path, 'r') as fs:
json_data = json.load(fs)
# REVIEW:
# json.load 后又 json.loads意味着文件中存放的不是标准 JSON 对象而是“JSON 字符串”。
# 这反映出上游数据预处理存在历史包袱,但这里做了兼容。
json_data = json.loads(json_data)
json_data = json_data['image_info_list']
for (i, image_data) in enumerate(json_data):
train_image_path_list.append(json_data[i]['image_path'].replace(r"E:\suqiang\new_microscope_data\90_patch_dataset_hsv", "E:/Datasets/SparseFocusDataset/complete_image/complete_image/90_patch_dataset_lap_hsv"))
train_patch_effective_list.append(json_data[i]['patch_effective'])
return train_image_path_list, train_patch_effective_list
# Load test need datasets
# """
# param root_path: Loading path of xlsx.
# param type: Control the type of data: rain or val.
# return: The path of the image and the labels indicating whether different areas on the image contain content.
# """
def get_test_data_and_label(root_path: str, type: str):
# REVIEW:
# 该函数为测试脚本提供整图 patch 级样本列表及其真实标签。
# 它不仅做了 Excel 读取,还做了 patch 文件展开,职责比函数名看起来更重一些。
wb = load_workbook(root_path)
train_image_path_list = []
train_patch_effective_list = []
train_defocus_distance_list = []
test_sheet = wb.get_sheet_by_name(type)
num_i = 0
for (i, row) in enumerate(test_sheet.iter_rows(values_only=True)):
# REVIEW:
# 这里继续依赖固定数据目录结构,因此 test.py 的可运行性高度依赖作者原始数据目录。
field_path = os.path.join(
r'E:\suqiang\DenseSparse\cropped',
row[0], '*.jpg')
image_path_list = glob.glob(field_path)
for image_patch in image_path_list:
labels = int(image_patch.split('\\')[-1].split('.')[0])
if labels >= -25000 and labels <= 25000:
train_image_path_list.append(image_patch)
# REVIEW:
# effective 这里统一填 0说明测试阶段 patch 是否可用不来自数据标注,而是后续 RINet 推理结果。
train_patch_effective_list.append(0)
train_defocus_distance_list.append(labels / 1000)
num_i = num_i + 1
print('num_i: ',num_i)
return (train_image_path_list, train_patch_effective_list,
train_defocus_distance_list)
class MyDataset(Dataset):
# REVIEW:
# 这是训练阶段的通用 DatasetDPNet 和 RINet 都复用它。
# 通过允许 label 为任意列表实现任务复用,优点是轻量,缺点是接口语义不够强约束。
def __init__(self, data, label, transform=None):
self.data = data
self.label = label
self.transform = transform
def __len__(self):
# return self.data.size(0)
return len(self.data)
def __getitem__(self, index):
imag = self.data[index]
# REVIEW:
# 这里未显式 convert('RGB'),如果数据源中出现灰度图或 RGBA 图,通道数可能和模型预期不一致。
image = Image.open(imag)
if len(self.label) > 0:
label = torch.tensor(self.label[index])
else:
label = torch.tensor([])
if self.transform != None:
image = self.transform(image)
# REVIEW:
# 额外返回 imag 路径有助于日志和调试,但也让 batch 中混入了非张量字段。
return image, label, imag
class MyDataset_test(Dataset):
# REVIEW:
# 测试专用 Dataset 在返回值中加入 effective 字段,以兼容 test.py 的历史接口。
# 当前 effective 在数据层面基本只是占位。
def __init__(self,data,label,effective,transform=None):
self.data = data
self.label = label
self.effective = effective
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image_name = self.data[index]
image = Image.open(image_name)
# image = image.resize((2016, 2016))
# image = np.array(image) # H W C
# image = image[:,16:2032, 216:2232]
label = self.label[index]
if self.transform != None:
image = self.transform(image)
effective = self.effective[index]
# REVIEW:
# 与 MyDataset 不同,这里 label 保持原始 Python 数值而非 tensor两个 Dataset 的接口风格并不完全一致。
return image,label,image_name,effective

268
old_test.py Normal file
View File

@ -0,0 +1,268 @@
import os
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from model.DPNet import register_model
from model.RINet import MobileNetV3_small as con_classification
from dataset import datasets as dataset_F
import util.utils as util
from tqdm import tqdm
import openpyxl
# REVIEW:
# 该脚本是项目的测试/推理主入口,也是两阶段方法真正汇合的地方:
# 先用 RINet 对整图打 patch 有效性分数,再用 DPNet 对候选 patch 做 defocus distance 回归,
# 最后通过中位数聚合得到整图预测。
#
# REVIEW:
# 如果想快速理解作者的方法,这个文件是最关键的,因为它最完整地体现了算法在实际评估中的执行顺序。
# Set the visible GPU devices for CUDA
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
# Image patch size
size = 224
# Paths to test data Excel files
root_path_list = [r'E:\suqiang\DenseSparse\testData_20241227.xlsx']
# REVIEW:
# 测试集路径完全硬编码,说明当前脚本更像作者自己的实验工具而不是通用命令行程序。
# Define data sheet types and test categories (cell or tissue)
data_type_list = ['Sheet', 'Sheet1', 'Sheet1', 'Sheet1']
cell_or_tissue_list = ['tissue', 'tissue']
# Load the classification model and pre-trained weights
classification_model = con_classification()
state_load = torch.load('../weight/RINet_best_model.pt')
classification_model = nn.DataParallel(classification_model).to('cuda')
classification_model.load_state_dict(state_load, False)
classification_model.eval() # Set the model to evaluation mode
# REVIEW:
# RINet 权重和结构之间没有版本检查,只有运行时才能暴露兼容性问题。
# Loop for different models (e.g., DPNet)
for c in range(1):
if c == 0:
state_load = torch.load('../weight/DP_best_model.pt')
model = register_model()
names = 'ours'
model = nn.DataParallel(model).to('cuda')
model.load_state_dict(state_load, False)
model.eval() # Set the model to evaluation mode
# Iterate over test datasets
for j in range(len(root_path_list)):
root_path = root_path_list[j]
data_type = data_type_list[j]
cell_or_tissue = cell_or_tissue_list[j]
# Load test data and labels from the specified Excel sheet
(train_image_path_list, train_patch_effective_list,
train_defocus_distance_list) = dataset_F.get_test_data_and_label(root_path=root_path, type=f'{data_type}')
# Define image transformation pipeline
train_transform = transforms.Compose([
transforms.CenterCrop((2016, 2016)), # Crop the center region of the image
transforms.ToTensor(), # Convert the image to a tensor
])
# REVIEW:
# 这里把整图固定裁成 2016x2016并配合 size=224 形成 9x9 patch 网格。
# 因此测试流程对输入尺寸有很强假设。
# Create a custom dataset and DataLoader
train_data = dataset_F.MyDataset_test(train_image_path_list, train_defocus_distance_list,
train_patch_effective_list, train_transform)
train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, pin_memory=True, num_workers=0)
# Initialize an Excel workbook to save results
workbook = openpyxl.Workbook()
sheet = workbook.create_sheet(index=0)
n = 0 # Row index for Excel sheet
# Iterate through the test data
for i, (img, label, image_name, effective) in tqdm(enumerate(train_loader)):
image_path = image_name[0]
image_label = int(image_path.split('\\')[-1].split('.')[0]) # Extract ground truth label
# Add header row to the Excel sheet on the first iteration
if i == 0:
sheet.cell(1, 1).value = 'groundTruth'
sheet.cell(1, 2).value = 'all_prediction'
sheet.cell(1, 3).value = 'all_can_use_prediction'
sheet.cell(1, 4).value = '61_prediction'
sheet.cell(1, 5).value = '51_prediction'
sheet.cell(1, 6).value = '41_prediction'
sheet.cell(1, 7).value = '31_prediction'
sheet.cell(1, 8).value = '25_prediction'
sheet.cell(1, 9).value = '19_prediction'
sheet.cell(1, 10).value = '15_prediction'
sheet.cell(1, 11).value = '9_prediction'
sheet.cell(1, 12).value = '5_prediction'
sheet.cell(1, 13).value = '3_prediction'
sheet.cell(1, 14).value = '1_prediction'
sheet.cell(1, 15).value = 'path'
sheet.cell(1, 16).value = 'max_can_use_patch_num'
sheet.cell(1, 17).value = 'avg_prediction'
# Preprocess the image for the classification model
img2 = F.interpolate(img, size=(512, 512), mode='bilinear', align_corners=False)
img = torch.cuda.FloatTensor(np.array(img))
img2 = torch.cuda.FloatTensor(np.array(img2))
# REVIEW:
# 这里继续沿用了 Tensor -> numpy -> CUDA Tensor 的旧式转换写法。
if torch.cuda.is_available():
img = img.cuda()
img2 = img2.cuda()
# Perform classification on the preprocessed image
output_weight = classification_model(img2)
output_weight_userful = output_weight[torch.abs(output_weight) > 0.8] # Filter predictions above a threshold
can_use_num = len(output_weight_userful) # Count usable predictions
# REVIEW:
# patch 是否“可用”由阈值 0.8 决定,这是一个非常关键的经验超参数。
# Initialize usable patch numbers for different thresholds
can_use_num_61, can_use_num_51, can_use_num_41 = 61, 51, 41
can_use_num_31, can_use_num_25, can_use_num_19 = 31, 25, 19
can_use_num_15, can_use_num_9, can_use_num_5 = 15, 9, 5
can_use_num_3, can_use_num_1 = 3, 1
# Handle cases with no usable patches
if can_use_num == 0:
util.print_info(f'{image_name} don\'t have enough info')
continue
# REVIEW:
# 没有任何 patch 通过阈值时直接跳过样本,会导致结果文件中少样本,这是需要知晓的隐式行为。
# Adjust usable patch numbers based on conditions
if can_use_num % 2 == 0:
can_use_num -= 1
if 51 <= can_use_num < 61:
can_use_num_61 = can_use_num
elif 41 <= can_use_num < 51:
can_use_num_61 = can_use_num_51 = can_use_num
elif 31 <= can_use_num < 41:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num
elif 25 <= can_use_num < 31:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num
elif 19 <= can_use_num < 25:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num
elif 15 <= can_use_num < 19:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num
elif 9 <= can_use_num < 15:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num
elif 5 <= can_use_num < 9:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num
elif 3 <= can_use_num < 5:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num
elif can_use_num < 3:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num_3 = 1
# Extract top predictions for each patch group
max_all_can_value, max_all_can_index = torch.topk(output_weight, can_use_num)
max_61_value, max_61_index = torch.topk(output_weight, can_use_num_61)
max_51_value, max_51_index = torch.topk(output_weight, can_use_num_51)
max_41_value, max_41_index = torch.topk(output_weight, can_use_num_41)
max_31_value, max_31_index = torch.topk(output_weight, can_use_num_31)
max_25_value, max_25_index = torch.topk(output_weight, can_use_num_25)
max_19_value, max_19_index = torch.topk(output_weight, can_use_num_19)
max_15_value, max_15_index = torch.topk(output_weight, can_use_num_15)
max_9_value, max_9_index = torch.topk(output_weight, can_use_num_9)
max_5_value, max_5_index = torch.topk(output_weight, can_use_num_5)
max_3_value, max_3_index = torch.topk(output_weight, can_use_num_3)
max_1_value, max_1_index = torch.topk(output_weight, can_use_num_1)
patch_image_list_all = []
patch_image_list_all_can_use = []
patch_image_list_61 = []
patch_image_list_51 = []
patch_image_list_41 = []
patch_image_list_31 = []
patch_image_list_25 = []
patch_image_list_19 = []
patch_image_list_15 = []
patch_image_list_9 = []
patch_image_list_5 = []
patch_image_list_3 = []
patch_image_list_1 = []
label_list = []
# REVIEW:
# 这里同时比较多套 top-k 聚合结果,说明作者不仅关心最终预测,也在研究“保留多少 patch 更合理”。
for k in range(0, 2016, size):
a = int(k / size)
for j in range(0, 2016, size):
b = int(j / size)
patch_image = img[:, :, k:k + size, j:j + size]
patch_image_list_all.append(patch_image)
if len(patch_image_list_all) != 0:
out_all = model(torch.cat(patch_image_list_all, dim=0))
# REVIEW:
# 81 个 patch 一次性送入 DPNet 是个不错的实现,避免了逐 patch 前向的重复开销。
for z in range(len(out_all)):
if torch.any(max_all_can_index.eq(z)):
patch_image_list_all_can_use.append(out_all[z].item())
if torch.any(max_61_index.eq(z)):
patch_image_list_61.append(out_all[z].item())
if torch.any(max_51_index.eq(z)):
patch_image_list_51.append(out_all[z].item())
if torch.any(max_41_index.eq(z)):
patch_image_list_41.append(out_all[z].item())
if torch.any(max_31_index.eq(z)):
patch_image_list_31.append(out_all[z].item())
if torch.any(max_25_index.eq(z)):
patch_image_list_25.append(out_all[z].item())
if torch.any(max_19_index.eq(z)):
patch_image_list_19.append(out_all[z].item())
if torch.any(max_15_index.eq(z)):
patch_image_list_15.append(out_all[z].item())
if torch.any(max_9_index.eq(z)):
patch_image_list_9.append(out_all[z].item())
if torch.any(max_5_index.eq(z)):
patch_image_list_5.append(out_all[z].item())
if torch.any(max_3_index.eq(z)):
patch_image_list_3.append(out_all[z].item())
if torch.any(max_1_index.eq(z)):
patch_image_list_1.append(out_all[z].item())
nums_all = float(np.median(out_all.tolist()))
avg_nums_all = float(np.mean(out_all.tolist()))
data = []
data.append(image_label)
data.append(float(np.median(out_all.tolist())) * 1000)
data.append(float(np.median(patch_image_list_all_can_use)) * 1000)
data.append(float(np.median(patch_image_list_61)) * 1000)
data.append(float(np.median(patch_image_list_51)) * 1000)
data.append(float(np.median(patch_image_list_41)) * 1000)
data.append(float(np.median(patch_image_list_31)) * 1000)
data.append(float(np.median(patch_image_list_25)) * 1000)
data.append(float(np.median(patch_image_list_19)) * 1000)
data.append(float(np.median(patch_image_list_15)) * 1000)
data.append(float(np.median(patch_image_list_9)) * 1000)
data.append(float(np.median(patch_image_list_5)) * 1000)
data.append(float(np.median(patch_image_list_3)) * 1000)
data.append(float(np.median(patch_image_list_1)) * 1000)
data.append((image_path.split('\\')[-3] + '\\' +image_path.split('\\')[-2] + '\\' +image_path.split('\\')[-1]))
data.append(len(output_weight_userful))
data.append(float(np.mean(out_all.tolist())) * 1000)
for (m, info) in enumerate(data):
sheet.cell(n + 2, m + 1).value = info
n = n + 1
util.save_val_image(1, img,
score=out_all, image_name=image_name,
type_str=cell_or_tissue, type='test', type_label=output_weight)
# REVIEW:
# 数值输出之外还保留了网格可视化,这对解释模型行为和排查异常样本很有帮助。
workbook.save(f'./{names}_to_{cell_or_tissue}_{data_type}_2.xlsx')
# REVIEW:
# 最终结果保存为 Excel非常贴合论文实验场景如果做批量评估结构化文本格式会更便于分析。

676
playground.ipynb Normal file

File diff suppressed because one or more lines are too long

307
train_dpn.py Normal file
View File

@ -0,0 +1,307 @@
import argparse
import logging
import shutil
import time
from tqdm import tqdm
from datetime import datetime
from pathlib import Path
from pprint import pformat
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import tomllib
from models import DPNet
import old_datasets as dataset_F
import utils
def train_epoch(model, train_loader, criterion, optimizer, scheduler, device, writer, global_step):
model.train()
running_loss = 0.0
sample_count = 0
for img, label, image_name in tqdm(train_loader, desc="Train", bar_format="{l_bar}{bar:20}{r_bar}"):
img = img.to(device, non_blocking=True)
label = label.to(device, non_blocking=True).float().view(-1)
optimizer.zero_grad(set_to_none=True)
output = model(img).view(-1)
loss = criterion(output, label)
loss.backward()
optimizer.step()
scheduler.step()
global_step += 1
lr = optimizer.param_groups[0]["lr"]
writer.add_scalar("lr", lr, global_step)
batch_size = img.size(0)
running_loss += loss.item() * batch_size
sample_count += batch_size
epoch_loss = running_loss / sample_count
return epoch_loss, global_step
@torch.no_grad()
def validate_epoch(model, val_loader, criterion, device):
model.eval()
running_loss = 0.0
sample_count = 0
for img, label, image_name in tqdm(val_loader, desc="Validate", bar_format="{l_bar}{bar:20}{r_bar}"):
img = img.to(device, non_blocking=True)
label = label.to(device, non_blocking=True).float().view(-1)
output = model(img).view(-1)
loss = criterion(output, label)
batch_size = img.size(0)
running_loss += loss.item() * batch_size
sample_count += batch_size
epoch_loss = running_loss / sample_count
return epoch_loss
def main():
# =========================
# 1. 读取配置
# =========================
parser = argparse.ArgumentParser(description="Train DPNet")
parser.add_argument("--config", type=str, required=True, help="Path to TOML config file")
args = parser.parse_args()
config_path = Path(args.config)
with config_path.open("rb") as f:
config = tomllib.load(f)
xlsx_files = config["xlsx_files"]
batch_size = config["batch_size"]
learning_rate = config["learning_rate"]
epochs = config["epochs"]
num_workers = config["num_workers"]
seed = config["seed"]
init_weight = config["init_weight"]
# =========================
# 2. 创建输出目录
# =========================
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_dpn")
run_dir = Path.cwd() / run_name
run_dir.mkdir(parents=True, exist_ok=False)
shutil.copy2(config_path, run_dir / config_path.name)
# =========================
# 3. 初始化日志与 TensorBoard
# =========================
logger = logging.getLogger("dpnet_train")
logger.setLevel(logging.INFO)
logger.propagate = False
logger.handlers.clear()
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
file_handler = logging.FileHandler(run_dir / "train.log", encoding="utf-8")
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
writer = SummaryWriter(log_dir=str(run_dir))
logger.info(f"Run directory: {run_dir}")
logger.info(f"Config path: {config_path}")
logger.info(f"Loaded config:\n{pformat(config, sort_dicts=False)}")
# =========================
# 4. 设置随机种子与设备
# =========================
utils.set_seeds(seed=seed)
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
logger.info(f"Using device: {device}")
# =========================
# 5. 准备数据集与 DataLoader
# =========================
train_image_path_list, train_defocus_distance_list, val_image_path_list, val_defocus_distance_list = \
dataset_F.get_DPNet_train_data_and_label(root_path_list=xlsx_files)
train_transform = transforms.Compose([
transforms.ColorJitter(
brightness=(0.9, 1.4),
contrast=(0.8, 1.5),
saturation=(0.8, 1.5),
),
transforms.ToTensor(),
])
val_transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = dataset_F.MyDataset(
train_image_path_list,
train_defocus_distance_list,
train_transform,
)
val_dataset = dataset_F.MyDataset(
val_image_path_list,
val_defocus_distance_list,
val_transform,
)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=(device.type == "cuda"),
persistent_workers=(num_workers > 0),
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=(device.type == "cuda"),
persistent_workers=(num_workers > 0),
)
logger.info(f"Train dataset size: {len(train_dataset)}")
logger.info(f"Val dataset size: {len(val_dataset)}")
logger.info(f"Train steps per epoch: {len(train_loader)}")
# =========================
# 6. 准备模型
# =========================
model = DPNet().to(device)
if init_weight:
state_dict = torch.load(init_weight, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
logger.info(f"Loaded init weight from: {init_weight}")
else:
logger.info("Training from scratch")
# =========================
# 7. 准备损失函数、优化器、调度器
# =========================
criterion = nn.MSELoss(reduction="mean")
optimizer = Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
total_steps = epochs * len(train_loader)
warmup_steps = min(2000, total_steps - 1) if total_steps > 1 else 0
scheduler = SequentialLR(
optimizer,
schedulers=[
LinearLR(
optimizer,
start_factor=0.1,
end_factor=1.0,
total_iters=warmup_steps,
),
CosineAnnealingLR(
optimizer,
T_max=max(1, total_steps - warmup_steps),
eta_min=0.0,
),
],
milestones=[warmup_steps],
)
logger.info("Loss: MSELoss(reduction='mean')")
logger.info(f"Optimizer: Adam(lr={learning_rate}, betas=(0.9, 0.999))")
logger.info("Scheduler: step-based warmup + cosine annealing")
logger.info(f"Total training steps: {total_steps}")
logger.info(f"Warmup steps: {warmup_steps}")
# =========================
# 8. 开始训练循环
# =========================
best_val_loss = float("inf")
global_step = 0
start_time = time.time()
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train_loss, global_step = train_epoch(
model=model,
train_loader=train_loader,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
device=device,
writer=writer,
global_step=global_step,
)
val_loss = validate_epoch(
model=model,
val_loader=val_loader,
criterion=criterion,
device=device,
)
epoch_total_time = time.time() - epoch_start_time
writer.add_scalar("train_loss", train_loss, epoch)
writer.add_scalar("val_loss", val_loss, epoch)
writer.add_scalar("epoch_total_time", epoch_total_time, epoch)
logger.info(
f"Epoch [{epoch}/{epochs}] | "
f"train_loss={train_loss:.8f} | "
f"val_loss={val_loss:.8f} | "
f"epoch_total_time={epoch_total_time:.2f}s | "
f"global_step={global_step}"
)
torch.save(model.state_dict(), run_dir / "last_dpn.pth")
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), run_dir / "best_dpn.pth")
logger.info(f"Best model updated, best_val_loss={best_val_loss:.8f}")
# =========================
# 9. 收尾
# =========================
total_time = time.time() - start_time
logger.info("Training finished")
logger.info(f"Best validation loss: {best_val_loss:.8f}")
logger.info(f"Total time: {total_time:.2f} seconds")
writer.close()
if __name__ == "__main__":
main()

217
train_rin.py Normal file
View File

@ -0,0 +1,217 @@
import shutil
import time
from tqdm import tqdm
from datetime import datetime
from pathlib import Path
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from models import RINet
import old_datasets as dataset_F
from datasets import RIN_Dataset, RINPairTransform
import utils
# 训练一轮
def train_epoch(model, loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
total_samples = 0
for images, labels in tqdm(
loader,
desc="Train:",
bar_format="{l_bar}{bar:20}{r_bar}",
leave=False,
):
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).view(-1)
optimizer.zero_grad(set_to_none=True)
outputs = model(images).view(-1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
this_batch_size = images.size(0)
running_loss += loss.item() * this_batch_size
total_samples += this_batch_size
epoch_loss = running_loss / total_samples
return epoch_loss
# 验证一轮
@torch.no_grad()
def valid_epoch(model, loader, criterion, device):
model.eval()
running_loss = 0.0
total_samples = 0
for images, labels in tqdm(
loader,
desc=f"Valid:",
bar_format="{l_bar}{bar:20}{r_bar}",
leave=False,
):
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).view(-1)
outputs = model(images).view(-1)
loss = criterion(outputs, labels)
this_batch_size = images.size(0)
running_loss += loss.item() * this_batch_size
total_samples += this_batch_size
epoch_loss = running_loss / total_samples
return epoch_loss
# 主训练函数
def main():
# ========== 1 配置文件与超参数 ==========
config, config_path = utils.get_hyperparams()
XLSX_FILES = config["xlsx_files"]
BATCH_SIZE = config["batch_size"]
NUM_WORKERS = config["num_workers"]
LEARNING_RATE = config["learning_rate"]
NUM_EPOCHS = config["epochs"]
SEED = config["seed"]
INIT_WEIGHT_PATH = config["init_weight"]
# ========== 2 创建输出文件目录 ==========
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_RIN")
run_dir = Path.cwd() / run_name
run_dir.mkdir(parents=True, exist_ok=False)
shutil.copy2(config_path, run_dir / config_path.name)
# ========== 3 日志、tensorboard、随机种子与设备 ==========
logger = utils.get_logger(__name__, run_dir / "train.log")
writer = SummaryWriter(str(run_dir / "run"))
utils.set_seeds(SEED)
device = torch.device("cuda:0")
logger.info(f"Config path: {config_path}")
logger.info(f"Loaded config: {str(config)}")
logger.info(f"Run directory: {run_dir}")
logger.info(f"Using device: {device}")
# ========== 4 数据与 loader ==========
train_image_path_list, train_patch_effective_list = (
dataset_F.get_RINet_data(XLSX_FILES[0], "train")
)
valid_image_path_list, valid_patch_effective_list = (
dataset_F.get_RINet_data(XLSX_FILES[0], "val")
)
train_transform = RINPairTransform(train=True, image_size=512)
valid_transform = RINPairTransform(train=False, image_size=512)
train_set = RIN_Dataset(
train_image_path_list,
train_patch_effective_list,
train_transform,
)
valid_set = RIN_Dataset(
valid_image_path_list,
valid_patch_effective_list,
valid_transform,
)
train_loader = DataLoader(
dataset=train_set,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=True,
persistent_workers=(NUM_WORKERS > 0),
)
valid_loader = DataLoader(
dataset=valid_set,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True,
persistent_workers=(NUM_WORKERS > 0),
)
logger.info(f"Train dataset size: {len(train_set)}")
logger.info(f"Val dataset size: {len(valid_set)}")
logger.info(f"Train steps per epoch: {len(train_loader)}")
# ========== 5 模型、损失、优化器、调度器 ==========
model = RINet().to(device)
if INIT_WEIGHT_PATH:
state_dict = torch.load(INIT_WEIGHT_PATH, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
logger.info(f"Loaded init weight from: {INIT_WEIGHT_PATH}")
else:
logger.info("Training from scratch")
criterion = nn.BCELoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
scheduler = utils.get_warmup_cosine_scheduler(optimizer, NUM_EPOCHS)
logger.info("Loss: BCELoss()")
logger.info(f"Optimizer: Adam(lr={LEARNING_RATE}, betas=(0.9, 0.999))")
logger.info("Scheduler: epoch-based warmup + cosine annealing")
# ========== 6 开始训练 ==========
logger.info("START TRAINING")
best_valid_loss = float("inf")
try:
for epoch in range(1, NUM_EPOCHS + 1):
epoch_start_time = time.time()
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
valid_loss = valid_epoch(model, valid_loader, criterion, device)
epoch_lr = optimizer.param_groups[0]["lr"] # 当前轮学习率
scheduler.step()
epoch_time_cost = time.time() - epoch_start_time
# 如果更好则保存
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), run_dir / "best_model.pt")
logger.info(f"Best model saved, valid_loss = {best_valid_loss:.4f}")
# 日志与 tensorboard
logger.info(
f"Epoch [{epoch}/{NUM_EPOCHS}] "
f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f} | "
f"Best Valid Loss: {best_valid_loss:.4f} | "
f"Epoch Time Cost: {epoch_time_cost:.2f} s | "
f"Epoch Learning Rate: {epoch_lr:.6e}"
)
writer.add_scalar("Loss/train", train_loss, epoch)
writer.add_scalar("Loss/valid", valid_loss, epoch)
writer.add_scalar("Loss/best_valid", best_valid_loss, epoch)
writer.add_scalar("Time/epoch", epoch_time_cost, epoch)
writer.add_scalar("Time/learning_rate", epoch_lr, epoch)
except KeyboardInterrupt:
logger.info("Training interrupted by user")
finally:
torch.save(model.state_dict(), run_dir / "last_model.pt")
logger.info("Last model saved")
writer.close()
logger.info("TensorBoard writer closed")
logger.info(f"Training finished, best validation loss: {best_valid_loss:.8f}")
if __name__ == "__main__":
main()

110
utils.py Normal file
View File

@ -0,0 +1,110 @@
import math
import torch
import random
import logging
import tomllib
import argparse
from pathlib import Path
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
import torch.distributed as dist
# 设置固定的随机数种子
def set_seeds(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 获取日志句柄
def get_logger(name, log_file):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
logger.propagate = False
logger.handlers.clear()
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
file_handler = logging.FileHandler(
log_file,
mode="a",
encoding="utf-8",
)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(console_handler)
logger.addHandler(file_handler)
return logger
# 初始化 DDP 并行
def setup_distributed():
if not torch.cuda.is_available():
raise RuntimeError("DDP training requires CUDA.")
dist.init_process_group(backend="nccl")
local_rank = dist.get_node_local_rank()
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
is_main_process = (rank == 0)
return local_rank, rank, world_size, device, is_main_process
# 释放 DDP 并行
def cleanup_distributed():
dist.destroy_process_group()
# 命令行参数解析配置文件
def get_hyperparams():
parser = argparse.ArgumentParser()
parser.add_argument("config", help="Path to TOML config file")
args = parser.parse_args()
config_path = Path(args.config)
with config_path.open("rb") as f:
return tomllib.load(f), config_path
# 线性预热与余弦退火调度器
def get_warmup_cosine_scheduler(optimizer, epochs):
max_warmup_epochs, start_factor, eta_min_factor = 10, 0.1, 0.0
warmup_epochs = min(epochs // 10, max_warmup_epochs)
cosine_epochs = epochs - warmup_epochs
def lr_lambda(current_epoch):
# 线性预热阶段
if warmup_epochs > 0 and current_epoch < warmup_epochs:
return start_factor + (1.0 - start_factor) * (
current_epoch / warmup_epochs
)
# 余弦退火阶段
cosine_epoch = current_epoch - warmup_epochs
return eta_min_factor + (1.0 - eta_min_factor) * 0.5 * (
1.0 + math.cos(math.pi * cosine_epoch / cosine_epochs)
)
return LambdaLR(optimizer, lr_lambda=lr_lambda)