first commit
This commit is contained in:
commit
1b6fcf93d0
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
.vscode/
|
||||
|
||||
20*_RIN*/
|
||||
20*_RIN_OLD*/
|
||||
13
config_dpn.toml
Normal file
13
config_dpn.toml
Normal file
@ -0,0 +1,13 @@
|
||||
xlsx_files = [
|
||||
"E:/Datasets/SparseFocusDataset/complete_image/complete_image/excel/02/224_dataset_little.xlsx"
|
||||
]
|
||||
|
||||
batch_size = 64
|
||||
learning_rate = 1e-4
|
||||
epochs = 600
|
||||
num_workers = 4
|
||||
seed = 3407
|
||||
|
||||
# 为空表示从头训练
|
||||
# 非空表示仅加载模型参数初始化
|
||||
init_weight = ""
|
||||
14
config_rin.toml
Normal file
14
config_rin.toml
Normal file
@ -0,0 +1,14 @@
|
||||
xlsx_files = [
|
||||
"E:/Datasets/SparseFocusDataset/complete_image/complete_image/excel/02/new_classification_dataset.xlsx"
|
||||
]
|
||||
|
||||
|
||||
batch_size = 64
|
||||
learning_rate = 1e-5
|
||||
epochs = 600
|
||||
num_workers = 12
|
||||
seed = 3407
|
||||
|
||||
# 为空表示从头训练
|
||||
# 非空表示仅加载模型参数初始化
|
||||
init_weight = ""
|
||||
75
datasets.py
Normal file
75
datasets.py
Normal file
@ -0,0 +1,75 @@
|
||||
import random
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional as F
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
# 同步增强
|
||||
class RINPairTransform:
|
||||
def __init__(self, train=True, image_size=512):
|
||||
self.train = train
|
||||
self.image_size = image_size
|
||||
self.color_jitter = transforms.ColorJitter(
|
||||
brightness=(0.9, 1.4),
|
||||
contrast=(0.8, 1.5),
|
||||
saturation=(0.8, 1.5),
|
||||
)
|
||||
|
||||
def __call__(self, image, label):
|
||||
image = F.resize(image, size=(self.image_size, self.image_size))
|
||||
label = torch.as_tensor(label, dtype=torch.float32).view(9, 9)
|
||||
|
||||
if self.train:
|
||||
# D4 数据增强
|
||||
# 随机 90 度旋转
|
||||
k = random.randint(0, 3)
|
||||
|
||||
if k == 1:
|
||||
image = image.transpose(Image.Transpose.ROTATE_90)
|
||||
label = torch.rot90(label, k=1, dims=(0, 1))
|
||||
|
||||
elif k == 2:
|
||||
image = image.transpose(Image.Transpose.ROTATE_180)
|
||||
label = torch.rot90(label, k=2, dims=(0, 1))
|
||||
|
||||
elif k == 3:
|
||||
image = image.transpose(Image.Transpose.ROTATE_270)
|
||||
label = torch.rot90(label, k=3, dims=(0, 1))
|
||||
|
||||
# 随机翻转
|
||||
if random.random() < 0.5:
|
||||
image = F.hflip(image)
|
||||
label = torch.flip(label, dims=(1,))
|
||||
|
||||
# 颜色增强
|
||||
image = self.color_jitter(image)
|
||||
|
||||
image = F.to_tensor(image)
|
||||
return image, label
|
||||
|
||||
|
||||
class RIN_Dataset(Dataset):
|
||||
def __init__(self, image_path_list, label_list, transform=None):
|
||||
self.image_path_list = image_path_list
|
||||
self.label_list = label_list
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_path_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
image_path = self.image_path_list[index]
|
||||
label = self.label_list[index]
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
if self.transform is not None:
|
||||
image, label = self.transform(image, label)
|
||||
else:
|
||||
image = F.resize(image, size=(512, 512))
|
||||
image = F.to_tensor(image)
|
||||
label = torch.as_tensor(label, dtype=torch.float32).view(9, 9)
|
||||
|
||||
return image, label
|
||||
137
models.py
Normal file
137
models.py
Normal file
@ -0,0 +1,137 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
|
||||
from torchvision.models.convnext import LayerNorm2d
|
||||
from torchvision.ops import SqueezeExcitation, Permute, StochasticDepth
|
||||
|
||||
|
||||
# RINet
|
||||
class ImportanceClassifier(nn.Module):
|
||||
def __init__(self, in_channels=576, hidden_channels=96):
|
||||
super().__init__()
|
||||
|
||||
self.se = SqueezeExcitation(input_channels=576, squeeze_channels=144)
|
||||
self.pool = nn.AvgPool2d(kernel_size=7, stride=1)
|
||||
self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=2, stride=1, bias=True)
|
||||
self.act1 = nn.Hardswish(inplace=True)
|
||||
self.conv2 = nn.Conv2d(hidden_channels, 1, kernel_size=1, stride=1, bias=True)
|
||||
self.act2 = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.se(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv1(x)
|
||||
x = self.act1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.act2(x)
|
||||
return x.squeeze()
|
||||
|
||||
|
||||
class RINet(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
backbone_model = mobilenet_v3_small()
|
||||
|
||||
self.backbone = backbone_model.features
|
||||
self.classifier = ImportanceClassifier(in_channels=576, hidden_channels=96)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
# DPNet
|
||||
class DFEBlock(nn.Module):
|
||||
def __init__(self, dim, layer_scale=1e-6, stochastic_depth_prob=0.0):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
|
||||
nn.BatchNorm2d(dim, eps=1e-4),
|
||||
Permute([0, 2, 3, 1]),
|
||||
nn.Linear(dim, 4 * dim, bias=True),
|
||||
nn.ReLU6(inplace=True),
|
||||
nn.Linear(4 * dim, dim, bias=True),
|
||||
Permute([0, 3, 1, 2]),
|
||||
)
|
||||
|
||||
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
|
||||
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, mode="row")
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
result = self.layer_scale * self.block(x)
|
||||
result = self.stochastic_depth(result)
|
||||
result += x
|
||||
return result
|
||||
|
||||
|
||||
class DFEBlockConfig:
|
||||
def __init__(self, input_channels, out_channels, num_layers):
|
||||
self.input_channels = input_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_layers = num_layers
|
||||
|
||||
|
||||
class DPNet(nn.Module):
|
||||
def __init__(self, stochastic_depth_prob=0.0, layer_scale=1e-6):
|
||||
super().__init__()
|
||||
|
||||
block_setting = [
|
||||
DFEBlockConfig(128, 256, 3),
|
||||
DFEBlockConfig(256, 512, 3),
|
||||
DFEBlockConfig(512, 1024, 9),
|
||||
DFEBlockConfig(1024, None, 3),
|
||||
]
|
||||
|
||||
layers = []
|
||||
|
||||
firstconv_output_channels = block_setting[0].input_channels
|
||||
layers.append(
|
||||
nn.Conv2d(3, firstconv_output_channels, kernel_size=4, stride=4, padding=0, bias=True)
|
||||
)
|
||||
|
||||
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
|
||||
stage_block_id = 0
|
||||
|
||||
for cnf in block_setting:
|
||||
stage = []
|
||||
for _ in range(cnf.num_layers):
|
||||
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
|
||||
stage.append(
|
||||
DFEBlock(dim=cnf.input_channels, layer_scale=layer_scale, stochastic_depth_prob=sd_prob)
|
||||
)
|
||||
stage_block_id += 1
|
||||
layers.append(nn.Sequential(*stage))
|
||||
if cnf.out_channels is not None:
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
LayerNorm2d(cnf.input_channels, eps=1e-4),
|
||||
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2, bias=True)
|
||||
)
|
||||
)
|
||||
|
||||
self.features = nn.Sequential(*layers)
|
||||
self.pool = nn.MaxPool2d(kernel_size=7, stride=1)
|
||||
|
||||
lastblock = block_setting[-1]
|
||||
lastconv_output_channels = (
|
||||
lastblock.out_channels
|
||||
if lastblock.out_channels is not None
|
||||
else lastblock.input_channels
|
||||
)
|
||||
|
||||
self.regressor = nn.Sequential(
|
||||
nn.Conv2d(lastconv_output_channels, 1280, kernel_size=1, stride=1, padding=0, bias=True),
|
||||
nn.Conv2d(1280, 100, kernel_size=1, stride=1, padding=0, bias=True),
|
||||
nn.Flatten(1),
|
||||
nn.Linear(100, 1, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.pool(x)
|
||||
x = self.regressor(x)
|
||||
return x
|
||||
217
old_datasets.py
Normal file
217
old_datasets.py
Normal file
@ -0,0 +1,217 @@
|
||||
import os
|
||||
|
||||
import time
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import random
|
||||
from torchvision.transforms import ToPILImage
|
||||
import torch
|
||||
import glob
|
||||
import torchvision.transforms as transforms
|
||||
from openpyxl import load_workbook
|
||||
import json
|
||||
import re
|
||||
|
||||
# REVIEW:
|
||||
# 这个文件是项目的数据入口层,负责把 Excel、JSON 和图像文件组织成训练/测试阶段可消费的数据。
|
||||
# 整体上它解决了“如何把作者本地数据组织接到模型上”的问题,但工程抽象较弱,路径与数据格式都强依赖作者环境。
|
||||
#
|
||||
# REVIEW:
|
||||
# 当前文件同时服务于两个任务:
|
||||
# 1. DPNet 的 defocus distance 回归;
|
||||
# 2. RINet 的 patch 有效性二分类。
|
||||
# 这种复用在小型研究项目里很常见,但随着任务差异增大,后续维护成本会逐渐上升。
|
||||
|
||||
|
||||
# DPNet load datasets
|
||||
#"""
|
||||
# param root_path_list:Loading path of xlsx.
|
||||
# return:The training and validation data both contain the path of the image and the defocus - distance label of the image.
|
||||
#"""
|
||||
def get_DPNet_train_data_and_label(root_path_list:list):
|
||||
# REVIEW:
|
||||
# 该函数读取 DPNet 的训练/验证数据,返回的是 4 个平行列表:
|
||||
# 训练图像路径、训练标签、验证图像路径、验证标签。
|
||||
# 好处是简单直接,缺点是多个列表之间的“同位关系”完全靠调用方自己保证。
|
||||
|
||||
print("begin load data")
|
||||
train_image_path_list = []
|
||||
train_defocus_distance_list = []
|
||||
val_image_path_list = []
|
||||
val_defocus_distance_list = []
|
||||
start_time = time.time()
|
||||
for i, root_path in enumerate(root_path_list):
|
||||
xlsx_list = glob.glob(root_path)
|
||||
for xlsx in xlsx_list:
|
||||
wb = load_workbook(xlsx)
|
||||
# REVIEW:
|
||||
# 这里使用 openpyxl 的旧接口 get_sheet_by_name,说明代码更偏实验性质。
|
||||
# 运行上通常没问题,但长期兼容性和可维护性一般。
|
||||
train_sheet = wb.get_sheet_by_name('train')
|
||||
val_sheet = wb.get_sheet_by_name('val')
|
||||
# REVIEW:
|
||||
# train_random_num / val_random_num 被用来做一种“每隔若干样本取 1 个”的下采样。
|
||||
# 这种写法能快速控量,但采样策略不够直观,也不便于复现实验数据划分。
|
||||
train_random_num = random.randint(0, 9)
|
||||
val_random_num = random.randint(0, 9)
|
||||
for row in train_sheet.iter_rows(values_only=True):
|
||||
# REVIEW:
|
||||
# 这里直接把作者本地 Linux 绝对路径写进了数据查找逻辑,是本项目可移植性较差的核心原因之一。
|
||||
image_path_list = glob.glob(os.path.join('E:/Datasets/SparseFocusDataset/224_image', row[0], '*/*.jpg'))
|
||||
for (i, image_path) in enumerate(image_path_list):
|
||||
labels = int(image_path.replace('\\', '/').split('/')[-1].split('.')[0])
|
||||
# REVIEW:
|
||||
# 标签直接从文件名解析,说明数据组织约定非常强。如果文件命名规则变化,这里会立即失效。
|
||||
if labels >= -25000 and labels < 25000:
|
||||
if train_random_num % 5 == 0:
|
||||
train_image_path_list.append(image_path)
|
||||
train_defocus_distance_list.append(labels / 1000)
|
||||
train_random_num += 1
|
||||
for row in val_sheet.iter_rows(values_only=True):
|
||||
image_path_list = glob.glob(
|
||||
os.path.join('E:/Datasets/SparseFocusDataset/224_image', row[0],
|
||||
'*/*.jpg'))
|
||||
for (i, image_path) in enumerate(image_path_list):
|
||||
labels = int(image_path.replace('\\', '/').split('/')[-1].split('.')[0])
|
||||
if labels >= -25000 and labels < 25000:
|
||||
if val_random_num % 5 == 0:
|
||||
val_image_path_list.append(image_path)
|
||||
val_defocus_distance_list.append(labels / 1000)
|
||||
val_random_num += 1
|
||||
print('need time:', time.time() - start_time)
|
||||
print('train data nums:', len(train_image_path_list))
|
||||
print('val data nums:', len(val_image_path_list))
|
||||
return train_image_path_list, train_defocus_distance_list, val_image_path_list, val_defocus_distance_list
|
||||
|
||||
# Load RINet need datasets
|
||||
# """
|
||||
# param root_path: Loading path of xlsx.
|
||||
# param type: Control the type of data: rain or val.
|
||||
# return: The path of the image and the labels indicating whether different areas on the image contain content.
|
||||
# """
|
||||
def get_RINet_data(root_path: str, type: str):
|
||||
# REVIEW:
|
||||
# 该函数负责准备 RINet 所需的 patch 分类数据。与 DPNet 不同,这里的标签来自 JSON,
|
||||
# 表示 patch 是否“有效/包含足够信息”。
|
||||
|
||||
wb = load_workbook(root_path)
|
||||
train_image_path_list = []
|
||||
train_patch_effective_list = []
|
||||
test_sheet = wb.get_sheet_by_name(type)
|
||||
for (i, row) in enumerate(test_sheet.iter_rows(values_only=True)):
|
||||
# REVIEW:
|
||||
# 此处改为依赖作者本地 Windows 绝对路径,说明仓库内容很可能来自多环境拼接。
|
||||
field_path = os.path.join('E:/Datasets/SparseFocusDataset/complete_image/complete_image/90_patch_dataset_lap_hsv',row[0])
|
||||
json_data_path = os.path.join(field_path, '224_patch_effective.json')
|
||||
with open(json_data_path, 'r') as fs:
|
||||
json_data = json.load(fs)
|
||||
# REVIEW:
|
||||
# json.load 后又 json.loads,意味着文件中存放的不是标准 JSON 对象,而是“JSON 字符串”。
|
||||
# 这反映出上游数据预处理存在历史包袱,但这里做了兼容。
|
||||
json_data = json.loads(json_data)
|
||||
json_data = json_data['image_info_list']
|
||||
for (i, image_data) in enumerate(json_data):
|
||||
train_image_path_list.append(json_data[i]['image_path'].replace(r"E:\suqiang\new_microscope_data\90_patch_dataset_hsv", "E:/Datasets/SparseFocusDataset/complete_image/complete_image/90_patch_dataset_lap_hsv"))
|
||||
|
||||
train_patch_effective_list.append(json_data[i]['patch_effective'])
|
||||
return train_image_path_list, train_patch_effective_list
|
||||
# Load test need datasets
|
||||
# """
|
||||
# param root_path: Loading path of xlsx.
|
||||
# param type: Control the type of data: rain or val.
|
||||
# return: The path of the image and the labels indicating whether different areas on the image contain content.
|
||||
# """
|
||||
def get_test_data_and_label(root_path: str, type: str):
|
||||
# REVIEW:
|
||||
# 该函数为测试脚本提供整图 patch 级样本列表及其真实标签。
|
||||
# 它不仅做了 Excel 读取,还做了 patch 文件展开,职责比函数名看起来更重一些。
|
||||
|
||||
wb = load_workbook(root_path)
|
||||
train_image_path_list = []
|
||||
train_patch_effective_list = []
|
||||
train_defocus_distance_list = []
|
||||
test_sheet = wb.get_sheet_by_name(type)
|
||||
num_i = 0
|
||||
for (i, row) in enumerate(test_sheet.iter_rows(values_only=True)):
|
||||
# REVIEW:
|
||||
# 这里继续依赖固定数据目录结构,因此 test.py 的可运行性高度依赖作者原始数据目录。
|
||||
field_path = os.path.join(
|
||||
r'E:\suqiang\DenseSparse\cropped',
|
||||
row[0], '*.jpg')
|
||||
image_path_list = glob.glob(field_path)
|
||||
for image_patch in image_path_list:
|
||||
labels = int(image_patch.split('\\')[-1].split('.')[0])
|
||||
if labels >= -25000 and labels <= 25000:
|
||||
train_image_path_list.append(image_patch)
|
||||
# REVIEW:
|
||||
# effective 这里统一填 0,说明测试阶段 patch 是否可用不来自数据标注,而是后续 RINet 推理结果。
|
||||
train_patch_effective_list.append(0)
|
||||
train_defocus_distance_list.append(labels / 1000)
|
||||
num_i = num_i + 1
|
||||
print('num_i: ',num_i)
|
||||
return (train_image_path_list, train_patch_effective_list,
|
||||
train_defocus_distance_list)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
# REVIEW:
|
||||
# 这是训练阶段的通用 Dataset,DPNet 和 RINet 都复用它。
|
||||
# 通过允许 label 为任意列表实现任务复用,优点是轻量,缺点是接口语义不够强约束。
|
||||
def __init__(self, data, label, transform=None):
|
||||
self.data = data
|
||||
self.label = label
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
# return self.data.size(0)
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
imag = self.data[index]
|
||||
# REVIEW:
|
||||
# 这里未显式 convert('RGB'),如果数据源中出现灰度图或 RGBA 图,通道数可能和模型预期不一致。
|
||||
image = Image.open(imag)
|
||||
if len(self.label) > 0:
|
||||
label = torch.tensor(self.label[index])
|
||||
else:
|
||||
label = torch.tensor([])
|
||||
if self.transform != None:
|
||||
image = self.transform(image)
|
||||
# REVIEW:
|
||||
# 额外返回 imag 路径有助于日志和调试,但也让 batch 中混入了非张量字段。
|
||||
return image, label, imag
|
||||
|
||||
|
||||
|
||||
class MyDataset_test(Dataset):
|
||||
# REVIEW:
|
||||
# 测试专用 Dataset 在返回值中加入 effective 字段,以兼容 test.py 的历史接口。
|
||||
# 当前 effective 在数据层面基本只是占位。
|
||||
def __init__(self,data,label,effective,transform=None):
|
||||
self.data = data
|
||||
self.label = label
|
||||
self.effective = effective
|
||||
self.transform = transform
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
image_name = self.data[index]
|
||||
image = Image.open(image_name)
|
||||
# image = image.resize((2016, 2016))
|
||||
# image = np.array(image) # H W C
|
||||
# image = image[:,16:2032, 216:2232]
|
||||
label = self.label[index]
|
||||
if self.transform != None:
|
||||
image = self.transform(image)
|
||||
effective = self.effective[index]
|
||||
# REVIEW:
|
||||
# 与 MyDataset 不同,这里 label 保持原始 Python 数值而非 tensor,两个 Dataset 的接口风格并不完全一致。
|
||||
return image,label,image_name,effective
|
||||
268
old_test.py
Normal file
268
old_test.py
Normal file
@ -0,0 +1,268 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
from model.DPNet import register_model
|
||||
from model.RINet import MobileNetV3_small as con_classification
|
||||
|
||||
from dataset import datasets as dataset_F
|
||||
import util.utils as util
|
||||
from tqdm import tqdm
|
||||
import openpyxl
|
||||
|
||||
# REVIEW:
|
||||
# 该脚本是项目的测试/推理主入口,也是两阶段方法真正汇合的地方:
|
||||
# 先用 RINet 对整图打 patch 有效性分数,再用 DPNet 对候选 patch 做 defocus distance 回归,
|
||||
# 最后通过中位数聚合得到整图预测。
|
||||
#
|
||||
# REVIEW:
|
||||
# 如果想快速理解作者的方法,这个文件是最关键的,因为它最完整地体现了算法在实际评估中的执行顺序。
|
||||
|
||||
# Set the visible GPU devices for CUDA
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
||||
|
||||
# Image patch size
|
||||
size = 224
|
||||
|
||||
# Paths to test data Excel files
|
||||
root_path_list = [r'E:\suqiang\DenseSparse\testData_20241227.xlsx']
|
||||
# REVIEW:
|
||||
# 测试集路径完全硬编码,说明当前脚本更像作者自己的实验工具而不是通用命令行程序。
|
||||
|
||||
# Define data sheet types and test categories (cell or tissue)
|
||||
data_type_list = ['Sheet', 'Sheet1', 'Sheet1', 'Sheet1']
|
||||
cell_or_tissue_list = ['tissue', 'tissue']
|
||||
|
||||
# Load the classification model and pre-trained weights
|
||||
classification_model = con_classification()
|
||||
state_load = torch.load('../weight/RINet_best_model.pt')
|
||||
classification_model = nn.DataParallel(classification_model).to('cuda')
|
||||
classification_model.load_state_dict(state_load, False)
|
||||
classification_model.eval() # Set the model to evaluation mode
|
||||
# REVIEW:
|
||||
# RINet 权重和结构之间没有版本检查,只有运行时才能暴露兼容性问题。
|
||||
|
||||
# Loop for different models (e.g., DPNet)
|
||||
for c in range(1):
|
||||
if c == 0:
|
||||
state_load = torch.load('../weight/DP_best_model.pt')
|
||||
model = register_model()
|
||||
names = 'ours'
|
||||
model = nn.DataParallel(model).to('cuda')
|
||||
model.load_state_dict(state_load, False)
|
||||
model.eval() # Set the model to evaluation mode
|
||||
|
||||
# Iterate over test datasets
|
||||
for j in range(len(root_path_list)):
|
||||
root_path = root_path_list[j]
|
||||
data_type = data_type_list[j]
|
||||
cell_or_tissue = cell_or_tissue_list[j]
|
||||
|
||||
# Load test data and labels from the specified Excel sheet
|
||||
(train_image_path_list, train_patch_effective_list,
|
||||
train_defocus_distance_list) = dataset_F.get_test_data_and_label(root_path=root_path, type=f'{data_type}')
|
||||
|
||||
# Define image transformation pipeline
|
||||
train_transform = transforms.Compose([
|
||||
transforms.CenterCrop((2016, 2016)), # Crop the center region of the image
|
||||
transforms.ToTensor(), # Convert the image to a tensor
|
||||
])
|
||||
# REVIEW:
|
||||
# 这里把整图固定裁成 2016x2016,并配合 size=224 形成 9x9 patch 网格。
|
||||
# 因此测试流程对输入尺寸有很强假设。
|
||||
|
||||
# Create a custom dataset and DataLoader
|
||||
train_data = dataset_F.MyDataset_test(train_image_path_list, train_defocus_distance_list,
|
||||
train_patch_effective_list, train_transform)
|
||||
train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, pin_memory=True, num_workers=0)
|
||||
|
||||
# Initialize an Excel workbook to save results
|
||||
workbook = openpyxl.Workbook()
|
||||
sheet = workbook.create_sheet(index=0)
|
||||
n = 0 # Row index for Excel sheet
|
||||
|
||||
# Iterate through the test data
|
||||
for i, (img, label, image_name, effective) in tqdm(enumerate(train_loader)):
|
||||
image_path = image_name[0]
|
||||
image_label = int(image_path.split('\\')[-1].split('.')[0]) # Extract ground truth label
|
||||
|
||||
# Add header row to the Excel sheet on the first iteration
|
||||
if i == 0:
|
||||
sheet.cell(1, 1).value = 'groundTruth'
|
||||
sheet.cell(1, 2).value = 'all_prediction'
|
||||
sheet.cell(1, 3).value = 'all_can_use_prediction'
|
||||
sheet.cell(1, 4).value = '61_prediction'
|
||||
sheet.cell(1, 5).value = '51_prediction'
|
||||
sheet.cell(1, 6).value = '41_prediction'
|
||||
sheet.cell(1, 7).value = '31_prediction'
|
||||
sheet.cell(1, 8).value = '25_prediction'
|
||||
sheet.cell(1, 9).value = '19_prediction'
|
||||
sheet.cell(1, 10).value = '15_prediction'
|
||||
sheet.cell(1, 11).value = '9_prediction'
|
||||
sheet.cell(1, 12).value = '5_prediction'
|
||||
sheet.cell(1, 13).value = '3_prediction'
|
||||
sheet.cell(1, 14).value = '1_prediction'
|
||||
sheet.cell(1, 15).value = 'path'
|
||||
sheet.cell(1, 16).value = 'max_can_use_patch_num'
|
||||
sheet.cell(1, 17).value = 'avg_prediction'
|
||||
|
||||
# Preprocess the image for the classification model
|
||||
img2 = F.interpolate(img, size=(512, 512), mode='bilinear', align_corners=False)
|
||||
img = torch.cuda.FloatTensor(np.array(img))
|
||||
img2 = torch.cuda.FloatTensor(np.array(img2))
|
||||
# REVIEW:
|
||||
# 这里继续沿用了 Tensor -> numpy -> CUDA Tensor 的旧式转换写法。
|
||||
if torch.cuda.is_available():
|
||||
img = img.cuda()
|
||||
img2 = img2.cuda()
|
||||
|
||||
# Perform classification on the preprocessed image
|
||||
output_weight = classification_model(img2)
|
||||
output_weight_userful = output_weight[torch.abs(output_weight) > 0.8] # Filter predictions above a threshold
|
||||
can_use_num = len(output_weight_userful) # Count usable predictions
|
||||
# REVIEW:
|
||||
# patch 是否“可用”由阈值 0.8 决定,这是一个非常关键的经验超参数。
|
||||
|
||||
# Initialize usable patch numbers for different thresholds
|
||||
can_use_num_61, can_use_num_51, can_use_num_41 = 61, 51, 41
|
||||
can_use_num_31, can_use_num_25, can_use_num_19 = 31, 25, 19
|
||||
can_use_num_15, can_use_num_9, can_use_num_5 = 15, 9, 5
|
||||
can_use_num_3, can_use_num_1 = 3, 1
|
||||
|
||||
# Handle cases with no usable patches
|
||||
if can_use_num == 0:
|
||||
util.print_info(f'{image_name} don\'t have enough info')
|
||||
continue
|
||||
# REVIEW:
|
||||
# 没有任何 patch 通过阈值时直接跳过样本,会导致结果文件中少样本,这是需要知晓的隐式行为。
|
||||
|
||||
# Adjust usable patch numbers based on conditions
|
||||
if can_use_num % 2 == 0:
|
||||
can_use_num -= 1
|
||||
if 51 <= can_use_num < 61:
|
||||
can_use_num_61 = can_use_num
|
||||
elif 41 <= can_use_num < 51:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num
|
||||
elif 31 <= can_use_num < 41:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num
|
||||
elif 25 <= can_use_num < 31:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num
|
||||
elif 19 <= can_use_num < 25:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num
|
||||
elif 15 <= can_use_num < 19:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num
|
||||
elif 9 <= can_use_num < 15:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num
|
||||
elif 5 <= can_use_num < 9:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num
|
||||
elif 3 <= can_use_num < 5:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num
|
||||
elif can_use_num < 3:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num_3 = 1
|
||||
|
||||
# Extract top predictions for each patch group
|
||||
max_all_can_value, max_all_can_index = torch.topk(output_weight, can_use_num)
|
||||
max_61_value, max_61_index = torch.topk(output_weight, can_use_num_61)
|
||||
max_51_value, max_51_index = torch.topk(output_weight, can_use_num_51)
|
||||
max_41_value, max_41_index = torch.topk(output_weight, can_use_num_41)
|
||||
max_31_value, max_31_index = torch.topk(output_weight, can_use_num_31)
|
||||
max_25_value, max_25_index = torch.topk(output_weight, can_use_num_25)
|
||||
max_19_value, max_19_index = torch.topk(output_weight, can_use_num_19)
|
||||
max_15_value, max_15_index = torch.topk(output_weight, can_use_num_15)
|
||||
max_9_value, max_9_index = torch.topk(output_weight, can_use_num_9)
|
||||
max_5_value, max_5_index = torch.topk(output_weight, can_use_num_5)
|
||||
max_3_value, max_3_index = torch.topk(output_weight, can_use_num_3)
|
||||
max_1_value, max_1_index = torch.topk(output_weight, can_use_num_1)
|
||||
patch_image_list_all = []
|
||||
patch_image_list_all_can_use = []
|
||||
patch_image_list_61 = []
|
||||
patch_image_list_51 = []
|
||||
patch_image_list_41 = []
|
||||
patch_image_list_31 = []
|
||||
patch_image_list_25 = []
|
||||
patch_image_list_19 = []
|
||||
patch_image_list_15 = []
|
||||
patch_image_list_9 = []
|
||||
patch_image_list_5 = []
|
||||
patch_image_list_3 = []
|
||||
patch_image_list_1 = []
|
||||
label_list = []
|
||||
# REVIEW:
|
||||
# 这里同时比较多套 top-k 聚合结果,说明作者不仅关心最终预测,也在研究“保留多少 patch 更合理”。
|
||||
for k in range(0, 2016, size):
|
||||
a = int(k / size)
|
||||
for j in range(0, 2016, size):
|
||||
b = int(j / size)
|
||||
patch_image = img[:, :, k:k + size, j:j + size]
|
||||
patch_image_list_all.append(patch_image)
|
||||
if len(patch_image_list_all) != 0:
|
||||
out_all = model(torch.cat(patch_image_list_all, dim=0))
|
||||
# REVIEW:
|
||||
# 81 个 patch 一次性送入 DPNet 是个不错的实现,避免了逐 patch 前向的重复开销。
|
||||
for z in range(len(out_all)):
|
||||
if torch.any(max_all_can_index.eq(z)):
|
||||
patch_image_list_all_can_use.append(out_all[z].item())
|
||||
if torch.any(max_61_index.eq(z)):
|
||||
patch_image_list_61.append(out_all[z].item())
|
||||
if torch.any(max_51_index.eq(z)):
|
||||
patch_image_list_51.append(out_all[z].item())
|
||||
if torch.any(max_41_index.eq(z)):
|
||||
patch_image_list_41.append(out_all[z].item())
|
||||
if torch.any(max_31_index.eq(z)):
|
||||
patch_image_list_31.append(out_all[z].item())
|
||||
if torch.any(max_25_index.eq(z)):
|
||||
patch_image_list_25.append(out_all[z].item())
|
||||
if torch.any(max_19_index.eq(z)):
|
||||
patch_image_list_19.append(out_all[z].item())
|
||||
if torch.any(max_15_index.eq(z)):
|
||||
patch_image_list_15.append(out_all[z].item())
|
||||
if torch.any(max_9_index.eq(z)):
|
||||
patch_image_list_9.append(out_all[z].item())
|
||||
if torch.any(max_5_index.eq(z)):
|
||||
patch_image_list_5.append(out_all[z].item())
|
||||
if torch.any(max_3_index.eq(z)):
|
||||
patch_image_list_3.append(out_all[z].item())
|
||||
if torch.any(max_1_index.eq(z)):
|
||||
patch_image_list_1.append(out_all[z].item())
|
||||
nums_all = float(np.median(out_all.tolist()))
|
||||
avg_nums_all = float(np.mean(out_all.tolist()))
|
||||
data = []
|
||||
data.append(image_label)
|
||||
data.append(float(np.median(out_all.tolist())) * 1000)
|
||||
data.append(float(np.median(patch_image_list_all_can_use)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_61)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_51)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_41)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_31)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_25)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_19)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_15)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_9)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_5)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_3)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_1)) * 1000)
|
||||
data.append((image_path.split('\\')[-3] + '\\' +image_path.split('\\')[-2] + '\\' +image_path.split('\\')[-1]))
|
||||
data.append(len(output_weight_userful))
|
||||
data.append(float(np.mean(out_all.tolist())) * 1000)
|
||||
for (m, info) in enumerate(data):
|
||||
sheet.cell(n + 2, m + 1).value = info
|
||||
n = n + 1
|
||||
util.save_val_image(1, img,
|
||||
score=out_all, image_name=image_name,
|
||||
type_str=cell_or_tissue, type='test', type_label=output_weight)
|
||||
# REVIEW:
|
||||
# 数值输出之外还保留了网格可视化,这对解释模型行为和排查异常样本很有帮助。
|
||||
workbook.save(f'./{names}_to_{cell_or_tissue}_{data_type}_2.xlsx')
|
||||
# REVIEW:
|
||||
# 最终结果保存为 Excel,非常贴合论文实验场景;如果做批量评估,结构化文本格式会更便于分析。
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
676
playground.ipynb
Normal file
676
playground.ipynb
Normal file
File diff suppressed because one or more lines are too long
307
train_dpn.py
Normal file
307
train_dpn.py
Normal file
@ -0,0 +1,307 @@
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as transforms
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import tomllib
|
||||
|
||||
from models import DPNet
|
||||
import old_datasets as dataset_F
|
||||
import utils
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, scheduler, device, writer, global_step):
|
||||
model.train()
|
||||
|
||||
running_loss = 0.0
|
||||
sample_count = 0
|
||||
|
||||
for img, label, image_name in tqdm(train_loader, desc="Train", bar_format="{l_bar}{bar:20}{r_bar}"):
|
||||
img = img.to(device, non_blocking=True)
|
||||
label = label.to(device, non_blocking=True).float().view(-1)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
output = model(img).view(-1)
|
||||
loss = criterion(output, label)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
global_step += 1
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
writer.add_scalar("lr", lr, global_step)
|
||||
|
||||
batch_size = img.size(0)
|
||||
running_loss += loss.item() * batch_size
|
||||
sample_count += batch_size
|
||||
|
||||
epoch_loss = running_loss / sample_count
|
||||
return epoch_loss, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_epoch(model, val_loader, criterion, device):
|
||||
model.eval()
|
||||
|
||||
running_loss = 0.0
|
||||
sample_count = 0
|
||||
|
||||
for img, label, image_name in tqdm(val_loader, desc="Validate", bar_format="{l_bar}{bar:20}{r_bar}"):
|
||||
img = img.to(device, non_blocking=True)
|
||||
label = label.to(device, non_blocking=True).float().view(-1)
|
||||
|
||||
output = model(img).view(-1)
|
||||
loss = criterion(output, label)
|
||||
|
||||
batch_size = img.size(0)
|
||||
running_loss += loss.item() * batch_size
|
||||
sample_count += batch_size
|
||||
|
||||
epoch_loss = running_loss / sample_count
|
||||
return epoch_loss
|
||||
|
||||
|
||||
def main():
|
||||
# =========================
|
||||
# 1. 读取配置
|
||||
# =========================
|
||||
parser = argparse.ArgumentParser(description="Train DPNet")
|
||||
parser.add_argument("--config", type=str, required=True, help="Path to TOML config file")
|
||||
args = parser.parse_args()
|
||||
|
||||
config_path = Path(args.config)
|
||||
with config_path.open("rb") as f:
|
||||
config = tomllib.load(f)
|
||||
|
||||
xlsx_files = config["xlsx_files"]
|
||||
batch_size = config["batch_size"]
|
||||
learning_rate = config["learning_rate"]
|
||||
epochs = config["epochs"]
|
||||
num_workers = config["num_workers"]
|
||||
seed = config["seed"]
|
||||
init_weight = config["init_weight"]
|
||||
|
||||
# =========================
|
||||
# 2. 创建输出目录
|
||||
# =========================
|
||||
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_dpn")
|
||||
run_dir = Path.cwd() / run_name
|
||||
run_dir.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
shutil.copy2(config_path, run_dir / config_path.name)
|
||||
|
||||
# =========================
|
||||
# 3. 初始化日志与 TensorBoard
|
||||
# =========================
|
||||
logger = logging.getLogger("dpnet_train")
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
logger.handlers.clear()
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
file_handler = logging.FileHandler(run_dir / "train.log", encoding="utf-8")
|
||||
file_handler.setLevel(logging.INFO)
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setLevel(logging.INFO)
|
||||
stream_handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
writer = SummaryWriter(log_dir=str(run_dir))
|
||||
|
||||
logger.info(f"Run directory: {run_dir}")
|
||||
logger.info(f"Config path: {config_path}")
|
||||
logger.info(f"Loaded config:\n{pformat(config, sort_dicts=False)}")
|
||||
|
||||
# =========================
|
||||
# 4. 设置随机种子与设备
|
||||
# =========================
|
||||
utils.set_seeds(seed=seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# =========================
|
||||
# 5. 准备数据集与 DataLoader
|
||||
# =========================
|
||||
train_image_path_list, train_defocus_distance_list, val_image_path_list, val_defocus_distance_list = \
|
||||
dataset_F.get_DPNet_train_data_and_label(root_path_list=xlsx_files)
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.ColorJitter(
|
||||
brightness=(0.9, 1.4),
|
||||
contrast=(0.8, 1.5),
|
||||
saturation=(0.8, 1.5),
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
val_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
train_dataset = dataset_F.MyDataset(
|
||||
train_image_path_list,
|
||||
train_defocus_distance_list,
|
||||
train_transform,
|
||||
)
|
||||
val_dataset = dataset_F.MyDataset(
|
||||
val_image_path_list,
|
||||
val_defocus_distance_list,
|
||||
val_transform,
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=(device.type == "cuda"),
|
||||
persistent_workers=(num_workers > 0),
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
dataset=val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=(device.type == "cuda"),
|
||||
persistent_workers=(num_workers > 0),
|
||||
)
|
||||
|
||||
logger.info(f"Train dataset size: {len(train_dataset)}")
|
||||
logger.info(f"Val dataset size: {len(val_dataset)}")
|
||||
logger.info(f"Train steps per epoch: {len(train_loader)}")
|
||||
|
||||
# =========================
|
||||
# 6. 准备模型
|
||||
# =========================
|
||||
model = DPNet().to(device)
|
||||
|
||||
if init_weight:
|
||||
state_dict = torch.load(init_weight, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
logger.info(f"Loaded init weight from: {init_weight}")
|
||||
else:
|
||||
logger.info("Training from scratch")
|
||||
|
||||
# =========================
|
||||
# 7. 准备损失函数、优化器、调度器
|
||||
# =========================
|
||||
criterion = nn.MSELoss(reduction="mean")
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
|
||||
|
||||
total_steps = epochs * len(train_loader)
|
||||
warmup_steps = min(2000, total_steps - 1) if total_steps > 1 else 0
|
||||
|
||||
scheduler = SequentialLR(
|
||||
optimizer,
|
||||
schedulers=[
|
||||
LinearLR(
|
||||
optimizer,
|
||||
start_factor=0.1,
|
||||
end_factor=1.0,
|
||||
total_iters=warmup_steps,
|
||||
),
|
||||
CosineAnnealingLR(
|
||||
optimizer,
|
||||
T_max=max(1, total_steps - warmup_steps),
|
||||
eta_min=0.0,
|
||||
),
|
||||
],
|
||||
milestones=[warmup_steps],
|
||||
)
|
||||
|
||||
logger.info("Loss: MSELoss(reduction='mean')")
|
||||
logger.info(f"Optimizer: Adam(lr={learning_rate}, betas=(0.9, 0.999))")
|
||||
logger.info("Scheduler: step-based warmup + cosine annealing")
|
||||
logger.info(f"Total training steps: {total_steps}")
|
||||
logger.info(f"Warmup steps: {warmup_steps}")
|
||||
|
||||
# =========================
|
||||
# 8. 开始训练循环
|
||||
# =========================
|
||||
best_val_loss = float("inf")
|
||||
global_step = 0
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
epoch_start_time = time.time()
|
||||
|
||||
train_loss, global_step = train_epoch(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
device=device,
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
)
|
||||
|
||||
val_loss = validate_epoch(
|
||||
model=model,
|
||||
val_loader=val_loader,
|
||||
criterion=criterion,
|
||||
device=device,
|
||||
)
|
||||
|
||||
epoch_total_time = time.time() - epoch_start_time
|
||||
|
||||
writer.add_scalar("train_loss", train_loss, epoch)
|
||||
writer.add_scalar("val_loss", val_loss, epoch)
|
||||
writer.add_scalar("epoch_total_time", epoch_total_time, epoch)
|
||||
|
||||
logger.info(
|
||||
f"Epoch [{epoch}/{epochs}] | "
|
||||
f"train_loss={train_loss:.8f} | "
|
||||
f"val_loss={val_loss:.8f} | "
|
||||
f"epoch_total_time={epoch_total_time:.2f}s | "
|
||||
f"global_step={global_step}"
|
||||
)
|
||||
|
||||
torch.save(model.state_dict(), run_dir / "last_dpn.pth")
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
torch.save(model.state_dict(), run_dir / "best_dpn.pth")
|
||||
logger.info(f"Best model updated, best_val_loss={best_val_loss:.8f}")
|
||||
|
||||
# =========================
|
||||
# 9. 收尾
|
||||
# =========================
|
||||
total_time = time.time() - start_time
|
||||
logger.info("Training finished")
|
||||
logger.info(f"Best validation loss: {best_val_loss:.8f}")
|
||||
logger.info(f"Total time: {total_time:.2f} seconds")
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
217
train_rin.py
Normal file
217
train_rin.py
Normal file
@ -0,0 +1,217 @@
|
||||
import shutil
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from models import RINet
|
||||
import old_datasets as dataset_F
|
||||
from datasets import RIN_Dataset, RINPairTransform
|
||||
import utils
|
||||
|
||||
|
||||
# 训练一轮
|
||||
def train_epoch(model, loader, criterion, optimizer, device):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
total_samples = 0
|
||||
|
||||
for images, labels in tqdm(
|
||||
loader,
|
||||
desc="Train:",
|
||||
bar_format="{l_bar}{bar:20}{r_bar}",
|
||||
leave=False,
|
||||
):
|
||||
images = images.to(device, non_blocking=True)
|
||||
labels = labels.to(device, non_blocking=True).view(-1)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
outputs = model(images).view(-1)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
this_batch_size = images.size(0)
|
||||
running_loss += loss.item() * this_batch_size
|
||||
total_samples += this_batch_size
|
||||
|
||||
epoch_loss = running_loss / total_samples
|
||||
return epoch_loss
|
||||
|
||||
|
||||
# 验证一轮
|
||||
@torch.no_grad()
|
||||
def valid_epoch(model, loader, criterion, device):
|
||||
model.eval()
|
||||
|
||||
running_loss = 0.0
|
||||
total_samples = 0
|
||||
|
||||
for images, labels in tqdm(
|
||||
loader,
|
||||
desc=f"Valid:",
|
||||
bar_format="{l_bar}{bar:20}{r_bar}",
|
||||
leave=False,
|
||||
):
|
||||
images = images.to(device, non_blocking=True)
|
||||
labels = labels.to(device, non_blocking=True).view(-1)
|
||||
|
||||
outputs = model(images).view(-1)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
this_batch_size = images.size(0)
|
||||
running_loss += loss.item() * this_batch_size
|
||||
total_samples += this_batch_size
|
||||
|
||||
epoch_loss = running_loss / total_samples
|
||||
|
||||
return epoch_loss
|
||||
|
||||
|
||||
# 主训练函数
|
||||
def main():
|
||||
# ========== 1 配置文件与超参数 ==========
|
||||
config, config_path = utils.get_hyperparams()
|
||||
|
||||
XLSX_FILES = config["xlsx_files"]
|
||||
BATCH_SIZE = config["batch_size"]
|
||||
NUM_WORKERS = config["num_workers"]
|
||||
LEARNING_RATE = config["learning_rate"]
|
||||
NUM_EPOCHS = config["epochs"]
|
||||
SEED = config["seed"]
|
||||
INIT_WEIGHT_PATH = config["init_weight"]
|
||||
|
||||
# ========== 2 创建输出文件目录 ==========
|
||||
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_RIN")
|
||||
run_dir = Path.cwd() / run_name
|
||||
run_dir.mkdir(parents=True, exist_ok=False)
|
||||
shutil.copy2(config_path, run_dir / config_path.name)
|
||||
|
||||
# ========== 3 日志、tensorboard、随机种子与设备 ==========
|
||||
logger = utils.get_logger(__name__, run_dir / "train.log")
|
||||
writer = SummaryWriter(str(run_dir / "run"))
|
||||
utils.set_seeds(SEED)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
logger.info(f"Config path: {config_path}")
|
||||
logger.info(f"Loaded config: {str(config)}")
|
||||
logger.info(f"Run directory: {run_dir}")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# ========== 4 数据与 loader ==========
|
||||
train_image_path_list, train_patch_effective_list = (
|
||||
dataset_F.get_RINet_data(XLSX_FILES[0], "train")
|
||||
)
|
||||
valid_image_path_list, valid_patch_effective_list = (
|
||||
dataset_F.get_RINet_data(XLSX_FILES[0], "val")
|
||||
)
|
||||
|
||||
train_transform = RINPairTransform(train=True, image_size=512)
|
||||
valid_transform = RINPairTransform(train=False, image_size=512)
|
||||
|
||||
train_set = RIN_Dataset(
|
||||
train_image_path_list,
|
||||
train_patch_effective_list,
|
||||
train_transform,
|
||||
)
|
||||
valid_set = RIN_Dataset(
|
||||
valid_image_path_list,
|
||||
valid_patch_effective_list,
|
||||
valid_transform,
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset=train_set,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
num_workers=NUM_WORKERS,
|
||||
pin_memory=True,
|
||||
persistent_workers=(NUM_WORKERS > 0),
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dataset=valid_set,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=False,
|
||||
num_workers=NUM_WORKERS,
|
||||
pin_memory=True,
|
||||
persistent_workers=(NUM_WORKERS > 0),
|
||||
)
|
||||
|
||||
logger.info(f"Train dataset size: {len(train_set)}")
|
||||
logger.info(f"Val dataset size: {len(valid_set)}")
|
||||
logger.info(f"Train steps per epoch: {len(train_loader)}")
|
||||
|
||||
# ========== 5 模型、损失、优化器、调度器 ==========
|
||||
model = RINet().to(device)
|
||||
if INIT_WEIGHT_PATH:
|
||||
state_dict = torch.load(INIT_WEIGHT_PATH, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
logger.info(f"Loaded init weight from: {INIT_WEIGHT_PATH}")
|
||||
else:
|
||||
logger.info("Training from scratch")
|
||||
|
||||
criterion = nn.BCELoss()
|
||||
optimizer = Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
|
||||
scheduler = utils.get_warmup_cosine_scheduler(optimizer, NUM_EPOCHS)
|
||||
|
||||
logger.info("Loss: BCELoss()")
|
||||
logger.info(f"Optimizer: Adam(lr={LEARNING_RATE}, betas=(0.9, 0.999))")
|
||||
logger.info("Scheduler: epoch-based warmup + cosine annealing")
|
||||
|
||||
# ========== 6 开始训练 ==========
|
||||
logger.info("START TRAINING")
|
||||
best_valid_loss = float("inf")
|
||||
|
||||
try:
|
||||
for epoch in range(1, NUM_EPOCHS + 1):
|
||||
epoch_start_time = time.time()
|
||||
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
|
||||
valid_loss = valid_epoch(model, valid_loader, criterion, device)
|
||||
epoch_lr = optimizer.param_groups[0]["lr"] # 当前轮学习率
|
||||
scheduler.step()
|
||||
epoch_time_cost = time.time() - epoch_start_time
|
||||
|
||||
# 如果更好则保存
|
||||
if valid_loss < best_valid_loss:
|
||||
best_valid_loss = valid_loss
|
||||
torch.save(model.state_dict(), run_dir / "best_model.pt")
|
||||
logger.info(f"Best model saved, valid_loss = {best_valid_loss:.4f}")
|
||||
|
||||
# 日志与 tensorboard
|
||||
logger.info(
|
||||
f"Epoch [{epoch}/{NUM_EPOCHS}] "
|
||||
f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f} | "
|
||||
f"Best Valid Loss: {best_valid_loss:.4f} | "
|
||||
f"Epoch Time Cost: {epoch_time_cost:.2f} s | "
|
||||
f"Epoch Learning Rate: {epoch_lr:.6e}"
|
||||
)
|
||||
|
||||
writer.add_scalar("Loss/train", train_loss, epoch)
|
||||
writer.add_scalar("Loss/valid", valid_loss, epoch)
|
||||
writer.add_scalar("Loss/best_valid", best_valid_loss, epoch)
|
||||
writer.add_scalar("Time/epoch", epoch_time_cost, epoch)
|
||||
writer.add_scalar("Time/learning_rate", epoch_lr, epoch)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
|
||||
finally:
|
||||
torch.save(model.state_dict(), run_dir / "last_model.pt")
|
||||
logger.info("Last model saved")
|
||||
|
||||
writer.close()
|
||||
logger.info("TensorBoard writer closed")
|
||||
|
||||
logger.info(f"Training finished, best validation loss: {best_valid_loss:.8f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
110
utils.py
Normal file
110
utils.py
Normal file
@ -0,0 +1,110 @@
|
||||
import math
|
||||
import torch
|
||||
import random
|
||||
import logging
|
||||
import tomllib
|
||||
import argparse
|
||||
|
||||
from pathlib import Path
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# 设置固定的随机数种子
|
||||
def set_seeds(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
# 获取日志句柄
|
||||
def get_logger(name, log_file):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
logger.handlers.clear()
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
log_file,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(console_handler)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
# 初始化 DDP 并行
|
||||
def setup_distributed():
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("DDP training requires CUDA.")
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
local_rank = dist.get_node_local_rank()
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
|
||||
is_main_process = (rank == 0)
|
||||
|
||||
return local_rank, rank, world_size, device, is_main_process
|
||||
|
||||
|
||||
# 释放 DDP 并行
|
||||
def cleanup_distributed():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
# 命令行参数解析配置文件
|
||||
def get_hyperparams():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("config", help="Path to TOML config file")
|
||||
args = parser.parse_args()
|
||||
|
||||
config_path = Path(args.config)
|
||||
with config_path.open("rb") as f:
|
||||
return tomllib.load(f), config_path
|
||||
|
||||
|
||||
# 线性预热与余弦退火调度器
|
||||
def get_warmup_cosine_scheduler(optimizer, epochs):
|
||||
max_warmup_epochs, start_factor, eta_min_factor = 10, 0.1, 0.0
|
||||
warmup_epochs = min(epochs // 10, max_warmup_epochs)
|
||||
cosine_epochs = epochs - warmup_epochs
|
||||
|
||||
def lr_lambda(current_epoch):
|
||||
# 线性预热阶段
|
||||
if warmup_epochs > 0 and current_epoch < warmup_epochs:
|
||||
return start_factor + (1.0 - start_factor) * (
|
||||
current_epoch / warmup_epochs
|
||||
)
|
||||
|
||||
# 余弦退火阶段
|
||||
cosine_epoch = current_epoch - warmup_epochs
|
||||
return eta_min_factor + (1.0 - eta_min_factor) * 0.5 * (
|
||||
1.0 + math.cos(math.pi * cosine_epoch / cosine_epochs)
|
||||
)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda=lr_lambda)
|
||||
Loading…
x
Reference in New Issue
Block a user