313 lines
10 KiB
Python
313 lines
10 KiB
Python
import shutil
|
|
import time
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torchvision.transforms as transforms
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
from torch.optim import Adam
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from tqdm import tqdm
|
|
|
|
import old_datasets as dataset_F
|
|
import utils
|
|
from models import DPNet
|
|
|
|
|
|
def reduce_epoch_loss(running_loss, total_samples, device):
|
|
stats = torch.tensor(
|
|
[running_loss, total_samples],
|
|
dtype=torch.float64,
|
|
device=device,
|
|
)
|
|
dist.all_reduce(stats, op=dist.ReduceOp.SUM)
|
|
return (stats[0] / stats[1]).item()
|
|
|
|
|
|
# 训练一轮
|
|
def train_epoch(model, loader, criterion, optimizer, device, is_main_process):
|
|
model.train()
|
|
running_loss = 0.0
|
|
total_samples = 0
|
|
|
|
for images, labels, image_names in tqdm(
|
|
loader,
|
|
desc="Train:",
|
|
bar_format="{l_bar}{bar:20}{r_bar}",
|
|
leave=False,
|
|
disable=not is_main_process,
|
|
):
|
|
images = images.to(device, non_blocking=True)
|
|
labels = labels.to(device, non_blocking=True).float().view(-1)
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
outputs = model(images).view(-1)
|
|
loss = criterion(outputs, labels)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
this_batch_size = images.size(0)
|
|
running_loss += loss.item() * this_batch_size
|
|
total_samples += this_batch_size
|
|
|
|
return reduce_epoch_loss(running_loss, total_samples, device)
|
|
|
|
|
|
# 验证一轮
|
|
@torch.no_grad()
|
|
def valid_epoch(model, loader, criterion, device, is_main_process):
|
|
model.eval()
|
|
running_loss = 0.0
|
|
total_samples = 0
|
|
|
|
for images, labels, image_names in tqdm(
|
|
loader,
|
|
desc="Valid:",
|
|
bar_format="{l_bar}{bar:20}{r_bar}",
|
|
leave=False,
|
|
disable=not is_main_process,
|
|
):
|
|
images = images.to(device, non_blocking=True)
|
|
labels = labels.to(device, non_blocking=True).float().view(-1)
|
|
|
|
outputs = model(images).view(-1)
|
|
loss = criterion(outputs, labels)
|
|
|
|
this_batch_size = images.size(0)
|
|
running_loss += loss.item() * this_batch_size
|
|
total_samples += this_batch_size
|
|
|
|
return reduce_epoch_loss(running_loss, total_samples, device)
|
|
|
|
|
|
def create_run_dir(config_path, is_main_process):
|
|
if is_main_process:
|
|
run_name = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_dpn_ddp")
|
|
run_dir = Path.cwd() / run_name
|
|
run_dir.mkdir(parents=True, exist_ok=False)
|
|
shutil.copy2(config_path, run_dir / config_path.name)
|
|
run_dir_text = str(run_dir)
|
|
else:
|
|
run_dir_text = None
|
|
|
|
shared_value = [run_dir_text]
|
|
dist.broadcast_object_list(shared_value, src=0)
|
|
dist.barrier()
|
|
|
|
return Path(shared_value[0])
|
|
|
|
|
|
# 主训练函数
|
|
def main():
|
|
local_rank, rank, world_size, device, is_main_process = utils.setup_distributed()
|
|
|
|
logger = None
|
|
writer = None
|
|
model = None
|
|
best_valid_loss = float("inf")
|
|
|
|
try:
|
|
# ========== 1 配置文件与超参数 ==========
|
|
config, config_path = utils.get_hyperparams()
|
|
|
|
XLSX_FILES = config["xlsx_files"]
|
|
BATCH_SIZE = config["batch_size"]
|
|
NUM_WORKERS = config["num_workers"]
|
|
LEARNING_RATE = config["learning_rate"]
|
|
NUM_EPOCHS = config["epochs"]
|
|
SEED = config["seed"]
|
|
INIT_WEIGHT_PATH = config["init_weight"]
|
|
|
|
# ========== 2 创建输出文件目录 ==========
|
|
run_dir = create_run_dir(config_path, is_main_process)
|
|
|
|
# ========== 3 日志、TensorBoard、随机种子与设备 ==========
|
|
if is_main_process:
|
|
logger = utils.get_logger(__name__, run_dir / "train.log")
|
|
writer = SummaryWriter(str(run_dir / "run"))
|
|
|
|
utils.set_seeds(SEED)
|
|
|
|
if is_main_process:
|
|
logger.info(f"Config path: {config_path}")
|
|
logger.info(f"Loaded config: {str(config)}")
|
|
logger.info(f"Run directory: {run_dir}")
|
|
logger.info(f"Using world size: {world_size}")
|
|
logger.info(f"Using device: {device}")
|
|
|
|
# ========== 4 数据与 loader ==========
|
|
(
|
|
train_image_path_list,
|
|
train_defocus_distance_list,
|
|
valid_image_path_list,
|
|
valid_defocus_distance_list,
|
|
) = dataset_F.get_DPNet_train_data_and_label(root_path_list=XLSX_FILES)
|
|
|
|
train_transform = transforms.Compose(
|
|
[
|
|
transforms.ColorJitter(
|
|
brightness=(0.9, 1.4),
|
|
contrast=(0.8, 1.5),
|
|
saturation=(0.8, 1.5),
|
|
),
|
|
transforms.ToTensor(),
|
|
]
|
|
)
|
|
valid_transform = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
]
|
|
)
|
|
|
|
train_set = dataset_F.MyDataset(
|
|
train_image_path_list,
|
|
train_defocus_distance_list,
|
|
train_transform,
|
|
)
|
|
valid_set = dataset_F.MyDataset(
|
|
valid_image_path_list,
|
|
valid_defocus_distance_list,
|
|
valid_transform,
|
|
)
|
|
|
|
train_sampler = DistributedSampler(
|
|
train_set,
|
|
num_replicas=world_size,
|
|
rank=rank,
|
|
shuffle=True,
|
|
)
|
|
valid_sampler = DistributedSampler(
|
|
valid_set,
|
|
num_replicas=world_size,
|
|
rank=rank,
|
|
shuffle=False,
|
|
)
|
|
|
|
train_loader = DataLoader(
|
|
dataset=train_set,
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=False,
|
|
sampler=train_sampler,
|
|
num_workers=NUM_WORKERS,
|
|
pin_memory=True,
|
|
persistent_workers=(NUM_WORKERS > 0),
|
|
)
|
|
valid_loader = DataLoader(
|
|
dataset=valid_set,
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=False,
|
|
sampler=valid_sampler,
|
|
num_workers=NUM_WORKERS,
|
|
pin_memory=True,
|
|
persistent_workers=(NUM_WORKERS > 0),
|
|
)
|
|
|
|
if is_main_process:
|
|
logger.info(f"Train dataset size: {len(train_set)}")
|
|
logger.info(f"Val dataset size: {len(valid_set)}")
|
|
logger.info(f"Train steps per epoch per process: {len(train_loader)}")
|
|
|
|
# ========== 5 模型、损失、优化器、调度器 ==========
|
|
model = DPNet().to(device)
|
|
if INIT_WEIGHT_PATH:
|
|
state_dict = torch.load(INIT_WEIGHT_PATH, map_location="cpu")
|
|
model.load_state_dict(state_dict, strict=True)
|
|
if is_main_process:
|
|
logger.info(f"Loaded init weight from: {INIT_WEIGHT_PATH}")
|
|
elif is_main_process:
|
|
logger.info("Training from scratch")
|
|
|
|
model = DistributedDataParallel(
|
|
model,
|
|
device_ids=[local_rank],
|
|
output_device=local_rank,
|
|
)
|
|
|
|
criterion = nn.MSELoss(reduction="mean")
|
|
optimizer = Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
|
|
scheduler = utils.get_warmup_cosine_scheduler(optimizer, NUM_EPOCHS)
|
|
|
|
if is_main_process:
|
|
logger.info("Loss: MSELoss(reduction='mean')")
|
|
logger.info(f"Optimizer: Adam(lr={LEARNING_RATE}, betas=(0.9, 0.999))")
|
|
logger.info("Scheduler: epoch-based warmup + cosine annealing")
|
|
|
|
# ========== 6 开始训练 ==========
|
|
if is_main_process:
|
|
logger.info("START TRAINING")
|
|
|
|
try:
|
|
for epoch in range(1, NUM_EPOCHS + 1):
|
|
train_sampler.set_epoch(epoch)
|
|
epoch_start_time = time.time()
|
|
|
|
train_loss = train_epoch(
|
|
model,
|
|
train_loader,
|
|
criterion,
|
|
optimizer,
|
|
device,
|
|
is_main_process,
|
|
)
|
|
valid_loss = valid_epoch(
|
|
model,
|
|
valid_loader,
|
|
criterion,
|
|
device,
|
|
is_main_process,
|
|
)
|
|
epoch_lr = optimizer.param_groups[0]["lr"]
|
|
scheduler.step()
|
|
epoch_time_cost = time.time() - epoch_start_time
|
|
|
|
if is_main_process and valid_loss < best_valid_loss:
|
|
best_valid_loss = valid_loss
|
|
torch.save(model.module.state_dict(), run_dir / "best_model.pt")
|
|
logger.info(f"Best model saved, valid_loss = {best_valid_loss:.4f}")
|
|
|
|
if is_main_process:
|
|
logger.info(
|
|
f"Epoch [{epoch}/{NUM_EPOCHS}] "
|
|
f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f} | "
|
|
f"Best Valid Loss: {best_valid_loss:.4f} | "
|
|
f"Epoch Time Cost: {epoch_time_cost:.2f} s | "
|
|
f"Epoch Learning Rate: {epoch_lr:.6e}"
|
|
)
|
|
|
|
writer.add_scalar("Loss/train", train_loss, epoch)
|
|
writer.add_scalar("Loss/valid", valid_loss, epoch)
|
|
writer.add_scalar("Loss/best_valid", best_valid_loss, epoch)
|
|
writer.add_scalar("Time/epoch", epoch_time_cost, epoch)
|
|
writer.add_scalar("Time/learning_rate", epoch_lr, epoch)
|
|
|
|
except KeyboardInterrupt:
|
|
if is_main_process:
|
|
logger.info("Training interrupted by user")
|
|
|
|
finally:
|
|
if is_main_process and model is not None:
|
|
torch.save(model.module.state_dict(), run_dir / "last_model.pt")
|
|
logger.info("Last model saved")
|
|
|
|
if writer is not None:
|
|
writer.close()
|
|
logger.info("TensorBoard writer closed")
|
|
|
|
if is_main_process and logger is not None:
|
|
logger.info(f"Training finished, best validation loss: {best_valid_loss:.8f}")
|
|
|
|
if dist.is_initialized():
|
|
dist.barrier()
|
|
utils.cleanup_distributed()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|