111 lines
2.8 KiB
Python
111 lines
2.8 KiB
Python
import math
|
|
import torch
|
|
import random
|
|
import logging
|
|
import tomllib
|
|
import argparse
|
|
|
|
from pathlib import Path
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
import numpy as np
|
|
import torch.distributed as dist
|
|
|
|
|
|
# 设置固定的随机数种子
|
|
def set_seeds(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
# 获取日志句柄
|
|
def get_logger(name, log_file):
|
|
logger = logging.getLogger(name)
|
|
logger.setLevel(logging.INFO)
|
|
logger.propagate = False
|
|
logger.handlers.clear()
|
|
|
|
formatter = logging.Formatter(
|
|
fmt="%(asctime)s | %(levelname)s | %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setLevel(logging.INFO)
|
|
console_handler.setFormatter(formatter)
|
|
|
|
file_handler = logging.FileHandler(
|
|
log_file,
|
|
mode="a",
|
|
encoding="utf-8",
|
|
)
|
|
file_handler.setLevel(logging.INFO)
|
|
file_handler.setFormatter(formatter)
|
|
|
|
logger.addHandler(console_handler)
|
|
logger.addHandler(file_handler)
|
|
|
|
return logger
|
|
|
|
|
|
# 初始化 DDP 并行
|
|
def setup_distributed():
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("DDP training requires CUDA.")
|
|
|
|
dist.init_process_group(backend="nccl")
|
|
|
|
local_rank = dist.get_node_local_rank()
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
|
|
torch.cuda.set_device(local_rank)
|
|
device = torch.device("cuda", local_rank)
|
|
|
|
is_main_process = (rank == 0)
|
|
|
|
return local_rank, rank, world_size, device, is_main_process
|
|
|
|
|
|
# 释放 DDP 并行
|
|
def cleanup_distributed():
|
|
dist.destroy_process_group()
|
|
|
|
|
|
# 命令行参数解析配置文件
|
|
def get_hyperparams():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("config", help="Path to TOML config file")
|
|
args = parser.parse_args()
|
|
|
|
config_path = Path(args.config)
|
|
with config_path.open("rb") as f:
|
|
return tomllib.load(f), config_path
|
|
|
|
|
|
# 线性预热与余弦退火调度器
|
|
def get_warmup_cosine_scheduler(optimizer, epochs):
|
|
max_warmup_epochs, start_factor, eta_min_factor = 10, 0.1, 0.0
|
|
warmup_epochs = min(epochs // 10, max_warmup_epochs)
|
|
cosine_epochs = epochs - warmup_epochs
|
|
|
|
def lr_lambda(current_epoch):
|
|
# 线性预热阶段
|
|
if warmup_epochs > 0 and current_epoch < warmup_epochs:
|
|
return start_factor + (1.0 - start_factor) * (
|
|
current_epoch / warmup_epochs
|
|
)
|
|
|
|
# 余弦退火阶段
|
|
cosine_epoch = current_epoch - warmup_epochs
|
|
return eta_min_factor + (1.0 - eta_min_factor) * 0.5 * (
|
|
1.0 + math.cos(math.pi * cosine_epoch / cosine_epochs)
|
|
)
|
|
|
|
return LambdaLR(optimizer, lr_lambda=lr_lambda)
|