218 lines
10 KiB
Python
218 lines
10 KiB
Python
import os
|
||
|
||
import time
|
||
|
||
from PIL import Image
|
||
from torch.utils.data import Dataset
|
||
from tqdm import tqdm
|
||
import numpy as np
|
||
import random
|
||
from torchvision.transforms import ToPILImage
|
||
import torch
|
||
import glob
|
||
import torchvision.transforms as transforms
|
||
from openpyxl import load_workbook
|
||
import json
|
||
import re
|
||
|
||
# REVIEW:
|
||
# 这个文件是项目的数据入口层,负责把 Excel、JSON 和图像文件组织成训练/测试阶段可消费的数据。
|
||
# 整体上它解决了“如何把作者本地数据组织接到模型上”的问题,但工程抽象较弱,路径与数据格式都强依赖作者环境。
|
||
#
|
||
# REVIEW:
|
||
# 当前文件同时服务于两个任务:
|
||
# 1. DPNet 的 defocus distance 回归;
|
||
# 2. RINet 的 patch 有效性二分类。
|
||
# 这种复用在小型研究项目里很常见,但随着任务差异增大,后续维护成本会逐渐上升。
|
||
|
||
|
||
# DPNet load datasets
|
||
#"""
|
||
# param root_path_list:Loading path of xlsx.
|
||
# return:The training and validation data both contain the path of the image and the defocus - distance label of the image.
|
||
#"""
|
||
def get_DPNet_train_data_and_label(root_path_list:list):
|
||
# REVIEW:
|
||
# 该函数读取 DPNet 的训练/验证数据,返回的是 4 个平行列表:
|
||
# 训练图像路径、训练标签、验证图像路径、验证标签。
|
||
# 好处是简单直接,缺点是多个列表之间的“同位关系”完全靠调用方自己保证。
|
||
|
||
print("begin load data")
|
||
train_image_path_list = []
|
||
train_defocus_distance_list = []
|
||
val_image_path_list = []
|
||
val_defocus_distance_list = []
|
||
start_time = time.time()
|
||
for i, root_path in enumerate(root_path_list):
|
||
xlsx_list = glob.glob(root_path)
|
||
for xlsx in xlsx_list:
|
||
wb = load_workbook(xlsx)
|
||
# REVIEW:
|
||
# 这里使用 openpyxl 的旧接口 get_sheet_by_name,说明代码更偏实验性质。
|
||
# 运行上通常没问题,但长期兼容性和可维护性一般。
|
||
train_sheet = wb.get_sheet_by_name('train')
|
||
val_sheet = wb.get_sheet_by_name('val')
|
||
# REVIEW:
|
||
# train_random_num / val_random_num 被用来做一种“每隔若干样本取 1 个”的下采样。
|
||
# 这种写法能快速控量,但采样策略不够直观,也不便于复现实验数据划分。
|
||
train_random_num = random.randint(0, 9)
|
||
val_random_num = random.randint(0, 9)
|
||
for row in train_sheet.iter_rows(values_only=True):
|
||
# REVIEW:
|
||
# 这里直接把作者本地 Linux 绝对路径写进了数据查找逻辑,是本项目可移植性较差的核心原因之一。
|
||
image_path_list = glob.glob(os.path.join('E:/Datasets/SparseFocusDataset/224_image', row[0], '*/*.jpg'))
|
||
for (i, image_path) in enumerate(image_path_list):
|
||
labels = int(image_path.replace('\\', '/').split('/')[-1].split('.')[0])
|
||
# REVIEW:
|
||
# 标签直接从文件名解析,说明数据组织约定非常强。如果文件命名规则变化,这里会立即失效。
|
||
if labels >= -25000 and labels < 25000:
|
||
if train_random_num % 5 == 0:
|
||
train_image_path_list.append(image_path)
|
||
train_defocus_distance_list.append(labels / 1000)
|
||
train_random_num += 1
|
||
for row in val_sheet.iter_rows(values_only=True):
|
||
image_path_list = glob.glob(
|
||
os.path.join('E:/Datasets/SparseFocusDataset/224_image', row[0],
|
||
'*/*.jpg'))
|
||
for (i, image_path) in enumerate(image_path_list):
|
||
labels = int(image_path.replace('\\', '/').split('/')[-1].split('.')[0])
|
||
if labels >= -25000 and labels < 25000:
|
||
if val_random_num % 5 == 0:
|
||
val_image_path_list.append(image_path)
|
||
val_defocus_distance_list.append(labels / 1000)
|
||
val_random_num += 1
|
||
print('need time:', time.time() - start_time)
|
||
print('train data nums:', len(train_image_path_list))
|
||
print('val data nums:', len(val_image_path_list))
|
||
return train_image_path_list, train_defocus_distance_list, val_image_path_list, val_defocus_distance_list
|
||
|
||
# Load RINet need datasets
|
||
# """
|
||
# param root_path: Loading path of xlsx.
|
||
# param type: Control the type of data: rain or val.
|
||
# return: The path of the image and the labels indicating whether different areas on the image contain content.
|
||
# """
|
||
def get_RINet_data(root_path: str, type: str):
|
||
# REVIEW:
|
||
# 该函数负责准备 RINet 所需的 patch 分类数据。与 DPNet 不同,这里的标签来自 JSON,
|
||
# 表示 patch 是否“有效/包含足够信息”。
|
||
|
||
wb = load_workbook(root_path)
|
||
train_image_path_list = []
|
||
train_patch_effective_list = []
|
||
test_sheet = wb.get_sheet_by_name(type)
|
||
for (i, row) in enumerate(test_sheet.iter_rows(values_only=True)):
|
||
# REVIEW:
|
||
# 此处改为依赖作者本地 Windows 绝对路径,说明仓库内容很可能来自多环境拼接。
|
||
field_path = os.path.join('E:/Datasets/SparseFocusDataset/complete_image/complete_image/90_patch_dataset_lap_hsv',row[0])
|
||
json_data_path = os.path.join(field_path, '224_patch_effective.json')
|
||
with open(json_data_path, 'r') as fs:
|
||
json_data = json.load(fs)
|
||
# REVIEW:
|
||
# json.load 后又 json.loads,意味着文件中存放的不是标准 JSON 对象,而是“JSON 字符串”。
|
||
# 这反映出上游数据预处理存在历史包袱,但这里做了兼容。
|
||
json_data = json.loads(json_data)
|
||
json_data = json_data['image_info_list']
|
||
for (i, image_data) in enumerate(json_data):
|
||
train_image_path_list.append(json_data[i]['image_path'].replace(r"E:\suqiang\new_microscope_data\90_patch_dataset_hsv", "E:/Datasets/SparseFocusDataset/complete_image/complete_image/90_patch_dataset_lap_hsv"))
|
||
|
||
train_patch_effective_list.append(json_data[i]['patch_effective'])
|
||
return train_image_path_list, train_patch_effective_list
|
||
# Load test need datasets
|
||
# """
|
||
# param root_path: Loading path of xlsx.
|
||
# param type: Control the type of data: rain or val.
|
||
# return: The path of the image and the labels indicating whether different areas on the image contain content.
|
||
# """
|
||
def get_test_data_and_label(root_path: str, type: str):
|
||
# REVIEW:
|
||
# 该函数为测试脚本提供整图 patch 级样本列表及其真实标签。
|
||
# 它不仅做了 Excel 读取,还做了 patch 文件展开,职责比函数名看起来更重一些。
|
||
|
||
wb = load_workbook(root_path)
|
||
train_image_path_list = []
|
||
train_patch_effective_list = []
|
||
train_defocus_distance_list = []
|
||
test_sheet = wb.get_sheet_by_name(type)
|
||
num_i = 0
|
||
for (i, row) in enumerate(test_sheet.iter_rows(values_only=True)):
|
||
# REVIEW:
|
||
# 这里继续依赖固定数据目录结构,因此 test.py 的可运行性高度依赖作者原始数据目录。
|
||
field_path = os.path.join(
|
||
r'E:\suqiang\DenseSparse\cropped',
|
||
row[0], '*.jpg')
|
||
image_path_list = glob.glob(field_path)
|
||
for image_patch in image_path_list:
|
||
labels = int(image_patch.split('\\')[-1].split('.')[0])
|
||
if labels >= -25000 and labels <= 25000:
|
||
train_image_path_list.append(image_patch)
|
||
# REVIEW:
|
||
# effective 这里统一填 0,说明测试阶段 patch 是否可用不来自数据标注,而是后续 RINet 推理结果。
|
||
train_patch_effective_list.append(0)
|
||
train_defocus_distance_list.append(labels / 1000)
|
||
num_i = num_i + 1
|
||
print('num_i: ',num_i)
|
||
return (train_image_path_list, train_patch_effective_list,
|
||
train_defocus_distance_list)
|
||
|
||
|
||
|
||
|
||
|
||
|
||
class MyDataset(Dataset):
|
||
# REVIEW:
|
||
# 这是训练阶段的通用 Dataset,DPNet 和 RINet 都复用它。
|
||
# 通过允许 label 为任意列表实现任务复用,优点是轻量,缺点是接口语义不够强约束。
|
||
def __init__(self, data, label, transform=None):
|
||
self.data = data
|
||
self.label = label
|
||
self.transform = transform
|
||
|
||
def __len__(self):
|
||
# return self.data.size(0)
|
||
return len(self.data)
|
||
|
||
def __getitem__(self, index):
|
||
imag = self.data[index]
|
||
# REVIEW:
|
||
# 这里未显式 convert('RGB'),如果数据源中出现灰度图或 RGBA 图,通道数可能和模型预期不一致。
|
||
image = Image.open(imag)
|
||
if len(self.label) > 0:
|
||
label = torch.tensor(self.label[index])
|
||
else:
|
||
label = torch.tensor([])
|
||
if self.transform != None:
|
||
image = self.transform(image)
|
||
# REVIEW:
|
||
# 额外返回 imag 路径有助于日志和调试,但也让 batch 中混入了非张量字段。
|
||
return image, label, imag
|
||
|
||
|
||
|
||
class MyDataset_test(Dataset):
|
||
# REVIEW:
|
||
# 测试专用 Dataset 在返回值中加入 effective 字段,以兼容 test.py 的历史接口。
|
||
# 当前 effective 在数据层面基本只是占位。
|
||
def __init__(self,data,label,effective,transform=None):
|
||
self.data = data
|
||
self.label = label
|
||
self.effective = effective
|
||
self.transform = transform
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def __getitem__(self, index):
|
||
image_name = self.data[index]
|
||
image = Image.open(image_name)
|
||
# image = image.resize((2016, 2016))
|
||
# image = np.array(image) # H W C
|
||
# image = image[:,16:2032, 216:2232]
|
||
label = self.label[index]
|
||
if self.transform != None:
|
||
image = self.transform(image)
|
||
effective = self.effective[index]
|
||
# REVIEW:
|
||
# 与 MyDataset 不同,这里 label 保持原始 Python 数值而非 tensor,两个 Dataset 的接口风格并不完全一致。
|
||
return image,label,image_name,effective
|