update project structure and training scripts

This commit is contained in:
kaiza_hikaru 2026-06-02 17:52:26 +08:00
parent 1b6fcf93d0
commit 27155ecfb5
12 changed files with 1739 additions and 1122 deletions

6
.gitignore vendored
View File

@ -2,6 +2,8 @@ __pycache__/
*.py[cod] *.py[cod]
.vscode/ .vscode/
old_proj/
20*_RIN*/ 20*/
20*_RIN_OLD*/
*.xlsx

119
PROJECT_CONVENTIONS.md Normal file
View File

@ -0,0 +1,119 @@
# SparseFocus 项目约定
## 项目理解
SparseFocus 面向自动对焦中的离焦预测问题。
基础任务是根据显微图像内容预测离焦程度。当前工作的核心判断是:内容稀疏视野会削弱离焦预测的稳定性。当视野中缺少足够有效结构时,直接从整图或任意局部区域回归离焦量,容易引入不可靠信号。
因此,项目目标中的主要流程应当是两阶段:
1. 先预测区域重要度,也就是判断每个局部区域是否包含有用信息。
2. 再使用重要区域作为有效证据进行离焦预测。
当前代码结构中:
- `RINet` 用于从输入图像预测 `9x9` 的区域重要度图。
- `DPNet` 用于预测离焦距离。
- `train_rin.py` 是当前较清晰的区域重要度训练入口。
- `train_dpn.py` 是当前离焦回归任务的训练入口。
- `old_datasets.py` 仍然承载了较多历史数据入口逻辑和本地路径假设。
- `datasets.py` 中已有较干净的 RIN 任务专用数据集和同步变换实现。
- TOML 配置文件用于将训练超参数从训练脚本中分离出来。
后续重构应当保留这个研究意图,同时让任务边界、数据契约和实验产物更加明确。
## 重构边界
`old_proj/` 是只读参考区域。
具体规则:
- `old_proj/` 下的代码、数据、笔记和产物只能读取参考。
- 重构过程中不得在 `old_proj/` 下创建、编辑、移动、重命名、格式化或删除任何文件。
- 所有实现变更都必须发生在 `old_proj/` 之外。
- 如果需要借鉴 `old_proj/` 中的历史逻辑,只能将思路手动迁移到主项目结构中,并在主项目内重新整理实现。
- 不要从 `old_proj/` 直接导入运行时代码;它不应成为当前工程依赖图的一部分。
这条边界用于保留旧快速迭代项目的历史上下文,同时避免旧结构和误操作继续泄漏到新的工程重构中。
## 代码风格约定
### 模块职责
- 模型、数据集、训练循环、工具函数、配置文件和实验输出应当分离。
- 当标签形状或返回值契约不同的时候,优先使用任务专用的数据集类。
- 除非共享契约已经明确,否则不要让同一个数据集辅助函数同时服务区域重要度分类和离焦回归。
- 路径发现和标签解析应放在数据层,不应放进模型逻辑或训练循环。
### 命名
- 任务名称应保持一致:
- `RIN` / `RINet` 表示区域重要度预测。
- `DP` / `DPNet` 表示离焦预测。
- 优先使用能表达含义的变量名,例如 `train_image_paths``valid_labels``run_dir``best_valid_loss`
- 新代码中如果领域含义已经明确,应避免使用 `type``data``label``imag` 这类过于泛化的命名。
### 训练脚本
- 训练脚本应保持稳定顺序:
1. 读取配置。
2. 创建运行目录。
3. 初始化日志、TensorBoard writer、随机种子和设备。
4. 构建数据集和 dataloader。
5. 构建模型、损失函数、优化器和学习率调度器。
6. 执行训练和验证循环。
7. 保存最佳权重和最后权重。
- 单轮训练函数和单轮验证函数应与 `main()` 分离。
- 验证阶段应使用 `torch.no_grad()`
- 日志中应记录配置路径、运行目录、数据集规模、优化器、调度器、每轮损失、最佳损失和权重保存事件。
- 每次实验运行都应将配置文件复制到对应运行目录中。
### 配置与实验
- 超参数应放在 TOML 配置文件中,不应硬编码在训练脚本中。
- 实验输出应写入带时间戳的运行目录,并保持被 Git 忽略。
- checkpoint、TensorBoard event 文件和训练日志属于运行时产物,不属于源码。
- 如果配置中的数据路径指向本机特定绝对路径,应先视为临时约定,直到引入更可移植的数据根目录机制。
### 模型代码
- 神经网络模块应保持小而可组合。
- 对有明确语义的模型组件,优先使用显式的 `nn.Module` 类。
- 当张量形状不明显时,应在模块边界附近记录形状预期,尤其是 `9x9` 区域重要度图和展平后的回归输出。
- 如果构造函数已经提供通道数参数,辅助模块内部应避免再次硬编码通道数。
### 数据契约
- 数据集 `__getitem__` 的返回值应稳定,并且与具体任务匹配。
- 图像张量在进入模型前应完成一致的归一化和形状整理。
- 除非模型明确设计为其他通道契约,否则读取图像时应显式转换为 RGB。
- 从文件名、Excel 或 JSON 中解析标签的逻辑,应与训练循环隔离,并适合单独测试。
- 数据加载函数中应避免隐藏依赖本机绝对路径。
### 编码与注释
- 新文件应使用 UTF-8 编码。
- 现有乱码中文注释应在聚焦重构时修复或删除,不应复制到新代码中。
- 绝大多数时候应使用 ASCII 字符尤其是代码标识、短字符串、运行时输出、表格字段名、Sheet 名和指标枚举值。
- 较长的大段文档、代码注释可以并且应该使用中文描述,便于直接理解。
- 保存到结果文件中的字段名、Sheet 名、指标枚举值等应使用英文,例如 `image_path``sparsity_level``dense`
- 注释应简短并解释真实背景,优先说明领域假设、张量形状、数据格式和不明显的训练选择。
- 不要用注释重复解释简单代码本身。
### 执行检查
- 代码运行、训练、测试和静态检查默认由用户执行。
- 除非用户明确要求,助手不主动运行代码检查、训练脚本或测试脚本。
### 导入与格式
- 导入顺序应按标准库、第三方库、本地模块分组。
- 新代码中优先使用 `pathlib.Path` 处理文件系统路径。
- 格式应兼容常见 Python 工具;窄范围改动时避免引入无关格式化变更。
### Git 卫生
- `old_proj/``.vscode/``__pycache__/` 和带时间戳的实验目录应保持被忽略。
- 不要提交生成的 checkpoint、TensorBoard 日志、训练日志或 Python 字节码。
- 重构时应保持改动范围清晰,并围绕有明确含义的工程节点进行提交。

