SparseFocus/stat_test_results.py

121 lines
3.2 KiB
Python

import argparse
from pathlib import Path
import pandas as pd
REQUIRED_COLUMNS = [
"image_path",
"roi_no",
"importance_label",
"importance_prediction",
"defocus_label",
"defocus_prediction",
]
def get_args():
parser = argparse.ArgumentParser(description="Aggregate SparseFocus test results")
parser.add_argument(
"result_file",
nargs="?",
default="test_results.xlsx",
help="Path to test result xlsx file",
)
return parser.parse_args()
def check_columns(df):
missing_columns = [column for column in REQUIRED_COLUMNS if column not in df.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
def get_sparsity_level(usable_count):
if usable_count == 0:
return "invalid"
if usable_count > 41:
return "dense"
if usable_count > 9:
return "sparse"
return "extremely_sparse"
def iter_image_groups(df):
if len(df) % 81 != 0:
raise ValueError(f"Row count should be divisible by 81, got {len(df)}.")
for start in range(0, len(df), 81):
group = df.iloc[start: start + 81].copy()
expected_roi_no = list(range(81))
actual_roi_no = group["roi_no"].astype(int).tolist()
if actual_roi_no != expected_roi_no:
image_path = group["image_path"].iloc[0]
raise ValueError(f"ROI order mismatch near image: {image_path}")
yield group
def aggregate_one_image(group):
image_path = group["image_path"].iloc[0]
defocus_label = group["defocus_label"].iloc[0]
usable_count = int((group["importance_label"] > 0).sum())
sparsity_level = get_sparsity_level(usable_count)
selected_by_prediction = group[group["importance_prediction"] > 0.8]
if len(selected_by_prediction) > 0:
pred_importance_gt_0_8 = selected_by_prediction["defocus_prediction"].median()
else:
pred_importance_gt_0_8 = pd.NA
sorted_group = group.sort_values(
by="importance_prediction",
ascending=False,
kind="mergesort",
)
row = {
"image_path": image_path,
"defocus_label": defocus_label,
"usable_roi_count": usable_count,
"sparsity_level": sparsity_level,
"pred_importance_gt_0_8": pred_importance_gt_0_8,
"all_blocks_median": group["defocus_prediction"].median(),
}
for k in range(81, 0, -1):
row[f"top_{k}_median"] = sorted_group.head(k)["defocus_prediction"].median()
return row
def aggregate_results(df):
check_columns(df)
rows = [aggregate_one_image(group) for group in iter_image_groups(df)]
return pd.DataFrame(rows)
def main():
args = get_args()
result_path = Path(args.result_file)
print(f"Reading test results: {result_path}")
df = pd.read_excel(result_path, sheet_name=0)
stat_df = aggregate_results(df)
print(f"Image count: {len(stat_df)}")
print("Writing Sheet2")
with pd.ExcelWriter(
result_path,
engine="openpyxl",
mode="a",
if_sheet_exists="replace",
) as writer:
stat_df.to_excel(writer, sheet_name="Sheet2", index=False)
print(f"Saved statistics to Sheet2: {result_path}")
if __name__ == "__main__":
main()