MOAF/MOAFTrain.py
2025-11-01 17:03:18 +08:00

174 lines
7.0 KiB
Python

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
import tomllib
import argparse
import numpy as np
from tqdm import tqdm
from pathlib import Path
from MOAFUtils import print_with_timestamp
from MOAFDatasets import MOAFDataset
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
model.train()
train_loss = 0.0
for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=60):
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
optimizer.zero_grad()
outputs = model(images, params)
loss = loss_fn(outputs.squeeze(1), labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
return train_loss
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
model.eval()
val_loss = 0.0
with torch.no_grad():
for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=60):
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
outputs = model(images, params)
loss = loss_fn(outputs.squeeze(1), labels)
val_loss += loss.item()
return val_loss
def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type):
best_val_loss = float('inf')
# !pip install tensorboard
with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer:
# Tensorboard 上显示模型结构
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
writer.add_graph(model, (dummy_input1, dummy_input2))
for epoch in range(epochs):
# 训练/验证
start_time = time.time()
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader)
avg_val_loss = valid_epoch(model,val_loader, epoch, epochs, device, loss_fn) / len(val_loader)
current_lr = optimizer.param_groups[0]['lr']
scheduler.step()
epoch_time = time.time() - start_time
# 显示记录数据
writer.add_scalar('Loss/train', avg_train_loss, epoch)
writer.add_scalar('Loss/val', avg_val_loss, epoch)
writer.add_scalar('LearningRate', current_lr, epoch)
writer.add_scalar('Time/epoch', epoch_time, epoch)
print_with_timestamp(f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {avg_train_loss:.6f} | "
f"Val Loss: {avg_val_loss:.6f} | "
f"LR: {current_lr:.2e} | "
f"Time: {epoch_time:.2f}s")
# 记录检查点
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
save_dict = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"train_loss": avg_train_loss,
"val_loss": avg_val_loss
}
Path("ckpts").mkdir(exist_ok=True, parents=True)
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description="Train script that loads config from a TOML file.")
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
args = parser.parse_args()
with open(args.config, "rb") as f:
cfg = tomllib.load(f)
# 确定超参数
model_type = cfg["model_type"]
output_type = cfg["output_type"]
dataset_dir = cfg["dataset_dir"]
batch_size = int(cfg["batch_size"])
num_workers = int(cfg["num_workers"])
lr = float(cfg["lr"])
epochs = int(cfg["epochs"])
warmup_epochs = int(cfg["warmup_epochs"])
objective_params_list = cfg["train_objective_params_list"]
checkpoint_load = cfg["checkpoint_load"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_with_timestamp(f"Using device {device}")
# 加载数据集
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
print_with_timestamp("Dataset Done")
train_loader = DataLoader(
train_set, batch_size=batch_size, num_workers=num_workers,
shuffle=True, pin_memory=True, persistent_workers=True
)
val_loader = DataLoader(
val_set, batch_size=batch_size, num_workers=num_workers,
shuffle=False, pin_memory=True, persistent_workers=True
)
print_with_timestamp("Dataset Loaded")
# 模型选择
if "film" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[4:]]
model = MOAFWithFiLM(fusion_depth_list).to(device)
elif "cca" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[3:]]
model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device)
elif "se" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[2:]]
model = MOAFWithSE(fusion_depth_list).to(device)
else:
model = MOAFNoFusion().to(device)
print_with_timestamp("Model Loaded")
# 形式化预训练参数加载
if checkpoint_load:
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
print_with_timestamp("Model Checkpoint Params Loaded")
else:
print_with_timestamp("Model Checkpoint Params Not Exist")
# 损失函数、优化器、学习率调度器
loss_fn = nn.HuberLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
)
print_with_timestamp("Start trainning")
fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type)
print_with_timestamp("Training completed!")
if __name__ == "__main__":
main()