76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
import random
|
|
import torch
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
from torchvision.transforms import functional as F
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
# 同步增强
|
|
class RINPairTransform:
|
|
def __init__(self, train=True, image_size=512):
|
|
self.train = train
|
|
self.image_size = image_size
|
|
self.color_jitter = transforms.ColorJitter(
|
|
brightness=(0.9, 1.4),
|
|
contrast=(0.8, 1.5),
|
|
saturation=(0.8, 1.5),
|
|
)
|
|
|
|
def __call__(self, image, label):
|
|
image = F.resize(image, size=(self.image_size, self.image_size))
|
|
label = torch.as_tensor(label, dtype=torch.float32).view(9, 9)
|
|
|
|
if self.train:
|
|
# D4 数据增强
|
|
# 随机 90 度旋转
|
|
k = random.randint(0, 3)
|
|
|
|
if k == 1:
|
|
image = image.transpose(Image.Transpose.ROTATE_90)
|
|
label = torch.rot90(label, k=1, dims=(0, 1))
|
|
|
|
elif k == 2:
|
|
image = image.transpose(Image.Transpose.ROTATE_180)
|
|
label = torch.rot90(label, k=2, dims=(0, 1))
|
|
|
|
elif k == 3:
|
|
image = image.transpose(Image.Transpose.ROTATE_270)
|
|
label = torch.rot90(label, k=3, dims=(0, 1))
|
|
|
|
# 随机翻转
|
|
if random.random() < 0.5:
|
|
image = F.hflip(image)
|
|
label = torch.flip(label, dims=(1,))
|
|
|
|
# 颜色增强
|
|
image = self.color_jitter(image)
|
|
|
|
image = F.to_tensor(image)
|
|
return image, label
|
|
|
|
|
|
class RIN_Dataset(Dataset):
|
|
def __init__(self, image_path_list, label_list, transform=None):
|
|
self.image_path_list = image_path_list
|
|
self.label_list = label_list
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.image_path_list)
|
|
|
|
def __getitem__(self, index):
|
|
image_path = self.image_path_list[index]
|
|
label = self.label_list[index]
|
|
|
|
image = Image.open(image_path).convert("RGB")
|
|
|
|
if self.transform is not None:
|
|
image, label = self.transform(image, label)
|
|
else:
|
|
image = F.resize(image, size=(512, 512))
|
|
image = F.to_tensor(image)
|
|
label = torch.as_tensor(label, dtype=torch.float32).view(9, 9)
|
|
|
|
return image, label
|