SparseFocus/train_dpn.py
2026-06-02 13:51:22 +08:00

307 lines
9.0 KiB
Python

import argparse
import logging
import shutil
import time
from tqdm import tqdm
from datetime import datetime
from pathlib import Path
from pprint import pformat
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import tomllib
from models import DPNet
import old_datasets as dataset_F
import utils
def train_epoch(model, train_loader, criterion, optimizer, scheduler, device, writer, global_step):
model.train()
running_loss = 0.0
sample_count = 0
for img, label, image_name in tqdm(train_loader, desc="Train", bar_format="{l_bar}{bar:20}{r_bar}"):
img = img.to(device, non_blocking=True)
label = label.to(device, non_blocking=True).float().view(-1)
optimizer.zero_grad(set_to_none=True)
output = model(img).view(-1)
loss = criterion(output, label)
loss.backward()
optimizer.step()
scheduler.step()
global_step += 1
lr = optimizer.param_groups[0]["lr"]
writer.add_scalar("lr", lr, global_step)
batch_size = img.size(0)
running_loss += loss.item() * batch_size
sample_count += batch_size
epoch_loss = running_loss / sample_count
return epoch_loss, global_step
@torch.no_grad()
def validate_epoch(model, val_loader, criterion, device):
model.eval()
running_loss = 0.0
sample_count = 0
for img, label, image_name in tqdm(val_loader, desc="Validate", bar_format="{l_bar}{bar:20}{r_bar}"):
img = img.to(device, non_blocking=True)
label = label.to(device, non_blocking=True).float().view(-1)
output = model(img).view(-1)
loss = criterion(output, label)
batch_size = img.size(0)
running_loss += loss.item() * batch_size
sample_count += batch_size
epoch_loss = running_loss / sample_count
return epoch_loss
def main():
# =========================
# 1. 读取配置
# =========================
parser = argparse.ArgumentParser(description="Train DPNet")
parser.add_argument("--config", type=str, required=True, help="Path to TOML config file")
args = parser.parse_args()
config_path = Path(args.config)
with config_path.open("rb") as f:
config = tomllib.load(f)
xlsx_files = config["xlsx_files"]
batch_size = config["batch_size"]
learning_rate = config["learning_rate"]
epochs = config["epochs"]
num_workers = config["num_workers"]
seed = config["seed"]
init_weight = config["init_weight"]
# =========================
# 2. 创建输出目录
# =========================
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_dpn")
run_dir = Path.cwd() / run_name
run_dir.mkdir(parents=True, exist_ok=False)
shutil.copy2(config_path, run_dir / config_path.name)
# =========================
# 3. 初始化日志与 TensorBoard
# =========================
logger = logging.getLogger("dpnet_train")
logger.setLevel(logging.INFO)
logger.propagate = False
logger.handlers.clear()
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
file_handler = logging.FileHandler(run_dir / "train.log", encoding="utf-8")
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
writer = SummaryWriter(log_dir=str(run_dir))
logger.info(f"Run directory: {run_dir}")
logger.info(f"Config path: {config_path}")
logger.info(f"Loaded config:\n{pformat(config, sort_dicts=False)}")
# =========================
# 4. 设置随机种子与设备
# =========================
utils.set_seeds(seed=seed)
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
logger.info(f"Using device: {device}")
# =========================
# 5. 准备数据集与 DataLoader
# =========================
train_image_path_list, train_defocus_distance_list, val_image_path_list, val_defocus_distance_list = \
dataset_F.get_DPNet_train_data_and_label(root_path_list=xlsx_files)
train_transform = transforms.Compose([
transforms.ColorJitter(
brightness=(0.9, 1.4),
contrast=(0.8, 1.5),
saturation=(0.8, 1.5),
),
transforms.ToTensor(),
])
val_transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = dataset_F.MyDataset(
train_image_path_list,
train_defocus_distance_list,
train_transform,
)
val_dataset = dataset_F.MyDataset(
val_image_path_list,
val_defocus_distance_list,
val_transform,
)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=(device.type == "cuda"),
persistent_workers=(num_workers > 0),
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=(device.type == "cuda"),
persistent_workers=(num_workers > 0),
)
logger.info(f"Train dataset size: {len(train_dataset)}")
logger.info(f"Val dataset size: {len(val_dataset)}")
logger.info(f"Train steps per epoch: {len(train_loader)}")
# =========================
# 6. 准备模型
# =========================
model = DPNet().to(device)
if init_weight:
state_dict = torch.load(init_weight, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
logger.info(f"Loaded init weight from: {init_weight}")
else:
logger.info("Training from scratch")
# =========================
# 7. 准备损失函数、优化器、调度器
# =========================
criterion = nn.MSELoss(reduction="mean")
optimizer = Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
total_steps = epochs * len(train_loader)
warmup_steps = min(2000, total_steps - 1) if total_steps > 1 else 0
scheduler = SequentialLR(
optimizer,
schedulers=[
LinearLR(
optimizer,
start_factor=0.1,
end_factor=1.0,
total_iters=warmup_steps,
),
CosineAnnealingLR(
optimizer,
T_max=max(1, total_steps - warmup_steps),
eta_min=0.0,
),
],
milestones=[warmup_steps],
)
logger.info("Loss: MSELoss(reduction='mean')")
logger.info(f"Optimizer: Adam(lr={learning_rate}, betas=(0.9, 0.999))")
logger.info("Scheduler: step-based warmup + cosine annealing")
logger.info(f"Total training steps: {total_steps}")
logger.info(f"Warmup steps: {warmup_steps}")
# =========================
# 8. 开始训练循环
# =========================
best_val_loss = float("inf")
global_step = 0
start_time = time.time()
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train_loss, global_step = train_epoch(
model=model,
train_loader=train_loader,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
device=device,
writer=writer,
global_step=global_step,
)
val_loss = validate_epoch(
model=model,
val_loader=val_loader,
criterion=criterion,
device=device,
)
epoch_total_time = time.time() - epoch_start_time
writer.add_scalar("train_loss", train_loss, epoch)
writer.add_scalar("val_loss", val_loss, epoch)
writer.add_scalar("epoch_total_time", epoch_total_time, epoch)
logger.info(
f"Epoch [{epoch}/{epochs}] | "
f"train_loss={train_loss:.8f} | "
f"val_loss={val_loss:.8f} | "
f"epoch_total_time={epoch_total_time:.2f}s | "
f"global_step={global_step}"
)
torch.save(model.state_dict(), run_dir / "last_dpn.pth")
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), run_dir / "best_dpn.pth")
logger.info(f"Best model updated, best_val_loss={best_val_loss:.8f}")
# =========================
# 9. 收尾
# =========================
total_time = time.time() - start_time
logger.info("Training finished")
logger.info(f"Best validation loss: {best_val_loss:.8f}")
logger.info(f"Total time: {total_time:.2f} seconds")
writer.close()
if __name__ == "__main__":
main()