from pathlib import Path import csv try: import torch from torch.utils.data import DataLoader except ImportError as exc: raise ImportError("缺少 PyTorch。请确认当前 conda 环境为 torch271,并已配置 torch271+cu126。") from exc try: import timm except ImportError as exc: raise ImportError("缺少 timm。请在 torch271 环境中安装 timm,或告诉我改用其他路线。") from exc from dataset import DATA_ROOT, DefocusDataset, as_posix_path, make_split_lists SEED = 2026 MODEL_NAME = "mobilenetv4_conv_small" PRETRAINED = False DEVICE = "cpu" BATCH_SIZE = 64 NUM_WORKERS = 8 CHECKPOINT_PATH = Path("checkpoints") / "best_mobilenetv4_defocus.pth" OUTPUT_DIR = Path("predictions") OUTPUT_CSV = OUTPUT_DIR / "test_predictions.csv" TEST_BATCH_SIZE = 2 TEST_NUM_WORKERS = 0 TEST_MAX_SAMPLES = 8 TEST_OUTPUT_CSV = OUTPUT_DIR / "test_predictions_smoke.csv" def get_device(device_name=DEVICE): if device_name.startswith("cuda") and not torch.cuda.is_available(): raise RuntimeError("当前指定设备为 cuda:0,但 PyTorch 没有检测到可用 CUDA。请检查 torch271+cu126 环境。") return torch.device(device_name) def create_model(): model = timm.create_model(MODEL_NAME, pretrained=PRETRAINED, num_classes=1) return model def load_model(checkpoint_path=CHECKPOINT_PATH, device_name=DEVICE): checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): raise FileNotFoundError(f"找不到检查点文件:{as_posix_path(checkpoint_path)}") device = get_device(device_name) checkpoint = torch.load(checkpoint_path, map_location=device) model = create_model() if isinstance(checkpoint, dict) and "model_state" in checkpoint: model.load_state_dict(checkpoint["model_state"]) else: model.load_state_dict(checkpoint) model.to(device) model.eval() return model, device def make_test_loader(data_root=DATA_ROOT, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, seed=SEED, max_samples=None): split_data = make_split_lists(data_root=data_root, seed=seed) image_paths = split_data["test"]["image_paths"] labels = split_data["test"]["labels"] if max_samples is not None: image_paths = image_paths[:max_samples] labels = labels[:max_samples] if not image_paths: raise RuntimeError("测试集为空,请检查数据目录结构和 field 层面的划分结果。") dataset = DefocusDataset(image_paths, labels, return_path=True) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, ) return dataloader @torch.no_grad() def predict(model, dataloader, device): rows = [] for images, targets, image_paths in dataloader: images = images.to(device, non_blocking=True) predictions = model(images).detach().cpu().view(-1) targets = targets.detach().cpu().view(-1) for image_path, label, prediction in zip(image_paths, targets.tolist(), predictions.tolist()): rows.append({ "image_path": as_posix_path(image_path), "label": label, "prediction": prediction, }) return rows def save_predictions_csv(rows, output_csv=OUTPUT_CSV): output_csv = Path(output_csv) output_csv.parent.mkdir(parents=True, exist_ok=True) with output_csv.open("w", newline="", encoding="utf-8") as file: writer = csv.DictWriter(file, fieldnames=["image_path", "label", "prediction"]) writer.writeheader() writer.writerows(rows) return output_csv def run_test(checkpoint_path=CHECKPOINT_PATH, output_csv=OUTPUT_CSV, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, max_samples=None): model, device = load_model(checkpoint_path=checkpoint_path, device_name=DEVICE) dataloader = make_test_loader( data_root=DATA_ROOT, batch_size=batch_size, num_workers=num_workers, seed=SEED, max_samples=max_samples, ) rows = predict(model, dataloader, device) output_csv = save_predictions_csv(rows, output_csv=output_csv) print(f"已保存测试预测结果:{as_posix_path(output_csv)}") print(f"写入样本数:{len(rows)}") return rows def test_eval(): """小规模测试:只跑少量测试集样本,确认加载、推理和 CSV 写出链路。""" if not DATA_ROOT.exists(): print("数据集根目录不存在,先跳过小规模测试。") return if not CHECKPOINT_PATH.exists(): print(f"最佳检查点不存在,先跳过小规模测试:{as_posix_path(CHECKPOINT_PATH)}") return run_test( checkpoint_path=CHECKPOINT_PATH, output_csv=TEST_OUTPUT_CSV, batch_size=TEST_BATCH_SIZE, num_workers=TEST_NUM_WORKERS, max_samples=TEST_MAX_SAMPLES, ) def main(): run_test( checkpoint_path=CHECKPOINT_PATH, output_csv=OUTPUT_CSV, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, ) if __name__ == "__main__": test_eval() # main()