237 lines
8.7 KiB
Markdown
237 lines
8.7 KiB
Markdown
# AGENTS.md
|
||
|
||
## 项目目标
|
||
|
||
本项目是一个用于实践深度学习模型量化技术的小型 PyTorch 工程。背景任务是:根据离焦图像预测离焦距离。任务形式是回归问题,核心模型使用 `timm` 库提供的 MobileNetV4,并将 `num_classes` 设置为 `1`,让模型输出一个标量预测值。
|
||
|
||
本项目规模不大,源码可以直接放在项目根目录。整体设计应当清晰、直接、容易检查,不需要搭建复杂的软件包结构。
|
||
|
||
## 语言约束
|
||
|
||
- 面向用户的所有说明、总结、报错解释、运行提示、文档内容应当使用中文。
|
||
- 项目文档优先使用中文。
|
||
- 代码注释优先使用中文,除非某些英文术语是库、API、指标名或行业固定表达。
|
||
- 日志、打印信息、错误提示优先使用中文,方便用户不用切换语言环境阅读。
|
||
- 变量名、函数名、文件名仍然使用常规 Python 英文命名风格,不要为了中文化而使用拼音或中文变量名。
|
||
|
||
## 环境约束
|
||
|
||
- Conda 环境名:`torch271`。
|
||
- Python 版本:`py310`。
|
||
- 预期 PyTorch 栈:`torch271+cu126`。
|
||
- 用户会提前配置好 conda 环境。
|
||
- 缺少依赖时,不允许擅自下载或安装。
|
||
- 如果发现缺包,应当用中文说明缺少哪个包、可能需要的安装命令或替代路线,然后等待用户决定。
|
||
|
||
## 主要依赖
|
||
|
||
预期会使用的主要库:
|
||
|
||
- `python`
|
||
- `torch`
|
||
- `torchvision`
|
||
- `timm`
|
||
- `numpy`
|
||
- `Pillow`
|
||
- 按需使用:`pandas`、`matplotlib`、`scikit-learn`、`onnx`、`onnxruntime` 以及量化相关后端库。
|
||
|
||
依赖应尽量少。标准库或已有工具函数能解决的问题,不要额外引入新依赖。
|
||
|
||
## 路径规范
|
||
|
||
- 所有文件和路径处理一律使用 `pathlib`。
|
||
- 打印路径、保存路径到文本、写入 CSV/JSON 元数据、生成配置字符串时,统一转换为 POSIX 风格,也就是使用 `/` 正斜杠。
|
||
- 源码中避免硬编码 Windows 反斜杠。
|
||
- 面向用户的示例路径优先使用相对项目根目录的路径。
|
||
|
||
推荐工具函数形式:
|
||
|
||
```python
|
||
from pathlib import Path
|
||
|
||
|
||
def as_posix_path(path):
|
||
return Path(path).as_posix()
|
||
```
|
||
|
||
## 代码风格
|
||
|
||
- 代码应当直接、清楚、容易读。
|
||
- 不要求严格类型标注。
|
||
- 每个可运行源码文件都应当有显式的 `main()` 函数。
|
||
- 涉及大规模训练、评估、量化、统计的脚本,应当同时提供一个小规模 `test_*()` 函数。
|
||
- 小规模测试函数要使用类似主流程的逻辑,但只处理极少量数据,用于快速验证数据形状、模型前向传播、损失计算、指标计算、文件保存或量化流程。
|
||
- 默认运行行为应优先执行小规模测试函数,用户确认后再手动切换到 `main()` 做大规模实验。
|
||
- 使用 `if __name__ == "__main__":` 显式控制入口。
|
||
- 导入模块时不要产生训练、推理、写文件等隐藏副作用。
|
||
- 注释要简短有用,重点解释不明显的数据约定、校准策略、量化限制和指标定义。
|
||
|
||
## 计划源码文件
|
||
|
||
项目发展过程中至少应包含这些根目录源码文件:
|
||
|
||
- `dataset.py`:数据集定义、图像变换、标签解析、小规模数据读取检查。
|
||
- `model.py`:MobileNetV4 回归模型创建、检查点加载、可选模型结构检查。
|
||
- `train.py`:训练循环、验证循环、检查点保存、小规模训练测试。
|
||
- `test.py`:训练后模型的推理和评估、指标计算、预测结果导出、小规模评估测试。
|
||
- `quantize.py`:量化实验、校准、量化后评估、后端或导出相关逻辑。
|
||
- `stats.py`:数据集统计、标签分布、预测误差统计、可选图表输出。
|
||
- `utils.py`:通用路径函数、随机种子、设备选择、日志、指标、检查点辅助函数。
|
||
|
||
可以按需新增文件,但新增文件应当让项目更清楚,而不是把小项目变成复杂框架。
|
||
|
||
## 数据集设计
|
||
|
||
数据集源码文件需要明确预期数据格式。如果数据布局尚未最终确定,应当让设计可以适配以下两种常见形式:
|
||
|
||
- CSV 文件中包含图像路径和数值型离焦距离标签;
|
||
- 文件夹结构配合一个元数据文件。
|
||
|
||
推荐约定:
|
||
|
||
- 图像路径列名:`image_path`。
|
||
- 回归目标列名:`defocus_distance`。
|
||
- 所有图像路径都通过 `Path` 处理。
|
||
- 相对图像路径应当相对于元数据文件所在目录或显式传入的数据根目录解析。
|
||
- 样本返回形式建议为 `(image_tensor, target_tensor)`。
|
||
- 目标值应为浮点张量,形状要能和模型输出 `[batch, 1]` 对齐。
|
||
- 数据集文件中应包含小规模测试函数,用于读取少量样本并打印图像张量形状、目标形状、目标范围和路径示例。
|
||
|
||
## 模型设计
|
||
|
||
使用 `timm.create_model()` 创建 MobileNetV4。
|
||
|
||
推荐形式:
|
||
|
||
```python
|
||
import timm
|
||
|
||
|
||
def create_model(model_name="mobilenetv4_conv_small", pretrained=True):
|
||
model = timm.create_model(model_name, pretrained=pretrained, num_classes=1)
|
||
return model
|
||
```
|
||
|
||
模型创建逻辑应当独立出来,让训练、测试和量化脚本都复用同一个模型工厂函数。
|
||
|
||
如果后续要切换具体 MobileNetV4 变体,应通过函数参数或简单脚本常量控制。
|
||
|
||
## 训练设计
|
||
|
||
训练代码应包含:
|
||
|
||
- 随机种子设置;
|
||
- 设备选择;
|
||
- 数据集和 DataLoader 创建;
|
||
- 创建 `num_classes=1` 的回归模型;
|
||
- 回归损失函数,例如 `MSELoss`、`L1Loss` 或 `SmoothL1Loss`;
|
||
- MAE、RMSE 等回归指标;
|
||
- 验证循环;
|
||
- 检查点保存,内容至少包括模型状态、优化器状态、epoch、指标和必要配置;
|
||
- 小规模 `test_train()` 流程,只运行极少量样本、batch 和 epoch。
|
||
|
||
脚本默认执行时应优先跑小规模测试,避免误触发长时间训练。
|
||
|
||
## 测试与评估设计
|
||
|
||
评估代码应当:
|
||
|
||
- 加载检查点;
|
||
- 重建相同模型结构;
|
||
- 在测试数据集上推理;
|
||
- 计算回归指标;
|
||
- 可选地将预测结果保存为 CSV,并统一使用 POSIX 风格路径;
|
||
- 包含 `test_eval()` 或类似小规模评估函数。
|
||
|
||
如果后续引入测试框架,要注意 `test.py` 这个项目脚本名和测试框架的文件发现规则是否冲突。
|
||
|
||
## 量化设计
|
||
|
||
量化代码应和普通训练、评估逻辑分离,但复用数据集、模型和指标工具函数。
|
||
|
||
可选技术路线取决于 PyTorch、`timm` 和部署目标的支持情况:
|
||
|
||
- 动态量化;
|
||
- 训练后静态量化和校准;
|
||
- 量化感知训练;
|
||
- 必要时使用 ONNX 等导出路线。
|
||
|
||
每个量化实验应记录:
|
||
|
||
- 原始模型或检查点来源;
|
||
- 量化方法;
|
||
- 校准数据规模和选择规则;
|
||
- 后端或导出格式;
|
||
- 量化前后评估指标;
|
||
- 如果测量了,还应记录模型大小和推理延迟。
|
||
|
||
量化脚本必须提供小规模测试或小规模校准函数,不能默认直接跑完整量化实验。
|
||
|
||
## 统计设计
|
||
|
||
统计代码应支持:
|
||
|
||
- 数据集大小和划分统计;
|
||
- 标签最小值、最大值、均值、标准差;
|
||
- 标签分布分箱;
|
||
- 缺失文件检查;
|
||
- 图像尺寸统计;
|
||
- 评估后的预测误差统计;
|
||
- 可选图表输出。
|
||
|
||
保存统计结果时,文件名要清楚,路径统一使用 POSIX 风格。
|
||
|
||
## 工具函数设计
|
||
|
||
共享工具函数可包括:
|
||
|
||
- `project_root()`;
|
||
- POSIX 路径格式化;
|
||
- 随机种子设置;
|
||
- 设备选择;
|
||
- MAE、RMSE 等指标函数;
|
||
- 检查点保存和加载;
|
||
- 简单日志函数;
|
||
- 图像扩展名过滤。
|
||
|
||
不要把 `utils.py` 变成核心逻辑杂物间。数据集逻辑放在 `dataset.py`,模型创建放在 `model.py`,实验流程放在对应脚本里。
|
||
|
||
## 执行原则
|
||
|
||
对于耗时脚本,应始终保留两级执行入口:
|
||
|
||
1. 小规模测试函数:快速验证同一套逻辑是否能跑通。
|
||
2. `main()` 函数:完整训练、评估、量化或统计。
|
||
|
||
大规模实验前,先运行对应小规模测试函数。是否切换到完整 `main()` 由用户决定。
|
||
|
||
推荐入口形式:
|
||
|
||
```python
|
||
def test_train():
|
||
...
|
||
|
||
|
||
def main():
|
||
...
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_train()
|
||
# main()
|
||
```
|
||
|
||
## 后续协作规则
|
||
|
||
- 修改项目前,先阅读本文件。
|
||
- 后续沟通、总结和错误说明使用中文。
|
||
- 保持小项目结构,除非用户明确要求改成更复杂的包结构。
|
||
- 不要未经允许安装依赖。
|
||
- 在用户提供数据集格式之前,不要假设数据布局已经固定。
|
||
- 路径处理始终使用 `pathlib`。
|
||
- 面向外部展示、保存或打印的路径统一使用 `/`。
|
||
- 可运行脚本默认应安全,优先执行小规模测试,而不是直接启动完整训练。
|
||
- 编辑已有用户工作时,不要回退无关改动。
|
||
- 在多个合理方案之间不确定时,选择最简单、最符合本文档和现有代码的方案。
|
||
|