MOAF/MOAFStat.py
2025-11-03 10:31:14 +08:00

252 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import tomllib
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from MOAFUtils import print_with_timestamp
def calculate_advanced_statistics(df_result):
# 定义要统计的列的基础名称
error_cols = ['error_mean', 'error_median', 'error_z_mean', 'error_z_median']
error_ratio_cols = ['error_ratio_mean', 'error_ratio_median',
'error_ratio_z_mean', 'error_ratio_z_median']
direction_cols = ['direction_mean', 'direction_median',
'direction_z_mean', 'direction_z_median']
result_stats = {}
# 1. 统计error_x列的MAE与STD
for col in error_cols:
if col in df_result.columns:
result_stats[f'MAE_{col}'] = df_result[col].mean()
result_stats[f'STD_{col}'] = df_result[col].std()
# 2. 统计error_ratio_x列的MAE与STD
for col in error_ratio_cols:
if col in df_result.columns:
result_stats[f'MAE_{col}'] = df_result[col].mean()
result_stats[f'STD_{col}'] = df_result[col].std()
# 3. 统计error_ratio_x列中小于等于1/3、1/2、1的比例
# 将所有error_ratio列合并计算比例
ratio_data = []
for col in error_ratio_cols:
if col in df_result.columns:
ratio_data.extend(df_result[col].dropna().values)
if ratio_data:
ratio_series = pd.Series(ratio_data)
result_stats['DoFAcc_1_3'] = (ratio_series <= 1/3).mean()
result_stats['DoFAcc_1_2'] = (ratio_series <= 1/2).mean()
result_stats['DoFAcc_1'] = (ratio_series <= 1).mean()
else:
result_stats['DoFAcc_1_3'] = np.nan
result_stats['DoFAcc_1_2'] = np.nan
result_stats['DoFAcc_1'] = np.nan
# 4. 统计direction_x列中1的比例DSS - Direction Sign Score
direction_data = []
for col in direction_cols:
if col in df_result.columns:
direction_data.extend(df_result[col].dropna().values)
if direction_data:
direction_series = pd.Series(direction_data)
result_stats['DSS'] = direction_series.mean() # 1的比例就是均值
else:
result_stats['DSS'] = np.nan
return pd.Series(result_stats)
def process_autofocus_results(input_file, dataset_type, group_size=30):
# ========== 图片级统计数据 ==========
# 读取原始数据
df_original = pd.read_excel(input_file)
# 检查数据行数是否能被group_size整除
total_rows = len(df_original)
if total_rows % group_size != 0:
print_with_timestamp(f"警告: 数据行数({total_rows})不能被{group_size}整除,将处理前{total_rows // group_size * group_size}行数据")
# 计算完整分组数
n_groups = total_rows // group_size
processed_data = []
print_with_timestamp(f"开始处理数据,共{total_rows}行,{n_groups}个完整分组")
for group_idx in trange(n_groups):
start_idx = group_idx * group_size
end_idx = (group_idx + 1) * group_size
# 获取当前分组的数据
group_data = df_original.iloc[start_idx:end_idx].copy()
# 提取第一个图像路径和其他固定列(假设每组内这些值相同)
first_row = group_data.iloc[0]
path = first_row['path']
mag = first_row['mag']
na = first_row['na']
rix = first_row['rix']
sharpness = first_row['sharpness']
relative = first_row['relative']
label = first_row['label']
# 获取当前组的预测值
pred_values = group_data['pred'].values
# 1. 直接计算均值
pred_mean = np.mean(pred_values)
# 2. 直接计算中值
pred_median = np.median(pred_values)
# 3. Z-score筛选后计算均值
z_scores = (pred_values - np.mean(pred_values)) / np.std(pred_values)
within_3sigma = np.abs(z_scores) <= 3
pred_z_mean = np.mean(pred_values[within_3sigma]) if np.any(within_3sigma) else np.nan
# 4. Z-score筛选后计算中值
pred_z_median = np.median(pred_values[within_3sigma]) if np.any(within_3sigma) else np.nan
# 统计被Z-score筛选掉的数据点数量
removed_count = np.sum(~within_3sigma)
# 构建结果行
result_row = {
'path': path,
'mag': mag,
'na': na,
'rix': rix,
'sharpness': sharpness,
'relative': relative,
'label': label,
'pred_mean': pred_mean,
'pred_median': pred_median,
'pred_z_mean': pred_z_mean,
'pred_z_median': pred_z_median,
'z_filter_removed': removed_count
}
processed_data.append(result_row)
# 创建结果DataFrame
df_result = pd.DataFrame(processed_data)
# 计算误差、方向正确性等
print_with_timestamp("开始进行图像级数据处理")
df_result["dof"] = ((550.0 * df_result["rix"]) / (df_result["na"] ** 2)) + ((df_result["rix"] * 3450.0) / (df_result["mag"] * df_result["na"]))
# 计算绝对/带符号误差/景深比
df_result["error_mean"] = np.abs(df_result["label"] - df_result["pred_mean"])
df_result["error_median"] = np.abs(df_result["label"] - df_result["pred_median"])
df_result["error_z_mean"] = np.abs(df_result["label"] - df_result["pred_z_mean"])
df_result["error_z_median"] = np.abs(df_result["label"] - df_result["pred_z_median"])
df_result["serror_mean"] = df_result["label"] - df_result["pred_mean"]
df_result["serror_median"] = df_result["label"] - df_result["pred_median"]
df_result["serror_z_mean"] = df_result["label"] - df_result["pred_z_mean"]
df_result["serror_z_median"] = df_result["label"] - df_result["pred_z_median"]
df_result["error_ratio_mean"] = df_result["error_mean"] / df_result["dof"]
df_result["error_ratio_median"] = df_result["error_median"] / df_result["dof"]
df_result["error_ratio_z_mean"] = df_result["error_z_mean"] / df_result["dof"]
df_result["error_ratio_z_median"] = df_result["error_z_median"] / df_result["dof"]
df_result["serror_ratio_mean"] = df_result["serror_mean"] / df_result["dof"]
df_result["serror_ratio_median"] = df_result["serror_median"] / df_result["dof"]
df_result["serror_ratio_z_mean"] = df_result["serror_z_mean"] / df_result["dof"]
df_result["serror_ratio_z_median"] = df_result["serror_z_median"] / df_result["dof"]
# 计算方向预测正确性
df_result["direction_mean"] = (df_result['label'] * df_result['pred_mean'] >= 0).astype(int)
df_result["direction_median"] = (df_result['label'] * df_result['pred_median'] >= 0).astype(int)
df_result["direction_z_mean"] = (df_result['label'] * df_result['pred_z_mean'] >= 0).astype(int)
df_result["direction_z_median"] = (df_result['label'] * df_result['pred_z_median'] >= 0).astype(int)
# 使用ExcelWriter来同时保存原始数据和分组结果
with pd.ExcelWriter(input_file, engine='openpyxl', mode='a') as writer:
# 保存分组结果到新sheet
df_result.to_excel(writer, sheet_name='grouped', index=False)
# ========== 高级统计数据 ==========
print_with_timestamp("开始进行高级数据处理")
# 获取物镜分组定义
lens_groups = {
"all": {"mag": None, "na": None, "rix": None}, # 全部数据
"obj1": {"mag": 10, "na": 0.25, "rix": 1.0000},
"obj2": {"mag": 10, "na": 0.30, "rix": 1.0000},
"obj3": {"mag": 20, "na": 0.70, "rix": 1.0000},
"obj4": {"mag": 20, "na": 0.80, "rix": 1.0000},
"obj5": {"mag": 40, "na": 0.65, "rix": 1.0000},
"obj6": {"mag": 100, "na": 0.80, "rix": 1.0000},
"obj7": {"mag": 100, "na": 1.25, "rix": 1.4730}
}
# 存储所有统计结果
all_stats = []
# 1. 整体数据统计 (all)
print_with_timestamp("正在处理整体数据 (all)...")
overall_stats = calculate_advanced_statistics(df_result)
overall_stats.name = 'all'
all_stats.append(overall_stats)
# 2. 按物镜参数分组统计
obj_keys = [f"obj{i}" for i in range(1, 8)]
for obj_key in obj_keys:
if obj_key in lens_groups:
params = lens_groups[obj_key]
print_with_timestamp(f"正在处理 {obj_key}: {params}...")
# 筛选数据
mask = ((df_result['mag'] == params['mag']) &
(df_result['na'] == params['na']) &
(df_result['rix'] == params['rix']))
df_subset = df_result[mask]
if len(df_subset) > 0:
obj_stats = calculate_advanced_statistics(df_subset)
obj_stats.name = obj_key
all_stats.append(obj_stats)
else:
print_with_timestamp(f"警告: 没有找到匹配 {obj_key} 的数据")
# 合并所有统计结果
stats_df = pd.concat(all_stats, axis=1).T
# 设置索引名称顺序
desired_index = ['all'] + [f'obj{i}' for i in range(1, 8)]
# 只保留存在的索引
stats_df = stats_df.reindex([idx for idx in desired_index if idx in stats_df.index])
with pd.ExcelWriter(input_file, engine='openpyxl', mode='a') as writer:
# 保存原始统计结果
stats_df.to_excel(writer, sheet_name='stat', index=True)
print_with_timestamp("处理结束")
return df_result
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description="Test script that loads config from a TOML file.")
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
args = parser.parse_args()
with open(args.config, "rb") as f:
cfg = tomllib.load(f)
# 确定超参数
model_type = cfg["model_type"]
dataset_type = cfg["dataset_type"]
result_path = f"results/{model_type}_{dataset_type}.xlsx"
_ = process_autofocus_results(result_path, dataset_type)
if __name__ == "__main__":
main()