DefocusEstimate/dataset.py
2026-05-17 21:06:33 +08:00

262 lines
8.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pathlib import Path
import random
try:
import numpy as np
except ImportError as exc:
raise ImportError("缺少 numpy。请在 torch271 环境中安装 numpy或告诉我改用其他路线。") from exc
try:
from PIL import Image
except ImportError as exc:
raise ImportError("缺少 Pillow。请在 torch271 环境中安装 pillow或告诉我改用其他路线。") from exc
try:
import torch
from torch.utils.data import Dataset
except ImportError as exc:
raise ImportError("缺少 PyTorch。请确认当前 conda 环境为 torch271并已配置 torch271+cu126。") from exc
# 数据集根目录。后续只需要改这里,或在构造函数中显式传入 data_root。
DATA_ROOT = Path("E:/Datasets/SimpleAFDataset/roi_with_label")
SPLIT_NAMES = ("train", "valid", "test")
IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp"}
def as_posix_path(path):
return Path(path).as_posix()
def parse_label_from_path(image_path):
"""从文件名去掉后缀后的字符串中解析离焦距离标签。"""
try:
return float(Path(image_path).stem)
except ValueError as exc:
raise ValueError(f"无法从文件名解析标签:{as_posix_path(image_path)}") from exc
def default_image_transform(image, image_size=224):
"""将 PIL 图像转成 MobileNet 常用输入形状:[3, H, W]。"""
image = image.convert("RGB")
if image_size is not None:
image = image.resize((image_size, image_size), Image.BILINEAR)
array = np.asarray(image, dtype=np.float32) / 255.0
array = np.transpose(array, (2, 0, 1))
return torch.from_numpy(array)
def find_field_dirs(data_root=DATA_ROOT):
"""查找所有 sampleXXX/fieldXXX 目录。"""
data_root = Path(data_root)
if not data_root.exists():
raise FileNotFoundError(f"数据集根目录不存在:{as_posix_path(data_root)}")
field_dirs = []
for sample_dir in sorted(data_root.glob("sample[0-9][0-9][0-9]")):
if not sample_dir.is_dir():
continue
for field_dir in sorted(sample_dir.glob("field[0-9][0-9][0-9]")):
if field_dir.is_dir():
field_dirs.append(field_dir)
return field_dirs
def collect_images_from_fields(field_dirs):
"""把 field 目录展开成图像路径列表和标签列表。"""
image_paths = []
labels = []
for field_dir in sorted(Path(path) for path in field_dirs):
for roi_dir in sorted(field_dir.glob("roi[0-9][0-9][0-9]")):
if not roi_dir.is_dir():
continue
for image_path in sorted(roi_dir.iterdir()):
if not image_path.is_file():
continue
if image_path.suffix.lower() not in IMAGE_SUFFIXES:
continue
image_paths.append(image_path)
labels.append(parse_label_from_path(image_path))
return image_paths, labels
def split_counts(total_count, train_ratio=0.8, valid_ratio=0.1):
"""计算 field 数量的 8:1:1 划分,样本太少时尽量保留验证/测试。"""
if total_count <= 0:
return 0, 0, 0
if total_count == 1:
return 1, 0, 0
if total_count == 2:
return 1, 1, 0
train_count = int(total_count * train_ratio)
valid_count = int(total_count * valid_ratio)
test_count = total_count - train_count - valid_count
if valid_count == 0:
valid_count = 1
train_count -= 1
if test_count == 0:
test_count = 1
train_count -= 1
return train_count, valid_count, test_count
def split_fields(field_dirs, seed=2026, train_ratio=0.8, valid_ratio=0.1):
"""在 field 层面对数据进行 train/valid/test 划分。"""
field_dirs = [Path(path) for path in field_dirs]
field_dirs = sorted(field_dirs)
rng = random.Random(seed)
rng.shuffle(field_dirs)
train_count, valid_count, _ = split_counts(
len(field_dirs),
train_ratio=train_ratio,
valid_ratio=valid_ratio,
)
train_fields = sorted(field_dirs[:train_count])
valid_fields = sorted(field_dirs[train_count:train_count + valid_count])
test_fields = sorted(field_dirs[train_count + valid_count:])
return {
"train": train_fields,
"valid": valid_fields,
"test": test_fields,
}
def make_split_lists(data_root=DATA_ROOT, seed=2026, train_ratio=0.8, valid_ratio=0.1):
"""
在 field 层面划分数据,并返回每个子集的图像路径和标签。
返回结构:
{
"train": {"image_paths": [...], "labels": [...], "field_dirs": [...]},
"valid": {"image_paths": [...], "labels": [...], "field_dirs": [...]},
"test": {"image_paths": [...], "labels": [...], "field_dirs": [...]},
}
"""
field_dirs = find_field_dirs(data_root)
split_field_map = split_fields(
field_dirs,
seed=seed,
train_ratio=train_ratio,
valid_ratio=valid_ratio,
)
split_data = {}
for split_name in SPLIT_NAMES:
image_paths, labels = collect_images_from_fields(split_field_map[split_name])
split_data[split_name] = {
"image_paths": image_paths,
"labels": labels,
"field_dirs": split_field_map[split_name],
}
return split_data
def get_split_items(split="train", data_root=DATA_ROOT, seed=2026):
"""返回指定 split 的图像路径列表和标签列表,可直接传给 DefocusDataset。"""
if split not in SPLIT_NAMES:
raise ValueError(f"split 必须是 {SPLIT_NAMES} 之一,当前为:{split}")
split_data = make_split_lists(data_root=data_root, seed=seed)
items = split_data[split]
return items["image_paths"], items["labels"]
class DefocusDataset(Dataset):
"""离焦距离回归数据集。"""
def __init__(self, image_paths, labels, transform=None, return_path=False):
if len(image_paths) != len(labels):
raise ValueError("image_paths 和 labels 的长度不一致。")
self.image_paths = [Path(path) for path in image_paths]
self.labels = [float(label) for label in labels]
self.transform = transform or default_image_transform
self.return_path = return_path
@classmethod
def from_split(cls, split="train", data_root=DATA_ROOT, seed=2026, transform=None, return_path=False):
image_paths, labels = get_split_items(split=split, data_root=data_root, seed=seed)
return cls(image_paths, labels, transform=transform, return_path=return_path)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, index):
image_path = self.image_paths[index]
label = self.labels[index]
with Image.open(image_path) as image:
image_tensor = self.transform(image)
target = torch.tensor([label], dtype=torch.float32)
if self.return_path:
return image_tensor, target, as_posix_path(image_path)
return image_tensor, target
def print_split_summary(split_data):
"""打印划分结果,方便快速确认数据是否符合预期。"""
for split_name in SPLIT_NAMES:
items = split_data[split_name]
labels = items["labels"]
field_count = len(items["field_dirs"])
image_count = len(items["image_paths"])
if labels:
label_min = min(labels)
label_max = max(labels)
label_text = f"标签范围:{label_min:.6g}{label_max:.6g}"
else:
label_text = "标签范围:无"
print(f"{split_name}: field 数={field_count}, 图像数={image_count}, {label_text}")
for image_path in items["image_paths"][:3]:
print(f" 示例:{as_posix_path(image_path)} -> {parse_label_from_path(image_path):.6g}")
def test_dataset():
"""小规模测试:扫描划分结果,并尝试读取训练集第一张图。"""
print(f"当前 DATA_ROOT{as_posix_path(DATA_ROOT)}")
if not DATA_ROOT.exists():
print("数据集根目录还不存在,先跳过真实数据读取测试。")
return
split_data = make_split_lists(DATA_ROOT)
print_split_summary(split_data)
train_paths = split_data["train"]["image_paths"]
train_labels = split_data["train"]["labels"]
if not train_paths:
print("训练集没有找到图像,请检查目录是否符合 sampleXXX/fieldXXX/roiXXX/{label}.jpg。")
return
dataset = DefocusDataset(train_paths[:1], train_labels[:1], return_path=True)
image_tensor, target, image_path = dataset[0]
print(f"读取样本:{image_path}")
print(f"图像张量形状:{tuple(image_tensor.shape)}")
print(f"标签张量形状:{tuple(target.shape)},标签值:{target.item():.6g}")
def main():
split_data = make_split_lists(DATA_ROOT)
print_split_summary(split_data)
if __name__ == "__main__":
test_dataset()
# main()