delete patinece and only pbar when rank is 0

This commit is contained in:
kaiza_hikaru 2025-11-01 16:28:05 +08:00
parent 699d7448b9
commit 011eae0107
4 changed files with 15 additions and 34 deletions

5
.gitignore vendored
View File

@ -1,7 +1,8 @@
__pycache__/
ckpts/
configs/
configs/*
!configs/config_example.toml
runs/
results/
models.ipynb
*.ipynb
ShuffleNetV2.txt

View File

@ -49,9 +49,8 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
return val_loss
def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type):
def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type):
best_val_loss = float('inf')
patience_counter = 0
# !pip install tensorboard
with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer:
# Tensorboard 上显示模型结构
@ -82,7 +81,6 @@ def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, sch
# 记录检查点
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
save_dict = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
@ -94,11 +92,6 @@ def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, sch
Path("ckpts").mkdir(exist_ok=True, parents=True)
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
else:
patience_counter += 1
if patience_counter > patience:
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
break
def main():
@ -116,7 +109,6 @@ def main():
batch_size = int(cfg["batch_size"])
num_workers = int(cfg["num_workers"])
lr = float(cfg["lr"])
patience = int(cfg["patience"])
epochs = int(cfg["epochs"])
warmup_epochs = int(cfg["warmup_epochs"])
objective_params_list = cfg["train_objective_params_list"]
@ -173,7 +165,7 @@ def main():
)
print_with_timestamp("Start trainning")
fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type)
fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type)
print_with_timestamp("Training completed!")

View File

@ -21,11 +21,13 @@ from MOAFDatasets import MOAFDataset
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank):
model.train()
train_loss = 0.0
for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"):
data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]") if rank in {-1, 0} else train_loader
for data in data_iter:
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
@ -39,12 +41,14 @@ def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
return train_loss
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank):
model.eval()
val_loss = 0.0
data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]") if rank in {-1, 0} else val_loader
with torch.no_grad():
for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"):
for data in data_iter:
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
@ -56,12 +60,6 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
def fit(rank, world_size, cfg):
"""
每个进程运行的主函数单卡
rank: 该进程的全局 rank0 ~ world_size-1
world_size: 进程总数通常等于可用 GPU
cfg: toml 读取的配置字典
"""
# 初始化分布式参数
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = cfg.get("master_port", "29500")
@ -78,7 +76,6 @@ def fit(rank, world_size, cfg):
batch_size = int(cfg["batch_size"])
num_workers = int(cfg["num_workers"])
lr = float(cfg["lr"])
patience = int(cfg["patience"])
epochs = int(cfg["epochs"])
warmup_epochs = int(cfg["warmup_epochs"])
objective_params_list = cfg["train_objective_params_list"]
@ -143,7 +140,6 @@ def fit(rank, world_size, cfg):
# Tensorboard 上显示模型结构
if rank == 0:
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}")
# tensorboard graph: use a small dummy input placed on correct device
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
tb_writer.add_graph(model.module, (dummy_input1, dummy_input2))
@ -152,7 +148,6 @@ def fit(rank, world_size, cfg):
# 训练
best_val_loss = float('inf')
patience_counter = 0
if rank == 0:
print_with_timestamp("Start training (DDP)")
@ -163,8 +158,8 @@ def fit(rank, world_size, cfg):
val_sampler.set_epoch(epoch)
start_time = time.time()
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader)
avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn) / len(val_loader)
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank) / len(train_loader)
avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank) / len(val_loader)
current_lr = optimizer.param_groups[0]['lr']
scheduler.step()
epoch_time = time.time() - start_time
@ -184,7 +179,6 @@ def fit(rank, world_size, cfg):
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
save_dict = {
"epoch": epoch,
# 保存 module.state_dict()DDP 包裹时用 module
@ -197,11 +191,6 @@ def fit(rank, world_size, cfg):
Path("ckpts").mkdir(exist_ok=True, parents=True)
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
else:
patience_counter += 1
if patience_counter > patience:
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
break
# 清除进程
if tb_writer is not None:

View File

@ -6,7 +6,6 @@ dataset_dir = "F:/Datasets/MODatasetD"
batch_size = 64
num_workers = 8
lr = 1e-4
patience = 5
epochs = 5
warmup_epochs = 1
# 其它