3
config_test.toml Normal file
View File

@ -0,0 +1,3 @@
excel_file = "E:/Datasets/SparseFocusDataset/DenseSparse/testData_20241227.xlsx"
rin_weight = "F:/Projects/SparseFocus/2026_05_06_19_36_35_RIN/best_model.pt"
dpn_weight = "F:/Projects/SparseFocus/2026_04_18_18_37_07_dpn/best_dpn.pth"

View File

@ -15,6 +15,8 @@ from openpyxl import load_workbook
import json import json
import re import re
from pathlib import Path
# REVIEW: # REVIEW:
# 这个文件是项目的数据入口层,负责把 Excel、JSON 和图像文件组织成训练/测试阶段可消费的数据。 # 这个文件是项目的数据入口层,负责把 Excel、JSON 和图像文件组织成训练/测试阶段可消费的数据。
# 整体上它解决了“如何把作者本地数据组织接到模型上”的问题,但工程抽象较弱,路径与数据格式都强依赖作者环境。 # 整体上它解决了“如何把作者本地数据组织接到模型上”的问题,但工程抽象较弱,路径与数据格式都强依赖作者环境。
@ -139,16 +141,21 @@ def get_test_data_and_label(root_path: str, type: str):
# REVIEW: # REVIEW:
# 这里继续依赖固定数据目录结构,因此 test.py 的可运行性高度依赖作者原始数据目录。 # 这里继续依赖固定数据目录结构,因此 test.py 的可运行性高度依赖作者原始数据目录。
field_path = os.path.join( field_path = os.path.join(
r'E:\suqiang\DenseSparse\cropped', 'E:/Datasets/SparseFocusDataset/DenseSparse/cropped',
row[0], '*.jpg') row[0], '*.jpg')
image_path_list = glob.glob(field_path) image_path_list = glob.glob(field_path)
for image_patch in image_path_list: for image_patch in image_path_list:
labels = int(image_patch.split('\\')[-1].split('.')[0]) labels = int(image_patch.split('\\')[-1].split('.')[0])
if labels >= -25000 and labels <= 25000: if labels >= -25000 and labels <= 25000:
train_image_path_list.append(image_patch) train_image_path_list.append(image_patch.replace("\\", "/"))
# REVIEW:
# effective 这里统一填 0说明测试阶段 patch 是否可用不来自数据标注,而是后续 RINet 推理结果。 json_data_path = Path(image_patch.replace("\\", "/")).with_name("224_patch_effective.json").as_posix()
train_patch_effective_list.append(0) with open(json_data_path, 'r') as fs:
json_data = json.load(fs)
json_data = json.loads(json_data)
json_data = json_data['image_info_list']
train_patch_effective_list.append(json_data[0]['patch_effective'])
train_defocus_distance_list.append(labels / 1000) train_defocus_distance_list.append(labels / 1000)
num_i = num_i + 1 num_i = num_i + 1
print('num_i: ',num_i) print('num_i: ',num_i)

View File

