MOAF/MOAFTest.py

91 lines
3.2 KiB
Python

import torch
from torch.utils.data import DataLoader
import tomllib
import argparse
import numpy as np
from tqdm import tqdm
from pathlib import Path
from MOAFUtils import print_with_timestamp
from MOAFDatasets import MOAFDataset
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE, MOAFWithMMLP
def test(model, test_loader, device, model_type, dataset_type):
model.eval()
results = []
with torch.no_grad():
for data in tqdm(test_loader, desc=f"[Test]", ncols=180):
images = data["image"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
outputs = model(images, params)
results.extend(outputs.squeeze(1).cpu().numpy())
results = np.array(results)
# 直接预测结果保存为 excel 文件
df = test_loader.dataset.get_dataframe()
df["pred"] = results
Path("results").mkdir(exist_ok=True, parents=True)
# !pip install openpyxl
df.to_excel(f"results/{model_type}_{dataset_type}.xlsx", index=False)
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description="Test script that loads config from a TOML file.")
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
args = parser.parse_args()
with open(args.config, "rb") as f:
cfg = tomllib.load(f)
# 确定超参数
model_type = cfg["model_type"]
dataset_type = cfg["dataset_type"]
dataset_dir = cfg["dataset_dir"]
batch_size = int(cfg["batch_size"])
num_workers = int(cfg["num_workers"])
objective_params_list = cfg["test_objective_params_list"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_with_timestamp(f"Using device {device}")
# 加载数据集
test_set = MOAFDataset(dataset_dir, "test", objective_params_list)
print_with_timestamp("Dataset Done")
test_loader = DataLoader(
test_set, batch_size=batch_size, num_workers=num_workers,
shuffle=False, pin_memory=True, persistent_workers=True
)
print_with_timestamp("Dataset Loaded")
# 模型选择
if "film" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[4:]]
model = MOAFWithFiLM(fusion_depth_list).to(device)
elif "cca" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[3:]]
model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device)
elif "se" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[2:]]
model = MOAFWithSE(fusion_depth_list).to(device)
elif "mmlp" in model_type:
model = MOAFWithMMLP().to(device)
else:
model = MOAFNoFusion().to(device)
checkpoint = torch.load(f"ckpts/{model_type}_{dataset_type}_best_model.pt", map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
print_with_timestamp("Model Loaded")
print_with_timestamp("Start testing")
test(model, test_loader, device, model_type, dataset_type)
print_with_timestamp("Testing completed!")
if __name__ == "__main__":
main()