add data process
This commit is contained in:
parent
05eb174091
commit
d1242b2cc0
259
MOAFStat.py
Normal file
259
MOAFStat.py
Normal file
@ -0,0 +1,259 @@
|
||||
import tomllib
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from MOAFUtils import print_with_timestamp
|
||||
|
||||
|
||||
def calculate_advanced_statistics(df_result):
|
||||
# 定义要统计的列的基础名称
|
||||
error_cols = ['error_mean', 'error_median', 'error_z_mean', 'error_z_median']
|
||||
error_ratio_cols = ['error_ratio_mean', 'error_ratio_median',
|
||||
'error_ratio_z_mean', 'error_ratio_z_median']
|
||||
direction_cols = ['direction_mean', 'direction_median',
|
||||
'direction_z_mean', 'direction_z_median']
|
||||
|
||||
result_stats = {}
|
||||
|
||||
# 1. 统计error_x列的MAE与STD
|
||||
for col in error_cols:
|
||||
if col in df_result.columns:
|
||||
result_stats[f'MAE_{col}'] = df_result[col].mean()
|
||||
result_stats[f'STD_{col}'] = df_result[col].std()
|
||||
|
||||
# 2. 统计error_ratio_x列的MAE与STD
|
||||
for col in error_ratio_cols:
|
||||
if col in df_result.columns:
|
||||
result_stats[f'MAE_{col}'] = df_result[col].mean()
|
||||
result_stats[f'STD_{col}'] = df_result[col].std()
|
||||
|
||||
# 3. 统计error_ratio_x列中小于等于1/3、1/2、1的比例
|
||||
# 将所有error_ratio列合并计算比例
|
||||
ratio_data = []
|
||||
for col in error_ratio_cols:
|
||||
if col in df_result.columns:
|
||||
ratio_data.extend(df_result[col].dropna().values)
|
||||
|
||||
if ratio_data:
|
||||
ratio_series = pd.Series(ratio_data)
|
||||
result_stats['DoFAcc_1_3'] = (ratio_series <= 1/3).mean()
|
||||
result_stats['DoFAcc_1_2'] = (ratio_series <= 1/2).mean()
|
||||
result_stats['DoFAcc_1'] = (ratio_series <= 1).mean()
|
||||
else:
|
||||
result_stats['DoFAcc_1_3'] = np.nan
|
||||
result_stats['DoFAcc_1_2'] = np.nan
|
||||
result_stats['DoFAcc_1'] = np.nan
|
||||
|
||||
# 4. 统计direction_x列中1的比例(DSS - Direction Sign Score)
|
||||
direction_data = []
|
||||
for col in direction_cols:
|
||||
if col in df_result.columns:
|
||||
direction_data.extend(df_result[col].dropna().values)
|
||||
|
||||
if direction_data:
|
||||
direction_series = pd.Series(direction_data)
|
||||
result_stats['DSS'] = direction_series.mean() # 1的比例就是均值
|
||||
else:
|
||||
result_stats['DSS'] = np.nan
|
||||
|
||||
return pd.Series(result_stats)
|
||||
|
||||
|
||||
def process_autofocus_results(input_file, output_type="distance", group_size=30):
|
||||
# ========== 图片级统计数据 ==========
|
||||
# 读取原始数据
|
||||
df_original = pd.read_excel(input_file)
|
||||
|
||||
# 检查数据行数是否能被group_size整除
|
||||
total_rows = len(df_original)
|
||||
if total_rows % group_size != 0:
|
||||
print_with_timestamp(f"警告: 数据行数({total_rows})不能被{group_size}整除,将处理前{total_rows // group_size * group_size}行数据")
|
||||
|
||||
# 计算完整分组数
|
||||
n_groups = total_rows // group_size
|
||||
processed_data = []
|
||||
|
||||
print_with_timestamp(f"开始处理数据,共{total_rows}行,{n_groups}个完整分组")
|
||||
|
||||
for group_idx in trange(n_groups):
|
||||
start_idx = group_idx * group_size
|
||||
end_idx = (group_idx + 1) * group_size
|
||||
|
||||
# 获取当前分组的数据
|
||||
group_data = df_original.iloc[start_idx:end_idx].copy()
|
||||
|
||||
# 提取第一个图像路径和其他固定列(假设每组内这些值相同)
|
||||
first_row = group_data.iloc[0]
|
||||
path = first_row['path']
|
||||
mag = first_row['mag']
|
||||
na = first_row['na']
|
||||
rix = first_row['rix']
|
||||
sharpness = first_row['sharpness']
|
||||
relative = first_row['relative']
|
||||
label = first_row['label']
|
||||
|
||||
# 获取当前组的预测值
|
||||
pred_values = group_data['pred'].values
|
||||
|
||||
# 1. 直接计算均值
|
||||
pred_mean = np.mean(pred_values)
|
||||
|
||||
# 2. 直接计算中值
|
||||
pred_median = np.median(pred_values)
|
||||
|
||||
# 3. Z-score筛选后计算均值
|
||||
z_scores = (pred_values - np.mean(pred_values)) / np.std(pred_values)
|
||||
within_3sigma = np.abs(z_scores) <= 3
|
||||
pred_z_mean = np.mean(pred_values[within_3sigma]) if np.any(within_3sigma) else np.nan
|
||||
|
||||
# 4. Z-score筛选后计算中值
|
||||
pred_z_median = np.median(pred_values[within_3sigma]) if np.any(within_3sigma) else np.nan
|
||||
|
||||
# 统计被Z-score筛选掉的数据点数量
|
||||
removed_count = np.sum(~within_3sigma)
|
||||
|
||||
# 构建结果行
|
||||
result_row = {
|
||||
'path': path,
|
||||
'mag': mag,
|
||||
'na': na,
|
||||
'rix': rix,
|
||||
'sharpness': sharpness,
|
||||
'relative': relative,
|
||||
'label': label,
|
||||
'pred_mean': pred_mean,
|
||||
'pred_median': pred_median,
|
||||
'pred_z_mean': pred_z_mean,
|
||||
'pred_z_median': pred_z_median,
|
||||
'z_filter_removed': removed_count
|
||||
}
|
||||
processed_data.append(result_row)
|
||||
|
||||
# 创建结果DataFrame
|
||||
df_result = pd.DataFrame(processed_data)
|
||||
|
||||
# 计算误差、方向正确性等
|
||||
print_with_timestamp("开始进行图像级数据处理")
|
||||
# 计算绝对/带符号误差
|
||||
df_result["error_mean"] = np.abs(df_result["label"] - df_result["pred_mean"])
|
||||
df_result["error_median"] = np.abs(df_result["label"] - df_result["pred_median"])
|
||||
df_result["error_z_mean"] = np.abs(df_result["label"] - df_result["pred_z_mean"])
|
||||
df_result["error_z_median"] = np.abs(df_result["label"] - df_result["pred_z_median"])
|
||||
df_result["serror_mean"] = df_result["label"] - df_result["pred_mean"]
|
||||
df_result["serror_median"] = df_result["label"] - df_result["pred_median"]
|
||||
df_result["serror_z_mean"] = df_result["label"] - df_result["pred_z_mean"]
|
||||
df_result["serror_z_median"] = df_result["label"] - df_result["pred_z_median"]
|
||||
# 计算绝对/带符号误差景深比
|
||||
df_result["dof"] = ((550.0 * df_result["rix"]) / (df_result["na"] ** 2)) + ((df_result["rix"] * 3450.0) / (df_result["mag"] * df_result["na"]))
|
||||
if output_type == "ratio":
|
||||
df_result["error_ratio_mean"] = df_result["error_mean"]
|
||||
df_result["error_ratio_median"] = df_result["error_median"]
|
||||
df_result["error_ratio_z_mean"] = df_result["error_z_mean"]
|
||||
df_result["error_ratio_z_median"] = df_result["error_z_median"]
|
||||
df_result["serror_ratio_mean"] = df_result["serror_mean"]
|
||||
df_result["serror_ratio_median"] = df_result["serror_median"]
|
||||
df_result["serror_ratio_z_mean"] = df_result["serror_z_mean"]
|
||||
df_result["serror_ratio_z_median"] = df_result["serror_z_median"]
|
||||
else:
|
||||
df_result["error_ratio_mean"] = df_result["error_mean"] / df_result["dof"]
|
||||
df_result["error_ratio_median"] = df_result["error_median"] / df_result["dof"]
|
||||
df_result["error_ratio_z_mean"] = df_result["error_z_mean"] / df_result["dof"]
|
||||
df_result["error_ratio_z_median"] = df_result["error_z_median"] / df_result["dof"]
|
||||
df_result["serror_ratio_mean"] = df_result["serror_mean"] / df_result["dof"]
|
||||
df_result["serror_ratio_median"] = df_result["serror_median"] / df_result["dof"]
|
||||
df_result["serror_ratio_z_mean"] = df_result["serror_z_mean"] / df_result["dof"]
|
||||
df_result["serror_ratio_z_median"] = df_result["serror_z_median"] / df_result["dof"]
|
||||
# 计算方向预测正确性
|
||||
df_result["direction_mean"] = (df_result['label'] * df_result['pred_mean'] >= 0).astype(int)
|
||||
df_result["direction_median"] = (df_result['label'] * df_result['pred_median'] >= 0).astype(int)
|
||||
df_result["direction_z_mean"] = (df_result['label'] * df_result['pred_z_mean'] >= 0).astype(int)
|
||||
df_result["direction_z_median"] = (df_result['label'] * df_result['pred_z_median'] >= 0).astype(int)
|
||||
|
||||
# 使用ExcelWriter来同时保存原始数据和分组结果
|
||||
with pd.ExcelWriter(input_file, engine='openpyxl', mode='a') as writer:
|
||||
# 保存分组结果到新sheet
|
||||
df_result.to_excel(writer, sheet_name='grouped', index=False)
|
||||
|
||||
# ========== 高级统计数据 ==========
|
||||
print_with_timestamp("开始进行高级数据处理")
|
||||
# 获取物镜分组定义
|
||||
lens_groups = {
|
||||
"all": {"mag": None, "na": None, "rix": None}, # 全部数据
|
||||
"obj1": {"mag": 10, "na": 0.25, "rix": 1.0000},
|
||||
"obj2": {"mag": 10, "na": 0.30, "rix": 1.0000},
|
||||
"obj3": {"mag": 20, "na": 0.70, "rix": 1.0000},
|
||||
"obj4": {"mag": 20, "na": 0.80, "rix": 1.0000},
|
||||
"obj5": {"mag": 40, "na": 0.65, "rix": 1.0000},
|
||||
"obj6": {"mag": 100, "na": 0.80, "rix": 1.0000},
|
||||
"obj7": {"mag": 100, "na": 1.25, "rix": 1.4730}
|
||||
}
|
||||
|
||||
# 存储所有统计结果
|
||||
all_stats = []
|
||||
|
||||
# 1. 整体数据统计 (all)
|
||||
print_with_timestamp("正在处理整体数据 (all)...")
|
||||
overall_stats = calculate_advanced_statistics(df_result)
|
||||
overall_stats.name = 'all'
|
||||
all_stats.append(overall_stats)
|
||||
|
||||
# 2. 按物镜参数分组统计
|
||||
obj_keys = [f"obj{i}" for i in range(1, 8)]
|
||||
|
||||
for obj_key in obj_keys:
|
||||
if obj_key in lens_groups:
|
||||
params = lens_groups[obj_key]
|
||||
print_with_timestamp(f"正在处理 {obj_key}: {params}...")
|
||||
|
||||
# 筛选数据
|
||||
mask = ((df_result['mag'] == params['mag']) &
|
||||
(df_result['na'] == params['na']) &
|
||||
(df_result['rix'] == params['rix']))
|
||||
|
||||
df_subset = df_result[mask]
|
||||
|
||||
if len(df_subset) > 0:
|
||||
obj_stats = calculate_advanced_statistics(df_subset)
|
||||
obj_stats.name = obj_key
|
||||
all_stats.append(obj_stats)
|
||||
else:
|
||||
print_with_timestamp(f"警告: 没有找到匹配 {obj_key} 的数据")
|
||||
|
||||
# 合并所有统计结果
|
||||
stats_df = pd.concat(all_stats, axis=1).T
|
||||
|
||||
# 设置索引名称顺序
|
||||
desired_index = ['all'] + [f'obj{i}' for i in range(1, 8)]
|
||||
# 只保留存在的索引
|
||||
stats_df = stats_df.reindex([idx for idx in desired_index if idx in stats_df.index])
|
||||
|
||||
with pd.ExcelWriter(input_file, engine='openpyxl', mode='a') as writer:
|
||||
# 保存原始统计结果
|
||||
stats_df.to_excel(writer, sheet_name='stat', index=True)
|
||||
|
||||
print_with_timestamp("处理结束")
|
||||
return df_result
|
||||
|
||||
|
||||
def main():
|
||||
# 解析命令行参数
|
||||
parser = argparse.ArgumentParser(description="Test script that loads config from a TOML file.")
|
||||
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
|
||||
args = parser.parse_args()
|
||||
with open(args.config, "rb") as f:
|
||||
cfg = tomllib.load(f)
|
||||
|
||||
# 确定超参数
|
||||
model_type = cfg["model_type"]
|
||||
output_type = cfg["output_type"]
|
||||
|
||||
result_path = f"results/{model_type}_{output_type}.xlsx"
|
||||
|
||||
_ = process_autofocus_results(result_path, output_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user