447 lines
15 KiB
Python
447 lines
15 KiB
Python
from pathlib import Path
|
||
import logging
|
||
import random
|
||
import sys
|
||
import time
|
||
|
||
try:
|
||
import numpy as np
|
||
except ImportError as exc:
|
||
raise ImportError("缺少 numpy。请在 torch271 环境中安装 numpy,或告诉我改用其他路线。") from exc
|
||
|
||
try:
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
except ImportError as exc:
|
||
raise ImportError("缺少 PyTorch。请确认当前 conda 环境为 torch271,并已配置 torch271+cu126。") from exc
|
||
|
||
try:
|
||
import timm
|
||
except ImportError as exc:
|
||
raise ImportError("缺少 timm。请在 torch271 环境中安装 timm,或告诉我改用其他路线。") from exc
|
||
|
||
from dataset import DATA_ROOT, DefocusDataset, as_posix_path, make_split_lists
|
||
|
||
|
||
# 基础超参数。大规模实验前先运行 test_train()。
|
||
SEED = 2026
|
||
MODEL_NAME = "mobilenetv4_conv_small"
|
||
PRETRAINED = False
|
||
DEVICE = "cuda:0"
|
||
|
||
BATCH_SIZE = 64
|
||
NUM_WORKERS = 8
|
||
LEARNING_RATE = 1e-4
|
||
EPOCHS = 300
|
||
|
||
CHECKPOINT_DIR = Path("checkpoints")
|
||
BEST_CHECKPOINT_NAME = "best_mobilenetv4_defocus.pth"
|
||
LAST_CHECKPOINT_NAME = "last_mobilenetv4_defocus.pth"
|
||
LOG_DIR = Path("logs")
|
||
LOG_FILE = LOG_DIR / "train.log"
|
||
TENSORBOARD_DIR = Path("runs") / "mobilenetv4_defocus"
|
||
TQDM_NCOLS = 100
|
||
|
||
TEST_BATCH_SIZE = 2
|
||
TEST_NUM_WORKERS = 0
|
||
TEST_MAX_SAMPLES = 8
|
||
TEST_EPOCHS = 1
|
||
|
||
|
||
def set_seed(seed=SEED):
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
torch.cuda.manual_seed_all(seed)
|
||
|
||
|
||
def setup_logger(log_file=LOG_FILE):
|
||
logger = logging.getLogger("train")
|
||
logger.setLevel(logging.INFO)
|
||
logger.handlers.clear()
|
||
logger.propagate = False
|
||
|
||
log_file = Path(log_file)
|
||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
formatter = logging.Formatter(
|
||
fmt="%(asctime)s | %(levelname)s | %(message)s",
|
||
datefmt="%Y-%m-%d %H:%M:%S",
|
||
)
|
||
|
||
console_handler = logging.StreamHandler(sys.stdout)
|
||
console_handler.setFormatter(formatter)
|
||
logger.addHandler(console_handler)
|
||
|
||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||
file_handler.setFormatter(formatter)
|
||
logger.addHandler(file_handler)
|
||
|
||
logger.info(f"日志文件:{as_posix_path(log_file)}")
|
||
return logger
|
||
|
||
|
||
def create_summary_writer(log_dir=TENSORBOARD_DIR):
|
||
try:
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
except ImportError as exc:
|
||
raise ImportError("缺少 TensorBoard 相关依赖。请安装 tensorboard,或告诉我改用其他记录方式。") from exc
|
||
|
||
log_dir = Path(log_dir)
|
||
log_dir.mkdir(parents=True, exist_ok=True)
|
||
return SummaryWriter(log_dir=as_posix_path(log_dir))
|
||
|
||
|
||
def make_progress_bar(iterable, desc):
|
||
try:
|
||
from tqdm import tqdm
|
||
except ImportError as exc:
|
||
raise ImportError("缺少 tqdm。请在 torch271 环境中安装 tqdm,或告诉我改用普通日志显示进度。") from exc
|
||
|
||
return tqdm(
|
||
iterable,
|
||
desc=desc,
|
||
total=len(iterable),
|
||
ncols=TQDM_NCOLS,
|
||
leave=False,
|
||
dynamic_ncols=False,
|
||
)
|
||
|
||
|
||
def get_device(device_name=DEVICE):
|
||
if device_name.startswith("cuda") and not torch.cuda.is_available():
|
||
raise RuntimeError("当前指定设备为 cuda:0,但 PyTorch 没有检测到可用 CUDA。请检查 torch271+cu126 环境。")
|
||
return torch.device(device_name)
|
||
|
||
|
||
def create_model():
|
||
model = timm.create_model(MODEL_NAME, pretrained=PRETRAINED, num_classes=1)
|
||
return model
|
||
|
||
|
||
def regression_metrics(predictions, targets):
|
||
errors = predictions - targets
|
||
mae = torch.mean(torch.abs(errors)).item()
|
||
rmse = torch.sqrt(torch.mean(errors ** 2)).item()
|
||
return {
|
||
"mae": mae,
|
||
"rmse": rmse,
|
||
}
|
||
|
||
|
||
def get_current_lr(optimizer):
|
||
return optimizer.param_groups[0]["lr"]
|
||
|
||
|
||
def move_batch_to_device(batch, device):
|
||
images, targets = batch[:2]
|
||
images = images.to(device, non_blocking=True)
|
||
targets = targets.to(device, non_blocking=True)
|
||
return images, targets
|
||
|
||
|
||
def train_epoch(model, dataloader, criterion, optimizer, device, epoch=None):
|
||
model.train()
|
||
|
||
total_loss = 0.0
|
||
total_count = 0
|
||
all_predictions = []
|
||
all_targets = []
|
||
|
||
desc = "训练" if epoch is None else f"训练 {epoch:03d}"
|
||
progress_bar = make_progress_bar(dataloader, desc)
|
||
|
||
for batch in progress_bar:
|
||
images, targets = move_batch_to_device(batch, device)
|
||
|
||
optimizer.zero_grad(set_to_none=True)
|
||
predictions = model(images)
|
||
loss = criterion(predictions, targets)
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
batch_size = images.size(0)
|
||
total_loss += loss.item() * batch_size
|
||
total_count += batch_size
|
||
all_predictions.append(predictions.detach().cpu())
|
||
all_targets.append(targets.detach().cpu())
|
||
progress_bar.set_postfix(loss=f"{loss.item():<10.4f}")
|
||
|
||
mean_loss = total_loss / max(total_count, 1)
|
||
metrics = regression_metrics(torch.cat(all_predictions), torch.cat(all_targets))
|
||
metrics["loss"] = mean_loss
|
||
return metrics
|
||
|
||
|
||
@torch.no_grad()
|
||
def valid_epoch(model, dataloader, criterion, device, epoch=None):
|
||
model.eval()
|
||
|
||
total_loss = 0.0
|
||
total_count = 0
|
||
all_predictions = []
|
||
all_targets = []
|
||
|
||
desc = "验证" if epoch is None else f"验证 {epoch:03d}"
|
||
progress_bar = make_progress_bar(dataloader, desc)
|
||
|
||
for batch in progress_bar:
|
||
images, targets = move_batch_to_device(batch, device)
|
||
|
||
predictions = model(images)
|
||
loss = criterion(predictions, targets)
|
||
|
||
batch_size = images.size(0)
|
||
total_loss += loss.item() * batch_size
|
||
total_count += batch_size
|
||
all_predictions.append(predictions.detach().cpu())
|
||
all_targets.append(targets.detach().cpu())
|
||
progress_bar.set_postfix(loss=f"{loss.item():<10.4f}")
|
||
|
||
mean_loss = total_loss / max(total_count, 1)
|
||
metrics = regression_metrics(torch.cat(all_predictions), torch.cat(all_targets))
|
||
metrics["loss"] = mean_loss
|
||
return metrics
|
||
|
||
|
||
def write_tensorboard_scalars(writer, epoch, train_metrics, valid_metrics, best_valid_loss, train_seconds, valid_seconds, epoch_seconds, learning_rate):
|
||
writer.add_scalar("train/loss", train_metrics["loss"], epoch)
|
||
writer.add_scalar("train/mae", train_metrics["mae"], epoch)
|
||
writer.add_scalar("train/rmse", train_metrics["rmse"], epoch)
|
||
|
||
writer.add_scalar("valid/loss", valid_metrics["loss"], epoch)
|
||
writer.add_scalar("valid/mae", valid_metrics["mae"], epoch)
|
||
writer.add_scalar("valid/rmse", valid_metrics["rmse"], epoch)
|
||
|
||
writer.add_scalar("summary/best_valid_loss", best_valid_loss, epoch)
|
||
writer.add_scalar("time/train_seconds", train_seconds, epoch)
|
||
writer.add_scalar("time/valid_seconds", valid_seconds, epoch)
|
||
writer.add_scalar("time/epoch_seconds", epoch_seconds, epoch)
|
||
writer.add_scalar("optimizer/learning_rate", learning_rate, epoch)
|
||
writer.flush()
|
||
|
||
|
||
def make_dataloaders(data_root=DATA_ROOT, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, seed=SEED):
|
||
split_data = make_split_lists(data_root=data_root, seed=seed)
|
||
|
||
train_dataset = DefocusDataset(
|
||
split_data["train"]["image_paths"],
|
||
split_data["train"]["labels"],
|
||
)
|
||
valid_dataset = DefocusDataset(
|
||
split_data["valid"]["image_paths"],
|
||
split_data["valid"]["labels"],
|
||
)
|
||
|
||
if len(train_dataset) == 0:
|
||
raise RuntimeError("训练集为空,请检查数据目录结构。")
|
||
if len(valid_dataset) == 0:
|
||
raise RuntimeError("验证集为空,请检查 field 层面的划分结果。")
|
||
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=True,
|
||
num_workers=num_workers,
|
||
pin_memory=True,
|
||
persistent_workers=True,
|
||
)
|
||
valid_loader = DataLoader(
|
||
valid_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=False,
|
||
num_workers=num_workers,
|
||
pin_memory=True,
|
||
persistent_workers=True,
|
||
)
|
||
|
||
return train_loader, valid_loader, split_data
|
||
|
||
|
||
def save_checkpoint(path, model, optimizer, epoch, train_metrics, valid_metrics, best_valid_loss):
|
||
path = Path(path)
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
checkpoint = {
|
||
"epoch": epoch,
|
||
"model_name": MODEL_NAME,
|
||
"pretrained": PRETRAINED,
|
||
"model_state": model.state_dict(),
|
||
"optimizer_state": optimizer.state_dict(),
|
||
"train_metrics": train_metrics,
|
||
"valid_metrics": valid_metrics,
|
||
"best_valid_loss": best_valid_loss,
|
||
"config": {
|
||
"data_root": as_posix_path(DATA_ROOT),
|
||
"batch_size": BATCH_SIZE,
|
||
"num_workers": NUM_WORKERS,
|
||
"learning_rate": LEARNING_RATE,
|
||
"device": DEVICE,
|
||
"loss": "SmoothL1Loss",
|
||
"optimizer": "Adam",
|
||
},
|
||
}
|
||
torch.save(checkpoint, path)
|
||
|
||
|
||
def log_epoch_metrics(logger, epoch, train_metrics, valid_metrics, train_seconds, valid_seconds, epoch_seconds, learning_rate):
|
||
logger.info(
|
||
f"第 {epoch:03d} 轮 | "
|
||
f"训练 loss={train_metrics['loss']:.6f}, MAE={train_metrics['mae']:.6f}, RMSE={train_metrics['rmse']:.6f} | "
|
||
f"验证 loss={valid_metrics['loss']:.6f}, MAE={valid_metrics['mae']:.6f}, RMSE={valid_metrics['rmse']:.6f} | "
|
||
f"耗时 train={train_seconds:.2f}s, valid={valid_seconds:.2f}s, epoch={epoch_seconds:.2f}s | "
|
||
f"学习率={learning_rate:.6g}"
|
||
)
|
||
|
||
|
||
def fit(train_loader, valid_loader, epochs=EPOCHS, device_name=DEVICE, logger=None, writer=None):
|
||
logger = logger or setup_logger()
|
||
close_writer = writer is None
|
||
if writer is None:
|
||
writer = create_summary_writer()
|
||
logger.info(f"TensorBoard 目录:{as_posix_path(TENSORBOARD_DIR)}")
|
||
|
||
device = get_device(device_name)
|
||
model = create_model().to(device)
|
||
criterion = torch.nn.SmoothL1Loss()
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||
|
||
best_valid_loss = float("inf")
|
||
best_path = CHECKPOINT_DIR / BEST_CHECKPOINT_NAME
|
||
last_path = CHECKPOINT_DIR / LAST_CHECKPOINT_NAME
|
||
last_epoch = None
|
||
last_train_metrics = None
|
||
last_valid_metrics = None
|
||
|
||
try:
|
||
for epoch in range(1, epochs + 1):
|
||
epoch_start = time.perf_counter()
|
||
|
||
train_start = time.perf_counter()
|
||
train_metrics = train_epoch(model, train_loader, criterion, optimizer, device, epoch=epoch)
|
||
train_seconds = time.perf_counter() - train_start
|
||
|
||
valid_start = time.perf_counter()
|
||
valid_metrics = valid_epoch(model, valid_loader, criterion, device, epoch=epoch)
|
||
valid_seconds = time.perf_counter() - valid_start
|
||
|
||
epoch_seconds = time.perf_counter() - epoch_start
|
||
learning_rate = get_current_lr(optimizer)
|
||
|
||
if valid_metrics["loss"] < best_valid_loss:
|
||
best_valid_loss = valid_metrics["loss"]
|
||
save_checkpoint(best_path, model, optimizer, epoch, train_metrics, valid_metrics, best_valid_loss)
|
||
logger.info(f"已保存最佳检查点:{as_posix_path(best_path)}")
|
||
|
||
last_epoch = epoch
|
||
last_train_metrics = train_metrics
|
||
last_valid_metrics = valid_metrics
|
||
|
||
log_epoch_metrics(
|
||
logger,
|
||
epoch,
|
||
train_metrics,
|
||
valid_metrics,
|
||
train_seconds,
|
||
valid_seconds,
|
||
epoch_seconds,
|
||
learning_rate,
|
||
)
|
||
write_tensorboard_scalars(
|
||
writer,
|
||
epoch,
|
||
train_metrics,
|
||
valid_metrics,
|
||
best_valid_loss,
|
||
train_seconds,
|
||
valid_seconds,
|
||
epoch_seconds,
|
||
learning_rate,
|
||
)
|
||
except KeyboardInterrupt:
|
||
logger.warning("检测到 Ctrl-C 手动中止训练,准备保存最后一个完整 epoch 的 last 检查点。")
|
||
finally:
|
||
if last_epoch is not None:
|
||
save_checkpoint(
|
||
last_path,
|
||
model,
|
||
optimizer,
|
||
last_epoch,
|
||
last_train_metrics,
|
||
last_valid_metrics,
|
||
best_valid_loss,
|
||
)
|
||
logger.info(f"已保存 last 检查点:{as_posix_path(last_path)},对应第 {last_epoch:03d} 轮。")
|
||
else:
|
||
logger.warning("训练尚未完整完成任何 epoch,未保存 last 检查点。")
|
||
|
||
if close_writer:
|
||
writer.close()
|
||
|
||
return model
|
||
|
||
|
||
def test_train():
|
||
"""小规模训练测试:只取少量样本,验证完整训练链路。"""
|
||
logger = setup_logger(LOG_DIR / "test_train.log")
|
||
set_seed(SEED)
|
||
logger.info(f"数据根目录:{as_posix_path(DATA_ROOT)}")
|
||
logger.info(f"模型:{MODEL_NAME},pretrained={PRETRAINED}")
|
||
logger.info(f"设备:{DEVICE}")
|
||
|
||
if not DATA_ROOT.exists():
|
||
logger.info("数据集根目录不存在,先跳过小规模训练测试。")
|
||
return
|
||
|
||
split_data = make_split_lists(data_root=DATA_ROOT, seed=SEED)
|
||
train_paths = split_data["train"]["image_paths"][:TEST_MAX_SAMPLES]
|
||
train_labels = split_data["train"]["labels"][:TEST_MAX_SAMPLES]
|
||
valid_paths = split_data["valid"]["image_paths"][:TEST_MAX_SAMPLES]
|
||
valid_labels = split_data["valid"]["labels"][:TEST_MAX_SAMPLES]
|
||
|
||
if not train_paths or not valid_paths:
|
||
logger.info("训练集或验证集为空,请检查数据目录结构和 field 划分结果。")
|
||
return
|
||
|
||
train_dataset = DefocusDataset(train_paths, train_labels)
|
||
valid_dataset = DefocusDataset(valid_paths, valid_labels)
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
batch_size=TEST_BATCH_SIZE,
|
||
shuffle=True,
|
||
num_workers=TEST_NUM_WORKERS,
|
||
pin_memory=True,
|
||
)
|
||
valid_loader = DataLoader(
|
||
valid_dataset,
|
||
batch_size=TEST_BATCH_SIZE,
|
||
shuffle=False,
|
||
num_workers=TEST_NUM_WORKERS,
|
||
pin_memory=True,
|
||
)
|
||
|
||
writer = create_summary_writer(TENSORBOARD_DIR / "test_train")
|
||
logger.info(f"TensorBoard 测试目录:{as_posix_path(TENSORBOARD_DIR / 'test_train')}")
|
||
fit(train_loader, valid_loader, epochs=TEST_EPOCHS, device_name=DEVICE, logger=logger, writer=writer)
|
||
writer.close()
|
||
logger.info("小规模训练测试完成。")
|
||
|
||
|
||
def main():
|
||
logger = setup_logger(LOG_FILE)
|
||
set_seed(SEED)
|
||
train_loader, valid_loader, split_data = make_dataloaders(
|
||
data_root=DATA_ROOT,
|
||
batch_size=BATCH_SIZE,
|
||
num_workers=NUM_WORKERS,
|
||
seed=SEED,
|
||
)
|
||
logger.info(f"训练图像数:{len(split_data['train']['image_paths'])}")
|
||
logger.info(f"验证图像数:{len(split_data['valid']['image_paths'])}")
|
||
fit(train_loader, valid_loader, epochs=EPOCHS, device_name=DEVICE, logger=logger)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# test_train()
|
||
main()
|