from pathlib import Path try: import numpy as np except ImportError as exc: raise ImportError("缺少 numpy。请在 torch271 环境中安装 numpy,或告诉我改用其他路线。") from exc try: import pandas as pd except ImportError as exc: raise ImportError("缺少 pandas。请在 torch271 环境中安装 pandas,或告诉我改用其他路线。") from exc from dataset import as_posix_path INPUT_CSV = Path("predictions") / "test_predictions.csv" OUTPUT_CSV = INPUT_CSV SUMMARY_TXT = Path("predictions") / "test_summary.txt" # 景深参考值。请按你的实验设定手动修改。 DEPTH_OF_FIELD = 1.0 LABEL_COLUMN = "label" PREDICTION_COLUMN = "prediction" SIGNED_ERROR_COLUMN = "signed_error" ABSOLUTE_ERROR_COLUMN = "absolute_error" DIRECTION_SIGN_COLUMN = "direction_sign" def check_columns(dataframe): required_columns = {LABEL_COLUMN, PREDICTION_COLUMN} missing_columns = required_columns - set(dataframe.columns) if missing_columns: missing_text = ", ".join(sorted(missing_columns)) raise ValueError(f"预测结果 CSV 缺少必要列:{missing_text}") def add_error_columns(dataframe): check_columns(dataframe) dataframe = dataframe.copy() labels = dataframe[LABEL_COLUMN].astype(float) predictions = dataframe[PREDICTION_COLUMN].astype(float) dataframe[SIGNED_ERROR_COLUMN] = labels - predictions dataframe[ABSOLUTE_ERROR_COLUMN] = dataframe[SIGNED_ERROR_COLUMN].abs() dataframe[DIRECTION_SIGN_COLUMN] = np.sign(labels * predictions).astype(int) return dataframe def percent_text(value): return f"{value * 100:.4f}%" def compute_summary(dataframe, depth_of_field=DEPTH_OF_FIELD): if depth_of_field <= 0: raise ValueError("DEPTH_OF_FIELD 必须大于 0。") absolute_error = dataframe[ABSOLUTE_ERROR_COLUMN].astype(float) direction_sign = dataframe[DIRECTION_SIGN_COLUMN] mae = absolute_error.mean() std = absolute_error.std(ddof=0) dof_acc_1 = (absolute_error <= depth_of_field).mean() dof_acc_1_2 = (absolute_error <= depth_of_field / 2).mean() dof_acc_1_3 = (absolute_error <= depth_of_field / 3).mean() dss = (direction_sign != -1).mean() return { "sample_count": len(dataframe), "mae": mae, "std": std, "dof_acc_1": dof_acc_1, "dof_acc_1_2": dof_acc_1_2, "dof_acc_1_3": dof_acc_1_3, "dss": dss, } def save_summary(summary, summary_txt=SUMMARY_TXT, depth_of_field=DEPTH_OF_FIELD): summary_txt = Path(summary_txt) summary_txt.parent.mkdir(parents=True, exist_ok=True) lines = [ "测试结果凝练统计", f"样本数:{summary['sample_count']}", f"景深参考值:{depth_of_field:.6g}", f"MAE and STD:{summary['mae']:.6g}, {summary['std']:.6g}", f"DoF-Acc@1:{percent_text(summary['dof_acc_1'])}", f"DoF-Acc@1/2:{percent_text(summary['dof_acc_1_2'])}", f"DoF-Acc@1/3:{percent_text(summary['dof_acc_1_3'])}", f"DSS:{percent_text(summary['dss'])}", ] summary_txt.write_text("\n".join(lines) + "\n", encoding="utf-8") return summary_txt def run_stats(input_csv=INPUT_CSV, output_csv=OUTPUT_CSV, summary_txt=SUMMARY_TXT, depth_of_field=DEPTH_OF_FIELD): input_csv = Path(input_csv) output_csv = Path(output_csv) if not input_csv.exists(): raise FileNotFoundError(f"找不到预测结果 CSV:{as_posix_path(input_csv)}") dataframe = pd.read_csv(input_csv) dataframe = add_error_columns(dataframe) output_csv.parent.mkdir(parents=True, exist_ok=True) dataframe.to_csv(output_csv, index=False, encoding="utf-8") summary = compute_summary(dataframe, depth_of_field=depth_of_field) summary_txt = save_summary(summary, summary_txt=summary_txt, depth_of_field=depth_of_field) print(f"已更新预测结果 CSV:{as_posix_path(output_csv)}") print(f"已保存统计摘要:{as_posix_path(summary_txt)}") return dataframe, summary def test_stats(): """小规模测试:如果预测 CSV 存在,就对当前文件执行统计流程。""" if not INPUT_CSV.exists(): print(f"预测结果 CSV 不存在,先跳过统计测试:{as_posix_path(INPUT_CSV)}") return run_stats( input_csv=INPUT_CSV, output_csv=OUTPUT_CSV, summary_txt=SUMMARY_TXT, depth_of_field=DEPTH_OF_FIELD, ) def main(): run_stats( input_csv=INPUT_CSV, output_csv=OUTPUT_CSV, summary_txt=SUMMARY_TXT, depth_of_field=DEPTH_OF_FIELD, ) if __name__ == "__main__": test_stats() # main()