235 lines
9.3 KiB
Python
235 lines
9.3 KiB
Python
import os
|
||
import time
|
||
import argparse
|
||
import tomllib
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
from tqdm import tqdm
|
||
|
||
import torch
|
||
import torch.distributed as dist
|
||
import torch.multiprocessing as mp
|
||
from torch import nn
|
||
from torch.utils.data import DataLoader
|
||
from torch.utils.data.distributed import DistributedSampler
|
||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
|
||
from MOAFUtils import print_with_timestamp
|
||
from MOAFDatasets import MOAFDataset
|
||
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE, MOAFWithMMLP
|
||
|
||
|
||
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank):
|
||
model.train()
|
||
train_loss = 0.0
|
||
|
||
data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=180) if rank in {-1, 0} else train_loader
|
||
|
||
for data in data_iter:
|
||
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
||
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
||
|
||
optimizer.zero_grad()
|
||
outputs = model(images, params)
|
||
loss = loss_fn(outputs.squeeze(1), labels)
|
||
loss.backward()
|
||
optimizer.step()
|
||
train_loss += loss.item()
|
||
|
||
return train_loss
|
||
|
||
|
||
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank):
|
||
model.eval()
|
||
val_loss = 0.0
|
||
|
||
data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=180) if rank in {-1, 0} else val_loader
|
||
|
||
with torch.no_grad():
|
||
for data in data_iter:
|
||
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
||
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
||
|
||
outputs = model(images, params)
|
||
loss = loss_fn(outputs.squeeze(1), labels)
|
||
val_loss += loss.item()
|
||
|
||
return val_loss
|
||
|
||
|
||
def fit(rank, world_size, cfg):
|
||
# 初始化分布式参数
|
||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||
os.environ['MASTER_PORT'] = cfg.get("master_port", "29500")
|
||
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
||
torch.cuda.set_device(rank)
|
||
device = torch.device(f'cuda:{rank}')
|
||
if rank == 0:
|
||
print_with_timestamp(f"Distributed initialized. World size: {world_size}")
|
||
|
||
# 确定超参数
|
||
model_type = cfg["model_type"]
|
||
dataset_type = cfg["dataset_type"]
|
||
dataset_dir = cfg["dataset_dir"]
|
||
batch_size = int(cfg["batch_size"])
|
||
num_workers = int(cfg["num_workers"])
|
||
lr = float(cfg["lr"])
|
||
epochs = int(cfg["epochs"])
|
||
warmup_epochs = int(cfg["warmup_epochs"])
|
||
objective_params_list = cfg["train_objective_params_list"]
|
||
checkpoint_load = cfg["checkpoint_load"]
|
||
|
||
# 加载数据集
|
||
train_set = MOAFDataset(dataset_dir, "train", objective_params_list)
|
||
val_set = MOAFDataset(dataset_dir, "val", objective_params_list)
|
||
|
||
# 分布式化数据集
|
||
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
|
||
val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False)
|
||
|
||
train_loader = DataLoader(
|
||
train_set, batch_size=batch_size, num_workers=num_workers,
|
||
shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler
|
||
)
|
||
val_loader = DataLoader(
|
||
val_set, batch_size=batch_size, num_workers=num_workers,
|
||
shuffle=False, pin_memory=True, persistent_workers=True, sampler=val_sampler
|
||
)
|
||
|
||
if rank == 0:
|
||
print_with_timestamp("Dataset Loaded (Distributed)")
|
||
|
||
# 模型选择
|
||
if "film" in model_type:
|
||
fusion_depth_list = [int(ch) for ch in model_type[4:]]
|
||
model = MOAFWithFiLM(fusion_depth_list).to(device)
|
||
elif "cca" in model_type:
|
||
fusion_depth_list = [int(ch) for ch in model_type[3:]]
|
||
model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device)
|
||
elif "se" in model_type:
|
||
fusion_depth_list = [int(ch) for ch in model_type[2:]]
|
||
model = MOAFWithSE(fusion_depth_list).to(device)
|
||
elif "mmlp" in model_type:
|
||
model = MOAFWithMMLP().to(device)
|
||
else:
|
||
model = MOAFNoFusion().to(device)
|
||
|
||
# 形式化预训练参数加载
|
||
if checkpoint_load:
|
||
if Path(f"ckpts/{model_type}_{dataset_type}_best_model.pt").exists():
|
||
checkpoint = torch.load(f"ckpts/{model_type}_{dataset_type}_best_model.pt", map_location=device, weights_only=False)
|
||
model.load_state_dict(checkpoint["model_state_dict"])
|
||
if rank == 0:
|
||
print_with_timestamp("Model Checkpoint Params Loaded")
|
||
else:
|
||
if rank == 0:
|
||
print_with_timestamp("Model Checkpoint Params Not Exist")
|
||
|
||
# DDP 化模型
|
||
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
|
||
if rank == 0:
|
||
print_with_timestamp("Model Wrapped with DDP")
|
||
|
||
# 损失函数、优化器、学习率调度器
|
||
loss_fn = nn.HuberLoss()
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||
optimizer,
|
||
lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs
|
||
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
|
||
)
|
||
|
||
# Tensorboard 上显示模型结构
|
||
if rank == 0:
|
||
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{dataset_type}")
|
||
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
|
||
tb_writer.add_graph(model.module, (dummy_input1, dummy_input2))
|
||
|
||
else:
|
||
tb_writer = None
|
||
|
||
# 训练
|
||
best_val_loss = float('inf')
|
||
|
||
if rank == 0:
|
||
print_with_timestamp("Start training (DDP)")
|
||
|
||
for epoch in range(epochs):
|
||
# 必须在每个 epoch 开头设置 epoch 给 sampler,这样 shuffle 能跨 epoch 工作
|
||
train_sampler.set_epoch(epoch)
|
||
val_sampler.set_epoch(epoch)
|
||
|
||
start_time = time.time()
|
||
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank) / len(train_loader)
|
||
avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank) / len(val_loader)
|
||
current_lr = optimizer.param_groups[0]['lr']
|
||
scheduler.step()
|
||
epoch_time = time.time() - start_time
|
||
|
||
# 只有 rank 0 写 tensorboard 和保存 checkpoint,避免重复 IO
|
||
if rank == 0:
|
||
tb_writer.add_scalar('Loss/train', avg_train_loss, epoch)
|
||
tb_writer.add_scalar('Loss/val', avg_val_loss, epoch)
|
||
tb_writer.add_scalar('LearningRate', current_lr, epoch)
|
||
tb_writer.add_scalar('Time/epoch', epoch_time, epoch)
|
||
|
||
print_with_timestamp(f"Epoch {epoch+1}/{epochs} | "
|
||
f"Train Loss: {avg_train_loss:.6f} | "
|
||
f"Val Loss: {avg_val_loss:.6f} | "
|
||
f"LR: {current_lr:.2e} | "
|
||
f"Time: {epoch_time:.2f}s")
|
||
|
||
if avg_val_loss < best_val_loss:
|
||
best_val_loss = avg_val_loss
|
||
save_dict = {
|
||
"epoch": epoch,
|
||
# 保存 module.state_dict()(DDP 包裹时用 module)
|
||
"model_state_dict": model.module.state_dict(),
|
||
"optimizer_state_dict": optimizer.state_dict(),
|
||
"scheduler_state_dict": scheduler.state_dict(),
|
||
"train_loss": avg_train_loss,
|
||
"val_loss": avg_val_loss
|
||
}
|
||
Path("ckpts").mkdir(exist_ok=True, parents=True)
|
||
torch.save(save_dict, f"ckpts/{model_type}_{dataset_type}_best_model.pt")
|
||
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
|
||
|
||
# 清除进程
|
||
if tb_writer is not None:
|
||
tb_writer.close()
|
||
dist.destroy_process_group()
|
||
if rank == 0:
|
||
print_with_timestamp("Training completed and distributed cleaned up")
|
||
|
||
|
||
def parse_args_and_cfg():
|
||
parser = argparse.ArgumentParser(description="DDP train script that loads config from a TOML file.")
|
||
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
|
||
parser.add_argument("--nproc_per_node", type=int, default=None, help="number of processes (GPUs) to use on this node")
|
||
args = parser.parse_args()
|
||
with open(args.config, "rb") as f:
|
||
cfg = tomllib.load(f)
|
||
return args, cfg
|
||
|
||
|
||
def main():
|
||
args, cfg = parse_args_and_cfg()
|
||
# 自动检测可用 GPU 数量,除非用户指定 nproc_per_node
|
||
available_gpus = torch.cuda.device_count()
|
||
if available_gpus == 0:
|
||
raise RuntimeError("No CUDA devices available for DDP. Use a GPU machine.")
|
||
|
||
nproc = args.nproc_per_node if args.nproc_per_node is not None else available_gpus
|
||
if nproc > available_gpus:
|
||
raise ValueError(f"Requested {nproc} processes but only {available_gpus} GPUs available")
|
||
|
||
# 使用 spawn 启动 nproc 个子进程,每个进程运行 fit(rank, world_size, cfg)
|
||
# spawn 会在每个子进程中调用 fit(rank, world_size, cfg) 并传入不同的 rank
|
||
mp.spawn(fit, args=(nproc, cfg), nprocs=nproc, join=True)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|