2026-05-17 21:06:33 +08:00

164 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pathlib import Path
import csv
try:
import torch
from torch.utils.data import DataLoader
except ImportError as exc:
raise ImportError("缺少 PyTorch。请确认当前 conda 环境为 torch271并已配置 torch271+cu126。") from exc
try:
import timm
except ImportError as exc:
raise ImportError("缺少 timm。请在 torch271 环境中安装 timm或告诉我改用其他路线。") from exc
from dataset import DATA_ROOT, DefocusDataset, as_posix_path, make_split_lists
SEED = 2026
MODEL_NAME = "mobilenetv4_conv_small"
PRETRAINED = False
DEVICE = "cpu"
BATCH_SIZE = 64
NUM_WORKERS = 8
CHECKPOINT_PATH = Path("checkpoints") / "best_mobilenetv4_defocus.pth"
OUTPUT_DIR = Path("predictions")
OUTPUT_CSV = OUTPUT_DIR / "test_predictions.csv"
TEST_BATCH_SIZE = 2
TEST_NUM_WORKERS = 0
TEST_MAX_SAMPLES = 8
TEST_OUTPUT_CSV = OUTPUT_DIR / "test_predictions_smoke.csv"
def get_device(device_name=DEVICE):
if device_name.startswith("cuda") and not torch.cuda.is_available():
raise RuntimeError("当前指定设备为 cuda:0但 PyTorch 没有检测到可用 CUDA。请检查 torch271+cu126 环境。")
return torch.device(device_name)
def create_model():
model = timm.create_model(MODEL_NAME, pretrained=PRETRAINED, num_classes=1)
return model
def load_model(checkpoint_path=CHECKPOINT_PATH, device_name=DEVICE):
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"找不到检查点文件:{as_posix_path(checkpoint_path)}")
device = get_device(device_name)
checkpoint = torch.load(checkpoint_path, map_location=device)
model = create_model()
if isinstance(checkpoint, dict) and "model_state" in checkpoint:
model.load_state_dict(checkpoint["model_state"])
else:
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
return model, device
def make_test_loader(data_root=DATA_ROOT, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, seed=SEED, max_samples=None):
split_data = make_split_lists(data_root=data_root, seed=seed)
image_paths = split_data["test"]["image_paths"]
labels = split_data["test"]["labels"]
if max_samples is not None:
image_paths = image_paths[:max_samples]
labels = labels[:max_samples]
if not image_paths:
raise RuntimeError("测试集为空,请检查数据目录结构和 field 层面的划分结果。")
dataset = DefocusDataset(image_paths, labels, return_path=True)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
return dataloader
@torch.no_grad()
def predict(model, dataloader, device):
rows = []
for images, targets, image_paths in dataloader:
images = images.to(device, non_blocking=True)
predictions = model(images).detach().cpu().view(-1)
targets = targets.detach().cpu().view(-1)
for image_path, label, prediction in zip(image_paths, targets.tolist(), predictions.tolist()):
rows.append({
"image_path": as_posix_path(image_path),
"label": label,
"prediction": prediction,
})
return rows
def save_predictions_csv(rows, output_csv=OUTPUT_CSV):
output_csv = Path(output_csv)
output_csv.parent.mkdir(parents=True, exist_ok=True)
with output_csv.open("w", newline="", encoding="utf-8") as file:
writer = csv.DictWriter(file, fieldnames=["image_path", "label", "prediction"])
writer.writeheader()
writer.writerows(rows)
return output_csv
def run_test(checkpoint_path=CHECKPOINT_PATH, output_csv=OUTPUT_CSV, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, max_samples=None):
model, device = load_model(checkpoint_path=checkpoint_path, device_name=DEVICE)
dataloader = make_test_loader(
data_root=DATA_ROOT,
batch_size=batch_size,
num_workers=num_workers,
seed=SEED,
max_samples=max_samples,
)
rows = predict(model, dataloader, device)
output_csv = save_predictions_csv(rows, output_csv=output_csv)
print(f"已保存测试预测结果:{as_posix_path(output_csv)}")
print(f"写入样本数:{len(rows)}")
return rows
def test_eval():
"""小规模测试:只跑少量测试集样本,确认加载、推理和 CSV 写出链路。"""
if not DATA_ROOT.exists():
print("数据集根目录不存在,先跳过小规模测试。")
return
if not CHECKPOINT_PATH.exists():
print(f"最佳检查点不存在,先跳过小规模测试:{as_posix_path(CHECKPOINT_PATH)}")
return
run_test(
checkpoint_path=CHECKPOINT_PATH,
output_csv=TEST_OUTPUT_CSV,
batch_size=TEST_BATCH_SIZE,
num_workers=TEST_NUM_WORKERS,
max_samples=TEST_MAX_SAMPLES,
)
def main():
run_test(
checkpoint_path=CHECKPOINT_PATH,
output_csv=OUTPUT_CSV,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
)
if __name__ == "__main__":
test_eval()
# main()