diff --git a/MOAFTrain.py b/MOAFTrain.py index 1f36f49..a18963e 100644 --- a/MOAFTrain.py +++ b/MOAFTrain.py @@ -19,7 +19,7 @@ def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn): model.train() train_loss = 0.0 - for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"): + for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=60): images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) @@ -38,7 +38,7 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn): val_loss = 0.0 with torch.no_grad(): - for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"): + for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=60): images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) diff --git a/MOAFTrainDDP.py b/MOAFTrainDDP.py index cc18e8e..0674f35 100644 --- a/MOAFTrainDDP.py +++ b/MOAFTrainDDP.py @@ -25,7 +25,7 @@ def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, model.train() train_loss = 0.0 - data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]") if rank in {-1, 0} else train_loader + data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]", ncols=60) if rank in {-1, 0} else train_loader for data in data_iter: images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) @@ -45,7 +45,7 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank): model.eval() val_loss = 0.0 - data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]") if rank in {-1, 0} else val_loader + data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]", ncols=60) if rank in {-1, 0} else val_loader with torch.no_grad(): for data in data_iter: