import random import torch from PIL import Image from torchvision import transforms from torchvision.transforms import functional as F from torch.utils.data import Dataset # 同步增强 class RINPairTransform: def __init__(self, train=True, image_size=512): self.train = train self.image_size = image_size self.color_jitter = transforms.ColorJitter( brightness=(0.9, 1.4), contrast=(0.8, 1.5), saturation=(0.8, 1.5), ) def __call__(self, image, label): image = F.resize(image, size=(self.image_size, self.image_size)) label = torch.as_tensor(label, dtype=torch.float32).view(9, 9) if self.train: # D4 数据增强 # 随机 90 度旋转 k = random.randint(0, 3) if k == 1: image = image.transpose(Image.Transpose.ROTATE_90) label = torch.rot90(label, k=1, dims=(0, 1)) elif k == 2: image = image.transpose(Image.Transpose.ROTATE_180) label = torch.rot90(label, k=2, dims=(0, 1)) elif k == 3: image = image.transpose(Image.Transpose.ROTATE_270) label = torch.rot90(label, k=3, dims=(0, 1)) # 随机翻转 if random.random() < 0.5: image = F.hflip(image) label = torch.flip(label, dims=(1,)) # 颜色增强 image = self.color_jitter(image) image = F.to_tensor(image) return image, label class RIN_Dataset(Dataset): def __init__(self, image_path_list, label_list, transform=None): self.image_path_list = image_path_list self.label_list = label_list self.transform = transform def __len__(self): return len(self.image_path_list) def __getitem__(self, index): image_path = self.image_path_list[index] label = self.label_list[index] image = Image.open(image_path).convert("RGB") if self.transform is not None: image, label = self.transform(image, label) else: image = F.resize(image, size=(512, 512)) image = F.to_tensor(image) label = torch.as_tensor(label, dtype=torch.float32).view(9, 9) return image, label