from pathlib import Path import logging import random import sys import time try: import numpy as np except ImportError as exc: raise ImportError("缺少 numpy。请在 torch271 环境中安装 numpy,或告诉我改用其他路线。") from exc try: import torch from torch.utils.data import DataLoader except ImportError as exc: raise ImportError("缺少 PyTorch。请确认当前 conda 环境为 torch271,并已配置 torch271+cu126。") from exc try: import timm except ImportError as exc: raise ImportError("缺少 timm。请在 torch271 环境中安装 timm,或告诉我改用其他路线。") from exc from dataset import DATA_ROOT, DefocusDataset, as_posix_path, make_split_lists # 基础超参数。大规模实验前先运行 test_train()。 SEED = 2026 MODEL_NAME = "mobilenetv4_conv_small" PRETRAINED = False DEVICE = "cuda:0" BATCH_SIZE = 64 NUM_WORKERS = 8 LEARNING_RATE = 1e-4 EPOCHS = 300 CHECKPOINT_DIR = Path("checkpoints") BEST_CHECKPOINT_NAME = "best_mobilenetv4_defocus.pth" LAST_CHECKPOINT_NAME = "last_mobilenetv4_defocus.pth" LOG_DIR = Path("logs") LOG_FILE = LOG_DIR / "train.log" TENSORBOARD_DIR = Path("runs") / "mobilenetv4_defocus" TQDM_NCOLS = 100 TEST_BATCH_SIZE = 2 TEST_NUM_WORKERS = 0 TEST_MAX_SAMPLES = 8 TEST_EPOCHS = 1 def set_seed(seed=SEED): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def setup_logger(log_file=LOG_FILE): logger = logging.getLogger("train") logger.setLevel(logging.INFO) logger.handlers.clear() logger.propagate = False log_file = Path(log_file) log_file.parent.mkdir(parents=True, exist_ok=True) formatter = logging.Formatter( fmt="%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(formatter) logger.addHandler(console_handler) file_handler = logging.FileHandler(log_file, encoding="utf-8") file_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.info(f"日志文件:{as_posix_path(log_file)}") return logger def create_summary_writer(log_dir=TENSORBOARD_DIR): try: from torch.utils.tensorboard import SummaryWriter except ImportError as exc: raise ImportError("缺少 TensorBoard 相关依赖。请安装 tensorboard,或告诉我改用其他记录方式。") from exc log_dir = Path(log_dir) log_dir.mkdir(parents=True, exist_ok=True) return SummaryWriter(log_dir=as_posix_path(log_dir)) def make_progress_bar(iterable, desc): try: from tqdm import tqdm except ImportError as exc: raise ImportError("缺少 tqdm。请在 torch271 环境中安装 tqdm,或告诉我改用普通日志显示进度。") from exc return tqdm( iterable, desc=desc, total=len(iterable), ncols=TQDM_NCOLS, leave=False, dynamic_ncols=False, ) def get_device(device_name=DEVICE): if device_name.startswith("cuda") and not torch.cuda.is_available(): raise RuntimeError("当前指定设备为 cuda:0,但 PyTorch 没有检测到可用 CUDA。请检查 torch271+cu126 环境。") return torch.device(device_name) def create_model(): model = timm.create_model(MODEL_NAME, pretrained=PRETRAINED, num_classes=1) return model def regression_metrics(predictions, targets): errors = predictions - targets mae = torch.mean(torch.abs(errors)).item() rmse = torch.sqrt(torch.mean(errors ** 2)).item() return { "mae": mae, "rmse": rmse, } def get_current_lr(optimizer): return optimizer.param_groups[0]["lr"] def move_batch_to_device(batch, device): images, targets = batch[:2] images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) return images, targets def train_epoch(model, dataloader, criterion, optimizer, device, epoch=None): model.train() total_loss = 0.0 total_count = 0 all_predictions = [] all_targets = [] desc = "训练" if epoch is None else f"训练 {epoch:03d}" progress_bar = make_progress_bar(dataloader, desc) for batch in progress_bar: images, targets = move_batch_to_device(batch, device) optimizer.zero_grad(set_to_none=True) predictions = model(images) loss = criterion(predictions, targets) loss.backward() optimizer.step() batch_size = images.size(0) total_loss += loss.item() * batch_size total_count += batch_size all_predictions.append(predictions.detach().cpu()) all_targets.append(targets.detach().cpu()) progress_bar.set_postfix(loss=f"{loss.item():<10.4f}") mean_loss = total_loss / max(total_count, 1) metrics = regression_metrics(torch.cat(all_predictions), torch.cat(all_targets)) metrics["loss"] = mean_loss return metrics @torch.no_grad() def valid_epoch(model, dataloader, criterion, device, epoch=None): model.eval() total_loss = 0.0 total_count = 0 all_predictions = [] all_targets = [] desc = "验证" if epoch is None else f"验证 {epoch:03d}" progress_bar = make_progress_bar(dataloader, desc) for batch in progress_bar: images, targets = move_batch_to_device(batch, device) predictions = model(images) loss = criterion(predictions, targets) batch_size = images.size(0) total_loss += loss.item() * batch_size total_count += batch_size all_predictions.append(predictions.detach().cpu()) all_targets.append(targets.detach().cpu()) progress_bar.set_postfix(loss=f"{loss.item():<10.4f}") mean_loss = total_loss / max(total_count, 1) metrics = regression_metrics(torch.cat(all_predictions), torch.cat(all_targets)) metrics["loss"] = mean_loss return metrics def write_tensorboard_scalars(writer, epoch, train_metrics, valid_metrics, best_valid_loss, train_seconds, valid_seconds, epoch_seconds, learning_rate): writer.add_scalar("train/loss", train_metrics["loss"], epoch) writer.add_scalar("train/mae", train_metrics["mae"], epoch) writer.add_scalar("train/rmse", train_metrics["rmse"], epoch) writer.add_scalar("valid/loss", valid_metrics["loss"], epoch) writer.add_scalar("valid/mae", valid_metrics["mae"], epoch) writer.add_scalar("valid/rmse", valid_metrics["rmse"], epoch) writer.add_scalar("summary/best_valid_loss", best_valid_loss, epoch) writer.add_scalar("time/train_seconds", train_seconds, epoch) writer.add_scalar("time/valid_seconds", valid_seconds, epoch) writer.add_scalar("time/epoch_seconds", epoch_seconds, epoch) writer.add_scalar("optimizer/learning_rate", learning_rate, epoch) writer.flush() def make_dataloaders(data_root=DATA_ROOT, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, seed=SEED): split_data = make_split_lists(data_root=data_root, seed=seed) train_dataset = DefocusDataset( split_data["train"]["image_paths"], split_data["train"]["labels"], ) valid_dataset = DefocusDataset( split_data["valid"]["image_paths"], split_data["valid"]["labels"], ) if len(train_dataset) == 0: raise RuntimeError("训练集为空,请检查数据目录结构。") if len(valid_dataset) == 0: raise RuntimeError("验证集为空,请检查 field 层面的划分结果。") train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, ) valid_loader = DataLoader( valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=True, ) return train_loader, valid_loader, split_data def save_checkpoint(path, model, optimizer, epoch, train_metrics, valid_metrics, best_valid_loss): path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) checkpoint = { "epoch": epoch, "model_name": MODEL_NAME, "pretrained": PRETRAINED, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "train_metrics": train_metrics, "valid_metrics": valid_metrics, "best_valid_loss": best_valid_loss, "config": { "data_root": as_posix_path(DATA_ROOT), "batch_size": BATCH_SIZE, "num_workers": NUM_WORKERS, "learning_rate": LEARNING_RATE, "device": DEVICE, "loss": "SmoothL1Loss", "optimizer": "Adam", }, } torch.save(checkpoint, path) def log_epoch_metrics(logger, epoch, train_metrics, valid_metrics, train_seconds, valid_seconds, epoch_seconds, learning_rate): logger.info( f"第 {epoch:03d} 轮 | " f"训练 loss={train_metrics['loss']:.6f}, MAE={train_metrics['mae']:.6f}, RMSE={train_metrics['rmse']:.6f} | " f"验证 loss={valid_metrics['loss']:.6f}, MAE={valid_metrics['mae']:.6f}, RMSE={valid_metrics['rmse']:.6f} | " f"耗时 train={train_seconds:.2f}s, valid={valid_seconds:.2f}s, epoch={epoch_seconds:.2f}s | " f"学习率={learning_rate:.6g}" ) def fit(train_loader, valid_loader, epochs=EPOCHS, device_name=DEVICE, logger=None, writer=None): logger = logger or setup_logger() close_writer = writer is None if writer is None: writer = create_summary_writer() logger.info(f"TensorBoard 目录:{as_posix_path(TENSORBOARD_DIR)}") device = get_device(device_name) model = create_model().to(device) criterion = torch.nn.SmoothL1Loss() optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) best_valid_loss = float("inf") best_path = CHECKPOINT_DIR / BEST_CHECKPOINT_NAME last_path = CHECKPOINT_DIR / LAST_CHECKPOINT_NAME last_epoch = None last_train_metrics = None last_valid_metrics = None try: for epoch in range(1, epochs + 1): epoch_start = time.perf_counter() train_start = time.perf_counter() train_metrics = train_epoch(model, train_loader, criterion, optimizer, device, epoch=epoch) train_seconds = time.perf_counter() - train_start valid_start = time.perf_counter() valid_metrics = valid_epoch(model, valid_loader, criterion, device, epoch=epoch) valid_seconds = time.perf_counter() - valid_start epoch_seconds = time.perf_counter() - epoch_start learning_rate = get_current_lr(optimizer) if valid_metrics["loss"] < best_valid_loss: best_valid_loss = valid_metrics["loss"] save_checkpoint(best_path, model, optimizer, epoch, train_metrics, valid_metrics, best_valid_loss) logger.info(f"已保存最佳检查点:{as_posix_path(best_path)}") last_epoch = epoch last_train_metrics = train_metrics last_valid_metrics = valid_metrics log_epoch_metrics( logger, epoch, train_metrics, valid_metrics, train_seconds, valid_seconds, epoch_seconds, learning_rate, ) write_tensorboard_scalars( writer, epoch, train_metrics, valid_metrics, best_valid_loss, train_seconds, valid_seconds, epoch_seconds, learning_rate, ) except KeyboardInterrupt: logger.warning("检测到 Ctrl-C 手动中止训练,准备保存最后一个完整 epoch 的 last 检查点。") finally: if last_epoch is not None: save_checkpoint( last_path, model, optimizer, last_epoch, last_train_metrics, last_valid_metrics, best_valid_loss, ) logger.info(f"已保存 last 检查点:{as_posix_path(last_path)},对应第 {last_epoch:03d} 轮。") else: logger.warning("训练尚未完整完成任何 epoch,未保存 last 检查点。") if close_writer: writer.close() return model def test_train(): """小规模训练测试:只取少量样本,验证完整训练链路。""" logger = setup_logger(LOG_DIR / "test_train.log") set_seed(SEED) logger.info(f"数据根目录:{as_posix_path(DATA_ROOT)}") logger.info(f"模型:{MODEL_NAME},pretrained={PRETRAINED}") logger.info(f"设备:{DEVICE}") if not DATA_ROOT.exists(): logger.info("数据集根目录不存在,先跳过小规模训练测试。") return split_data = make_split_lists(data_root=DATA_ROOT, seed=SEED) train_paths = split_data["train"]["image_paths"][:TEST_MAX_SAMPLES] train_labels = split_data["train"]["labels"][:TEST_MAX_SAMPLES] valid_paths = split_data["valid"]["image_paths"][:TEST_MAX_SAMPLES] valid_labels = split_data["valid"]["labels"][:TEST_MAX_SAMPLES] if not train_paths or not valid_paths: logger.info("训练集或验证集为空,请检查数据目录结构和 field 划分结果。") return train_dataset = DefocusDataset(train_paths, train_labels) valid_dataset = DefocusDataset(valid_paths, valid_labels) train_loader = DataLoader( train_dataset, batch_size=TEST_BATCH_SIZE, shuffle=True, num_workers=TEST_NUM_WORKERS, pin_memory=True, ) valid_loader = DataLoader( valid_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False, num_workers=TEST_NUM_WORKERS, pin_memory=True, ) writer = create_summary_writer(TENSORBOARD_DIR / "test_train") logger.info(f"TensorBoard 测试目录:{as_posix_path(TENSORBOARD_DIR / 'test_train')}") fit(train_loader, valid_loader, epochs=TEST_EPOCHS, device_name=DEVICE, logger=logger, writer=writer) writer.close() logger.info("小规模训练测试完成。") def main(): logger = setup_logger(LOG_FILE) set_seed(SEED) train_loader, valid_loader, split_data = make_dataloaders( data_root=DATA_ROOT, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, seed=SEED, ) logger.info(f"训练图像数:{len(split_data['train']['image_paths'])}") logger.info(f"验证图像数:{len(split_data['valid']['image_paths'])}") fit(train_loader, valid_loader, epochs=EPOCHS, device_name=DEVICE, logger=logger) if __name__ == "__main__": # test_train() main()