import torch from torch.utils.data import DataLoader import tomllib import argparse import numpy as np from tqdm import tqdm from pathlib import Path from MOAFUtils import print_with_timestamp from MOAFDatasets import MOAFDataset from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE, MOAFWithMMLP def test(model, test_loader, device, model_type, dataset_type): model.eval() results = [] with torch.no_grad(): for data in tqdm(test_loader, desc=f"[Test]"): images = data["image"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) outputs = model(images, params) results.extend(outputs.squeeze(1).cpu().numpy()) results = np.array(results) # 直接预测结果保存为 excel 文件 df = test_loader.dataset.get_dataframe() df["pred"] = results Path("results").mkdir(exist_ok=True, parents=True) # !pip install openpyxl df.to_excel(f"results/{model_type}_{dataset_type}.xlsx", index=False) def main(): # 解析命令行参数 parser = argparse.ArgumentParser(description="Test script that loads config from a TOML file.") parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)") args = parser.parse_args() with open(args.config, "rb") as f: cfg = tomllib.load(f) # 确定超参数 model_type = cfg["model_type"] dataset_type = cfg["dataset_type"] dataset_dir = cfg["dataset_dir"] batch_size = int(cfg["batch_size"]) num_workers = int(cfg["num_workers"]) objective_params_list = cfg["test_objective_params_list"] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print_with_timestamp(f"Using device {device}") # 加载数据集 test_set = MOAFDataset(dataset_dir, "test", objective_params_list) print_with_timestamp("Dataset Done") test_loader = DataLoader( test_set, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True, persistent_workers=True ) print_with_timestamp("Dataset Loaded") # 模型选择 if "film" in model_type: fusion_depth_list = [int(ch) for ch in model_type[4:]] model = MOAFWithFiLM(fusion_depth_list).to(device) elif "cca" in model_type: fusion_depth_list = [int(ch) for ch in model_type[3:]] model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device) elif "se" in model_type: fusion_depth_list = [int(ch) for ch in model_type[2:]] model = MOAFWithSE(fusion_depth_list).to(device) elif "mmlp" in model_type: model = MOAFWithMMLP(fusion_depth_list).to(device) else: model = MOAFNoFusion().to(device) checkpoint = torch.load(f"ckpts/{model_type}_{dataset_type}_best_model.pt", map_location=device, weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) print_with_timestamp("Model Loaded") print_with_timestamp("Start testing") test(model, test_loader, device, model_type, dataset_type) print_with_timestamp("Testing completed!") if __name__ == "__main__": main()