from pathlib import Path import torch import pandas as pd from PIL import Image from torchvision.transforms import functional as F from tqdm import tqdm import old_datasets as dataset_F import utils from models import DPNet, RINet def strip_module_prefix(state_dict): if not all(key.startswith("module.") for key in state_dict): return state_dict return {key.removeprefix("module."): value for key, value in state_dict.items()} def load_model_weight(model, weight_path, device): state_dict = torch.load(weight_path, map_location="cpu") if isinstance(state_dict, dict) and "state_dict" in state_dict: state_dict = state_dict["state_dict"] model.load_state_dict(strip_module_prefix(state_dict), strict=True) model.to(device) model.eval() return model def build_patch_tensor(cropped_image): cropped_tensor = F.to_tensor(cropped_image) patches = [] for row in range(9): for col in range(9): top = row * 224 left = col * 224 patch = cropped_tensor[:, top: top + 224, left: left + 224] patches.append(patch) return torch.stack(patches, dim=0) def predict_one_image(image_path, importance_label, defocus_label, rin_model, dpn_model, device): if len(importance_label) != 81: raise RuntimeError(f"Importance label count should be 81, got {len(importance_label)}.") image = Image.open(image_path).convert("RGB") cropped_image = F.center_crop(image, [2016, 2016]) rin_image = F.resize(cropped_image, [512, 512]) rin_tensor = F.to_tensor(rin_image).unsqueeze(0).to(device) patch_tensor = build_patch_tensor(cropped_image).to(device) with torch.no_grad(): importance_predictions = rin_model(rin_tensor).reshape(-1).detach().cpu() defocus_predictions = dpn_model(patch_tensor).reshape(-1).detach().cpu() if len(importance_predictions) != 81: raise RuntimeError(f"RIN output count should be 81, got {len(importance_predictions)}.") if len(defocus_predictions) != 81: raise RuntimeError(f"DPN output count should be 81, got {len(defocus_predictions)}.") rows = [] for roi_no in range(81): rows.append( [ str(image_path), roi_no, importance_label[roi_no], float(importance_predictions[roi_no]), float(defocus_label) * 1000.0, float(defocus_predictions[roi_no]) * 1000.0, ] ) return rows def save_results(rows, output_path): df = pd.DataFrame( rows, columns=[ "image_path", "roi_no", "importance_label", "importance_prediction", "defocus_label", "defocus_prediction", ], ) df.to_excel(output_path, index=False) def main(): config, config_path = utils.get_hyperparams() excel_file = config["excel_file"] rin_weight = config["rin_weight"] dpn_weight = config["dpn_weight"] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") output_path = Path.cwd() / "test_results.xlsx" print(f"Config path: {config_path}") print(f"Using device: {device}") print("Loading models") rin_model = load_model_weight(RINet(), rin_weight, device) dpn_model = load_model_weight(DPNet(), dpn_weight, device) print("Loading test data") image_paths, importance_labels, defocus_labels = dataset_F.get_test_data_and_label( excel_file, "Sheet", ) print(f"Test image count: {len(image_paths)}") all_rows = [] for image_path, importance_label, defocus_label in tqdm( zip(image_paths, importance_labels, defocus_labels), total=len(image_paths), desc="Test", bar_format="{l_bar}{bar:20}{r_bar}", ): all_rows.extend( predict_one_image( image_path, importance_label, defocus_label, rin_model, dpn_model, device, ) ) save_results(all_rows, output_path) print(f"Saved test results: {output_path}") print(f"Saved row count: {len(all_rows)}") if __name__ == "__main__": main()