From 2bdefda64ecf44a379f3141c347471dbd0392093 Mon Sep 17 00:00:00 2001 From: kaiza_hikaru Date: Sat, 1 Nov 2025 17:06:49 +0800 Subject: [PATCH] change pbar length --- MOAFTrain.py | 4 ++-- MOAFTrainDDP.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/MOAFTrain.py b/MOAFTrain.py index a18963e..8ba7763 100644 --- a/MOAFTrain.py +++ b/MOAFTrain.py @@ -19,7 +19,7 @@ def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn): model.train() train_loss = 0.0 - for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=60): + for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=180): images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) @@ -38,7 +38,7 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn): val_loss = 0.0 with torch.no_grad(): - for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=60): + for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=180): images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) diff --git a/MOAFTrainDDP.py b/MOAFTrainDDP.py index 0674f35..e9e65f9 100644 --- a/MOAFTrainDDP.py +++ b/MOAFTrainDDP.py @@ -25,7 +25,7 @@ def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, model.train() train_loss = 0.0 - data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=60) if rank in {-1, 0} else train_loader + data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=180) if rank in {-1, 0} else train_loader for data in data_iter: images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) @@ -45,7 +45,7 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank): model.eval() val_loss = 0.0 - data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=60) if rank in {-1, 0} else val_loader + data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=180) if rank in {-1, 0} else val_loader with torch.no_grad(): for data in data_iter: