164 lines
5.1 KiB
Python
164 lines
5.1 KiB
Python
from pathlib import Path
|
||
import csv
|
||
|
||
try:
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
except ImportError as exc:
|
||
raise ImportError("缺少 PyTorch。请确认当前 conda 环境为 torch271,并已配置 torch271+cu126。") from exc
|
||
|
||
try:
|
||
import timm
|
||
except ImportError as exc:
|
||
raise ImportError("缺少 timm。请在 torch271 环境中安装 timm,或告诉我改用其他路线。") from exc
|
||
|
||
from dataset import DATA_ROOT, DefocusDataset, as_posix_path, make_split_lists
|
||
|
||
|
||
SEED = 2026
|
||
MODEL_NAME = "mobilenetv4_conv_small"
|
||
PRETRAINED = False
|
||
DEVICE = "cpu"
|
||
|
||
BATCH_SIZE = 64
|
||
NUM_WORKERS = 8
|
||
CHECKPOINT_PATH = Path("checkpoints") / "best_mobilenetv4_defocus.pth"
|
||
OUTPUT_DIR = Path("predictions")
|
||
OUTPUT_CSV = OUTPUT_DIR / "test_predictions.csv"
|
||
|
||
TEST_BATCH_SIZE = 2
|
||
TEST_NUM_WORKERS = 0
|
||
TEST_MAX_SAMPLES = 8
|
||
TEST_OUTPUT_CSV = OUTPUT_DIR / "test_predictions_smoke.csv"
|
||
|
||
|
||
def get_device(device_name=DEVICE):
|
||
if device_name.startswith("cuda") and not torch.cuda.is_available():
|
||
raise RuntimeError("当前指定设备为 cuda:0,但 PyTorch 没有检测到可用 CUDA。请检查 torch271+cu126 环境。")
|
||
return torch.device(device_name)
|
||
|
||
|
||
def create_model():
|
||
model = timm.create_model(MODEL_NAME, pretrained=PRETRAINED, num_classes=1)
|
||
return model
|
||
|
||
|
||
def load_model(checkpoint_path=CHECKPOINT_PATH, device_name=DEVICE):
|
||
checkpoint_path = Path(checkpoint_path)
|
||
if not checkpoint_path.exists():
|
||
raise FileNotFoundError(f"找不到检查点文件:{as_posix_path(checkpoint_path)}")
|
||
|
||
device = get_device(device_name)
|
||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||
|
||
model = create_model()
|
||
if isinstance(checkpoint, dict) and "model_state" in checkpoint:
|
||
model.load_state_dict(checkpoint["model_state"])
|
||
else:
|
||
model.load_state_dict(checkpoint)
|
||
|
||
model.to(device)
|
||
model.eval()
|
||
return model, device
|
||
|
||
|
||
def make_test_loader(data_root=DATA_ROOT, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, seed=SEED, max_samples=None):
|
||
split_data = make_split_lists(data_root=data_root, seed=seed)
|
||
image_paths = split_data["test"]["image_paths"]
|
||
labels = split_data["test"]["labels"]
|
||
|
||
if max_samples is not None:
|
||
image_paths = image_paths[:max_samples]
|
||
labels = labels[:max_samples]
|
||
|
||
if not image_paths:
|
||
raise RuntimeError("测试集为空,请检查数据目录结构和 field 层面的划分结果。")
|
||
|
||
dataset = DefocusDataset(image_paths, labels, return_path=True)
|
||
dataloader = DataLoader(
|
||
dataset,
|
||
batch_size=batch_size,
|
||
shuffle=False,
|
||
num_workers=num_workers,
|
||
pin_memory=True,
|
||
)
|
||
return dataloader
|
||
|
||
|
||
@torch.no_grad()
|
||
def predict(model, dataloader, device):
|
||
rows = []
|
||
for images, targets, image_paths in dataloader:
|
||
images = images.to(device, non_blocking=True)
|
||
predictions = model(images).detach().cpu().view(-1)
|
||
targets = targets.detach().cpu().view(-1)
|
||
|
||
for image_path, label, prediction in zip(image_paths, targets.tolist(), predictions.tolist()):
|
||
rows.append({
|
||
"image_path": as_posix_path(image_path),
|
||
"label": label,
|
||
"prediction": prediction,
|
||
})
|
||
|
||
return rows
|
||
|
||
|
||
def save_predictions_csv(rows, output_csv=OUTPUT_CSV):
|
||
output_csv = Path(output_csv)
|
||
output_csv.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
with output_csv.open("w", newline="", encoding="utf-8") as file:
|
||
writer = csv.DictWriter(file, fieldnames=["image_path", "label", "prediction"])
|
||
writer.writeheader()
|
||
writer.writerows(rows)
|
||
|
||
return output_csv
|
||
|
||
|
||
def run_test(checkpoint_path=CHECKPOINT_PATH, output_csv=OUTPUT_CSV, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, max_samples=None):
|
||
model, device = load_model(checkpoint_path=checkpoint_path, device_name=DEVICE)
|
||
dataloader = make_test_loader(
|
||
data_root=DATA_ROOT,
|
||
batch_size=batch_size,
|
||
num_workers=num_workers,
|
||
seed=SEED,
|
||
max_samples=max_samples,
|
||
)
|
||
rows = predict(model, dataloader, device)
|
||
output_csv = save_predictions_csv(rows, output_csv=output_csv)
|
||
print(f"已保存测试预测结果:{as_posix_path(output_csv)}")
|
||
print(f"写入样本数:{len(rows)}")
|
||
return rows
|
||
|
||
|
||
def test_eval():
|
||
"""小规模测试:只跑少量测试集样本,确认加载、推理和 CSV 写出链路。"""
|
||
if not DATA_ROOT.exists():
|
||
print("数据集根目录不存在,先跳过小规模测试。")
|
||
return
|
||
if not CHECKPOINT_PATH.exists():
|
||
print(f"最佳检查点不存在,先跳过小规模测试:{as_posix_path(CHECKPOINT_PATH)}")
|
||
return
|
||
|
||
run_test(
|
||
checkpoint_path=CHECKPOINT_PATH,
|
||
output_csv=TEST_OUTPUT_CSV,
|
||
batch_size=TEST_BATCH_SIZE,
|
||
num_workers=TEST_NUM_WORKERS,
|
||
max_samples=TEST_MAX_SAMPLES,
|
||
)
|
||
|
||
|
||
def main():
|
||
run_test(
|
||
checkpoint_path=CHECKPOINT_PATH,
|
||
output_csv=OUTPUT_CSV,
|
||
batch_size=BATCH_SIZE,
|
||
num_workers=NUM_WORKERS,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_eval()
|
||
# main()
|