update project structure and training scripts
This commit is contained in:
parent
1b6fcf93d0
commit
27155ecfb5
6
.gitignore
vendored
6
.gitignore
vendored
@ -2,6 +2,8 @@ __pycache__/
|
||||
*.py[cod]
|
||||
|
||||
.vscode/
|
||||
old_proj/
|
||||
|
||||
20*_RIN*/
|
||||
20*_RIN_OLD*/
|
||||
20*/
|
||||
|
||||
*.xlsx
|
||||
|
||||
119
PROJECT_CONVENTIONS.md
Normal file
119
PROJECT_CONVENTIONS.md
Normal file
@ -0,0 +1,119 @@
|
||||
# SparseFocus 项目约定
|
||||
|
||||
## 项目理解
|
||||
|
||||
SparseFocus 面向自动对焦中的离焦预测问题。
|
||||
|
||||
基础任务是根据显微图像内容预测离焦程度。当前工作的核心判断是:内容稀疏视野会削弱离焦预测的稳定性。当视野中缺少足够有效结构时,直接从整图或任意局部区域回归离焦量,容易引入不可靠信号。
|
||||
|
||||
因此,项目目标中的主要流程应当是两阶段:
|
||||
|
||||
1. 先预测区域重要度,也就是判断每个局部区域是否包含有用信息。
|
||||
2. 再使用重要区域作为有效证据进行离焦预测。
|
||||
|
||||
当前代码结构中:
|
||||
|
||||
- `RINet` 用于从输入图像预测 `9x9` 的区域重要度图。
|
||||
- `DPNet` 用于预测离焦距离。
|
||||
- `train_rin.py` 是当前较清晰的区域重要度训练入口。
|
||||
- `train_dpn.py` 是当前离焦回归任务的训练入口。
|
||||
- `old_datasets.py` 仍然承载了较多历史数据入口逻辑和本地路径假设。
|
||||
- `datasets.py` 中已有较干净的 RIN 任务专用数据集和同步变换实现。
|
||||
- TOML 配置文件用于将训练超参数从训练脚本中分离出来。
|
||||
|
||||
后续重构应当保留这个研究意图,同时让任务边界、数据契约和实验产物更加明确。
|
||||
|
||||
## 重构边界
|
||||
|
||||
`old_proj/` 是只读参考区域。
|
||||
|
||||
具体规则:
|
||||
|
||||
- `old_proj/` 下的代码、数据、笔记和产物只能读取参考。
|
||||
- 重构过程中不得在 `old_proj/` 下创建、编辑、移动、重命名、格式化或删除任何文件。
|
||||
- 所有实现变更都必须发生在 `old_proj/` 之外。
|
||||
- 如果需要借鉴 `old_proj/` 中的历史逻辑,只能将思路手动迁移到主项目结构中,并在主项目内重新整理实现。
|
||||
- 不要从 `old_proj/` 直接导入运行时代码;它不应成为当前工程依赖图的一部分。
|
||||
|
||||
这条边界用于保留旧快速迭代项目的历史上下文,同时避免旧结构和误操作继续泄漏到新的工程重构中。
|
||||
|
||||
## 代码风格约定
|
||||
|
||||
### 模块职责
|
||||
|
||||
- 模型、数据集、训练循环、工具函数、配置文件和实验输出应当分离。
|
||||
- 当标签形状或返回值契约不同的时候,优先使用任务专用的数据集类。
|
||||
- 除非共享契约已经明确,否则不要让同一个数据集辅助函数同时服务区域重要度分类和离焦回归。
|
||||
- 路径发现和标签解析应放在数据层,不应放进模型逻辑或训练循环。
|
||||
|
||||
### 命名
|
||||
|
||||
- 任务名称应保持一致:
|
||||
- `RIN` / `RINet` 表示区域重要度预测。
|
||||
- `DP` / `DPNet` 表示离焦预测。
|
||||
- 优先使用能表达含义的变量名,例如 `train_image_paths`、`valid_labels`、`run_dir` 和 `best_valid_loss`。
|
||||
- 新代码中如果领域含义已经明确,应避免使用 `type`、`data`、`label`、`imag` 这类过于泛化的命名。
|
||||
|
||||
### 训练脚本
|
||||
|
||||
- 训练脚本应保持稳定顺序:
|
||||
1. 读取配置。
|
||||
2. 创建运行目录。
|
||||
3. 初始化日志、TensorBoard writer、随机种子和设备。
|
||||
4. 构建数据集和 dataloader。
|
||||
5. 构建模型、损失函数、优化器和学习率调度器。
|
||||
6. 执行训练和验证循环。
|
||||
7. 保存最佳权重和最后权重。
|
||||
- 单轮训练函数和单轮验证函数应与 `main()` 分离。
|
||||
- 验证阶段应使用 `torch.no_grad()`。
|
||||
- 日志中应记录配置路径、运行目录、数据集规模、优化器、调度器、每轮损失、最佳损失和权重保存事件。
|
||||
- 每次实验运行都应将配置文件复制到对应运行目录中。
|
||||
|
||||
### 配置与实验
|
||||
|
||||
- 超参数应放在 TOML 配置文件中,不应硬编码在训练脚本中。
|
||||
- 实验输出应写入带时间戳的运行目录,并保持被 Git 忽略。
|
||||
- checkpoint、TensorBoard event 文件和训练日志属于运行时产物,不属于源码。
|
||||
- 如果配置中的数据路径指向本机特定绝对路径,应先视为临时约定,直到引入更可移植的数据根目录机制。
|
||||
|
||||
### 模型代码
|
||||
|
||||
- 神经网络模块应保持小而可组合。
|
||||
- 对有明确语义的模型组件,优先使用显式的 `nn.Module` 类。
|
||||
- 当张量形状不明显时,应在模块边界附近记录形状预期,尤其是 `9x9` 区域重要度图和展平后的回归输出。
|
||||
- 如果构造函数已经提供通道数参数,辅助模块内部应避免再次硬编码通道数。
|
||||
|
||||
### 数据契约
|
||||
|
||||
- 数据集 `__getitem__` 的返回值应稳定,并且与具体任务匹配。
|
||||
- 图像张量在进入模型前应完成一致的归一化和形状整理。
|
||||
- 除非模型明确设计为其他通道契约,否则读取图像时应显式转换为 RGB。
|
||||
- 从文件名、Excel 或 JSON 中解析标签的逻辑,应与训练循环隔离,并适合单独测试。
|
||||
- 数据加载函数中应避免隐藏依赖本机绝对路径。
|
||||
|
||||
### 编码与注释
|
||||
|
||||
- 新文件应使用 UTF-8 编码。
|
||||
- 现有乱码中文注释应在聚焦重构时修复或删除,不应复制到新代码中。
|
||||
- 绝大多数时候应使用 ASCII 字符,尤其是代码标识、短字符串、运行时输出、表格字段名、Sheet 名和指标枚举值。
|
||||
- 较长的大段文档、代码注释可以并且应该使用中文描述,便于直接理解。
|
||||
- 保存到结果文件中的字段名、Sheet 名、指标枚举值等应使用英文,例如 `image_path`、`sparsity_level`、`dense`。
|
||||
- 注释应简短并解释真实背景,优先说明领域假设、张量形状、数据格式和不明显的训练选择。
|
||||
- 不要用注释重复解释简单代码本身。
|
||||
|
||||
### 执行检查
|
||||
|
||||
- 代码运行、训练、测试和静态检查默认由用户执行。
|
||||
- 除非用户明确要求,助手不主动运行代码检查、训练脚本或测试脚本。
|
||||
|
||||
### 导入与格式
|
||||
|
||||
- 导入顺序应按标准库、第三方库、本地模块分组。
|
||||
- 新代码中优先使用 `pathlib.Path` 处理文件系统路径。
|
||||
- 格式应兼容常见 Python 工具;窄范围改动时避免引入无关格式化变更。
|
||||
|
||||
### Git 卫生
|
||||
|
||||
- `old_proj/`、`.vscode/`、`__pycache__/` 和带时间戳的实验目录应保持被忽略。
|
||||
- 不要提交生成的 checkpoint、TensorBoard 日志、训练日志或 Python 字节码。
|
||||
- 重构时应保持改动范围清晰,并围绕有明确含义的工程节点进行提交。
|
||||
3
config_test.toml
Normal file
3
config_test.toml
Normal file
@ -0,0 +1,3 @@
|
||||
excel_file = "E:/Datasets/SparseFocusDataset/DenseSparse/testData_20241227.xlsx"
|
||||
rin_weight = "F:/Projects/SparseFocus/2026_05_06_19_36_35_RIN/best_model.pt"
|
||||
dpn_weight = "F:/Projects/SparseFocus/2026_04_18_18_37_07_dpn/best_dpn.pth"
|
||||
@ -15,6 +15,8 @@ from openpyxl import load_workbook
|
||||
import json
|
||||
import re
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
# REVIEW:
|
||||
# 这个文件是项目的数据入口层,负责把 Excel、JSON 和图像文件组织成训练/测试阶段可消费的数据。
|
||||
# 整体上它解决了“如何把作者本地数据组织接到模型上”的问题,但工程抽象较弱,路径与数据格式都强依赖作者环境。
|
||||
@ -139,16 +141,21 @@ def get_test_data_and_label(root_path: str, type: str):
|
||||
# REVIEW:
|
||||
# 这里继续依赖固定数据目录结构,因此 test.py 的可运行性高度依赖作者原始数据目录。
|
||||
field_path = os.path.join(
|
||||
r'E:\suqiang\DenseSparse\cropped',
|
||||
'E:/Datasets/SparseFocusDataset/DenseSparse/cropped',
|
||||
row[0], '*.jpg')
|
||||
image_path_list = glob.glob(field_path)
|
||||
for image_patch in image_path_list:
|
||||
labels = int(image_patch.split('\\')[-1].split('.')[0])
|
||||
if labels >= -25000 and labels <= 25000:
|
||||
train_image_path_list.append(image_patch)
|
||||
# REVIEW:
|
||||
# effective 这里统一填 0,说明测试阶段 patch 是否可用不来自数据标注,而是后续 RINet 推理结果。
|
||||
train_patch_effective_list.append(0)
|
||||
train_image_path_list.append(image_patch.replace("\\", "/"))
|
||||
|
||||
json_data_path = Path(image_patch.replace("\\", "/")).with_name("224_patch_effective.json").as_posix()
|
||||
with open(json_data_path, 'r') as fs:
|
||||
json_data = json.load(fs)
|
||||
json_data = json.loads(json_data)
|
||||
json_data = json_data['image_info_list']
|
||||
train_patch_effective_list.append(json_data[0]['patch_effective'])
|
||||
|
||||
train_defocus_distance_list.append(labels / 1000)
|
||||
num_i = num_i + 1
|
||||
print('num_i: ',num_i)
|
||||
|
||||
268
old_test.py
268
old_test.py
@ -1,268 +0,0 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
from model.DPNet import register_model
|
||||
from model.RINet import MobileNetV3_small as con_classification
|
||||
|
||||
from dataset import datasets as dataset_F
|
||||
import util.utils as util
|
||||
from tqdm import tqdm
|
||||
import openpyxl
|
||||
|
||||
# REVIEW:
|
||||
# 该脚本是项目的测试/推理主入口,也是两阶段方法真正汇合的地方:
|
||||
# 先用 RINet 对整图打 patch 有效性分数,再用 DPNet 对候选 patch 做 defocus distance 回归,
|
||||
# 最后通过中位数聚合得到整图预测。
|
||||
#
|
||||
# REVIEW:
|
||||
# 如果想快速理解作者的方法,这个文件是最关键的,因为它最完整地体现了算法在实际评估中的执行顺序。
|
||||
|
||||
# Set the visible GPU devices for CUDA
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
||||
|
||||
# Image patch size
|
||||
size = 224
|
||||
|
||||
# Paths to test data Excel files
|
||||
root_path_list = [r'E:\suqiang\DenseSparse\testData_20241227.xlsx']
|
||||
# REVIEW:
|
||||
# 测试集路径完全硬编码,说明当前脚本更像作者自己的实验工具而不是通用命令行程序。
|
||||
|
||||
# Define data sheet types and test categories (cell or tissue)
|
||||
data_type_list = ['Sheet', 'Sheet1', 'Sheet1', 'Sheet1']
|
||||
cell_or_tissue_list = ['tissue', 'tissue']
|
||||
|
||||
# Load the classification model and pre-trained weights
|
||||
classification_model = con_classification()
|
||||
state_load = torch.load('../weight/RINet_best_model.pt')
|
||||
classification_model = nn.DataParallel(classification_model).to('cuda')
|
||||
classification_model.load_state_dict(state_load, False)
|
||||
classification_model.eval() # Set the model to evaluation mode
|
||||
# REVIEW:
|
||||
# RINet 权重和结构之间没有版本检查,只有运行时才能暴露兼容性问题。
|
||||
|
||||
# Loop for different models (e.g., DPNet)
|
||||
for c in range(1):
|
||||
if c == 0:
|
||||
state_load = torch.load('../weight/DP_best_model.pt')
|
||||
model = register_model()
|
||||
names = 'ours'
|
||||
model = nn.DataParallel(model).to('cuda')
|
||||
model.load_state_dict(state_load, False)
|
||||
model.eval() # Set the model to evaluation mode
|
||||
|
||||
# Iterate over test datasets
|
||||
for j in range(len(root_path_list)):
|
||||
root_path = root_path_list[j]
|
||||
data_type = data_type_list[j]
|
||||
cell_or_tissue = cell_or_tissue_list[j]
|
||||
|
||||
# Load test data and labels from the specified Excel sheet
|
||||
(train_image_path_list, train_patch_effective_list,
|
||||
train_defocus_distance_list) = dataset_F.get_test_data_and_label(root_path=root_path, type=f'{data_type}')
|
||||
|
||||
# Define image transformation pipeline
|
||||
train_transform = transforms.Compose([
|
||||
transforms.CenterCrop((2016, 2016)), # Crop the center region of the image
|
||||
transforms.ToTensor(), # Convert the image to a tensor
|
||||
])
|
||||
# REVIEW:
|
||||
# 这里把整图固定裁成 2016x2016,并配合 size=224 形成 9x9 patch 网格。
|
||||
# 因此测试流程对输入尺寸有很强假设。
|
||||
|
||||
# Create a custom dataset and DataLoader
|
||||
train_data = dataset_F.MyDataset_test(train_image_path_list, train_defocus_distance_list,
|
||||
train_patch_effective_list, train_transform)
|
||||
train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, pin_memory=True, num_workers=0)
|
||||
|
||||
# Initialize an Excel workbook to save results
|
||||
workbook = openpyxl.Workbook()
|
||||
sheet = workbook.create_sheet(index=0)
|
||||
n = 0 # Row index for Excel sheet
|
||||
|
||||
# Iterate through the test data
|
||||
for i, (img, label, image_name, effective) in tqdm(enumerate(train_loader)):
|
||||
image_path = image_name[0]
|
||||
image_label = int(image_path.split('\\')[-1].split('.')[0]) # Extract ground truth label
|
||||
|
||||
# Add header row to the Excel sheet on the first iteration
|
||||
if i == 0:
|
||||
sheet.cell(1, 1).value = 'groundTruth'
|
||||
sheet.cell(1, 2).value = 'all_prediction'
|
||||
sheet.cell(1, 3).value = 'all_can_use_prediction'
|
||||
sheet.cell(1, 4).value = '61_prediction'
|
||||
sheet.cell(1, 5).value = '51_prediction'
|
||||
sheet.cell(1, 6).value = '41_prediction'
|
||||
sheet.cell(1, 7).value = '31_prediction'
|
||||
sheet.cell(1, 8).value = '25_prediction'
|
||||
sheet.cell(1, 9).value = '19_prediction'
|
||||
sheet.cell(1, 10).value = '15_prediction'
|
||||
sheet.cell(1, 11).value = '9_prediction'
|
||||
sheet.cell(1, 12).value = '5_prediction'
|
||||
sheet.cell(1, 13).value = '3_prediction'
|
||||
sheet.cell(1, 14).value = '1_prediction'
|
||||
sheet.cell(1, 15).value = 'path'
|
||||
sheet.cell(1, 16).value = 'max_can_use_patch_num'
|
||||
sheet.cell(1, 17).value = 'avg_prediction'
|
||||
|
||||
# Preprocess the image for the classification model
|
||||
img2 = F.interpolate(img, size=(512, 512), mode='bilinear', align_corners=False)
|
||||
img = torch.cuda.FloatTensor(np.array(img))
|
||||
img2 = torch.cuda.FloatTensor(np.array(img2))
|
||||
# REVIEW:
|
||||
# 这里继续沿用了 Tensor -> numpy -> CUDA Tensor 的旧式转换写法。
|
||||
if torch.cuda.is_available():
|
||||
img = img.cuda()
|
||||
img2 = img2.cuda()
|
||||
|
||||
# Perform classification on the preprocessed image
|
||||
output_weight = classification_model(img2)
|
||||
output_weight_userful = output_weight[torch.abs(output_weight) > 0.8] # Filter predictions above a threshold
|
||||
can_use_num = len(output_weight_userful) # Count usable predictions
|
||||
# REVIEW:
|
||||
# patch 是否“可用”由阈值 0.8 决定,这是一个非常关键的经验超参数。
|
||||
|
||||
# Initialize usable patch numbers for different thresholds
|
||||
can_use_num_61, can_use_num_51, can_use_num_41 = 61, 51, 41
|
||||
can_use_num_31, can_use_num_25, can_use_num_19 = 31, 25, 19
|
||||
can_use_num_15, can_use_num_9, can_use_num_5 = 15, 9, 5
|
||||
can_use_num_3, can_use_num_1 = 3, 1
|
||||
|
||||
# Handle cases with no usable patches
|
||||
if can_use_num == 0:
|
||||
util.print_info(f'{image_name} don\'t have enough info')
|
||||
continue
|
||||
# REVIEW:
|
||||
# 没有任何 patch 通过阈值时直接跳过样本,会导致结果文件中少样本,这是需要知晓的隐式行为。
|
||||
|
||||
# Adjust usable patch numbers based on conditions
|
||||
if can_use_num % 2 == 0:
|
||||
can_use_num -= 1
|
||||
if 51 <= can_use_num < 61:
|
||||
can_use_num_61 = can_use_num
|
||||
elif 41 <= can_use_num < 51:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num
|
||||
elif 31 <= can_use_num < 41:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num
|
||||
elif 25 <= can_use_num < 31:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num
|
||||
elif 19 <= can_use_num < 25:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num
|
||||
elif 15 <= can_use_num < 19:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num
|
||||
elif 9 <= can_use_num < 15:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num
|
||||
elif 5 <= can_use_num < 9:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num
|
||||
elif 3 <= can_use_num < 5:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num
|
||||
elif can_use_num < 3:
|
||||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num_3 = 1
|
||||
|
||||
# Extract top predictions for each patch group
|
||||
max_all_can_value, max_all_can_index = torch.topk(output_weight, can_use_num)
|
||||
max_61_value, max_61_index = torch.topk(output_weight, can_use_num_61)
|
||||
max_51_value, max_51_index = torch.topk(output_weight, can_use_num_51)
|
||||
max_41_value, max_41_index = torch.topk(output_weight, can_use_num_41)
|
||||
max_31_value, max_31_index = torch.topk(output_weight, can_use_num_31)
|
||||
max_25_value, max_25_index = torch.topk(output_weight, can_use_num_25)
|
||||
max_19_value, max_19_index = torch.topk(output_weight, can_use_num_19)
|
||||
max_15_value, max_15_index = torch.topk(output_weight, can_use_num_15)
|
||||
max_9_value, max_9_index = torch.topk(output_weight, can_use_num_9)
|
||||
max_5_value, max_5_index = torch.topk(output_weight, can_use_num_5)
|
||||
max_3_value, max_3_index = torch.topk(output_weight, can_use_num_3)
|
||||
max_1_value, max_1_index = torch.topk(output_weight, can_use_num_1)
|
||||
patch_image_list_all = []
|
||||
patch_image_list_all_can_use = []
|
||||
patch_image_list_61 = []
|
||||
patch_image_list_51 = []
|
||||
patch_image_list_41 = []
|
||||
patch_image_list_31 = []
|
||||
patch_image_list_25 = []
|
||||
patch_image_list_19 = []
|
||||
patch_image_list_15 = []
|
||||
patch_image_list_9 = []
|
||||
patch_image_list_5 = []
|
||||
patch_image_list_3 = []
|
||||
patch_image_list_1 = []
|
||||
label_list = []
|
||||
# REVIEW:
|
||||
# 这里同时比较多套 top-k 聚合结果,说明作者不仅关心最终预测,也在研究“保留多少 patch 更合理”。
|
||||
for k in range(0, 2016, size):
|
||||
a = int(k / size)
|
||||
for j in range(0, 2016, size):
|
||||
b = int(j / size)
|
||||
patch_image = img[:, :, k:k + size, j:j + size]
|
||||
patch_image_list_all.append(patch_image)
|
||||
if len(patch_image_list_all) != 0:
|
||||
out_all = model(torch.cat(patch_image_list_all, dim=0))
|
||||
# REVIEW:
|
||||
# 81 个 patch 一次性送入 DPNet 是个不错的实现,避免了逐 patch 前向的重复开销。
|
||||
for z in range(len(out_all)):
|
||||
if torch.any(max_all_can_index.eq(z)):
|
||||
patch_image_list_all_can_use.append(out_all[z].item())
|
||||
if torch.any(max_61_index.eq(z)):
|
||||
patch_image_list_61.append(out_all[z].item())
|
||||
if torch.any(max_51_index.eq(z)):
|
||||
patch_image_list_51.append(out_all[z].item())
|
||||
if torch.any(max_41_index.eq(z)):
|
||||
patch_image_list_41.append(out_all[z].item())
|
||||
if torch.any(max_31_index.eq(z)):
|
||||
patch_image_list_31.append(out_all[z].item())
|
||||
if torch.any(max_25_index.eq(z)):
|
||||
patch_image_list_25.append(out_all[z].item())
|
||||
if torch.any(max_19_index.eq(z)):
|
||||
patch_image_list_19.append(out_all[z].item())
|
||||
if torch.any(max_15_index.eq(z)):
|
||||
patch_image_list_15.append(out_all[z].item())
|
||||
if torch.any(max_9_index.eq(z)):
|
||||
patch_image_list_9.append(out_all[z].item())
|
||||
if torch.any(max_5_index.eq(z)):
|
||||
patch_image_list_5.append(out_all[z].item())
|
||||
if torch.any(max_3_index.eq(z)):
|
||||
patch_image_list_3.append(out_all[z].item())
|
||||
if torch.any(max_1_index.eq(z)):
|
||||
patch_image_list_1.append(out_all[z].item())
|
||||
nums_all = float(np.median(out_all.tolist()))
|
||||
avg_nums_all = float(np.mean(out_all.tolist()))
|
||||
data = []
|
||||
data.append(image_label)
|
||||
data.append(float(np.median(out_all.tolist())) * 1000)
|
||||
data.append(float(np.median(patch_image_list_all_can_use)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_61)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_51)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_41)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_31)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_25)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_19)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_15)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_9)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_5)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_3)) * 1000)
|
||||
data.append(float(np.median(patch_image_list_1)) * 1000)
|
||||
data.append((image_path.split('\\')[-3] + '\\' +image_path.split('\\')[-2] + '\\' +image_path.split('\\')[-1]))
|
||||
data.append(len(output_weight_userful))
|
||||
data.append(float(np.mean(out_all.tolist())) * 1000)
|
||||
for (m, info) in enumerate(data):
|
||||
sheet.cell(n + 2, m + 1).value = info
|
||||
n = n + 1
|
||||
util.save_val_image(1, img,
|
||||
score=out_all, image_name=image_name,
|
||||
type_str=cell_or_tissue, type='test', type_label=output_weight)
|
||||
# REVIEW:
|
||||
# 数值输出之外还保留了网格可视化,这对解释模型行为和排查异常样本很有帮助。
|
||||
workbook.save(f'./{names}_to_{cell_or_tissue}_{data_type}_2.xlsx')
|
||||
# REVIEW:
|
||||
# 最终结果保存为 Excel,非常贴合论文实验场景;如果做批量评估,结构化文本格式会更便于分析。
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
1486
playground.ipynb
1486
playground.ipynb
File diff suppressed because one or more lines are too long
120
stat_test_results.py
Normal file
120
stat_test_results.py
Normal file
@ -0,0 +1,120 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
REQUIRED_COLUMNS = [
|
||||
"image_path",
|
||||
"roi_no",
|
||||
"importance_label",
|
||||
"importance_prediction",
|
||||
"defocus_label",
|
||||
"defocus_prediction",
|
||||
]
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="Aggregate SparseFocus test results")
|
||||
parser.add_argument(
|
||||
"result_file",
|
||||
nargs="?",
|
||||
default="test_results.xlsx",
|
||||
help="Path to test result xlsx file",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def check_columns(df):
|
||||
missing_columns = [column for column in REQUIRED_COLUMNS if column not in df.columns]
|
||||
if missing_columns:
|
||||
raise ValueError(f"Missing required columns: {missing_columns}")
|
||||
|
||||
|
||||
def get_sparsity_level(usable_count):
|
||||
if usable_count == 0:
|
||||
return "invalid"
|
||||
if usable_count > 41:
|
||||
return "dense"
|
||||
if usable_count > 9:
|
||||
return "sparse"
|
||||
return "extremely_sparse"
|
||||
|
||||
|
||||
def iter_image_groups(df):
|
||||
if len(df) % 81 != 0:
|
||||
raise ValueError(f"Row count should be divisible by 81, got {len(df)}.")
|
||||
|
||||
for start in range(0, len(df), 81):
|
||||
group = df.iloc[start: start + 81].copy()
|
||||
expected_roi_no = list(range(81))
|
||||
actual_roi_no = group["roi_no"].astype(int).tolist()
|
||||
if actual_roi_no != expected_roi_no:
|
||||
image_path = group["image_path"].iloc[0]
|
||||
raise ValueError(f"ROI order mismatch near image: {image_path}")
|
||||
|
||||
yield group
|
||||
|
||||
|
||||
def aggregate_one_image(group):
|
||||
image_path = group["image_path"].iloc[0]
|
||||
defocus_label = group["defocus_label"].iloc[0]
|
||||
usable_count = int((group["importance_label"] > 0).sum())
|
||||
sparsity_level = get_sparsity_level(usable_count)
|
||||
|
||||
selected_by_prediction = group[group["importance_prediction"] > 0.8]
|
||||
if len(selected_by_prediction) > 0:
|
||||
pred_importance_gt_0_8 = selected_by_prediction["defocus_prediction"].median()
|
||||
else:
|
||||
pred_importance_gt_0_8 = pd.NA
|
||||
|
||||
sorted_group = group.sort_values(
|
||||
by="importance_prediction",
|
||||
ascending=False,
|
||||
kind="mergesort",
|
||||
)
|
||||
|
||||
row = {
|
||||
"image_path": image_path,
|
||||
"defocus_label": defocus_label,
|
||||
"usable_roi_count": usable_count,
|
||||
"sparsity_level": sparsity_level,
|
||||
"pred_importance_gt_0_8": pred_importance_gt_0_8,
|
||||
"all_blocks_median": group["defocus_prediction"].median(),
|
||||
}
|
||||
|
||||
for k in range(81, 0, -1):
|
||||
row[f"top_{k}_median"] = sorted_group.head(k)["defocus_prediction"].median()
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def aggregate_results(df):
|
||||
check_columns(df)
|
||||
rows = [aggregate_one_image(group) for group in iter_image_groups(df)]
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
result_path = Path(args.result_file)
|
||||
|
||||
print(f"Reading test results: {result_path}")
|
||||
df = pd.read_excel(result_path, sheet_name=0)
|
||||
stat_df = aggregate_results(df)
|
||||
|
||||
print(f"Image count: {len(stat_df)}")
|
||||
print("Writing Sheet2")
|
||||
with pd.ExcelWriter(
|
||||
result_path,
|
||||
engine="openpyxl",
|
||||
mode="a",
|
||||
if_sheet_exists="replace",
|
||||
) as writer:
|
||||
stat_df.to_excel(writer, sheet_name="Sheet2", index=False)
|
||||
|
||||
print(f"Saved statistics to Sheet2: {result_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
145
test.py
Normal file
145
test.py
Normal file
@ -0,0 +1,145 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from torchvision.transforms import functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
import old_datasets as dataset_F
|
||||
import utils
|
||||
from models import DPNet, RINet
|
||||
|
||||
|
||||
def strip_module_prefix(state_dict):
|
||||
if not all(key.startswith("module.") for key in state_dict):
|
||||
return state_dict
|
||||
|
||||
return {key.removeprefix("module."): value for key, value in state_dict.items()}
|
||||
|
||||
|
||||
def load_model_weight(model, weight_path, device):
|
||||
state_dict = torch.load(weight_path, map_location="cpu")
|
||||
if isinstance(state_dict, dict) and "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
model.load_state_dict(strip_module_prefix(state_dict), strict=True)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def build_patch_tensor(cropped_image):
|
||||
cropped_tensor = F.to_tensor(cropped_image)
|
||||
patches = []
|
||||
|
||||
for row in range(9):
|
||||
for col in range(9):
|
||||
top = row * 224
|
||||
left = col * 224
|
||||
patch = cropped_tensor[:, top: top + 224, left: left + 224]
|
||||
patches.append(patch)
|
||||
|
||||
return torch.stack(patches, dim=0)
|
||||
|
||||
|
||||
def predict_one_image(image_path, importance_label, defocus_label, rin_model, dpn_model, device):
|
||||
if len(importance_label) != 81:
|
||||
raise RuntimeError(f"Importance label count should be 81, got {len(importance_label)}.")
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
cropped_image = F.center_crop(image, [2016, 2016])
|
||||
|
||||
rin_image = F.resize(cropped_image, [512, 512])
|
||||
rin_tensor = F.to_tensor(rin_image).unsqueeze(0).to(device)
|
||||
patch_tensor = build_patch_tensor(cropped_image).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
importance_predictions = rin_model(rin_tensor).reshape(-1).detach().cpu()
|
||||
defocus_predictions = dpn_model(patch_tensor).reshape(-1).detach().cpu()
|
||||
|
||||
if len(importance_predictions) != 81:
|
||||
raise RuntimeError(f"RIN output count should be 81, got {len(importance_predictions)}.")
|
||||
if len(defocus_predictions) != 81:
|
||||
raise RuntimeError(f"DPN output count should be 81, got {len(defocus_predictions)}.")
|
||||
|
||||
rows = []
|
||||
for roi_no in range(81):
|
||||
rows.append(
|
||||
[
|
||||
str(image_path),
|
||||
roi_no,
|
||||
importance_label[roi_no],
|
||||
float(importance_predictions[roi_no]),
|
||||
float(defocus_label) * 1000.0,
|
||||
float(defocus_predictions[roi_no]) * 1000.0,
|
||||
]
|
||||
)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def save_results(rows, output_path):
|
||||
df = pd.DataFrame(
|
||||
rows,
|
||||
columns=[
|
||||
"image_path",
|
||||
"roi_no",
|
||||
"importance_label",
|
||||
"importance_prediction",
|
||||
"defocus_label",
|
||||
"defocus_prediction",
|
||||
],
|
||||
)
|
||||
df.to_excel(output_path, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
config, config_path = utils.get_hyperparams()
|
||||
|
||||
excel_file = config["excel_file"]
|
||||
rin_weight = config["rin_weight"]
|
||||
dpn_weight = config["dpn_weight"]
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
output_path = Path.cwd() / "test_results.xlsx"
|
||||
|
||||
print(f"Config path: {config_path}")
|
||||
print(f"Using device: {device}")
|
||||
print("Loading models")
|
||||
|
||||
rin_model = load_model_weight(RINet(), rin_weight, device)
|
||||
dpn_model = load_model_weight(DPNet(), dpn_weight, device)
|
||||
|
||||
print("Loading test data")
|
||||
image_paths, importance_labels, defocus_labels = dataset_F.get_test_data_and_label(
|
||||
excel_file,
|
||||
"Sheet",
|
||||
)
|
||||
print(f"Test image count: {len(image_paths)}")
|
||||
|
||||
all_rows = []
|
||||
for image_path, importance_label, defocus_label in tqdm(
|
||||
zip(image_paths, importance_labels, defocus_labels),
|
||||
total=len(image_paths),
|
||||
desc="Test",
|
||||
bar_format="{l_bar}{bar:20}{r_bar}",
|
||||
):
|
||||
all_rows.extend(
|
||||
predict_one_image(
|
||||
image_path,
|
||||
importance_label,
|
||||
defocus_label,
|
||||
rin_model,
|
||||
dpn_model,
|
||||
device,
|
||||
)
|
||||
)
|
||||
|
||||
save_results(all_rows, output_path)
|
||||
print(f"Saved test results: {output_path}")
|
||||
print(f"Saved row count: {len(all_rows)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
377
train_dpn.py
377
train_dpn.py
@ -1,307 +1,226 @@
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as transforms
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import tomllib
|
||||
from tqdm import tqdm
|
||||
|
||||
from models import DPNet
|
||||
import old_datasets as dataset_F
|
||||
import utils
|
||||
from models import DPNet
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, scheduler, device, writer, global_step):
|
||||
# 训练一轮
|
||||
def train_epoch(model, loader, criterion, optimizer, device):
|
||||
model.train()
|
||||
|
||||
running_loss = 0.0
|
||||
sample_count = 0
|
||||
total_samples = 0
|
||||
|
||||
for img, label, image_name in tqdm(train_loader, desc="Train", bar_format="{l_bar}{bar:20}{r_bar}"):
|
||||
img = img.to(device, non_blocking=True)
|
||||
label = label.to(device, non_blocking=True).float().view(-1)
|
||||
for images, labels, image_names in tqdm(
|
||||
loader,
|
||||
desc="Train:",
|
||||
bar_format="{l_bar}{bar:20}{r_bar}",
|
||||
leave=False,
|
||||
):
|
||||
images = images.to(device, non_blocking=True)
|
||||
labels = labels.to(device, non_blocking=True).float().view(-1)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
output = model(img).view(-1)
|
||||
loss = criterion(output, label)
|
||||
outputs = model(images).view(-1)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
global_step += 1
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
writer.add_scalar("lr", lr, global_step)
|
||||
this_batch_size = images.size(0)
|
||||
running_loss += loss.item() * this_batch_size
|
||||
total_samples += this_batch_size
|
||||
|
||||
batch_size = img.size(0)
|
||||
running_loss += loss.item() * batch_size
|
||||
sample_count += batch_size
|
||||
|
||||
epoch_loss = running_loss / sample_count
|
||||
return epoch_loss, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_epoch(model, val_loader, criterion, device):
|
||||
model.eval()
|
||||
|
||||
running_loss = 0.0
|
||||
sample_count = 0
|
||||
|
||||
for img, label, image_name in tqdm(val_loader, desc="Validate", bar_format="{l_bar}{bar:20}{r_bar}"):
|
||||
img = img.to(device, non_blocking=True)
|
||||
label = label.to(device, non_blocking=True).float().view(-1)
|
||||
|
||||
output = model(img).view(-1)
|
||||
loss = criterion(output, label)
|
||||
|
||||
batch_size = img.size(0)
|
||||
running_loss += loss.item() * batch_size
|
||||
sample_count += batch_size
|
||||
|
||||
epoch_loss = running_loss / sample_count
|
||||
epoch_loss = running_loss / total_samples
|
||||
return epoch_loss
|
||||
|
||||
|
||||
# 验证一轮
|
||||
@torch.no_grad()
|
||||
def valid_epoch(model, loader, criterion, device):
|
||||
model.eval()
|
||||
running_loss = 0.0
|
||||
total_samples = 0
|
||||
|
||||
for images, labels, image_names in tqdm(
|
||||
loader,
|
||||
desc="Valid:",
|
||||
bar_format="{l_bar}{bar:20}{r_bar}",
|
||||
leave=False,
|
||||
):
|
||||
images = images.to(device, non_blocking=True)
|
||||
labels = labels.to(device, non_blocking=True).float().view(-1)
|
||||
|
||||
outputs = model(images).view(-1)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
this_batch_size = images.size(0)
|
||||
running_loss += loss.item() * this_batch_size
|
||||
total_samples += this_batch_size
|
||||
|
||||
epoch_loss = running_loss / total_samples
|
||||
return epoch_loss
|
||||
|
||||
|
||||
# 主训练函数
|
||||
def main():
|
||||
# =========================
|
||||
# 1. 读取配置
|
||||
# =========================
|
||||
parser = argparse.ArgumentParser(description="Train DPNet")
|
||||
parser.add_argument("--config", type=str, required=True, help="Path to TOML config file")
|
||||
args = parser.parse_args()
|
||||
# ========== 1 配置文件与超参数 ==========
|
||||
config, config_path = utils.get_hyperparams()
|
||||
|
||||
config_path = Path(args.config)
|
||||
with config_path.open("rb") as f:
|
||||
config = tomllib.load(f)
|
||||
XLSX_FILES = config["xlsx_files"]
|
||||
BATCH_SIZE = config["batch_size"]
|
||||
NUM_WORKERS = config["num_workers"]
|
||||
LEARNING_RATE = config["learning_rate"]
|
||||
NUM_EPOCHS = config["epochs"]
|
||||
SEED = config["seed"]
|
||||
INIT_WEIGHT_PATH = config["init_weight"]
|
||||
|
||||
xlsx_files = config["xlsx_files"]
|
||||
batch_size = config["batch_size"]
|
||||
learning_rate = config["learning_rate"]
|
||||
epochs = config["epochs"]
|
||||
num_workers = config["num_workers"]
|
||||
seed = config["seed"]
|
||||
init_weight = config["init_weight"]
|
||||
|
||||
# =========================
|
||||
# 2. 创建输出目录
|
||||
# =========================
|
||||
# ========== 2 创建输出文件目录 ==========
|
||||
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_dpn")
|
||||
run_dir = Path.cwd() / run_name
|
||||
run_dir.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
shutil.copy2(config_path, run_dir / config_path.name)
|
||||
|
||||
# =========================
|
||||
# 3. 初始化日志与 TensorBoard
|
||||
# =========================
|
||||
logger = logging.getLogger("dpnet_train")
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
logger.handlers.clear()
|
||||
# ========== 3 日志、TensorBoard、随机种子与设备 ==========
|
||||
logger = utils.get_logger(__name__, run_dir / "train.log")
|
||||
writer = SummaryWriter(str(run_dir / "run"))
|
||||
utils.set_seeds(SEED)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
file_handler = logging.FileHandler(run_dir / "train.log", encoding="utf-8")
|
||||
file_handler.setLevel(logging.INFO)
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setLevel(logging.INFO)
|
||||
stream_handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
writer = SummaryWriter(log_dir=str(run_dir))
|
||||
|
||||
logger.info(f"Run directory: {run_dir}")
|
||||
logger.info(f"Config path: {config_path}")
|
||||
logger.info(f"Loaded config:\n{pformat(config, sort_dicts=False)}")
|
||||
|
||||
# =========================
|
||||
# 4. 设置随机种子与设备
|
||||
# =========================
|
||||
utils.set_seeds(seed=seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
logger.info(f"Loaded config: {str(config)}")
|
||||
logger.info(f"Run directory: {run_dir}")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# =========================
|
||||
# 5. 准备数据集与 DataLoader
|
||||
# =========================
|
||||
train_image_path_list, train_defocus_distance_list, val_image_path_list, val_defocus_distance_list = \
|
||||
dataset_F.get_DPNet_train_data_and_label(root_path_list=xlsx_files)
|
||||
# ========== 4 数据与 loader ==========
|
||||
(
|
||||
train_image_path_list,
|
||||
train_defocus_distance_list,
|
||||
valid_image_path_list,
|
||||
valid_defocus_distance_list,
|
||||
) = dataset_F.get_DPNet_train_data_and_label(root_path_list=XLSX_FILES)
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.ColorJitter(
|
||||
brightness=(0.9, 1.4),
|
||||
contrast=(0.8, 1.5),
|
||||
saturation=(0.8, 1.5),
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
train_transform = transforms.Compose(
|
||||
[
|
||||
transforms.ColorJitter(
|
||||
brightness=(0.9, 1.4),
|
||||
contrast=(0.8, 1.5),
|
||||
saturation=(0.8, 1.5),
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
valid_transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
val_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
train_dataset = dataset_F.MyDataset(
|
||||
train_set = dataset_F.MyDataset(
|
||||
train_image_path_list,
|
||||
train_defocus_distance_list,
|
||||
train_transform,
|
||||
)
|
||||
val_dataset = dataset_F.MyDataset(
|
||||
val_image_path_list,
|
||||
val_defocus_distance_list,
|
||||
val_transform,
|
||||
valid_set = dataset_F.MyDataset(
|
||||
valid_image_path_list,
|
||||
valid_defocus_distance_list,
|
||||
valid_transform,
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
dataset=train_set,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=(device.type == "cuda"),
|
||||
persistent_workers=(num_workers > 0),
|
||||
num_workers=NUM_WORKERS,
|
||||
pin_memory=True,
|
||||
persistent_workers=(NUM_WORKERS > 0),
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
dataset=val_dataset,
|
||||
batch_size=batch_size,
|
||||
valid_loader = DataLoader(
|
||||
dataset=valid_set,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=(device.type == "cuda"),
|
||||
persistent_workers=(num_workers > 0),
|
||||
num_workers=NUM_WORKERS,
|
||||
pin_memory=True,
|
||||
persistent_workers=(NUM_WORKERS > 0),
|
||||
)
|
||||
|
||||
logger.info(f"Train dataset size: {len(train_dataset)}")
|
||||
logger.info(f"Val dataset size: {len(val_dataset)}")
|
||||
logger.info(f"Train dataset size: {len(train_set)}")
|
||||
logger.info(f"Val dataset size: {len(valid_set)}")
|
||||
logger.info(f"Train steps per epoch: {len(train_loader)}")
|
||||
|
||||
# =========================
|
||||
# 6. 准备模型
|
||||
# =========================
|
||||
# ========== 5 模型、损失、优化器、调度器 ==========
|
||||
model = DPNet().to(device)
|
||||
|
||||
if init_weight:
|
||||
state_dict = torch.load(init_weight, map_location="cpu")
|
||||
if INIT_WEIGHT_PATH:
|
||||
state_dict = torch.load(INIT_WEIGHT_PATH, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
logger.info(f"Loaded init weight from: {init_weight}")
|
||||
logger.info(f"Loaded init weight from: {INIT_WEIGHT_PATH}")
|
||||
else:
|
||||
logger.info("Training from scratch")
|
||||
|
||||
# =========================
|
||||
# 7. 准备损失函数、优化器、调度器
|
||||
# =========================
|
||||
criterion = nn.MSELoss(reduction="mean")
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
|
||||
|
||||
total_steps = epochs * len(train_loader)
|
||||
warmup_steps = min(2000, total_steps - 1) if total_steps > 1 else 0
|
||||
|
||||
scheduler = SequentialLR(
|
||||
optimizer,
|
||||
schedulers=[
|
||||
LinearLR(
|
||||
optimizer,
|
||||
start_factor=0.1,
|
||||
end_factor=1.0,
|
||||
total_iters=warmup_steps,
|
||||
),
|
||||
CosineAnnealingLR(
|
||||
optimizer,
|
||||
T_max=max(1, total_steps - warmup_steps),
|
||||
eta_min=0.0,
|
||||
),
|
||||
],
|
||||
milestones=[warmup_steps],
|
||||
)
|
||||
optimizer = Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
|
||||
scheduler = utils.get_warmup_cosine_scheduler(optimizer, NUM_EPOCHS)
|
||||
|
||||
logger.info("Loss: MSELoss(reduction='mean')")
|
||||
logger.info(f"Optimizer: Adam(lr={learning_rate}, betas=(0.9, 0.999))")
|
||||
logger.info("Scheduler: step-based warmup + cosine annealing")
|
||||
logger.info(f"Total training steps: {total_steps}")
|
||||
logger.info(f"Warmup steps: {warmup_steps}")
|
||||
logger.info(f"Optimizer: Adam(lr={LEARNING_RATE}, betas=(0.9, 0.999))")
|
||||
logger.info("Scheduler: epoch-based warmup + cosine annealing")
|
||||
|
||||
# =========================
|
||||
# 8. 开始训练循环
|
||||
# =========================
|
||||
best_val_loss = float("inf")
|
||||
global_step = 0
|
||||
start_time = time.time()
|
||||
# ========== 6 开始训练 ==========
|
||||
logger.info("START TRAINING")
|
||||
best_valid_loss = float("inf")
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
epoch_start_time = time.time()
|
||||
try:
|
||||
for epoch in range(1, NUM_EPOCHS + 1):
|
||||
epoch_start_time = time.time()
|
||||
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
|
||||
valid_loss = valid_epoch(model, valid_loader, criterion, device)
|
||||
epoch_lr = optimizer.param_groups[0]["lr"]
|
||||
scheduler.step()
|
||||
epoch_time_cost = time.time() - epoch_start_time
|
||||
|
||||
train_loss, global_step = train_epoch(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
device=device,
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
)
|
||||
if valid_loss < best_valid_loss:
|
||||
best_valid_loss = valid_loss
|
||||
torch.save(model.state_dict(), run_dir / "best_model.pt")
|
||||
logger.info(f"Best model saved, valid_loss = {best_valid_loss:.4f}")
|
||||
|
||||
val_loss = validate_epoch(
|
||||
model=model,
|
||||
val_loader=val_loader,
|
||||
criterion=criterion,
|
||||
device=device,
|
||||
)
|
||||
logger.info(
|
||||
f"Epoch [{epoch}/{NUM_EPOCHS}] "
|
||||
f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f} | "
|
||||
f"Best Valid Loss: {best_valid_loss:.4f} | "
|
||||
f"Epoch Time Cost: {epoch_time_cost:.2f} s | "
|
||||
f"Epoch Learning Rate: {epoch_lr:.6e}"
|
||||
)
|
||||
|
||||
epoch_total_time = time.time() - epoch_start_time
|
||||
writer.add_scalar("Loss/train", train_loss, epoch)
|
||||
writer.add_scalar("Loss/valid", valid_loss, epoch)
|
||||
writer.add_scalar("Loss/best_valid", best_valid_loss, epoch)
|
||||
writer.add_scalar("Time/epoch", epoch_time_cost, epoch)
|
||||
writer.add_scalar("Time/learning_rate", epoch_lr, epoch)
|
||||
|
||||
writer.add_scalar("train_loss", train_loss, epoch)
|
||||
writer.add_scalar("val_loss", val_loss, epoch)
|
||||
writer.add_scalar("epoch_total_time", epoch_total_time, epoch)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
|
||||
logger.info(
|
||||
f"Epoch [{epoch}/{epochs}] | "
|
||||
f"train_loss={train_loss:.8f} | "
|
||||
f"val_loss={val_loss:.8f} | "
|
||||
f"epoch_total_time={epoch_total_time:.2f}s | "
|
||||
f"global_step={global_step}"
|
||||
)
|
||||
finally:
|
||||
torch.save(model.state_dict(), run_dir / "last_model.pt")
|
||||
logger.info("Last model saved")
|
||||
|
||||
torch.save(model.state_dict(), run_dir / "last_dpn.pth")
|
||||
writer.close()
|
||||
logger.info("TensorBoard writer closed")
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
torch.save(model.state_dict(), run_dir / "best_dpn.pth")
|
||||
logger.info(f"Best model updated, best_val_loss={best_val_loss:.8f}")
|
||||
|
||||
# =========================
|
||||
# 9. 收尾
|
||||
# =========================
|
||||
total_time = time.time() - start_time
|
||||
logger.info("Training finished")
|
||||
logger.info(f"Best validation loss: {best_val_loss:.8f}")
|
||||
logger.info(f"Total time: {total_time:.2f} seconds")
|
||||
|
||||
writer.close()
|
||||
logger.info(f"Training finished, best validation loss: {best_valid_loss:.8f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
312
train_dpn_ddp.py
Normal file
312
train_dpn_ddp.py
Normal file
@ -0,0 +1,312 @@
|
||||
import shutil
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as transforms
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
import old_datasets as dataset_F
|
||||
import utils
|
||||
from models import DPNet
|
||||
|
||||
|
||||
def reduce_epoch_loss(running_loss, total_samples, device):
|
||||
stats = torch.tensor(
|
||||
[running_loss, total_samples],
|
||||
dtype=torch.float64,
|
||||
device=device,
|
||||
)
|
||||
dist.all_reduce(stats, op=dist.ReduceOp.SUM)
|
||||
return (stats[0] / stats[1]).item()
|
||||
|
||||
|
||||
# 训练一轮
|
||||
def train_epoch(model, loader, criterion, optimizer, device, is_main_process):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
total_samples = 0
|
||||
|
||||
for images, labels, image_names in tqdm(
|
||||
loader,
|
||||
desc="Train:",
|
||||
bar_format="{l_bar}{bar:20}{r_bar}",
|
||||
leave=False,
|
||||
disable=not is_main_process,
|
||||
):
|
||||
images = images.to(device, non_blocking=True)
|
||||
labels = labels.to(device, non_blocking=True).float().view(-1)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
outputs = model(images).view(-1)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
this_batch_size = images.size(0)
|
||||
running_loss += loss.item() * this_batch_size
|
||||
total_samples += this_batch_size
|
||||
|
||||
return reduce_epoch_loss(running_loss, total_samples, device)
|
||||
|
||||
|
||||
# 验证一轮
|
||||
@torch.no_grad()
|
||||
def valid_epoch(model, loader, criterion, device, is_main_process):
|
||||
model.eval()
|
||||
running_loss = 0.0
|
||||
total_samples = 0
|
||||
|
||||
for images, labels, image_names in tqdm(
|
||||
loader,
|
||||
desc="Valid:",
|
||||
bar_format="{l_bar}{bar:20}{r_bar}",
|
||||
leave=False,
|
||||
disable=not is_main_process,
|
||||
):
|
||||
images = images.to(device, non_blocking=True)
|
||||
labels = labels.to(device, non_blocking=True).float().view(-1)
|
||||
|
||||
outputs = model(images).view(-1)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
this_batch_size = images.size(0)
|
||||
running_loss += loss.item() * this_batch_size
|
||||
total_samples += this_batch_size
|
||||
|
||||
return reduce_epoch_loss(running_loss, total_samples, device)
|
||||
|
||||
|
||||
def create_run_dir(config_path, is_main_process):
|
||||
if is_main_process:
|
||||
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_dpn_ddp")
|
||||
run_dir = Path.cwd() / run_name
|
||||
run_dir.mkdir(parents=True, exist_ok=False)
|
||||
shutil.copy2(config_path, run_dir / config_path.name)
|
||||
run_dir_text = str(run_dir)
|
||||
else:
|
||||
run_dir_text = None
|
||||
|
||||
shared_value = [run_dir_text]
|
||||
dist.broadcast_object_list(shared_value, src=0)
|
||||
dist.barrier()
|
||||
|
||||
return Path(shared_value[0])
|
||||
|
||||
|
||||
# 主训练函数
|
||||
def main():
|
||||
local_rank, rank, world_size, device, is_main_process = utils.setup_distributed()
|
||||
|
||||
logger = None
|
||||
writer = None
|
||||
model = None
|
||||
best_valid_loss = float("inf")
|
||||
|
||||
try:
|
||||
# ========== 1 配置文件与超参数 ==========
|
||||
config, config_path = utils.get_hyperparams()
|
||||
|
||||
XLSX_FILES = config["xlsx_files"]
|
||||
BATCH_SIZE = config["batch_size"]
|
||||
NUM_WORKERS = config["num_workers"]
|
||||
LEARNING_RATE = config["learning_rate"]
|
||||
NUM_EPOCHS = config["epochs"]
|
||||
SEED = config["seed"]
|
||||
INIT_WEIGHT_PATH = config["init_weight"]
|
||||
|
||||
# ========== 2 创建输出文件目录 ==========
|
||||
run_dir = create_run_dir(config_path, is_main_process)
|
||||
|
||||
# ========== 3 日志、TensorBoard、随机种子与设备 ==========
|
||||
if is_main_process:
|
||||
logger = utils.get_logger(__name__, run_dir / "train.log")
|
||||
writer = SummaryWriter(str(run_dir / "run"))
|
||||
|
||||
utils.set_seeds(SEED)
|
||||
|
||||
if is_main_process:
|
||||
logger.info(f"Config path: {config_path}")
|
||||
logger.info(f"Loaded config: {str(config)}")
|
||||
logger.info(f"Run directory: {run_dir}")
|
||||
logger.info(f"Using world size: {world_size}")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# ========== 4 数据与 loader ==========
|
||||
(
|
||||
train_image_path_list,
|
||||
train_defocus_distance_list,
|
||||
valid_image_path_list,
|
||||
valid_defocus_distance_list,
|
||||
) = dataset_F.get_DPNet_train_data_and_label(root_path_list=XLSX_FILES)
|
||||
|
||||
train_transform = transforms.Compose(
|
||||
[
|
||||
transforms.ColorJitter(
|
||||
brightness=(0.9, 1.4),
|
||||
contrast=(0.8, 1.5),
|
||||
saturation=(0.8, 1.5),
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
valid_transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
train_set = dataset_F.MyDataset(
|
||||
train_image_path_list,
|
||||
train_defocus_distance_list,
|
||||
train_transform,
|
||||
)
|
||||
valid_set = dataset_F.MyDataset(
|
||||
valid_image_path_list,
|
||||
valid_defocus_distance_list,
|
||||
valid_transform,
|
||||
)
|
||||
|
||||
train_sampler = DistributedSampler(
|
||||
train_set,
|
||||
num_replicas=world_size,
|
||||
rank=rank,
|
||||
shuffle=True,
|
||||
)
|
||||
valid_sampler = DistributedSampler(
|
||||
valid_set,
|
||||
num_replicas=world_size,
|
||||
rank=rank,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset=train_set,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=False,
|
||||
sampler=train_sampler,
|
||||
num_workers=NUM_WORKERS,
|
||||
pin_memory=True,
|
||||
persistent_workers=(NUM_WORKERS > 0),
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dataset=valid_set,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=False,
|
||||
sampler=valid_sampler,
|
||||
num_workers=NUM_WORKERS,
|
||||
pin_memory=True,
|
||||
persistent_workers=(NUM_WORKERS > 0),
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
logger.info(f"Train dataset size: {len(train_set)}")
|
||||
logger.info(f"Val dataset size: {len(valid_set)}")
|
||||
logger.info(f"Train steps per epoch per process: {len(train_loader)}")
|
||||
|
||||
# ========== 5 模型、损失、优化器、调度器 ==========
|
||||
model = DPNet().to(device)
|
||||
if INIT_WEIGHT_PATH:
|
||||
state_dict = torch.load(INIT_WEIGHT_PATH, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
if is_main_process:
|
||||
logger.info(f"Loaded init weight from: {INIT_WEIGHT_PATH}")
|
||||
elif is_main_process:
|
||||
logger.info("Training from scratch")
|
||||
|
||||
model = DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[local_rank],
|
||||
output_device=local_rank,
|
||||
)
|
||||
|
||||
criterion = nn.MSELoss(reduction="mean")
|
||||
optimizer = Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
|
||||
scheduler = utils.get_warmup_cosine_scheduler(optimizer, NUM_EPOCHS)
|
||||
|
||||
if is_main_process:
|
||||
logger.info("Loss: MSELoss(reduction='mean')")
|
||||
logger.info(f"Optimizer: Adam(lr={LEARNING_RATE}, betas=(0.9, 0.999))")
|
||||
logger.info("Scheduler: epoch-based warmup + cosine annealing")
|
||||
|
||||
# ========== 6 开始训练 ==========
|
||||
if is_main_process:
|
||||
logger.info("START TRAINING")
|
||||
|
||||
try:
|
||||
for epoch in range(1, NUM_EPOCHS + 1):
|
||||
train_sampler.set_epoch(epoch)
|
||||
epoch_start_time = time.time()
|
||||
|
||||
train_loss = train_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
device,
|
||||
is_main_process,
|
||||
)
|
||||
valid_loss = valid_epoch(
|
||||
model,
|
||||
valid_loader,
|
||||
criterion,
|
||||
device,
|
||||
is_main_process,
|
||||
)
|
||||
epoch_lr = optimizer.param_groups[0]["lr"]
|
||||
scheduler.step()
|
||||
epoch_time_cost = time.time() - epoch_start_time
|
||||
|
||||
if is_main_process and valid_loss < best_valid_loss:
|
||||
best_valid_loss = valid_loss
|
||||
torch.save(model.module.state_dict(), run_dir / "best_model.pt")
|
||||
logger.info(f"Best model saved, valid_loss = {best_valid_loss:.4f}")
|
||||
|
||||
if is_main_process:
|
||||
logger.info(
|
||||
f"Epoch [{epoch}/{NUM_EPOCHS}] "
|
||||
f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f} | "
|
||||
f"Best Valid Loss: {best_valid_loss:.4f} | "
|
||||
f"Epoch Time Cost: {epoch_time_cost:.2f} s | "
|
||||
f"Epoch Learning Rate: {epoch_lr:.6e}"
|
||||
)
|
||||
|
||||
writer.add_scalar("Loss/train", train_loss, epoch)
|
||||
writer.add_scalar("Loss/valid", valid_loss, epoch)
|
||||
writer.add_scalar("Loss/best_valid", best_valid_loss, epoch)
|
||||
writer.add_scalar("Time/epoch", epoch_time_cost, epoch)
|
||||
writer.add_scalar("Time/learning_rate", epoch_lr, epoch)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
if is_main_process:
|
||||
logger.info("Training interrupted by user")
|
||||
|
||||
finally:
|
||||
if is_main_process and model is not None:
|
||||
torch.save(model.module.state_dict(), run_dir / "last_model.pt")
|
||||
logger.info("Last model saved")
|
||||
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
logger.info("TensorBoard writer closed")
|
||||
|
||||
if is_main_process and logger is not None:
|
||||
logger.info(f"Training finished, best validation loss: {best_valid_loss:.8f}")
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
utils.cleanup_distributed()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -90,7 +90,7 @@ def main():
|
||||
INIT_WEIGHT_PATH = config["init_weight"]
|
||||
|
||||
# ========== 2 创建输出文件目录 ==========
|
||||
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_RIN")
|
||||
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_rin")
|
||||
run_dir = Path.cwd() / run_name
|
||||
run_dir.mkdir(parents=True, exist_ok=False)
|
||||
shutil.copy2(config_path, run_dir / config_path.name)
|
||||
|
||||
6
utils.py
6
utils.py
@ -1,4 +1,5 @@
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
import random
|
||||
import logging
|
||||
@ -60,7 +61,7 @@ def setup_distributed():
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
local_rank = dist.get_node_local_rank()
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
@ -74,7 +75,8 @@ def setup_distributed():
|
||||
|
||||
# 释放 DDP 并行
|
||||
def cleanup_distributed():
|
||||
dist.destroy_process_group()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
# 命令行参数解析配置文件
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user