@ -1,268 +0,0 @@
import os
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from model.DPNet import register_model
from model.RINet import MobileNetV3_small as con_classification
from dataset import datasets as dataset_F
import util.utils as util
from tqdm import tqdm
import openpyxl
# REVIEW:
# 该脚本是项目的测试/推理主入口,也是两阶段方法真正汇合的地方:
# 先用 RINet 对整图打 patch 有效性分数,再用 DPNet 对候选 patch 做 defocus distance 回归,
# 最后通过中位数聚合得到整图预测。
#
# REVIEW:
# 如果想快速理解作者的方法,这个文件是最关键的,因为它最完整地体现了算法在实际评估中的执行顺序。
# Set the visible GPU devices for CUDA
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
# Image patch size
size = 224
# Paths to test data Excel files
root_path_list = [r'E:\suqiang\DenseSparse\testData_20241227.xlsx']
# REVIEW:
# 测试集路径完全硬编码,说明当前脚本更像作者自己的实验工具而不是通用命令行程序。
# Define data sheet types and test categories (cell or tissue)
data_type_list = ['Sheet', 'Sheet1', 'Sheet1', 'Sheet1']
cell_or_tissue_list = ['tissue', 'tissue']
# Load the classification model and pre-trained weights
classification_model = con_classification()
state_load = torch.load('../weight/RINet_best_model.pt')
classification_model = nn.DataParallel(classification_model).to('cuda')
classification_model.load_state_dict(state_load, False)
classification_model.eval() # Set the model to evaluation mode
# REVIEW:
# RINet 权重和结构之间没有版本检查,只有运行时才能暴露兼容性问题。
# Loop for different models (e.g., DPNet)
for c in range(1):
if c == 0:
state_load = torch.load('../weight/DP_best_model.pt')
model = register_model()
names = 'ours'
model = nn.DataParallel(model).to('cuda')
model.load_state_dict(state_load, False)
model.eval() # Set the model to evaluation mode
# Iterate over test datasets
for j in range(len(root_path_list)):
root_path = root_path_list[j]
data_type = data_type_list[j]
cell_or_tissue = cell_or_tissue_list[j]
# Load test data and labels from the specified Excel sheet
(train_image_path_list, train_patch_effective_list,
train_defocus_distance_list) = dataset_F.get_test_data_and_label(root_path=root_path, type=f'{data_type}')
# Define image transformation pipeline
train_transform = transforms.Compose([
transforms.CenterCrop((2016, 2016)), # Crop the center region of the image
transforms.ToTensor(), # Convert the image to a tensor
])
# REVIEW:
# 这里把整图固定裁成 2016x2016并配合 size=224 形成 9x9 patch 网格。
# 因此测试流程对输入尺寸有很强假设。
# Create a custom dataset and DataLoader
train_data = dataset_F.MyDataset_test(train_image_path_list, train_defocus_distance_list,
train_patch_effective_list, train_transform)
train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, pin_memory=True, num_workers=0)
# Initialize an Excel workbook to save results
workbook = openpyxl.Workbook()
sheet = workbook.create_sheet(index=0)
n = 0 # Row index for Excel sheet
# Iterate through the test data
for i, (img, label, image_name, effective) in tqdm(enumerate(train_loader)):
image_path = image_name[0]
image_label = int(image_path.split('\\')[-1].split('.')[0]) # Extract ground truth label
# Add header row to the Excel sheet on the first iteration
if i == 0:
sheet.cell(1, 1).value = 'groundTruth'
sheet.cell(1, 2).value = 'all_prediction'
sheet.cell(1, 3).value = 'all_can_use_prediction'
sheet.cell(1, 4).value = '61_prediction'
sheet.cell(1, 5).value = '51_prediction'
sheet.cell(1, 6).value = '41_prediction'
sheet.cell(1, 7).value = '31_prediction'
sheet.cell(1, 8).value = '25_prediction'
sheet.cell(1, 9).value = '19_prediction'
sheet.cell(1, 10).value = '15_prediction'
sheet.cell(1, 11).value = '9_prediction'
sheet.cell(1, 12).value = '5_prediction'
sheet.cell(1, 13).value = '3_prediction'
sheet.cell(1, 14).value = '1_prediction'
sheet.cell(1, 15).value = 'path'
sheet.cell(1, 16).value = 'max_can_use_patch_num'
sheet.cell(1, 17).value = 'avg_prediction'
# Preprocess the image for the classification model
img2 = F.interpolate(img, size=(512, 512), mode='bilinear', align_corners=False)
img = torch.cuda.FloatTensor(np.array(img))
img2 = torch.cuda.FloatTensor(np.array(img2))
# REVIEW:
# 这里继续沿用了 Tensor -> numpy -> CUDA Tensor 的旧式转换写法。
if torch.cuda.is_available():
img = img.cuda()
img2 = img2.cuda()
# Perform classification on the preprocessed image
output_weight = classification_model(img2)
output_weight_userful = output_weight[torch.abs(output_weight) > 0.8] # Filter predictions above a threshold
can_use_num = len(output_weight_userful) # Count usable predictions
# REVIEW:
# patch 是否“可用”由阈值 0.8 决定,这是一个非常关键的经验超参数。
# Initialize usable patch numbers for different thresholds
can_use_num_61, can_use_num_51, can_use_num_41 = 61, 51, 41
can_use_num_31, can_use_num_25, can_use_num_19 = 31, 25, 19
can_use_num_15, can_use_num_9, can_use_num_5 = 15, 9, 5
can_use_num_3, can_use_num_1 = 3, 1
# Handle cases with no usable patches
if can_use_num == 0:
util.print_info(f'{image_name} don\'t have enough info')
continue
# REVIEW:
# 没有任何 patch 通过阈值时直接跳过样本,会导致结果文件中少样本,这是需要知晓的隐式行为。
# Adjust usable patch numbers based on conditions
if can_use_num % 2 == 0:
can_use_num -= 1
if 51 <= can_use_num < 61:
can_use_num_61 = can_use_num
elif 41 <= can_use_num < 51:
can_use_num_61 = can_use_num_51 = can_use_num
elif 31 <= can_use_num < 41:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num
elif 25 <= can_use_num < 31:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num
elif 19 <= can_use_num < 25:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num
elif 15 <= can_use_num < 19:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num
elif 9 <= can_use_num < 15:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num
elif 5 <= can_use_num < 9:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num
elif 3 <= can_use_num < 5:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num
elif can_use_num < 3:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num_3 = 1
# Extract top predictions for each patch group
max_all_can_value, max_all_can_index = torch.topk(output_weight, can_use_num)
max_61_value, max_61_index = torch.topk(output_weight, can_use_num_61)
max_51_value, max_51_index = torch.topk(output_weight, can_use_num_51)
max_41_value, max_41_index = torch.topk(output_weight, can_use_num_41)
max_31_value, max_31_index = torch.topk(output_weight, can_use_num_31)
max_25_value, max_25_index = torch.topk(output_weight, can_use_num_25)
max_19_value, max_19_index = torch.topk(output_weight, can_use_num_19)
max_15_value, max_15_index = torch.topk(output_weight, can_use_num_15)
max_9_value, max_9_index = torch.topk(output_weight, can_use_num_9)
max_5_value, max_5_index = torch.topk(output_weight, can_use_num_5)
max_3_value, max_3_index = torch.topk(output_weight, can_use_num_3)
max_1_value, max_1_index = torch.topk(output_weight, can_use_num_1)
patch_image_list_all = []
patch_image_list_all_can_use = []
patch_image_list_61 = []
patch_image_list_51 = []
patch_image_list_41 = []
patch_image_list_31 = []
patch_image_list_25 = []
patch_image_list_19 = []
patch_image_list_15 = []
patch_image_list_9 = []
patch_image_list_5 = []
patch_image_list_3 = []
patch_image_list_1 = []
label_list = []
# REVIEW:
# 这里同时比较多套 top-k 聚合结果,说明作者不仅关心最终预测,也在研究“保留多少 patch 更合理”。
for k in range(0, 2016, size):
a = int(k / size)
for j in range(0, 2016, size):
b = int(j / size)
patch_image = img[:, :, k:k + size, j:j + size]
patch_image_list_all.append(patch_image)
if len(patch_image_list_all) != 0:
out_all = model(torch.cat(patch_image_list_all, dim=0))
# REVIEW:
# 81 个 patch 一次性送入 DPNet 是个不错的实现,避免了逐 patch 前向的重复开销。
for z in range(len(out_all)):
if torch.any(max_all_can_index.eq(z)):
patch_image_list_all_can_use.append(out_all[z].item())
if torch.any(max_61_index.eq(z)):
patch_image_list_61.append(out_all[z].item())
if torch.any(max_51_index.eq(z)):
patch_image_list_51.append(out_all[z].item())
if torch.any(max_41_index.eq(z)):
patch_image_list_41.append(out_all[z].item())
if torch.any(max_31_index.eq(z)):
patch_image_list_31.append(out_all[z].item())
if torch.any(max_25_index.eq(z)):
patch_image_list_25.append(out_all[z].item())
if torch.any(max_19_index.eq(z)):
patch_image_list_19.append(out_all[z].item())
if torch.any(max_15_index.eq(z)):
patch_image_list_15.append(out_all[z].item())
if torch.any(max_9_index.eq(z)):
patch_image_list_9.append(out_all[z].item())
if torch.any(max_5_index.eq(z)):
patch_image_list_5.append(out_all[z].item())
if torch.any(max_3_index.eq(z)):
patch_image_list_3.append(out_all[z].item())
if torch.any(max_1_index.eq(z)):
patch_image_list_1.append(out_all[z].item())
nums_all = float(np.median(out_all.tolist()))
avg_nums_all = float(np.mean(out_all.tolist()))
data = []
data.append(image_label)
data.append(float(np.median(out_all.tolist())) * 1000)
data.append(float(np.median(patch_image_list_all_can_use)) * 1000)
data.append(float(np.median(patch_image_list_61)) * 1000)
data.append(float(np.median(patch_image_list_51)) * 1000)
data.append(float(np.median(patch_image_list_41)) * 1000)
data.append(float(np.median(patch_image_list_31)) * 1000)
data.append(float(np.median(patch_image_list_25)) * 1000)
data.append(float(np.median(patch_image_list_19)) * 1000)
data.append(float(np.median(patch_image_list_15)) * 1000)
data.append(float(np.median(patch_image_list_9)) * 1000)
data.append(float(np.median(patch_image_list_5)) * 1000)
data.append(float(np.median(patch_image_list_3)) * 1000)
data.append(float(np.median(patch_image_list_1)) * 1000)
data.append((image_path.split('\\')[-3] + '\\' +image_path.split('\\')[-2] + '\\' +image_path.split('\\')[-1]))
data.append(len(output_weight_userful))
data.append(float(np.mean(out_all.tolist())) * 1000)
for (m, info) in enumerate(data):
sheet.cell(n + 2, m + 1).value = info
n = n + 1
util.save_val_image(1, img,
score=out_all, image_name=image_name,
type_str=cell_or_tissue, type='test', type_label=output_weight)
# REVIEW:
# 数值输出之外还保留了网格可视化,这对解释模型行为和排查异常样本很有帮助。
workbook.save(f'./{names}_to_{cell_or_tissue}_{data_type}_2.xlsx')
# REVIEW:
# 最终结果保存为 Excel非常贴合论文实验场景如果做批量评估结构化文本格式会更便于分析。

