269 lines
14 KiB
Python
269 lines
14 KiB
Python
import os
|
||
import numpy as np
|
||
import torch.nn.functional as F
|
||
import torchvision.transforms as transforms
|
||
from torch.utils.data import DataLoader
|
||
import torch.nn as nn
|
||
import torch
|
||
|
||
from model.DPNet import register_model
|
||
from model.RINet import MobileNetV3_small as con_classification
|
||
|
||
from dataset import datasets as dataset_F
|
||
import util.utils as util
|
||
from tqdm import tqdm
|
||
import openpyxl
|
||
|
||
# REVIEW:
|
||
# 该脚本是项目的测试/推理主入口,也是两阶段方法真正汇合的地方:
|
||
# 先用 RINet 对整图打 patch 有效性分数,再用 DPNet 对候选 patch 做 defocus distance 回归,
|
||
# 最后通过中位数聚合得到整图预测。
|
||
#
|
||
# REVIEW:
|
||
# 如果想快速理解作者的方法,这个文件是最关键的,因为它最完整地体现了算法在实际评估中的执行顺序。
|
||
|
||
# Set the visible GPU devices for CUDA
|
||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
||
|
||
# Image patch size
|
||
size = 224
|
||
|
||
# Paths to test data Excel files
|
||
root_path_list = [r'E:\suqiang\DenseSparse\testData_20241227.xlsx']
|
||
# REVIEW:
|
||
# 测试集路径完全硬编码,说明当前脚本更像作者自己的实验工具而不是通用命令行程序。
|
||
|
||
# Define data sheet types and test categories (cell or tissue)
|
||
data_type_list = ['Sheet', 'Sheet1', 'Sheet1', 'Sheet1']
|
||
cell_or_tissue_list = ['tissue', 'tissue']
|
||
|
||
# Load the classification model and pre-trained weights
|
||
classification_model = con_classification()
|
||
state_load = torch.load('../weight/RINet_best_model.pt')
|
||
classification_model = nn.DataParallel(classification_model).to('cuda')
|
||
classification_model.load_state_dict(state_load, False)
|
||
classification_model.eval() # Set the model to evaluation mode
|
||
# REVIEW:
|
||
# RINet 权重和结构之间没有版本检查,只有运行时才能暴露兼容性问题。
|
||
|
||
# Loop for different models (e.g., DPNet)
|
||
for c in range(1):
|
||
if c == 0:
|
||
state_load = torch.load('../weight/DP_best_model.pt')
|
||
model = register_model()
|
||
names = 'ours'
|
||
model = nn.DataParallel(model).to('cuda')
|
||
model.load_state_dict(state_load, False)
|
||
model.eval() # Set the model to evaluation mode
|
||
|
||
# Iterate over test datasets
|
||
for j in range(len(root_path_list)):
|
||
root_path = root_path_list[j]
|
||
data_type = data_type_list[j]
|
||
cell_or_tissue = cell_or_tissue_list[j]
|
||
|
||
# Load test data and labels from the specified Excel sheet
|
||
(train_image_path_list, train_patch_effective_list,
|
||
train_defocus_distance_list) = dataset_F.get_test_data_and_label(root_path=root_path, type=f'{data_type}')
|
||
|
||
# Define image transformation pipeline
|
||
train_transform = transforms.Compose([
|
||
transforms.CenterCrop((2016, 2016)), # Crop the center region of the image
|
||
transforms.ToTensor(), # Convert the image to a tensor
|
||
])
|
||
# REVIEW:
|
||
# 这里把整图固定裁成 2016x2016,并配合 size=224 形成 9x9 patch 网格。
|
||
# 因此测试流程对输入尺寸有很强假设。
|
||
|
||
# Create a custom dataset and DataLoader
|
||
train_data = dataset_F.MyDataset_test(train_image_path_list, train_defocus_distance_list,
|
||
train_patch_effective_list, train_transform)
|
||
train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, pin_memory=True, num_workers=0)
|
||
|
||
# Initialize an Excel workbook to save results
|
||
workbook = openpyxl.Workbook()
|
||
sheet = workbook.create_sheet(index=0)
|
||
n = 0 # Row index for Excel sheet
|
||
|
||
# Iterate through the test data
|
||
for i, (img, label, image_name, effective) in tqdm(enumerate(train_loader)):
|
||
image_path = image_name[0]
|
||
image_label = int(image_path.split('\\')[-1].split('.')[0]) # Extract ground truth label
|
||
|
||
# Add header row to the Excel sheet on the first iteration
|
||
if i == 0:
|
||
sheet.cell(1, 1).value = 'groundTruth'
|
||
sheet.cell(1, 2).value = 'all_prediction'
|
||
sheet.cell(1, 3).value = 'all_can_use_prediction'
|
||
sheet.cell(1, 4).value = '61_prediction'
|
||
sheet.cell(1, 5).value = '51_prediction'
|
||
sheet.cell(1, 6).value = '41_prediction'
|
||
sheet.cell(1, 7).value = '31_prediction'
|
||
sheet.cell(1, 8).value = '25_prediction'
|
||
sheet.cell(1, 9).value = '19_prediction'
|
||
sheet.cell(1, 10).value = '15_prediction'
|
||
sheet.cell(1, 11).value = '9_prediction'
|
||
sheet.cell(1, 12).value = '5_prediction'
|
||
sheet.cell(1, 13).value = '3_prediction'
|
||
sheet.cell(1, 14).value = '1_prediction'
|
||
sheet.cell(1, 15).value = 'path'
|
||
sheet.cell(1, 16).value = 'max_can_use_patch_num'
|
||
sheet.cell(1, 17).value = 'avg_prediction'
|
||
|
||
# Preprocess the image for the classification model
|
||
img2 = F.interpolate(img, size=(512, 512), mode='bilinear', align_corners=False)
|
||
img = torch.cuda.FloatTensor(np.array(img))
|
||
img2 = torch.cuda.FloatTensor(np.array(img2))
|
||
# REVIEW:
|
||
# 这里继续沿用了 Tensor -> numpy -> CUDA Tensor 的旧式转换写法。
|
||
if torch.cuda.is_available():
|
||
img = img.cuda()
|
||
img2 = img2.cuda()
|
||
|
||
# Perform classification on the preprocessed image
|
||
output_weight = classification_model(img2)
|
||
output_weight_userful = output_weight[torch.abs(output_weight) > 0.8] # Filter predictions above a threshold
|
||
can_use_num = len(output_weight_userful) # Count usable predictions
|
||
# REVIEW:
|
||
# patch 是否“可用”由阈值 0.8 决定,这是一个非常关键的经验超参数。
|
||
|
||
# Initialize usable patch numbers for different thresholds
|
||
can_use_num_61, can_use_num_51, can_use_num_41 = 61, 51, 41
|
||
can_use_num_31, can_use_num_25, can_use_num_19 = 31, 25, 19
|
||
can_use_num_15, can_use_num_9, can_use_num_5 = 15, 9, 5
|
||
can_use_num_3, can_use_num_1 = 3, 1
|
||
|
||
# Handle cases with no usable patches
|
||
if can_use_num == 0:
|
||
util.print_info(f'{image_name} don\'t have enough info')
|
||
continue
|
||
# REVIEW:
|
||
# 没有任何 patch 通过阈值时直接跳过样本,会导致结果文件中少样本,这是需要知晓的隐式行为。
|
||
|
||
# Adjust usable patch numbers based on conditions
|
||
if can_use_num % 2 == 0:
|
||
can_use_num -= 1
|
||
if 51 <= can_use_num < 61:
|
||
can_use_num_61 = can_use_num
|
||
elif 41 <= can_use_num < 51:
|
||
can_use_num_61 = can_use_num_51 = can_use_num
|
||
elif 31 <= can_use_num < 41:
|
||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num
|
||
elif 25 <= can_use_num < 31:
|
||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num
|
||
elif 19 <= can_use_num < 25:
|
||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num
|
||
elif 15 <= can_use_num < 19:
|
||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num
|
||
elif 9 <= can_use_num < 15:
|
||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num
|
||
elif 5 <= can_use_num < 9:
|
||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num
|
||
elif 3 <= can_use_num < 5:
|
||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num
|
||
elif can_use_num < 3:
|
||
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num_3 = 1
|
||
|
||
# Extract top predictions for each patch group
|
||
max_all_can_value, max_all_can_index = torch.topk(output_weight, can_use_num)
|
||
max_61_value, max_61_index = torch.topk(output_weight, can_use_num_61)
|
||
max_51_value, max_51_index = torch.topk(output_weight, can_use_num_51)
|
||
max_41_value, max_41_index = torch.topk(output_weight, can_use_num_41)
|
||
max_31_value, max_31_index = torch.topk(output_weight, can_use_num_31)
|
||
max_25_value, max_25_index = torch.topk(output_weight, can_use_num_25)
|
||
max_19_value, max_19_index = torch.topk(output_weight, can_use_num_19)
|
||
max_15_value, max_15_index = torch.topk(output_weight, can_use_num_15)
|
||
max_9_value, max_9_index = torch.topk(output_weight, can_use_num_9)
|
||
max_5_value, max_5_index = torch.topk(output_weight, can_use_num_5)
|
||
max_3_value, max_3_index = torch.topk(output_weight, can_use_num_3)
|
||
max_1_value, max_1_index = torch.topk(output_weight, can_use_num_1)
|
||
patch_image_list_all = []
|
||
patch_image_list_all_can_use = []
|
||
patch_image_list_61 = []
|
||
patch_image_list_51 = []
|
||
patch_image_list_41 = []
|
||
patch_image_list_31 = []
|
||
patch_image_list_25 = []
|
||
patch_image_list_19 = []
|
||
patch_image_list_15 = []
|
||
patch_image_list_9 = []
|
||
patch_image_list_5 = []
|
||
patch_image_list_3 = []
|
||
patch_image_list_1 = []
|
||
label_list = []
|
||
# REVIEW:
|
||
# 这里同时比较多套 top-k 聚合结果,说明作者不仅关心最终预测,也在研究“保留多少 patch 更合理”。
|
||
for k in range(0, 2016, size):
|
||
a = int(k / size)
|
||
for j in range(0, 2016, size):
|
||
b = int(j / size)
|
||
patch_image = img[:, :, k:k + size, j:j + size]
|
||
patch_image_list_all.append(patch_image)
|
||
if len(patch_image_list_all) != 0:
|
||
out_all = model(torch.cat(patch_image_list_all, dim=0))
|
||
# REVIEW:
|
||
# 81 个 patch 一次性送入 DPNet 是个不错的实现,避免了逐 patch 前向的重复开销。
|
||
for z in range(len(out_all)):
|
||
if torch.any(max_all_can_index.eq(z)):
|
||
patch_image_list_all_can_use.append(out_all[z].item())
|
||
if torch.any(max_61_index.eq(z)):
|
||
patch_image_list_61.append(out_all[z].item())
|
||
if torch.any(max_51_index.eq(z)):
|
||
patch_image_list_51.append(out_all[z].item())
|
||
if torch.any(max_41_index.eq(z)):
|
||
patch_image_list_41.append(out_all[z].item())
|
||
if torch.any(max_31_index.eq(z)):
|
||
patch_image_list_31.append(out_all[z].item())
|
||
if torch.any(max_25_index.eq(z)):
|
||
patch_image_list_25.append(out_all[z].item())
|
||
if torch.any(max_19_index.eq(z)):
|
||
patch_image_list_19.append(out_all[z].item())
|
||
if torch.any(max_15_index.eq(z)):
|
||
patch_image_list_15.append(out_all[z].item())
|
||
if torch.any(max_9_index.eq(z)):
|
||
patch_image_list_9.append(out_all[z].item())
|
||
if torch.any(max_5_index.eq(z)):
|
||
patch_image_list_5.append(out_all[z].item())
|
||
if torch.any(max_3_index.eq(z)):
|
||
patch_image_list_3.append(out_all[z].item())
|
||
if torch.any(max_1_index.eq(z)):
|
||
patch_image_list_1.append(out_all[z].item())
|
||
nums_all = float(np.median(out_all.tolist()))
|
||
avg_nums_all = float(np.mean(out_all.tolist()))
|
||
data = []
|
||
data.append(image_label)
|
||
data.append(float(np.median(out_all.tolist())) * 1000)
|
||
data.append(float(np.median(patch_image_list_all_can_use)) * 1000)
|
||
data.append(float(np.median(patch_image_list_61)) * 1000)
|
||
data.append(float(np.median(patch_image_list_51)) * 1000)
|
||
data.append(float(np.median(patch_image_list_41)) * 1000)
|
||
data.append(float(np.median(patch_image_list_31)) * 1000)
|
||
data.append(float(np.median(patch_image_list_25)) * 1000)
|
||
data.append(float(np.median(patch_image_list_19)) * 1000)
|
||
data.append(float(np.median(patch_image_list_15)) * 1000)
|
||
data.append(float(np.median(patch_image_list_9)) * 1000)
|
||
data.append(float(np.median(patch_image_list_5)) * 1000)
|
||
data.append(float(np.median(patch_image_list_3)) * 1000)
|
||
data.append(float(np.median(patch_image_list_1)) * 1000)
|
||
data.append((image_path.split('\\')[-3] + '\\' +image_path.split('\\')[-2] + '\\' +image_path.split('\\')[-1]))
|
||
data.append(len(output_weight_userful))
|
||
data.append(float(np.mean(out_all.tolist())) * 1000)
|
||
for (m, info) in enumerate(data):
|
||
sheet.cell(n + 2, m + 1).value = info
|
||
n = n + 1
|
||
util.save_val_image(1, img,
|
||
score=out_all, image_name=image_name,
|
||
type_str=cell_or_tissue, type='test', type_label=output_weight)
|
||
# REVIEW:
|
||
# 数值输出之外还保留了网格可视化,这对解释模型行为和排查异常样本很有帮助。
|
||
workbook.save(f'./{names}_to_{cell_or_tissue}_{data_type}_2.xlsx')
|
||
# REVIEW:
|
||
# 最终结果保存为 Excel,非常贴合论文实验场景;如果做批量评估,结构化文本格式会更便于分析。
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|