import torch from torch import nn from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter import time import tomllib import argparse import numpy as np from tqdm import tqdm from pathlib import Path from MOAFUtils import print_with_timestamp from MOAFDatasets import MOAFDataset from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn): model.train() train_loss = 0.0 for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"): images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) optimizer.zero_grad() outputs = model(images, params) loss = loss_fn(outputs.squeeze(1), labels) loss.backward() optimizer.step() train_loss += loss.item() return train_loss def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn): model.eval() val_loss = 0.0 with torch.no_grad(): for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"): images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) outputs = model(images, params) loss = loss_fn(outputs.squeeze(1), labels) val_loss += loss.item() return val_loss def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type): best_val_loss = float('inf') patience_counter = 0 # !pip install tensorboard with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer: # Tensorboard 上显示模型结构 dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device) writer.add_graph(model, (dummy_input1, dummy_input2)) for epoch in range(epochs): # 训练/验证 start_time = time.time() avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader) avg_val_loss = valid_epoch(model,val_loader, epoch, epochs, device, loss_fn) / len(val_loader) current_lr = optimizer.param_groups[0]['lr'] scheduler.step() epoch_time = time.time() - start_time # 显示记录数据 writer.add_scalar('Loss/train', avg_train_loss, epoch) writer.add_scalar('Loss/val', avg_val_loss, epoch) writer.add_scalar('LearningRate', current_lr, epoch) writer.add_scalar('Time/epoch', epoch_time, epoch) print_with_timestamp(f"Epoch {epoch+1}/{epochs} | " f"Train Loss: {avg_train_loss:.6f} | " f"Val Loss: {avg_val_loss:.6f} | " f"LR: {current_lr:.2e} | " f"Time: {epoch_time:.2f}s") # 记录检查点 if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss patience_counter = 0 save_dict = { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "train_loss": avg_train_loss, "val_loss": avg_val_loss } Path("ckpts").mkdir(exist_ok=True, parents=True) torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt") print_with_timestamp(f"New best model saved at epoch {epoch+1}") else: patience_counter += 1 if patience_counter > patience: print_with_timestamp(f"Early stopping at {epoch+1} epochs") break def main(): # 解析命令行参数 parser = argparse.ArgumentParser(description="Train script that loads config from a TOML file.") parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)") args = parser.parse_args() with open(args.config, "rb") as f: cfg = tomllib.load(f) # 确定超参数 model_type = cfg["model_type"] output_type = cfg["output_type"] dataset_dir = cfg["dataset_dir"] batch_size = int(cfg["batch_size"]) num_workers = int(cfg["num_workers"]) lr = float(cfg["lr"]) patience = int(cfg["patience"]) epochs = int(cfg["epochs"]) warmup_epochs = int(cfg["warmup_epochs"]) objective_params_list = cfg["train_objective_params_list"] checkpoint_load = cfg["checkpoint_load"] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print_with_timestamp(f"Using device {device}") # 加载数据集 train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type) val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type) print_with_timestamp("Dataset Done") train_loader = DataLoader( train_set, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=True, persistent_workers=True ) val_loader = DataLoader( val_set, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True, persistent_workers=True ) print_with_timestamp("Dataset Loaded") # 模型选择 if "film" in model_type: fusion_depth_list = [int(ch) for ch in model_type[4:]] model = MOAFWithFiLM(fusion_depth_list).to(device) elif "cca" in model_type: fusion_depth_list = [int(ch) for ch in model_type[3:]] model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device) elif "se" in model_type: fusion_depth_list = [int(ch) for ch in model_type[2:]] model = MOAFWithSE(fusion_depth_list).to(device) else: model = MOAFNoFusion().to(device) print_with_timestamp("Model Loaded") # 形式化预训练参数加载 if checkpoint_load: if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists(): checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) print_with_timestamp("Model Checkpoint Params Loaded") else: print_with_timestamp("Model Checkpoint Params Not Exist") # 损失函数、优化器、学习率调度器 loss_fn = nn.HuberLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))) ) print_with_timestamp("Start trainning") fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type) print_with_timestamp("Training completed!") if __name__ == "__main__": main()