change pbar length

This commit is contained in:
kaiza_hikaru 2025-11-01 17:06:49 +08:00
parent 0ff21dcdc6
commit 2bdefda64e
2 changed files with 4 additions and 4 deletions

View File

@ -19,7 +19,7 @@ def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
model.train() model.train()
train_loss = 0.0 train_loss = 0.0
for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=60): for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=180):
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
@ -38,7 +38,7 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
val_loss = 0.0 val_loss = 0.0
with torch.no_grad(): with torch.no_grad():
for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=60): for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=180):
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)

View File

@ -25,7 +25,7 @@ def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn,
model.train() model.train()
train_loss = 0.0 train_loss = 0.0
data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=60) if rank in {-1, 0} else train_loader data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=180) if rank in {-1, 0} else train_loader
for data in data_iter: for data in data_iter:
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
@ -45,7 +45,7 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank):
model.eval() model.eval()
val_loss = 0.0 val_loss = 0.0
data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=60) if rank in {-1, 0} else val_loader data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=180) if rank in {-1, 0} else val_loader
with torch.no_grad(): with torch.no_grad():
for data in data_iter: for data in data_iter: