import os import time from PIL import Image from torch.utils.data import Dataset from tqdm import tqdm import numpy as np import random from torchvision.transforms import ToPILImage import torch import glob import torchvision.transforms as transforms from openpyxl import load_workbook import json import re # REVIEW: # 这个文件是项目的数据入口层,负责把 Excel、JSON 和图像文件组织成训练/测试阶段可消费的数据。 # 整体上它解决了“如何把作者本地数据组织接到模型上”的问题,但工程抽象较弱,路径与数据格式都强依赖作者环境。 # # REVIEW: # 当前文件同时服务于两个任务: # 1. DPNet 的 defocus distance 回归; # 2. RINet 的 patch 有效性二分类。 # 这种复用在小型研究项目里很常见,但随着任务差异增大,后续维护成本会逐渐上升。 # DPNet load datasets #""" # param root_path_list:Loading path of xlsx. # return:The training and validation data both contain the path of the image and the defocus - distance label of the image. #""" def get_DPNet_train_data_and_label(root_path_list:list): # REVIEW: # 该函数读取 DPNet 的训练/验证数据,返回的是 4 个平行列表: # 训练图像路径、训练标签、验证图像路径、验证标签。 # 好处是简单直接,缺点是多个列表之间的“同位关系”完全靠调用方自己保证。 print("begin load data") train_image_path_list = [] train_defocus_distance_list = [] val_image_path_list = [] val_defocus_distance_list = [] start_time = time.time() for i, root_path in enumerate(root_path_list): xlsx_list = glob.glob(root_path) for xlsx in xlsx_list: wb = load_workbook(xlsx) # REVIEW: # 这里使用 openpyxl 的旧接口 get_sheet_by_name,说明代码更偏实验性质。 # 运行上通常没问题,但长期兼容性和可维护性一般。 train_sheet = wb.get_sheet_by_name('train') val_sheet = wb.get_sheet_by_name('val') # REVIEW: # train_random_num / val_random_num 被用来做一种“每隔若干样本取 1 个”的下采样。 # 这种写法能快速控量,但采样策略不够直观,也不便于复现实验数据划分。 train_random_num = random.randint(0, 9) val_random_num = random.randint(0, 9) for row in train_sheet.iter_rows(values_only=True): # REVIEW: # 这里直接把作者本地 Linux 绝对路径写进了数据查找逻辑,是本项目可移植性较差的核心原因之一。 image_path_list = glob.glob(os.path.join('E:/Datasets/SparseFocusDataset/224_image', row[0], '*/*.jpg')) for (i, image_path) in enumerate(image_path_list): labels = int(image_path.replace('\\', '/').split('/')[-1].split('.')[0]) # REVIEW: # 标签直接从文件名解析,说明数据组织约定非常强。如果文件命名规则变化,这里会立即失效。 if labels >= -25000 and labels < 25000: if train_random_num % 5 == 0: train_image_path_list.append(image_path) train_defocus_distance_list.append(labels / 1000) train_random_num += 1 for row in val_sheet.iter_rows(values_only=True): image_path_list = glob.glob( os.path.join('E:/Datasets/SparseFocusDataset/224_image', row[0], '*/*.jpg')) for (i, image_path) in enumerate(image_path_list): labels = int(image_path.replace('\\', '/').split('/')[-1].split('.')[0]) if labels >= -25000 and labels < 25000: if val_random_num % 5 == 0: val_image_path_list.append(image_path) val_defocus_distance_list.append(labels / 1000) val_random_num += 1 print('need time:', time.time() - start_time) print('train data nums:', len(train_image_path_list)) print('val data nums:', len(val_image_path_list)) return train_image_path_list, train_defocus_distance_list, val_image_path_list, val_defocus_distance_list # Load RINet need datasets # """ # param root_path: Loading path of xlsx. # param type: Control the type of data: rain or val. # return: The path of the image and the labels indicating whether different areas on the image contain content. # """ def get_RINet_data(root_path: str, type: str): # REVIEW: # 该函数负责准备 RINet 所需的 patch 分类数据。与 DPNet 不同,这里的标签来自 JSON, # 表示 patch 是否“有效/包含足够信息”。 wb = load_workbook(root_path) train_image_path_list = [] train_patch_effective_list = [] test_sheet = wb.get_sheet_by_name(type) for (i, row) in enumerate(test_sheet.iter_rows(values_only=True)): # REVIEW: # 此处改为依赖作者本地 Windows 绝对路径,说明仓库内容很可能来自多环境拼接。 field_path = os.path.join('E:/Datasets/SparseFocusDataset/complete_image/complete_image/90_patch_dataset_lap_hsv',row[0]) json_data_path = os.path.join(field_path, '224_patch_effective.json') with open(json_data_path, 'r') as fs: json_data = json.load(fs) # REVIEW: # json.load 后又 json.loads,意味着文件中存放的不是标准 JSON 对象,而是“JSON 字符串”。 # 这反映出上游数据预处理存在历史包袱,但这里做了兼容。 json_data = json.loads(json_data) json_data = json_data['image_info_list'] for (i, image_data) in enumerate(json_data): train_image_path_list.append(json_data[i]['image_path'].replace(r"E:\suqiang\new_microscope_data\90_patch_dataset_hsv", "E:/Datasets/SparseFocusDataset/complete_image/complete_image/90_patch_dataset_lap_hsv")) train_patch_effective_list.append(json_data[i]['patch_effective']) return train_image_path_list, train_patch_effective_list # Load test need datasets # """ # param root_path: Loading path of xlsx. # param type: Control the type of data: rain or val. # return: The path of the image and the labels indicating whether different areas on the image contain content. # """ def get_test_data_and_label(root_path: str, type: str): # REVIEW: # 该函数为测试脚本提供整图 patch 级样本列表及其真实标签。 # 它不仅做了 Excel 读取,还做了 patch 文件展开,职责比函数名看起来更重一些。 wb = load_workbook(root_path) train_image_path_list = [] train_patch_effective_list = [] train_defocus_distance_list = [] test_sheet = wb.get_sheet_by_name(type) num_i = 0 for (i, row) in enumerate(test_sheet.iter_rows(values_only=True)): # REVIEW: # 这里继续依赖固定数据目录结构,因此 test.py 的可运行性高度依赖作者原始数据目录。 field_path = os.path.join( r'E:\suqiang\DenseSparse\cropped', row[0], '*.jpg') image_path_list = glob.glob(field_path) for image_patch in image_path_list: labels = int(image_patch.split('\\')[-1].split('.')[0]) if labels >= -25000 and labels <= 25000: train_image_path_list.append(image_patch) # REVIEW: # effective 这里统一填 0,说明测试阶段 patch 是否可用不来自数据标注,而是后续 RINet 推理结果。 train_patch_effective_list.append(0) train_defocus_distance_list.append(labels / 1000) num_i = num_i + 1 print('num_i: ',num_i) return (train_image_path_list, train_patch_effective_list, train_defocus_distance_list) class MyDataset(Dataset): # REVIEW: # 这是训练阶段的通用 Dataset,DPNet 和 RINet 都复用它。 # 通过允许 label 为任意列表实现任务复用,优点是轻量,缺点是接口语义不够强约束。 def __init__(self, data, label, transform=None): self.data = data self.label = label self.transform = transform def __len__(self): # return self.data.size(0) return len(self.data) def __getitem__(self, index): imag = self.data[index] # REVIEW: # 这里未显式 convert('RGB'),如果数据源中出现灰度图或 RGBA 图,通道数可能和模型预期不一致。 image = Image.open(imag) if len(self.label) > 0: label = torch.tensor(self.label[index]) else: label = torch.tensor([]) if self.transform != None: image = self.transform(image) # REVIEW: # 额外返回 imag 路径有助于日志和调试,但也让 batch 中混入了非张量字段。 return image, label, imag class MyDataset_test(Dataset): # REVIEW: # 测试专用 Dataset 在返回值中加入 effective 字段,以兼容 test.py 的历史接口。 # 当前 effective 在数据层面基本只是占位。 def __init__(self,data,label,effective,transform=None): self.data = data self.label = label self.effective = effective self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, index): image_name = self.data[index] image = Image.open(image_name) # image = image.resize((2016, 2016)) # image = np.array(image) # H W C # image = image[:,16:2032, 216:2232] label = self.label[index] if self.transform != None: image = self.transform(image) effective = self.effective[index] # REVIEW: # 与 MyDataset 不同,这里 label 保持原始 Python 数值而非 tensor,两个 Dataset 的接口风格并不完全一致。 return image,label,image_name,effective