MOAF/MOAFDatasets.py
2025-10-23 16:40:40 +08:00

161 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset
from PIL import Image
import torch
from torchvision import transforms
class MOAFDataset(Dataset):
def __init__(self, dataset_root, tvt='train', objectives_params_list=None, output_type='distance'):
"""
dataset_root: 根目录Pathable
tvt: 'train'|'val'|'test'(用于选择 transform
objectives_params_list: 列表,包含要加载的物镜目录名,例如 ["10x-0.25-1.0000", ...]
output_type: 'distance'(返回 nm'ratio'(返回 defocus / DoF
"""
super().__init__()
self.dataset_root = Path(dataset_root)
self.tvt = tvt
if objectives_params_list is None:
self.objectives_params_list = ["10x-0.25-1.0000"]
else:
self.objectives_params_list = objectives_params_list
# 处理 output_type非法输入回退到 'distance'
if isinstance(output_type, str) and output_type.lower() == "ratio":
self.output_type = "ratio"
else:
self.output_type = "distance"
# 根据 tvt 选择 transform
if self.tvt == "train":
self.transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
else:
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 直接在构造函数中读取并合并 csv
all_dfs = []
for param_dir in self.objectives_params_list:
csv_file_path = self.dataset_root / "tvtinfo" / param_dir / f"{self.tvt}.csv"
if not csv_file_path.exists():
raise FileNotFoundError(f"CSV not found: {csv_file_path}")
df = pd.read_csv(csv_file_path)
all_dfs.append(df)
if len(all_dfs) == 0:
raise ValueError("No csv files were loaded. Check objectives_params_list and dataset_root.")
combined_df = pd.concat(all_dfs, ignore_index=True)
# 过滤 relative 范围
self.dataframe = combined_df[(combined_df["relative"] >= -50) & (combined_df["relative"] <= 50)].reset_index(drop=True)
def __len__(self):
return len(self.dataframe)
@staticmethod
def _compute_dof_nm(mag, na, rix, wavelength_nm=550.0, pixel_size_nm=3450.0):
"""
公式: DoF = lambda * n / (NA ** 2) + (n * e) / (M * NA)
输入参数均为标量float返回 DoFnm
"""
# 防止除以零
if na == 0 or mag == 0:
return float('inf')
lam = float(wavelength_nm)
n = float(rix)
M = float(mag)
NA = float(na)
e = float(pixel_size_nm)
dof = (lam * n) / (NA ** 2) + ((n * e) / (M * NA))
return dof
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
row = self.dataframe.iloc[idx]
img_path = self.dataset_root / row['path']
image = Image.open(img_path).convert('RGB')
# 基本数值字段(注意 CSV 列名需匹配)
mag = float(row['mag'])
na = float(row['na'])
rix = float(row['rix'])
label_nm = float(row['label'])
image = self.transform(image)
mag_tensor = torch.tensor(mag, dtype=torch.float32)
na_tensor = torch.tensor(na, dtype=torch.float32)
rix_tensor = torch.tensor(rix, dtype=torch.float32)
label_nm_tensor = torch.tensor(label_nm, dtype=torch.float32)
# min-max 归一化输入参数
mag_tensor = (mag_tensor - 10) / (100 - 10)
na_tensor = (na_tensor - 0) / (1.25 - 0)
rix_tensor = (rix_tensor - 1.0) / (1.5 - 1.0)
# 根据 output_type 决定输出 label
if self.output_type == "ratio":
dof_nm = self._compute_dof_nm(mag=mag, na=na, rix=rix, wavelength_nm=550.0, pixel_size_nm=3450.0)
# 若 DOF 为 inf 或极大,避免除零
if not (dof_nm is None or dof_nm == float('inf') or dof_nm == 0):
label_out = label_nm / dof_nm
else:
label_out = label_nm # 回退,虽然不太可能
label_out_tensor = torch.tensor(float(label_out), dtype=torch.float32)
else:
# distance 模式:直接返回 nm
label_out_tensor = label_nm_tensor
sample = {
'image': image,
'mag': mag_tensor,
'na': na_tensor,
'rix': rix_tensor,
'label': label_out_tensor,
'path': img_path.as_posix(),
}
return sample
def get_dataframe(self):
return self.dataframe
if __name__ == "__main__":
# 简单测试
train_set = MOAFDataset("F:/Datasets/MODatasetD", tvt='train',
objectives_params_list=[
"100x-1.25-1.4730",
],
output_type='ratio')
from torch.utils.data import DataLoader
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2)
for batch in train_loader:
images = batch["image"]
labels = batch["label"]
print(f"images.shape: {images.shape}, labels.shape: {labels.shape}")
mags = batch["mag"]
nas = batch["na"]
rixs = batch["rix"]
print(f"mags: {mags}, nas: {nas}, rixs: {rixs}")
print(f"mags.shape: {mags.shape}, nas.shape: {nas.shape}, rixs.shape: {rixs.shape}")
params = torch.stack((mags, nas, rixs), dim=1)
print(f"params shape: {params.shape}")
print("first labels:")
for i in range(min(4, labels.shape[0])):
print(labels[i].item())
break