2026-05-17 21:06:33 +08:00

148 lines
4.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pathlib import Path
try:
import numpy as np
except ImportError as exc:
raise ImportError("缺少 numpy。请在 torch271 环境中安装 numpy或告诉我改用其他路线。") from exc
try:
import pandas as pd
except ImportError as exc:
raise ImportError("缺少 pandas。请在 torch271 环境中安装 pandas或告诉我改用其他路线。") from exc
from dataset import as_posix_path
INPUT_CSV = Path("predictions") / "test_predictions.csv"
OUTPUT_CSV = INPUT_CSV
SUMMARY_TXT = Path("predictions") / "test_summary.txt"
# 景深参考值。请按你的实验设定手动修改。
DEPTH_OF_FIELD = 1.0
LABEL_COLUMN = "label"
PREDICTION_COLUMN = "prediction"
SIGNED_ERROR_COLUMN = "signed_error"
ABSOLUTE_ERROR_COLUMN = "absolute_error"
DIRECTION_SIGN_COLUMN = "direction_sign"
def check_columns(dataframe):
required_columns = {LABEL_COLUMN, PREDICTION_COLUMN}
missing_columns = required_columns - set(dataframe.columns)
if missing_columns:
missing_text = ", ".join(sorted(missing_columns))
raise ValueError(f"预测结果 CSV 缺少必要列:{missing_text}")
def add_error_columns(dataframe):
check_columns(dataframe)
dataframe = dataframe.copy()
labels = dataframe[LABEL_COLUMN].astype(float)
predictions = dataframe[PREDICTION_COLUMN].astype(float)
dataframe[SIGNED_ERROR_COLUMN] = labels - predictions
dataframe[ABSOLUTE_ERROR_COLUMN] = dataframe[SIGNED_ERROR_COLUMN].abs()
dataframe[DIRECTION_SIGN_COLUMN] = np.sign(labels * predictions).astype(int)
return dataframe
def percent_text(value):
return f"{value * 100:.4f}%"
def compute_summary(dataframe, depth_of_field=DEPTH_OF_FIELD):
if depth_of_field <= 0:
raise ValueError("DEPTH_OF_FIELD 必须大于 0。")
absolute_error = dataframe[ABSOLUTE_ERROR_COLUMN].astype(float)
direction_sign = dataframe[DIRECTION_SIGN_COLUMN]
mae = absolute_error.mean()
std = absolute_error.std(ddof=0)
dof_acc_1 = (absolute_error <= depth_of_field).mean()
dof_acc_1_2 = (absolute_error <= depth_of_field / 2).mean()
dof_acc_1_3 = (absolute_error <= depth_of_field / 3).mean()
dss = (direction_sign != -1).mean()
return {
"sample_count": len(dataframe),
"mae": mae,
"std": std,
"dof_acc_1": dof_acc_1,
"dof_acc_1_2": dof_acc_1_2,
"dof_acc_1_3": dof_acc_1_3,
"dss": dss,
}
def save_summary(summary, summary_txt=SUMMARY_TXT, depth_of_field=DEPTH_OF_FIELD):
summary_txt = Path(summary_txt)
summary_txt.parent.mkdir(parents=True, exist_ok=True)
lines = [
"测试结果凝练统计",
f"样本数:{summary['sample_count']}",
f"景深参考值:{depth_of_field:.6g}",
f"MAE and STD{summary['mae']:.6g}, {summary['std']:.6g}",
f"DoF-Acc@1{percent_text(summary['dof_acc_1'])}",
f"DoF-Acc@1/2{percent_text(summary['dof_acc_1_2'])}",
f"DoF-Acc@1/3{percent_text(summary['dof_acc_1_3'])}",
f"DSS{percent_text(summary['dss'])}",
]
summary_txt.write_text("\n".join(lines) + "\n", encoding="utf-8")
return summary_txt
def run_stats(input_csv=INPUT_CSV, output_csv=OUTPUT_CSV, summary_txt=SUMMARY_TXT, depth_of_field=DEPTH_OF_FIELD):
input_csv = Path(input_csv)
output_csv = Path(output_csv)
if not input_csv.exists():
raise FileNotFoundError(f"找不到预测结果 CSV{as_posix_path(input_csv)}")
dataframe = pd.read_csv(input_csv)
dataframe = add_error_columns(dataframe)
output_csv.parent.mkdir(parents=True, exist_ok=True)
dataframe.to_csv(output_csv, index=False, encoding="utf-8")
summary = compute_summary(dataframe, depth_of_field=depth_of_field)
summary_txt = save_summary(summary, summary_txt=summary_txt, depth_of_field=depth_of_field)
print(f"已更新预测结果 CSV{as_posix_path(output_csv)}")
print(f"已保存统计摘要:{as_posix_path(summary_txt)}")
return dataframe, summary
def test_stats():
"""小规模测试:如果预测 CSV 存在,就对当前文件执行统计流程。"""
if not INPUT_CSV.exists():
print(f"预测结果 CSV 不存在,先跳过统计测试:{as_posix_path(INPUT_CSV)}")
return
run_stats(
input_csv=INPUT_CSV,
output_csv=OUTPUT_CSV,
summary_txt=SUMMARY_TXT,
depth_of_field=DEPTH_OF_FIELD,
)
def main():
run_stats(
input_csv=INPUT_CSV,
output_csv=OUTPUT_CSV,
summary_txt=SUMMARY_TXT,
depth_of_field=DEPTH_OF_FIELD,
)
if __name__ == "__main__":
test_stats()
# main()