commit f54d81770202a3be9c056cf917434744f7ed432d Author: kaiza_hikaru Date: Thu Oct 23 16:40:40 2025 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4ae37e6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +__pycache__/ +ckpts/ +configs/ +runs/ +results/ +models.ipynb +ShuffleNetV2.txt \ No newline at end of file diff --git a/MOAFDatasets.py b/MOAFDatasets.py new file mode 100644 index 0000000..b394099 --- /dev/null +++ b/MOAFDatasets.py @@ -0,0 +1,160 @@ +import pandas as pd +from pathlib import Path +from torch.utils.data import Dataset +from PIL import Image +import torch +from torchvision import transforms + + +class MOAFDataset(Dataset): + def __init__(self, dataset_root, tvt='train', objectives_params_list=None, output_type='distance'): + """ + dataset_root: 根目录(Pathable) + tvt: 'train'|'val'|'test'(用于选择 transform) + objectives_params_list: 列表,包含要加载的物镜目录名,例如 ["10x-0.25-1.0000", ...] + output_type: 'distance'(返回 nm)或 'ratio'(返回 defocus / DoF) + """ + super().__init__() + self.dataset_root = Path(dataset_root) + self.tvt = tvt + if objectives_params_list is None: + self.objectives_params_list = ["10x-0.25-1.0000"] + else: + self.objectives_params_list = objectives_params_list + + # 处理 output_type,非法输入回退到 'distance' + if isinstance(output_type, str) and output_type.lower() == "ratio": + self.output_type = "ratio" + else: + self.output_type = "distance" + + # 根据 tvt 选择 transform + if self.tvt == "train": + self.transform = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + else: + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # 直接在构造函数中读取并合并 csv + all_dfs = [] + for param_dir in self.objectives_params_list: + csv_file_path = self.dataset_root / "tvtinfo" / param_dir / f"{self.tvt}.csv" + if not csv_file_path.exists(): + raise FileNotFoundError(f"CSV not found: {csv_file_path}") + df = pd.read_csv(csv_file_path) + all_dfs.append(df) + + if len(all_dfs) == 0: + raise ValueError("No csv files were loaded. Check objectives_params_list and dataset_root.") + + combined_df = pd.concat(all_dfs, ignore_index=True) + + # 过滤 relative 范围 + self.dataframe = combined_df[(combined_df["relative"] >= -50) & (combined_df["relative"] <= 50)].reset_index(drop=True) + + def __len__(self): + return len(self.dataframe) + + @staticmethod + def _compute_dof_nm(mag, na, rix, wavelength_nm=550.0, pixel_size_nm=3450.0): + """ + 公式: DoF = lambda * n / (NA ** 2) + (n * e) / (M * NA) + 输入参数均为标量(float),返回 DoF(nm) + """ + # 防止除以零 + if na == 0 or mag == 0: + return float('inf') + lam = float(wavelength_nm) + n = float(rix) + M = float(mag) + NA = float(na) + e = float(pixel_size_nm) + dof = (lam * n) / (NA ** 2) + ((n * e) / (M * NA)) + return dof + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + row = self.dataframe.iloc[idx] + img_path = self.dataset_root / row['path'] + image = Image.open(img_path).convert('RGB') + + # 基本数值字段(注意 CSV 列名需匹配) + mag = float(row['mag']) + na = float(row['na']) + rix = float(row['rix']) + label_nm = float(row['label']) + + image = self.transform(image) + + mag_tensor = torch.tensor(mag, dtype=torch.float32) + na_tensor = torch.tensor(na, dtype=torch.float32) + rix_tensor = torch.tensor(rix, dtype=torch.float32) + label_nm_tensor = torch.tensor(label_nm, dtype=torch.float32) + + # min-max 归一化输入参数 + mag_tensor = (mag_tensor - 10) / (100 - 10) + na_tensor = (na_tensor - 0) / (1.25 - 0) + rix_tensor = (rix_tensor - 1.0) / (1.5 - 1.0) + + # 根据 output_type 决定输出 label + if self.output_type == "ratio": + dof_nm = self._compute_dof_nm(mag=mag, na=na, rix=rix, wavelength_nm=550.0, pixel_size_nm=3450.0) + # 若 DOF 为 inf 或极大,避免除零 + if not (dof_nm is None or dof_nm == float('inf') or dof_nm == 0): + label_out = label_nm / dof_nm + else: + label_out = label_nm # 回退,虽然不太可能 + label_out_tensor = torch.tensor(float(label_out), dtype=torch.float32) + else: + # distance 模式:直接返回 nm + label_out_tensor = label_nm_tensor + + sample = { + 'image': image, + 'mag': mag_tensor, + 'na': na_tensor, + 'rix': rix_tensor, + 'label': label_out_tensor, + 'path': img_path.as_posix(), + } + + return sample + + def get_dataframe(self): + return self.dataframe + + +if __name__ == "__main__": + # 简单测试 + train_set = MOAFDataset("F:/Datasets/MODatasetD", tvt='train', + objectives_params_list=[ + "100x-1.25-1.4730", + ], + output_type='ratio') + from torch.utils.data import DataLoader + train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2) + for batch in train_loader: + images = batch["image"] + labels = batch["label"] + print(f"images.shape: {images.shape}, labels.shape: {labels.shape}") + mags = batch["mag"] + nas = batch["na"] + rixs = batch["rix"] + print(f"mags: {mags}, nas: {nas}, rixs: {rixs}") + print(f"mags.shape: {mags.shape}, nas.shape: {nas.shape}, rixs.shape: {rixs.shape}") + params = torch.stack((mags, nas, rixs), dim=1) + print(f"params shape: {params.shape}") + print("first labels:") + for i in range(min(4, labels.shape[0])): + print(labels[i].item()) + break diff --git a/MOAFModels.py b/MOAFModels.py new file mode 100644 index 0000000..52eb3dd --- /dev/null +++ b/MOAFModels.py @@ -0,0 +1,240 @@ +import torch +from torch import nn +from torchvision import models + + +# 无融合模型 +class MOAFNoFusion(nn.Module): + def __init__(self): + super().__init__() + shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") + self.features = nn.Sequential( + shuff.conv1, shuff.maxpool, + shuff.stage2, shuff.stage3, + shuff.stage4, shuff.conv5 + ) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.regressor = nn.Sequential( + nn.Flatten(), + nn.Linear(1024, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 1) + ) + + def forward(self, image, params): + x = self.features(image) + x = self.avgpool(x) + x = self.regressor(x) + return x + + +# 参数嵌入模型 +class ParamEmbedding(nn.Module): + def __init__(self, in_dim=3, hidden_dim=64, out_dim=128): + super().__init__() + self.embedding = nn.Sequential( + nn.Linear(in_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, out_dim), + nn.ReLU(), + nn.LayerNorm(out_dim) + ) + + def forward(self, x): + return self.embedding(x) + + +# FiLM 融合块 +class FiLMBlock(nn.Module): + def __init__(self, param_emb_dim=128, feat_channels=128): + super().__init__() + self.gamma_gen = nn.Linear(param_emb_dim, feat_channels) + self.beta_gen = nn.Linear(param_emb_dim, feat_channels) + + def forward(self, feature_map, param_emb): + gamma = self.gamma_gen(param_emb).unsqueeze(-1).unsqueeze(-1) + beta = self.beta_gen(param_emb).unsqueeze(-1).unsqueeze(-1) + return feature_map * (1.0 + gamma) + beta + + +# 通道交叉注意力融合块 +class ChannelCrossAttention(nn.Module): + def __init__(self, param_emb_dim=128, feat_channels=128): + super().__init__() + self.param_emb_dim = param_emb_dim + self.feat_channels = feat_channels + self.hidden_dim = max(16, self.feat_channels // 4) + + self.global_pool = nn.AdaptiveAvgPool2d(1) + + self.bottle_neck = nn.Sequential( + nn.Linear(self.param_emb_dim + self.feat_channels, self.hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(self.hidden_dim, self.feat_channels), + nn.Sigmoid() + ) + + def forward(self, feature_map, param_emb): + b, c, h, w = feature_map.shape + pooled = self.global_pool(feature_map).view(b, c) + cat = torch.cat([pooled, param_emb], dim=1) + weights = self.bottle_neck(cat) + weights4d = weights.view(b, c, 1, 1) + return feature_map * (1.0 + weights4d) + + +# SE块 +class SEBlock(nn.Module): + def __init__(self, feat_channels=128): + super().__init__() + self.feat_channels = feat_channels + self.hidden_dim = max(16, self.feat_channels // 4) + + self.global_pool = nn.AdaptiveAvgPool2d(1) + + self.bottle_neck = nn.Sequential( + nn.Linear(self.feat_channels, self.hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(self.hidden_dim, self.feat_channels), + nn.Sigmoid() + ) + + def forward(self, feature_map): + b, c, h, w = feature_map.shape + pooled = self.global_pool(feature_map).view(b, c) + weights = self.bottle_neck(pooled) + weights4d = weights.view(b, c, 1, 1) + return feature_map * (1.0 + weights4d) + + +# 仅返回特征图的恒等变换 +class FusionIdentity(nn.Module): + def forward(self, feature_map, param_emb): + return feature_map + + +# 使用 FiLM 融合的模型 +class MOAFWithFiLM(nn.Module): + def __init__(self, fusion_level=None): + super().__init__() + shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") + self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool) + self.stage2 = shuff.stage2 + self.stage3 = shuff.stage3 + self.stage4 = shuff.stage4 + self.cbr2 = shuff.conv5 + + self.param_embedding = ParamEmbedding() + + if fusion_level is None: + fusion_level = [2] + + self.film_block0 = FiLMBlock(feat_channels=48) if 0 in fusion_level else FusionIdentity() + self.film_block1 = FiLMBlock(feat_channels=96) if 1 in fusion_level else FusionIdentity() + self.film_block2 = FiLMBlock(feat_channels=192) if 2 in fusion_level else FusionIdentity() + + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.regressor = nn.Sequential( + nn.Flatten(), + nn.Linear(1024, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 1) + ) + + def forward(self, image, params): + x = self.cbrm(image) + x = self.stage2(x) + param_emb = self.param_embedding(params) + x = self.film_block0(x, param_emb) + x = self.stage3(x) + x = self.film_block1(x, param_emb) + x = self.stage4(x) + x = self.film_block2(x, param_emb) + x = self.cbr2(x) + x = self.avgpool(x) + x = self.regressor(x) + return x + + +# 使用交叉注意力融合的模型 +class MOAFWithChannelCrossAttention(nn.Module): + def __init__(self, fusion_level=None): + super().__init__() + shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") + self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool) + self.stage2 = shuff.stage2 + self.stage3 = shuff.stage3 + self.stage4 = shuff.stage4 + self.cbr2 = shuff.conv5 + + self.param_embedding = ParamEmbedding() + + if fusion_level is None: + fusion_level = [2] + + self.cca_block0 = ChannelCrossAttention(feat_channels=48) if 0 in fusion_level else FusionIdentity() + self.cca_block1 = ChannelCrossAttention(feat_channels=96) if 1 in fusion_level else FusionIdentity() + self.cca_block2 = ChannelCrossAttention(feat_channels=192) if 2 in fusion_level else FusionIdentity() + + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.regressor = nn.Sequential( + nn.Flatten(), + nn.Linear(1024, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 1) + ) + + def forward(self, image, params): + x = self.cbrm(image) + x = self.stage2(x) + param_emb = self.param_embedding(params) + x = self.cca_block0(x, param_emb) + x = self.stage3(x) + x = self.cca_block1(x, param_emb) + x = self.stage4(x) + x = self.cca_block2(x, param_emb) + x = self.cbr2(x) + x = self.avgpool(x) + x = self.regressor(x) + return x + + +# 使用 SE 块但无融合的模型 +class MOAFWithSE(nn.Module): + def __init__(self, fusion_level=None): + super().__init__() + shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") + self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool) + self.stage2 = shuff.stage2 + self.stage3 = shuff.stage3 + self.stage4 = shuff.stage4 + self.cbr2 = shuff.conv5 + + if fusion_level is None: + fusion_level = [0] + + self.se_block0 = SEBlock(feat_channels=48) if 0 in fusion_level else nn.Identity() + self.se_block1 = SEBlock(feat_channels=96) if 1 in fusion_level else nn.Identity() + self.se_block2 = SEBlock(feat_channels=192) if 2 in fusion_level else nn.Identity() + + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.regressor = nn.Sequential( + nn.Flatten(), + nn.Linear(1024, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 1) + ) + + def forward(self, image, params): + x = self.cbrm(image) + x = self.stage2(x) + x = self.se_block0(x) + x = self.stage3(x) + x = self.se_block1(x) + x = self.stage4(x) + x = self.se_block2(x) + x = self.cbr2(x) + x = self.avgpool(x) + x = self.regressor(x) + return x + diff --git a/MOAFTest.py b/MOAFTest.py new file mode 100644 index 0000000..2b25f8c --- /dev/null +++ b/MOAFTest.py @@ -0,0 +1,88 @@ +import torch +from torch.utils.data import DataLoader + +import tomllib +import argparse +import numpy as np +from tqdm import tqdm +from pathlib import Path + +from MOAFUtils import print_with_timestamp +from MOAFDatasets import MOAFDataset +from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE + + +def test(model, test_loader, device, model_type, output_type): + model.eval() + results = [] + + with torch.no_grad(): + for data in tqdm(test_loader, desc=f"[Test]"): + images = data["image"].to(device, non_blocking=True) + params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) + outputs = model(images, params) + results.extend(outputs.squeeze(1).cpu().numpy()) + + results = np.array(results) + + # 直接预测结果保存为 excel 文件 + df = test_loader.dataset.get_dataframe() + df["pred"] = results + Path("results").mkdir(exist_ok=True, parents=True) + # !pip install openpyxl + df.to_excel(f"results/{model_type}_{output_type}.xlsx", index=False) + + +def main(): + # 解析命令行参数 + parser = argparse.ArgumentParser(description="Test script that loads config from a TOML file.") + parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)") + args = parser.parse_args() + with open(args.config, "rb") as f: + cfg = tomllib.load(f) + + # 确定超参数 + model_type = cfg["model_type"] + output_type = cfg["output_type"] + dataset_dir = cfg["dataset_dir"] + batch_size = int(cfg["batch_size"]) + num_workers = int(cfg["num_workers"]) + objective_params_list = cfg["test_objective_params_list"] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + print_with_timestamp(f"Using device {device}") + + # 加载数据集 + test_set = MOAFDataset(dataset_dir, "test", objective_params_list, output_type) + print_with_timestamp("Dataset Done") + + test_loader = DataLoader( + test_set, batch_size=batch_size, num_workers=num_workers, + shuffle=False, pin_memory=True, persistent_workers=True + ) + print_with_timestamp("Dataset Loaded") + + # 模型选择 + if "film" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[4:]] + model = MOAFWithFiLM(fusion_depth_list).to(device) + elif "cca" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[3:]] + model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device) + elif "se" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[2:]] + model = MOAFWithSE(fusion_depth_list).to(device) + else: + model = MOAFNoFusion().to(device) + + checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + print_with_timestamp("Model Loaded") + + print_with_timestamp("Start testing") + test(model, test_loader, device, model_type, output_type) + print_with_timestamp("Testing completed!") + + +if __name__ == "__main__": + main() diff --git a/MOAFTrain.py b/MOAFTrain.py new file mode 100644 index 0000000..c5d6578 --- /dev/null +++ b/MOAFTrain.py @@ -0,0 +1,178 @@ +import torch +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +import time +import tomllib +import argparse +import numpy as np +from tqdm import tqdm +from pathlib import Path + +from MOAFUtils import print_with_timestamp +from MOAFDatasets import MOAFDataset +from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE + + +def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn): + model.train() + train_loss = 0.0 + + for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"): + images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) + params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) + + optimizer.zero_grad() + outputs = model(images, params) + loss = loss_fn(outputs.squeeze(1), labels) + loss.backward() + optimizer.step() + train_loss += loss.item() + + return train_loss + + +def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn): + model.eval() + val_loss = 0.0 + + with torch.no_grad(): + for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"): + images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) + params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) + + outputs = model(images, params) + loss = loss_fn(outputs.squeeze(1), labels) + val_loss += loss.item() + + return val_loss + + +def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type): + best_val_loss = float('inf') + patience_counter = 0 + # !pip install tensorboard + with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer: + # Tensorboard 上显示模型结构 + dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device) + writer.add_graph(model, (dummy_input1, dummy_input2)) + + for epoch in range(epochs): + # 训练/验证 + start_time = time.time() + avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader) + avg_val_loss = valid_epoch(model,val_loader, epoch, epochs, device, loss_fn) / len(val_loader) + current_lr = optimizer.param_groups[0]['lr'] + scheduler.step() + epoch_time = time.time() - start_time + + # 显示记录数据 + writer.add_scalar('Loss/train', avg_train_loss, epoch) + writer.add_scalar('Loss/val', avg_val_loss, epoch) + writer.add_scalar('LearningRate', current_lr, epoch) + writer.add_scalar('Time/epoch', epoch_time, epoch) + + print_with_timestamp(f"Epoch {epoch+1}/{epochs} | " + f"Train Loss: {avg_train_loss:.6f} | " + f"Val Loss: {avg_val_loss:.6f} | " + f"LR: {current_lr:.2e} | " + f"Time: {epoch_time:.2f}s") + + # 记录检查点 + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + patience_counter = 0 + save_dict = { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "train_loss": avg_train_loss, + "val_loss": avg_val_loss + } + Path("ckpts").mkdir(exist_ok=True, parents=True) + torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt") + print_with_timestamp(f"New best model saved at epoch {epoch+1}") + else: + patience_counter += 1 + if patience_counter > patience: + print_with_timestamp(f"Early stopping at {epoch+1} epochs") + break + + +def main(): + # 解析命令行参数 + parser = argparse.ArgumentParser(description="Train script that loads config from a TOML file.") + parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)") + args = parser.parse_args() + with open(args.config, "rb") as f: + cfg = tomllib.load(f) + + # 确定超参数 + model_type = cfg["model_type"] + output_type = cfg["output_type"] + dataset_dir = cfg["dataset_dir"] + batch_size = int(cfg["batch_size"]) + num_workers = int(cfg["num_workers"]) + lr = float(cfg["lr"]) + patience = int(cfg["patience"]) + epochs = int(cfg["epochs"]) + warmup_epochs = int(cfg["warmup_epochs"]) + objective_params_list = cfg["train_objective_params_list"] + checkpoint_load = cfg["checkpoint_load"] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + print_with_timestamp(f"Using device {device}") + + # 加载数据集 + train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type) + val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type) + print_with_timestamp("Dataset Done") + + train_loader = DataLoader( + train_set, batch_size=batch_size, num_workers=num_workers, + shuffle=True, pin_memory=True, persistent_workers=True + ) + val_loader = DataLoader( + val_set, batch_size=batch_size, num_workers=num_workers, + shuffle=False, pin_memory=True, persistent_workers=True + ) + print_with_timestamp("Dataset Loaded") + + # 模型选择 + if "film" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[4:]] + model = MOAFWithFiLM(fusion_depth_list).to(device) + elif "cca" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[3:]] + model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device) + elif "se" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[2:]] + model = MOAFWithSE(fusion_depth_list).to(device) + else: + model = MOAFNoFusion().to(device) + print_with_timestamp("Model Loaded") + + # 形式化预训练参数加载 + if checkpoint_load: + checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + print_with_timestamp("Model Checkpoint Params Loaded") + + # 损失函数、优化器、学习率调度器 + loss_fn = nn.HuberLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs + else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))) + ) + + print_with_timestamp("Start trainning") + fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type) + print_with_timestamp("Training completed!") + + +if __name__ == "__main__": + main() diff --git a/MOAFTrainDDP.py b/MOAFTrainDDP.py new file mode 100644 index 0000000..8ac7778 --- /dev/null +++ b/MOAFTrainDDP.py @@ -0,0 +1,239 @@ +import os +import time +import argparse +import tomllib +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from MOAFUtils import print_with_timestamp +from MOAFDatasets import MOAFDataset +from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE + + +def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn): + model.train() + train_loss = 0.0 + + for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"): + images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) + params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) + + optimizer.zero_grad() + outputs = model(images, params) + loss = loss_fn(outputs.squeeze(1), labels) + loss.backward() + optimizer.step() + train_loss += loss.item() + + return train_loss + + +def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn): + model.eval() + val_loss = 0.0 + + with torch.no_grad(): + for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"): + images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True) + params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True) + + outputs = model(images, params) + loss = loss_fn(outputs.squeeze(1), labels) + val_loss += loss.item() + + return val_loss + + +def fit(rank, world_size, cfg): + """ + 每个进程运行的主函数(单卡)。 + rank: 该进程的全局 rank(0 ~ world_size-1) + world_size: 进程总数(通常等于可用 GPU 数) + cfg: 从 toml 读取的配置字典 + """ + # -------- init distributed env -------- + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = cfg.get("master_port", "29500") + dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + device = torch.device(f'cuda:{rank}') + if rank == 0: + print_with_timestamp(f"Distributed initialized. World size: {world_size}") + + # -------- parse hyperparams from cfg -------- + model_type = cfg["model_type"] + output_type = cfg["output_type"] + dataset_dir = cfg["dataset_dir"] + batch_size = int(cfg["batch_size"]) + num_workers = int(cfg["num_workers"]) + lr = float(cfg["lr"]) + patience = int(cfg["patience"]) + epochs = int(cfg["epochs"]) + warmup_epochs = int(cfg["warmup_epochs"]) + objective_params_list = cfg["train_objective_params_list"] + checkpoint_load = cfg["checkpoint_load"] + + # -------- datasets & distributed sampler -------- + train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type) + val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type) + + # DistributedSampler ensures each process sees a unique subset each epoch + train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True) + val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False) + + # When using DistributedSampler, do not shuffle at DataLoader level + train_loader = DataLoader( + train_set, batch_size=batch_size, num_workers=num_workers, + shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler + ) + val_loader = DataLoader( + val_set, batch_size=batch_size, num_workers=num_workers, + shuffle=False, pin_memory=True, persistent_workers=True, sampler=val_sampler + ) + + if rank == 0: + print_with_timestamp("Dataset Loaded (Distributed)") + + # -------- model creation -------- + if "film" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[4:]] + model = MOAFWithFiLM(fusion_depth_list).to(device) + elif "cca" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[3:]] + model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device) + elif "se" in model_type: + fusion_depth_list = [int(ch) for ch in model_type[2:]] + model = MOAFWithSE(fusion_depth_list).to(device) + else: + model = MOAFNoFusion().to(device) + + # 形式化预训练参数加载 + if checkpoint_load: + checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + print_with_timestamp("Model Checkpoint Params Loaded") + + # wrap with DDP. device_ids ensures single-GPU per process + model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False) + if rank == 0: + print_with_timestamp("Model Wrapped with DDP") + + # -------- loss / optimizer / scheduler -------- + loss_fn = nn.HuberLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs + else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))) + ) + + # -------- TensorBoard & checkpoint only on rank 0 -------- + if rank == 0: + tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") + # tensorboard graph: use a small dummy input placed on correct device + dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device) + tb_writer.add_graph(model.module, (dummy_input1, dummy_input2)) + + else: + tb_writer = None + + # -------- training loop with early stopping (only rank 0 saves checkpoints) -------- + best_val_loss = float('inf') + patience_counter = 0 + + if rank == 0: + print_with_timestamp("Start training (DDP)") + + for epoch in range(epochs): + # 必须在每个 epoch 开头设置 epoch 给 sampler,这样 shuffle 能跨 epoch 工作 + train_sampler.set_epoch(epoch) + val_sampler.set_epoch(epoch) + + start_time = time.time() + avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader) + avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn) / len(val_loader) + current_lr = optimizer.param_groups[0]['lr'] + scheduler.step() + epoch_time = time.time() - start_time + + # 只有 rank 0 写 tensorboard 和保存 checkpoint,避免重复 IO + if rank == 0: + tb_writer.add_scalar('Loss/train', avg_train_loss, epoch) + tb_writer.add_scalar('Loss/val', avg_val_loss, epoch) + tb_writer.add_scalar('LearningRate', current_lr, epoch) + tb_writer.add_scalar('Time/epoch', epoch_time, epoch) + + print_with_timestamp(f"Epoch {epoch+1}/{epochs} | " + f"Train Loss: {avg_train_loss:.6f} | " + f"Val Loss: {avg_val_loss:.6f} | " + f"LR: {current_lr:.2e} | " + f"Time: {epoch_time:.2f}s") + + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + patience_counter = 0 + save_dict = { + "epoch": epoch, + # 保存 module.state_dict()(DDP 包裹时用 module) + "model_state_dict": model.module.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "train_loss": avg_train_loss, + "val_loss": avg_val_loss + } + Path("ckpts").mkdir(exist_ok=True, parents=True) + torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt") + print_with_timestamp(f"New best model saved at epoch {epoch+1}") + else: + patience_counter += 1 + if patience_counter > patience: + print_with_timestamp(f"Early stopping at {epoch+1} epochs") + break + + # cleanup + if tb_writer is not None: + tb_writer.close() + dist.destroy_process_group() + if rank == 0: + print_with_timestamp("Training completed and distributed cleaned up") + + +def parse_args_and_cfg(): + parser = argparse.ArgumentParser(description="DDP train script that loads config from a TOML file.") + parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)") + parser.add_argument("--nproc_per_node", type=int, default=None, help="number of processes (GPUs) to use on this node") + args = parser.parse_args() + with open(args.config, "rb") as f: + cfg = tomllib.load(f) + return args, cfg + + +def main(): + args, cfg = parse_args_and_cfg() + # 自动检测可用 GPU 数量,除非用户指定 nproc_per_node + available_gpus = torch.cuda.device_count() + if available_gpus == 0: + raise RuntimeError("No CUDA devices available for DDP. Use a GPU machine.") + + nproc = args.nproc_per_node if args.nproc_per_node is not None else available_gpus + if nproc > available_gpus: + raise ValueError(f"Requested {nproc} processes but only {available_gpus} GPUs available") + + # 使用 spawn 启动 nproc 个子进程,每个进程运行 fit(rank, world_size, cfg) + # spawn 会在每个子进程中调用 fit(rank, world_size, cfg) 并传入不同的 rank + mp.spawn(fit, args=(nproc, cfg), nprocs=nproc, join=True) + + +if __name__ == "__main__": + main() diff --git a/MOAFUtils.py b/MOAFUtils.py new file mode 100644 index 0000000..69c4757 --- /dev/null +++ b/MOAFUtils.py @@ -0,0 +1,21 @@ +import sys +from datetime import datetime + +def print_with_timestamp(*args, sep=' ', end='\n', file=None, flush=False, time_format="%Y-%m-%d %H:%M:%S"): + """ + 与 print 类似,但在输出内容前加时间戳前缀。 + - *args: 要打印的对象(会用 sep 连接) + - sep, end, file, flush: 与内置 print 相同 + - time_format: 时间格式,默认 '%Y-%m-%d %H:%M:%S',遵循 datetime.strftime 格式 + """ + if file is None: + file = sys.stdout + + ts = datetime.now().strftime(time_format) + body = sep.join(str(a) for a in args) + output = f'{ts} --> {body}' + print(output, end=end, file=file, flush=flush) + + +if __name__ == "__main__": + print_with_timestamp(f"Test output") \ No newline at end of file diff --git a/config_example.toml b/config_example.toml new file mode 100644 index 0000000..a92112d --- /dev/null +++ b/config_example.toml @@ -0,0 +1,26 @@ +# 模型与数据 +model_type = "cca2" +output_type = "distance" +dataset_dir = "F:/Datasets/MODatasetD" +# 训练参数 +batch_size = 64 +num_workers = 8 +lr = 1e-4 +patience = 5 +epochs = 5 +warmup_epochs = 1 +# 其它 +train_objective_params_list = [ + "10x-0.25-1.0000", "10x-0.30-1.0000", + "20x-0.70-1.0000", "20x-0.80-1.0000", + "40x-0.65-1.0000", "100x-0.80-1.0000", + "100x-1.25-1.4730" +] +test_objective_params_list = [ + "10x-0.25-1.0000", "10x-0.30-1.0000", + "20x-0.70-1.0000", "20x-0.80-1.0000", + "40x-0.65-1.0000", "100x-0.80-1.0000", + "100x-1.25-1.4730" +] +# 加载形式化预训练参数 +checkpoint_load = true \ No newline at end of file