first commit

This commit is contained in:
kaiza_hikaru 2025-10-23 16:40:40 +08:00
commit f54d817702
8 changed files with 959 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
__pycache__/
ckpts/
configs/
runs/
results/
models.ipynb
ShuffleNetV2.txt

160
MOAFDatasets.py Normal file
View File

@ -0,0 +1,160 @@
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset
from PIL import Image
import torch
from torchvision import transforms
class MOAFDataset(Dataset):
def __init__(self, dataset_root, tvt='train', objectives_params_list=None, output_type='distance'):
"""
dataset_root: 根目录Pathable
tvt: 'train'|'val'|'test'用于选择 transform
objectives_params_list: 列表包含要加载的物镜目录名例如 ["10x-0.25-1.0000", ...]
output_type: 'distance'返回 nm 'ratio'返回 defocus / DoF
"""
super().__init__()
self.dataset_root = Path(dataset_root)
self.tvt = tvt
if objectives_params_list is None:
self.objectives_params_list = ["10x-0.25-1.0000"]
else:
self.objectives_params_list = objectives_params_list
# 处理 output_type非法输入回退到 'distance'
if isinstance(output_type, str) and output_type.lower() == "ratio":
self.output_type = "ratio"
else:
self.output_type = "distance"
# 根据 tvt 选择 transform
if self.tvt == "train":
self.transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
else:
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 直接在构造函数中读取并合并 csv
all_dfs = []
for param_dir in self.objectives_params_list:
csv_file_path = self.dataset_root / "tvtinfo" / param_dir / f"{self.tvt}.csv"
if not csv_file_path.exists():
raise FileNotFoundError(f"CSV not found: {csv_file_path}")
df = pd.read_csv(csv_file_path)
all_dfs.append(df)
if len(all_dfs) == 0:
raise ValueError("No csv files were loaded. Check objectives_params_list and dataset_root.")
combined_df = pd.concat(all_dfs, ignore_index=True)
# 过滤 relative 范围
self.dataframe = combined_df[(combined_df["relative"] >= -50) & (combined_df["relative"] <= 50)].reset_index(drop=True)
def __len__(self):
return len(self.dataframe)
@staticmethod
def _compute_dof_nm(mag, na, rix, wavelength_nm=550.0, pixel_size_nm=3450.0):
"""
公式 DoF = lambda * n / (NA ** 2) + (n * e) / (M * NA)
输入参数均为标量float返回 DoFnm
"""
# 防止除以零
if na == 0 or mag == 0:
return float('inf')
lam = float(wavelength_nm)
n = float(rix)
M = float(mag)
NA = float(na)
e = float(pixel_size_nm)
dof = (lam * n) / (NA ** 2) + ((n * e) / (M * NA))
return dof
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
row = self.dataframe.iloc[idx]
img_path = self.dataset_root / row['path']
image = Image.open(img_path).convert('RGB')
# 基本数值字段(注意 CSV 列名需匹配)
mag = float(row['mag'])
na = float(row['na'])
rix = float(row['rix'])
label_nm = float(row['label'])
image = self.transform(image)
mag_tensor = torch.tensor(mag, dtype=torch.float32)
na_tensor = torch.tensor(na, dtype=torch.float32)
rix_tensor = torch.tensor(rix, dtype=torch.float32)
label_nm_tensor = torch.tensor(label_nm, dtype=torch.float32)
# min-max 归一化输入参数
mag_tensor = (mag_tensor - 10) / (100 - 10)
na_tensor = (na_tensor - 0) / (1.25 - 0)
rix_tensor = (rix_tensor - 1.0) / (1.5 - 1.0)
# 根据 output_type 决定输出 label
if self.output_type == "ratio":
dof_nm = self._compute_dof_nm(mag=mag, na=na, rix=rix, wavelength_nm=550.0, pixel_size_nm=3450.0)
# 若 DOF 为 inf 或极大,避免除零
if not (dof_nm is None or dof_nm == float('inf') or dof_nm == 0):
label_out = label_nm / dof_nm
else:
label_out = label_nm # 回退,虽然不太可能
label_out_tensor = torch.tensor(float(label_out), dtype=torch.float32)
else:
# distance 模式:直接返回 nm
label_out_tensor = label_nm_tensor
sample = {
'image': image,
'mag': mag_tensor,
'na': na_tensor,
'rix': rix_tensor,
'label': label_out_tensor,
'path': img_path.as_posix(),
}
return sample
def get_dataframe(self):
return self.dataframe
if __name__ == "__main__":
# 简单测试
train_set = MOAFDataset("F:/Datasets/MODatasetD", tvt='train',
objectives_params_list=[
"100x-1.25-1.4730",
],
output_type='ratio')
from torch.utils.data import DataLoader
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2)
for batch in train_loader:
images = batch["image"]
labels = batch["label"]
print(f"images.shape: {images.shape}, labels.shape: {labels.shape}")
mags = batch["mag"]
nas = batch["na"]
rixs = batch["rix"]
print(f"mags: {mags}, nas: {nas}, rixs: {rixs}")
print(f"mags.shape: {mags.shape}, nas.shape: {nas.shape}, rixs.shape: {rixs.shape}")
params = torch.stack((mags, nas, rixs), dim=1)
print(f"params shape: {params.shape}")
print("first labels:")
for i in range(min(4, labels.shape[0])):
print(labels[i].item())
break

