SparseFocus/old_test.py
2026-06-02 13:51:22 +08:00

269 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from model.DPNet import register_model
from model.RINet import MobileNetV3_small as con_classification
from dataset import datasets as dataset_F
import util.utils as util
from tqdm import tqdm
import openpyxl
# REVIEW:
# 该脚本是项目的测试/推理主入口,也是两阶段方法真正汇合的地方:
# 先用 RINet 对整图打 patch 有效性分数,再用 DPNet 对候选 patch 做 defocus distance 回归,
# 最后通过中位数聚合得到整图预测。
#
# REVIEW:
# 如果想快速理解作者的方法,这个文件是最关键的,因为它最完整地体现了算法在实际评估中的执行顺序。
# Set the visible GPU devices for CUDA
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
# Image patch size
size = 224
# Paths to test data Excel files
root_path_list = [r'E:\suqiang\DenseSparse\testData_20241227.xlsx']
# REVIEW:
# 测试集路径完全硬编码,说明当前脚本更像作者自己的实验工具而不是通用命令行程序。
# Define data sheet types and test categories (cell or tissue)
data_type_list = ['Sheet', 'Sheet1', 'Sheet1', 'Sheet1']
cell_or_tissue_list = ['tissue', 'tissue']
# Load the classification model and pre-trained weights
classification_model = con_classification()
state_load = torch.load('../weight/RINet_best_model.pt')
classification_model = nn.DataParallel(classification_model).to('cuda')
classification_model.load_state_dict(state_load, False)
classification_model.eval() # Set the model to evaluation mode
# REVIEW:
# RINet 权重和结构之间没有版本检查,只有运行时才能暴露兼容性问题。
# Loop for different models (e.g., DPNet)
for c in range(1):
if c == 0:
state_load = torch.load('../weight/DP_best_model.pt')
model = register_model()
names = 'ours'
model = nn.DataParallel(model).to('cuda')
model.load_state_dict(state_load, False)
model.eval() # Set the model to evaluation mode
# Iterate over test datasets
for j in range(len(root_path_list)):
root_path = root_path_list[j]
data_type = data_type_list[j]
cell_or_tissue = cell_or_tissue_list[j]
# Load test data and labels from the specified Excel sheet
(train_image_path_list, train_patch_effective_list,
train_defocus_distance_list) = dataset_F.get_test_data_and_label(root_path=root_path, type=f'{data_type}')
# Define image transformation pipeline
train_transform = transforms.Compose([
transforms.CenterCrop((2016, 2016)), # Crop the center region of the image
transforms.ToTensor(), # Convert the image to a tensor
])
# REVIEW:
# 这里把整图固定裁成 2016x2016并配合 size=224 形成 9x9 patch 网格。
# 因此测试流程对输入尺寸有很强假设。
# Create a custom dataset and DataLoader
train_data = dataset_F.MyDataset_test(train_image_path_list, train_defocus_distance_list,
train_patch_effective_list, train_transform)
train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, pin_memory=True, num_workers=0)
# Initialize an Excel workbook to save results
workbook = openpyxl.Workbook()
sheet = workbook.create_sheet(index=0)
n = 0 # Row index for Excel sheet
# Iterate through the test data
for i, (img, label, image_name, effective) in tqdm(enumerate(train_loader)):
image_path = image_name[0]
image_label = int(image_path.split('\\')[-1].split('.')[0]) # Extract ground truth label
# Add header row to the Excel sheet on the first iteration
if i == 0:
sheet.cell(1, 1).value = 'groundTruth'
sheet.cell(1, 2).value = 'all_prediction'
sheet.cell(1, 3).value = 'all_can_use_prediction'
sheet.cell(1, 4).value = '61_prediction'
sheet.cell(1, 5).value = '51_prediction'
sheet.cell(1, 6).value = '41_prediction'
sheet.cell(1, 7).value = '31_prediction'
sheet.cell(1, 8).value = '25_prediction'
sheet.cell(1, 9).value = '19_prediction'
sheet.cell(1, 10).value = '15_prediction'
sheet.cell(1, 11).value = '9_prediction'
sheet.cell(1, 12).value = '5_prediction'
sheet.cell(1, 13).value = '3_prediction'
sheet.cell(1, 14).value = '1_prediction'
sheet.cell(1, 15).value = 'path'
sheet.cell(1, 16).value = 'max_can_use_patch_num'
sheet.cell(1, 17).value = 'avg_prediction'
# Preprocess the image for the classification model
img2 = F.interpolate(img, size=(512, 512), mode='bilinear', align_corners=False)
img = torch.cuda.FloatTensor(np.array(img))
img2 = torch.cuda.FloatTensor(np.array(img2))
# REVIEW:
# 这里继续沿用了 Tensor -> numpy -> CUDA Tensor 的旧式转换写法。
if torch.cuda.is_available():
img = img.cuda()
img2 = img2.cuda()
# Perform classification on the preprocessed image
output_weight = classification_model(img2)
output_weight_userful = output_weight[torch.abs(output_weight) > 0.8] # Filter predictions above a threshold
can_use_num = len(output_weight_userful) # Count usable predictions
# REVIEW:
# patch 是否“可用”由阈值 0.8 决定,这是一个非常关键的经验超参数。
# Initialize usable patch numbers for different thresholds
can_use_num_61, can_use_num_51, can_use_num_41 = 61, 51, 41
can_use_num_31, can_use_num_25, can_use_num_19 = 31, 25, 19
can_use_num_15, can_use_num_9, can_use_num_5 = 15, 9, 5
can_use_num_3, can_use_num_1 = 3, 1
# Handle cases with no usable patches
if can_use_num == 0:
util.print_info(f'{image_name} don\'t have enough info')
continue
# REVIEW:
# 没有任何 patch 通过阈值时直接跳过样本,会导致结果文件中少样本,这是需要知晓的隐式行为。
# Adjust usable patch numbers based on conditions
if can_use_num % 2 == 0:
can_use_num -= 1
if 51 <= can_use_num < 61:
can_use_num_61 = can_use_num
elif 41 <= can_use_num < 51:
can_use_num_61 = can_use_num_51 = can_use_num
elif 31 <= can_use_num < 41:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num
elif 25 <= can_use_num < 31:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num
elif 19 <= can_use_num < 25:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num
elif 15 <= can_use_num < 19:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num
elif 9 <= can_use_num < 15:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num
elif 5 <= can_use_num < 9:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num
elif 3 <= can_use_num < 5:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num
elif can_use_num < 3:
can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num_3 = 1
# Extract top predictions for each patch group
max_all_can_value, max_all_can_index = torch.topk(output_weight, can_use_num)
max_61_value, max_61_index = torch.topk(output_weight, can_use_num_61)
max_51_value, max_51_index = torch.topk(output_weight, can_use_num_51)
max_41_value, max_41_index = torch.topk(output_weight, can_use_num_41)
max_31_value, max_31_index = torch.topk(output_weight, can_use_num_31)
max_25_value, max_25_index = torch.topk(output_weight, can_use_num_25)
max_19_value, max_19_index = torch.topk(output_weight, can_use_num_19)
max_15_value, max_15_index = torch.topk(output_weight, can_use_num_15)
max_9_value, max_9_index = torch.topk(output_weight, can_use_num_9)
max_5_value, max_5_index = torch.topk(output_weight, can_use_num_5)
max_3_value, max_3_index = torch.topk(output_weight, can_use_num_3)
max_1_value, max_1_index = torch.topk(output_weight, can_use_num_1)
patch_image_list_all = []
patch_image_list_all_can_use = []
patch_image_list_61 = []
patch_image_list_51 = []
patch_image_list_41 = []
patch_image_list_31 = []
patch_image_list_25 = []
patch_image_list_19 = []
patch_image_list_15 = []
patch_image_list_9 = []
patch_image_list_5 = []
patch_image_list_3 = []
patch_image_list_1 = []
label_list = []
# REVIEW:
# 这里同时比较多套 top-k 聚合结果,说明作者不仅关心最终预测,也在研究“保留多少 patch 更合理”。
for k in range(0, 2016, size):
a = int(k / size)
for j in range(0, 2016, size):
b = int(j / size)
patch_image = img[:, :, k:k + size, j:j + size]
patch_image_list_all.append(patch_image)
if len(patch_image_list_all) != 0:
out_all = model(torch.cat(patch_image_list_all, dim=0))
# REVIEW:
# 81 个 patch 一次性送入 DPNet 是个不错的实现,避免了逐 patch 前向的重复开销。
for z in range(len(out_all)):
if torch.any(max_all_can_index.eq(z)):
patch_image_list_all_can_use.append(out_all[z].item())
if torch.any(max_61_index.eq(z)):
patch_image_list_61.append(out_all[z].item())
if torch.any(max_51_index.eq(z)):
patch_image_list_51.append(out_all[z].item())
if torch.any(max_41_index.eq(z)):
patch_image_list_41.append(out_all[z].item())
if torch.any(max_31_index.eq(z)):
patch_image_list_31.append(out_all[z].item())
if torch.any(max_25_index.eq(z)):
patch_image_list_25.append(out_all[z].item())
if torch.any(max_19_index.eq(z)):
patch_image_list_19.append(out_all[z].item())
if torch.any(max_15_index.eq(z)):
patch_image_list_15.append(out_all[z].item())
if torch.any(max_9_index.eq(z)):
patch_image_list_9.append(out_all[z].item())
if torch.any(max_5_index.eq(z)):
patch_image_list_5.append(out_all[z].item())
if torch.any(max_3_index.eq(z)):
patch_image_list_3.append(out_all[z].item())
if torch.any(max_1_index.eq(z)):
patch_image_list_1.append(out_all[z].item())
nums_all = float(np.median(out_all.tolist()))
avg_nums_all = float(np.mean(out_all.tolist()))
data = []
data.append(image_label)
data.append(float(np.median(out_all.tolist())) * 1000)
data.append(float(np.median(patch_image_list_all_can_use)) * 1000)
data.append(float(np.median(patch_image_list_61)) * 1000)
data.append(float(np.median(patch_image_list_51)) * 1000)
data.append(float(np.median(patch_image_list_41)) * 1000)
data.append(float(np.median(patch_image_list_31)) * 1000)
data.append(float(np.median(patch_image_list_25)) * 1000)
data.append(float(np.median(patch_image_list_19)) * 1000)
data.append(float(np.median(patch_image_list_15)) * 1000)
data.append(float(np.median(patch_image_list_9)) * 1000)
data.append(float(np.median(patch_image_list_5)) * 1000)
data.append(float(np.median(patch_image_list_3)) * 1000)
data.append(float(np.median(patch_image_list_1)) * 1000)
data.append((image_path.split('\\')[-3] + '\\' +image_path.split('\\')[-2] + '\\' +image_path.split('\\')[-1]))
data.append(len(output_weight_userful))
data.append(float(np.mean(out_all.tolist())) * 1000)
for (m, info) in enumerate(data):
sheet.cell(n + 2, m + 1).value = info
n = n + 1
util.save_val_image(1, img,
score=out_all, image_name=image_name,
type_str=cell_or_tissue, type='test', type_label=output_weight)
# REVIEW:
# 数值输出之外还保留了网格可视化,这对解释模型行为和排查异常样本很有帮助。
workbook.save(f'./{names}_to_{cell_or_tissue}_{data_type}_2.xlsx')
# REVIEW:
# 最终结果保存为 Excel非常贴合论文实验场景如果做批量评估结构化文本格式会更便于分析。