From 011eae0107c1eee179ae296b75dbd41f534ad13c Mon Sep 17 00:00:00 2001 From: kaiza_hikaru Date: Sat, 1 Nov 2025 16:28:05 +0800 Subject: [PATCH] delete patinece and only pbar when rank is 0 --- .gitignore | 5 +-- MOAFTrain.py | 12 ++----- MOAFTrainDDP.py | 31 ++++++------------- .../config_example.toml | 1 - 4 files changed, 15 insertions(+), 34 deletions(-) rename config_example.toml => configs/config_example.toml (97%) diff --git a/.gitignore b/.gitignore index 4ae37e6..a759263 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,8 @@ __pycache__/ ckpts/ -configs/ +configs/* +!configs/config_example.toml runs/ results/ -models.ipynb +*.ipynb ShuffleNetV2.txt \ No newline at end of file diff --git a/MOAFTrain.py b/MOAFTrain.py index 8be46a0..1f36f49 100644 --- a/MOAFTrain.py +++ b/MOAFTrain.py @@ -49,9 +49,8 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn): return val_loss -def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type): +def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type): best_val_loss = float('inf') - patience_counter = 0 # !pip install tensorboard with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer: # Tensorboard 上显示模型结构 @@ -82,7 +81,6 @@ def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, sch # 记录检查点 if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss - patience_counter = 0 save_dict = { "epoch": epoch, "model_state_dict": model.state_dict(), @@ -94,11 +92,6 @@ def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, sch Path("ckpts").mkdir(exist_ok=True, parents=True) torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt") print_with_timestamp(f"New best model saved at epoch {epoch+1}") - else: - patience_counter += 1 - if patience_counter > patience: - print_with_timestamp(f"Early stopping at {epoch+1} epochs") - break def main(): @@ -116,7 +109,6 @@ def main(): batch_size = int(cfg["batch_size"]) num_workers = int(cfg["num_workers"]) lr = float(cfg["lr"]) - patience = int(cfg["patience"]) epochs = int(cfg["epochs"]) warmup_epochs = int(cfg["warmup_epochs"]) objective_params_list = cfg["train_objective_params_list"] @@ -173,7 +165,7 @@ def main(): ) print_with_timestamp("Start trainning") - fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type) + fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type) print_with_timestamp("Training completed!") diff --git a/MOAFTrainDDP.py b/MOAFTrainDDP.py index da515d5..6cf7258 100644 --- a/MOAFTrainDDP.py +++ b/MOAFTrainDDP.py @@ -21,11 +21,13 @@ from MOAFDatasets import MOAFDataset from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE -def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn): +def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank): model.train() train_loss = 0.0 - for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"): + data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]") if rank in {-1, 0} else train_loader + + for data in data_iter: images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) @@ -39,12 +41,14 @@ def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn): return train_loss -def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn): +def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank): model.eval() val_loss = 0.0 + data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]") if rank in {-1, 0} else val_loader + with torch.no_grad(): - for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"): + for data in data_iter: images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) @@ -56,12 +60,6 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn): def fit(rank, world_size, cfg): - """ - 每个进程运行的主函数(单卡)。 - rank: 该进程的全局 rank(0 ~ world_size-1) - world_size: 进程总数(通常等于可用 GPU 数) - cfg: 从 toml 读取的配置字典 - """ # 初始化分布式参数 os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = cfg.get("master_port", "29500") @@ -78,7 +76,6 @@ def fit(rank, world_size, cfg): batch_size = int(cfg["batch_size"]) num_workers = int(cfg["num_workers"]) lr = float(cfg["lr"]) - patience = int(cfg["patience"]) epochs = int(cfg["epochs"]) warmup_epochs = int(cfg["warmup_epochs"]) objective_params_list = cfg["train_objective_params_list"] @@ -143,7 +140,6 @@ def fit(rank, world_size, cfg): # Tensorboard 上显示模型结构 if rank == 0: tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") - # tensorboard graph: use a small dummy input placed on correct device dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device) tb_writer.add_graph(model.module, (dummy_input1, dummy_input2)) @@ -152,7 +148,6 @@ def fit(rank, world_size, cfg): # 训练 best_val_loss = float('inf') - patience_counter = 0 if rank == 0: print_with_timestamp("Start training (DDP)") @@ -163,8 +158,8 @@ def fit(rank, world_size, cfg): val_sampler.set_epoch(epoch) start_time = time.time() - avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader) - avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn) / len(val_loader) + avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank) / len(train_loader) + avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank) / len(val_loader) current_lr = optimizer.param_groups[0]['lr'] scheduler.step() epoch_time = time.time() - start_time @@ -184,7 +179,6 @@ def fit(rank, world_size, cfg): if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss - patience_counter = 0 save_dict = { "epoch": epoch, # 保存 module.state_dict()(DDP 包裹时用 module) @@ -197,11 +191,6 @@ def fit(rank, world_size, cfg): Path("ckpts").mkdir(exist_ok=True, parents=True) torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt") print_with_timestamp(f"New best model saved at epoch {epoch+1}") - else: - patience_counter += 1 - if patience_counter > patience: - print_with_timestamp(f"Early stopping at {epoch+1} epochs") - break # 清除进程 if tb_writer is not None: diff --git a/config_example.toml b/configs/config_example.toml similarity index 97% rename from config_example.toml rename to configs/config_example.toml index a92112d..ad66d2a 100644 --- a/config_example.toml +++ b/configs/config_example.toml @@ -6,7 +6,6 @@ dataset_dir = "F:/Datasets/MODatasetD" batch_size = 64 num_workers = 8 lr = 1e-4 -patience = 5 epochs = 5 warmup_epochs = 1 # 其它