first commit
This commit is contained in:
commit
f54d817702
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
__pycache__/
|
||||||
|
ckpts/
|
||||||
|
configs/
|
||||||
|
runs/
|
||||||
|
results/
|
||||||
|
models.ipynb
|
||||||
|
ShuffleNetV2.txt
|
||||||
160
MOAFDatasets.py
Normal file
160
MOAFDatasets.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
class MOAFDataset(Dataset):
|
||||||
|
def __init__(self, dataset_root, tvt='train', objectives_params_list=None, output_type='distance'):
|
||||||
|
"""
|
||||||
|
dataset_root: 根目录(Pathable)
|
||||||
|
tvt: 'train'|'val'|'test'(用于选择 transform)
|
||||||
|
objectives_params_list: 列表,包含要加载的物镜目录名,例如 ["10x-0.25-1.0000", ...]
|
||||||
|
output_type: 'distance'(返回 nm)或 'ratio'(返回 defocus / DoF)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dataset_root = Path(dataset_root)
|
||||||
|
self.tvt = tvt
|
||||||
|
if objectives_params_list is None:
|
||||||
|
self.objectives_params_list = ["10x-0.25-1.0000"]
|
||||||
|
else:
|
||||||
|
self.objectives_params_list = objectives_params_list
|
||||||
|
|
||||||
|
# 处理 output_type,非法输入回退到 'distance'
|
||||||
|
if isinstance(output_type, str) and output_type.lower() == "ratio":
|
||||||
|
self.output_type = "ratio"
|
||||||
|
else:
|
||||||
|
self.output_type = "distance"
|
||||||
|
|
||||||
|
# 根据 tvt 选择 transform
|
||||||
|
if self.tvt == "train":
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.RandomVerticalFlip(),
|
||||||
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
# 直接在构造函数中读取并合并 csv
|
||||||
|
all_dfs = []
|
||||||
|
for param_dir in self.objectives_params_list:
|
||||||
|
csv_file_path = self.dataset_root / "tvtinfo" / param_dir / f"{self.tvt}.csv"
|
||||||
|
if not csv_file_path.exists():
|
||||||
|
raise FileNotFoundError(f"CSV not found: {csv_file_path}")
|
||||||
|
df = pd.read_csv(csv_file_path)
|
||||||
|
all_dfs.append(df)
|
||||||
|
|
||||||
|
if len(all_dfs) == 0:
|
||||||
|
raise ValueError("No csv files were loaded. Check objectives_params_list and dataset_root.")
|
||||||
|
|
||||||
|
combined_df = pd.concat(all_dfs, ignore_index=True)
|
||||||
|
|
||||||
|
# 过滤 relative 范围
|
||||||
|
self.dataframe = combined_df[(combined_df["relative"] >= -50) & (combined_df["relative"] <= 50)].reset_index(drop=True)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.dataframe)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_dof_nm(mag, na, rix, wavelength_nm=550.0, pixel_size_nm=3450.0):
|
||||||
|
"""
|
||||||
|
公式: DoF = lambda * n / (NA ** 2) + (n * e) / (M * NA)
|
||||||
|
输入参数均为标量(float),返回 DoF(nm)
|
||||||
|
"""
|
||||||
|
# 防止除以零
|
||||||
|
if na == 0 or mag == 0:
|
||||||
|
return float('inf')
|
||||||
|
lam = float(wavelength_nm)
|
||||||
|
n = float(rix)
|
||||||
|
M = float(mag)
|
||||||
|
NA = float(na)
|
||||||
|
e = float(pixel_size_nm)
|
||||||
|
dof = (lam * n) / (NA ** 2) + ((n * e) / (M * NA))
|
||||||
|
return dof
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if torch.is_tensor(idx):
|
||||||
|
idx = idx.tolist()
|
||||||
|
|
||||||
|
row = self.dataframe.iloc[idx]
|
||||||
|
img_path = self.dataset_root / row['path']
|
||||||
|
image = Image.open(img_path).convert('RGB')
|
||||||
|
|
||||||
|
# 基本数值字段(注意 CSV 列名需匹配)
|
||||||
|
mag = float(row['mag'])
|
||||||
|
na = float(row['na'])
|
||||||
|
rix = float(row['rix'])
|
||||||
|
label_nm = float(row['label'])
|
||||||
|
|
||||||
|
image = self.transform(image)
|
||||||
|
|
||||||
|
mag_tensor = torch.tensor(mag, dtype=torch.float32)
|
||||||
|
na_tensor = torch.tensor(na, dtype=torch.float32)
|
||||||
|
rix_tensor = torch.tensor(rix, dtype=torch.float32)
|
||||||
|
label_nm_tensor = torch.tensor(label_nm, dtype=torch.float32)
|
||||||
|
|
||||||
|
# min-max 归一化输入参数
|
||||||
|
mag_tensor = (mag_tensor - 10) / (100 - 10)
|
||||||
|
na_tensor = (na_tensor - 0) / (1.25 - 0)
|
||||||
|
rix_tensor = (rix_tensor - 1.0) / (1.5 - 1.0)
|
||||||
|
|
||||||
|
# 根据 output_type 决定输出 label
|
||||||
|
if self.output_type == "ratio":
|
||||||
|
dof_nm = self._compute_dof_nm(mag=mag, na=na, rix=rix, wavelength_nm=550.0, pixel_size_nm=3450.0)
|
||||||
|
# 若 DOF 为 inf 或极大,避免除零
|
||||||
|
if not (dof_nm is None or dof_nm == float('inf') or dof_nm == 0):
|
||||||
|
label_out = label_nm / dof_nm
|
||||||
|
else:
|
||||||
|
label_out = label_nm # 回退,虽然不太可能
|
||||||
|
label_out_tensor = torch.tensor(float(label_out), dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
# distance 模式:直接返回 nm
|
||||||
|
label_out_tensor = label_nm_tensor
|
||||||
|
|
||||||
|
sample = {
|
||||||
|
'image': image,
|
||||||
|
'mag': mag_tensor,
|
||||||
|
'na': na_tensor,
|
||||||
|
'rix': rix_tensor,
|
||||||
|
'label': label_out_tensor,
|
||||||
|
'path': img_path.as_posix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def get_dataframe(self):
|
||||||
|
return self.dataframe
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 简单测试
|
||||||
|
train_set = MOAFDataset("F:/Datasets/MODatasetD", tvt='train',
|
||||||
|
objectives_params_list=[
|
||||||
|
"100x-1.25-1.4730",
|
||||||
|
],
|
||||||
|
output_type='ratio')
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2)
|
||||||
|
for batch in train_loader:
|
||||||
|
images = batch["image"]
|
||||||
|
labels = batch["label"]
|
||||||
|
print(f"images.shape: {images.shape}, labels.shape: {labels.shape}")
|
||||||
|
mags = batch["mag"]
|
||||||
|
nas = batch["na"]
|
||||||
|
rixs = batch["rix"]
|
||||||
|
print(f"mags: {mags}, nas: {nas}, rixs: {rixs}")
|
||||||
|
print(f"mags.shape: {mags.shape}, nas.shape: {nas.shape}, rixs.shape: {rixs.shape}")
|
||||||
|
params = torch.stack((mags, nas, rixs), dim=1)
|
||||||
|
print(f"params shape: {params.shape}")
|
||||||
|
print("first labels:")
|
||||||
|
for i in range(min(4, labels.shape[0])):
|
||||||
|
print(labels[i].item())
|
||||||
|
break
|
||||||
240
MOAFModels.py
Normal file
240
MOAFModels.py
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torchvision import models
|
||||||
|
|
||||||
|
|
||||||
|
# 无融合模型
|
||||||
|
class MOAFNoFusion(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
shuff.conv1, shuff.maxpool,
|
||||||
|
shuff.stage2, shuff.stage3,
|
||||||
|
shuff.stage4, shuff.conv5
|
||||||
|
)
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.regressor = nn.Sequential(
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(1024, 128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(128, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image, params):
|
||||||
|
x = self.features(image)
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = self.regressor(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# 参数嵌入模型
|
||||||
|
class ParamEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_dim=3, hidden_dim=64, out_dim=128):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding = nn.Sequential(
|
||||||
|
nn.Linear(in_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, out_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.LayerNorm(out_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.embedding(x)
|
||||||
|
|
||||||
|
|
||||||
|
# FiLM 融合块
|
||||||
|
class FiLMBlock(nn.Module):
|
||||||
|
def __init__(self, param_emb_dim=128, feat_channels=128):
|
||||||
|
super().__init__()
|
||||||
|
self.gamma_gen = nn.Linear(param_emb_dim, feat_channels)
|
||||||
|
self.beta_gen = nn.Linear(param_emb_dim, feat_channels)
|
||||||
|
|
||||||
|
def forward(self, feature_map, param_emb):
|
||||||
|
gamma = self.gamma_gen(param_emb).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
beta = self.beta_gen(param_emb).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
return feature_map * (1.0 + gamma) + beta
|
||||||
|
|
||||||
|
|
||||||
|
# 通道交叉注意力融合块
|
||||||
|
class ChannelCrossAttention(nn.Module):
|
||||||
|
def __init__(self, param_emb_dim=128, feat_channels=128):
|
||||||
|
super().__init__()
|
||||||
|
self.param_emb_dim = param_emb_dim
|
||||||
|
self.feat_channels = feat_channels
|
||||||
|
self.hidden_dim = max(16, self.feat_channels // 4)
|
||||||
|
|
||||||
|
self.global_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
|
||||||
|
self.bottle_neck = nn.Sequential(
|
||||||
|
nn.Linear(self.param_emb_dim + self.feat_channels, self.hidden_dim),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(self.hidden_dim, self.feat_channels),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, feature_map, param_emb):
|
||||||
|
b, c, h, w = feature_map.shape
|
||||||
|
pooled = self.global_pool(feature_map).view(b, c)
|
||||||
|
cat = torch.cat([pooled, param_emb], dim=1)
|
||||||
|
weights = self.bottle_neck(cat)
|
||||||
|
weights4d = weights.view(b, c, 1, 1)
|
||||||
|
return feature_map * (1.0 + weights4d)
|
||||||
|
|
||||||
|
|
||||||
|
# SE块
|
||||||
|
class SEBlock(nn.Module):
|
||||||
|
def __init__(self, feat_channels=128):
|
||||||
|
super().__init__()
|
||||||
|
self.feat_channels = feat_channels
|
||||||
|
self.hidden_dim = max(16, self.feat_channels // 4)
|
||||||
|
|
||||||
|
self.global_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
|
||||||
|
self.bottle_neck = nn.Sequential(
|
||||||
|
nn.Linear(self.feat_channels, self.hidden_dim),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(self.hidden_dim, self.feat_channels),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, feature_map):
|
||||||
|
b, c, h, w = feature_map.shape
|
||||||
|
pooled = self.global_pool(feature_map).view(b, c)
|
||||||
|
weights = self.bottle_neck(pooled)
|
||||||
|
weights4d = weights.view(b, c, 1, 1)
|
||||||
|
return feature_map * (1.0 + weights4d)
|
||||||
|
|
||||||
|
|
||||||
|
# 仅返回特征图的恒等变换
|
||||||
|
class FusionIdentity(nn.Module):
|
||||||
|
def forward(self, feature_map, param_emb):
|
||||||
|
return feature_map
|
||||||
|
|
||||||
|
|
||||||
|
# 使用 FiLM 融合的模型
|
||||||
|
class MOAFWithFiLM(nn.Module):
|
||||||
|
def __init__(self, fusion_level=None):
|
||||||
|
super().__init__()
|
||||||
|
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
|
||||||
|
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
|
||||||
|
self.stage2 = shuff.stage2
|
||||||
|
self.stage3 = shuff.stage3
|
||||||
|
self.stage4 = shuff.stage4
|
||||||
|
self.cbr2 = shuff.conv5
|
||||||
|
|
||||||
|
self.param_embedding = ParamEmbedding()
|
||||||
|
|
||||||
|
if fusion_level is None:
|
||||||
|
fusion_level = [2]
|
||||||
|
|
||||||
|
self.film_block0 = FiLMBlock(feat_channels=48) if 0 in fusion_level else FusionIdentity()
|
||||||
|
self.film_block1 = FiLMBlock(feat_channels=96) if 1 in fusion_level else FusionIdentity()
|
||||||
|
self.film_block2 = FiLMBlock(feat_channels=192) if 2 in fusion_level else FusionIdentity()
|
||||||
|
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.regressor = nn.Sequential(
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(1024, 128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(128, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image, params):
|
||||||
|
x = self.cbrm(image)
|
||||||
|
x = self.stage2(x)
|
||||||
|
param_emb = self.param_embedding(params)
|
||||||
|
x = self.film_block0(x, param_emb)
|
||||||
|
x = self.stage3(x)
|
||||||
|
x = self.film_block1(x, param_emb)
|
||||||
|
x = self.stage4(x)
|
||||||
|
x = self.film_block2(x, param_emb)
|
||||||
|
x = self.cbr2(x)
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = self.regressor(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# 使用交叉注意力融合的模型
|
||||||
|
class MOAFWithChannelCrossAttention(nn.Module):
|
||||||
|
def __init__(self, fusion_level=None):
|
||||||
|
super().__init__()
|
||||||
|
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
|
||||||
|
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
|
||||||
|
self.stage2 = shuff.stage2
|
||||||
|
self.stage3 = shuff.stage3
|
||||||
|
self.stage4 = shuff.stage4
|
||||||
|
self.cbr2 = shuff.conv5
|
||||||
|
|
||||||
|
self.param_embedding = ParamEmbedding()
|
||||||
|
|
||||||
|
if fusion_level is None:
|
||||||
|
fusion_level = [2]
|
||||||
|
|
||||||
|
self.cca_block0 = ChannelCrossAttention(feat_channels=48) if 0 in fusion_level else FusionIdentity()
|
||||||
|
self.cca_block1 = ChannelCrossAttention(feat_channels=96) if 1 in fusion_level else FusionIdentity()
|
||||||
|
self.cca_block2 = ChannelCrossAttention(feat_channels=192) if 2 in fusion_level else FusionIdentity()
|
||||||
|
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.regressor = nn.Sequential(
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(1024, 128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(128, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image, params):
|
||||||
|
x = self.cbrm(image)
|
||||||
|
x = self.stage2(x)
|
||||||
|
param_emb = self.param_embedding(params)
|
||||||
|
x = self.cca_block0(x, param_emb)
|
||||||
|
x = self.stage3(x)
|
||||||
|
x = self.cca_block1(x, param_emb)
|
||||||
|
x = self.stage4(x)
|
||||||
|
x = self.cca_block2(x, param_emb)
|
||||||
|
x = self.cbr2(x)
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = self.regressor(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# 使用 SE 块但无融合的模型
|
||||||
|
class MOAFWithSE(nn.Module):
|
||||||
|
def __init__(self, fusion_level=None):
|
||||||
|
super().__init__()
|
||||||
|
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
|
||||||
|
self.cbrm = nn.Sequential(shuff.conv1, shuff.maxpool)
|
||||||
|
self.stage2 = shuff.stage2
|
||||||
|
self.stage3 = shuff.stage3
|
||||||
|
self.stage4 = shuff.stage4
|
||||||
|
self.cbr2 = shuff.conv5
|
||||||
|
|
||||||
|
if fusion_level is None:
|
||||||
|
fusion_level = [0]
|
||||||
|
|
||||||
|
self.se_block0 = SEBlock(feat_channels=48) if 0 in fusion_level else nn.Identity()
|
||||||
|
self.se_block1 = SEBlock(feat_channels=96) if 1 in fusion_level else nn.Identity()
|
||||||
|
self.se_block2 = SEBlock(feat_channels=192) if 2 in fusion_level else nn.Identity()
|
||||||
|
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.regressor = nn.Sequential(
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(1024, 128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(128, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image, params):
|
||||||
|
x = self.cbrm(image)
|
||||||
|
x = self.stage2(x)
|
||||||
|
x = self.se_block0(x)
|
||||||
|
x = self.stage3(x)
|
||||||
|
x = self.se_block1(x)
|
||||||
|
x = self.stage4(x)
|
||||||
|
x = self.se_block2(x)
|
||||||
|
x = self.cbr2(x)
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = self.regressor(x)
|
||||||
|
return x
|
||||||
|
|
||||||
88
MOAFTest.py
Normal file
88
MOAFTest.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
import tomllib
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from MOAFUtils import print_with_timestamp
|
||||||
|
from MOAFDatasets import MOAFDataset
|
||||||
|
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
|
||||||
|
|
||||||
|
|
||||||
|
def test(model, test_loader, device, model_type, output_type):
|
||||||
|
model.eval()
|
||||||
|
results = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in tqdm(test_loader, desc=f"[Test]"):
|
||||||
|
images = data["image"].to(device, non_blocking=True)
|
||||||
|
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
||||||
|
outputs = model(images, params)
|
||||||
|
results.extend(outputs.squeeze(1).cpu().numpy())
|
||||||
|
|
||||||
|
results = np.array(results)
|
||||||
|
|
||||||
|
# 直接预测结果保存为 excel 文件
|
||||||
|
df = test_loader.dataset.get_dataframe()
|
||||||
|
df["pred"] = results
|
||||||
|
Path("results").mkdir(exist_ok=True, parents=True)
|
||||||
|
# !pip install openpyxl
|
||||||
|
df.to_excel(f"results/{model_type}_{output_type}.xlsx", index=False)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 解析命令行参数
|
||||||
|
parser = argparse.ArgumentParser(description="Test script that loads config from a TOML file.")
|
||||||
|
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
with open(args.config, "rb") as f:
|
||||||
|
cfg = tomllib.load(f)
|
||||||
|
|
||||||
|
# 确定超参数
|
||||||
|
model_type = cfg["model_type"]
|
||||||
|
output_type = cfg["output_type"]
|
||||||
|
dataset_dir = cfg["dataset_dir"]
|
||||||
|
batch_size = int(cfg["batch_size"])
|
||||||
|
num_workers = int(cfg["num_workers"])
|
||||||
|
objective_params_list = cfg["test_objective_params_list"]
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
print_with_timestamp(f"Using device {device}")
|
||||||
|
|
||||||
|
# 加载数据集
|
||||||
|
test_set = MOAFDataset(dataset_dir, "test", objective_params_list, output_type)
|
||||||
|
print_with_timestamp("Dataset Done")
|
||||||
|
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_set, batch_size=batch_size, num_workers=num_workers,
|
||||||
|
shuffle=False, pin_memory=True, persistent_workers=True
|
||||||
|
)
|
||||||
|
print_with_timestamp("Dataset Loaded")
|
||||||
|
|
||||||
|
# 模型选择
|
||||||
|
if "film" in model_type:
|
||||||
|
fusion_depth_list = [int(ch) for ch in model_type[4:]]
|
||||||
|
model = MOAFWithFiLM(fusion_depth_list).to(device)
|
||||||
|
elif "cca" in model_type:
|
||||||
|
fusion_depth_list = [int(ch) for ch in model_type[3:]]
|
||||||
|
model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device)
|
||||||
|
elif "se" in model_type:
|
||||||
|
fusion_depth_list = [int(ch) for ch in model_type[2:]]
|
||||||
|
model = MOAFWithSE(fusion_depth_list).to(device)
|
||||||
|
else:
|
||||||
|
model = MOAFNoFusion().to(device)
|
||||||
|
|
||||||
|
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
print_with_timestamp("Model Loaded")
|
||||||
|
|
||||||
|
print_with_timestamp("Start testing")
|
||||||
|
test(model, test_loader, device, model_type, output_type)
|
||||||
|
print_with_timestamp("Testing completed!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
178
MOAFTrain.py
Normal file
178
MOAFTrain.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
import time
|
||||||
|
import tomllib
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from MOAFUtils import print_with_timestamp
|
||||||
|
from MOAFDatasets import MOAFDataset
|
||||||
|
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
|
||||||
|
model.train()
|
||||||
|
train_loss = 0.0
|
||||||
|
|
||||||
|
for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"):
|
||||||
|
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
||||||
|
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(images, params)
|
||||||
|
loss = loss_fn(outputs.squeeze(1), labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
train_loss += loss.item()
|
||||||
|
|
||||||
|
return train_loss
|
||||||
|
|
||||||
|
|
||||||
|
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"):
|
||||||
|
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
||||||
|
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
||||||
|
|
||||||
|
outputs = model(images, params)
|
||||||
|
loss = loss_fn(outputs.squeeze(1), labels)
|
||||||
|
val_loss += loss.item()
|
||||||
|
|
||||||
|
return val_loss
|
||||||
|
|
||||||
|
|
||||||
|
def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type):
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
patience_counter = 0
|
||||||
|
# !pip install tensorboard
|
||||||
|
with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer:
|
||||||
|
# Tensorboard 上显示模型结构
|
||||||
|
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
|
||||||
|
writer.add_graph(model, (dummy_input1, dummy_input2))
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
# 训练/验证
|
||||||
|
start_time = time.time()
|
||||||
|
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader)
|
||||||
|
avg_val_loss = valid_epoch(model,val_loader, epoch, epochs, device, loss_fn) / len(val_loader)
|
||||||
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
scheduler.step()
|
||||||
|
epoch_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 显示记录数据
|
||||||
|
writer.add_scalar('Loss/train', avg_train_loss, epoch)
|
||||||
|
writer.add_scalar('Loss/val', avg_val_loss, epoch)
|
||||||
|
writer.add_scalar('LearningRate', current_lr, epoch)
|
||||||
|
writer.add_scalar('Time/epoch', epoch_time, epoch)
|
||||||
|
|
||||||
|
print_with_timestamp(f"Epoch {epoch+1}/{epochs} | "
|
||||||
|
f"Train Loss: {avg_train_loss:.6f} | "
|
||||||
|
f"Val Loss: {avg_val_loss:.6f} | "
|
||||||
|
f"LR: {current_lr:.2e} | "
|
||||||
|
f"Time: {epoch_time:.2f}s")
|
||||||
|
|
||||||
|
# 记录检查点
|
||||||
|
if avg_val_loss < best_val_loss:
|
||||||
|
best_val_loss = avg_val_loss
|
||||||
|
patience_counter = 0
|
||||||
|
save_dict = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"model_state_dict": model.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"scheduler_state_dict": scheduler.state_dict(),
|
||||||
|
"train_loss": avg_train_loss,
|
||||||
|
"val_loss": avg_val_loss
|
||||||
|
}
|
||||||
|
Path("ckpts").mkdir(exist_ok=True, parents=True)
|
||||||
|
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
|
||||||
|
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
|
||||||
|
else:
|
||||||
|
patience_counter += 1
|
||||||
|
if patience_counter > patience:
|
||||||
|
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 解析命令行参数
|
||||||
|
parser = argparse.ArgumentParser(description="Train script that loads config from a TOML file.")
|
||||||
|
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
with open(args.config, "rb") as f:
|
||||||
|
cfg = tomllib.load(f)
|
||||||
|
|
||||||
|
# 确定超参数
|
||||||
|
model_type = cfg["model_type"]
|
||||||
|
output_type = cfg["output_type"]
|
||||||
|
dataset_dir = cfg["dataset_dir"]
|
||||||
|
batch_size = int(cfg["batch_size"])
|
||||||
|
num_workers = int(cfg["num_workers"])
|
||||||
|
lr = float(cfg["lr"])
|
||||||
|
patience = int(cfg["patience"])
|
||||||
|
epochs = int(cfg["epochs"])
|
||||||
|
warmup_epochs = int(cfg["warmup_epochs"])
|
||||||
|
objective_params_list = cfg["train_objective_params_list"]
|
||||||
|
checkpoint_load = cfg["checkpoint_load"]
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
print_with_timestamp(f"Using device {device}")
|
||||||
|
|
||||||
|
# 加载数据集
|
||||||
|
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
|
||||||
|
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
|
||||||
|
print_with_timestamp("Dataset Done")
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_set, batch_size=batch_size, num_workers=num_workers,
|
||||||
|
shuffle=True, pin_memory=True, persistent_workers=True
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_set, batch_size=batch_size, num_workers=num_workers,
|
||||||
|
shuffle=False, pin_memory=True, persistent_workers=True
|
||||||
|
)
|
||||||
|
print_with_timestamp("Dataset Loaded")
|
||||||
|
|
||||||
|
# 模型选择
|
||||||
|
if "film" in model_type:
|
||||||
|
fusion_depth_list = [int(ch) for ch in model_type[4:]]
|
||||||
|
model = MOAFWithFiLM(fusion_depth_list).to(device)
|
||||||
|
elif "cca" in model_type:
|
||||||
|
fusion_depth_list = [int(ch) for ch in model_type[3:]]
|
||||||
|
model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device)
|
||||||
|
elif "se" in model_type:
|
||||||
|
fusion_depth_list = [int(ch) for ch in model_type[2:]]
|
||||||
|
model = MOAFWithSE(fusion_depth_list).to(device)
|
||||||
|
else:
|
||||||
|
model = MOAFNoFusion().to(device)
|
||||||
|
print_with_timestamp("Model Loaded")
|
||||||
|
|
||||||
|
# 形式化预训练参数加载
|
||||||
|
if checkpoint_load:
|
||||||
|
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
print_with_timestamp("Model Checkpoint Params Loaded")
|
||||||
|
|
||||||
|
# 损失函数、优化器、学习率调度器
|
||||||
|
loss_fn = nn.HuberLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||||
|
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
|
optimizer,
|
||||||
|
lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs
|
||||||
|
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
|
||||||
|
)
|
||||||
|
|
||||||
|
print_with_timestamp("Start trainning")
|
||||||
|
fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, patience, model_type, output_type)
|
||||||
|
print_with_timestamp("Training completed!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
239
MOAFTrainDDP.py
Normal file
239
MOAFTrainDDP.py
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
import tomllib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from MOAFUtils import print_with_timestamp
|
||||||
|
from MOAFDatasets import MOAFDataset
|
||||||
|
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
|
||||||
|
model.train()
|
||||||
|
train_loss = 0.0
|
||||||
|
|
||||||
|
for data in tqdm(train_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Train]"):
|
||||||
|
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
||||||
|
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(images, params)
|
||||||
|
loss = loss_fn(outputs.squeeze(1), labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
train_loss += loss.item()
|
||||||
|
|
||||||
|
return train_loss
|
||||||
|
|
||||||
|
|
||||||
|
def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in tqdm(val_loader, desc=f"Epoch {epoch+1:03d}/{epochs:03d} [Valid]"):
|
||||||
|
images, labels = data["image"].to(device, non_blocking=True), data["label"].to(device, non_blocking=True)
|
||||||
|
params = torch.stack((data["mag"], data["na"], data["rix"]), dim=1).to(device, non_blocking=True)
|
||||||
|
|
||||||
|
outputs = model(images, params)
|
||||||
|
loss = loss_fn(outputs.squeeze(1), labels)
|
||||||
|
val_loss += loss.item()
|
||||||
|
|
||||||
|
return val_loss
|
||||||
|
|
||||||
|
|
||||||
|
def fit(rank, world_size, cfg):
|
||||||
|
"""
|
||||||
|
每个进程运行的主函数(单卡)。
|
||||||
|
rank: 该进程的全局 rank(0 ~ world_size-1)
|
||||||
|
world_size: 进程总数(通常等于可用 GPU 数)
|
||||||
|
cfg: 从 toml 读取的配置字典
|
||||||
|
"""
|
||||||
|
# -------- init distributed env --------
|
||||||
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
|
os.environ['MASTER_PORT'] = cfg.get("master_port", "29500")
|
||||||
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
device = torch.device(f'cuda:{rank}')
|
||||||
|
if rank == 0:
|
||||||
|
print_with_timestamp(f"Distributed initialized. World size: {world_size}")
|
||||||
|
|
||||||
|
# -------- parse hyperparams from cfg --------
|
||||||
|
model_type = cfg["model_type"]
|
||||||
|
output_type = cfg["output_type"]
|
||||||
|
dataset_dir = cfg["dataset_dir"]
|
||||||
|
batch_size = int(cfg["batch_size"])
|
||||||
|
num_workers = int(cfg["num_workers"])
|
||||||
|
lr = float(cfg["lr"])
|
||||||
|
patience = int(cfg["patience"])
|
||||||
|
epochs = int(cfg["epochs"])
|
||||||
|
warmup_epochs = int(cfg["warmup_epochs"])
|
||||||
|
objective_params_list = cfg["train_objective_params_list"]
|
||||||
|
checkpoint_load = cfg["checkpoint_load"]
|
||||||
|
|
||||||
|
# -------- datasets & distributed sampler --------
|
||||||
|
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
|
||||||
|
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
|
||||||
|
|
||||||
|
# DistributedSampler ensures each process sees a unique subset each epoch
|
||||||
|
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
|
||||||
|
val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False)
|
||||||
|
|
||||||
|
# When using DistributedSampler, do not shuffle at DataLoader level
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_set, batch_size=batch_size, num_workers=num_workers,
|
||||||
|
shuffle=False, pin_memory=True, persistent_workers=True, sampler=train_sampler
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_set, batch_size=batch_size, num_workers=num_workers,
|
||||||
|
shuffle=False, pin_memory=True, persistent_workers=True, sampler=val_sampler
|
||||||
|
)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
print_with_timestamp("Dataset Loaded (Distributed)")
|
||||||
|
|
||||||
|
# -------- model creation --------
|
||||||
|
if "film" in model_type:
|
||||||
|
fusion_depth_list = [int(ch) for ch in model_type[4:]]
|
||||||
|
model = MOAFWithFiLM(fusion_depth_list).to(device)
|
||||||
|
elif "cca" in model_type:
|
||||||
|
fusion_depth_list = [int(ch) for ch in model_type[3:]]
|
||||||
|
model = MOAFWithChannelCrossAttention(fusion_depth_list).to(device)
|
||||||
|
elif "se" in model_type:
|
||||||
|
fusion_depth_list = [int(ch) for ch in model_type[2:]]
|
||||||
|
model = MOAFWithSE(fusion_depth_list).to(device)
|
||||||
|
else:
|
||||||
|
model = MOAFNoFusion().to(device)
|
||||||
|
|
||||||
|
# 形式化预训练参数加载
|
||||||
|
if checkpoint_load:
|
||||||
|
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
print_with_timestamp("Model Checkpoint Params Loaded")
|
||||||
|
|
||||||
|
# wrap with DDP. device_ids ensures single-GPU per process
|
||||||
|
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
|
||||||
|
if rank == 0:
|
||||||
|
print_with_timestamp("Model Wrapped with DDP")
|
||||||
|
|
||||||
|
# -------- loss / optimizer / scheduler --------
|
||||||
|
loss_fn = nn.HuberLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||||
|
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
|
optimizer,
|
||||||
|
lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs
|
||||||
|
else 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
|
||||||
|
)
|
||||||
|
|
||||||
|
# -------- TensorBoard & checkpoint only on rank 0 --------
|
||||||
|
if rank == 0:
|
||||||
|
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}")
|
||||||
|
# tensorboard graph: use a small dummy input placed on correct device
|
||||||
|
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
|
||||||
|
tb_writer.add_graph(model.module, (dummy_input1, dummy_input2))
|
||||||
|
|
||||||
|
else:
|
||||||
|
tb_writer = None
|
||||||
|
|
||||||
|
# -------- training loop with early stopping (only rank 0 saves checkpoints) --------
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
patience_counter = 0
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
print_with_timestamp("Start training (DDP)")
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
# 必须在每个 epoch 开头设置 epoch 给 sampler,这样 shuffle 能跨 epoch 工作
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
val_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
avg_train_loss = train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn) / len(train_loader)
|
||||||
|
avg_val_loss = valid_epoch(model, val_loader, epoch, epochs, device, loss_fn) / len(val_loader)
|
||||||
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
scheduler.step()
|
||||||
|
epoch_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 只有 rank 0 写 tensorboard 和保存 checkpoint,避免重复 IO
|
||||||
|
if rank == 0:
|
||||||
|
tb_writer.add_scalar('Loss/train', avg_train_loss, epoch)
|
||||||
|
tb_writer.add_scalar('Loss/val', avg_val_loss, epoch)
|
||||||
|
tb_writer.add_scalar('LearningRate', current_lr, epoch)
|
||||||
|
tb_writer.add_scalar('Time/epoch', epoch_time, epoch)
|
||||||
|
|
||||||
|
print_with_timestamp(f"Epoch {epoch+1}/{epochs} | "
|
||||||
|
f"Train Loss: {avg_train_loss:.6f} | "
|
||||||
|
f"Val Loss: {avg_val_loss:.6f} | "
|
||||||
|
f"LR: {current_lr:.2e} | "
|
||||||
|
f"Time: {epoch_time:.2f}s")
|
||||||
|
|
||||||
|
if avg_val_loss < best_val_loss:
|
||||||
|
best_val_loss = avg_val_loss
|
||||||
|
patience_counter = 0
|
||||||
|
save_dict = {
|
||||||
|
"epoch": epoch,
|
||||||
|
# 保存 module.state_dict()(DDP 包裹时用 module)
|
||||||
|
"model_state_dict": model.module.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"scheduler_state_dict": scheduler.state_dict(),
|
||||||
|
"train_loss": avg_train_loss,
|
||||||
|
"val_loss": avg_val_loss
|
||||||
|
}
|
||||||
|
Path("ckpts").mkdir(exist_ok=True, parents=True)
|
||||||
|
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
|
||||||
|
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
|
||||||
|
else:
|
||||||
|
patience_counter += 1
|
||||||
|
if patience_counter > patience:
|
||||||
|
print_with_timestamp(f"Early stopping at {epoch+1} epochs")
|
||||||
|
break
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.close()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
if rank == 0:
|
||||||
|
print_with_timestamp("Training completed and distributed cleaned up")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args_and_cfg():
|
||||||
|
parser = argparse.ArgumentParser(description="DDP train script that loads config from a TOML file.")
|
||||||
|
parser.add_argument("config", help="Path to TOML config file (e.g., config.toml)")
|
||||||
|
parser.add_argument("--nproc_per_node", type=int, default=None, help="number of processes (GPUs) to use on this node")
|
||||||
|
args = parser.parse_args()
|
||||||
|
with open(args.config, "rb") as f:
|
||||||
|
cfg = tomllib.load(f)
|
||||||
|
return args, cfg
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args, cfg = parse_args_and_cfg()
|
||||||
|
# 自动检测可用 GPU 数量,除非用户指定 nproc_per_node
|
||||||
|
available_gpus = torch.cuda.device_count()
|
||||||
|
if available_gpus == 0:
|
||||||
|
raise RuntimeError("No CUDA devices available for DDP. Use a GPU machine.")
|
||||||
|
|
||||||
|
nproc = args.nproc_per_node if args.nproc_per_node is not None else available_gpus
|
||||||
|
if nproc > available_gpus:
|
||||||
|
raise ValueError(f"Requested {nproc} processes but only {available_gpus} GPUs available")
|
||||||
|
|
||||||
|
# 使用 spawn 启动 nproc 个子进程,每个进程运行 fit(rank, world_size, cfg)
|
||||||
|
# spawn 会在每个子进程中调用 fit(rank, world_size, cfg) 并传入不同的 rank
|
||||||
|
mp.spawn(fit, args=(nproc, cfg), nprocs=nproc, join=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
21
MOAFUtils.py
Normal file
21
MOAFUtils.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
def print_with_timestamp(*args, sep=' ', end='\n', file=None, flush=False, time_format="%Y-%m-%d %H:%M:%S"):
|
||||||
|
"""
|
||||||
|
与 print 类似,但在输出内容前加时间戳前缀。
|
||||||
|
- *args: 要打印的对象(会用 sep 连接)
|
||||||
|
- sep, end, file, flush: 与内置 print 相同
|
||||||
|
- time_format: 时间格式,默认 '%Y-%m-%d %H:%M:%S',遵循 datetime.strftime 格式
|
||||||
|
"""
|
||||||
|
if file is None:
|
||||||
|
file = sys.stdout
|
||||||
|
|
||||||
|
ts = datetime.now().strftime(time_format)
|
||||||
|
body = sep.join(str(a) for a in args)
|
||||||
|
output = f'{ts} --> {body}'
|
||||||
|
print(output, end=end, file=file, flush=flush)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print_with_timestamp(f"Test output")
|
||||||
26
config_example.toml
Normal file
26
config_example.toml
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# 模型与数据
|
||||||
|
model_type = "cca2"
|
||||||
|
output_type = "distance"
|
||||||
|
dataset_dir = "F:/Datasets/MODatasetD"
|
||||||
|
# 训练参数
|
||||||
|
batch_size = 64
|
||||||
|
num_workers = 8
|
||||||
|
lr = 1e-4
|
||||||
|
patience = 5
|
||||||
|
epochs = 5
|
||||||
|
warmup_epochs = 1
|
||||||
|
# 其它
|
||||||
|
train_objective_params_list = [
|
||||||
|
"10x-0.25-1.0000", "10x-0.30-1.0000",
|
||||||
|
"20x-0.70-1.0000", "20x-0.80-1.0000",
|
||||||
|
"40x-0.65-1.0000", "100x-0.80-1.0000",
|
||||||
|
"100x-1.25-1.4730"
|
||||||
|
]
|
||||||
|
test_objective_params_list = [
|
||||||
|
"10x-0.25-1.0000", "10x-0.30-1.0000",
|
||||||
|
"20x-0.70-1.0000", "20x-0.80-1.0000",
|
||||||
|
"40x-0.65-1.0000", "100x-0.80-1.0000",
|
||||||
|
"100x-1.25-1.4730"
|
||||||
|
]
|
||||||
|
# 加载形式化预训练参数
|
||||||
|
checkpoint_load = true
|
||||||
Loading…
x
Reference in New Issue
Block a user