240
MOAFModels.py Normal file
View File

@ -0,0 +1,240 @@
import torch
from torch import nn
from torchvision import models
# 无融合模型
class MOAFNoFusion(nn.Module):
def __init__(self):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.features = nn.Sequential(
shuff.conv1, shuff.maxpool,
shuff.stage2, shuff.stage3,
shuff.stage4, shuff.conv5
)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def forward(self, image, params):
x = self.features(image)
x = self.avgpool(x)
x = self.regressor(x)
return x
# 参数嵌入模型
class ParamEmbedding(nn.Module):
def __init__(self, in_dim=3, hidden_dim=64, out_dim=128):
super().__init__()
self.embedding = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, out_dim),
nn.ReLU(),
nn.LayerNorm(out_dim)
)
def forward(self, x):
return self.embedding(x)
# FiLM 融合块
class FiLMBlock(nn.Module):
def __init__(self, param_emb_dim=128, feat_channels=128):
super().__init__()
self.gamma_gen = nn.Linear(param_emb_dim, feat_channels)
self.beta_gen = nn.Linear(param_emb_dim, feat_channels)
def forward(self, feature_map, param_emb):
gamma = self.gamma_gen(param_emb).unsqueeze(-1).unsqueeze(-1)
beta = self.beta_gen(param_emb).unsqueeze(-1).unsqueeze(-1)
return feature_map * (1.0 + gamma) + beta
# 通道交叉注意力融合块
class ChannelCrossAttention(nn.Module):
def __init__(self, param_emb_dim=128, feat_channels=128):
super().__init__()
self.param_emb_dim = param_emb_dim
self.feat_channels = feat_channels
self.hidden_dim = max(16, self.feat_channels // 4)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.bottle_neck = nn.Sequential(
nn.Linear(self.param_emb_dim + self.feat_channels, self.hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(self.hidden_dim, self.feat_channels),
nn.Sigmoid()
)
def forward(self, feature_map, param_emb):
b, c, h, w = feature_map.shape
pooled = self.global_pool(feature_map).view(b, c)
cat = torch.cat([pooled, param_emb], dim=1)
weights = self.bottle_neck(cat)
weights4d = weights.view(b, c, 1, 1)
return feature_map * (1.0 + weights4d)
# SE块
class SEBlock(nn.Module):
def __init__(self, feat_channels=128):
super().__init__()
self.feat_channels = feat_channels
self.hidden_dim = max(16, self.feat_channels // 4)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.bottle_neck = nn.Sequential(
nn.Linear(self.feat_channels, self.hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(self.hidden_dim, self.feat_channels),
nn.Sigmoid()
)
def forward(self, feature_map):
b, c, h, w = feature_map.shape
pooled = self.global_pool(feature_map).view(b, c)
weights = self.bottle_neck(pooled)
weights4d = weights.view(b, c, 1, 1)
return feature_map * (1.0 + weights4d)
# 仅返回特征图的恒等变换
class FusionIdentity(nn.Module):
def forward(self, feature_map, param_emb):
return feature_map
# 使用 FiLM 融合的模型
class MOAFWithFiLM(nn.Module):
def __init__(self, fusion_level=None):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
self.stage2 = shuff.stage2
self.stage3 = shuff.stage3
self.stage4 = shuff.stage4
self.cbr2 = shuff.conv5
self.param_embedding = ParamEmbedding()
if fusion_level is None:
fusion_level = [2]
self.film_block0 = FiLMBlock(feat_channels=48) if 0 in fusion_level else FusionIdentity()
self.film_block1 = FiLMBlock(feat_channels=96) if 1 in fusion_level else FusionIdentity()
self.film_block2 = FiLMBlock(feat_channels=192) if 2 in fusion_level else FusionIdentity()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def forward(self, image, params):
x = self.cbrm(image)
x = self.stage2(x)
param_emb = self.param_embedding(params)
x = self.film_block0(x, param_emb)
x = self.stage3(x)
x = self.film_block1(x, param_emb)
x = self.stage4(x)
x = self.film_block2(x, param_emb)
x = self.cbr2(x)
x = self.avgpool(x)
x = self.regressor(x)
return x
# 使用交叉注意力融合的模型
class MOAFWithChannelCrossAttention(nn.Module):
def __init__(self, fusion_level=None):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
self.stage2 = shuff.stage2
self.stage3 = shuff.stage3
self.stage4 = shuff.stage4
self.cbr2 = shuff.conv5
self.param_embedding = ParamEmbedding()
if fusion_level is None:
fusion_level = [2]
self.cca_block0 = ChannelCrossAttention(feat_channels=48) if 0 in fusion_level else FusionIdentity()
self.cca_block1 = ChannelCrossAttention(feat_channels=96) if 1 in fusion_level else FusionIdentity()
self.cca_block2 = ChannelCrossAttention(feat_channels=192) if 2 in fusion_level else FusionIdentity()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def forward(self, image, params):
x = self.cbrm(image)
x = self.stage2(x)
param_emb = self.param_embedding(params)
x = self.cca_block0(x, param_emb)
x = self.stage3(x)
x = self.cca_block1(x, param_emb)
x = self.stage4(x)
x = self.cca_block2(x, param_emb)
x = self.cbr2(x)
x = self.avgpool(x)
x = self.regressor(x)
return x
# 使用 SE 块但无融合的模型
class MOAFWithSE(nn.Module):
def __init__(self, fusion_level=None):
super().__init__()
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
self.stage2 = shuff.stage2
self.stage3 = shuff.stage3
self.stage4 = shuff.stage4
self.cbr2 = shuff.conv5
if fusion_level is None:
fusion_level = [0]
self.se_block0 = SEBlock(feat_channels=48) if 0 in fusion_level else nn.Identity()
self.se_block1 = SEBlock(feat_channels=96) if 1 in fusion_level else nn.Identity()
self.se_block2 = SEBlock(feat_channels=192) if 2 in fusion_level else nn.Identity()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Flatten(),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
def forward(self, image, params):
x = self.cbrm(image)
x = self.stage2(x)
x = self.se_block0(x)
x = self.stage3(x)
x = self.se_block1(x)
x = self.stage4(x)
x = self.se_block2(x)
x = self.cbr2(x)
x = self.avgpool(x)
x = self.regressor(x)
return x

88
MOAFTest.py Normal file
View File

@ -0,0 +1,88 @@
import torch
from torch.utils.data import DataLoader
import tomllib
import argparse
import numpy as np
from tqdm import tqdm
from pathlib import Path
from MOAFUtils import print_with_timestamp
from MOAFDatasets import MOAFDataset
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
def test(model, test_loader, device, model_type, output_type):
model.eval()
results = []
with torch.no_grad():
for data in tqdm(test_loader, desc=f"[Test]"):
images = data["image"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
outputs = model(images, params)
results.extend(outputs.squeeze(1).cpu().numpy())
results = np.array(results)
# 直接预测结果保存为 excel 文件
df = test_loader.dataset.get_dataframe()
df["pred"] = results
Path("results").mkdir(exist_ok=True, parents=True)
# !pip install openpyxl
df.to_excel(f"results/{model_type}_{output_type}.xlsx", index=False)
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description="Test script that loads config from a TOML file.")
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
args = parser.parse_args()
with open(args.config, "rb") as f:
cfg = tomllib.load(f)
# 确定超参数
model_type = cfg["model_type"]
output_type = cfg["output_type"]
dataset_dir = cfg["dataset_dir"]
batch_size = int(cfg["batch_size"])
num_workers = int(cfg["num_workers"])
objective_params_list = cfg["test_objective_params_list"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_with_timestamp(f"Using device {device}")
# 加载数据集
test_set = MOAFDataset(dataset_dir, "test", objective_params_list, output_type)
print_with_timestamp("Dataset Done")
test_loader = DataLoader(
test_set, batch_size=batch_size, num_workers=num_workers,
shuffle=False, pin_memory=True, persistent_workers=True
)
print_with_timestamp("Dataset Loaded")
# 模型选择
if "film" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[4:]]
model = MOAFWithFiLM(fusion_depth_list).to(device)
elif "cca" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[3:]]
model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device)
elif "se" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[2:]]
model = MOAFWithSE(fusion_depth_list).to(device)
else:
model = MOAFNoFusion().to(device)
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
print_with_timestamp("Model Loaded")
print_with_timestamp("Start testing")
test(model, test_loader, device, model_type, output_type)
print_with_timestamp("Testing completed!")
if __name__ == "__main__":
main()

