DefocusEstimate/train.py
2026-05-17 21:06:33 +08:00

447 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pathlib import Path
import logging
import random
import sys
import time
try:
import numpy as np
except ImportError as exc:
raise ImportError("缺少 numpy。请在 torch271 环境中安装 numpy或告诉我改用其他路线。") from exc
try:
import torch
from torch.utils.data import DataLoader
except ImportError as exc:
raise ImportError("缺少 PyTorch。请确认当前 conda 环境为 torch271并已配置 torch271+cu126。") from exc
try:
import timm
except ImportError as exc:
raise ImportError("缺少 timm。请在 torch271 环境中安装 timm或告诉我改用其他路线。") from exc
from dataset import DATA_ROOT, DefocusDataset, as_posix_path, make_split_lists
# 基础超参数。大规模实验前先运行 test_train()。
SEED = 2026
MODEL_NAME = "mobilenetv4_conv_small"
PRETRAINED = False
DEVICE = "cuda:0"
BATCH_SIZE = 64
NUM_WORKERS = 8
LEARNING_RATE = 1e-4
EPOCHS = 300
CHECKPOINT_DIR = Path("checkpoints")
BEST_CHECKPOINT_NAME = "best_mobilenetv4_defocus.pth"
LAST_CHECKPOINT_NAME = "last_mobilenetv4_defocus.pth"
LOG_DIR = Path("logs")
LOG_FILE = LOG_DIR / "train.log"
TENSORBOARD_DIR = Path("runs") / "mobilenetv4_defocus"
TQDM_NCOLS = 100
TEST_BATCH_SIZE = 2
TEST_NUM_WORKERS = 0
TEST_MAX_SAMPLES = 8
TEST_EPOCHS = 1
def set_seed(seed=SEED):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def setup_logger(log_file=LOG_FILE):
logger = logging.getLogger("train")
logger.setLevel(logging.INFO)
logger.handlers.clear()
logger.propagate = False
log_file = Path(log_file)
log_file.parent.mkdir(parents=True, exist_ok=True)
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info(f"日志文件:{as_posix_path(log_file)}")
return logger
def create_summary_writer(log_dir=TENSORBOARD_DIR):
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError as exc:
raise ImportError("缺少 TensorBoard 相关依赖。请安装 tensorboard或告诉我改用其他记录方式。") from exc
log_dir = Path(log_dir)
log_dir.mkdir(parents=True, exist_ok=True)
return SummaryWriter(log_dir=as_posix_path(log_dir))
def make_progress_bar(iterable, desc):
try:
from tqdm import tqdm
except ImportError as exc:
raise ImportError("缺少 tqdm。请在 torch271 环境中安装 tqdm或告诉我改用普通日志显示进度。") from exc
return tqdm(
iterable,
desc=desc,
total=len(iterable),
ncols=TQDM_NCOLS,
leave=False,
dynamic_ncols=False,
)
def get_device(device_name=DEVICE):
if device_name.startswith("cuda") and not torch.cuda.is_available():
raise RuntimeError("当前指定设备为 cuda:0但 PyTorch 没有检测到可用 CUDA。请检查 torch271+cu126 环境。")
return torch.device(device_name)
def create_model():
model = timm.create_model(MODEL_NAME, pretrained=PRETRAINED, num_classes=1)
return model
def regression_metrics(predictions, targets):
errors = predictions - targets
mae = torch.mean(torch.abs(errors)).item()
rmse = torch.sqrt(torch.mean(errors ** 2)).item()
return {
"mae": mae,
"rmse": rmse,
}
def get_current_lr(optimizer):
return optimizer.param_groups[0]["lr"]
def move_batch_to_device(batch, device):
images, targets = batch[:2]
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
return images, targets
def train_epoch(model, dataloader, criterion, optimizer, device, epoch=None):
model.train()
total_loss = 0.0
total_count = 0
all_predictions = []
all_targets = []
desc = "训练" if epoch is None else f"训练 {epoch:03d}"
progress_bar = make_progress_bar(dataloader, desc)
for batch in progress_bar:
images, targets = move_batch_to_device(batch, device)
optimizer.zero_grad(set_to_none=True)
predictions = model(images)
loss = criterion(predictions, targets)
loss.backward()
optimizer.step()
batch_size = images.size(0)
total_loss += loss.item() * batch_size
total_count += batch_size
all_predictions.append(predictions.detach().cpu())
all_targets.append(targets.detach().cpu())
progress_bar.set_postfix(loss=f"{loss.item():<10.4f}")
mean_loss = total_loss / max(total_count, 1)
metrics = regression_metrics(torch.cat(all_predictions), torch.cat(all_targets))
metrics["loss"] = mean_loss
return metrics
@torch.no_grad()
def valid_epoch(model, dataloader, criterion, device, epoch=None):
model.eval()
total_loss = 0.0
total_count = 0
all_predictions = []
all_targets = []
desc = "验证" if epoch is None else f"验证 {epoch:03d}"
progress_bar = make_progress_bar(dataloader, desc)
for batch in progress_bar:
images, targets = move_batch_to_device(batch, device)
predictions = model(images)
loss = criterion(predictions, targets)
batch_size = images.size(0)
total_loss += loss.item() * batch_size
total_count += batch_size
all_predictions.append(predictions.detach().cpu())
all_targets.append(targets.detach().cpu())
progress_bar.set_postfix(loss=f"{loss.item():<10.4f}")
mean_loss = total_loss / max(total_count, 1)
metrics = regression_metrics(torch.cat(all_predictions), torch.cat(all_targets))
metrics["loss"] = mean_loss
return metrics
def write_tensorboard_scalars(writer, epoch, train_metrics, valid_metrics, best_valid_loss, train_seconds, valid_seconds, epoch_seconds, learning_rate):
writer.add_scalar("train/loss", train_metrics["loss"], epoch)
writer.add_scalar("train/mae", train_metrics["mae"], epoch)
writer.add_scalar("train/rmse", train_metrics["rmse"], epoch)
writer.add_scalar("valid/loss", valid_metrics["loss"], epoch)
writer.add_scalar("valid/mae", valid_metrics["mae"], epoch)
writer.add_scalar("valid/rmse", valid_metrics["rmse"], epoch)
writer.add_scalar("summary/best_valid_loss", best_valid_loss, epoch)
writer.add_scalar("time/train_seconds", train_seconds, epoch)
writer.add_scalar("time/valid_seconds", valid_seconds, epoch)
writer.add_scalar("time/epoch_seconds", epoch_seconds, epoch)
writer.add_scalar("optimizer/learning_rate", learning_rate, epoch)
writer.flush()
def make_dataloaders(data_root=DATA_ROOT, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, seed=SEED):
split_data = make_split_lists(data_root=data_root, seed=seed)
train_dataset = DefocusDataset(
split_data["train"]["image_paths"],
split_data["train"]["labels"],
)
valid_dataset = DefocusDataset(
split_data["valid"]["image_paths"],
split_data["valid"]["labels"],
)
if len(train_dataset) == 0:
raise RuntimeError("训练集为空,请检查数据目录结构。")
if len(valid_dataset) == 0:
raise RuntimeError("验证集为空,请检查 field 层面的划分结果。")
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
)
valid_loader = DataLoader(
valid_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
)
return train_loader, valid_loader, split_data
def save_checkpoint(path, model, optimizer, epoch, train_metrics, valid_metrics, best_valid_loss):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = {
"epoch": epoch,
"model_name": MODEL_NAME,
"pretrained": PRETRAINED,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"train_metrics": train_metrics,
"valid_metrics": valid_metrics,
"best_valid_loss": best_valid_loss,
"config": {
"data_root": as_posix_path(DATA_ROOT),
"batch_size": BATCH_SIZE,
"num_workers": NUM_WORKERS,
"learning_rate": LEARNING_RATE,
"device": DEVICE,
"loss": "SmoothL1Loss",
"optimizer": "Adam",
},
}
torch.save(checkpoint, path)
def log_epoch_metrics(logger, epoch, train_metrics, valid_metrics, train_seconds, valid_seconds, epoch_seconds, learning_rate):
logger.info(
f"{epoch:03d} 轮 | "
f"训练 loss={train_metrics['loss']:.6f}, MAE={train_metrics['mae']:.6f}, RMSE={train_metrics['rmse']:.6f} | "
f"验证 loss={valid_metrics['loss']:.6f}, MAE={valid_metrics['mae']:.6f}, RMSE={valid_metrics['rmse']:.6f} | "
f"耗时 train={train_seconds:.2f}s, valid={valid_seconds:.2f}s, epoch={epoch_seconds:.2f}s | "
f"学习率={learning_rate:.6g}"
)
def fit(train_loader, valid_loader, epochs=EPOCHS, device_name=DEVICE, logger=None, writer=None):
logger = logger or setup_logger()
close_writer = writer is None
if writer is None:
writer = create_summary_writer()
logger.info(f"TensorBoard 目录:{as_posix_path(TENSORBOARD_DIR)}")
device = get_device(device_name)
model = create_model().to(device)
criterion = torch.nn.SmoothL1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
best_valid_loss = float("inf")
best_path = CHECKPOINT_DIR / BEST_CHECKPOINT_NAME
last_path = CHECKPOINT_DIR / LAST_CHECKPOINT_NAME
last_epoch = None
last_train_metrics = None
last_valid_metrics = None
try:
for epoch in range(1, epochs + 1):
epoch_start = time.perf_counter()
train_start = time.perf_counter()
train_metrics = train_epoch(model, train_loader, criterion, optimizer, device, epoch=epoch)
train_seconds = time.perf_counter() - train_start
valid_start = time.perf_counter()
valid_metrics = valid_epoch(model, valid_loader, criterion, device, epoch=epoch)
valid_seconds = time.perf_counter() - valid_start
epoch_seconds = time.perf_counter() - epoch_start
learning_rate = get_current_lr(optimizer)
if valid_metrics["loss"] < best_valid_loss:
best_valid_loss = valid_metrics["loss"]
save_checkpoint(best_path, model, optimizer, epoch, train_metrics, valid_metrics, best_valid_loss)
logger.info(f"已保存最佳检查点:{as_posix_path(best_path)}")
last_epoch = epoch
last_train_metrics = train_metrics
last_valid_metrics = valid_metrics
log_epoch_metrics(
logger,
epoch,
train_metrics,
valid_metrics,
train_seconds,
valid_seconds,
epoch_seconds,
learning_rate,
)
write_tensorboard_scalars(
writer,
epoch,
train_metrics,
valid_metrics,
best_valid_loss,
train_seconds,
valid_seconds,
epoch_seconds,
learning_rate,
)
except KeyboardInterrupt:
logger.warning("检测到 Ctrl-C 手动中止训练,准备保存最后一个完整 epoch 的 last 检查点。")
finally:
if last_epoch is not None:
save_checkpoint(
last_path,
model,
optimizer,
last_epoch,
last_train_metrics,
last_valid_metrics,
best_valid_loss,
)
logger.info(f"已保存 last 检查点:{as_posix_path(last_path)},对应第 {last_epoch:03d} 轮。")
else:
logger.warning("训练尚未完整完成任何 epoch未保存 last 检查点。")
if close_writer:
writer.close()
return model
def test_train():
"""小规模训练测试:只取少量样本,验证完整训练链路。"""
logger = setup_logger(LOG_DIR / "test_train.log")
set_seed(SEED)
logger.info(f"数据根目录:{as_posix_path(DATA_ROOT)}")
logger.info(f"模型:{MODEL_NAME}pretrained={PRETRAINED}")
logger.info(f"设备:{DEVICE}")
if not DATA_ROOT.exists():
logger.info("数据集根目录不存在,先跳过小规模训练测试。")
return
split_data = make_split_lists(data_root=DATA_ROOT, seed=SEED)
train_paths = split_data["train"]["image_paths"][:TEST_MAX_SAMPLES]
train_labels = split_data["train"]["labels"][:TEST_MAX_SAMPLES]
valid_paths = split_data["valid"]["image_paths"][:TEST_MAX_SAMPLES]
valid_labels = split_data["valid"]["labels"][:TEST_MAX_SAMPLES]
if not train_paths or not valid_paths:
logger.info("训练集或验证集为空,请检查数据目录结构和 field 划分结果。")
return
train_dataset = DefocusDataset(train_paths, train_labels)
valid_dataset = DefocusDataset(valid_paths, valid_labels)
train_loader = DataLoader(
train_dataset,
batch_size=TEST_BATCH_SIZE,
shuffle=True,
num_workers=TEST_NUM_WORKERS,
pin_memory=True,
)
valid_loader = DataLoader(
valid_dataset,
batch_size=TEST_BATCH_SIZE,
shuffle=False,
num_workers=TEST_NUM_WORKERS,
pin_memory=True,
)
writer = create_summary_writer(TENSORBOARD_DIR / "test_train")
logger.info(f"TensorBoard 测试目录:{as_posix_path(TENSORBOARD_DIR / 'test_train')}")
fit(train_loader, valid_loader, epochs=TEST_EPOCHS, device_name=DEVICE, logger=logger, writer=writer)
writer.close()
logger.info("小规模训练测试完成。")
def main():
logger = setup_logger(LOG_FILE)
set_seed(SEED)
train_loader, valid_loader, split_data = make_dataloaders(
data_root=DATA_ROOT,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
seed=SEED,
)
logger.info(f"训练图像数:{len(split_data['train']['image_paths'])}")
logger.info(f"验证图像数:{len(split_data['valid']['image_paths'])}")
fit(train_loader, valid_loader, epochs=EPOCHS, device_name=DEVICE, logger=logger)
if __name__ == "__main__":
# test_train()
main()