import math import os import torch import random import logging import tomllib import argparse from pathlib import Path from torch.optim.lr_scheduler import LambdaLR import numpy as np import torch.distributed as dist # 设置固定的随机数种子 def set_seeds(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # 获取日志句柄 def get_logger(name, log_file): logger = logging.getLogger(name) logger.setLevel(logging.INFO) logger.propagate = False logger.handlers.clear() formatter = logging.Formatter( fmt="%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_handler.setFormatter(formatter) file_handler = logging.FileHandler( log_file, mode="a", encoding="utf-8", ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(console_handler) logger.addHandler(file_handler) return logger # 初始化 DDP 并行 def setup_distributed(): if not torch.cuda.is_available(): raise RuntimeError("DDP training requires CUDA.") dist.init_process_group(backend="nccl") local_rank = int(os.environ["LOCAL_RANK"]) rank = dist.get_rank() world_size = dist.get_world_size() torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) is_main_process = (rank == 0) return local_rank, rank, world_size, device, is_main_process # 释放 DDP 并行 def cleanup_distributed(): if dist.is_initialized(): dist.destroy_process_group() # 命令行参数解析配置文件 def get_hyperparams(): parser = argparse.ArgumentParser() parser.add_argument("config", help="Path to TOML config file") args = parser.parse_args() config_path = Path(args.config) with config_path.open("rb") as f: return tomllib.load(f), config_path # 线性预热与余弦退火调度器 def get_warmup_cosine_scheduler(optimizer, epochs): max_warmup_epochs, start_factor, eta_min_factor = 10, 0.1, 0.0 warmup_epochs = min(epochs // 10, max_warmup_epochs) cosine_epochs = epochs - warmup_epochs def lr_lambda(current_epoch): # 线性预热阶段 if warmup_epochs > 0 and current_epoch < warmup_epochs: return start_factor + (1.0 - start_factor) * ( current_epoch / warmup_epochs ) # 余弦退火阶段 cosine_epoch = current_epoch - warmup_epochs return eta_min_factor + (1.0 - eta_min_factor) * 0.5 * ( 1.0 + math.cos(math.pi * cosine_epoch / cosine_epochs) ) return LambdaLR(optimizer, lr_lambda=lr_lambda)