178
MOAFTrain.py Normal file
View File

@ -0,0 +1,178 @@
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
import tomllib
import argparse
import numpy as np
from tqdm import tqdm
from pathlib import Path
from MOAFUtils import print_with_timestamp
from MOAFDatasets import MOAFDataset
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
model.train()
train_loss = 0.0
for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"):
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
optimizer.zero_grad()
outputs = model(images, params)
loss = loss_fn(outputs.squeeze(1), labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
return train_loss
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
model.eval()
val_loss = 0.0
with torch.no_grad():
for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"):
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
outputs = model(images, params)
loss = loss_fn(outputs.squeeze(1), labels)
val_loss += loss.item()
return val_loss
def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type):
best_val_loss = float('inf')
patience_counter = 0
# !pip install tensorboard
with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer:
# Tensorboard 上显示模型结构
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
writer.add_graph(model, (dummy_input1, dummy_input2))
for epoch in range(epochs):
# 训练/验证
start_time = time.time()
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader)
avg_val_loss = valid_epoch(model,val_loader, epoch, epochs, device, loss_fn) / len(val_loader)
current_lr = optimizer.param_groups[0]['lr']
scheduler.step()
epoch_time = time.time() - start_time
# 显示记录数据
writer.add_scalar('Loss/train', avg_train_loss, epoch)
writer.add_scalar('Loss/val', avg_val_loss, epoch)
writer.add_scalar('LearningRate', current_lr, epoch)
writer.add_scalar('Time/epoch', epoch_time, epoch)
print_with_timestamp(f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {avg_train_loss:.6f} | "
f"Val Loss: {avg_val_loss:.6f} | "
f"LR: {current_lr:.2e} | "
f"Time: {epoch_time:.2f}s")
# 记录检查点
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
save_dict = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"train_loss": avg_train_loss,
"val_loss": avg_val_loss
}
Path("ckpts").mkdir(exist_ok=True, parents=True)
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
else:
patience_counter += 1
if patience_counter > patience:
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
break
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description="Train script that loads config from a TOML file.")
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
args = parser.parse_args()
with open(args.config, "rb") as f:
cfg = tomllib.load(f)
# 确定超参数
model_type = cfg["model_type"]
output_type = cfg["output_type"]
dataset_dir = cfg["dataset_dir"]
batch_size = int(cfg["batch_size"])
num_workers = int(cfg["num_workers"])
lr = float(cfg["lr"])
patience = int(cfg["patience"])
epochs = int(cfg["epochs"])
warmup_epochs = int(cfg["warmup_epochs"])
objective_params_list = cfg["train_objective_params_list"]
checkpoint_load = cfg["checkpoint_load"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_with_timestamp(f"Using device {device}")
# 加载数据集
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
print_with_timestamp("Dataset Done")
train_loader = DataLoader(
train_set, batch_size=batch_size, num_workers=num_workers,
shuffle=True, pin_memory=True, persistent_workers=True
)
val_loader = DataLoader(
val_set, batch_size=batch_size, num_workers=num_workers,
shuffle=False, pin_memory=True, persistent_workers=True
)
print_with_timestamp("Dataset Loaded")
# 模型选择
if "film" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[4:]]
model = MOAFWithFiLM(fusion_depth_list).to(device)
elif "cca" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[3:]]
model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device)
elif "se" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[2:]]
model = MOAFWithSE(fusion_depth_list).to(device)
else:
model = MOAFNoFusion().to(device)
print_with_timestamp("Model Loaded")
# 形式化预训练参数加载
if checkpoint_load:
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
print_with_timestamp("Model Checkpoint Params Loaded")
# 损失函数、优化器、学习率调度器
loss_fn = nn.HuberLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
)
print_with_timestamp("Start trainning")
fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type)
print_with_timestamp("Training completed!")
if __name__ == "__main__":
main()

