SparseFocus/train_dpn_ddp.py

313 lines
10 KiB
Python

import shutil
import time
from datetime import datetime
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision.transforms as transforms
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import old_datasets as dataset_F
import utils
from models import DPNet
def reduce_epoch_loss(running_loss, total_samples, device):
stats = torch.tensor(
[running_loss, total_samples],
dtype=torch.float64,
device=device,
)
dist.all_reduce(stats, op=dist.ReduceOp.SUM)
return (stats[0] / stats[1]).item()
# 训练一轮
def train_epoch(model, loader, criterion, optimizer, device, is_main_process):
model.train()
running_loss = 0.0
total_samples = 0
for images, labels, image_names in tqdm(
loader,
desc="Train:",
bar_format="{l_bar}{bar:20}{r_bar}",
leave=False,
disable=not is_main_process,
):
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).float().view(-1)
optimizer.zero_grad(set_to_none=True)
outputs = model(images).view(-1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
this_batch_size = images.size(0)
running_loss += loss.item() * this_batch_size
total_samples += this_batch_size
return reduce_epoch_loss(running_loss, total_samples, device)
# 验证一轮
@torch.no_grad()
def valid_epoch(model, loader, criterion, device, is_main_process):
model.eval()
running_loss = 0.0
total_samples = 0
for images, labels, image_names in tqdm(
loader,
desc="Valid:",
bar_format="{l_bar}{bar:20}{r_bar}",
leave=False,
disable=not is_main_process,
):
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).float().view(-1)
outputs = model(images).view(-1)
loss = criterion(outputs, labels)
this_batch_size = images.size(0)
running_loss += loss.item() * this_batch_size
total_samples += this_batch_size
return reduce_epoch_loss(running_loss, total_samples, device)
def create_run_dir(config_path, is_main_process):
if is_main_process:
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_dpn_ddp")
run_dir = Path.cwd() / run_name
run_dir.mkdir(parents=True, exist_ok=False)
shutil.copy2(config_path, run_dir / config_path.name)
run_dir_text = str(run_dir)
else:
run_dir_text = None
shared_value = [run_dir_text]
dist.broadcast_object_list(shared_value, src=0)
dist.barrier()
return Path(shared_value[0])
# 主训练函数
def main():
local_rank, rank, world_size, device, is_main_process = utils.setup_distributed()
logger = None
writer = None
model = None
best_valid_loss = float("inf")
try:
# ========== 1 配置文件与超参数 ==========
config, config_path = utils.get_hyperparams()
XLSX_FILES = config["xlsx_files"]
BATCH_SIZE = config["batch_size"]
NUM_WORKERS = config["num_workers"]
LEARNING_RATE = config["learning_rate"]
NUM_EPOCHS = config["epochs"]
SEED = config["seed"]
INIT_WEIGHT_PATH = config["init_weight"]
# ========== 2 创建输出文件目录 ==========
run_dir = create_run_dir(config_path, is_main_process)
# ========== 3 日志、TensorBoard、随机种子与设备 ==========
if is_main_process:
logger = utils.get_logger(__name__, run_dir / "train.log")
writer = SummaryWriter(str(run_dir / "run"))
utils.set_seeds(SEED)
if is_main_process:
logger.info(f"Config path: {config_path}")
logger.info(f"Loaded config: {str(config)}")
logger.info(f"Run directory: {run_dir}")
logger.info(f"Using world size: {world_size}")
logger.info(f"Using device: {device}")
# ========== 4 数据与 loader ==========
(
train_image_path_list,
train_defocus_distance_list,
valid_image_path_list,
valid_defocus_distance_list,
) = dataset_F.get_DPNet_train_data_and_label(root_path_list=XLSX_FILES)
train_transform = transforms.Compose(
[
transforms.ColorJitter(
brightness=(0.9, 1.4),
contrast=(0.8, 1.5),
saturation=(0.8, 1.5),
),
transforms.ToTensor(),
]
)
valid_transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
train_set = dataset_F.MyDataset(
train_image_path_list,
train_defocus_distance_list,
train_transform,
)
valid_set = dataset_F.MyDataset(
valid_image_path_list,
valid_defocus_distance_list,
valid_transform,
)
train_sampler = DistributedSampler(
train_set,
num_replicas=world_size,
rank=rank,
shuffle=True,
)
valid_sampler = DistributedSampler(
valid_set,
num_replicas=world_size,
rank=rank,
shuffle=False,
)
train_loader = DataLoader(
dataset=train_set,
batch_size=BATCH_SIZE,
shuffle=False,
sampler=train_sampler,
num_workers=NUM_WORKERS,
pin_memory=True,
persistent_workers=(NUM_WORKERS > 0),
)
valid_loader = DataLoader(
dataset=valid_set,
batch_size=BATCH_SIZE,
shuffle=False,
sampler=valid_sampler,
num_workers=NUM_WORKERS,
pin_memory=True,
persistent_workers=(NUM_WORKERS > 0),
)
if is_main_process:
logger.info(f"Train dataset size: {len(train_set)}")
logger.info(f"Val dataset size: {len(valid_set)}")
logger.info(f"Train steps per epoch per process: {len(train_loader)}")
# ========== 5 模型、损失、优化器、调度器 ==========
model = DPNet().to(device)
if INIT_WEIGHT_PATH:
state_dict = torch.load(INIT_WEIGHT_PATH, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
if is_main_process:
logger.info(f"Loaded init weight from: {INIT_WEIGHT_PATH}")
elif is_main_process:
logger.info("Training from scratch")
model = DistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank,
)
criterion = nn.MSELoss(reduction="mean")
optimizer = Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
scheduler = utils.get_warmup_cosine_scheduler(optimizer, NUM_EPOCHS)
if is_main_process:
logger.info("Loss: MSELoss(reduction='mean')")
logger.info(f"Optimizer: Adam(lr={LEARNING_RATE}, betas=(0.9, 0.999))")
logger.info("Scheduler: epoch-based warmup + cosine annealing")
# ========== 6 开始训练 ==========
if is_main_process:
logger.info("START TRAINING")
try:
for epoch in range(1, NUM_EPOCHS + 1):
train_sampler.set_epoch(epoch)
epoch_start_time = time.time()
train_loss = train_epoch(
model,
train_loader,
criterion,
optimizer,
device,
is_main_process,
)
valid_loss = valid_epoch(
model,
valid_loader,
criterion,
device,
is_main_process,
)
epoch_lr = optimizer.param_groups[0]["lr"]
scheduler.step()
epoch_time_cost = time.time() - epoch_start_time
if is_main_process and valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.module.state_dict(), run_dir / "best_model.pt")
logger.info(f"Best model saved, valid_loss = {best_valid_loss:.4f}")
if is_main_process:
logger.info(
f"Epoch [{epoch}/{NUM_EPOCHS}] "
f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f} | "
f"Best Valid Loss: {best_valid_loss:.4f} | "
f"Epoch Time Cost: {epoch_time_cost:.2f} s | "
f"Epoch Learning Rate: {epoch_lr:.6e}"
)
writer.add_scalar("Loss/train", train_loss, epoch)
writer.add_scalar("Loss/valid", valid_loss, epoch)
writer.add_scalar("Loss/best_valid", best_valid_loss, epoch)
writer.add_scalar("Time/epoch", epoch_time_cost, epoch)
writer.add_scalar("Time/learning_rate", epoch_lr, epoch)
except KeyboardInterrupt:
if is_main_process:
logger.info("Training interrupted by user")
finally:
if is_main_process and model is not None:
torch.save(model.module.state_dict(), run_dir / "last_model.pt")
logger.info("Last model saved")
if writer is not None:
writer.close()
logger.info("TensorBoard writer closed")
if is_main_process and logger is not None:
logger.info(f"Training finished, best validation loss: {best_valid_loss:.8f}")
if dist.is_initialized():
dist.barrier()
utils.cleanup_distributed()
if __name__ == "__main__":
main()