import pandas as pd from pathlib import Path from torch.utils.data import Dataset from PIL import Image import torch from torchvision import transforms class MOAFDataset(Dataset): def __init__(self, dataset_root, tvt='train', objectives_params_list=None): """ dataset_root: 根目录(Pathable) tvt: 'train'|'val'|'test'(用于选择 transform) objectives_params_list: 列表,包含要加载的物镜目录名,例如 ["10x-0.25-1.0000", ...] """ super().__init__() self.dataset_root = Path(dataset_root) self.tvt = tvt if objectives_params_list is None: self.objectives_params_list = ["10x-0.25-1.0000"] else: self.objectives_params_list = objectives_params_list # 根据 tvt 选择 transform if self.tvt == "train": self.transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 直接在构造函数中读取并合并 csv all_dfs = [] for param_dir in self.objectives_params_list: csv_file_path = self.dataset_root / "tvtinfo" / param_dir / f"{self.tvt}.csv" if not csv_file_path.exists(): raise FileNotFoundError(f"CSV not found: {csv_file_path}") df = pd.read_csv(csv_file_path) all_dfs.append(df) if len(all_dfs) == 0: raise ValueError("No csv files were loaded. Check objectives_params_list and dataset_root.") combined_df = pd.concat(all_dfs, ignore_index=True) # 过滤 relative 范围 self.dataframe = combined_df[(combined_df["relative"] >= -50) & (combined_df["relative"] <= 50)].reset_index(drop=True) def __len__(self): return len(self.dataframe) @staticmethod def _compute_dof_nm(mag, na, rix, wavelength_nm=550.0, pixel_size_nm=3450.0): """ 公式: DoF = lambda * n / (NA ** 2) + (n * e) / (M * NA) 输入参数均为标量(float),返回 DoF(nm) """ # 防止除以零 if na == 0 or mag == 0: return float('inf') lam = float(wavelength_nm) n = float(rix) M = float(mag) NA = float(na) e = float(pixel_size_nm) dof = (lam * n) / (NA ** 2) + ((n * e) / (M * NA)) return dof def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() row = self.dataframe.iloc[idx] img_path = self.dataset_root / row['path'] image = Image.open(img_path).convert('RGB') # 基本数值字段(注意 CSV 列名需匹配) mag = float(row['mag']) na = float(row['na']) rix = float(row['rix']) label_nm = float(row['label']) image = self.transform(image) mag_tensor = torch.tensor(mag, dtype=torch.float32) na_tensor = torch.tensor(na, dtype=torch.float32) rix_tensor = torch.tensor(rix, dtype=torch.float32) label_nm_tensor = torch.tensor(label_nm, dtype=torch.float32) # # min-max 归一化输入参数 # mag_tensor = (mag_tensor - 10) / (100 - 10) # na_tensor = (na_tensor - 0) / (1.25 - 0) # rix_tensor = (rix_tensor - 1.0) / (1.5 - 1.0) sample = { 'image': image, 'mag': mag_tensor, 'na': na_tensor, 'rix': rix_tensor, 'label': label_nm_tensor, 'path': img_path.as_posix(), } return sample def get_dataframe(self): return self.dataframe if __name__ == "__main__": # 简单测试 train_set = MOAFDataset("F:/Datasets/MODatasetD", tvt='train', objectives_params_list=[ "100x-1.25-1.4730", ]) from torch.utils.data import DataLoader train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2) for batch in train_loader: images = batch["image"] labels = batch["label"] print(f"images.shape: {images.shape}, labels.shape: {labels.shape}") mags = batch["mag"] nas = batch["na"] rixs = batch["rix"] print(f"mags: {mags}, nas: {nas}, rixs: {rixs}") print(f"mags.shape: {mags.shape}, nas.shape: {nas.shape}, rixs.shape: {rixs.shape}") params = torch.stack((mags, nas, rixs), dim=1) print(f"params shape: {params.shape}") print("first labels:") for i in range(min(4, labels.shape[0])): print(labels[i].item()) break