174 lines
7.0 KiB
Python
174 lines
7.0 KiB
Python
import torch
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
import time
|
|
import tomllib
|
|
import argparse
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from pathlib import Path
|
|
|
|
from MOAFUtils import print_with_timestamp
|
|
from MOAFDatasets import MOAFDataset
|
|
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
|
|
|
|
|
|
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
|
|
model.train()
|
|
train_loss = 0.0
|
|
|
|
for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"):
|
|
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
|
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
|
|
|
optimizer.zero_grad()
|
|
outputs = model(images, params)
|
|
loss = loss_fn(outputs.squeeze(1), labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
train_loss += loss.item()
|
|
|
|
return train_loss
|
|
|
|
|
|
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
|
|
model.eval()
|
|
val_loss = 0.0
|
|
|
|
with torch.no_grad():
|
|
for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"):
|
|
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
|
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
|
|
|
outputs = model(images, params)
|
|
loss = loss_fn(outputs.squeeze(1), labels)
|
|
val_loss += loss.item()
|
|
|
|
return val_loss
|
|
|
|
|
|
def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type):
|
|
best_val_loss = float('inf')
|
|
# !pip install tensorboard
|
|
with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer:
|
|
# Tensorboard 上显示模型结构
|
|
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
|
|
writer.add_graph(model, (dummy_input1, dummy_input2))
|
|
|
|
for epoch in range(epochs):
|
|
# 训练/验证
|
|
start_time = time.time()
|
|
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader)
|
|
avg_val_loss = valid_epoch(model,val_loader, epoch, epochs, device, loss_fn) / len(val_loader)
|
|
current_lr = optimizer.param_groups[0]['lr']
|
|
scheduler.step()
|
|
epoch_time = time.time() - start_time
|
|
|
|
# 显示记录数据
|
|
writer.add_scalar('Loss/train', avg_train_loss, epoch)
|
|
writer.add_scalar('Loss/val', avg_val_loss, epoch)
|
|
writer.add_scalar('LearningRate', current_lr, epoch)
|
|
writer.add_scalar('Time/epoch', epoch_time, epoch)
|
|
|
|
print_with_timestamp(f"Epoch {epoch+1}/{epochs} | "
|
|
f"Train Loss: {avg_train_loss:.6f} | "
|
|
f"Val Loss: {avg_val_loss:.6f} | "
|
|
f"LR: {current_lr:.2e} | "
|
|
f"Time: {epoch_time:.2f}s")
|
|
|
|
# 记录检查点
|
|
if avg_val_loss < best_val_loss:
|
|
best_val_loss = avg_val_loss
|
|
save_dict = {
|
|
"epoch": epoch,
|
|
"model_state_dict": model.state_dict(),
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
"scheduler_state_dict": scheduler.state_dict(),
|
|
"train_loss": avg_train_loss,
|
|
"val_loss": avg_val_loss
|
|
}
|
|
Path("ckpts").mkdir(exist_ok=True, parents=True)
|
|
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
|
|
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
|
|
|
|
|
|
def main():
|
|
# 解析命令行参数
|
|
parser = argparse.ArgumentParser(description="Train script that loads config from a TOML file.")
|
|
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
|
|
args = parser.parse_args()
|
|
with open(args.config, "rb") as f:
|
|
cfg = tomllib.load(f)
|
|
|
|
# 确定超参数
|
|
model_type = cfg["model_type"]
|
|
output_type = cfg["output_type"]
|
|
dataset_dir = cfg["dataset_dir"]
|
|
batch_size = int(cfg["batch_size"])
|
|
num_workers = int(cfg["num_workers"])
|
|
lr = float(cfg["lr"])
|
|
epochs = int(cfg["epochs"])
|
|
warmup_epochs = int(cfg["warmup_epochs"])
|
|
objective_params_list = cfg["train_objective_params_list"]
|
|
checkpoint_load = cfg["checkpoint_load"]
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
print_with_timestamp(f"Using device {device}")
|
|
|
|
# 加载数据集
|
|
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
|
|
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
|
|
print_with_timestamp("Dataset Done")
|
|
|
|
train_loader = DataLoader(
|
|
train_set, batch_size=batch_size, num_workers=num_workers,
|
|
shuffle=True, pin_memory=True, persistent_workers=True
|
|
)
|
|
val_loader = DataLoader(
|
|
val_set, batch_size=batch_size, num_workers=num_workers,
|
|
shuffle=False, pin_memory=True, persistent_workers=True
|
|
)
|
|
print_with_timestamp("Dataset Loaded")
|
|
|
|
# 模型选择
|
|
if "film" in model_type:
|
|
fusion_depth_list = [int(ch) for ch in model_type[4:]]
|
|
model = MOAFWithFiLM(fusion_depth_list).to(device)
|
|
elif "cca" in model_type:
|
|
fusion_depth_list = [int(ch) for ch in model_type[3:]]
|
|
model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device)
|
|
elif "se" in model_type:
|
|
fusion_depth_list = [int(ch) for ch in model_type[2:]]
|
|
model = MOAFWithSE(fusion_depth_list).to(device)
|
|
else:
|
|
model = MOAFNoFusion().to(device)
|
|
print_with_timestamp("Model Loaded")
|
|
|
|
# 形式化预训练参数加载
|
|
if checkpoint_load:
|
|
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
|
|
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
print_with_timestamp("Model Checkpoint Params Loaded")
|
|
else:
|
|
print_with_timestamp("Model Checkpoint Params Not Exist")
|
|
|
|
# 损失函数、优化器、学习率调度器
|
|
loss_fn = nn.HuberLoss()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
|
optimizer,
|
|
lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs
|
|
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
|
|
)
|
|
|
|
print_with_timestamp("Start trainning")
|
|
fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type)
|
|
print_with_timestamp("Training completed!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|