MOAF/MOAFTrainDDP.py
2025-10-23 16:40:40 +08:00

240 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import time
import argparse
import tomllib
from pathlib import Path
import numpy as np
from tqdm import tqdm
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from MOAFUtils import print_with_timestamp
from MOAFDatasets import MOAFDataset
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
model.train()
train_loss = 0.0
for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"):
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
optimizer.zero_grad()
outputs = model(images, params)
loss = loss_fn(outputs.squeeze(1), labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
return train_loss
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
model.eval()
val_loss = 0.0
with torch.no_grad():
for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"):
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
outputs = model(images, params)
loss = loss_fn(outputs.squeeze(1), labels)
val_loss += loss.item()
return val_loss
def fit(rank, world_size, cfg):
"""
每个进程运行的主函数(单卡)。
rank: 该进程的全局 rank0 ~ world_size-1
world_size: 进程总数(通常等于可用 GPU 数)
cfg: 从 toml 读取的配置字典
"""
# -------- init distributed env --------
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = cfg.get("master_port", "29500")
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f'cuda:{rank}')
if rank == 0:
print_with_timestamp(f"Distributed initialized. World size: {world_size}")
# -------- parse hyperparams from cfg --------
model_type = cfg["model_type"]
output_type = cfg["output_type"]
dataset_dir = cfg["dataset_dir"]
batch_size = int(cfg["batch_size"])
num_workers = int(cfg["num_workers"])
lr = float(cfg["lr"])
patience = int(cfg["patience"])
epochs = int(cfg["epochs"])
warmup_epochs = int(cfg["warmup_epochs"])
objective_params_list = cfg["train_objective_params_list"]
checkpoint_load = cfg["checkpoint_load"]
# -------- datasets & distributed sampler --------
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
# DistributedSampler ensures each process sees a unique subset each epoch
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False)
# When using DistributedSampler, do not shuffle at DataLoader level
train_loader = DataLoader(
train_set, batch_size=batch_size, num_workers=num_workers,
shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler
)
val_loader = DataLoader(
val_set, batch_size=batch_size, num_workers=num_workers,
shuffle=False, pin_memory=True, persistent_workers=True, sampler=val_sampler
)
if rank == 0:
print_with_timestamp("Dataset Loaded (Distributed)")
# -------- model creation --------
if "film" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[4:]]
model = MOAFWithFiLM(fusion_depth_list).to(device)
elif "cca" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[3:]]
model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device)
elif "se" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[2:]]
model = MOAFWithSE(fusion_depth_list).to(device)
else:
model = MOAFNoFusion().to(device)
# 形式化预训练参数加载
if checkpoint_load:
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
print_with_timestamp("Model Checkpoint Params Loaded")
# wrap with DDP. device_ids ensures single-GPU per process
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
if rank == 0:
print_with_timestamp("Model Wrapped with DDP")
# -------- loss / optimizer / scheduler --------
loss_fn = nn.HuberLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
)
# -------- TensorBoard & checkpoint only on rank 0 --------
if rank == 0:
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}")
# tensorboard graph: use a small dummy input placed on correct device
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
tb_writer.add_graph(model.module, (dummy_input1, dummy_input2))
else:
tb_writer = None
# -------- training loop with early stopping (only rank 0 saves checkpoints) --------
best_val_loss = float('inf')
patience_counter = 0
if rank == 0:
print_with_timestamp("Start training (DDP)")
for epoch in range(epochs):
# 必须在每个 epoch 开头设置 epoch 给 sampler这样 shuffle 能跨 epoch 工作
train_sampler.set_epoch(epoch)
val_sampler.set_epoch(epoch)
start_time = time.time()
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader)
avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn) / len(val_loader)
current_lr = optimizer.param_groups[0]['lr']
scheduler.step()
epoch_time = time.time() - start_time
# 只有 rank 0 写 tensorboard 和保存 checkpoint避免重复 IO
if rank == 0:
tb_writer.add_scalar('Loss/train', avg_train_loss, epoch)
tb_writer.add_scalar('Loss/val', avg_val_loss, epoch)
tb_writer.add_scalar('LearningRate', current_lr, epoch)
tb_writer.add_scalar('Time/epoch', epoch_time, epoch)
print_with_timestamp(f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {avg_train_loss:.6f} | "
f"Val Loss: {avg_val_loss:.6f} | "
f"LR: {current_lr:.2e} | "
f"Time: {epoch_time:.2f}s")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
save_dict = {
"epoch": epoch,
# 保存 module.state_dict()DDP 包裹时用 module
"model_state_dict": model.module.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"train_loss": avg_train_loss,
"val_loss": avg_val_loss
}
Path("ckpts").mkdir(exist_ok=True, parents=True)
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
else:
patience_counter += 1
if patience_counter > patience:
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
break
# cleanup
if tb_writer is not None:
tb_writer.close()
dist.destroy_process_group()
if rank == 0:
print_with_timestamp("Training completed and distributed cleaned up")
def parse_args_and_cfg():
parser = argparse.ArgumentParser(description="DDP train script that loads config from a TOML file.")
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
parser.add_argument("--nproc_per_node", type=int, default=None, help="number of processes (GPUs) to use on this node")
args = parser.parse_args()
with open(args.config, "rb") as f:
cfg = tomllib.load(f)
return args, cfg
def main():
args, cfg = parse_args_and_cfg()
# 自动检测可用 GPU 数量,除非用户指定 nproc_per_node
available_gpus = torch.cuda.device_count()
if available_gpus == 0:
raise RuntimeError("No CUDA devices available for DDP. Use a GPU machine.")
nproc = args.nproc_per_node if args.nproc_per_node is not None else available_gpus
if nproc > available_gpus:
raise ValueError(f"Requested {nproc} processes but only {available_gpus} GPUs available")
# 使用 spawn 启动 nproc 个子进程,每个进程运行 fit(rank, world_size, cfg)
# spawn 会在每个子进程中调用 fit(rank, world_size, cfg) 并传入不同的 rank
mp.spawn(fit, args=(nproc, cfg), nprocs=nproc, join=True)
if __name__ == "__main__":
main()