From f54d81770202a3be9c056cf917434744f7ed432d Mon Sep 17 00:00:00 2001 From: kaiza_hikaru Date: Thu, 23 Oct 2025 16:40:40 +0800 Subject: [PATCH] first commit --- .gitignore | 7 ++ MOAFDatasets.py | 160 +++++++++++++++++++++++++++++ MOAFModels.py | 240 ++++++++++++++++++++++++++++++++++++++++++++ MOAFTest.py | 88 ++++++++++++++++ MOAFTrain.py | 178 ++++++++++++++++++++++++++++++++ MOAFTrainDDP.py | 239 +++++++++++++++++++++++++++++++++++++++++++ MOAFUtils.py | 21 ++++ config_example.toml | 26 +++++ 8 files changed, 959 insertions(+) create mode 100644 .gitignore create mode 100644 MOAFDatasets.py create mode 100644 MOAFModels.py create mode 100644 MOAFTest.py create mode 100644 MOAFTrain.py create mode 100644 MOAFTrainDDP.py create mode 100644 MOAFUtils.py create mode 100644 config_example.toml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4ae37e6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +__pycache__/ +ckpts/ +configs/ +runs/ +results/ +models.ipynb +ShuffleNetV2.txt \ No newline at end of file diff --git a/MOAFDatasets.py b/MOAFDatasets.py new file mode 100644 index 0000000..b394099 --- /dev/null +++ b/MOAFDatasets.py @@ -0,0 +1,160 @@ +import pandas as pd +from pathlib import Path +from torch.utils.data import Dataset +from PIL import Image +import torch +from torchvision import transforms + + +class MOAFDataset(Dataset): + def __init__(self, dataset_root, tvt='train', objectives_params_list=None, output_type='distance'): + """ + dataset_root: 根目录(Pathable) + tvt: 'train'|'val'|'test'(用于选择 transform) + objectives_params_list: 列表,包含要加载的物镜目录名,例如 ["10x-0.25-1.0000", ...] + output_type: 'distance'(返回 nm)或 'ratio'(返回 defocus / DoF) + """ + super().__init__() + self.dataset_root = Path(dataset_root) + self.tvt = tvt + if objectives_params_list is None: + self.objectives_params_list = ["10x-0.25-1.0000"] + else: + self.objectives_params_list = objectives_params_list + + # 处理 output_type,非法输入回退到 'distance' + if isinstance(output_type, str) and output_type.lower() == "ratio": + self.output_type = "ratio" + else: + self.output_type = "distance" + + # 根据 tvt 选择 transform + if self.tvt == "train": + self.transform = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + else: + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # 直接在构造函数中读取并合并 csv + all_dfs = [] + for param_dir in self.objectives_params_list: + csv_file_path = self.dataset_root / "tvtinfo" / param_dir / f"{self.tvt}.csv" + if not csv_file_path.exists(): + raise FileNotFoundError(f"CSV not found: {csv_file_path}") + df = pd.read_csv(csv_file_path) + all_dfs.append(df) + + if len(all_dfs) == 0: + raise ValueError("No csv files were loaded. Check objectives_params_list and dataset_root.") + + combined_df = pd.concat(all_dfs, ignore_index=True) + + # 过滤 relative 范围 + self.dataframe = combined_df[(combined_df["relative"] >= -50) & (combined_df["relative"] <= 50)].reset_index(drop=True) + + def __len__(self): + return len(self.dataframe) + + @staticmethod + def _compute_dof_nm(mag, na, rix, wavelength_nm=550.0, pixel_size_nm=3450.0): + """ + 公式: DoF = lambda * n / (NA ** 2) + (n * e) / (M * NA) + 输入参数均为标量(float),返回 DoF(nm) + """ + # 防止除以零 + if na == 0 or mag == 0: + return float('inf') + lam = float(wavelength_nm) + n = float(rix) + M = float(mag) + NA = float(na) + e = float(pixel_size_nm) + dof = (lam * n) / (NA ** 2) + ((n * e) / (M * NA)) + return dof + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + row = self.dataframe.iloc[idx] + img_path = self.dataset_root / row['path'] + image = Image.open(img_path).convert('RGB') + + # 基本数值字段(注意 CSV 列名需匹配) + mag = float(row['mag']) + na = float(row['na']) + rix = float(row['rix']) + label_nm = float(row['label']) + + image = self.transform(image) + + mag_tensor = torch.tensor(mag, dtype=torch.float32) + na_tensor = torch.tensor(na, dtype=torch.float32) + rix_tensor = torch.tensor(rix, dtype=torch.float32) + label_nm_tensor = torch.tensor(label_nm, dtype=torch.float32) + + # min-max 归一化输入参数 + mag_tensor = (mag_tensor - 10) / (100 - 10) + na_tensor = (na_tensor - 0) / (1.25 - 0) + rix_tensor = (rix_tensor - 1.0) / (1.5 - 1.0) + + # 根据 output_type 决定输出 label + if self.output_type == "ratio": + dof_nm = self._compute_dof_nm(mag=mag, na=na, rix=rix, wavelength_nm=550.0, pixel_size_nm=3450.0) + # 若 DOF 为 inf 或极大,避免除零 + if not (dof_nm is None or dof_nm == float('inf') or dof_nm == 0): + label_out = label_nm / dof_nm + else: + label_out = label_nm # 回退,虽然不太可能 + label_out_tensor = torch.tensor(float(label_out), dtype=torch.float32) + else: + # distance 模式:直接返回 nm + label_out_tensor = label_nm_tensor + + sample = { + 'image': image, + 'mag': mag_tensor, + 'na': na_tensor, + 'rix': rix_tensor, + 'label': label_out_tensor, + 'path': img_path.as_posix(), + } + + return sample + + def get_dataframe(self): + return self.dataframe + + +if __name__ == "__main__": + # 简单测试 + train_set = MOAFDataset("F:/Datasets/MODatasetD", tvt='train', + objectives_params_list=[ + "100x-1.25-1.4730", + ], + output_type='ratio') + from torch.utils.data import DataLoader + train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2) + for batch in train_loader: + images = batch["image"] + labels = batch["label"] + print(f"images.shape: {images.shape}, labels.shape: {labels.shape}") + mags = batch["mag"] + nas = batch["na"] + rixs = batch["rix"] + print(f"mags: {mags}, nas: {nas}, rixs: {rixs}") + print(f"mags.shape: {mags.shape}, nas.shape: {nas.shape}, rixs.shape: {rixs.shape}") + params = torch.stack((mags, nas, rixs), dim=1) + print(f"params shape: {params.shape}") + print("first labels:") + for i in range(min(4, labels.shape[0])): + print(labels[i].item()) + break diff --git a/MOAFModels.py b/MOAFModels.py new file mode 100644 index 0000000..52eb3dd --- /dev/null +++ b/MOAFModels.py @@ -0,0 +1,240 @@ +import torch +from torch import nn +from torchvision import models + + +# 无融合模型 +class MOAFNoFusion(nn.Module): + def __init__(self): + super().__init__() + shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") + self.features = nn.Sequential( + shuff.conv1, shuff.maxpool, + shuff.stage2, shuff.stage3, + shuff.stage4, shuff.conv5 + ) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.regressor = nn.Sequential( + nn.Flatten(), + nn.Linear(1024, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 1) + ) + + def forward(self, image, params): + x = self.features(image) + x = self.avgpool(x) + x = self.regressor(x) + return x + + +# 参数嵌入模型 +class ParamEmbedding(nn.Module): + def __init__(self, in_dim=3, hidden_dim=64, out_dim=128): + super().__init__() + self.embedding = nn.Sequential( + nn.Linear(in_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, out_dim), + nn.ReLU(), + nn.LayerNorm(out_dim) + ) + + def forward(self, x): + return self.embedding(x) + + +# FiLM 融合块 +class FiLMBlock(nn.Module): + def __init__(self, param_emb_dim=128, feat_channels=128): + super().__init__() + self.gamma_gen = nn.Linear(param_emb_dim, feat_channels) + self.beta_gen = nn.Linear(param_emb_dim, feat_channels) + + def forward(self, feature_map, param_emb): + gamma = self.gamma_gen(param_emb).unsqueeze(-1).unsqueeze(-1) + beta = self.beta_gen(param_emb).unsqueeze(-1).unsqueeze(-1) + return feature_map * (1.0 + gamma) + beta + + +# 通道交叉注意力融合块 +class ChannelCrossAttention(nn.Module): + def __init__(self, param_emb_dim=128, feat_channels=128): + super().__init__() + self.param_emb_dim = param_emb_dim + self.feat_channels = feat_channels + self.hidden_dim = max(16, self.feat_channels // 4) + + self.global_pool = nn.AdaptiveAvgPool2d(1) + + self.bottle_neck = nn.Sequential( + nn.Linear(self.param_emb_dim + self.feat_channels, self.hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(self.hidden_dim, self.feat_channels), + nn.Sigmoid() + ) + + def forward(self, feature_map, param_emb): + b, c, h, w = feature_map.shape + pooled = self.global_pool(feature_map).view(b, c) + cat = torch.cat([pooled, param_emb], dim=1) + weights = self.bottle_neck(cat) + weights4d = weights.view(b, c, 1, 1) + return feature_map * (1.0 + weights4d) + + +# SE块 +class SEBlock(nn.Module): + def __init__(self, feat_channels=128): + super().__init__() + self.feat_channels = feat_channels + self.hidden_dim = max(16, self.feat_channels // 4) + + self.global_pool = nn.AdaptiveAvgPool2d(1) + + self.bottle_neck = nn.Sequential( + nn.Linear(self.feat_channels, self.hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(self.hidden_dim, self.feat_channels), + nn.Sigmoid() + ) + + def forward(self, feature_map): + b, c, h, w = feature_map.shape + pooled = self.global_pool(feature_map).view(b, c) + weights = self.bottle_neck(pooled) + weights4d = weights.view(b, c, 1, 1) + return feature_map * (1.0 + weights4d) + + +# 仅返回特征图的恒等变换 +class FusionIdentity(nn.Module): + def forward(self, feature_map, param_emb): + return feature_map + + +# 使用 FiLM 融合的模型 +class MOAFWithFiLM(nn.Module): + def __init__(self, fusion_level=None): + super().__init__() + shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") + self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool) + self.stage2 = shuff.stage2 + self.stage3 = shuff.stage3 + self.stage4 = shuff.stage4 + self.cbr2 = shuff.conv5 + + self.param_embedding = ParamEmbedding() + + if fusion_level is None: + fusion_level = [2] + + self.film_block0 = FiLMBlock(feat_channels=48) if 0 in fusion_level else FusionIdentity() + self.film_block1 = FiLMBlock(feat_channels=96) if 1 in fusion_level else FusionIdentity() + self.film_block2 = FiLMBlock(feat_channels=192) if 2 in fusion_level else FusionIdentity() + + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.regressor = nn.Sequential( + nn.Flatten(), + nn.Linear(1024, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 1) + ) + + def forward(self, image, params): + x = self.cbrm(image) + x = self.stage2(x) + param_emb = self.param_embedding(params) + x = self.film_block0(x, param_emb) + x = self.stage3(x) + x = self.film_block1(x, param_emb) + x = self.stage4(x) + x = self.film_block2(x, param_emb) + x = self.cbr2(x) + x = self.avgpool(x) + x = self.regressor(x) + return x + + +# 使用交叉注意力融合的模型 +class MOAFWithChannelCrossAttention(nn.Module): + def __init__(self, fusion_level=None): + super().__init__() + shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") + self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool) + self.stage2 = shuff.stage2 + self.stage3 = shuff.stage3 + self.stage4 = shuff.stage4 + self.cbr2 = shuff.conv5 + + self.param_embedding = ParamEmbedding() + + if fusion_level is None: + fusion_level = [2] + + self.cca_block0 = ChannelCrossAttention(feat_channels=48) if 0 in fusion_level else FusionIdentity() + self.cca_block1 = ChannelCrossAttention(feat_channels=96) if 1 in fusion_level else FusionIdentity() + self.cca_block2 = ChannelCrossAttention(feat_channels=192) if 2 in fusion_level else FusionIdentity() + + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.regressor = nn.Sequential( + nn.Flatten(), + nn.Linear(1024, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 1) + ) + + def forward(self, image, params): + x = self.cbrm(image) + x = self.stage2(x) + param_emb = self.param_embedding(params) + x = self.cca_block0(x, param_emb) + x = self.stage3(x) + x = self.cca_block1(x, param_emb) + x = self.stage4(x) + x = self.cca_block2(x, param_emb) + x = self.cbr2(x) + x = self.avgpool(x) + x = self.regressor(x) + return x + + +# 使用 SE 块但无融合的模型 +class MOAFWithSE(nn.Module): + def __init__(self, fusion_level=None): + super().__init__() + shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") + self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool) + self.stage2 = shuff.stage2 + self.stage3 = shuff.stage3 + self.stage4 = shuff.stage4 + self.cbr2 = shuff.conv5 + + if fusion_level is None: + fusion_level = [0] + + self.se_block0 = SEBlock(feat_channels=48) if 0 in fusion_level else nn.Identity() + self.se_block1 = SEBlock(feat_channels=96) if 1 in fusion_level else nn.Identity() + self.se_block2 = SEBlock(feat_channels=192) if 2 in fusion_level else nn.Identity() + + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.regressor = nn.Sequential( + nn.Flatten(), + nn.Linear(1024, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 1) + ) + + def forward(self, image, params): + x = self.cbrm(image) + x = self.stage2(x) + x = self.se_block0(x) + x = self.stage3(x) + x = self.se_block1(x) + x = self.stage4(x) + x = self.se_block2(x) + x = self.cbr2(x) + x = self.avgpool(x) + x = self.regressor(x) + return x + diff --git a/MOAFTest.py b/MOAFTest.py new file mode 100644 index 0000000..2b25f8c --- /dev/null +++ b/MOAFTest.py @@ -0,0 +1,88 @@ +import torch +from torch.utils.data import DataLoader + +import tomllib +import argparse +import numpy as np +from tqdm import tqdm +from pathlib import Path + +from MOAFUtils import print_with_timestamp +from MOAFDatasets import MOAFDataset +from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE + + +def test(model, test_loader, device, model_type, output_type): + model.eval() + results = [] + + with torch.no_grad(): + for data in tqdm(test_loader, desc=f"[Test]"): + images = data["image"].to(device, non_blocking=True) + params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) + outputs = model(images, params) + results.extend(outputs.squeeze(1).cpu().numpy()) + + results = np.array(results) + + # 直接预测结果保存为 excel 文件 + df = test_loader.dataset.get_dataframe() + df["pred"] = results + Path("results").mkdir(exist_ok=True, parents=True) + # !pip install openpyxl + df.to_excel(f"results/{model_type}_{output_type}.xlsx", index=False) + + +def main(): + # 解析命令行参数 + parser = argparse.ArgumentParser(description="Test script that loads config from a TOML file.") + parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)") + args = parser.parse_args() + with open(args.config, "rb") as f: + cfg = tomllib.load(f) + + # 确定超参数 + model_type = cfg["model_type"] + output_type = cfg["output_type"] + dataset_dir = cfg["dataset_dir"] + batch_size = int(cfg["batch_size"]) + num_workers = int(cfg["num_workers"]) + objective_params_list = cfg["test_objective_params_list"] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + print_with_timestamp(f"Using device {device}") + + # 加载数据集 + test_set = MOAFDataset(dataset_dir, "test", objective_params_list, output_type) + print_with_timestamp("Dataset Done") + + test_loader = DataLoader( + test_set, batch_size=batch_size, num_workers=num_workers, + shuffle=False, pin_memory=True, persistent_workers=True + ) + print_with_timestamp("Dataset Loaded") + + # 模型选择 + if "film" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[4:]] + model = MOAFWithFiLM(fusion_depth_list).to(device) + elif "cca" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[3:]] + model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device) + elif "se" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[2:]] + model = MOAFWithSE(fusion_depth_list).to(device) + else: + model = MOAFNoFusion().to(device) + + checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + print_with_timestamp("Model Loaded") + + print_with_timestamp("Start testing") + test(model, test_loader, device, model_type, output_type) + print_with_timestamp("Testing completed!") + + +if __name__ == "__main__": + main() diff --git a/MOAFTrain.py b/MOAFTrain.py new file mode 100644 index 0000000..c5d6578 --- /dev/null +++ b/MOAFTrain.py @@ -0,0 +1,178 @@ +import torch +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +import time +import tomllib +import argparse +import numpy as np +from tqdm import tqdm +from pathlib import Path + +from MOAFUtils import print_with_timestamp +from MOAFDatasets import MOAFDataset +from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE + + +def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn): + model.train() + train_loss = 0.0 + + for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"): + images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) + params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) + + optimizer.zero_grad() + outputs = model(images, params) + loss = loss_fn(outputs.squeeze(1), labels) + loss.backward() + optimizer.step() + train_loss += loss.item() + + return train_loss + + +def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn): + model.eval() + val_loss = 0.0 + + with torch.no_grad(): + for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"): + images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) + params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) + + outputs = model(images, params) + loss = loss_fn(outputs.squeeze(1), labels) + val_loss += loss.item() + + return val_loss + + +def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type): + best_val_loss = float('inf') + patience_counter = 0 + # !pip install tensorboard + with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer: + # Tensorboard 上显示模型结构 + dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device) + writer.add_graph(model, (dummy_input1, dummy_input2)) + + for epoch in range(epochs): + # 训练/验证 + start_time = time.time() + avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader) + avg_val_loss = valid_epoch(model,val_loader, epoch, epochs, device, loss_fn) / len(val_loader) + current_lr = optimizer.param_groups[0]['lr'] + scheduler.step() + epoch_time = time.time() - start_time + + # 显示记录数据 + writer.add_scalar('Loss/train', avg_train_loss, epoch) + writer.add_scalar('Loss/val', avg_val_loss, epoch) + writer.add_scalar('LearningRate', current_lr, epoch) + writer.add_scalar('Time/epoch', epoch_time, epoch) + + print_with_timestamp(f"Epoch {epoch+1}/{epochs} | " + f"Train Loss: {avg_train_loss:.6f} | " + f"Val Loss: {avg_val_loss:.6f} | " + f"LR: {current_lr:.2e} | " + f"Time: {epoch_time:.2f}s") + + # 记录检查点 + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + patience_counter = 0 + save_dict = { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "train_loss": avg_train_loss, + "val_loss": avg_val_loss + } + Path("ckpts").mkdir(exist_ok=True, parents=True) + torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt") + print_with_timestamp(f"New best model saved at epoch {epoch+1}") + else: + patience_counter += 1 + if patience_counter > patience: + print_with_timestamp(f"Early stopping at {epoch+1} epochs") + break + + +def main(): + # 解析命令行参数 + parser = argparse.ArgumentParser(description="Train script that loads config from a TOML file.") + parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)") + args = parser.parse_args() + with open(args.config, "rb") as f: + cfg = tomllib.load(f) + + # 确定超参数 + model_type = cfg["model_type"] + output_type = cfg["output_type"] + dataset_dir = cfg["dataset_dir"] + batch_size = int(cfg["batch_size"]) + num_workers = int(cfg["num_workers"]) + lr = float(cfg["lr"]) + patience = int(cfg["patience"]) + epochs = int(cfg["epochs"]) + warmup_epochs = int(cfg["warmup_epochs"]) + objective_params_list = cfg["train_objective_params_list"] + checkpoint_load = cfg["checkpoint_load"] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + print_with_timestamp(f"Using device {device}") + + # 加载数据集 + train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type) + val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type) + print_with_timestamp("Dataset Done") + + train_loader = DataLoader( + train_set, batch_size=batch_size, num_workers=num_workers, + shuffle=True, pin_memory=True, persistent_workers=True + ) + val_loader = DataLoader( + val_set, batch_size=batch_size, num_workers=num_workers, + shuffle=False, pin_memory=True, persistent_workers=True + ) + print_with_timestamp("Dataset Loaded") + + # 模型选择 + if "film" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[4:]] + model = MOAFWithFiLM(fusion_depth_list).to(device) + elif "cca" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[3:]] + model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device) + elif "se" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[2:]] + model = MOAFWithSE(fusion_depth_list).to(device) + else: + model = MOAFNoFusion().to(device) + print_with_timestamp("Model Loaded") + + # 形式化预训练参数加载 + if checkpoint_load: + checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + print_with_timestamp("Model Checkpoint Params Loaded") + + # 损失函数、优化器、学习率调度器 + loss_fn = nn.HuberLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs + else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))) + ) + + print_with_timestamp("Start trainning") + fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type) + print_with_timestamp("Training completed!") + + +if __name__ == "__main__": + main() diff --git a/MOAFTrainDDP.py b/MOAFTrainDDP.py new file mode 100644 index 0000000..8ac7778 --- /dev/null +++ b/MOAFTrainDDP.py @@ -0,0 +1,239 @@ +import os +import time +import argparse +import tomllib +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from MOAFUtils import print_with_timestamp +from MOAFDatasets import MOAFDataset +from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE + + +def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn): + model.train() + train_loss = 0.0 + + for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"): + images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) + params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) + + optimizer.zero_grad() + outputs = model(images, params) + loss = loss_fn(outputs.squeeze(1), labels) + loss.backward() + optimizer.step() + train_loss += loss.item() + + return train_loss + + +def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn): + model.eval() + val_loss = 0.0 + + with torch.no_grad(): + for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"): + images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) + params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) + + outputs = model(images, params) + loss = loss_fn(outputs.squeeze(1), labels) + val_loss += loss.item() + + return val_loss + + +def fit(rank, world_size, cfg): + """ + 每个进程运行的主函数(单卡)。 + rank: 该进程的全局 rank(0 ~ world_size-1) + world_size: 进程总数(通常等于可用 GPU 数) + cfg: 从 toml 读取的配置字典 + """ + # -------- init distributed env -------- + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = cfg.get("master_port", "29500") + dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + device = torch.device(f'cuda:{rank}') + if rank == 0: + print_with_timestamp(f"Distributed initialized. World size: {world_size}") + + # -------- parse hyperparams from cfg -------- + model_type = cfg["model_type"] + output_type = cfg["output_type"] + dataset_dir = cfg["dataset_dir"] + batch_size = int(cfg["batch_size"]) + num_workers = int(cfg["num_workers"]) + lr = float(cfg["lr"]) + patience = int(cfg["patience"]) + epochs = int(cfg["epochs"]) + warmup_epochs = int(cfg["warmup_epochs"]) + objective_params_list = cfg["train_objective_params_list"] + checkpoint_load = cfg["checkpoint_load"] + + # -------- datasets & distributed sampler -------- + train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type) + val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type) + + # DistributedSampler ensures each process sees a unique subset each epoch + train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True) + val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False) + + # When using DistributedSampler, do not shuffle at DataLoader level + train_loader = DataLoader( + train_set, batch_size=batch_size, num_workers=num_workers, + shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler + ) + val_loader = DataLoader( + val_set, batch_size=batch_size, num_workers=num_workers, + shuffle=False, pin_memory=True, persistent_workers=True, sampler=val_sampler + ) + + if rank == 0: + print_with_timestamp("Dataset Loaded (Distributed)") + + # -------- model creation -------- + if "film" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[4:]] + model = MOAFWithFiLM(fusion_depth_list).to(device) + elif "cca" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[3:]] + model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device) + elif "se" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[2:]] + model = MOAFWithSE(fusion_depth_list).to(device) + else: + model = MOAFNoFusion().to(device) + + # 形式化预训练参数加载 + if checkpoint_load: + checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + print_with_timestamp("Model Checkpoint Params Loaded") + + # wrap with DDP. device_ids ensures single-GPU per process + model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False) + if rank == 0: + print_with_timestamp("Model Wrapped with DDP") + + # -------- loss / optimizer / scheduler -------- + loss_fn = nn.HuberLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs + else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))) + ) + + # -------- TensorBoard & checkpoint only on rank 0 -------- + if rank == 0: + tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") + # tensorboard graph: use a small dummy input placed on correct device + dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device) + tb_writer.add_graph(model.module, (dummy_input1, dummy_input2)) + + else: + tb_writer = None + + # -------- training loop with early stopping (only rank 0 saves checkpoints) -------- + best_val_loss = float('inf') + patience_counter = 0 + + if rank == 0: + print_with_timestamp("Start training (DDP)") + + for epoch in range(epochs): + # 必须在每个 epoch 开头设置 epoch 给 sampler,这样 shuffle 能跨 epoch 工作 + train_sampler.set_epoch(epoch) + val_sampler.set_epoch(epoch) + + start_time = time.time() + avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader) + avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn) / len(val_loader) + current_lr = optimizer.param_groups[0]['lr'] + scheduler.step() + epoch_time = time.time() - start_time + + # 只有 rank 0 写 tensorboard 和保存 checkpoint,避免重复 IO + if rank == 0: + tb_writer.add_scalar('Loss/train', avg_train_loss, epoch) + tb_writer.add_scalar('Loss/val', avg_val_loss, epoch) + tb_writer.add_scalar('LearningRate', current_lr, epoch) + tb_writer.add_scalar('Time/epoch', epoch_time, epoch) + + print_with_timestamp(f"Epoch {epoch+1}/{epochs} | " + f"Train Loss: {avg_train_loss:.6f} | " + f"Val Loss: {avg_val_loss:.6f} | " + f"LR: {current_lr:.2e} | " + f"Time: {epoch_time:.2f}s") + + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + patience_counter = 0 + save_dict = { + "epoch": epoch, + # 保存 module.state_dict()(DDP 包裹时用 module) + "model_state_dict": model.module.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "train_loss": avg_train_loss, + "val_loss": avg_val_loss + } + Path("ckpts").mkdir(exist_ok=True, parents=True) + torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt") + print_with_timestamp(f"New best model saved at epoch {epoch+1}") + else: + patience_counter += 1 + if patience_counter > patience: + print_with_timestamp(f"Early stopping at {epoch+1} epochs") + break + + # cleanup + if tb_writer is not None: + tb_writer.close() + dist.destroy_process_group() + if rank == 0: + print_with_timestamp("Training completed and distributed cleaned up") + + +def parse_args_and_cfg(): + parser = argparse.ArgumentParser(description="DDP train script that loads config from a TOML file.") + parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)") + parser.add_argument("--nproc_per_node", type=int, default=None, help="number of processes (GPUs) to use on this node") + args = parser.parse_args() + with open(args.config, "rb") as f: + cfg = tomllib.load(f) + return args, cfg + + +def main(): + args, cfg = parse_args_and_cfg() + # 自动检测可用 GPU 数量,除非用户指定 nproc_per_node + available_gpus = torch.cuda.device_count() + if available_gpus == 0: + raise RuntimeError("No CUDA devices available for DDP. Use a GPU machine.") + + nproc = args.nproc_per_node if args.nproc_per_node is not None else available_gpus + if nproc > available_gpus: + raise ValueError(f"Requested {nproc} processes but only {available_gpus} GPUs available") + + # 使用 spawn 启动 nproc 个子进程,每个进程运行 fit(rank, world_size, cfg) + # spawn 会在每个子进程中调用 fit(rank, world_size, cfg) 并传入不同的 rank + mp.spawn(fit, args=(nproc, cfg), nprocs=nproc, join=True) + + +if __name__ == "__main__": + main() diff --git a/MOAFUtils.py b/MOAFUtils.py new file mode 100644 index 0000000..69c4757 --- /dev/null +++ b/MOAFUtils.py @@ -0,0 +1,21 @@ +import sys +from datetime import datetime + +def print_with_timestamp(*args, sep=' ', end='\n', file=None, flush=False, time_format="%Y-%m-%d %H:%M:%S"): + """ + 与 print 类似,但在输出内容前加时间戳前缀。 + - *args: 要打印的对象(会用 sep 连接) + - sep, end, file, flush: 与内置 print 相同 + - time_format: 时间格式,默认 '%Y-%m-%d %H:%M:%S',遵循 datetime.strftime 格式 + """ + if file is None: + file = sys.stdout + + ts = datetime.now().strftime(time_format) + body = sep.join(str(a) for a in args) + output = f'{ts} --> {body}' + print(output, end=end, file=file, flush=flush) + + +if __name__ == "__main__": + print_with_timestamp(f"Test output") \ No newline at end of file diff --git a/config_example.toml b/config_example.toml new file mode 100644 index 0000000..a92112d --- /dev/null +++ b/config_example.toml @@ -0,0 +1,26 @@ +# 模型与数据 +model_type = "cca2" +output_type = "distance" +dataset_dir = "F:/Datasets/MODatasetD" +# 训练参数 +batch_size = 64 +num_workers = 8 +lr = 1e-4 +patience = 5 +epochs = 5 +warmup_epochs = 1 +# 其它 +train_objective_params_list = [ + "10x-0.25-1.0000", "10x-0.30-1.0000", + "20x-0.70-1.0000", "20x-0.80-1.0000", + "40x-0.65-1.0000", "100x-0.80-1.0000", + "100x-1.25-1.4730" +] +test_objective_params_list = [ + "10x-0.25-1.0000", "10x-0.30-1.0000", + "20x-0.70-1.0000", "20x-0.80-1.0000", + "40x-0.65-1.0000", "100x-0.80-1.0000", + "100x-1.25-1.4730" +] +# 加载形式化预训练参数 +checkpoint_load = true \ No newline at end of file