239
MOAFTrainDDP.py Normal file
View File

@ -0,0 +1,239 @@
import os
import time
import argparse
import tomllib
from pathlib import Path
import numpy as np
from tqdm import tqdm
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from MOAFUtils import print_with_timestamp
from MOAFDatasets import MOAFDataset
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
model.train()
train_loss = 0.0
for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"):
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
optimizer.zero_grad()
outputs = model(images, params)
loss = loss_fn(outputs.squeeze(1), labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
return train_loss
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
model.eval()
val_loss = 0.0
with torch.no_grad():
for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"):
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
outputs = model(images, params)
loss = loss_fn(outputs.squeeze(1), labels)
val_loss += loss.item()
return val_loss
def fit(rank, world_size, cfg):
"""
每个进程运行的主函数单卡
rank: 该进程的全局 rank0 ~ world_size-1
world_size: 进程总数通常等于可用 GPU
cfg: toml 读取的配置字典
"""
# -------- init distributed env --------
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = cfg.get("master_port", "29500")
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f'cuda:{rank}')
if rank == 0:
print_with_timestamp(f"Distributed initialized. World size: {world_size}")
# -------- parse hyperparams from cfg --------
model_type = cfg["model_type"]
output_type = cfg["output_type"]
dataset_dir = cfg["dataset_dir"]
batch_size = int(cfg["batch_size"])
num_workers = int(cfg["num_workers"])
lr = float(cfg["lr"])
patience = int(cfg["patience"])
epochs = int(cfg["epochs"])
warmup_epochs = int(cfg["warmup_epochs"])
objective_params_list = cfg["train_objective_params_list"]
checkpoint_load = cfg["checkpoint_load"]
# -------- datasets & distributed sampler --------
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
# DistributedSampler ensures each process sees a unique subset each epoch
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False)
# When using DistributedSampler, do not shuffle at DataLoader level
train_loader = DataLoader(
train_set, batch_size=batch_size, num_workers=num_workers,
shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler
)
val_loader = DataLoader(
val_set, batch_size=batch_size, num_workers=num_workers,
shuffle=False, pin_memory=True, persistent_workers=True, sampler=val_sampler
)
if rank == 0:
print_with_timestamp("Dataset Loaded (Distributed)")
# -------- model creation --------
if "film" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[4:]]
model = MOAFWithFiLM(fusion_depth_list).to(device)
elif "cca" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[3:]]
model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device)
elif "se" in model_type:
fusion_depth_list = [int(ch) for ch in model_type[2:]]
model = MOAFWithSE(fusion_depth_list).to(device)
else:
model = MOAFNoFusion().to(device)
# 形式化预训练参数加载
if checkpoint_load:
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
print_with_timestamp("Model Checkpoint Params Loaded")
# wrap with DDP. device_ids ensures single-GPU per process
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
if rank == 0:
print_with_timestamp("Model Wrapped with DDP")
# -------- loss / optimizer / scheduler --------
loss_fn = nn.HuberLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
)
# -------- TensorBoard & checkpoint only on rank 0 --------
if rank == 0:
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}")
# tensorboard graph: use a small dummy input placed on correct device
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
tb_writer.add_graph(model.module, (dummy_input1, dummy_input2))
else:
tb_writer = None
# -------- training loop with early stopping (only rank 0 saves checkpoints) --------
best_val_loss = float('inf')
patience_counter = 0
if rank == 0:
print_with_timestamp("Start training (DDP)")
for epoch in range(epochs):
# 必须在每个 epoch 开头设置 epoch 给 sampler这样 shuffle 能跨 epoch 工作
train_sampler.set_epoch(epoch)
val_sampler.set_epoch(epoch)
start_time = time.time()
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader)
avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn) / len(val_loader)
current_lr = optimizer.param_groups[0]['lr']
scheduler.step()
epoch_time = time.time() - start_time
# 只有 rank 0 写 tensorboard 和保存 checkpoint避免重复 IO
if rank == 0:
tb_writer.add_scalar('Loss/train', avg_train_loss, epoch)
tb_writer.add_scalar('Loss/val', avg_val_loss, epoch)
tb_writer.add_scalar('LearningRate', current_lr, epoch)
tb_writer.add_scalar('Time/epoch', epoch_time, epoch)
print_with_timestamp(f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {avg_train_loss:.6f} | "
f"Val Loss: {avg_val_loss:.6f} | "
f"LR: {current_lr:.2e} | "
f"Time: {epoch_time:.2f}s")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
save_dict = {
"epoch": epoch,
# 保存 module.state_dict()DDP 包裹时用 module
"model_state_dict": model.module.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"train_loss": avg_train_loss,
"val_loss": avg_val_loss
}
Path("ckpts").mkdir(exist_ok=True, parents=True)
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
else:
patience_counter += 1
if patience_counter > patience:
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
break
# cleanup
if tb_writer is not None:
tb_writer.close()
dist.destroy_process_group()
if rank == 0:
print_with_timestamp("Training completed and distributed cleaned up")
def parse_args_and_cfg():
parser = argparse.ArgumentParser(description="DDP train script that loads config from a TOML file.")
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
parser.add_argument("--nproc_per_node", type=int, default=None, help="number of processes (GPUs) to use on this node")
args = parser.parse_args()
with open(args.config, "rb") as f:
cfg = tomllib.load(f)
return args, cfg
def main():
args, cfg = parse_args_and_cfg()
# 自动检测可用 GPU 数量,除非用户指定 nproc_per_node
available_gpus = torch.cuda.device_count()
if available_gpus == 0:
raise RuntimeError("No CUDA devices available for DDP. Use a GPU machine.")
nproc = args.nproc_per_node if args.nproc_per_node is not None else available_gpus
if nproc > available_gpus:
raise ValueError(f"Requested {nproc} processes but only {available_gpus} GPUs available")
# 使用 spawn 启动 nproc 个子进程,每个进程运行 fit(rank, world_size, cfg)
# spawn 会在每个子进程中调用 fit(rank, world_size, cfg) 并传入不同的 rank
mp.spawn(fit, args=(nproc, cfg), nprocs=nproc, join=True)
if __name__ == "__main__":
main()

