161 lines
6.0 KiB
Python
161 lines
6.0 KiB
Python
import pandas as pd
|
||
from pathlib import Path
|
||
from torch.utils.data import Dataset
|
||
from PIL import Image
|
||
import torch
|
||
from torchvision import transforms
|
||
|
||
|
||
class MOAFDataset(Dataset):
|
||
def __init__(self, dataset_root, tvt='train', objectives_params_list=None, output_type='distance'):
|
||
"""
|
||
dataset_root: 根目录(Pathable)
|
||
tvt: 'train'|'val'|'test'(用于选择 transform)
|
||
objectives_params_list: 列表,包含要加载的物镜目录名,例如 ["10x-0.25-1.0000", ...]
|
||
output_type: 'distance'(返回 nm)或 'ratio'(返回 defocus / DoF)
|
||
"""
|
||
super().__init__()
|
||
self.dataset_root = Path(dataset_root)
|
||
self.tvt = tvt
|
||
if objectives_params_list is None:
|
||
self.objectives_params_list = ["10x-0.25-1.0000"]
|
||
else:
|
||
self.objectives_params_list = objectives_params_list
|
||
|
||
# 处理 output_type,非法输入回退到 'distance'
|
||
if isinstance(output_type, str) and output_type.lower() == "ratio":
|
||
self.output_type = "ratio"
|
||
else:
|
||
self.output_type = "distance"
|
||
|
||
# 根据 tvt 选择 transform
|
||
if self.tvt == "train":
|
||
self.transform = transforms.Compose([
|
||
transforms.RandomHorizontalFlip(),
|
||
transforms.RandomVerticalFlip(),
|
||
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||
])
|
||
else:
|
||
self.transform = transforms.Compose([
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||
])
|
||
|
||
# 直接在构造函数中读取并合并 csv
|
||
all_dfs = []
|
||
for param_dir in self.objectives_params_list:
|
||
csv_file_path = self.dataset_root / "tvtinfo" / param_dir / f"{self.tvt}.csv"
|
||
if not csv_file_path.exists():
|
||
raise FileNotFoundError(f"CSV not found: {csv_file_path}")
|
||
df = pd.read_csv(csv_file_path)
|
||
all_dfs.append(df)
|
||
|
||
if len(all_dfs) == 0:
|
||
raise ValueError("No csv files were loaded. Check objectives_params_list and dataset_root.")
|
||
|
||
combined_df = pd.concat(all_dfs, ignore_index=True)
|
||
|
||
# 过滤 relative 范围
|
||
self.dataframe = combined_df[(combined_df["relative"] >= -50) & (combined_df["relative"] <= 50)].reset_index(drop=True)
|
||
|
||
def __len__(self):
|
||
return len(self.dataframe)
|
||
|
||
@staticmethod
|
||
def _compute_dof_nm(mag, na, rix, wavelength_nm=550.0, pixel_size_nm=3450.0):
|
||
"""
|
||
公式: DoF = lambda * n / (NA ** 2) + (n * e) / (M * NA)
|
||
输入参数均为标量(float),返回 DoF(nm)
|
||
"""
|
||
# 防止除以零
|
||
if na == 0 or mag == 0:
|
||
return float('inf')
|
||
lam = float(wavelength_nm)
|
||
n = float(rix)
|
||
M = float(mag)
|
||
NA = float(na)
|
||
e = float(pixel_size_nm)
|
||
dof = (lam * n) / (NA ** 2) + ((n * e) / (M * NA))
|
||
return dof
|
||
|
||
def __getitem__(self, idx):
|
||
if torch.is_tensor(idx):
|
||
idx = idx.tolist()
|
||
|
||
row = self.dataframe.iloc[idx]
|
||
img_path = self.dataset_root / row['path']
|
||
image = Image.open(img_path).convert('RGB')
|
||
|
||
# 基本数值字段(注意 CSV 列名需匹配)
|
||
mag = float(row['mag'])
|
||
na = float(row['na'])
|
||
rix = float(row['rix'])
|
||
label_nm = float(row['label'])
|
||
|
||
image = self.transform(image)
|
||
|
||
mag_tensor = torch.tensor(mag, dtype=torch.float32)
|
||
na_tensor = torch.tensor(na, dtype=torch.float32)
|
||
rix_tensor = torch.tensor(rix, dtype=torch.float32)
|
||
label_nm_tensor = torch.tensor(label_nm, dtype=torch.float32)
|
||
|
||
# min-max 归一化输入参数
|
||
mag_tensor = (mag_tensor - 10) / (100 - 10)
|
||
na_tensor = (na_tensor - 0) / (1.25 - 0)
|
||
rix_tensor = (rix_tensor - 1.0) / (1.5 - 1.0)
|
||
|
||
# 根据 output_type 决定输出 label
|
||
if self.output_type == "ratio":
|
||
dof_nm = self._compute_dof_nm(mag=mag, na=na, rix=rix, wavelength_nm=550.0, pixel_size_nm=3450.0)
|
||
# 若 DOF 为 inf 或极大,避免除零
|
||
if not (dof_nm is None or dof_nm == float('inf') or dof_nm == 0):
|
||
label_out = label_nm / dof_nm
|
||
else:
|
||
label_out = label_nm # 回退,虽然不太可能
|
||
label_out_tensor = torch.tensor(float(label_out), dtype=torch.float32)
|
||
else:
|
||
# distance 模式:直接返回 nm
|
||
label_out_tensor = label_nm_tensor
|
||
|
||
sample = {
|
||
'image': image,
|
||
'mag': mag_tensor,
|
||
'na': na_tensor,
|
||
'rix': rix_tensor,
|
||
'label': label_out_tensor,
|
||
'path': img_path.as_posix(),
|
||
}
|
||
|
||
return sample
|
||
|
||
def get_dataframe(self):
|
||
return self.dataframe
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 简单测试
|
||
train_set = MOAFDataset("F:/Datasets/MODatasetD", tvt='train',
|
||
objectives_params_list=[
|
||
"100x-1.25-1.4730",
|
||
],
|
||
output_type='ratio')
|
||
from torch.utils.data import DataLoader
|
||
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2)
|
||
for batch in train_loader:
|
||
images = batch["image"]
|
||
labels = batch["label"]
|
||
print(f"images.shape: {images.shape}, labels.shape: {labels.shape}")
|
||
mags = batch["mag"]
|
||
nas = batch["na"]
|
||
rixs = batch["rix"]
|
||
print(f"mags: {mags}, nas: {nas}, rixs: {rixs}")
|
||
print(f"mags.shape: {mags.shape}, nas.shape: {nas.shape}, rixs.shape: {rixs.shape}")
|
||
params = torch.stack((mags, nas, rixs), dim=1)
|
||
print(f"params shape: {params.shape}")
|
||
print("first labels:")
|
||
for i in range(min(4, labels.shape[0])):
|
||
print(labels[i].item())
|
||
break
|