148 lines
4.6 KiB
Python
148 lines
4.6 KiB
Python
from pathlib import Path
|
||
|
||
try:
|
||
import numpy as np
|
||
except ImportError as exc:
|
||
raise ImportError("缺少 numpy。请在 torch271 环境中安装 numpy,或告诉我改用其他路线。") from exc
|
||
|
||
try:
|
||
import pandas as pd
|
||
except ImportError as exc:
|
||
raise ImportError("缺少 pandas。请在 torch271 环境中安装 pandas,或告诉我改用其他路线。") from exc
|
||
|
||
from dataset import as_posix_path
|
||
|
||
|
||
INPUT_CSV = Path("predictions") / "test_predictions.csv"
|
||
OUTPUT_CSV = INPUT_CSV
|
||
SUMMARY_TXT = Path("predictions") / "test_summary.txt"
|
||
|
||
# 景深参考值。请按你的实验设定手动修改。
|
||
DEPTH_OF_FIELD = 1.0
|
||
|
||
LABEL_COLUMN = "label"
|
||
PREDICTION_COLUMN = "prediction"
|
||
SIGNED_ERROR_COLUMN = "signed_error"
|
||
ABSOLUTE_ERROR_COLUMN = "absolute_error"
|
||
DIRECTION_SIGN_COLUMN = "direction_sign"
|
||
|
||
|
||
def check_columns(dataframe):
|
||
required_columns = {LABEL_COLUMN, PREDICTION_COLUMN}
|
||
missing_columns = required_columns - set(dataframe.columns)
|
||
if missing_columns:
|
||
missing_text = ", ".join(sorted(missing_columns))
|
||
raise ValueError(f"预测结果 CSV 缺少必要列:{missing_text}")
|
||
|
||
|
||
def add_error_columns(dataframe):
|
||
check_columns(dataframe)
|
||
|
||
dataframe = dataframe.copy()
|
||
labels = dataframe[LABEL_COLUMN].astype(float)
|
||
predictions = dataframe[PREDICTION_COLUMN].astype(float)
|
||
|
||
dataframe[SIGNED_ERROR_COLUMN] = labels - predictions
|
||
dataframe[ABSOLUTE_ERROR_COLUMN] = dataframe[SIGNED_ERROR_COLUMN].abs()
|
||
dataframe[DIRECTION_SIGN_COLUMN] = np.sign(labels * predictions).astype(int)
|
||
|
||
return dataframe
|
||
|
||
|
||
def percent_text(value):
|
||
return f"{value * 100:.4f}%"
|
||
|
||
|
||
def compute_summary(dataframe, depth_of_field=DEPTH_OF_FIELD):
|
||
if depth_of_field <= 0:
|
||
raise ValueError("DEPTH_OF_FIELD 必须大于 0。")
|
||
|
||
absolute_error = dataframe[ABSOLUTE_ERROR_COLUMN].astype(float)
|
||
direction_sign = dataframe[DIRECTION_SIGN_COLUMN]
|
||
|
||
mae = absolute_error.mean()
|
||
std = absolute_error.std(ddof=0)
|
||
|
||
dof_acc_1 = (absolute_error <= depth_of_field).mean()
|
||
dof_acc_1_2 = (absolute_error <= depth_of_field / 2).mean()
|
||
dof_acc_1_3 = (absolute_error <= depth_of_field / 3).mean()
|
||
dss = (direction_sign != -1).mean()
|
||
|
||
return {
|
||
"sample_count": len(dataframe),
|
||
"mae": mae,
|
||
"std": std,
|
||
"dof_acc_1": dof_acc_1,
|
||
"dof_acc_1_2": dof_acc_1_2,
|
||
"dof_acc_1_3": dof_acc_1_3,
|
||
"dss": dss,
|
||
}
|
||
|
||
|
||
def save_summary(summary, summary_txt=SUMMARY_TXT, depth_of_field=DEPTH_OF_FIELD):
|
||
summary_txt = Path(summary_txt)
|
||
summary_txt.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
lines = [
|
||
"测试结果凝练统计",
|
||
f"样本数:{summary['sample_count']}",
|
||
f"景深参考值:{depth_of_field:.6g}",
|
||
f"MAE and STD:{summary['mae']:.6g}, {summary['std']:.6g}",
|
||
f"DoF-Acc@1:{percent_text(summary['dof_acc_1'])}",
|
||
f"DoF-Acc@1/2:{percent_text(summary['dof_acc_1_2'])}",
|
||
f"DoF-Acc@1/3:{percent_text(summary['dof_acc_1_3'])}",
|
||
f"DSS:{percent_text(summary['dss'])}",
|
||
]
|
||
|
||
summary_txt.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||
return summary_txt
|
||
|
||
|
||
def run_stats(input_csv=INPUT_CSV, output_csv=OUTPUT_CSV, summary_txt=SUMMARY_TXT, depth_of_field=DEPTH_OF_FIELD):
|
||
input_csv = Path(input_csv)
|
||
output_csv = Path(output_csv)
|
||
|
||
if not input_csv.exists():
|
||
raise FileNotFoundError(f"找不到预测结果 CSV:{as_posix_path(input_csv)}")
|
||
|
||
dataframe = pd.read_csv(input_csv)
|
||
dataframe = add_error_columns(dataframe)
|
||
|
||
output_csv.parent.mkdir(parents=True, exist_ok=True)
|
||
dataframe.to_csv(output_csv, index=False, encoding="utf-8")
|
||
|
||
summary = compute_summary(dataframe, depth_of_field=depth_of_field)
|
||
summary_txt = save_summary(summary, summary_txt=summary_txt, depth_of_field=depth_of_field)
|
||
|
||
print(f"已更新预测结果 CSV:{as_posix_path(output_csv)}")
|
||
print(f"已保存统计摘要:{as_posix_path(summary_txt)}")
|
||
return dataframe, summary
|
||
|
||
|
||
def test_stats():
|
||
"""小规模测试:如果预测 CSV 存在,就对当前文件执行统计流程。"""
|
||
if not INPUT_CSV.exists():
|
||
print(f"预测结果 CSV 不存在,先跳过统计测试:{as_posix_path(INPUT_CSV)}")
|
||
return
|
||
|
||
run_stats(
|
||
input_csv=INPUT_CSV,
|
||
output_csv=OUTPUT_CSV,
|
||
summary_txt=SUMMARY_TXT,
|
||
depth_of_field=DEPTH_OF_FIELD,
|
||
)
|
||
|
||
|
||
def main():
|
||
run_stats(
|
||
input_csv=INPUT_CSV,
|
||
output_csv=OUTPUT_CSV,
|
||
summary_txt=SUMMARY_TXT,
|
||
depth_of_field=DEPTH_OF_FIELD,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_stats()
|
||
# main()
|