from pathlib import Path import random try: import numpy as np except ImportError as exc: raise ImportError("缺少 numpy。请在 torch271 环境中安装 numpy,或告诉我改用其他路线。") from exc try: from PIL import Image except ImportError as exc: raise ImportError("缺少 Pillow。请在 torch271 环境中安装 pillow,或告诉我改用其他路线。") from exc try: import torch from torch.utils.data import Dataset except ImportError as exc: raise ImportError("缺少 PyTorch。请确认当前 conda 环境为 torch271,并已配置 torch271+cu126。") from exc # 数据集根目录。后续只需要改这里,或在构造函数中显式传入 data_root。 DATA_ROOT = Path("E:/Datasets/SimpleAFDataset/roi_with_label") SPLIT_NAMES = ("train", "valid", "test") IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp"} def as_posix_path(path): return Path(path).as_posix() def parse_label_from_path(image_path): """从文件名去掉后缀后的字符串中解析离焦距离标签。""" try: return float(Path(image_path).stem) except ValueError as exc: raise ValueError(f"无法从文件名解析标签:{as_posix_path(image_path)}") from exc def default_image_transform(image, image_size=224): """将 PIL 图像转成 MobileNet 常用输入形状:[3, H, W]。""" image = image.convert("RGB") if image_size is not None: image = image.resize((image_size, image_size), Image.BILINEAR) array = np.asarray(image, dtype=np.float32) / 255.0 array = np.transpose(array, (2, 0, 1)) return torch.from_numpy(array) def find_field_dirs(data_root=DATA_ROOT): """查找所有 sampleXXX/fieldXXX 目录。""" data_root = Path(data_root) if not data_root.exists(): raise FileNotFoundError(f"数据集根目录不存在:{as_posix_path(data_root)}") field_dirs = [] for sample_dir in sorted(data_root.glob("sample[0-9][0-9][0-9]")): if not sample_dir.is_dir(): continue for field_dir in sorted(sample_dir.glob("field[0-9][0-9][0-9]")): if field_dir.is_dir(): field_dirs.append(field_dir) return field_dirs def collect_images_from_fields(field_dirs): """把 field 目录展开成图像路径列表和标签列表。""" image_paths = [] labels = [] for field_dir in sorted(Path(path) for path in field_dirs): for roi_dir in sorted(field_dir.glob("roi[0-9][0-9][0-9]")): if not roi_dir.is_dir(): continue for image_path in sorted(roi_dir.iterdir()): if not image_path.is_file(): continue if image_path.suffix.lower() not in IMAGE_SUFFIXES: continue image_paths.append(image_path) labels.append(parse_label_from_path(image_path)) return image_paths, labels def split_counts(total_count, train_ratio=0.8, valid_ratio=0.1): """计算 field 数量的 8:1:1 划分,样本太少时尽量保留验证/测试。""" if total_count <= 0: return 0, 0, 0 if total_count == 1: return 1, 0, 0 if total_count == 2: return 1, 1, 0 train_count = int(total_count * train_ratio) valid_count = int(total_count * valid_ratio) test_count = total_count - train_count - valid_count if valid_count == 0: valid_count = 1 train_count -= 1 if test_count == 0: test_count = 1 train_count -= 1 return train_count, valid_count, test_count def split_fields(field_dirs, seed=2026, train_ratio=0.8, valid_ratio=0.1): """在 field 层面对数据进行 train/valid/test 划分。""" field_dirs = [Path(path) for path in field_dirs] field_dirs = sorted(field_dirs) rng = random.Random(seed) rng.shuffle(field_dirs) train_count, valid_count, _ = split_counts( len(field_dirs), train_ratio=train_ratio, valid_ratio=valid_ratio, ) train_fields = sorted(field_dirs[:train_count]) valid_fields = sorted(field_dirs[train_count:train_count + valid_count]) test_fields = sorted(field_dirs[train_count + valid_count:]) return { "train": train_fields, "valid": valid_fields, "test": test_fields, } def make_split_lists(data_root=DATA_ROOT, seed=2026, train_ratio=0.8, valid_ratio=0.1): """ 在 field 层面划分数据,并返回每个子集的图像路径和标签。 返回结构: { "train": {"image_paths": [...], "labels": [...], "field_dirs": [...]}, "valid": {"image_paths": [...], "labels": [...], "field_dirs": [...]}, "test": {"image_paths": [...], "labels": [...], "field_dirs": [...]}, } """ field_dirs = find_field_dirs(data_root) split_field_map = split_fields( field_dirs, seed=seed, train_ratio=train_ratio, valid_ratio=valid_ratio, ) split_data = {} for split_name in SPLIT_NAMES: image_paths, labels = collect_images_from_fields(split_field_map[split_name]) split_data[split_name] = { "image_paths": image_paths, "labels": labels, "field_dirs": split_field_map[split_name], } return split_data def get_split_items(split="train", data_root=DATA_ROOT, seed=2026): """返回指定 split 的图像路径列表和标签列表,可直接传给 DefocusDataset。""" if split not in SPLIT_NAMES: raise ValueError(f"split 必须是 {SPLIT_NAMES} 之一,当前为:{split}") split_data = make_split_lists(data_root=data_root, seed=seed) items = split_data[split] return items["image_paths"], items["labels"] class DefocusDataset(Dataset): """离焦距离回归数据集。""" def __init__(self, image_paths, labels, transform=None, return_path=False): if len(image_paths) != len(labels): raise ValueError("image_paths 和 labels 的长度不一致。") self.image_paths = [Path(path) for path in image_paths] self.labels = [float(label) for label in labels] self.transform = transform or default_image_transform self.return_path = return_path @classmethod def from_split(cls, split="train", data_root=DATA_ROOT, seed=2026, transform=None, return_path=False): image_paths, labels = get_split_items(split=split, data_root=data_root, seed=seed) return cls(image_paths, labels, transform=transform, return_path=return_path) def __len__(self): return len(self.image_paths) def __getitem__(self, index): image_path = self.image_paths[index] label = self.labels[index] with Image.open(image_path) as image: image_tensor = self.transform(image) target = torch.tensor([label], dtype=torch.float32) if self.return_path: return image_tensor, target, as_posix_path(image_path) return image_tensor, target def print_split_summary(split_data): """打印划分结果,方便快速确认数据是否符合预期。""" for split_name in SPLIT_NAMES: items = split_data[split_name] labels = items["labels"] field_count = len(items["field_dirs"]) image_count = len(items["image_paths"]) if labels: label_min = min(labels) label_max = max(labels) label_text = f"标签范围:{label_min:.6g} 到 {label_max:.6g}" else: label_text = "标签范围:无" print(f"{split_name}: field 数={field_count}, 图像数={image_count}, {label_text}") for image_path in items["image_paths"][:3]: print(f" 示例:{as_posix_path(image_path)} -> {parse_label_from_path(image_path):.6g}") def test_dataset(): """小规模测试:扫描划分结果,并尝试读取训练集第一张图。""" print(f"当前 DATA_ROOT:{as_posix_path(DATA_ROOT)}") if not DATA_ROOT.exists(): print("数据集根目录还不存在,先跳过真实数据读取测试。") return split_data = make_split_lists(DATA_ROOT) print_split_summary(split_data) train_paths = split_data["train"]["image_paths"] train_labels = split_data["train"]["labels"] if not train_paths: print("训练集没有找到图像,请检查目录是否符合 sampleXXX/fieldXXX/roiXXX/{label}.jpg。") return dataset = DefocusDataset(train_paths[:1], train_labels[:1], return_path=True) image_tensor, target, image_path = dataset[0] print(f"读取样本:{image_path}") print(f"图像张量形状:{tuple(image_tensor.shape)}") print(f"标签张量形状:{tuple(target.shape)},标签值:{target.item():.6g}") def main(): split_data = make_split_lists(DATA_ROOT) print_split_summary(split_data) if __name__ == "__main__": test_dataset() # main()