From d1242b2cc053a95e2ef367d37c0a99bab4fd1727 Mon Sep 17 00:00:00 2001 From: kaiza_hikaru Date: Thu, 23 Oct 2025 21:28:38 +0800 Subject: [PATCH] add data process --- MOAFStat.py | 259 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 259 insertions(+) create mode 100644 MOAFStat.py diff --git a/MOAFStat.py b/MOAFStat.py new file mode 100644 index 0000000..8aeabec --- /dev/null +++ b/MOAFStat.py @@ -0,0 +1,259 @@ +import tomllib +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm, trange + +from MOAFUtils import print_with_timestamp + + +def calculate_advanced_statistics(df_result): + # 定义要统计的列的基础名称 + error_cols = ['error_mean', 'error_median', 'error_z_mean', 'error_z_median'] + error_ratio_cols = ['error_ratio_mean', 'error_ratio_median', + 'error_ratio_z_mean', 'error_ratio_z_median'] + direction_cols = ['direction_mean', 'direction_median', + 'direction_z_mean', 'direction_z_median'] + + result_stats = {} + + # 1. 统计error_x列的MAE与STD + for col in error_cols: + if col in df_result.columns: + result_stats[f'MAE_{col}'] = df_result[col].mean() + result_stats[f'STD_{col}'] = df_result[col].std() + + # 2. 统计error_ratio_x列的MAE与STD + for col in error_ratio_cols: + if col in df_result.columns: + result_stats[f'MAE_{col}'] = df_result[col].mean() + result_stats[f'STD_{col}'] = df_result[col].std() + + # 3. 统计error_ratio_x列中小于等于1/3、1/2、1的比例 + # 将所有error_ratio列合并计算比例 + ratio_data = [] + for col in error_ratio_cols: + if col in df_result.columns: + ratio_data.extend(df_result[col].dropna().values) + + if ratio_data: + ratio_series = pd.Series(ratio_data) + result_stats['DoFAcc_1_3'] = (ratio_series <= 1/3).mean() + result_stats['DoFAcc_1_2'] = (ratio_series <= 1/2).mean() + result_stats['DoFAcc_1'] = (ratio_series <= 1).mean() + else: + result_stats['DoFAcc_1_3'] = np.nan + result_stats['DoFAcc_1_2'] = np.nan + result_stats['DoFAcc_1'] = np.nan + + # 4. 统计direction_x列中1的比例(DSS - Direction Sign Score) + direction_data = [] + for col in direction_cols: + if col in df_result.columns: + direction_data.extend(df_result[col].dropna().values) + + if direction_data: + direction_series = pd.Series(direction_data) + result_stats['DSS'] = direction_series.mean() # 1的比例就是均值 + else: + result_stats['DSS'] = np.nan + + return pd.Series(result_stats) + + +def process_autofocus_results(input_file, output_type="distance", group_size=30): + # ========== 图片级统计数据 ========== + # 读取原始数据 + df_original = pd.read_excel(input_file) + + # 检查数据行数是否能被group_size整除 + total_rows = len(df_original) + if total_rows % group_size != 0: + print_with_timestamp(f"警告: 数据行数({total_rows})不能被{group_size}整除,将处理前{total_rows // group_size * group_size}行数据") + + # 计算完整分组数 + n_groups = total_rows // group_size + processed_data = [] + + print_with_timestamp(f"开始处理数据,共{total_rows}行,{n_groups}个完整分组") + + for group_idx in trange(n_groups): + start_idx = group_idx * group_size + end_idx = (group_idx + 1) * group_size + + # 获取当前分组的数据 + group_data = df_original.iloc[start_idx:end_idx].copy() + + # 提取第一个图像路径和其他固定列(假设每组内这些值相同) + first_row = group_data.iloc[0] + path = first_row['path'] + mag = first_row['mag'] + na = first_row['na'] + rix = first_row['rix'] + sharpness = first_row['sharpness'] + relative = first_row['relative'] + label = first_row['label'] + + # 获取当前组的预测值 + pred_values = group_data['pred'].values + + # 1. 直接计算均值 + pred_mean = np.mean(pred_values) + + # 2. 直接计算中值 + pred_median = np.median(pred_values) + + # 3. Z-score筛选后计算均值 + z_scores = (pred_values - np.mean(pred_values)) / np.std(pred_values) + within_3sigma = np.abs(z_scores) <= 3 + pred_z_mean = np.mean(pred_values[within_3sigma]) if np.any(within_3sigma) else np.nan + + # 4. Z-score筛选后计算中值 + pred_z_median = np.median(pred_values[within_3sigma]) if np.any(within_3sigma) else np.nan + + # 统计被Z-score筛选掉的数据点数量 + removed_count = np.sum(~within_3sigma) + + # 构建结果行 + result_row = { + 'path': path, + 'mag': mag, + 'na': na, + 'rix': rix, + 'sharpness': sharpness, + 'relative': relative, + 'label': label, + 'pred_mean': pred_mean, + 'pred_median': pred_median, + 'pred_z_mean': pred_z_mean, + 'pred_z_median': pred_z_median, + 'z_filter_removed': removed_count + } + processed_data.append(result_row) + + # 创建结果DataFrame + df_result = pd.DataFrame(processed_data) + + # 计算误差、方向正确性等 + print_with_timestamp("开始进行图像级数据处理") + # 计算绝对/带符号误差 + df_result["error_mean"] = np.abs(df_result["label"] - df_result["pred_mean"]) + df_result["error_median"] = np.abs(df_result["label"] - df_result["pred_median"]) + df_result["error_z_mean"] = np.abs(df_result["label"] - df_result["pred_z_mean"]) + df_result["error_z_median"] = np.abs(df_result["label"] - df_result["pred_z_median"]) + df_result["serror_mean"] = df_result["label"] - df_result["pred_mean"] + df_result["serror_median"] = df_result["label"] - df_result["pred_median"] + df_result["serror_z_mean"] = df_result["label"] - df_result["pred_z_mean"] + df_result["serror_z_median"] = df_result["label"] - df_result["pred_z_median"] + # 计算绝对/带符号误差景深比 + df_result["dof"] = ((550.0 * df_result["rix"]) / (df_result["na"] ** 2)) + ((df_result["rix"] * 3450.0) / (df_result["mag"] * df_result["na"])) + if output_type == "ratio": + df_result["error_ratio_mean"] = df_result["error_mean"] + df_result["error_ratio_median"] = df_result["error_median"] + df_result["error_ratio_z_mean"] = df_result["error_z_mean"] + df_result["error_ratio_z_median"] = df_result["error_z_median"] + df_result["serror_ratio_mean"] = df_result["serror_mean"] + df_result["serror_ratio_median"] = df_result["serror_median"] + df_result["serror_ratio_z_mean"] = df_result["serror_z_mean"] + df_result["serror_ratio_z_median"] = df_result["serror_z_median"] + else: + df_result["error_ratio_mean"] = df_result["error_mean"] / df_result["dof"] + df_result["error_ratio_median"] = df_result["error_median"] / df_result["dof"] + df_result["error_ratio_z_mean"] = df_result["error_z_mean"] / df_result["dof"] + df_result["error_ratio_z_median"] = df_result["error_z_median"] / df_result["dof"] + df_result["serror_ratio_mean"] = df_result["serror_mean"] / df_result["dof"] + df_result["serror_ratio_median"] = df_result["serror_median"] / df_result["dof"] + df_result["serror_ratio_z_mean"] = df_result["serror_z_mean"] / df_result["dof"] + df_result["serror_ratio_z_median"] = df_result["serror_z_median"] / df_result["dof"] + # 计算方向预测正确性 + df_result["direction_mean"] = (df_result['label'] * df_result['pred_mean'] >= 0).astype(int) + df_result["direction_median"] = (df_result['label'] * df_result['pred_median'] >= 0).astype(int) + df_result["direction_z_mean"] = (df_result['label'] * df_result['pred_z_mean'] >= 0).astype(int) + df_result["direction_z_median"] = (df_result['label'] * df_result['pred_z_median'] >= 0).astype(int) + + # 使用ExcelWriter来同时保存原始数据和分组结果 + with pd.ExcelWriter(input_file, engine='openpyxl', mode='a') as writer: + # 保存分组结果到新sheet + df_result.to_excel(writer, sheet_name='grouped', index=False) + + # ========== 高级统计数据 ========== + print_with_timestamp("开始进行高级数据处理") + # 获取物镜分组定义 + lens_groups = { + "all": {"mag": None, "na": None, "rix": None}, # 全部数据 + "obj1": {"mag": 10, "na": 0.25, "rix": 1.0000}, + "obj2": {"mag": 10, "na": 0.30, "rix": 1.0000}, + "obj3": {"mag": 20, "na": 0.70, "rix": 1.0000}, + "obj4": {"mag": 20, "na": 0.80, "rix": 1.0000}, + "obj5": {"mag": 40, "na": 0.65, "rix": 1.0000}, + "obj6": {"mag": 100, "na": 0.80, "rix": 1.0000}, + "obj7": {"mag": 100, "na": 1.25, "rix": 1.4730} + } + + # 存储所有统计结果 + all_stats = [] + + # 1. 整体数据统计 (all) + print_with_timestamp("正在处理整体数据 (all)...") + overall_stats = calculate_advanced_statistics(df_result) + overall_stats.name = 'all' + all_stats.append(overall_stats) + + # 2. 按物镜参数分组统计 + obj_keys = [f"obj{i}" for i in range(1, 8)] + + for obj_key in obj_keys: + if obj_key in lens_groups: + params = lens_groups[obj_key] + print_with_timestamp(f"正在处理 {obj_key}: {params}...") + + # 筛选数据 + mask = ((df_result['mag'] == params['mag']) & + (df_result['na'] == params['na']) & + (df_result['rix'] == params['rix'])) + + df_subset = df_result[mask] + + if len(df_subset) > 0: + obj_stats = calculate_advanced_statistics(df_subset) + obj_stats.name = obj_key + all_stats.append(obj_stats) + else: + print_with_timestamp(f"警告: 没有找到匹配 {obj_key} 的数据") + + # 合并所有统计结果 + stats_df = pd.concat(all_stats, axis=1).T + + # 设置索引名称顺序 + desired_index = ['all'] + [f'obj{i}' for i in range(1, 8)] + # 只保留存在的索引 + stats_df = stats_df.reindex([idx for idx in desired_index if idx in stats_df.index]) + + with pd.ExcelWriter(input_file, engine='openpyxl', mode='a') as writer: + # 保存原始统计结果 + stats_df.to_excel(writer, sheet_name='stat', index=True) + + print_with_timestamp("处理结束") + return df_result + + +def main(): + # 解析命令行参数 + parser = argparse.ArgumentParser(description="Test script that loads config from a TOML file.") + parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)") + args = parser.parse_args() + with open(args.config, "rb") as f: + cfg = tomllib.load(f) + + # 确定超参数 + model_type = cfg["model_type"] + output_type = cfg["output_type"] + + result_path = f"results/{model_type}_{output_type}.xlsx" + + _ = process_autofocus_results(result_path, output_type) + + +if __name__ == "__main__": + main() +