SparseFocus/datasets.py
2026-06-02 13:51:22 +08:00

76 lines
2.3 KiB
Python

import random
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as F
from torch.utils.data import Dataset
# 同步增强
class RINPairTransform:
def __init__(self, train=True, image_size=512):
self.train = train
self.image_size = image_size
self.color_jitter = transforms.ColorJitter(
brightness=(0.9, 1.4),
contrast=(0.8, 1.5),
saturation=(0.8, 1.5),
)
def __call__(self, image, label):
image = F.resize(image, size=(self.image_size, self.image_size))
label = torch.as_tensor(label, dtype=torch.float32).view(9, 9)
if self.train:
# D4 数据增强
# 随机 90 度旋转
k = random.randint(0, 3)
if k == 1:
image = image.transpose(Image.Transpose.ROTATE_90)
label = torch.rot90(label, k=1, dims=(0, 1))
elif k == 2:
image = image.transpose(Image.Transpose.ROTATE_180)
label = torch.rot90(label, k=2, dims=(0, 1))
elif k == 3:
image = image.transpose(Image.Transpose.ROTATE_270)
label = torch.rot90(label, k=3, dims=(0, 1))
# 随机翻转
if random.random() < 0.5:
image = F.hflip(image)
label = torch.flip(label, dims=(1,))
# 颜色增强
image = self.color_jitter(image)
image = F.to_tensor(image)
return image, label
class RIN_Dataset(Dataset):
def __init__(self, image_path_list, label_list, transform=None):
self.image_path_list = image_path_list
self.label_list = label_list
self.transform = transform
def __len__(self):
return len(self.image_path_list)
def __getitem__(self, index):
image_path = self.image_path_list[index]
label = self.label_list[index]
image = Image.open(image_path).convert("RGB")
if self.transform is not None:
image, label = self.transform(image, label)
else:
image = F.resize(image, size=(512, 512))
image = F.to_tensor(image)
label = torch.as_tensor(label, dtype=torch.float32).view(9, 9)
return image, label