SparseFocus/test.py

146 lines
4.2 KiB
Python

from pathlib import Path
import torch
import pandas as pd
from PIL import Image
from torchvision.transforms import functional as F
from tqdm import tqdm
import old_datasets as dataset_F
import utils
from models import DPNet, RINet
def strip_module_prefix(state_dict):
if not all(key.startswith("module.") for key in state_dict):
return state_dict
return {key.removeprefix("module."): value for key, value in state_dict.items()}
def load_model_weight(model, weight_path, device):
state_dict = torch.load(weight_path, map_location="cpu")
if isinstance(state_dict, dict) and "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(strip_module_prefix(state_dict), strict=True)
model.to(device)
model.eval()
return model
def build_patch_tensor(cropped_image):
cropped_tensor = F.to_tensor(cropped_image)
patches = []
for row in range(9):
for col in range(9):
top = row * 224
left = col * 224
patch = cropped_tensor[:, top: top + 224, left: left + 224]
patches.append(patch)
return torch.stack(patches, dim=0)
def predict_one_image(image_path, importance_label, defocus_label, rin_model, dpn_model, device):
if len(importance_label) != 81:
raise RuntimeError(f"Importance label count should be 81, got {len(importance_label)}.")
image = Image.open(image_path).convert("RGB")
cropped_image = F.center_crop(image, [2016, 2016])
rin_image = F.resize(cropped_image, [512, 512])
rin_tensor = F.to_tensor(rin_image).unsqueeze(0).to(device)
patch_tensor = build_patch_tensor(cropped_image).to(device)
with torch.no_grad():
importance_predictions = rin_model(rin_tensor).reshape(-1).detach().cpu()
defocus_predictions = dpn_model(patch_tensor).reshape(-1).detach().cpu()
if len(importance_predictions) != 81:
raise RuntimeError(f"RIN output count should be 81, got {len(importance_predictions)}.")
if len(defocus_predictions) != 81:
raise RuntimeError(f"DPN output count should be 81, got {len(defocus_predictions)}.")
rows = []
for roi_no in range(81):
rows.append(
[
str(image_path),
roi_no,
importance_label[roi_no],
float(importance_predictions[roi_no]),
float(defocus_label) * 1000.0,
float(defocus_predictions[roi_no]) * 1000.0,
]
)
return rows
def save_results(rows, output_path):
df = pd.DataFrame(
rows,
columns=[
"image_path",
"roi_no",
"importance_label",
"importance_prediction",
"defocus_label",
"defocus_prediction",
],
)
df.to_excel(output_path, index=False)
def main():
config, config_path = utils.get_hyperparams()
excel_file = config["excel_file"]
rin_weight = config["rin_weight"]
dpn_weight = config["dpn_weight"]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
output_path = Path.cwd() / "test_results.xlsx"
print(f"Config path: {config_path}")
print(f"Using device: {device}")
print("Loading models")
rin_model = load_model_weight(RINet(), rin_weight, device)
dpn_model = load_model_weight(DPNet(), dpn_weight, device)
print("Loading test data")
image_paths, importance_labels, defocus_labels = dataset_F.get_test_data_and_label(
excel_file,
"Sheet",
)
print(f"Test image count: {len(image_paths)}")
all_rows = []
for image_path, importance_label, defocus_label in tqdm(
zip(image_paths, importance_labels, defocus_labels),
total=len(image_paths),
desc="Test",
bar_format="{l_bar}{bar:20}{r_bar}",
):
all_rows.extend(
predict_one_image(
image_path,
importance_label,
defocus_label,
rin_model,
dpn_model,
device,
)
)
save_results(all_rows, output_path)
print(f"Saved test results: {output_path}")
print(f"Saved row count: {len(all_rows)}")
if __name__ == "__main__":
main()