SparseFocus/train_rin.py

218 lines
6.8 KiB
Python

import shutil
import time
from tqdm import tqdm
from datetime import datetime
from pathlib import Path
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from models import RINet
import old_datasets as dataset_F
from datasets import RIN_Dataset, RINPairTransform
import utils
# 训练一轮
def train_epoch(model, loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
total_samples = 0
for images, labels in tqdm(
loader,
desc="Train:",
bar_format="{l_bar}{bar:20}{r_bar}",
leave=False,
):
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).view(-1)
optimizer.zero_grad(set_to_none=True)
outputs = model(images).view(-1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
this_batch_size = images.size(0)
running_loss += loss.item() * this_batch_size
total_samples += this_batch_size
epoch_loss = running_loss / total_samples
return epoch_loss
# 验证一轮
@torch.no_grad()
def valid_epoch(model, loader, criterion, device):
model.eval()
running_loss = 0.0
total_samples = 0
for images, labels in tqdm(
loader,
desc=f"Valid:",
bar_format="{l_bar}{bar:20}{r_bar}",
leave=False,
):
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).view(-1)
outputs = model(images).view(-1)
loss = criterion(outputs, labels)
this_batch_size = images.size(0)
running_loss += loss.item() * this_batch_size
total_samples += this_batch_size
epoch_loss = running_loss / total_samples
return epoch_loss
# 主训练函数
def main():
# ========== 1 配置文件与超参数 ==========
config, config_path = utils.get_hyperparams()
XLSX_FILES = config["xlsx_files"]
BATCH_SIZE = config["batch_size"]
NUM_WORKERS = config["num_workers"]
LEARNING_RATE = config["learning_rate"]
NUM_EPOCHS = config["epochs"]
SEED = config["seed"]
INIT_WEIGHT_PATH = config["init_weight"]
# ========== 2 创建输出文件目录 ==========
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_rin")
run_dir = Path.cwd() / run_name
run_dir.mkdir(parents=True, exist_ok=False)
shutil.copy2(config_path, run_dir / config_path.name)
# ========== 3 日志、tensorboard、随机种子与设备 ==========
logger = utils.get_logger(__name__, run_dir / "train.log")
writer = SummaryWriter(str(run_dir / "run"))
utils.set_seeds(SEED)
device = torch.device("cuda:0")
logger.info(f"Config path: {config_path}")
logger.info(f"Loaded config: {str(config)}")
logger.info(f"Run directory: {run_dir}")
logger.info(f"Using device: {device}")
# ========== 4 数据与 loader ==========
train_image_path_list, train_patch_effective_list = (
dataset_F.get_RINet_data(XLSX_FILES[0], "train")
)
valid_image_path_list, valid_patch_effective_list = (
dataset_F.get_RINet_data(XLSX_FILES[0], "val")
)
train_transform = RINPairTransform(train=True, image_size=512)
valid_transform = RINPairTransform(train=False, image_size=512)
train_set = RIN_Dataset(
train_image_path_list,
train_patch_effective_list,
train_transform,
)
valid_set = RIN_Dataset(
valid_image_path_list,
valid_patch_effective_list,
valid_transform,
)
train_loader = DataLoader(
dataset=train_set,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=True,
persistent_workers=(NUM_WORKERS > 0),
)
valid_loader = DataLoader(
dataset=valid_set,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True,
persistent_workers=(NUM_WORKERS > 0),
)
logger.info(f"Train dataset size: {len(train_set)}")
logger.info(f"Val dataset size: {len(valid_set)}")
logger.info(f"Train steps per epoch: {len(train_loader)}")
# ========== 5 模型、损失、优化器、调度器 ==========
model = RINet().to(device)
if INIT_WEIGHT_PATH:
state_dict = torch.load(INIT_WEIGHT_PATH, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
logger.info(f"Loaded init weight from: {INIT_WEIGHT_PATH}")
else:
logger.info("Training from scratch")
criterion = nn.BCELoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
scheduler = utils.get_warmup_cosine_scheduler(optimizer, NUM_EPOCHS)
logger.info("Loss: BCELoss()")
logger.info(f"Optimizer: Adam(lr={LEARNING_RATE}, betas=(0.9, 0.999))")
logger.info("Scheduler: epoch-based warmup + cosine annealing")
# ========== 6 开始训练 ==========
logger.info("START TRAINING")
best_valid_loss = float("inf")
try:
for epoch in range(1, NUM_EPOCHS + 1):
epoch_start_time = time.time()
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
valid_loss = valid_epoch(model, valid_loader, criterion, device)
epoch_lr = optimizer.param_groups[0]["lr"] # 当前轮学习率
scheduler.step()
epoch_time_cost = time.time() - epoch_start_time
# 如果更好则保存
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), run_dir / "best_model.pt")
logger.info(f"Best model saved, valid_loss = {best_valid_loss:.4f}")
# 日志与 tensorboard
logger.info(
f"Epoch [{epoch}/{NUM_EPOCHS}] "
f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f} | "
f"Best Valid Loss: {best_valid_loss:.4f} | "
f"Epoch Time Cost: {epoch_time_cost:.2f} s | "
f"Epoch Learning Rate: {epoch_lr:.6e}"
)
writer.add_scalar("Loss/train", train_loss, epoch)
writer.add_scalar("Loss/valid", valid_loss, epoch)
writer.add_scalar("Loss/best_valid", best_valid_loss, epoch)
writer.add_scalar("Time/epoch", epoch_time_cost, epoch)
writer.add_scalar("Time/learning_rate", epoch_lr, epoch)
except KeyboardInterrupt:
logger.info("Training interrupted by user")
finally:
torch.save(model.state_dict(), run_dir / "last_model.pt")
logger.info("Last model saved")
writer.close()
logger.info("TensorBoard writer closed")
logger.info(f"Training finished, best validation loss: {best_valid_loss:.8f}")
if __name__ == "__main__":
main()