SparseFocus/utils.py

113 lines
2.9 KiB
Python

import math
import os
import torch
import random
import logging
import tomllib
import argparse
from pathlib import Path
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
import torch.distributed as dist
# 设置固定的随机数种子
def set_seeds(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 获取日志句柄
def get_logger(name, log_file):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
logger.propagate = False
logger.handlers.clear()
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
file_handler = logging.FileHandler(
log_file,
mode="a",
encoding="utf-8",
)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(console_handler)
logger.addHandler(file_handler)
return logger
# 初始化 DDP 并行
def setup_distributed():
if not torch.cuda.is_available():
raise RuntimeError("DDP training requires CUDA.")
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
is_main_process = (rank == 0)
return local_rank, rank, world_size, device, is_main_process
# 释放 DDP 并行
def cleanup_distributed():
if dist.is_initialized():
dist.destroy_process_group()
# 命令行参数解析配置文件
def get_hyperparams():
parser = argparse.ArgumentParser()
parser.add_argument("config", help="Path to TOML config file")
args = parser.parse_args()
config_path = Path(args.config)
with config_path.open("rb") as f:
return tomllib.load(f), config_path
# 线性预热与余弦退火调度器
def get_warmup_cosine_scheduler(optimizer, epochs):
max_warmup_epochs, start_factor, eta_min_factor = 10, 0.1, 0.0
warmup_epochs = min(epochs // 10, max_warmup_epochs)
cosine_epochs = epochs - warmup_epochs
def lr_lambda(current_epoch):
# 线性预热阶段
if warmup_epochs > 0 and current_epoch < warmup_epochs:
return start_factor + (1.0 - start_factor) * (
current_epoch / warmup_epochs
)
# 余弦退火阶段
cosine_epoch = current_epoch - warmup_epochs
return eta_min_factor + (1.0 - eta_min_factor) * 0.5 * (
1.0 + math.cos(math.pi * cosine_epoch / cosine_epochs)
)
return LambdaLR(optimizer, lr_lambda=lr_lambda)