262 lines
8.7 KiB
Python
262 lines
8.7 KiB
Python
from pathlib import Path
|
||
import random
|
||
|
||
try:
|
||
import numpy as np
|
||
except ImportError as exc:
|
||
raise ImportError("缺少 numpy。请在 torch271 环境中安装 numpy,或告诉我改用其他路线。") from exc
|
||
|
||
try:
|
||
from PIL import Image
|
||
except ImportError as exc:
|
||
raise ImportError("缺少 Pillow。请在 torch271 环境中安装 pillow,或告诉我改用其他路线。") from exc
|
||
|
||
try:
|
||
import torch
|
||
from torch.utils.data import Dataset
|
||
except ImportError as exc:
|
||
raise ImportError("缺少 PyTorch。请确认当前 conda 环境为 torch271,并已配置 torch271+cu126。") from exc
|
||
|
||
|
||
# 数据集根目录。后续只需要改这里,或在构造函数中显式传入 data_root。
|
||
DATA_ROOT = Path("E:/Datasets/SimpleAFDataset/roi_with_label")
|
||
|
||
SPLIT_NAMES = ("train", "valid", "test")
|
||
IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp"}
|
||
|
||
|
||
def as_posix_path(path):
|
||
return Path(path).as_posix()
|
||
|
||
|
||
def parse_label_from_path(image_path):
|
||
"""从文件名去掉后缀后的字符串中解析离焦距离标签。"""
|
||
try:
|
||
return float(Path(image_path).stem)
|
||
except ValueError as exc:
|
||
raise ValueError(f"无法从文件名解析标签:{as_posix_path(image_path)}") from exc
|
||
|
||
|
||
def default_image_transform(image, image_size=224):
|
||
"""将 PIL 图像转成 MobileNet 常用输入形状:[3, H, W]。"""
|
||
image = image.convert("RGB")
|
||
if image_size is not None:
|
||
image = image.resize((image_size, image_size), Image.BILINEAR)
|
||
|
||
array = np.asarray(image, dtype=np.float32) / 255.0
|
||
array = np.transpose(array, (2, 0, 1))
|
||
return torch.from_numpy(array)
|
||
|
||
|
||
def find_field_dirs(data_root=DATA_ROOT):
|
||
"""查找所有 sampleXXX/fieldXXX 目录。"""
|
||
data_root = Path(data_root)
|
||
if not data_root.exists():
|
||
raise FileNotFoundError(f"数据集根目录不存在:{as_posix_path(data_root)}")
|
||
|
||
field_dirs = []
|
||
for sample_dir in sorted(data_root.glob("sample[0-9][0-9][0-9]")):
|
||
if not sample_dir.is_dir():
|
||
continue
|
||
for field_dir in sorted(sample_dir.glob("field[0-9][0-9][0-9]")):
|
||
if field_dir.is_dir():
|
||
field_dirs.append(field_dir)
|
||
|
||
return field_dirs
|
||
|
||
|
||
def collect_images_from_fields(field_dirs):
|
||
"""把 field 目录展开成图像路径列表和标签列表。"""
|
||
image_paths = []
|
||
labels = []
|
||
|
||
for field_dir in sorted(Path(path) for path in field_dirs):
|
||
for roi_dir in sorted(field_dir.glob("roi[0-9][0-9][0-9]")):
|
||
if not roi_dir.is_dir():
|
||
continue
|
||
for image_path in sorted(roi_dir.iterdir()):
|
||
if not image_path.is_file():
|
||
continue
|
||
if image_path.suffix.lower() not in IMAGE_SUFFIXES:
|
||
continue
|
||
image_paths.append(image_path)
|
||
labels.append(parse_label_from_path(image_path))
|
||
|
||
return image_paths, labels
|
||
|
||
|
||
def split_counts(total_count, train_ratio=0.8, valid_ratio=0.1):
|
||
"""计算 field 数量的 8:1:1 划分,样本太少时尽量保留验证/测试。"""
|
||
if total_count <= 0:
|
||
return 0, 0, 0
|
||
if total_count == 1:
|
||
return 1, 0, 0
|
||
if total_count == 2:
|
||
return 1, 1, 0
|
||
|
||
train_count = int(total_count * train_ratio)
|
||
valid_count = int(total_count * valid_ratio)
|
||
test_count = total_count - train_count - valid_count
|
||
|
||
if valid_count == 0:
|
||
valid_count = 1
|
||
train_count -= 1
|
||
if test_count == 0:
|
||
test_count = 1
|
||
train_count -= 1
|
||
|
||
return train_count, valid_count, test_count
|
||
|
||
|
||
def split_fields(field_dirs, seed=2026, train_ratio=0.8, valid_ratio=0.1):
|
||
"""在 field 层面对数据进行 train/valid/test 划分。"""
|
||
field_dirs = [Path(path) for path in field_dirs]
|
||
field_dirs = sorted(field_dirs)
|
||
|
||
rng = random.Random(seed)
|
||
rng.shuffle(field_dirs)
|
||
|
||
train_count, valid_count, _ = split_counts(
|
||
len(field_dirs),
|
||
train_ratio=train_ratio,
|
||
valid_ratio=valid_ratio,
|
||
)
|
||
|
||
train_fields = sorted(field_dirs[:train_count])
|
||
valid_fields = sorted(field_dirs[train_count:train_count + valid_count])
|
||
test_fields = sorted(field_dirs[train_count + valid_count:])
|
||
|
||
return {
|
||
"train": train_fields,
|
||
"valid": valid_fields,
|
||
"test": test_fields,
|
||
}
|
||
|
||
|
||
def make_split_lists(data_root=DATA_ROOT, seed=2026, train_ratio=0.8, valid_ratio=0.1):
|
||
"""
|
||
在 field 层面划分数据,并返回每个子集的图像路径和标签。
|
||
|
||
返回结构:
|
||
{
|
||
"train": {"image_paths": [...], "labels": [...], "field_dirs": [...]},
|
||
"valid": {"image_paths": [...], "labels": [...], "field_dirs": [...]},
|
||
"test": {"image_paths": [...], "labels": [...], "field_dirs": [...]},
|
||
}
|
||
"""
|
||
field_dirs = find_field_dirs(data_root)
|
||
split_field_map = split_fields(
|
||
field_dirs,
|
||
seed=seed,
|
||
train_ratio=train_ratio,
|
||
valid_ratio=valid_ratio,
|
||
)
|
||
|
||
split_data = {}
|
||
for split_name in SPLIT_NAMES:
|
||
image_paths, labels = collect_images_from_fields(split_field_map[split_name])
|
||
split_data[split_name] = {
|
||
"image_paths": image_paths,
|
||
"labels": labels,
|
||
"field_dirs": split_field_map[split_name],
|
||
}
|
||
|
||
return split_data
|
||
|
||
|
||
def get_split_items(split="train", data_root=DATA_ROOT, seed=2026):
|
||
"""返回指定 split 的图像路径列表和标签列表,可直接传给 DefocusDataset。"""
|
||
if split not in SPLIT_NAMES:
|
||
raise ValueError(f"split 必须是 {SPLIT_NAMES} 之一,当前为:{split}")
|
||
|
||
split_data = make_split_lists(data_root=data_root, seed=seed)
|
||
items = split_data[split]
|
||
return items["image_paths"], items["labels"]
|
||
|
||
|
||
class DefocusDataset(Dataset):
|
||
"""离焦距离回归数据集。"""
|
||
|
||
def __init__(self, image_paths, labels, transform=None, return_path=False):
|
||
if len(image_paths) != len(labels):
|
||
raise ValueError("image_paths 和 labels 的长度不一致。")
|
||
|
||
self.image_paths = [Path(path) for path in image_paths]
|
||
self.labels = [float(label) for label in labels]
|
||
self.transform = transform or default_image_transform
|
||
self.return_path = return_path
|
||
|
||
@classmethod
|
||
def from_split(cls, split="train", data_root=DATA_ROOT, seed=2026, transform=None, return_path=False):
|
||
image_paths, labels = get_split_items(split=split, data_root=data_root, seed=seed)
|
||
return cls(image_paths, labels, transform=transform, return_path=return_path)
|
||
|
||
def __len__(self):
|
||
return len(self.image_paths)
|
||
|
||
def __getitem__(self, index):
|
||
image_path = self.image_paths[index]
|
||
label = self.labels[index]
|
||
|
||
with Image.open(image_path) as image:
|
||
image_tensor = self.transform(image)
|
||
|
||
target = torch.tensor([label], dtype=torch.float32)
|
||
|
||
if self.return_path:
|
||
return image_tensor, target, as_posix_path(image_path)
|
||
return image_tensor, target
|
||
|
||
|
||
def print_split_summary(split_data):
|
||
"""打印划分结果,方便快速确认数据是否符合预期。"""
|
||
for split_name in SPLIT_NAMES:
|
||
items = split_data[split_name]
|
||
labels = items["labels"]
|
||
field_count = len(items["field_dirs"])
|
||
image_count = len(items["image_paths"])
|
||
|
||
if labels:
|
||
label_min = min(labels)
|
||
label_max = max(labels)
|
||
label_text = f"标签范围:{label_min:.6g} 到 {label_max:.6g}"
|
||
else:
|
||
label_text = "标签范围:无"
|
||
|
||
print(f"{split_name}: field 数={field_count}, 图像数={image_count}, {label_text}")
|
||
for image_path in items["image_paths"][:3]:
|
||
print(f" 示例:{as_posix_path(image_path)} -> {parse_label_from_path(image_path):.6g}")
|
||
|
||
|
||
def test_dataset():
|
||
"""小规模测试:扫描划分结果,并尝试读取训练集第一张图。"""
|
||
print(f"当前 DATA_ROOT:{as_posix_path(DATA_ROOT)}")
|
||
if not DATA_ROOT.exists():
|
||
print("数据集根目录还不存在,先跳过真实数据读取测试。")
|
||
return
|
||
|
||
split_data = make_split_lists(DATA_ROOT)
|
||
print_split_summary(split_data)
|
||
|
||
train_paths = split_data["train"]["image_paths"]
|
||
train_labels = split_data["train"]["labels"]
|
||
if not train_paths:
|
||
print("训练集没有找到图像,请检查目录是否符合 sampleXXX/fieldXXX/roiXXX/{label}.jpg。")
|
||
return
|
||
|
||
dataset = DefocusDataset(train_paths[:1], train_labels[:1], return_path=True)
|
||
image_tensor, target, image_path = dataset[0]
|
||
print(f"读取样本:{image_path}")
|
||
print(f"图像张量形状:{tuple(image_tensor.shape)}")
|
||
print(f"标签张量形状:{tuple(target.shape)},标签值:{target.item():.6g}")
|
||
|
||
|
||
def main():
|
||
split_data = make_split_lists(DATA_ROOT)
|
||
print_split_summary(split_data)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_dataset()
|
||
# main()
|