146 lines
4.2 KiB
Python
146 lines
4.2 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
import pandas as pd
|
|
from PIL import Image
|
|
from torchvision.transforms import functional as F
|
|
from tqdm import tqdm
|
|
|
|
import old_datasets as dataset_F
|
|
import utils
|
|
from models import DPNet, RINet
|
|
|
|
|
|
def strip_module_prefix(state_dict):
|
|
if not all(key.startswith("module.") for key in state_dict):
|
|
return state_dict
|
|
|
|
return {key.removeprefix("module."): value for key, value in state_dict.items()}
|
|
|
|
|
|
def load_model_weight(model, weight_path, device):
|
|
state_dict = torch.load(weight_path, map_location="cpu")
|
|
if isinstance(state_dict, dict) and "state_dict" in state_dict:
|
|
state_dict = state_dict["state_dict"]
|
|
|
|
model.load_state_dict(strip_module_prefix(state_dict), strict=True)
|
|
model.to(device)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def build_patch_tensor(cropped_image):
|
|
cropped_tensor = F.to_tensor(cropped_image)
|
|
patches = []
|
|
|
|
for row in range(9):
|
|
for col in range(9):
|
|
top = row * 224
|
|
left = col * 224
|
|
patch = cropped_tensor[:, top: top + 224, left: left + 224]
|
|
patches.append(patch)
|
|
|
|
return torch.stack(patches, dim=0)
|
|
|
|
|
|
def predict_one_image(image_path, importance_label, defocus_label, rin_model, dpn_model, device):
|
|
if len(importance_label) != 81:
|
|
raise RuntimeError(f"Importance label count should be 81, got {len(importance_label)}.")
|
|
|
|
image = Image.open(image_path).convert("RGB")
|
|
cropped_image = F.center_crop(image, [2016, 2016])
|
|
|
|
rin_image = F.resize(cropped_image, [512, 512])
|
|
rin_tensor = F.to_tensor(rin_image).unsqueeze(0).to(device)
|
|
patch_tensor = build_patch_tensor(cropped_image).to(device)
|
|
|
|
with torch.no_grad():
|
|
importance_predictions = rin_model(rin_tensor).reshape(-1).detach().cpu()
|
|
defocus_predictions = dpn_model(patch_tensor).reshape(-1).detach().cpu()
|
|
|
|
if len(importance_predictions) != 81:
|
|
raise RuntimeError(f"RIN output count should be 81, got {len(importance_predictions)}.")
|
|
if len(defocus_predictions) != 81:
|
|
raise RuntimeError(f"DPN output count should be 81, got {len(defocus_predictions)}.")
|
|
|
|
rows = []
|
|
for roi_no in range(81):
|
|
rows.append(
|
|
[
|
|
str(image_path),
|
|
roi_no,
|
|
importance_label[roi_no],
|
|
float(importance_predictions[roi_no]),
|
|
float(defocus_label) * 1000.0,
|
|
float(defocus_predictions[roi_no]) * 1000.0,
|
|
]
|
|
)
|
|
|
|
return rows
|
|
|
|
|
|
def save_results(rows, output_path):
|
|
df = pd.DataFrame(
|
|
rows,
|
|
columns=[
|
|
"image_path",
|
|
"roi_no",
|
|
"importance_label",
|
|
"importance_prediction",
|
|
"defocus_label",
|
|
"defocus_prediction",
|
|
],
|
|
)
|
|
df.to_excel(output_path, index=False)
|
|
|
|
|
|
def main():
|
|
config, config_path = utils.get_hyperparams()
|
|
|
|
excel_file = config["excel_file"]
|
|
rin_weight = config["rin_weight"]
|
|
dpn_weight = config["dpn_weight"]
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
output_path = Path.cwd() / "test_results.xlsx"
|
|
|
|
print(f"Config path: {config_path}")
|
|
print(f"Using device: {device}")
|
|
print("Loading models")
|
|
|
|
rin_model = load_model_weight(RINet(), rin_weight, device)
|
|
dpn_model = load_model_weight(DPNet(), dpn_weight, device)
|
|
|
|
print("Loading test data")
|
|
image_paths, importance_labels, defocus_labels = dataset_F.get_test_data_and_label(
|
|
excel_file,
|
|
"Sheet",
|
|
)
|
|
print(f"Test image count: {len(image_paths)}")
|
|
|
|
all_rows = []
|
|
for image_path, importance_label, defocus_label in tqdm(
|
|
zip(image_paths, importance_labels, defocus_labels),
|
|
total=len(image_paths),
|
|
desc="Test",
|
|
bar_format="{l_bar}{bar:20}{r_bar}",
|
|
):
|
|
all_rows.extend(
|
|
predict_one_image(
|
|
image_path,
|
|
importance_label,
|
|
defocus_label,
|
|
rin_model,
|
|
dpn_model,
|
|
device,
|
|
)
|
|
)
|
|
|
|
save_results(all_rows, output_path)
|
|
print(f"Saved test results: {output_path}")
|
|
print(f"Saved row count: {len(all_rows)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|