21
MOAFUtils.py Normal file
View File

@ -0,0 +1,21 @@
import sys
from datetime import datetime
def print_with_timestamp(*args, sep=' ', end='\n', file=None, flush=False, time_format="%Y-%m-%d %H:%M:%S"):
"""
print 类似但在输出内容前加时间戳前缀
- *args: 要打印的对象会用 sep 连接
- sep, end, file, flush: 与内置 print 相同
- time_format: 时间格式默认 '%Y-%m-%d %H:%M:%S'遵循 datetime.strftime 格式
"""
if file is None:
file = sys.stdout
ts = datetime.now().strftime(time_format)
body = sep.join(str(a) for a in args)
output = f'{ts} --> {body}'
print(output, end=end, file=file, flush=flush)
if __name__ == "__main__":
print_with_timestamp(f"Test output")

26
config_example.toml Normal file
View File

@ -0,0 +1,26 @@
# 模型与数据
model_type = "cca2"
output_type = "distance"
dataset_dir = "F:/Datasets/MODatasetD"
# 训练参数
batch_size = 64
num_workers = 8
lr = 1e-4
patience = 5
epochs = 5
warmup_epochs = 1
# 其它
train_objective_params_list = [
"10x-0.25-1.0000", "10x-0.30-1.0000",
"20x-0.70-1.0000", "20x-0.80-1.0000",
"40x-0.65-1.0000", "100x-0.80-1.0000",
"100x-1.25-1.4730"
]
test_objective_params_list = [
"10x-0.25-1.0000", "10x-0.30-1.0000",
"20x-0.70-1.0000", "20x-0.80-1.0000",
"40x-0.65-1.0000", "100x-0.80-1.0000",
"100x-1.25-1.4730"
]
# 加载形式化预训练参数
checkpoint_load = true