218 lines
6.8 KiB
Python
218 lines
6.8 KiB
Python
import shutil
|
|
import time
|
|
from tqdm import tqdm
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.optim import Adam
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from models import RINet
|
|
import old_datasets as dataset_F
|
|
from datasets import RIN_Dataset, RINPairTransform
|
|
import utils
|
|
|
|
|
|
# 训练一轮
|
|
def train_epoch(model, loader, criterion, optimizer, device):
|
|
model.train()
|
|
running_loss = 0.0
|
|
total_samples = 0
|
|
|
|
for images, labels in tqdm(
|
|
loader,
|
|
desc="Train:",
|
|
bar_format="{l_bar}{bar:20}{r_bar}",
|
|
leave=False,
|
|
):
|
|
images = images.to(device, non_blocking=True)
|
|
labels = labels.to(device, non_blocking=True).view(-1)
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
outputs = model(images).view(-1)
|
|
loss = criterion(outputs, labels)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
this_batch_size = images.size(0)
|
|
running_loss += loss.item() * this_batch_size
|
|
total_samples += this_batch_size
|
|
|
|
epoch_loss = running_loss / total_samples
|
|
return epoch_loss
|
|
|
|
|
|
# 验证一轮
|
|
@torch.no_grad()
|
|
def valid_epoch(model, loader, criterion, device):
|
|
model.eval()
|
|
|
|
running_loss = 0.0
|
|
total_samples = 0
|
|
|
|
for images, labels in tqdm(
|
|
loader,
|
|
desc=f"Valid:",
|
|
bar_format="{l_bar}{bar:20}{r_bar}",
|
|
leave=False,
|
|
):
|
|
images = images.to(device, non_blocking=True)
|
|
labels = labels.to(device, non_blocking=True).view(-1)
|
|
|
|
outputs = model(images).view(-1)
|
|
loss = criterion(outputs, labels)
|
|
|
|
this_batch_size = images.size(0)
|
|
running_loss += loss.item() * this_batch_size
|
|
total_samples += this_batch_size
|
|
|
|
epoch_loss = running_loss / total_samples
|
|
|
|
return epoch_loss
|
|
|
|
|
|
# 主训练函数
|
|
def main():
|
|
# ========== 1 配置文件与超参数 ==========
|
|
config, config_path = utils.get_hyperparams()
|
|
|
|
XLSX_FILES = config["xlsx_files"]
|
|
BATCH_SIZE = config["batch_size"]
|
|
NUM_WORKERS = config["num_workers"]
|
|
LEARNING_RATE = config["learning_rate"]
|
|
NUM_EPOCHS = config["epochs"]
|
|
SEED = config["seed"]
|
|
INIT_WEIGHT_PATH = config["init_weight"]
|
|
|
|
# ========== 2 创建输出文件目录 ==========
|
|
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_RIN")
|
|
run_dir = Path.cwd() / run_name
|
|
run_dir.mkdir(parents=True, exist_ok=False)
|
|
shutil.copy2(config_path, run_dir / config_path.name)
|
|
|
|
# ========== 3 日志、tensorboard、随机种子与设备 ==========
|
|
logger = utils.get_logger(__name__, run_dir / "train.log")
|
|
writer = SummaryWriter(str(run_dir / "run"))
|
|
utils.set_seeds(SEED)
|
|
device = torch.device("cuda:0")
|
|
|
|
logger.info(f"Config path: {config_path}")
|
|
logger.info(f"Loaded config: {str(config)}")
|
|
logger.info(f"Run directory: {run_dir}")
|
|
logger.info(f"Using device: {device}")
|
|
|
|
# ========== 4 数据与 loader ==========
|
|
train_image_path_list, train_patch_effective_list = (
|
|
dataset_F.get_RINet_data(XLSX_FILES[0], "train")
|
|
)
|
|
valid_image_path_list, valid_patch_effective_list = (
|
|
dataset_F.get_RINet_data(XLSX_FILES[0], "val")
|
|
)
|
|
|
|
train_transform = RINPairTransform(train=True, image_size=512)
|
|
valid_transform = RINPairTransform(train=False, image_size=512)
|
|
|
|
train_set = RIN_Dataset(
|
|
train_image_path_list,
|
|
train_patch_effective_list,
|
|
train_transform,
|
|
)
|
|
valid_set = RIN_Dataset(
|
|
valid_image_path_list,
|
|
valid_patch_effective_list,
|
|
valid_transform,
|
|
)
|
|
|
|
train_loader = DataLoader(
|
|
dataset=train_set,
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=True,
|
|
num_workers=NUM_WORKERS,
|
|
pin_memory=True,
|
|
persistent_workers=(NUM_WORKERS > 0),
|
|
)
|
|
valid_loader = DataLoader(
|
|
dataset=valid_set,
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=False,
|
|
num_workers=NUM_WORKERS,
|
|
pin_memory=True,
|
|
persistent_workers=(NUM_WORKERS > 0),
|
|
)
|
|
|
|
logger.info(f"Train dataset size: {len(train_set)}")
|
|
logger.info(f"Val dataset size: {len(valid_set)}")
|
|
logger.info(f"Train steps per epoch: {len(train_loader)}")
|
|
|
|
# ========== 5 模型、损失、优化器、调度器 ==========
|
|
model = RINet().to(device)
|
|
if INIT_WEIGHT_PATH:
|
|
state_dict = torch.load(INIT_WEIGHT_PATH, map_location="cpu")
|
|
model.load_state_dict(state_dict, strict=True)
|
|
logger.info(f"Loaded init weight from: {INIT_WEIGHT_PATH}")
|
|
else:
|
|
logger.info("Training from scratch")
|
|
|
|
criterion = nn.BCELoss()
|
|
optimizer = Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
|
|
scheduler = utils.get_warmup_cosine_scheduler(optimizer, NUM_EPOCHS)
|
|
|
|
logger.info("Loss: BCELoss()")
|
|
logger.info(f"Optimizer: Adam(lr={LEARNING_RATE}, betas=(0.9, 0.999))")
|
|
logger.info("Scheduler: epoch-based warmup + cosine annealing")
|
|
|
|
# ========== 6 开始训练 ==========
|
|
logger.info("START TRAINING")
|
|
best_valid_loss = float("inf")
|
|
|
|
try:
|
|
for epoch in range(1, NUM_EPOCHS + 1):
|
|
epoch_start_time = time.time()
|
|
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
|
|
valid_loss = valid_epoch(model, valid_loader, criterion, device)
|
|
epoch_lr = optimizer.param_groups[0]["lr"] # 当前轮学习率
|
|
scheduler.step()
|
|
epoch_time_cost = time.time() - epoch_start_time
|
|
|
|
# 如果更好则保存
|
|
if valid_loss < best_valid_loss:
|
|
best_valid_loss = valid_loss
|
|
torch.save(model.state_dict(), run_dir / "best_model.pt")
|
|
logger.info(f"Best model saved, valid_loss = {best_valid_loss:.4f}")
|
|
|
|
# 日志与 tensorboard
|
|
logger.info(
|
|
f"Epoch [{epoch}/{NUM_EPOCHS}] "
|
|
f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f} | "
|
|
f"Best Valid Loss: {best_valid_loss:.4f} | "
|
|
f"Epoch Time Cost: {epoch_time_cost:.2f} s | "
|
|
f"Epoch Learning Rate: {epoch_lr:.6e}"
|
|
)
|
|
|
|
writer.add_scalar("Loss/train", train_loss, epoch)
|
|
writer.add_scalar("Loss/valid", valid_loss, epoch)
|
|
writer.add_scalar("Loss/best_valid", best_valid_loss, epoch)
|
|
writer.add_scalar("Time/epoch", epoch_time_cost, epoch)
|
|
writer.add_scalar("Time/learning_rate", epoch_lr, epoch)
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Training interrupted by user")
|
|
|
|
finally:
|
|
torch.save(model.state_dict(), run_dir / "last_model.pt")
|
|
logger.info("Last model saved")
|
|
|
|
writer.close()
|
|
logger.info("TensorBoard writer closed")
|
|
|
|
logger.info(f"Training finished, best validation loss: {best_valid_loss:.8f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|