File diff suppressed because one or more lines are too long

120
stat_test_results.py Normal file
View File

@ -0,0 +1,120 @@
import argparse
from pathlib import Path
import pandas as pd
REQUIRED_COLUMNS = [
"image_path",
"roi_no",
"importance_label",
"importance_prediction",
"defocus_label",
"defocus_prediction",
]
def get_args():
parser = argparse.ArgumentParser(description="Aggregate SparseFocus test results")
parser.add_argument(
"result_file",
nargs="?",
default="test_results.xlsx",
help="Path to test result xlsx file",
)
return parser.parse_args()
def check_columns(df):
missing_columns = [column for column in REQUIRED_COLUMNS if column not in df.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
def get_sparsity_level(usable_count):
if usable_count == 0:
return "invalid"
if usable_count > 41:
return "dense"
if usable_count > 9:
return "sparse"
return "extremely_sparse"
def iter_image_groups(df):
if len(df) % 81 != 0:
raise ValueError(f"Row count should be divisible by 81, got {len(df)}.")
for start in range(0, len(df), 81):
group = df.iloc[start: start + 81].copy()
expected_roi_no = list(range(81))
actual_roi_no = group["roi_no"].astype(int).tolist()
if actual_roi_no != expected_roi_no:
image_path = group["image_path"].iloc[0]
raise ValueError(f"ROI order mismatch near image: {image_path}")
yield group
def aggregate_one_image(group):
image_path = group["image_path"].iloc[0]
defocus_label = group["defocus_label"].iloc[0]
usable_count = int((group["importance_label"] > 0).sum())
sparsity_level = get_sparsity_level(usable_count)
selected_by_prediction = group[group["importance_prediction"] > 0.8]
if len(selected_by_prediction) > 0:
pred_importance_gt_0_8 = selected_by_prediction["defocus_prediction"].median()
else:
pred_importance_gt_0_8 = pd.NA
sorted_group = group.sort_values(
by="importance_prediction",
ascending=False,
kind="mergesort",
)
row = {
"image_path": image_path,
"defocus_label": defocus_label,
"usable_roi_count": usable_count,
"sparsity_level": sparsity_level,
"pred_importance_gt_0_8": pred_importance_gt_0_8,
"all_blocks_median": group["defocus_prediction"].median(),
}
for k in range(81, 0, -1):
row[f"top_{k}_median"] = sorted_group.head(k)["defocus_prediction"].median()
return row
def aggregate_results(df):
check_columns(df)
rows = [aggregate_one_image(group) for group in iter_image_groups(df)]
return pd.DataFrame(rows)
def main():
args = get_args()
result_path = Path(args.result_file)
print(f"Reading test results: {result_path}")
df = pd.read_excel(result_path, sheet_name=0)
stat_df = aggregate_results(df)
print(f"Image count: {len(stat_df)}")
print("Writing Sheet2")
with pd.ExcelWriter(
result_path,
engine="openpyxl",
mode="a",
if_sheet_exists="replace",
) as writer:
stat_df.to_excel(writer, sheet_name="Sheet2", index=False)
print(f"Saved statistics to Sheet2: {result_path}")
if __name__ == "__main__":
main()

145
test.py Normal file
View File

@ -0,0 +1,145 @@
from pathlib import Path
import torch
import pandas as pd
from PIL import Image
from torchvision.transforms import functional as F
from tqdm import tqdm
import old_datasets as dataset_F
import utils
from models import DPNet, RINet
def strip_module_prefix(state_dict):
if not all(key.startswith("module.") for key in state_dict):
return state_dict
return {key.removeprefix("module."): value for key, value in state_dict.items()}
def load_model_weight(model, weight_path, device):
state_dict = torch.load(weight_path, map_location="cpu")
if isinstance(state_dict, dict) and "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(strip_module_prefix(state_dict), strict=True)
model.to(device)
model.eval()
return model
def build_patch_tensor(cropped_image):
cropped_tensor = F.to_tensor(cropped_image)
patches = []
for row in range(9):
for col in range(9):
top = row * 224
left = col * 224
patch = cropped_tensor[:, top: top + 224, left: left + 224]
patches.append(patch)
return torch.stack(patches, dim=0)
def predict_one_image(image_path, importance_label, defocus_label, rin_model, dpn_model, device):
if len(importance_label) != 81:
raise RuntimeError(f"Importance label count should be 81, got {len(importance_label)}.")
image = Image.open(image_path).convert("RGB")
cropped_image = F.center_crop(image, [2016, 2016])
rin_image = F.resize(cropped_image, [512, 512])
rin_tensor = F.to_tensor(rin_image).unsqueeze(0).to(device)
patch_tensor = build_patch_tensor(cropped_image).to(device)
with torch.no_grad():
importance_predictions = rin_model(rin_tensor).reshape(-1).detach().cpu()
defocus_predictions = dpn_model(patch_tensor).reshape(-1).detach().cpu()
if len(importance_predictions) != 81:
raise RuntimeError(f"RIN output count should be 81, got {len(importance_predictions)}.")
if len(defocus_predictions) != 81:
raise RuntimeError(f"DPN output count should be 81, got {len(defocus_predictions)}.")
rows = []
for roi_no in range(81):
rows.append(
[
str(image_path),
roi_no,
importance_label[roi_no],
float(importance_predictions[roi_no]),
float(defocus_label) * 1000.0,
float(defocus_predictions[roi_no]) * 1000.0,
]
)
return rows
def save_results(rows, output_path):
df = pd.DataFrame(
rows,
columns=[
"image_path",
"roi_no",
"importance_label",
"importance_prediction",
"defocus_label",
"defocus_prediction",
],
)
df.to_excel(output_path, index=False)
def main():
config, config_path = utils.get_hyperparams()
excel_file = config["excel_file"]
rin_weight = config["rin_weight"]
dpn_weight = config["dpn_weight"]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
output_path = Path.cwd() / "test_results.xlsx"
print(f"Config path: {config_path}")
print(f"Using device: {device}")
print("Loading models")
rin_model = load_model_weight(RINet(), rin_weight, device)
dpn_model = load_model_weight(DPNet(), dpn_weight, device)
print("Loading test data")
image_paths, importance_labels, defocus_labels = dataset_F.get_test_data_and_label(
excel_file,
"Sheet",
)
print(f"Test image count: {len(image_paths)}")
all_rows = []
for image_path, importance_label, defocus_label in tqdm(
zip(image_paths, importance_labels, defocus_labels),
total=len(image_paths),
desc="Test",
bar_format="{l_bar}{bar:20}{r_bar}",
):
all_rows.extend(
predict_one_image(
image_path,
importance_label,
defocus_label,
rin_model,
dpn_model,
device,
)
)
save_results(all_rows, output_path)
print(f"Saved test results: {output_path}")
print(f"Saved row count: {len(all_rows)}")
if __name__ == "__main__":
main()

View File

@ -1,307 +1,226 @@
import argparse
import logging
import shutil import shutil
import time import time
from tqdm import tqdm
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from pprint import pformat
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torch.optim import Adam from torch.optim import Adam
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import tomllib from tqdm import tqdm
from models import DPNet
import old_datasets as dataset_F import old_datasets as dataset_F
import utils import utils
from models import DPNet
def train_epoch(model, train_loader, criterion, optimizer, scheduler, device, writer, global_step): # 训练一轮
def train_epoch(model, loader, criterion, optimizer, device):
model.train() model.train()
running_loss = 0.0 running_loss = 0.0
sample_count = 0 total_samples = 0
for img, label, image_name in tqdm(train_loader, desc="Train", bar_format="{l_bar}{bar:20}{r_bar}"): for images, labels, image_names in tqdm(
img = img.to(device, non_blocking=True) loader,
label = label.to(device, non_blocking=True).float().view(-1) desc="Train:",
bar_format="{l_bar}{bar:20}{r_bar}",
leave=False,
):
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).float().view(-1)
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
output = model(img).view(-1) outputs = model(images).view(-1)
loss = criterion(output, label) loss = criterion(outputs, labels)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
scheduler.step()
global_step += 1 this_batch_size = images.size(0)
lr = optimizer.param_groups[0]["lr"] running_loss += loss.item() * this_batch_size
writer.add_scalar("lr", lr, global_step) total_samples += this_batch_size
batch_size = img.size(0) epoch_loss = running_loss / total_samples
running_loss += loss.item() * batch_size
sample_count += batch_size
epoch_loss = running_loss / sample_count
return epoch_loss, global_step
@torch.no_grad()
def validate_epoch(model, val_loader, criterion, device):
model.eval()
running_loss = 0.0
sample_count = 0
for img, label, image_name in tqdm(val_loader, desc="Validate", bar_format="{l_bar}{bar:20}{r_bar}"):
img = img.to(device, non_blocking=True)
label = label.to(device, non_blocking=True).float().view(-1)
output = model(img).view(-1)
loss = criterion(output, label)
batch_size = img.size(0)
running_loss += loss.item() * batch_size
sample_count += batch_size
epoch_loss = running_loss / sample_count
return epoch_loss return epoch_loss
# 验证一轮
@torch.no_grad()
def valid_epoch(model, loader, criterion, device):
model.eval()
running_loss = 0.0
total_samples = 0
for images, labels, image_names in tqdm(
loader,
desc="Valid:",
bar_format="{l_bar}{bar:20}{r_bar}",
leave=False,
):
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).float().view(-1)
outputs = model(images).view(-1)
loss = criterion(outputs, labels)
this_batch_size = images.size(0)
running_loss += loss.item() * this_batch_size
total_samples += this_batch_size
epoch_loss = running_loss / total_samples
return epoch_loss
# 主训练函数
def main(): def main():
# ========================= # ========== 1 配置文件与超参数 ==========
# 1. 读取配置 config, config_path = utils.get_hyperparams()
# =========================
parser = argparse.ArgumentParser(description="Train DPNet")
parser.add_argument("--config", type=str, required=True, help="Path to TOML config file")
args = parser.parse_args()
config_path = Path(args.config) XLSX_FILES = config["xlsx_files"]
with config_path.open("rb") as f: BATCH_SIZE = config["batch_size"]
config = tomllib.load(f) NUM_WORKERS = config["num_workers"]
LEARNING_RATE = config["learning_rate"]
NUM_EPOCHS = config["epochs"]
SEED = config["seed"]
INIT_WEIGHT_PATH = config["init_weight"]
xlsx_files = config["xlsx_files"] # ========== 2 创建输出文件目录 ==========
batch_size = config["batch_size"]
learning_rate = config["learning_rate"]
epochs = config["epochs"]
num_workers = config["num_workers"]
seed = config["seed"]
init_weight = config["init_weight"]
# =========================
# 2. 创建输出目录
# =========================
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_dpn") run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_dpn")
run_dir = Path.cwd() / run_name run_dir = Path.cwd() / run_name
run_dir.mkdir(parents=True, exist_ok=False) run_dir.mkdir(parents=True, exist_ok=False)
shutil.copy2(config_path, run_dir / config_path.name) shutil.copy2(config_path, run_dir / config_path.name)
# ========================= # ========== 3 日志、TensorBoard、随机种子与设备 ==========
# 3. 初始化日志与 TensorBoard logger = utils.get_logger(__name__, run_dir / "train.log")
# ========================= writer = SummaryWriter(str(run_dir / "run"))
logger = logging.getLogger("dpnet_train") utils.set_seeds(SEED)
logger.setLevel(logging.INFO) device = torch.device("cuda:0")
logger.propagate = False
logger.handlers.clear()
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
file_handler = logging.FileHandler(run_dir / "train.log", encoding="utf-8")
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
writer = SummaryWriter(log_dir=str(run_dir))
logger.info(f"Run directory: {run_dir}")
logger.info(f"Config path: {config_path}") logger.info(f"Config path: {config_path}")
logger.info(f"Loaded config:\n{pformat(config, sort_dicts=False)}") logger.info(f"Loaded config: {str(config)}")
logger.info(f"Run directory: {run_dir}")
# =========================
# 4. 设置随机种子与设备
# =========================
utils.set_seeds(seed=seed)
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
logger.info(f"Using device: {device}") logger.info(f"Using device: {device}")
# ========================= # ========== 4 数据与 loader ==========
# 5. 准备数据集与 DataLoader (
# ========================= train_image_path_list,
train_image_path_list, train_defocus_distance_list, val_image_path_list, val_defocus_distance_list = \ train_defocus_distance_list,
dataset_F.get_DPNet_train_data_and_label(root_path_list=xlsx_files) valid_image_path_list,
valid_defocus_distance_list,
) = dataset_F.get_DPNet_train_data_and_label(root_path_list=XLSX_FILES)
train_transform = transforms.Compose([ train_transform = transforms.Compose(
transforms.ColorJitter( [
brightness=(0.9, 1.4), transforms.ColorJitter(
contrast=(0.8, 1.5), brightness=(0.9, 1.4),
saturation=(0.8, 1.5), contrast=(0.8, 1.5),
), saturation=(0.8, 1.5),
transforms.ToTensor(), ),
]) transforms.ToTensor(),
]
)
valid_transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
val_transform = transforms.Compose([ train_set = dataset_F.MyDataset(
transforms.ToTensor(),
])
train_dataset = dataset_F.MyDataset(
train_image_path_list, train_image_path_list,
train_defocus_distance_list, train_defocus_distance_list,
train_transform, train_transform,
) )
val_dataset = dataset_F.MyDataset( valid_set = dataset_F.MyDataset(
val_image_path_list, valid_image_path_list,
val_defocus_distance_list, valid_defocus_distance_list,
val_transform, valid_transform,
) )
train_loader = DataLoader( train_loader = DataLoader(
dataset=train_dataset, dataset=train_set,
batch_size=batch_size, batch_size=BATCH_SIZE,
shuffle=True, shuffle=True,
num_workers=num_workers, num_workers=NUM_WORKERS,
pin_memory=(device.type == "cuda"), pin_memory=True,
persistent_workers=(num_workers > 0), persistent_workers=(NUM_WORKERS > 0),
) )
valid_loader = DataLoader(
val_loader = DataLoader( dataset=valid_set,
dataset=val_dataset, batch_size=BATCH_SIZE,
batch_size=batch_size,
shuffle=False, shuffle=False,
num_workers=num_workers, num_workers=NUM_WORKERS,
pin_memory=(device.type == "cuda"), pin_memory=True,
persistent_workers=(num_workers > 0), persistent_workers=(NUM_WORKERS > 0),
) )
logger.info(f"Train dataset size: {len(train_dataset)}") logger.info(f"Train dataset size: {len(train_set)}")
logger.info(f"Val dataset size: {len(val_dataset)}") logger.info(f"Val dataset size: {len(valid_set)}")
logger.info(f"Train steps per epoch: {len(train_loader)}") logger.info(f"Train steps per epoch: {len(train_loader)}")
# ========================= # ========== 5 模型、损失、优化器、调度器 ==========
# 6. 准备模型
# =========================
model = DPNet().to(device) model = DPNet().to(device)
if INIT_WEIGHT_PATH:
if init_weight: state_dict = torch.load(INIT_WEIGHT_PATH, map_location="cpu")
state_dict = torch.load(init_weight, map_location="cpu")
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
logger.info(f"Loaded init weight from: {init_weight}") logger.info(f"Loaded init weight from: {INIT_WEIGHT_PATH}")
else: else:
logger.info("Training from scratch") logger.info("Training from scratch")
# =========================
# 7. 准备损失函数、优化器、调度器
# =========================
criterion = nn.MSELoss(reduction="mean") criterion = nn.MSELoss(reduction="mean")
optimizer = Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
optimizer = Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999)) scheduler = utils.get_warmup_cosine_scheduler(optimizer, NUM_EPOCHS)
total_steps = epochs * len(train_loader)
warmup_steps = min(2000, total_steps - 1) if total_steps > 1 else 0
scheduler = SequentialLR(
optimizer,
schedulers=[
LinearLR(
optimizer,
start_factor=0.1,
end_factor=1.0,
total_iters=warmup_steps,
),
CosineAnnealingLR(
optimizer,
T_max=max(1, total_steps - warmup_steps),
eta_min=0.0,
),
],
milestones=[warmup_steps],
)
logger.info("Loss: MSELoss(reduction='mean')") logger.info("Loss: MSELoss(reduction='mean')")
logger.info(f"Optimizer: Adam(lr={learning_rate}, betas=(0.9, 0.999))") logger.info(f"Optimizer: Adam(lr={LEARNING_RATE}, betas=(0.9, 0.999))")
logger.info("Scheduler: step-based warmup + cosine annealing") logger.info("Scheduler: epoch-based warmup + cosine annealing")
logger.info(f"Total training steps: {total_steps}")
logger.info(f"Warmup steps: {warmup_steps}")
# ========================= # ========== 6 开始训练 ==========
# 8. 开始训练循环 logger.info("START TRAINING")
# ========================= best_valid_loss = float("inf")
best_val_loss = float("inf")
global_step = 0
start_time = time.time()
for epoch in range(1, epochs + 1): try:
epoch_start_time = time.time() for epoch in range(1, NUM_EPOCHS + 1):
epoch_start_time = time.time()
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
valid_loss = valid_epoch(model, valid_loader, criterion, device)
epoch_lr = optimizer.param_groups[0]["lr"]
scheduler.step()
epoch_time_cost = time.time() - epoch_start_time
train_loss, global_step = train_epoch( if valid_loss < best_valid_loss:
model=model, best_valid_loss = valid_loss
train_loader=train_loader, torch.save(model.state_dict(), run_dir / "best_model.pt")
criterion=criterion, logger.info(f"Best model saved, valid_loss = {best_valid_loss:.4f}")
optimizer=optimizer,
scheduler=scheduler,
device=device,
writer=writer,
global_step=global_step,
)
val_loss = validate_epoch( logger.info(
model=model, f"Epoch [{epoch}/{NUM_EPOCHS}] "
val_loader=val_loader, f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f} | "
criterion=criterion, f"Best Valid Loss: {best_valid_loss:.4f} | "
device=device, f"Epoch Time Cost: {epoch_time_cost:.2f} s | "
) f"Epoch Learning Rate: {epoch_lr:.6e}"
)
epoch_total_time = time.time() - epoch_start_time writer.add_scalar("Loss/train", train_loss, epoch)
writer.add_scalar("Loss/valid", valid_loss, epoch)
writer.add_scalar("Loss/best_valid", best_valid_loss, epoch)
writer.add_scalar("Time/epoch", epoch_time_cost, epoch)
writer.add_scalar("Time/learning_rate", epoch_lr, epoch)
writer.add_scalar("train_loss", train_loss, epoch) except KeyboardInterrupt:
writer.add_scalar("val_loss", val_loss, epoch) logger.info("Training interrupted by user")
writer.add_scalar("epoch_total_time", epoch_total_time, epoch)
logger.info( finally:
f"Epoch [{epoch}/{epochs}] | " torch.save(model.state_dict(), run_dir / "last_model.pt")
f"train_loss={train_loss:.8f} | " logger.info("Last model saved")
f"val_loss={val_loss:.8f} | "
f"epoch_total_time={epoch_total_time:.2f}s | "
f"global_step={global_step}"
)
torch.save(model.state_dict(), run_dir / "last_dpn.pth") writer.close()
logger.info("TensorBoard writer closed")
if val_loss < best_val_loss: logger.info(f"Training finished, best validation loss: {best_valid_loss:.8f}")
best_val_loss = val_loss
torch.save(model.state_dict(), run_dir / "best_dpn.pth")
logger.info(f"Best model updated, best_val_loss={best_val_loss:.8f}")
# =========================
# 9. 收尾
# =========================
total_time = time.time() - start_time
logger.info("Training finished")
logger.info(f"Best validation loss: {best_val_loss:.8f}")
logger.info(f"Total time: {total_time:.2f} seconds")
writer.close()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

