MOAF/MOAFTrainDDP.py

231 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import time
import argparse
import tomllib
from pathlib import Path
import numpy as np
from tqdm import tqdm
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from MOAFUtils import print_with_timestamp
from MOAFDatasets import MOAFDataset
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank):
model.train()
train_loss = 0.0
data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]") if rank in {-1, 0} else train_loader
for data in data_iter:
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
optimizer.zero_grad()
outputs = model(images, params)
loss = loss_fn(outputs.squeeze(1), labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
return train_loss
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank):
model.eval()
val_loss = 0.0
data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]") if rank in {-1, 0} else val_loader
with torch.no_grad():
for data in data_iter:
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
outputs = model(images, params)
loss = loss_fn(outputs.squeeze(1), labels)
val_loss += loss.item()
return val_loss
def fit(rank, world_size, cfg):
# 初始化分布式参数
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = cfg.get("master_port", "29500")
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f'cuda:{rank}')
if rank == 0:
print_with_timestamp(f"Distributed initialized. World size: {world_size}")
# 确定超参数
model_type = cfg["model_type"]
output_type = cfg["output_type"]
dataset_dir = cfg["dataset_dir"]
batch_size = int(cfg["batch_size"])
num_workers = int(cfg["num_workers"])
lr = float(cfg["lr"])
epochs = int(cfg["epochs"])
warmup_epochs = int(cfg["warmup_epochs"])
objective_params_list = cfg["train_objective_params_list"]
checkpoint_load = cfg["checkpoint_load"]
# 加载数据集
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
# 分布式化数据集
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False)
train_loader = DataLoader(
train_set, batch_size=batch_size, num_workers=num_workers,
shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler
)
val_loader = DataLoader(
val_set, batch_size=batch_size, num_workers=num_workers,
shuffle=False, pin_memory=True, persistent_workers=True, sampler=val_sampler
)
if rank == 0:
print_with_timestamp("Dataset Loaded (Distributed)")
# 模型选择
if "film" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[4:]]
model = MOAFWithFiLM(fusion_depth_list).to(device)
elif "cca" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[3:]]
model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device)
elif "se" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[2:]]
model = MOAFWithSE(fusion_depth_list).to(device)
else:
model = MOAFNoFusion().to(device)
# 形式化预训练参数加载
if checkpoint_load:
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
print_with_timestamp("Model Checkpoint Params Loaded")
else:
print_with_timestamp("Model Checkpoint Params Not Exist")
# DDP 化模型
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
if rank == 0:
print_with_timestamp("Model Wrapped with DDP")
# 损失函数、优化器、学习率调度器
loss_fn = nn.HuberLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
)
# Tensorboard 上显示模型结构
if rank == 0:
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}")
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
tb_writer.add_graph(model.module, (dummy_input1, dummy_input2))
else:
tb_writer = None
# 训练
best_val_loss = float('inf')
if rank == 0:
print_with_timestamp("Start training (DDP)")
for epoch in range(epochs):
# 必须在每个 epoch 开头设置 epoch 给 sampler这样 shuffle 能跨 epoch 工作
train_sampler.set_epoch(epoch)
val_sampler.set_epoch(epoch)
start_time = time.time()
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank) / len(train_loader)
avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank) / len(val_loader)
current_lr = optimizer.param_groups[0]['lr']
scheduler.step()
epoch_time = time.time() - start_time
# 只有 rank 0 写 tensorboard 和保存 checkpoint避免重复 IO
if rank == 0:
tb_writer.add_scalar('Loss/train', avg_train_loss, epoch)
tb_writer.add_scalar('Loss/val', avg_val_loss, epoch)
tb_writer.add_scalar('LearningRate', current_lr, epoch)
tb_writer.add_scalar('Time/epoch', epoch_time, epoch)
print_with_timestamp(f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {avg_train_loss:.6f} | "
f"Val Loss: {avg_val_loss:.6f} | "
f"LR: {current_lr:.2e} | "
f"Time: {epoch_time:.2f}s")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
save_dict = {
"epoch": epoch,
# 保存 module.state_dict()DDP 包裹时用 module
"model_state_dict": model.module.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"train_loss": avg_train_loss,
"val_loss": avg_val_loss
}
Path("ckpts").mkdir(exist_ok=True, parents=True)
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
# 清除进程
if tb_writer is not None:
tb_writer.close()
dist.destroy_process_group()
if rank == 0:
print_with_timestamp("Training completed and distributed cleaned up")
def parse_args_and_cfg():
parser = argparse.ArgumentParser(description="DDP train script that loads config from a TOML file.")
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
parser.add_argument("--nproc_per_node", type=int, default=None, help="number of processes (GPUs) to use on this node")
args = parser.parse_args()
with open(args.config, "rb") as f:
cfg = tomllib.load(f)
return args, cfg
def main():
args, cfg = parse_args_and_cfg()
# 自动检测可用 GPU 数量,除非用户指定 nproc_per_node
available_gpus = torch.cuda.device_count()
if available_gpus == 0:
raise RuntimeError("No CUDA devices available for DDP. Use a GPU machine.")
nproc = args.nproc_per_node if args.nproc_per_node is not None else available_gpus
if nproc > available_gpus:
raise ValueError(f"Requested {nproc} processes but only {available_gpus} GPUs available")
# 使用 spawn 启动 nproc 个子进程,每个进程运行 fit(rank, world_size, cfg)
# spawn 会在每个子进程中调用 fit(rank, world_size, cfg) 并传入不同的 rank
mp.spawn(fit, args=(nproc, cfg), nprocs=nproc, join=True)
if __name__ == "__main__":
main()