change pbar length
This commit is contained in:
parent
0ff21dcdc6
commit
2bdefda64e
@ -19,7 +19,7 @@ def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
|
||||
for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=60):
|
||||
for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=180):
|
||||
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
||||
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
||||
|
||||
@ -38,7 +38,7 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
|
||||
val_loss = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=60):
|
||||
for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=180):
|
||||
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
||||
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn,
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
|
||||
data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=60) if rank in {-1, 0} else train_loader
|
||||
data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=180) if rank in {-1, 0} else train_loader
|
||||
|
||||
for data in data_iter:
|
||||
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
||||
@ -45,7 +45,7 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank):
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
|
||||
data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=60) if rank in {-1, 0} else val_loader
|
||||
data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=180) if rank in {-1, 0} else val_loader
|
||||
|
||||
with torch.no_grad():
|
||||
for data in data_iter:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user