312
train_dpn_ddp.py Normal file
View File

@ -0,0 +1,312 @@
import shutil
import time
from datetime import datetime
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision.transforms as transforms
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import old_datasets as dataset_F
import utils
from models import DPNet
def reduce_epoch_loss(running_loss, total_samples, device):
stats = torch.tensor(
[running_loss, total_samples],
dtype=torch.float64,
device=device,
)
dist.all_reduce(stats, op=dist.ReduceOp.SUM)
return (stats[0] / stats[1]).item()
# 训练一轮
def train_epoch(model, loader, criterion, optimizer, device, is_main_process):
model.train()
running_loss = 0.0
total_samples = 0
for images, labels, image_names in tqdm(
loader,
desc="Train:",
bar_format="{l_bar}{bar:20}{r_bar}",
leave=False,
disable=not is_main_process,
):
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).float().view(-1)
optimizer.zero_grad(set_to_none=True)
outputs = model(images).view(-1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
this_batch_size = images.size(0)
running_loss += loss.item() * this_batch_size
total_samples += this_batch_size
return reduce_epoch_loss(running_loss, total_samples, device)
# 验证一轮
@torch.no_grad()
def valid_epoch(model, loader, criterion, device, is_main_process):
model.eval()
running_loss = 0.0
total_samples = 0
for images, labels, image_names in tqdm(
loader,
desc="Valid:",
bar_format="{l_bar}{bar:20}{r_bar}",
leave=False,
disable=not is_main_process,
):
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).float().view(-1)
outputs = model(images).view(-1)
loss = criterion(outputs, labels)
this_batch_size = images.size(0)
running_loss += loss.item() * this_batch_size
total_samples += this_batch_size
return reduce_epoch_loss(running_loss, total_samples, device)
def create_run_dir(config_path, is_main_process):
if is_main_process:
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_dpn_ddp")
run_dir = Path.cwd() / run_name
run_dir.mkdir(parents=True, exist_ok=False)
shutil.copy2(config_path, run_dir / config_path.name)
run_dir_text = str(run_dir)
else:
run_dir_text = None
shared_value = [run_dir_text]
dist.broadcast_object_list(shared_value, src=0)
dist.barrier()
return Path(shared_value[0])
# 主训练函数
def main():
local_rank, rank, world_size, device, is_main_process = utils.setup_distributed()
logger = None
writer = None
model = None
best_valid_loss = float("inf")
try:
# ========== 1 配置文件与超参数 ==========
config, config_path = utils.get_hyperparams()
XLSX_FILES = config["xlsx_files"]
BATCH_SIZE = config["batch_size"]
NUM_WORKERS = config["num_workers"]
LEARNING_RATE = config["learning_rate"]
NUM_EPOCHS = config["epochs"]
SEED = config["seed"]
INIT_WEIGHT_PATH = config["init_weight"]
# ========== 2 创建输出文件目录 ==========
run_dir = create_run_dir(config_path, is_main_process)
# ========== 3 日志、TensorBoard、随机种子与设备 ==========
if is_main_process:
logger = utils.get_logger(__name__, run_dir / "train.log")
writer = SummaryWriter(str(run_dir / "run"))
utils.set_seeds(SEED)
if is_main_process:
logger.info(f"Config path: {config_path}")
logger.info(f"Loaded config: {str(config)}")
logger.info(f"Run directory: {run_dir}")
logger.info(f"Using world size: {world_size}")
logger.info(f"Using device: {device}")
# ========== 4 数据与 loader ==========
(
train_image_path_list,
train_defocus_distance_list,
valid_image_path_list,
valid_defocus_distance_list,
) = dataset_F.get_DPNet_train_data_and_label(root_path_list=XLSX_FILES)
train_transform = transforms.Compose(
[
transforms.ColorJitter(
brightness=(0.9, 1.4),
contrast=(0.8, 1.5),
saturation=(0.8, 1.5),
),
transforms.ToTensor(),
]
)
valid_transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
train_set = dataset_F.MyDataset(
train_image_path_list,
train_defocus_distance_list,
train_transform,
)
valid_set = dataset_F.MyDataset(
valid_image_path_list,
valid_defocus_distance_list,
valid_transform,
)
train_sampler = DistributedSampler(
train_set,
num_replicas=world_size,
rank=rank,
shuffle=True,
)
valid_sampler = DistributedSampler(
valid_set,
num_replicas=world_size,
rank=rank,
shuffle=False,
)
train_loader = DataLoader(
dataset=train_set,
batch_size=BATCH_SIZE,
shuffle=False,
sampler=train_sampler,
num_workers=NUM_WORKERS,
pin_memory=True,
persistent_workers=(NUM_WORKERS > 0),
)
valid_loader = DataLoader(
dataset=valid_set,
batch_size=BATCH_SIZE,
shuffle=False,
sampler=valid_sampler,
num_workers=NUM_WORKERS,
pin_memory=True,
persistent_workers=(NUM_WORKERS > 0),
)
if is_main_process:
logger.info(f"Train dataset size: {len(train_set)}")
logger.info(f"Val dataset size: {len(valid_set)}")
logger.info(f"Train steps per epoch per process: {len(train_loader)}")
# ========== 5 模型、损失、优化器、调度器 ==========
model = DPNet().to(device)
if INIT_WEIGHT_PATH:
state_dict = torch.load(INIT_WEIGHT_PATH, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
if is_main_process:
logger.info(f"Loaded init weight from: {INIT_WEIGHT_PATH}")
elif is_main_process:
logger.info("Training from scratch")
model = DistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank,
)
criterion = nn.MSELoss(reduction="mean")
optimizer = Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
scheduler = utils.get_warmup_cosine_scheduler(optimizer, NUM_EPOCHS)
if is_main_process:
logger.info("Loss: MSELoss(reduction='mean')")
logger.info(f"Optimizer: Adam(lr={LEARNING_RATE}, betas=(0.9, 0.999))")
logger.info("Scheduler: epoch-based warmup + cosine annealing")
# ========== 6 开始训练 ==========
if is_main_process:
logger.info("START TRAINING")
try:
for epoch in range(1, NUM_EPOCHS + 1):
train_sampler.set_epoch(epoch)
epoch_start_time = time.time()
train_loss = train_epoch(
model,
train_loader,
criterion,
optimizer,
device,
is_main_process,
)
valid_loss = valid_epoch(
model,
valid_loader,
criterion,
device,
is_main_process,
)
epoch_lr = optimizer.param_groups[0]["lr"]
scheduler.step()
epoch_time_cost = time.time() - epoch_start_time
if is_main_process and valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.module.state_dict(), run_dir / "best_model.pt")
logger.info(f"Best model saved, valid_loss = {best_valid_loss:.4f}")
if is_main_process:
logger.info(
f"Epoch [{epoch}/{NUM_EPOCHS}] "
f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f} | "
f"Best Valid Loss: {best_valid_loss:.4f} | "
f"Epoch Time Cost: {epoch_time_cost:.2f} s | "
f"Epoch Learning Rate: {epoch_lr:.6e}"
)
writer.add_scalar("Loss/train", train_loss, epoch)
writer.add_scalar("Loss/valid", valid_loss, epoch)
writer.add_scalar("Loss/best_valid", best_valid_loss, epoch)
writer.add_scalar("Time/epoch", epoch_time_cost, epoch)
writer.add_scalar("Time/learning_rate", epoch_lr, epoch)
except KeyboardInterrupt:
if is_main_process:
logger.info("Training interrupted by user")
finally:
if is_main_process and model is not None:
torch.save(model.module.state_dict(), run_dir / "last_model.pt")
logger.info("Last model saved")
if writer is not None:
writer.close()
logger.info("TensorBoard writer closed")
if is_main_process and logger is not None:
logger.info(f"Training finished, best validation loss: {best_valid_loss:.8f}")
if dist.is_initialized():
dist.barrier()
utils.cleanup_distributed()
if __name__ == "__main__":
main()

View File

@ -90,7 +90,7 @@ def main():
INIT_WEIGHT_PATH = config["init_weight"] INIT_WEIGHT_PATH = config["init_weight"]
# ========== 2 创建输出文件目录 ========== # ========== 2 创建输出文件目录 ==========
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_RIN") run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_rin")
run_dir = Path.cwd() / run_name run_dir = Path.cwd() / run_name
run_dir.mkdir(parents=True, exist_ok=False) run_dir.mkdir(parents=True, exist_ok=False)
shutil.copy2(config_path, run_dir / config_path.name) shutil.copy2(config_path, run_dir / config_path.name)

View File

@ -1,4 +1,5 @@
import math import math
import os
import torch import torch
import random import random
import logging import logging
@ -60,7 +61,7 @@ def setup_distributed():
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
local_rank = dist.get_node_local_rank() local_rank = int(os.environ["LOCAL_RANK"])
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
@ -74,7 +75,8 @@ def setup_distributed():
# 释放 DDP 并行 # 释放 DDP 并行
def cleanup_distributed(): def cleanup_distributed():
dist.destroy_process_group() if dist.is_initialized():
dist.destroy_process_group()
# 命令行参数解析配置文件 # 命令行参数解析配置文件