import os import numpy as np import torch.nn.functional as F import torchvision.transforms as transforms from torch.utils.data import DataLoader import torch.nn as nn import torch from model.DPNet import register_model from model.RINet import MobileNetV3_small as con_classification from dataset import datasets as dataset_F import util.utils as util from tqdm import tqdm import openpyxl # REVIEW: # 该脚本是项目的测试/推理主入口,也是两阶段方法真正汇合的地方: # 先用 RINet 对整图打 patch 有效性分数,再用 DPNet 对候选 patch 做 defocus distance 回归, # 最后通过中位数聚合得到整图预测。 # # REVIEW: # 如果想快速理解作者的方法,这个文件是最关键的,因为它最完整地体现了算法在实际评估中的执行顺序。 # Set the visible GPU devices for CUDA os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # Image patch size size = 224 # Paths to test data Excel files root_path_list = [r'E:\suqiang\DenseSparse\testData_20241227.xlsx'] # REVIEW: # 测试集路径完全硬编码,说明当前脚本更像作者自己的实验工具而不是通用命令行程序。 # Define data sheet types and test categories (cell or tissue) data_type_list = ['Sheet', 'Sheet1', 'Sheet1', 'Sheet1'] cell_or_tissue_list = ['tissue', 'tissue'] # Load the classification model and pre-trained weights classification_model = con_classification() state_load = torch.load('../weight/RINet_best_model.pt') classification_model = nn.DataParallel(classification_model).to('cuda') classification_model.load_state_dict(state_load, False) classification_model.eval() # Set the model to evaluation mode # REVIEW: # RINet 权重和结构之间没有版本检查,只有运行时才能暴露兼容性问题。 # Loop for different models (e.g., DPNet) for c in range(1): if c == 0: state_load = torch.load('../weight/DP_best_model.pt') model = register_model() names = 'ours' model = nn.DataParallel(model).to('cuda') model.load_state_dict(state_load, False) model.eval() # Set the model to evaluation mode # Iterate over test datasets for j in range(len(root_path_list)): root_path = root_path_list[j] data_type = data_type_list[j] cell_or_tissue = cell_or_tissue_list[j] # Load test data and labels from the specified Excel sheet (train_image_path_list, train_patch_effective_list, train_defocus_distance_list) = dataset_F.get_test_data_and_label(root_path=root_path, type=f'{data_type}') # Define image transformation pipeline train_transform = transforms.Compose([ transforms.CenterCrop((2016, 2016)), # Crop the center region of the image transforms.ToTensor(), # Convert the image to a tensor ]) # REVIEW: # 这里把整图固定裁成 2016x2016,并配合 size=224 形成 9x9 patch 网格。 # 因此测试流程对输入尺寸有很强假设。 # Create a custom dataset and DataLoader train_data = dataset_F.MyDataset_test(train_image_path_list, train_defocus_distance_list, train_patch_effective_list, train_transform) train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, pin_memory=True, num_workers=0) # Initialize an Excel workbook to save results workbook = openpyxl.Workbook() sheet = workbook.create_sheet(index=0) n = 0 # Row index for Excel sheet # Iterate through the test data for i, (img, label, image_name, effective) in tqdm(enumerate(train_loader)): image_path = image_name[0] image_label = int(image_path.split('\\')[-1].split('.')[0]) # Extract ground truth label # Add header row to the Excel sheet on the first iteration if i == 0: sheet.cell(1, 1).value = 'groundTruth' sheet.cell(1, 2).value = 'all_prediction' sheet.cell(1, 3).value = 'all_can_use_prediction' sheet.cell(1, 4).value = '61_prediction' sheet.cell(1, 5).value = '51_prediction' sheet.cell(1, 6).value = '41_prediction' sheet.cell(1, 7).value = '31_prediction' sheet.cell(1, 8).value = '25_prediction' sheet.cell(1, 9).value = '19_prediction' sheet.cell(1, 10).value = '15_prediction' sheet.cell(1, 11).value = '9_prediction' sheet.cell(1, 12).value = '5_prediction' sheet.cell(1, 13).value = '3_prediction' sheet.cell(1, 14).value = '1_prediction' sheet.cell(1, 15).value = 'path' sheet.cell(1, 16).value = 'max_can_use_patch_num' sheet.cell(1, 17).value = 'avg_prediction' # Preprocess the image for the classification model img2 = F.interpolate(img, size=(512, 512), mode='bilinear', align_corners=False) img = torch.cuda.FloatTensor(np.array(img)) img2 = torch.cuda.FloatTensor(np.array(img2)) # REVIEW: # 这里继续沿用了 Tensor -> numpy -> CUDA Tensor 的旧式转换写法。 if torch.cuda.is_available(): img = img.cuda() img2 = img2.cuda() # Perform classification on the preprocessed image output_weight = classification_model(img2) output_weight_userful = output_weight[torch.abs(output_weight) > 0.8] # Filter predictions above a threshold can_use_num = len(output_weight_userful) # Count usable predictions # REVIEW: # patch 是否“可用”由阈值 0.8 决定,这是一个非常关键的经验超参数。 # Initialize usable patch numbers for different thresholds can_use_num_61, can_use_num_51, can_use_num_41 = 61, 51, 41 can_use_num_31, can_use_num_25, can_use_num_19 = 31, 25, 19 can_use_num_15, can_use_num_9, can_use_num_5 = 15, 9, 5 can_use_num_3, can_use_num_1 = 3, 1 # Handle cases with no usable patches if can_use_num == 0: util.print_info(f'{image_name} don\'t have enough info') continue # REVIEW: # 没有任何 patch 通过阈值时直接跳过样本,会导致结果文件中少样本,这是需要知晓的隐式行为。 # Adjust usable patch numbers based on conditions if can_use_num % 2 == 0: can_use_num -= 1 if 51 <= can_use_num < 61: can_use_num_61 = can_use_num elif 41 <= can_use_num < 51: can_use_num_61 = can_use_num_51 = can_use_num elif 31 <= can_use_num < 41: can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num elif 25 <= can_use_num < 31: can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num elif 19 <= can_use_num < 25: can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num elif 15 <= can_use_num < 19: can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num elif 9 <= can_use_num < 15: can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num elif 5 <= can_use_num < 9: can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num elif 3 <= can_use_num < 5: can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num elif can_use_num < 3: can_use_num_61 = can_use_num_51 = can_use_num_41 = can_use_num_31 = can_use_num_25 = can_use_num_19 = can_use_num_15 = can_use_num_9 = can_use_num_5 = can_use_num_3 = 1 # Extract top predictions for each patch group max_all_can_value, max_all_can_index = torch.topk(output_weight, can_use_num) max_61_value, max_61_index = torch.topk(output_weight, can_use_num_61) max_51_value, max_51_index = torch.topk(output_weight, can_use_num_51) max_41_value, max_41_index = torch.topk(output_weight, can_use_num_41) max_31_value, max_31_index = torch.topk(output_weight, can_use_num_31) max_25_value, max_25_index = torch.topk(output_weight, can_use_num_25) max_19_value, max_19_index = torch.topk(output_weight, can_use_num_19) max_15_value, max_15_index = torch.topk(output_weight, can_use_num_15) max_9_value, max_9_index = torch.topk(output_weight, can_use_num_9) max_5_value, max_5_index = torch.topk(output_weight, can_use_num_5) max_3_value, max_3_index = torch.topk(output_weight, can_use_num_3) max_1_value, max_1_index = torch.topk(output_weight, can_use_num_1) patch_image_list_all = [] patch_image_list_all_can_use = [] patch_image_list_61 = [] patch_image_list_51 = [] patch_image_list_41 = [] patch_image_list_31 = [] patch_image_list_25 = [] patch_image_list_19 = [] patch_image_list_15 = [] patch_image_list_9 = [] patch_image_list_5 = [] patch_image_list_3 = [] patch_image_list_1 = [] label_list = [] # REVIEW: # 这里同时比较多套 top-k 聚合结果,说明作者不仅关心最终预测,也在研究“保留多少 patch 更合理”。 for k in range(0, 2016, size): a = int(k / size) for j in range(0, 2016, size): b = int(j / size) patch_image = img[:, :, k:k + size, j:j + size] patch_image_list_all.append(patch_image) if len(patch_image_list_all) != 0: out_all = model(torch.cat(patch_image_list_all, dim=0)) # REVIEW: # 81 个 patch 一次性送入 DPNet 是个不错的实现,避免了逐 patch 前向的重复开销。 for z in range(len(out_all)): if torch.any(max_all_can_index.eq(z)): patch_image_list_all_can_use.append(out_all[z].item()) if torch.any(max_61_index.eq(z)): patch_image_list_61.append(out_all[z].item()) if torch.any(max_51_index.eq(z)): patch_image_list_51.append(out_all[z].item()) if torch.any(max_41_index.eq(z)): patch_image_list_41.append(out_all[z].item()) if torch.any(max_31_index.eq(z)): patch_image_list_31.append(out_all[z].item()) if torch.any(max_25_index.eq(z)): patch_image_list_25.append(out_all[z].item()) if torch.any(max_19_index.eq(z)): patch_image_list_19.append(out_all[z].item()) if torch.any(max_15_index.eq(z)): patch_image_list_15.append(out_all[z].item()) if torch.any(max_9_index.eq(z)): patch_image_list_9.append(out_all[z].item()) if torch.any(max_5_index.eq(z)): patch_image_list_5.append(out_all[z].item()) if torch.any(max_3_index.eq(z)): patch_image_list_3.append(out_all[z].item()) if torch.any(max_1_index.eq(z)): patch_image_list_1.append(out_all[z].item()) nums_all = float(np.median(out_all.tolist())) avg_nums_all = float(np.mean(out_all.tolist())) data = [] data.append(image_label) data.append(float(np.median(out_all.tolist())) * 1000) data.append(float(np.median(patch_image_list_all_can_use)) * 1000) data.append(float(np.median(patch_image_list_61)) * 1000) data.append(float(np.median(patch_image_list_51)) * 1000) data.append(float(np.median(patch_image_list_41)) * 1000) data.append(float(np.median(patch_image_list_31)) * 1000) data.append(float(np.median(patch_image_list_25)) * 1000) data.append(float(np.median(patch_image_list_19)) * 1000) data.append(float(np.median(patch_image_list_15)) * 1000) data.append(float(np.median(patch_image_list_9)) * 1000) data.append(float(np.median(patch_image_list_5)) * 1000) data.append(float(np.median(patch_image_list_3)) * 1000) data.append(float(np.median(patch_image_list_1)) * 1000) data.append((image_path.split('\\')[-3] + '\\' +image_path.split('\\')[-2] + '\\' +image_path.split('\\')[-1])) data.append(len(output_weight_userful)) data.append(float(np.mean(out_all.tolist())) * 1000) for (m, info) in enumerate(data): sheet.cell(n + 2, m + 1).value = info n = n + 1 util.save_val_image(1, img, score=out_all, image_name=image_name, type_str=cell_or_tissue, type='test', type_label=output_weight) # REVIEW: # 数值输出之外还保留了网格可视化,这对解释模型行为和排查异常样本很有帮助。 workbook.save(f'./{names}_to_{cell_or_tissue}_{data_type}_2.xlsx') # REVIEW: # 最终结果保存为 Excel,非常贴合论文实验场景;如果做批量评估,结构化文本格式会更便于分析。