import tomllib import argparse import numpy as np import pandas as pd from tqdm import tqdm, trange from MOAFUtils import print_with_timestamp def calculate_advanced_statistics(df_result): # 定义要统计的列的基础名称 error_cols = ['error_mean', 'error_median', 'error_z_mean', 'error_z_median'] error_ratio_cols = ['error_ratio_mean', 'error_ratio_median', 'error_ratio_z_mean', 'error_ratio_z_median'] direction_cols = ['direction_mean', 'direction_median', 'direction_z_mean', 'direction_z_median'] result_stats = {} # 1. 统计error_x列的MAE与STD for col in error_cols: if col in df_result.columns: result_stats[f'MAE_{col}'] = df_result[col].mean() result_stats[f'STD_{col}'] = df_result[col].std() # 2. 统计error_ratio_x列的MAE与STD for col in error_ratio_cols: if col in df_result.columns: result_stats[f'MAE_{col}'] = df_result[col].mean() result_stats[f'STD_{col}'] = df_result[col].std() # 3. 统计error_ratio_x列中小于等于1/3、1/2、1的比例 for col in error_ratio_cols: if col in df_result.columns: result_stats[f'DoFAcc_1_3_{col}'] = (df_result[col] <= 1/3).mean() result_stats[f'DoFAcc_1_2_{col}'] = (df_result[col] <= 1/2).mean() result_stats[f'DoFAcc_1_{col}'] = (df_result[col] <= 1).mean() # 4. 统计direction_x列中1的比例(DSS - Direction Sign Score) for col in direction_cols: if col in df_result.columns: result_stats[f'DSS_{col}'] = df_result[col].mean() return pd.Series(result_stats) def process_autofocus_results(input_file, group_size=30): # ========== 图片级统计数据 ========== # 读取原始数据 df_original = pd.read_excel(input_file) # 检查数据行数是否能被group_size整除 total_rows = len(df_original) if total_rows % group_size != 0: print_with_timestamp(f"警告: 数据行数({total_rows})不能被{group_size}整除,将处理前{total_rows // group_size * group_size}行数据") # 计算完整分组数 n_groups = total_rows // group_size processed_data = [] print_with_timestamp(f"开始处理数据,共{total_rows}行,{n_groups}个完整分组") for group_idx in trange(n_groups, ncols=180): start_idx = group_idx * group_size end_idx = (group_idx + 1) * group_size # 获取当前分组的数据 group_data = df_original.iloc[start_idx:end_idx].copy() # 提取第一个图像路径和其他固定列(假设每组内这些值相同) first_row = group_data.iloc[0] path = first_row['path'] mag = first_row['mag'] na = first_row['na'] rix = first_row['rix'] sharpness = first_row['sharpness'] relative = first_row['relative'] label = first_row['label'] # 获取当前组的预测值 pred_values = group_data['pred'].values # 1. 直接计算均值 pred_mean = np.mean(pred_values) # 2. 直接计算中值 pred_median = np.median(pred_values) # 3. Z-score筛选后计算均值 z_scores = (pred_values - np.mean(pred_values)) / np.std(pred_values) within_3sigma = np.abs(z_scores) <= 3 pred_z_mean = np.mean(pred_values[within_3sigma]) if np.any(within_3sigma) else np.nan # 4. Z-score筛选后计算中值 pred_z_median = np.median(pred_values[within_3sigma]) if np.any(within_3sigma) else np.nan # 统计被Z-score筛选掉的数据点数量 removed_count = np.sum(~within_3sigma) # 构建结果行 result_row = { 'path': path, 'mag': mag, 'na': na, 'rix': rix, 'sharpness': sharpness, 'relative': relative, 'label': label, 'pred_mean': pred_mean, 'pred_median': pred_median, 'pred_z_mean': pred_z_mean, 'pred_z_median': pred_z_median, 'z_filter_removed': removed_count } processed_data.append(result_row) # 创建结果DataFrame df_result = pd.DataFrame(processed_data) # 计算误差、方向正确性等 print_with_timestamp("开始进行图像级数据处理") df_result["dof"] = ((550.0 * df_result["rix"]) / (df_result["na"] ** 2)) + ((df_result["rix"] * 3450.0) / (df_result["mag"] * df_result["na"])) # 计算绝对/带符号误差/景深比 df_result["error_mean"] = np.abs(df_result["label"] - df_result["pred_mean"]) df_result["error_median"] = np.abs(df_result["label"] - df_result["pred_median"]) df_result["error_z_mean"] = np.abs(df_result["label"] - df_result["pred_z_mean"]) df_result["error_z_median"] = np.abs(df_result["label"] - df_result["pred_z_median"]) df_result["serror_mean"] = df_result["label"] - df_result["pred_mean"] df_result["serror_median"] = df_result["label"] - df_result["pred_median"] df_result["serror_z_mean"] = df_result["label"] - df_result["pred_z_mean"] df_result["serror_z_median"] = df_result["label"] - df_result["pred_z_median"] df_result["error_ratio_mean"] = df_result["error_mean"] / df_result["dof"] df_result["error_ratio_median"] = df_result["error_median"] / df_result["dof"] df_result["error_ratio_z_mean"] = df_result["error_z_mean"] / df_result["dof"] df_result["error_ratio_z_median"] = df_result["error_z_median"] / df_result["dof"] df_result["serror_ratio_mean"] = df_result["serror_mean"] / df_result["dof"] df_result["serror_ratio_median"] = df_result["serror_median"] / df_result["dof"] df_result["serror_ratio_z_mean"] = df_result["serror_z_mean"] / df_result["dof"] df_result["serror_ratio_z_median"] = df_result["serror_z_median"] / df_result["dof"] # 计算方向预测正确性 df_result["direction_mean"] = (df_result['label'] * df_result['pred_mean'] >= 0).astype(int) df_result["direction_median"] = (df_result['label'] * df_result['pred_median'] >= 0).astype(int) df_result["direction_z_mean"] = (df_result['label'] * df_result['pred_z_mean'] >= 0).astype(int) df_result["direction_z_median"] = (df_result['label'] * df_result['pred_z_median'] >= 0).astype(int) # 使用ExcelWriter来同时保存原始数据和分组结果 with pd.ExcelWriter(input_file, engine='openpyxl', mode='a') as writer: # 保存分组结果到新sheet df_result.to_excel(writer, sheet_name='grouped', index=False) # ========== 高级统计数据 ========== print_with_timestamp("开始进行高级数据处理") # 获取物镜分组定义 lens_groups = { "all": {"mag": None, "na": None, "rix": None}, # 全部数据 "obj1": {"mag": 10, "na": 0.25, "rix": 1.0000}, "obj2": {"mag": 10, "na": 0.30, "rix": 1.0000}, "obj3": {"mag": 20, "na": 0.70, "rix": 1.0000}, "obj4": {"mag": 20, "na": 0.80, "rix": 1.0000}, "obj5": {"mag": 40, "na": 0.65, "rix": 1.0000}, "obj6": {"mag": 100, "na": 0.80, "rix": 1.0000}, "obj7": {"mag": 100, "na": 1.25, "rix": 1.4730} } # 存储所有统计结果 all_stats = [] # 1. 整体数据统计 (all) print_with_timestamp("正在处理整体数据 (all)...") overall_stats = calculate_advanced_statistics(df_result) overall_stats.name = 'all' all_stats.append(overall_stats) # 2. 按物镜参数分组统计 obj_keys = [f"obj{i}" for i in range(1, 8)] for obj_key in obj_keys: if obj_key in lens_groups: params = lens_groups[obj_key] print_with_timestamp(f"正在处理 {obj_key}: {params}...") # 筛选数据 mask = ((df_result['mag'] == params['mag']) & (df_result['na'] == params['na']) & (df_result['rix'] == params['rix'])) df_subset = df_result[mask] if len(df_subset) > 0: obj_stats = calculate_advanced_statistics(df_subset) obj_stats.name = obj_key all_stats.append(obj_stats) else: print_with_timestamp(f"警告: 没有找到匹配 {obj_key} 的数据") # 合并所有统计结果 stats_df = pd.concat(all_stats, axis=1).T # 设置索引名称顺序 desired_index = ['all'] + [f'obj{i}' for i in range(1, 8)] # 只保留存在的索引 stats_df = stats_df.reindex([idx for idx in desired_index if idx in stats_df.index]) with pd.ExcelWriter(input_file, engine='openpyxl', mode='a') as writer: # 保存原始统计结果 stats_df.to_excel(writer, sheet_name='stat', index=True) print_with_timestamp("处理结束") return df_result def main(): # 解析命令行参数 parser = argparse.ArgumentParser(description="Test script that loads config from a TOML file.") parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)") args = parser.parse_args() with open(args.config, "rb") as f: cfg = tomllib.load(f) # 确定超参数 model_type = cfg["model_type"] dataset_type = cfg["dataset_type"] result_path = f"results/{model_type}_{dataset_type}.xlsx" _ = process_autofocus_results(result_path) if __name__ == "__main__": main()