first commit
This commit is contained in:
commit
382418a1d3
85
.gitignore
vendored
Normal file
85
.gitignore
vendored
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
# Python 缓存
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
# 虚拟环境和本地环境文件
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
*.env
|
||||||
|
|
||||||
|
# Conda 本地环境目录
|
||||||
|
.conda/
|
||||||
|
conda-meta/
|
||||||
|
|
||||||
|
# IDE / 编辑器
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# 系统文件
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
desktop.ini
|
||||||
|
|
||||||
|
# 数据集和本地原始数据
|
||||||
|
data/
|
||||||
|
datasets/
|
||||||
|
raw_data/
|
||||||
|
|
||||||
|
# 训练、验证、测试输出
|
||||||
|
runs/
|
||||||
|
outputs/
|
||||||
|
output/
|
||||||
|
results/
|
||||||
|
logs/
|
||||||
|
checkpoints/
|
||||||
|
weights/
|
||||||
|
predictions/
|
||||||
|
reports/
|
||||||
|
figures/
|
||||||
|
|
||||||
|
# 模型权重和导出产物
|
||||||
|
*.pt
|
||||||
|
*.pth
|
||||||
|
*.ckpt
|
||||||
|
*.safetensors
|
||||||
|
*.onnx
|
||||||
|
*.engine
|
||||||
|
*.trt
|
||||||
|
*.torchscript
|
||||||
|
|
||||||
|
# 量化和性能测试产物
|
||||||
|
quantized/
|
||||||
|
calibration_cache/
|
||||||
|
benchmark/
|
||||||
|
profile/
|
||||||
|
*.profile
|
||||||
|
*.prof
|
||||||
|
|
||||||
|
# 数组、表格和中间结果
|
||||||
|
*.npy
|
||||||
|
*.npz
|
||||||
|
*.pkl
|
||||||
|
*.pickle
|
||||||
|
*.csv.tmp
|
||||||
|
*.json.tmp
|
||||||
|
|
||||||
|
# Jupyter
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
|
||||||
|
# 日志和临时文件
|
||||||
|
*.log
|
||||||
|
*.tmp
|
||||||
|
*.temp
|
||||||
|
|
||||||
236
AGENTS.md
Normal file
236
AGENTS.md
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
# AGENTS.md
|
||||||
|
|
||||||
|
## 项目目标
|
||||||
|
|
||||||
|
本项目是一个用于实践深度学习模型量化技术的小型 PyTorch 工程。背景任务是:根据离焦图像预测离焦距离。任务形式是回归问题,核心模型使用 `timm` 库提供的 MobileNetV4,并将 `num_classes` 设置为 `1`,让模型输出一个标量预测值。
|
||||||
|
|
||||||
|
本项目规模不大,源码可以直接放在项目根目录。整体设计应当清晰、直接、容易检查,不需要搭建复杂的软件包结构。
|
||||||
|
|
||||||
|
## 语言约束
|
||||||
|
|
||||||
|
- 面向用户的所有说明、总结、报错解释、运行提示、文档内容应当使用中文。
|
||||||
|
- 项目文档优先使用中文。
|
||||||
|
- 代码注释优先使用中文,除非某些英文术语是库、API、指标名或行业固定表达。
|
||||||
|
- 日志、打印信息、错误提示优先使用中文,方便用户不用切换语言环境阅读。
|
||||||
|
- 变量名、函数名、文件名仍然使用常规 Python 英文命名风格,不要为了中文化而使用拼音或中文变量名。
|
||||||
|
|
||||||
|
## 环境约束
|
||||||
|
|
||||||
|
- Conda 环境名:`torch271`。
|
||||||
|
- Python 版本:`py310`。
|
||||||
|
- 预期 PyTorch 栈:`torch271+cu126`。
|
||||||
|
- 用户会提前配置好 conda 环境。
|
||||||
|
- 缺少依赖时,不允许擅自下载或安装。
|
||||||
|
- 如果发现缺包,应当用中文说明缺少哪个包、可能需要的安装命令或替代路线,然后等待用户决定。
|
||||||
|
|
||||||
|
## 主要依赖
|
||||||
|
|
||||||
|
预期会使用的主要库:
|
||||||
|
|
||||||
|
- `python`
|
||||||
|
- `torch`
|
||||||
|
- `torchvision`
|
||||||
|
- `timm`
|
||||||
|
- `numpy`
|
||||||
|
- `Pillow`
|
||||||
|
- 按需使用:`pandas`、`matplotlib`、`scikit-learn`、`onnx`、`onnxruntime` 以及量化相关后端库。
|
||||||
|
|
||||||
|
依赖应尽量少。标准库或已有工具函数能解决的问题,不要额外引入新依赖。
|
||||||
|
|
||||||
|
## 路径规范
|
||||||
|
|
||||||
|
- 所有文件和路径处理一律使用 `pathlib`。
|
||||||
|
- 打印路径、保存路径到文本、写入 CSV/JSON 元数据、生成配置字符串时,统一转换为 POSIX 风格,也就是使用 `/` 正斜杠。
|
||||||
|
- 源码中避免硬编码 Windows 反斜杠。
|
||||||
|
- 面向用户的示例路径优先使用相对项目根目录的路径。
|
||||||
|
|
||||||
|
推荐工具函数形式:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def as_posix_path(path):
|
||||||
|
return Path(path).as_posix()
|
||||||
|
```
|
||||||
|
|
||||||
|
## 代码风格
|
||||||
|
|
||||||
|
- 代码应当直接、清楚、容易读。
|
||||||
|
- 不要求严格类型标注。
|
||||||
|
- 每个可运行源码文件都应当有显式的 `main()` 函数。
|
||||||
|
- 涉及大规模训练、评估、量化、统计的脚本,应当同时提供一个小规模 `test_*()` 函数。
|
||||||
|
- 小规模测试函数要使用类似主流程的逻辑,但只处理极少量数据,用于快速验证数据形状、模型前向传播、损失计算、指标计算、文件保存或量化流程。
|
||||||
|
- 默认运行行为应优先执行小规模测试函数,用户确认后再手动切换到 `main()` 做大规模实验。
|
||||||
|
- 使用 `if __name__ == "__main__":` 显式控制入口。
|
||||||
|
- 导入模块时不要产生训练、推理、写文件等隐藏副作用。
|
||||||
|
- 注释要简短有用,重点解释不明显的数据约定、校准策略、量化限制和指标定义。
|
||||||
|
|
||||||
|
## 计划源码文件
|
||||||
|
|
||||||
|
项目发展过程中至少应包含这些根目录源码文件:
|
||||||
|
|
||||||
|
- `dataset.py`:数据集定义、图像变换、标签解析、小规模数据读取检查。
|
||||||
|
- `model.py`:MobileNetV4 回归模型创建、检查点加载、可选模型结构检查。
|
||||||
|
- `train.py`:训练循环、验证循环、检查点保存、小规模训练测试。
|
||||||
|
- `test.py`:训练后模型的推理和评估、指标计算、预测结果导出、小规模评估测试。
|
||||||
|
- `quantize.py`:量化实验、校准、量化后评估、后端或导出相关逻辑。
|
||||||
|
- `stats.py`:数据集统计、标签分布、预测误差统计、可选图表输出。
|
||||||
|
- `utils.py`:通用路径函数、随机种子、设备选择、日志、指标、检查点辅助函数。
|
||||||
|
|
||||||
|
可以按需新增文件,但新增文件应当让项目更清楚,而不是把小项目变成复杂框架。
|
||||||
|
|
||||||
|
## 数据集设计
|
||||||
|
|
||||||
|
数据集源码文件需要明确预期数据格式。如果数据布局尚未最终确定,应当让设计可以适配以下两种常见形式:
|
||||||
|
|
||||||
|
- CSV 文件中包含图像路径和数值型离焦距离标签;
|
||||||
|
- 文件夹结构配合一个元数据文件。
|
||||||
|
|
||||||
|
推荐约定:
|
||||||
|
|
||||||
|
- 图像路径列名:`image_path`。
|
||||||
|
- 回归目标列名:`defocus_distance`。
|
||||||
|
- 所有图像路径都通过 `Path` 处理。
|
||||||
|
- 相对图像路径应当相对于元数据文件所在目录或显式传入的数据根目录解析。
|
||||||
|
- 样本返回形式建议为 `(image_tensor, target_tensor)`。
|
||||||
|
- 目标值应为浮点张量,形状要能和模型输出 `[batch, 1]` 对齐。
|
||||||
|
- 数据集文件中应包含小规模测试函数,用于读取少量样本并打印图像张量形状、目标形状、目标范围和路径示例。
|
||||||
|
|
||||||
|
## 模型设计
|
||||||
|
|
||||||
|
使用 `timm.create_model()` 创建 MobileNetV4。
|
||||||
|
|
||||||
|
推荐形式:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import timm
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(model_name="mobilenetv4_conv_small", pretrained=True):
|
||||||
|
model = timm.create_model(model_name, pretrained=pretrained, num_classes=1)
|
||||||
|
return model
|
||||||
|
```
|
||||||
|
|
||||||
|
模型创建逻辑应当独立出来,让训练、测试和量化脚本都复用同一个模型工厂函数。
|
||||||
|
|
||||||
|
如果后续要切换具体 MobileNetV4 变体,应通过函数参数或简单脚本常量控制。
|
||||||
|
|
||||||
|
## 训练设计
|
||||||
|
|
||||||
|
训练代码应包含:
|
||||||
|
|
||||||
|
- 随机种子设置;
|
||||||
|
- 设备选择;
|
||||||
|
- 数据集和 DataLoader 创建;
|
||||||
|
- 创建 `num_classes=1` 的回归模型;
|
||||||
|
- 回归损失函数,例如 `MSELoss`、`L1Loss` 或 `SmoothL1Loss`;
|
||||||
|
- MAE、RMSE 等回归指标;
|
||||||
|
- 验证循环;
|
||||||
|
- 检查点保存,内容至少包括模型状态、优化器状态、epoch、指标和必要配置;
|
||||||
|
- 小规模 `test_train()` 流程,只运行极少量样本、batch 和 epoch。
|
||||||
|
|
||||||
|
脚本默认执行时应优先跑小规模测试,避免误触发长时间训练。
|
||||||
|
|
||||||
|
## 测试与评估设计
|
||||||
|
|
||||||
|
评估代码应当:
|
||||||
|
|
||||||
|
- 加载检查点;
|
||||||
|
- 重建相同模型结构;
|
||||||
|
- 在测试数据集上推理;
|
||||||
|
- 计算回归指标;
|
||||||
|
- 可选地将预测结果保存为 CSV,并统一使用 POSIX 风格路径;
|
||||||
|
- 包含 `test_eval()` 或类似小规模评估函数。
|
||||||
|
|
||||||
|
如果后续引入测试框架,要注意 `test.py` 这个项目脚本名和测试框架的文件发现规则是否冲突。
|
||||||
|
|
||||||
|
## 量化设计
|
||||||
|
|
||||||
|
量化代码应和普通训练、评估逻辑分离,但复用数据集、模型和指标工具函数。
|
||||||
|
|
||||||
|
可选技术路线取决于 PyTorch、`timm` 和部署目标的支持情况:
|
||||||
|
|
||||||
|
- 动态量化;
|
||||||
|
- 训练后静态量化和校准;
|
||||||
|
- 量化感知训练;
|
||||||
|
- 必要时使用 ONNX 等导出路线。
|
||||||
|
|
||||||
|
每个量化实验应记录:
|
||||||
|
|
||||||
|
- 原始模型或检查点来源;
|
||||||
|
- 量化方法;
|
||||||
|
- 校准数据规模和选择规则;
|
||||||
|
- 后端或导出格式;
|
||||||
|
- 量化前后评估指标;
|
||||||
|
- 如果测量了,还应记录模型大小和推理延迟。
|
||||||
|
|
||||||
|
量化脚本必须提供小规模测试或小规模校准函数,不能默认直接跑完整量化实验。
|
||||||
|
|
||||||
|
## 统计设计
|
||||||
|
|
||||||
|
统计代码应支持:
|
||||||
|
|
||||||
|
- 数据集大小和划分统计;
|
||||||
|
- 标签最小值、最大值、均值、标准差;
|
||||||
|
- 标签分布分箱;
|
||||||
|
- 缺失文件检查;
|
||||||
|
- 图像尺寸统计;
|
||||||
|
- 评估后的预测误差统计;
|
||||||
|
- 可选图表输出。
|
||||||
|
|
||||||
|
保存统计结果时,文件名要清楚,路径统一使用 POSIX 风格。
|
||||||
|
|
||||||
|
## 工具函数设计
|
||||||
|
|
||||||
|
共享工具函数可包括:
|
||||||
|
|
||||||
|
- `project_root()`;
|
||||||
|
- POSIX 路径格式化;
|
||||||
|
- 随机种子设置;
|
||||||
|
- 设备选择;
|
||||||
|
- MAE、RMSE 等指标函数;
|
||||||
|
- 检查点保存和加载;
|
||||||
|
- 简单日志函数;
|
||||||
|
- 图像扩展名过滤。
|
||||||
|
|
||||||
|
不要把 `utils.py` 变成核心逻辑杂物间。数据集逻辑放在 `dataset.py`,模型创建放在 `model.py`,实验流程放在对应脚本里。
|
||||||
|
|
||||||
|
## 执行原则
|
||||||
|
|
||||||
|
对于耗时脚本,应始终保留两级执行入口:
|
||||||
|
|
||||||
|
1. 小规模测试函数:快速验证同一套逻辑是否能跑通。
|
||||||
|
2. `main()` 函数:完整训练、评估、量化或统计。
|
||||||
|
|
||||||
|
大规模实验前,先运行对应小规模测试函数。是否切换到完整 `main()` 由用户决定。
|
||||||
|
|
||||||
|
推荐入口形式:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_train():
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_train()
|
||||||
|
# main()
|
||||||
|
```
|
||||||
|
|
||||||
|
## 后续协作规则
|
||||||
|
|
||||||
|
- 修改项目前,先阅读本文件。
|
||||||
|
- 后续沟通、总结和错误说明使用中文。
|
||||||
|
- 保持小项目结构,除非用户明确要求改成更复杂的包结构。
|
||||||
|
- 不要未经允许安装依赖。
|
||||||
|
- 在用户提供数据集格式之前,不要假设数据布局已经固定。
|
||||||
|
- 路径处理始终使用 `pathlib`。
|
||||||
|
- 面向外部展示、保存或打印的路径统一使用 `/`。
|
||||||
|
- 可运行脚本默认应安全,优先执行小规模测试,而不是直接启动完整训练。
|
||||||
|
- 编辑已有用户工作时,不要回退无关改动。
|
||||||
|
- 在多个合理方案之间不确定时,选择最简单、最符合本文档和现有代码的方案。
|
||||||
|
|
||||||
261
dataset.py
Normal file
261
dataset.py
Normal file
@ -0,0 +1,261 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import random
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("缺少 numpy。请在 torch271 环境中安装 numpy,或告诉我改用其他路线。") from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
from PIL import Image
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("缺少 Pillow。请在 torch271 环境中安装 pillow,或告诉我改用其他路线。") from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("缺少 PyTorch。请确认当前 conda 环境为 torch271,并已配置 torch271+cu126。") from exc
|
||||||
|
|
||||||
|
|
||||||
|
# 数据集根目录。后续只需要改这里,或在构造函数中显式传入 data_root。
|
||||||
|
DATA_ROOT = Path("E:/Datasets/SimpleAFDataset/roi_with_label")
|
||||||
|
|
||||||
|
SPLIT_NAMES = ("train", "valid", "test")
|
||||||
|
IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp"}
|
||||||
|
|
||||||
|
|
||||||
|
def as_posix_path(path):
|
||||||
|
return Path(path).as_posix()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_label_from_path(image_path):
|
||||||
|
"""从文件名去掉后缀后的字符串中解析离焦距离标签。"""
|
||||||
|
try:
|
||||||
|
return float(Path(image_path).stem)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(f"无法从文件名解析标签:{as_posix_path(image_path)}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def default_image_transform(image, image_size=224):
|
||||||
|
"""将 PIL 图像转成 MobileNet 常用输入形状:[3, H, W]。"""
|
||||||
|
image = image.convert("RGB")
|
||||||
|
if image_size is not None:
|
||||||
|
image = image.resize((image_size, image_size), Image.BILINEAR)
|
||||||
|
|
||||||
|
array = np.asarray(image, dtype=np.float32) / 255.0
|
||||||
|
array = np.transpose(array, (2, 0, 1))
|
||||||
|
return torch.from_numpy(array)
|
||||||
|
|
||||||
|
|
||||||
|
def find_field_dirs(data_root=DATA_ROOT):
|
||||||
|
"""查找所有 sampleXXX/fieldXXX 目录。"""
|
||||||
|
data_root = Path(data_root)
|
||||||
|
if not data_root.exists():
|
||||||
|
raise FileNotFoundError(f"数据集根目录不存在:{as_posix_path(data_root)}")
|
||||||
|
|
||||||
|
field_dirs = []
|
||||||
|
for sample_dir in sorted(data_root.glob("sample[0-9][0-9][0-9]")):
|
||||||
|
if not sample_dir.is_dir():
|
||||||
|
continue
|
||||||
|
for field_dir in sorted(sample_dir.glob("field[0-9][0-9][0-9]")):
|
||||||
|
if field_dir.is_dir():
|
||||||
|
field_dirs.append(field_dir)
|
||||||
|
|
||||||
|
return field_dirs
|
||||||
|
|
||||||
|
|
||||||
|
def collect_images_from_fields(field_dirs):
|
||||||
|
"""把 field 目录展开成图像路径列表和标签列表。"""
|
||||||
|
image_paths = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
for field_dir in sorted(Path(path) for path in field_dirs):
|
||||||
|
for roi_dir in sorted(field_dir.glob("roi[0-9][0-9][0-9]")):
|
||||||
|
if not roi_dir.is_dir():
|
||||||
|
continue
|
||||||
|
for image_path in sorted(roi_dir.iterdir()):
|
||||||
|
if not image_path.is_file():
|
||||||
|
continue
|
||||||
|
if image_path.suffix.lower() not in IMAGE_SUFFIXES:
|
||||||
|
continue
|
||||||
|
image_paths.append(image_path)
|
||||||
|
labels.append(parse_label_from_path(image_path))
|
||||||
|
|
||||||
|
return image_paths, labels
|
||||||
|
|
||||||
|
|
||||||
|
def split_counts(total_count, train_ratio=0.8, valid_ratio=0.1):
|
||||||
|
"""计算 field 数量的 8:1:1 划分,样本太少时尽量保留验证/测试。"""
|
||||||
|
if total_count <= 0:
|
||||||
|
return 0, 0, 0
|
||||||
|
if total_count == 1:
|
||||||
|
return 1, 0, 0
|
||||||
|
if total_count == 2:
|
||||||
|
return 1, 1, 0
|
||||||
|
|
||||||
|
train_count = int(total_count * train_ratio)
|
||||||
|
valid_count = int(total_count * valid_ratio)
|
||||||
|
test_count = total_count - train_count - valid_count
|
||||||
|
|
||||||
|
if valid_count == 0:
|
||||||
|
valid_count = 1
|
||||||
|
train_count -= 1
|
||||||
|
if test_count == 0:
|
||||||
|
test_count = 1
|
||||||
|
train_count -= 1
|
||||||
|
|
||||||
|
return train_count, valid_count, test_count
|
||||||
|
|
||||||
|
|
||||||
|
def split_fields(field_dirs, seed=2026, train_ratio=0.8, valid_ratio=0.1):
|
||||||
|
"""在 field 层面对数据进行 train/valid/test 划分。"""
|
||||||
|
field_dirs = [Path(path) for path in field_dirs]
|
||||||
|
field_dirs = sorted(field_dirs)
|
||||||
|
|
||||||
|
rng = random.Random(seed)
|
||||||
|
rng.shuffle(field_dirs)
|
||||||
|
|
||||||
|
train_count, valid_count, _ = split_counts(
|
||||||
|
len(field_dirs),
|
||||||
|
train_ratio=train_ratio,
|
||||||
|
valid_ratio=valid_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_fields = sorted(field_dirs[:train_count])
|
||||||
|
valid_fields = sorted(field_dirs[train_count:train_count + valid_count])
|
||||||
|
test_fields = sorted(field_dirs[train_count + valid_count:])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"train": train_fields,
|
||||||
|
"valid": valid_fields,
|
||||||
|
"test": test_fields,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_split_lists(data_root=DATA_ROOT, seed=2026, train_ratio=0.8, valid_ratio=0.1):
|
||||||
|
"""
|
||||||
|
在 field 层面划分数据,并返回每个子集的图像路径和标签。
|
||||||
|
|
||||||
|
返回结构:
|
||||||
|
{
|
||||||
|
"train": {"image_paths": [...], "labels": [...], "field_dirs": [...]},
|
||||||
|
"valid": {"image_paths": [...], "labels": [...], "field_dirs": [...]},
|
||||||
|
"test": {"image_paths": [...], "labels": [...], "field_dirs": [...]},
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
field_dirs = find_field_dirs(data_root)
|
||||||
|
split_field_map = split_fields(
|
||||||
|
field_dirs,
|
||||||
|
seed=seed,
|
||||||
|
train_ratio=train_ratio,
|
||||||
|
valid_ratio=valid_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
split_data = {}
|
||||||
|
for split_name in SPLIT_NAMES:
|
||||||
|
image_paths, labels = collect_images_from_fields(split_field_map[split_name])
|
||||||
|
split_data[split_name] = {
|
||||||
|
"image_paths": image_paths,
|
||||||
|
"labels": labels,
|
||||||
|
"field_dirs": split_field_map[split_name],
|
||||||
|
}
|
||||||
|
|
||||||
|
return split_data
|
||||||
|
|
||||||
|
|
||||||
|
def get_split_items(split="train", data_root=DATA_ROOT, seed=2026):
|
||||||
|
"""返回指定 split 的图像路径列表和标签列表,可直接传给 DefocusDataset。"""
|
||||||
|
if split not in SPLIT_NAMES:
|
||||||
|
raise ValueError(f"split 必须是 {SPLIT_NAMES} 之一,当前为:{split}")
|
||||||
|
|
||||||
|
split_data = make_split_lists(data_root=data_root, seed=seed)
|
||||||
|
items = split_data[split]
|
||||||
|
return items["image_paths"], items["labels"]
|
||||||
|
|
||||||
|
|
||||||
|
class DefocusDataset(Dataset):
|
||||||
|
"""离焦距离回归数据集。"""
|
||||||
|
|
||||||
|
def __init__(self, image_paths, labels, transform=None, return_path=False):
|
||||||
|
if len(image_paths) != len(labels):
|
||||||
|
raise ValueError("image_paths 和 labels 的长度不一致。")
|
||||||
|
|
||||||
|
self.image_paths = [Path(path) for path in image_paths]
|
||||||
|
self.labels = [float(label) for label in labels]
|
||||||
|
self.transform = transform or default_image_transform
|
||||||
|
self.return_path = return_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_split(cls, split="train", data_root=DATA_ROOT, seed=2026, transform=None, return_path=False):
|
||||||
|
image_paths, labels = get_split_items(split=split, data_root=data_root, seed=seed)
|
||||||
|
return cls(image_paths, labels, transform=transform, return_path=return_path)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_paths)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
image_path = self.image_paths[index]
|
||||||
|
label = self.labels[index]
|
||||||
|
|
||||||
|
with Image.open(image_path) as image:
|
||||||
|
image_tensor = self.transform(image)
|
||||||
|
|
||||||
|
target = torch.tensor([label], dtype=torch.float32)
|
||||||
|
|
||||||
|
if self.return_path:
|
||||||
|
return image_tensor, target, as_posix_path(image_path)
|
||||||
|
return image_tensor, target
|
||||||
|
|
||||||
|
|
||||||
|
def print_split_summary(split_data):
|
||||||
|
"""打印划分结果,方便快速确认数据是否符合预期。"""
|
||||||
|
for split_name in SPLIT_NAMES:
|
||||||
|
items = split_data[split_name]
|
||||||
|
labels = items["labels"]
|
||||||
|
field_count = len(items["field_dirs"])
|
||||||
|
image_count = len(items["image_paths"])
|
||||||
|
|
||||||
|
if labels:
|
||||||
|
label_min = min(labels)
|
||||||
|
label_max = max(labels)
|
||||||
|
label_text = f"标签范围:{label_min:.6g} 到 {label_max:.6g}"
|
||||||
|
else:
|
||||||
|
label_text = "标签范围:无"
|
||||||
|
|
||||||
|
print(f"{split_name}: field 数={field_count}, 图像数={image_count}, {label_text}")
|
||||||
|
for image_path in items["image_paths"][:3]:
|
||||||
|
print(f" 示例:{as_posix_path(image_path)} -> {parse_label_from_path(image_path):.6g}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset():
|
||||||
|
"""小规模测试:扫描划分结果,并尝试读取训练集第一张图。"""
|
||||||
|
print(f"当前 DATA_ROOT:{as_posix_path(DATA_ROOT)}")
|
||||||
|
if not DATA_ROOT.exists():
|
||||||
|
print("数据集根目录还不存在,先跳过真实数据读取测试。")
|
||||||
|
return
|
||||||
|
|
||||||
|
split_data = make_split_lists(DATA_ROOT)
|
||||||
|
print_split_summary(split_data)
|
||||||
|
|
||||||
|
train_paths = split_data["train"]["image_paths"]
|
||||||
|
train_labels = split_data["train"]["labels"]
|
||||||
|
if not train_paths:
|
||||||
|
print("训练集没有找到图像,请检查目录是否符合 sampleXXX/fieldXXX/roiXXX/{label}.jpg。")
|
||||||
|
return
|
||||||
|
|
||||||
|
dataset = DefocusDataset(train_paths[:1], train_labels[:1], return_path=True)
|
||||||
|
image_tensor, target, image_path = dataset[0]
|
||||||
|
print(f"读取样本:{image_path}")
|
||||||
|
print(f"图像张量形状:{tuple(image_tensor.shape)}")
|
||||||
|
print(f"标签张量形状:{tuple(target.shape)},标签值:{target.item():.6g}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
split_data = make_split_lists(DATA_ROOT)
|
||||||
|
print_split_summary(split_data)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_dataset()
|
||||||
|
# main()
|
||||||
147
stats.py
Normal file
147
stats.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("缺少 numpy。请在 torch271 环境中安装 numpy,或告诉我改用其他路线。") from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("缺少 pandas。请在 torch271 环境中安装 pandas,或告诉我改用其他路线。") from exc
|
||||||
|
|
||||||
|
from dataset import as_posix_path
|
||||||
|
|
||||||
|
|
||||||
|
INPUT_CSV = Path("predictions") / "test_predictions.csv"
|
||||||
|
OUTPUT_CSV = INPUT_CSV
|
||||||
|
SUMMARY_TXT = Path("predictions") / "test_summary.txt"
|
||||||
|
|
||||||
|
# 景深参考值。请按你的实验设定手动修改。
|
||||||
|
DEPTH_OF_FIELD = 1.0
|
||||||
|
|
||||||
|
LABEL_COLUMN = "label"
|
||||||
|
PREDICTION_COLUMN = "prediction"
|
||||||
|
SIGNED_ERROR_COLUMN = "signed_error"
|
||||||
|
ABSOLUTE_ERROR_COLUMN = "absolute_error"
|
||||||
|
DIRECTION_SIGN_COLUMN = "direction_sign"
|
||||||
|
|
||||||
|
|
||||||
|
def check_columns(dataframe):
|
||||||
|
required_columns = {LABEL_COLUMN, PREDICTION_COLUMN}
|
||||||
|
missing_columns = required_columns - set(dataframe.columns)
|
||||||
|
if missing_columns:
|
||||||
|
missing_text = ", ".join(sorted(missing_columns))
|
||||||
|
raise ValueError(f"预测结果 CSV 缺少必要列:{missing_text}")
|
||||||
|
|
||||||
|
|
||||||
|
def add_error_columns(dataframe):
|
||||||
|
check_columns(dataframe)
|
||||||
|
|
||||||
|
dataframe = dataframe.copy()
|
||||||
|
labels = dataframe[LABEL_COLUMN].astype(float)
|
||||||
|
predictions = dataframe[PREDICTION_COLUMN].astype(float)
|
||||||
|
|
||||||
|
dataframe[SIGNED_ERROR_COLUMN] = labels - predictions
|
||||||
|
dataframe[ABSOLUTE_ERROR_COLUMN] = dataframe[SIGNED_ERROR_COLUMN].abs()
|
||||||
|
dataframe[DIRECTION_SIGN_COLUMN] = np.sign(labels * predictions).astype(int)
|
||||||
|
|
||||||
|
return dataframe
|
||||||
|
|
||||||
|
|
||||||
|
def percent_text(value):
|
||||||
|
return f"{value * 100:.4f}%"
|
||||||
|
|
||||||
|
|
||||||
|
def compute_summary(dataframe, depth_of_field=DEPTH_OF_FIELD):
|
||||||
|
if depth_of_field <= 0:
|
||||||
|
raise ValueError("DEPTH_OF_FIELD 必须大于 0。")
|
||||||
|
|
||||||
|
absolute_error = dataframe[ABSOLUTE_ERROR_COLUMN].astype(float)
|
||||||
|
direction_sign = dataframe[DIRECTION_SIGN_COLUMN]
|
||||||
|
|
||||||
|
mae = absolute_error.mean()
|
||||||
|
std = absolute_error.std(ddof=0)
|
||||||
|
|
||||||
|
dof_acc_1 = (absolute_error <= depth_of_field).mean()
|
||||||
|
dof_acc_1_2 = (absolute_error <= depth_of_field / 2).mean()
|
||||||
|
dof_acc_1_3 = (absolute_error <= depth_of_field / 3).mean()
|
||||||
|
dss = (direction_sign != -1).mean()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sample_count": len(dataframe),
|
||||||
|
"mae": mae,
|
||||||
|
"std": std,
|
||||||
|
"dof_acc_1": dof_acc_1,
|
||||||
|
"dof_acc_1_2": dof_acc_1_2,
|
||||||
|
"dof_acc_1_3": dof_acc_1_3,
|
||||||
|
"dss": dss,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def save_summary(summary, summary_txt=SUMMARY_TXT, depth_of_field=DEPTH_OF_FIELD):
|
||||||
|
summary_txt = Path(summary_txt)
|
||||||
|
summary_txt.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"测试结果凝练统计",
|
||||||
|
f"样本数:{summary['sample_count']}",
|
||||||
|
f"景深参考值:{depth_of_field:.6g}",
|
||||||
|
f"MAE and STD:{summary['mae']:.6g}, {summary['std']:.6g}",
|
||||||
|
f"DoF-Acc@1:{percent_text(summary['dof_acc_1'])}",
|
||||||
|
f"DoF-Acc@1/2:{percent_text(summary['dof_acc_1_2'])}",
|
||||||
|
f"DoF-Acc@1/3:{percent_text(summary['dof_acc_1_3'])}",
|
||||||
|
f"DSS:{percent_text(summary['dss'])}",
|
||||||
|
]
|
||||||
|
|
||||||
|
summary_txt.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||||
|
return summary_txt
|
||||||
|
|
||||||
|
|
||||||
|
def run_stats(input_csv=INPUT_CSV, output_csv=OUTPUT_CSV, summary_txt=SUMMARY_TXT, depth_of_field=DEPTH_OF_FIELD):
|
||||||
|
input_csv = Path(input_csv)
|
||||||
|
output_csv = Path(output_csv)
|
||||||
|
|
||||||
|
if not input_csv.exists():
|
||||||
|
raise FileNotFoundError(f"找不到预测结果 CSV:{as_posix_path(input_csv)}")
|
||||||
|
|
||||||
|
dataframe = pd.read_csv(input_csv)
|
||||||
|
dataframe = add_error_columns(dataframe)
|
||||||
|
|
||||||
|
output_csv.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
dataframe.to_csv(output_csv, index=False, encoding="utf-8")
|
||||||
|
|
||||||
|
summary = compute_summary(dataframe, depth_of_field=depth_of_field)
|
||||||
|
summary_txt = save_summary(summary, summary_txt=summary_txt, depth_of_field=depth_of_field)
|
||||||
|
|
||||||
|
print(f"已更新预测结果 CSV:{as_posix_path(output_csv)}")
|
||||||
|
print(f"已保存统计摘要:{as_posix_path(summary_txt)}")
|
||||||
|
return dataframe, summary
|
||||||
|
|
||||||
|
|
||||||
|
def test_stats():
|
||||||
|
"""小规模测试:如果预测 CSV 存在,就对当前文件执行统计流程。"""
|
||||||
|
if not INPUT_CSV.exists():
|
||||||
|
print(f"预测结果 CSV 不存在,先跳过统计测试:{as_posix_path(INPUT_CSV)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
run_stats(
|
||||||
|
input_csv=INPUT_CSV,
|
||||||
|
output_csv=OUTPUT_CSV,
|
||||||
|
summary_txt=SUMMARY_TXT,
|
||||||
|
depth_of_field=DEPTH_OF_FIELD,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
run_stats(
|
||||||
|
input_csv=INPUT_CSV,
|
||||||
|
output_csv=OUTPUT_CSV,
|
||||||
|
summary_txt=SUMMARY_TXT,
|
||||||
|
depth_of_field=DEPTH_OF_FIELD,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_stats()
|
||||||
|
# main()
|
||||||
163
test.py
Normal file
163
test.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import csv
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("缺少 PyTorch。请确认当前 conda 环境为 torch271,并已配置 torch271+cu126。") from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
import timm
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("缺少 timm。请在 torch271 环境中安装 timm,或告诉我改用其他路线。") from exc
|
||||||
|
|
||||||
|
from dataset import DATA_ROOT, DefocusDataset, as_posix_path, make_split_lists
|
||||||
|
|
||||||
|
|
||||||
|
SEED = 2026
|
||||||
|
MODEL_NAME = "mobilenetv4_conv_small"
|
||||||
|
PRETRAINED = False
|
||||||
|
DEVICE = "cpu"
|
||||||
|
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
NUM_WORKERS = 8
|
||||||
|
CHECKPOINT_PATH = Path("checkpoints") / "best_mobilenetv4_defocus.pth"
|
||||||
|
OUTPUT_DIR = Path("predictions")
|
||||||
|
OUTPUT_CSV = OUTPUT_DIR / "test_predictions.csv"
|
||||||
|
|
||||||
|
TEST_BATCH_SIZE = 2
|
||||||
|
TEST_NUM_WORKERS = 0
|
||||||
|
TEST_MAX_SAMPLES = 8
|
||||||
|
TEST_OUTPUT_CSV = OUTPUT_DIR / "test_predictions_smoke.csv"
|
||||||
|
|
||||||
|
|
||||||
|
def get_device(device_name=DEVICE):
|
||||||
|
if device_name.startswith("cuda") and not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("当前指定设备为 cuda:0,但 PyTorch 没有检测到可用 CUDA。请检查 torch271+cu126 环境。")
|
||||||
|
return torch.device(device_name)
|
||||||
|
|
||||||
|
|
||||||
|
def create_model():
|
||||||
|
model = timm.create_model(MODEL_NAME, pretrained=PRETRAINED, num_classes=1)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(checkpoint_path=CHECKPOINT_PATH, device_name=DEVICE):
|
||||||
|
checkpoint_path = Path(checkpoint_path)
|
||||||
|
if not checkpoint_path.exists():
|
||||||
|
raise FileNotFoundError(f"找不到检查点文件:{as_posix_path(checkpoint_path)}")
|
||||||
|
|
||||||
|
device = get_device(device_name)
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
|
|
||||||
|
model = create_model()
|
||||||
|
if isinstance(checkpoint, dict) and "model_state" in checkpoint:
|
||||||
|
model.load_state_dict(checkpoint["model_state"])
|
||||||
|
else:
|
||||||
|
model.load_state_dict(checkpoint)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
return model, device
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_loader(data_root=DATA_ROOT, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, seed=SEED, max_samples=None):
|
||||||
|
split_data = make_split_lists(data_root=data_root, seed=seed)
|
||||||
|
image_paths = split_data["test"]["image_paths"]
|
||||||
|
labels = split_data["test"]["labels"]
|
||||||
|
|
||||||
|
if max_samples is not None:
|
||||||
|
image_paths = image_paths[:max_samples]
|
||||||
|
labels = labels[:max_samples]
|
||||||
|
|
||||||
|
if not image_paths:
|
||||||
|
raise RuntimeError("测试集为空,请检查数据目录结构和 field 层面的划分结果。")
|
||||||
|
|
||||||
|
dataset = DefocusDataset(image_paths, labels, return_path=True)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict(model, dataloader, device):
|
||||||
|
rows = []
|
||||||
|
for images, targets, image_paths in dataloader:
|
||||||
|
images = images.to(device, non_blocking=True)
|
||||||
|
predictions = model(images).detach().cpu().view(-1)
|
||||||
|
targets = targets.detach().cpu().view(-1)
|
||||||
|
|
||||||
|
for image_path, label, prediction in zip(image_paths, targets.tolist(), predictions.tolist()):
|
||||||
|
rows.append({
|
||||||
|
"image_path": as_posix_path(image_path),
|
||||||
|
"label": label,
|
||||||
|
"prediction": prediction,
|
||||||
|
})
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def save_predictions_csv(rows, output_csv=OUTPUT_CSV):
|
||||||
|
output_csv = Path(output_csv)
|
||||||
|
output_csv.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with output_csv.open("w", newline="", encoding="utf-8") as file:
|
||||||
|
writer = csv.DictWriter(file, fieldnames=["image_path", "label", "prediction"])
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(rows)
|
||||||
|
|
||||||
|
return output_csv
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(checkpoint_path=CHECKPOINT_PATH, output_csv=OUTPUT_CSV, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, max_samples=None):
|
||||||
|
model, device = load_model(checkpoint_path=checkpoint_path, device_name=DEVICE)
|
||||||
|
dataloader = make_test_loader(
|
||||||
|
data_root=DATA_ROOT,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
seed=SEED,
|
||||||
|
max_samples=max_samples,
|
||||||
|
)
|
||||||
|
rows = predict(model, dataloader, device)
|
||||||
|
output_csv = save_predictions_csv(rows, output_csv=output_csv)
|
||||||
|
print(f"已保存测试预测结果:{as_posix_path(output_csv)}")
|
||||||
|
print(f"写入样本数:{len(rows)}")
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval():
|
||||||
|
"""小规模测试:只跑少量测试集样本,确认加载、推理和 CSV 写出链路。"""
|
||||||
|
if not DATA_ROOT.exists():
|
||||||
|
print("数据集根目录不存在,先跳过小规模测试。")
|
||||||
|
return
|
||||||
|
if not CHECKPOINT_PATH.exists():
|
||||||
|
print(f"最佳检查点不存在,先跳过小规模测试:{as_posix_path(CHECKPOINT_PATH)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
run_test(
|
||||||
|
checkpoint_path=CHECKPOINT_PATH,
|
||||||
|
output_csv=TEST_OUTPUT_CSV,
|
||||||
|
batch_size=TEST_BATCH_SIZE,
|
||||||
|
num_workers=TEST_NUM_WORKERS,
|
||||||
|
max_samples=TEST_MAX_SAMPLES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
run_test(
|
||||||
|
checkpoint_path=CHECKPOINT_PATH,
|
||||||
|
output_csv=OUTPUT_CSV,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
num_workers=NUM_WORKERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_eval()
|
||||||
|
# main()
|
||||||
446
train.py
Normal file
446
train.py
Normal file
@ -0,0 +1,446 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("缺少 numpy。请在 torch271 环境中安装 numpy,或告诉我改用其他路线。") from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("缺少 PyTorch。请确认当前 conda 环境为 torch271,并已配置 torch271+cu126。") from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
import timm
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("缺少 timm。请在 torch271 环境中安装 timm,或告诉我改用其他路线。") from exc
|
||||||
|
|
||||||
|
from dataset import DATA_ROOT, DefocusDataset, as_posix_path, make_split_lists
|
||||||
|
|
||||||
|
|
||||||
|
# 基础超参数。大规模实验前先运行 test_train()。
|
||||||
|
SEED = 2026
|
||||||
|
MODEL_NAME = "mobilenetv4_conv_small"
|
||||||
|
PRETRAINED = False
|
||||||
|
DEVICE = "cuda:0"
|
||||||
|
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
NUM_WORKERS = 8
|
||||||
|
LEARNING_RATE = 1e-4
|
||||||
|
EPOCHS = 300
|
||||||
|
|
||||||
|
CHECKPOINT_DIR = Path("checkpoints")
|
||||||
|
BEST_CHECKPOINT_NAME = "best_mobilenetv4_defocus.pth"
|
||||||
|
LAST_CHECKPOINT_NAME = "last_mobilenetv4_defocus.pth"
|
||||||
|
LOG_DIR = Path("logs")
|
||||||
|
LOG_FILE = LOG_DIR / "train.log"
|
||||||
|
TENSORBOARD_DIR = Path("runs") / "mobilenetv4_defocus"
|
||||||
|
TQDM_NCOLS = 100
|
||||||
|
|
||||||
|
TEST_BATCH_SIZE = 2
|
||||||
|
TEST_NUM_WORKERS = 0
|
||||||
|
TEST_MAX_SAMPLES = 8
|
||||||
|
TEST_EPOCHS = 1
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed=SEED):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(log_file=LOG_FILE):
|
||||||
|
logger = logging.getLogger("train")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
logger.handlers.clear()
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
log_file = Path(log_file)
|
||||||
|
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="%(asctime)s | %(levelname)s | %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
logger.info(f"日志文件:{as_posix_path(log_file)}")
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def create_summary_writer(log_dir=TENSORBOARD_DIR):
|
||||||
|
try:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("缺少 TensorBoard 相关依赖。请安装 tensorboard,或告诉我改用其他记录方式。") from exc
|
||||||
|
|
||||||
|
log_dir = Path(log_dir)
|
||||||
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
return SummaryWriter(log_dir=as_posix_path(log_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def make_progress_bar(iterable, desc):
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("缺少 tqdm。请在 torch271 环境中安装 tqdm,或告诉我改用普通日志显示进度。") from exc
|
||||||
|
|
||||||
|
return tqdm(
|
||||||
|
iterable,
|
||||||
|
desc=desc,
|
||||||
|
total=len(iterable),
|
||||||
|
ncols=TQDM_NCOLS,
|
||||||
|
leave=False,
|
||||||
|
dynamic_ncols=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device(device_name=DEVICE):
|
||||||
|
if device_name.startswith("cuda") and not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("当前指定设备为 cuda:0,但 PyTorch 没有检测到可用 CUDA。请检查 torch271+cu126 环境。")
|
||||||
|
return torch.device(device_name)
|
||||||
|
|
||||||
|
|
||||||
|
def create_model():
|
||||||
|
model = timm.create_model(MODEL_NAME, pretrained=PRETRAINED, num_classes=1)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def regression_metrics(predictions, targets):
|
||||||
|
errors = predictions - targets
|
||||||
|
mae = torch.mean(torch.abs(errors)).item()
|
||||||
|
rmse = torch.sqrt(torch.mean(errors ** 2)).item()
|
||||||
|
return {
|
||||||
|
"mae": mae,
|
||||||
|
"rmse": rmse,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_lr(optimizer):
|
||||||
|
return optimizer.param_groups[0]["lr"]
|
||||||
|
|
||||||
|
|
||||||
|
def move_batch_to_device(batch, device):
|
||||||
|
images, targets = batch[:2]
|
||||||
|
images = images.to(device, non_blocking=True)
|
||||||
|
targets = targets.to(device, non_blocking=True)
|
||||||
|
return images, targets
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(model, dataloader, criterion, optimizer, device, epoch=None):
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
total_count = 0
|
||||||
|
all_predictions = []
|
||||||
|
all_targets = []
|
||||||
|
|
||||||
|
desc = "训练" if epoch is None else f"训练 {epoch:03d}"
|
||||||
|
progress_bar = make_progress_bar(dataloader, desc)
|
||||||
|
|
||||||
|
for batch in progress_bar:
|
||||||
|
images, targets = move_batch_to_device(batch, device)
|
||||||
|
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
predictions = model(images)
|
||||||
|
loss = criterion(predictions, targets)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
batch_size = images.size(0)
|
||||||
|
total_loss += loss.item() * batch_size
|
||||||
|
total_count += batch_size
|
||||||
|
all_predictions.append(predictions.detach().cpu())
|
||||||
|
all_targets.append(targets.detach().cpu())
|
||||||
|
progress_bar.set_postfix(loss=f"{loss.item():<10.4f}")
|
||||||
|
|
||||||
|
mean_loss = total_loss / max(total_count, 1)
|
||||||
|
metrics = regression_metrics(torch.cat(all_predictions), torch.cat(all_targets))
|
||||||
|
metrics["loss"] = mean_loss
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def valid_epoch(model, dataloader, criterion, device, epoch=None):
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
total_count = 0
|
||||||
|
all_predictions = []
|
||||||
|
all_targets = []
|
||||||
|
|
||||||
|
desc = "验证" if epoch is None else f"验证 {epoch:03d}"
|
||||||
|
progress_bar = make_progress_bar(dataloader, desc)
|
||||||
|
|
||||||
|
for batch in progress_bar:
|
||||||
|
images, targets = move_batch_to_device(batch, device)
|
||||||
|
|
||||||
|
predictions = model(images)
|
||||||
|
loss = criterion(predictions, targets)
|
||||||
|
|
||||||
|
batch_size = images.size(0)
|
||||||
|
total_loss += loss.item() * batch_size
|
||||||
|
total_count += batch_size
|
||||||
|
all_predictions.append(predictions.detach().cpu())
|
||||||
|
all_targets.append(targets.detach().cpu())
|
||||||
|
progress_bar.set_postfix(loss=f"{loss.item():<10.4f}")
|
||||||
|
|
||||||
|
mean_loss = total_loss / max(total_count, 1)
|
||||||
|
metrics = regression_metrics(torch.cat(all_predictions), torch.cat(all_targets))
|
||||||
|
metrics["loss"] = mean_loss
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def write_tensorboard_scalars(writer, epoch, train_metrics, valid_metrics, best_valid_loss, train_seconds, valid_seconds, epoch_seconds, learning_rate):
|
||||||
|
writer.add_scalar("train/loss", train_metrics["loss"], epoch)
|
||||||
|
writer.add_scalar("train/mae", train_metrics["mae"], epoch)
|
||||||
|
writer.add_scalar("train/rmse", train_metrics["rmse"], epoch)
|
||||||
|
|
||||||
|
writer.add_scalar("valid/loss", valid_metrics["loss"], epoch)
|
||||||
|
writer.add_scalar("valid/mae", valid_metrics["mae"], epoch)
|
||||||
|
writer.add_scalar("valid/rmse", valid_metrics["rmse"], epoch)
|
||||||
|
|
||||||
|
writer.add_scalar("summary/best_valid_loss", best_valid_loss, epoch)
|
||||||
|
writer.add_scalar("time/train_seconds", train_seconds, epoch)
|
||||||
|
writer.add_scalar("time/valid_seconds", valid_seconds, epoch)
|
||||||
|
writer.add_scalar("time/epoch_seconds", epoch_seconds, epoch)
|
||||||
|
writer.add_scalar("optimizer/learning_rate", learning_rate, epoch)
|
||||||
|
writer.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def make_dataloaders(data_root=DATA_ROOT, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, seed=SEED):
|
||||||
|
split_data = make_split_lists(data_root=data_root, seed=seed)
|
||||||
|
|
||||||
|
train_dataset = DefocusDataset(
|
||||||
|
split_data["train"]["image_paths"],
|
||||||
|
split_data["train"]["labels"],
|
||||||
|
)
|
||||||
|
valid_dataset = DefocusDataset(
|
||||||
|
split_data["valid"]["image_paths"],
|
||||||
|
split_data["valid"]["labels"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(train_dataset) == 0:
|
||||||
|
raise RuntimeError("训练集为空,请检查数据目录结构。")
|
||||||
|
if len(valid_dataset) == 0:
|
||||||
|
raise RuntimeError("验证集为空,请检查 field 层面的划分结果。")
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
valid_loader = DataLoader(
|
||||||
|
valid_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, valid_loader, split_data
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(path, model, optimizer, epoch, train_metrics, valid_metrics, best_valid_loss):
|
||||||
|
path = Path(path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"model_name": MODEL_NAME,
|
||||||
|
"pretrained": PRETRAINED,
|
||||||
|
"model_state": model.state_dict(),
|
||||||
|
"optimizer_state": optimizer.state_dict(),
|
||||||
|
"train_metrics": train_metrics,
|
||||||
|
"valid_metrics": valid_metrics,
|
||||||
|
"best_valid_loss": best_valid_loss,
|
||||||
|
"config": {
|
||||||
|
"data_root": as_posix_path(DATA_ROOT),
|
||||||
|
"batch_size": BATCH_SIZE,
|
||||||
|
"num_workers": NUM_WORKERS,
|
||||||
|
"learning_rate": LEARNING_RATE,
|
||||||
|
"device": DEVICE,
|
||||||
|
"loss": "SmoothL1Loss",
|
||||||
|
"optimizer": "Adam",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, path)
|
||||||
|
|
||||||
|
|
||||||
|
def log_epoch_metrics(logger, epoch, train_metrics, valid_metrics, train_seconds, valid_seconds, epoch_seconds, learning_rate):
|
||||||
|
logger.info(
|
||||||
|
f"第 {epoch:03d} 轮 | "
|
||||||
|
f"训练 loss={train_metrics['loss']:.6f}, MAE={train_metrics['mae']:.6f}, RMSE={train_metrics['rmse']:.6f} | "
|
||||||
|
f"验证 loss={valid_metrics['loss']:.6f}, MAE={valid_metrics['mae']:.6f}, RMSE={valid_metrics['rmse']:.6f} | "
|
||||||
|
f"耗时 train={train_seconds:.2f}s, valid={valid_seconds:.2f}s, epoch={epoch_seconds:.2f}s | "
|
||||||
|
f"学习率={learning_rate:.6g}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fit(train_loader, valid_loader, epochs=EPOCHS, device_name=DEVICE, logger=None, writer=None):
|
||||||
|
logger = logger or setup_logger()
|
||||||
|
close_writer = writer is None
|
||||||
|
if writer is None:
|
||||||
|
writer = create_summary_writer()
|
||||||
|
logger.info(f"TensorBoard 目录:{as_posix_path(TENSORBOARD_DIR)}")
|
||||||
|
|
||||||
|
device = get_device(device_name)
|
||||||
|
model = create_model().to(device)
|
||||||
|
criterion = torch.nn.SmoothL1Loss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
|
best_valid_loss = float("inf")
|
||||||
|
best_path = CHECKPOINT_DIR / BEST_CHECKPOINT_NAME
|
||||||
|
last_path = CHECKPOINT_DIR / LAST_CHECKPOINT_NAME
|
||||||
|
last_epoch = None
|
||||||
|
last_train_metrics = None
|
||||||
|
last_valid_metrics = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
for epoch in range(1, epochs + 1):
|
||||||
|
epoch_start = time.perf_counter()
|
||||||
|
|
||||||
|
train_start = time.perf_counter()
|
||||||
|
train_metrics = train_epoch(model, train_loader, criterion, optimizer, device, epoch=epoch)
|
||||||
|
train_seconds = time.perf_counter() - train_start
|
||||||
|
|
||||||
|
valid_start = time.perf_counter()
|
||||||
|
valid_metrics = valid_epoch(model, valid_loader, criterion, device, epoch=epoch)
|
||||||
|
valid_seconds = time.perf_counter() - valid_start
|
||||||
|
|
||||||
|
epoch_seconds = time.perf_counter() - epoch_start
|
||||||
|
learning_rate = get_current_lr(optimizer)
|
||||||
|
|
||||||
|
if valid_metrics["loss"] < best_valid_loss:
|
||||||
|
best_valid_loss = valid_metrics["loss"]
|
||||||
|
save_checkpoint(best_path, model, optimizer, epoch, train_metrics, valid_metrics, best_valid_loss)
|
||||||
|
logger.info(f"已保存最佳检查点:{as_posix_path(best_path)}")
|
||||||
|
|
||||||
|
last_epoch = epoch
|
||||||
|
last_train_metrics = train_metrics
|
||||||
|
last_valid_metrics = valid_metrics
|
||||||
|
|
||||||
|
log_epoch_metrics(
|
||||||
|
logger,
|
||||||
|
epoch,
|
||||||
|
train_metrics,
|
||||||
|
valid_metrics,
|
||||||
|
train_seconds,
|
||||||
|
valid_seconds,
|
||||||
|
epoch_seconds,
|
||||||
|
learning_rate,
|
||||||
|
)
|
||||||
|
write_tensorboard_scalars(
|
||||||
|
writer,
|
||||||
|
epoch,
|
||||||
|
train_metrics,
|
||||||
|
valid_metrics,
|
||||||
|
best_valid_loss,
|
||||||
|
train_seconds,
|
||||||
|
valid_seconds,
|
||||||
|
epoch_seconds,
|
||||||
|
learning_rate,
|
||||||
|
)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.warning("检测到 Ctrl-C 手动中止训练,准备保存最后一个完整 epoch 的 last 检查点。")
|
||||||
|
finally:
|
||||||
|
if last_epoch is not None:
|
||||||
|
save_checkpoint(
|
||||||
|
last_path,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
last_epoch,
|
||||||
|
last_train_metrics,
|
||||||
|
last_valid_metrics,
|
||||||
|
best_valid_loss,
|
||||||
|
)
|
||||||
|
logger.info(f"已保存 last 检查点:{as_posix_path(last_path)},对应第 {last_epoch:03d} 轮。")
|
||||||
|
else:
|
||||||
|
logger.warning("训练尚未完整完成任何 epoch,未保存 last 检查点。")
|
||||||
|
|
||||||
|
if close_writer:
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_train():
|
||||||
|
"""小规模训练测试:只取少量样本,验证完整训练链路。"""
|
||||||
|
logger = setup_logger(LOG_DIR / "test_train.log")
|
||||||
|
set_seed(SEED)
|
||||||
|
logger.info(f"数据根目录:{as_posix_path(DATA_ROOT)}")
|
||||||
|
logger.info(f"模型:{MODEL_NAME},pretrained={PRETRAINED}")
|
||||||
|
logger.info(f"设备:{DEVICE}")
|
||||||
|
|
||||||
|
if not DATA_ROOT.exists():
|
||||||
|
logger.info("数据集根目录不存在,先跳过小规模训练测试。")
|
||||||
|
return
|
||||||
|
|
||||||
|
split_data = make_split_lists(data_root=DATA_ROOT, seed=SEED)
|
||||||
|
train_paths = split_data["train"]["image_paths"][:TEST_MAX_SAMPLES]
|
||||||
|
train_labels = split_data["train"]["labels"][:TEST_MAX_SAMPLES]
|
||||||
|
valid_paths = split_data["valid"]["image_paths"][:TEST_MAX_SAMPLES]
|
||||||
|
valid_labels = split_data["valid"]["labels"][:TEST_MAX_SAMPLES]
|
||||||
|
|
||||||
|
if not train_paths or not valid_paths:
|
||||||
|
logger.info("训练集或验证集为空,请检查数据目录结构和 field 划分结果。")
|
||||||
|
return
|
||||||
|
|
||||||
|
train_dataset = DefocusDataset(train_paths, train_labels)
|
||||||
|
valid_dataset = DefocusDataset(valid_paths, valid_labels)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=TEST_BATCH_SIZE,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=TEST_NUM_WORKERS,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
valid_loader = DataLoader(
|
||||||
|
valid_dataset,
|
||||||
|
batch_size=TEST_BATCH_SIZE,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=TEST_NUM_WORKERS,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
writer = create_summary_writer(TENSORBOARD_DIR / "test_train")
|
||||||
|
logger.info(f"TensorBoard 测试目录:{as_posix_path(TENSORBOARD_DIR / 'test_train')}")
|
||||||
|
fit(train_loader, valid_loader, epochs=TEST_EPOCHS, device_name=DEVICE, logger=logger, writer=writer)
|
||||||
|
writer.close()
|
||||||
|
logger.info("小规模训练测试完成。")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logger = setup_logger(LOG_FILE)
|
||||||
|
set_seed(SEED)
|
||||||
|
train_loader, valid_loader, split_data = make_dataloaders(
|
||||||
|
data_root=DATA_ROOT,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
num_workers=NUM_WORKERS,
|
||||||
|
seed=SEED,
|
||||||
|
)
|
||||||
|
logger.info(f"训练图像数:{len(split_data['train']['image_paths'])}")
|
||||||
|
logger.info(f"验证图像数:{len(split_data['valid']['image_paths'])}")
|
||||||
|
fit(train_loader, valid_loader, epochs=EPOCHS, device_name=DEVICE, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# test_train()
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user