delete patinece and only pbar when rank is 0
This commit is contained in:
parent
699d7448b9
commit
011eae0107
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,7 +1,8 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
ckpts/
|
ckpts/
|
||||||
configs/
|
configs/*
|
||||||
|
!configs/config_example.toml
|
||||||
runs/
|
runs/
|
||||||
results/
|
results/
|
||||||
models.ipynb
|
*.ipynb
|
||||||
ShuffleNetV2.txt
|
ShuffleNetV2.txt
|
||||||
12
MOAFTrain.py
12
MOAFTrain.py
@ -49,9 +49,8 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
|
|||||||
return val_loss
|
return val_loss
|
||||||
|
|
||||||
|
|
||||||
def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type):
|
def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type):
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
patience_counter = 0
|
|
||||||
# !pip install tensorboard
|
# !pip install tensorboard
|
||||||
with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer:
|
with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer:
|
||||||
# Tensorboard 上显示模型结构
|
# Tensorboard 上显示模型结构
|
||||||
@ -82,7 +81,6 @@ def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, sch
|
|||||||
# 记录检查点
|
# 记录检查点
|
||||||
if avg_val_loss < best_val_loss:
|
if avg_val_loss < best_val_loss:
|
||||||
best_val_loss = avg_val_loss
|
best_val_loss = avg_val_loss
|
||||||
patience_counter = 0
|
|
||||||
save_dict = {
|
save_dict = {
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"model_state_dict": model.state_dict(),
|
"model_state_dict": model.state_dict(),
|
||||||
@ -94,11 +92,6 @@ def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, sch
|
|||||||
Path("ckpts").mkdir(exist_ok=True, parents=True)
|
Path("ckpts").mkdir(exist_ok=True, parents=True)
|
||||||
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
|
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
|
||||||
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
|
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
|
||||||
else:
|
|
||||||
patience_counter += 1
|
|
||||||
if patience_counter > patience:
|
|
||||||
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -116,7 +109,6 @@ def main():
|
|||||||
batch_size = int(cfg["batch_size"])
|
batch_size = int(cfg["batch_size"])
|
||||||
num_workers = int(cfg["num_workers"])
|
num_workers = int(cfg["num_workers"])
|
||||||
lr = float(cfg["lr"])
|
lr = float(cfg["lr"])
|
||||||
patience = int(cfg["patience"])
|
|
||||||
epochs = int(cfg["epochs"])
|
epochs = int(cfg["epochs"])
|
||||||
warmup_epochs = int(cfg["warmup_epochs"])
|
warmup_epochs = int(cfg["warmup_epochs"])
|
||||||
objective_params_list = cfg["train_objective_params_list"]
|
objective_params_list = cfg["train_objective_params_list"]
|
||||||
@ -173,7 +165,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
print_with_timestamp("Start trainning")
|
print_with_timestamp("Start trainning")
|
||||||
fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type)
|
fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type)
|
||||||
print_with_timestamp("Training completed!")
|
print_with_timestamp("Training completed!")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -21,11 +21,13 @@ from MOAFDatasets import MOAFDataset
|
|||||||
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
|
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
|
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank):
|
||||||
model.train()
|
model.train()
|
||||||
train_loss = 0.0
|
train_loss = 0.0
|
||||||
|
|
||||||
for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"):
|
data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]") if rank in {-1, 0} else train_loader
|
||||||
|
|
||||||
|
for data in data_iter:
|
||||||
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
||||||
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
||||||
|
|
||||||
@ -39,12 +41,14 @@ def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
|
|||||||
return train_loss
|
return train_loss
|
||||||
|
|
||||||
|
|
||||||
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
|
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank):
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss = 0.0
|
val_loss = 0.0
|
||||||
|
|
||||||
|
data_iter = tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]") if rank in {-1, 0} else val_loader
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"):
|
for data in data_iter:
|
||||||
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
||||||
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
||||||
|
|
||||||
@ -56,12 +60,6 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
|
|||||||
|
|
||||||
|
|
||||||
def fit(rank, world_size, cfg):
|
def fit(rank, world_size, cfg):
|
||||||
"""
|
|
||||||
每个进程运行的主函数(单卡)。
|
|
||||||
rank: 该进程的全局 rank(0 ~ world_size-1)
|
|
||||||
world_size: 进程总数(通常等于可用 GPU 数)
|
|
||||||
cfg: 从 toml 读取的配置字典
|
|
||||||
"""
|
|
||||||
# 初始化分布式参数
|
# 初始化分布式参数
|
||||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
os.environ['MASTER_PORT'] = cfg.get("master_port", "29500")
|
os.environ['MASTER_PORT'] = cfg.get("master_port", "29500")
|
||||||
@ -78,7 +76,6 @@ def fit(rank, world_size, cfg):
|
|||||||
batch_size = int(cfg["batch_size"])
|
batch_size = int(cfg["batch_size"])
|
||||||
num_workers = int(cfg["num_workers"])
|
num_workers = int(cfg["num_workers"])
|
||||||
lr = float(cfg["lr"])
|
lr = float(cfg["lr"])
|
||||||
patience = int(cfg["patience"])
|
|
||||||
epochs = int(cfg["epochs"])
|
epochs = int(cfg["epochs"])
|
||||||
warmup_epochs = int(cfg["warmup_epochs"])
|
warmup_epochs = int(cfg["warmup_epochs"])
|
||||||
objective_params_list = cfg["train_objective_params_list"]
|
objective_params_list = cfg["train_objective_params_list"]
|
||||||
@ -143,7 +140,6 @@ def fit(rank, world_size, cfg):
|
|||||||
# Tensorboard 上显示模型结构
|
# Tensorboard 上显示模型结构
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}")
|
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}")
|
||||||
# tensorboard graph: use a small dummy input placed on correct device
|
|
||||||
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
|
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
|
||||||
tb_writer.add_graph(model.module, (dummy_input1, dummy_input2))
|
tb_writer.add_graph(model.module, (dummy_input1, dummy_input2))
|
||||||
|
|
||||||
@ -152,7 +148,6 @@ def fit(rank, world_size, cfg):
|
|||||||
|
|
||||||
# 训练
|
# 训练
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
patience_counter = 0
|
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print_with_timestamp("Start training (DDP)")
|
print_with_timestamp("Start training (DDP)")
|
||||||
@ -163,8 +158,8 @@ def fit(rank, world_size, cfg):
|
|||||||
val_sampler.set_epoch(epoch)
|
val_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader)
|
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank) / len(train_loader)
|
||||||
avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn) / len(val_loader)
|
avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn, rank) / len(val_loader)
|
||||||
current_lr = optimizer.param_groups[0]['lr']
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
epoch_time = time.time() - start_time
|
epoch_time = time.time() - start_time
|
||||||
@ -184,7 +179,6 @@ def fit(rank, world_size, cfg):
|
|||||||
|
|
||||||
if avg_val_loss < best_val_loss:
|
if avg_val_loss < best_val_loss:
|
||||||
best_val_loss = avg_val_loss
|
best_val_loss = avg_val_loss
|
||||||
patience_counter = 0
|
|
||||||
save_dict = {
|
save_dict = {
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
# 保存 module.state_dict()(DDP 包裹时用 module)
|
# 保存 module.state_dict()(DDP 包裹时用 module)
|
||||||
@ -197,11 +191,6 @@ def fit(rank, world_size, cfg):
|
|||||||
Path("ckpts").mkdir(exist_ok=True, parents=True)
|
Path("ckpts").mkdir(exist_ok=True, parents=True)
|
||||||
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
|
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
|
||||||
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
|
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
|
||||||
else:
|
|
||||||
patience_counter += 1
|
|
||||||
if patience_counter > patience:
|
|
||||||
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
|
|
||||||
break
|
|
||||||
|
|
||||||
# 清除进程
|
# 清除进程
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
|
|||||||
@ -6,7 +6,6 @@ dataset_dir = "F:/Datasets/MODatasetD"
|
|||||||
batch_size = 64
|
batch_size = 64
|
||||||
num_workers = 8
|
num_workers = 8
|
||||||
lr = 1e-4
|
lr = 1e-4
|
||||||
patience = 5
|
|
||||||
epochs = 5
|
epochs = 5
|
||||||
warmup_epochs = 1
|
warmup_epochs = 1
|
||||||
# 其它
|
# 其它
|
||||||
Loading…
x
Reference in New Issue
Block a user