307 lines
9.0 KiB
Python
307 lines
9.0 KiB
Python
import argparse
|
|
import logging
|
|
import shutil
|
|
import time
|
|
from tqdm import tqdm
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from pprint import pformat
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.transforms as transforms
|
|
from torch.optim import Adam
|
|
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import tomllib
|
|
|
|
from models import DPNet
|
|
import old_datasets as dataset_F
|
|
import utils
|
|
|
|
|
|
def train_epoch(model, train_loader, criterion, optimizer, scheduler, device, writer, global_step):
|
|
model.train()
|
|
|
|
running_loss = 0.0
|
|
sample_count = 0
|
|
|
|
for img, label, image_name in tqdm(train_loader, desc="Train", bar_format="{l_bar}{bar:20}{r_bar}"):
|
|
img = img.to(device, non_blocking=True)
|
|
label = label.to(device, non_blocking=True).float().view(-1)
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
output = model(img).view(-1)
|
|
loss = criterion(output, label)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
global_step += 1
|
|
lr = optimizer.param_groups[0]["lr"]
|
|
writer.add_scalar("lr", lr, global_step)
|
|
|
|
batch_size = img.size(0)
|
|
running_loss += loss.item() * batch_size
|
|
sample_count += batch_size
|
|
|
|
epoch_loss = running_loss / sample_count
|
|
return epoch_loss, global_step
|
|
|
|
|
|
@torch.no_grad()
|
|
def validate_epoch(model, val_loader, criterion, device):
|
|
model.eval()
|
|
|
|
running_loss = 0.0
|
|
sample_count = 0
|
|
|
|
for img, label, image_name in tqdm(val_loader, desc="Validate", bar_format="{l_bar}{bar:20}{r_bar}"):
|
|
img = img.to(device, non_blocking=True)
|
|
label = label.to(device, non_blocking=True).float().view(-1)
|
|
|
|
output = model(img).view(-1)
|
|
loss = criterion(output, label)
|
|
|
|
batch_size = img.size(0)
|
|
running_loss += loss.item() * batch_size
|
|
sample_count += batch_size
|
|
|
|
epoch_loss = running_loss / sample_count
|
|
return epoch_loss
|
|
|
|
|
|
def main():
|
|
# =========================
|
|
# 1. 读取配置
|
|
# =========================
|
|
parser = argparse.ArgumentParser(description="Train DPNet")
|
|
parser.add_argument("--config", type=str, required=True, help="Path to TOML config file")
|
|
args = parser.parse_args()
|
|
|
|
config_path = Path(args.config)
|
|
with config_path.open("rb") as f:
|
|
config = tomllib.load(f)
|
|
|
|
xlsx_files = config["xlsx_files"]
|
|
batch_size = config["batch_size"]
|
|
learning_rate = config["learning_rate"]
|
|
epochs = config["epochs"]
|
|
num_workers = config["num_workers"]
|
|
seed = config["seed"]
|
|
init_weight = config["init_weight"]
|
|
|
|
# =========================
|
|
# 2. 创建输出目录
|
|
# =========================
|
|
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_dpn")
|
|
run_dir = Path.cwd() / run_name
|
|
run_dir.mkdir(parents=True, exist_ok=False)
|
|
|
|
shutil.copy2(config_path, run_dir / config_path.name)
|
|
|
|
# =========================
|
|
# 3. 初始化日志与 TensorBoard
|
|
# =========================
|
|
logger = logging.getLogger("dpnet_train")
|
|
logger.setLevel(logging.INFO)
|
|
logger.propagate = False
|
|
logger.handlers.clear()
|
|
|
|
formatter = logging.Formatter(
|
|
fmt="%(asctime)s | %(levelname)s | %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
|
|
file_handler = logging.FileHandler(run_dir / "train.log", encoding="utf-8")
|
|
file_handler.setLevel(logging.INFO)
|
|
file_handler.setFormatter(formatter)
|
|
|
|
stream_handler = logging.StreamHandler()
|
|
stream_handler.setLevel(logging.INFO)
|
|
stream_handler.setFormatter(formatter)
|
|
|
|
logger.addHandler(file_handler)
|
|
logger.addHandler(stream_handler)
|
|
|
|
writer = SummaryWriter(log_dir=str(run_dir))
|
|
|
|
logger.info(f"Run directory: {run_dir}")
|
|
logger.info(f"Config path: {config_path}")
|
|
logger.info(f"Loaded config:\n{pformat(config, sort_dicts=False)}")
|
|
|
|
# =========================
|
|
# 4. 设置随机种子与设备
|
|
# =========================
|
|
utils.set_seeds(seed=seed)
|
|
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda:0")
|
|
else:
|
|
device = torch.device("cpu")
|
|
|
|
logger.info(f"Using device: {device}")
|
|
|
|
# =========================
|
|
# 5. 准备数据集与 DataLoader
|
|
# =========================
|
|
train_image_path_list, train_defocus_distance_list, val_image_path_list, val_defocus_distance_list = \
|
|
dataset_F.get_DPNet_train_data_and_label(root_path_list=xlsx_files)
|
|
|
|
train_transform = transforms.Compose([
|
|
transforms.ColorJitter(
|
|
brightness=(0.9, 1.4),
|
|
contrast=(0.8, 1.5),
|
|
saturation=(0.8, 1.5),
|
|
),
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
val_transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
train_dataset = dataset_F.MyDataset(
|
|
train_image_path_list,
|
|
train_defocus_distance_list,
|
|
train_transform,
|
|
)
|
|
val_dataset = dataset_F.MyDataset(
|
|
val_image_path_list,
|
|
val_defocus_distance_list,
|
|
val_transform,
|
|
)
|
|
|
|
train_loader = DataLoader(
|
|
dataset=train_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=num_workers,
|
|
pin_memory=(device.type == "cuda"),
|
|
persistent_workers=(num_workers > 0),
|
|
)
|
|
|
|
val_loader = DataLoader(
|
|
dataset=val_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
num_workers=num_workers,
|
|
pin_memory=(device.type == "cuda"),
|
|
persistent_workers=(num_workers > 0),
|
|
)
|
|
|
|
logger.info(f"Train dataset size: {len(train_dataset)}")
|
|
logger.info(f"Val dataset size: {len(val_dataset)}")
|
|
logger.info(f"Train steps per epoch: {len(train_loader)}")
|
|
|
|
# =========================
|
|
# 6. 准备模型
|
|
# =========================
|
|
model = DPNet().to(device)
|
|
|
|
if init_weight:
|
|
state_dict = torch.load(init_weight, map_location="cpu")
|
|
model.load_state_dict(state_dict, strict=True)
|
|
logger.info(f"Loaded init weight from: {init_weight}")
|
|
else:
|
|
logger.info("Training from scratch")
|
|
|
|
# =========================
|
|
# 7. 准备损失函数、优化器、调度器
|
|
# =========================
|
|
criterion = nn.MSELoss(reduction="mean")
|
|
|
|
optimizer = Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
|
|
|
|
total_steps = epochs * len(train_loader)
|
|
warmup_steps = min(2000, total_steps - 1) if total_steps > 1 else 0
|
|
|
|
scheduler = SequentialLR(
|
|
optimizer,
|
|
schedulers=[
|
|
LinearLR(
|
|
optimizer,
|
|
start_factor=0.1,
|
|
end_factor=1.0,
|
|
total_iters=warmup_steps,
|
|
),
|
|
CosineAnnealingLR(
|
|
optimizer,
|
|
T_max=max(1, total_steps - warmup_steps),
|
|
eta_min=0.0,
|
|
),
|
|
],
|
|
milestones=[warmup_steps],
|
|
)
|
|
|
|
logger.info("Loss: MSELoss(reduction='mean')")
|
|
logger.info(f"Optimizer: Adam(lr={learning_rate}, betas=(0.9, 0.999))")
|
|
logger.info("Scheduler: step-based warmup + cosine annealing")
|
|
logger.info(f"Total training steps: {total_steps}")
|
|
logger.info(f"Warmup steps: {warmup_steps}")
|
|
|
|
# =========================
|
|
# 8. 开始训练循环
|
|
# =========================
|
|
best_val_loss = float("inf")
|
|
global_step = 0
|
|
start_time = time.time()
|
|
|
|
for epoch in range(1, epochs + 1):
|
|
epoch_start_time = time.time()
|
|
|
|
train_loss, global_step = train_epoch(
|
|
model=model,
|
|
train_loader=train_loader,
|
|
criterion=criterion,
|
|
optimizer=optimizer,
|
|
scheduler=scheduler,
|
|
device=device,
|
|
writer=writer,
|
|
global_step=global_step,
|
|
)
|
|
|
|
val_loss = validate_epoch(
|
|
model=model,
|
|
val_loader=val_loader,
|
|
criterion=criterion,
|
|
device=device,
|
|
)
|
|
|
|
epoch_total_time = time.time() - epoch_start_time
|
|
|
|
writer.add_scalar("train_loss", train_loss, epoch)
|
|
writer.add_scalar("val_loss", val_loss, epoch)
|
|
writer.add_scalar("epoch_total_time", epoch_total_time, epoch)
|
|
|
|
logger.info(
|
|
f"Epoch [{epoch}/{epochs}] | "
|
|
f"train_loss={train_loss:.8f} | "
|
|
f"val_loss={val_loss:.8f} | "
|
|
f"epoch_total_time={epoch_total_time:.2f}s | "
|
|
f"global_step={global_step}"
|
|
)
|
|
|
|
torch.save(model.state_dict(), run_dir / "last_dpn.pth")
|
|
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
torch.save(model.state_dict(), run_dir / "best_dpn.pth")
|
|
logger.info(f"Best model updated, best_val_loss={best_val_loss:.8f}")
|
|
|
|
# =========================
|
|
# 9. 收尾
|
|
# =========================
|
|
total_time = time.time() - start_time
|
|
logger.info("Training finished")
|
|
logger.info(f"Best validation loss: {best_val_loss:.8f}")
|
|
logger.info(f"Total time: {total_time:.2f} seconds")
|
|
|
|
writer.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |