add mmlp and other functions
This commit is contained in:
parent
2bdefda64e
commit
a8969ea76e
@ -7,12 +7,11 @@ from torchvision import transforms
|
|||||||
|
|
||||||
|
|
||||||
class MOAFDataset(Dataset):
|
class MOAFDataset(Dataset):
|
||||||
def __init__(self, dataset_root, tvt='train', objectives_params_list=None, output_type='distance'):
|
def __init__(self, dataset_root, tvt='train', objectives_params_list=None):
|
||||||
"""
|
"""
|
||||||
dataset_root: 根目录(Pathable)
|
dataset_root: 根目录(Pathable)
|
||||||
tvt: 'train'|'val'|'test'(用于选择 transform)
|
tvt: 'train'|'val'|'test'(用于选择 transform)
|
||||||
objectives_params_list: 列表,包含要加载的物镜目录名,例如 ["10x-0.25-1.0000", ...]
|
objectives_params_list: 列表,包含要加载的物镜目录名,例如 ["10x-0.25-1.0000", ...]
|
||||||
output_type: 'distance'(返回 nm)或 'ratio'(返回 defocus / DoF)
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dataset_root = Path(dataset_root)
|
self.dataset_root = Path(dataset_root)
|
||||||
@ -22,12 +21,6 @@ class MOAFDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
self.objectives_params_list = objectives_params_list
|
self.objectives_params_list = objectives_params_list
|
||||||
|
|
||||||
# 处理 output_type,非法输入回退到 'distance'
|
|
||||||
if isinstance(output_type, str) and output_type.lower() == "ratio":
|
|
||||||
self.output_type = "ratio"
|
|
||||||
else:
|
|
||||||
self.output_type = "distance"
|
|
||||||
|
|
||||||
# 根据 tvt 选择 transform
|
# 根据 tvt 选择 transform
|
||||||
if self.tvt == "train":
|
if self.tvt == "train":
|
||||||
self.transform = transforms.Compose([
|
self.transform = transforms.Compose([
|
||||||
@ -101,30 +94,17 @@ class MOAFDataset(Dataset):
|
|||||||
rix_tensor = torch.tensor(rix, dtype=torch.float32)
|
rix_tensor = torch.tensor(rix, dtype=torch.float32)
|
||||||
label_nm_tensor = torch.tensor(label_nm, dtype=torch.float32)
|
label_nm_tensor = torch.tensor(label_nm, dtype=torch.float32)
|
||||||
|
|
||||||
# min-max 归一化输入参数
|
# # min-max 归一化输入参数
|
||||||
mag_tensor = (mag_tensor - 10) / (100 - 10)
|
# mag_tensor = (mag_tensor - 10) / (100 - 10)
|
||||||
na_tensor = (na_tensor - 0) / (1.25 - 0)
|
# na_tensor = (na_tensor - 0) / (1.25 - 0)
|
||||||
rix_tensor = (rix_tensor - 1.0) / (1.5 - 1.0)
|
# rix_tensor = (rix_tensor - 1.0) / (1.5 - 1.0)
|
||||||
|
|
||||||
# 根据 output_type 决定输出 label
|
|
||||||
if self.output_type == "ratio":
|
|
||||||
dof_nm = self._compute_dof_nm(mag=mag, na=na, rix=rix, wavelength_nm=550.0, pixel_size_nm=3450.0)
|
|
||||||
# 若 DOF 为 inf 或极大,避免除零
|
|
||||||
if not (dof_nm is None or dof_nm == float('inf') or dof_nm == 0):
|
|
||||||
label_out = label_nm / dof_nm
|
|
||||||
else:
|
|
||||||
label_out = label_nm # 回退,虽然不太可能
|
|
||||||
label_out_tensor = torch.tensor(float(label_out), dtype=torch.float32)
|
|
||||||
else:
|
|
||||||
# distance 模式:直接返回 nm
|
|
||||||
label_out_tensor = label_nm_tensor
|
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
'image': image,
|
'image': image,
|
||||||
'mag': mag_tensor,
|
'mag': mag_tensor,
|
||||||
'na': na_tensor,
|
'na': na_tensor,
|
||||||
'rix': rix_tensor,
|
'rix': rix_tensor,
|
||||||
'label': label_out_tensor,
|
'label': label_nm_tensor,
|
||||||
'path': img_path.as_posix(),
|
'path': img_path.as_posix(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,8 +119,7 @@ if __name__ == "__main__":
|
|||||||
train_set = MOAFDataset("F:/Datasets/MODatasetD", tvt='train',
|
train_set = MOAFDataset("F:/Datasets/MODatasetD", tvt='train',
|
||||||
objectives_params_list=[
|
objectives_params_list=[
|
||||||
"100x-1.25-1.4730",
|
"100x-1.25-1.4730",
|
||||||
],
|
])
|
||||||
output_type='ratio')
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2)
|
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2)
|
||||||
for batch in train_loader:
|
for batch in train_loader:
|
||||||
|
|||||||
@ -40,8 +40,14 @@ class ParamEmbedding(nn.Module):
|
|||||||
nn.LayerNorm(out_dim)
|
nn.LayerNorm(out_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, params):
|
||||||
return self.embedding(x)
|
# min-max 归一化参数
|
||||||
|
normalized_params = torch.stack([
|
||||||
|
(params[:, 0] - 10.0) / 90.0,
|
||||||
|
params[:, 1] / 1.25,
|
||||||
|
(params[:, 2] - 1.0) / 0.5
|
||||||
|
], dim=1)
|
||||||
|
return self.embedding(normalized_params)
|
||||||
|
|
||||||
|
|
||||||
# FiLM 融合块
|
# FiLM 融合块
|
||||||
@ -238,3 +244,73 @@ class MOAFWithSE(nn.Module):
|
|||||||
x = self.regressor(x)
|
x = self.regressor(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
# 多回归头模型
|
||||||
|
class MOAFWithMMLP(nn.Module):
|
||||||
|
def __init__(self, num_lenses=7):
|
||||||
|
super().__init__()
|
||||||
|
shuff = models.shufflenet_v2_x0_5(weights="DEFAULT")
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
shuff.conv1, shuff.maxpool,
|
||||||
|
shuff.stage2, shuff.stage3,
|
||||||
|
shuff.stage4, shuff.conv5
|
||||||
|
)
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.num_lenses = num_lenses
|
||||||
|
self.regressors = nn.ModuleList([
|
||||||
|
self._create_regressor_head() for _ in range(num_lenses)
|
||||||
|
])
|
||||||
|
|
||||||
|
def _create_regressor_head(self):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(1024, 128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(128, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_lens_id(self, params):
|
||||||
|
batch_size = params.shape[0]
|
||||||
|
lens_ids = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
mag, na, rix = params[i][0].item(), params[i][1].item(), params[i][2].item()
|
||||||
|
|
||||||
|
# 直接的条件匹配(类似于switch-case)
|
||||||
|
if mag == 10 and na == 0.25 and rix == 1.0000:
|
||||||
|
lens_ids.append(0) # obj1
|
||||||
|
elif mag == 10 and na == 0.30 and rix == 1.0000:
|
||||||
|
lens_ids.append(1) # obj2
|
||||||
|
elif mag == 20 and na == 0.70 and rix == 1.0000:
|
||||||
|
lens_ids.append(2) # obj3
|
||||||
|
elif mag == 20 and na == 0.80 and rix == 1.0000:
|
||||||
|
lens_ids.append(3) # obj4
|
||||||
|
elif mag == 40 and na == 0.65 and rix == 1.0000:
|
||||||
|
lens_ids.append(4) # obj5
|
||||||
|
elif mag == 100 and na == 0.80 and rix == 1.0000:
|
||||||
|
lens_ids.append(5) # obj6
|
||||||
|
elif mag == 100 and na == 1.25 and rix == 1.4730:
|
||||||
|
lens_ids.append(6) # obj7
|
||||||
|
else:
|
||||||
|
lens_ids.append(0)
|
||||||
|
|
||||||
|
return torch.tensor(lens_ids, dtype=torch.long, device=params.device)
|
||||||
|
|
||||||
|
def forward(self, image, params):
|
||||||
|
x = self.features(image)
|
||||||
|
x = self.avgpool(x)
|
||||||
|
|
||||||
|
lens_ids = self._find_lens_id(params)
|
||||||
|
batch_size = params.size(0)
|
||||||
|
outputs = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
current_lens_id = lens_ids[i]
|
||||||
|
# 确保lens_id在有效范围内
|
||||||
|
if current_lens_id < 0 or current_lens_id >= self.num_lenses:
|
||||||
|
current_lens_id = 0
|
||||||
|
|
||||||
|
# 选择对应的回归头
|
||||||
|
head_output = self.regressors[current_lens_id](x[i].unsqueeze(0))
|
||||||
|
outputs.append(head_output)
|
||||||
|
|
||||||
|
return torch.cat(outputs, dim=0)
|
||||||
|
|
||||||
|
|||||||
16
MOAFTest.py
16
MOAFTest.py
@ -9,10 +9,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
from MOAFUtils import print_with_timestamp
|
from MOAFUtils import print_with_timestamp
|
||||||
from MOAFDatasets import MOAFDataset
|
from MOAFDatasets import MOAFDataset
|
||||||
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
|
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE, MOAFWithMMLP
|
||||||
|
|
||||||
|
|
||||||
def test(model, test_loader, device, model_type, output_type):
|
def test(model, test_loader, device, model_type, dataset_type):
|
||||||
model.eval()
|
model.eval()
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ def test(model, test_loader, device, model_type, output_type):
|
|||||||
df["pred"] = results
|
df["pred"] = results
|
||||||
Path("results").mkdir(exist_ok=True, parents=True)
|
Path("results").mkdir(exist_ok=True, parents=True)
|
||||||
# !pip install openpyxl
|
# !pip install openpyxl
|
||||||
df.to_excel(f"results/{model_type}_{output_type}.xlsx", index=False)
|
df.to_excel(f"results/{model_type}_{dataset_type}.xlsx", index=False)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -43,7 +43,7 @@ def main():
|
|||||||
|
|
||||||
# 确定超参数
|
# 确定超参数
|
||||||
model_type = cfg["model_type"]
|
model_type = cfg["model_type"]
|
||||||
output_type = cfg["output_type"]
|
dataset_type = cfg["dataset_type"]
|
||||||
dataset_dir = cfg["dataset_dir"]
|
dataset_dir = cfg["dataset_dir"]
|
||||||
batch_size = int(cfg["batch_size"])
|
batch_size = int(cfg["batch_size"])
|
||||||
num_workers = int(cfg["num_workers"])
|
num_workers = int(cfg["num_workers"])
|
||||||
@ -53,7 +53,7 @@ def main():
|
|||||||
print_with_timestamp(f"Using device {device}")
|
print_with_timestamp(f"Using device {device}")
|
||||||
|
|
||||||
# 加载数据集
|
# 加载数据集
|
||||||
test_set = MOAFDataset(dataset_dir, "test", objective_params_list, output_type)
|
test_set = MOAFDataset(dataset_dir, "test", objective_params_list)
|
||||||
print_with_timestamp("Dataset Done")
|
print_with_timestamp("Dataset Done")
|
||||||
|
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
@ -72,15 +72,17 @@ def main():
|
|||||||
elif "se" in model_type:
|
elif "se" in model_type:
|
||||||
fusion_depth_list = [int(ch) for ch in model_type[2:]]
|
fusion_depth_list = [int(ch) for ch in model_type[2:]]
|
||||||
model = MOAFWithSE(fusion_depth_list).to(device)
|
model = MOAFWithSE(fusion_depth_list).to(device)
|
||||||
|
elif "mmlp" in model_type:
|
||||||
|
model = MOAFWithMMLP(fusion_depth_list).to(device)
|
||||||
else:
|
else:
|
||||||
model = MOAFNoFusion().to(device)
|
model = MOAFNoFusion().to(device)
|
||||||
|
|
||||||
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
checkpoint = torch.load(f"ckpts/{model_type}_{dataset_type}_best_model.pt", map_location=device, weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
print_with_timestamp("Model Loaded")
|
print_with_timestamp("Model Loaded")
|
||||||
|
|
||||||
print_with_timestamp("Start testing")
|
print_with_timestamp("Start testing")
|
||||||
test(model, test_loader, device, model_type, output_type)
|
test(model, test_loader, device, model_type, dataset_type)
|
||||||
print_with_timestamp("Testing completed!")
|
print_with_timestamp("Testing completed!")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
22
MOAFTrain.py
22
MOAFTrain.py
@ -12,7 +12,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from MOAFUtils import print_with_timestamp
|
from MOAFUtils import print_with_timestamp
|
||||||
from MOAFDatasets import MOAFDataset
|
from MOAFDatasets import MOAFDataset
|
||||||
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
|
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE, MOAFWithMMLP
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
|
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn):
|
||||||
@ -49,10 +49,10 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn):
|
|||||||
return val_loss
|
return val_loss
|
||||||
|
|
||||||
|
|
||||||
def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type):
|
def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, dataset_type):
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
# !pip install tensorboard
|
# !pip install tensorboard
|
||||||
with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer:
|
with SummaryWriter(log_dir=f"runs/{model_type}_{dataset_type}") as writer:
|
||||||
# Tensorboard 上显示模型结构
|
# Tensorboard 上显示模型结构
|
||||||
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
|
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
|
||||||
writer.add_graph(model, (dummy_input1, dummy_input2))
|
writer.add_graph(model, (dummy_input1, dummy_input2))
|
||||||
@ -90,7 +90,7 @@ def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, sch
|
|||||||
"val_loss": avg_val_loss
|
"val_loss": avg_val_loss
|
||||||
}
|
}
|
||||||
Path("ckpts").mkdir(exist_ok=True, parents=True)
|
Path("ckpts").mkdir(exist_ok=True, parents=True)
|
||||||
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
|
torch.save(save_dict, f"ckpts/{model_type}_{dataset_type}_best_model.pt")
|
||||||
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
|
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
|
||||||
|
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ def main():
|
|||||||
|
|
||||||
# 确定超参数
|
# 确定超参数
|
||||||
model_type = cfg["model_type"]
|
model_type = cfg["model_type"]
|
||||||
output_type = cfg["output_type"]
|
dataset_type = cfg["dataset_type"]
|
||||||
dataset_dir = cfg["dataset_dir"]
|
dataset_dir = cfg["dataset_dir"]
|
||||||
batch_size = int(cfg["batch_size"])
|
batch_size = int(cfg["batch_size"])
|
||||||
num_workers = int(cfg["num_workers"])
|
num_workers = int(cfg["num_workers"])
|
||||||
@ -118,8 +118,8 @@ def main():
|
|||||||
print_with_timestamp(f"Using device {device}")
|
print_with_timestamp(f"Using device {device}")
|
||||||
|
|
||||||
# 加载数据集
|
# 加载数据集
|
||||||
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
|
train_set = MOAFDataset(dataset_dir, "train", objective_params_list)
|
||||||
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
|
val_set = MOAFDataset(dataset_dir, "val", objective_params_list)
|
||||||
print_with_timestamp("Dataset Done")
|
print_with_timestamp("Dataset Done")
|
||||||
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
@ -142,14 +142,16 @@ def main():
|
|||||||
elif "se" in model_type:
|
elif "se" in model_type:
|
||||||
fusion_depth_list = [int(ch) for ch in model_type[2:]]
|
fusion_depth_list = [int(ch) for ch in model_type[2:]]
|
||||||
model = MOAFWithSE(fusion_depth_list).to(device)
|
model = MOAFWithSE(fusion_depth_list).to(device)
|
||||||
|
elif "mmlp" in model_type:
|
||||||
|
model = MOAFWithMMLP(fusion_depth_list).to(device)
|
||||||
else:
|
else:
|
||||||
model = MOAFNoFusion().to(device)
|
model = MOAFNoFusion().to(device)
|
||||||
print_with_timestamp("Model Loaded")
|
print_with_timestamp("Model Loaded")
|
||||||
|
|
||||||
# 形式化预训练参数加载
|
# 形式化预训练参数加载
|
||||||
if checkpoint_load:
|
if checkpoint_load:
|
||||||
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
|
if Path(f"ckpts/{model_type}_{dataset_type}_best_model.pt").exists():
|
||||||
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
checkpoint = torch.load(f"ckpts/{model_type}_{dataset_type}_best_model.pt", map_location=device, weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
print_with_timestamp("Model Checkpoint Params Loaded")
|
print_with_timestamp("Model Checkpoint Params Loaded")
|
||||||
else:
|
else:
|
||||||
@ -165,7 +167,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
print_with_timestamp("Start trainning")
|
print_with_timestamp("Start trainning")
|
||||||
fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type)
|
fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, dataset_type)
|
||||||
print_with_timestamp("Training completed!")
|
print_with_timestamp("Training completed!")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from MOAFUtils import print_with_timestamp
|
from MOAFUtils import print_with_timestamp
|
||||||
from MOAFDatasets import MOAFDataset
|
from MOAFDatasets import MOAFDataset
|
||||||
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE
|
from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE, MOAFWithMMLP
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank):
|
def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank):
|
||||||
@ -71,7 +71,7 @@ def fit(rank, world_size, cfg):
|
|||||||
|
|
||||||
# 确定超参数
|
# 确定超参数
|
||||||
model_type = cfg["model_type"]
|
model_type = cfg["model_type"]
|
||||||
output_type = cfg["output_type"]
|
dataset_type = cfg["dataset_type"]
|
||||||
dataset_dir = cfg["dataset_dir"]
|
dataset_dir = cfg["dataset_dir"]
|
||||||
batch_size = int(cfg["batch_size"])
|
batch_size = int(cfg["batch_size"])
|
||||||
num_workers = int(cfg["num_workers"])
|
num_workers = int(cfg["num_workers"])
|
||||||
@ -82,8 +82,8 @@ def fit(rank, world_size, cfg):
|
|||||||
checkpoint_load = cfg["checkpoint_load"]
|
checkpoint_load = cfg["checkpoint_load"]
|
||||||
|
|
||||||
# 加载数据集
|
# 加载数据集
|
||||||
train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type)
|
train_set = MOAFDataset(dataset_dir, "train", objective_params_list)
|
||||||
val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type)
|
val_set = MOAFDataset(dataset_dir, "val", objective_params_list)
|
||||||
|
|
||||||
# 分布式化数据集
|
# 分布式化数据集
|
||||||
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
|
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
|
||||||
@ -111,13 +111,15 @@ def fit(rank, world_size, cfg):
|
|||||||
elif "se" in model_type:
|
elif "se" in model_type:
|
||||||
fusion_depth_list = [int(ch) for ch in model_type[2:]]
|
fusion_depth_list = [int(ch) for ch in model_type[2:]]
|
||||||
model = MOAFWithSE(fusion_depth_list).to(device)
|
model = MOAFWithSE(fusion_depth_list).to(device)
|
||||||
|
elif "mmlp" in model_type:
|
||||||
|
model = MOAFWithMMLP(fusion_depth_list).to(device)
|
||||||
else:
|
else:
|
||||||
model = MOAFNoFusion().to(device)
|
model = MOAFNoFusion().to(device)
|
||||||
|
|
||||||
# 形式化预训练参数加载
|
# 形式化预训练参数加载
|
||||||
if checkpoint_load:
|
if checkpoint_load:
|
||||||
if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists():
|
if Path(f"ckpts/{model_type}_{dataset_type}_best_model.pt").exists():
|
||||||
checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False)
|
checkpoint = torch.load(f"ckpts/{model_type}_{dataset_type}_best_model.pt", map_location=device, weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print_with_timestamp("Model Checkpoint Params Loaded")
|
print_with_timestamp("Model Checkpoint Params Loaded")
|
||||||
@ -141,7 +143,7 @@ def fit(rank, world_size, cfg):
|
|||||||
|
|
||||||
# Tensorboard 上显示模型结构
|
# Tensorboard 上显示模型结构
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}")
|
tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{dataset_type}")
|
||||||
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
|
dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device)
|
||||||
tb_writer.add_graph(model.module, (dummy_input1, dummy_input2))
|
tb_writer.add_graph(model.module, (dummy_input1, dummy_input2))
|
||||||
|
|
||||||
@ -191,7 +193,7 @@ def fit(rank, world_size, cfg):
|
|||||||
"val_loss": avg_val_loss
|
"val_loss": avg_val_loss
|
||||||
}
|
}
|
||||||
Path("ckpts").mkdir(exist_ok=True, parents=True)
|
Path("ckpts").mkdir(exist_ok=True, parents=True)
|
||||||
torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt")
|
torch.save(save_dict, f"ckpts/{model_type}_{dataset_type}_best_model.pt")
|
||||||
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
|
print_with_timestamp(f"New best model saved at epoch {epoch+1}")
|
||||||
|
|
||||||
# 清除进程
|
# 清除进程
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# 模型与数据
|
# 模型与数据, 其中 dataset_type 应当和 train_objective_params_list 对应起来
|
||||||
model_type = "cca2"
|
model_type = "cca2"
|
||||||
output_type = "distance"
|
dataset_type = "objall"
|
||||||
dataset_dir = "F:/Datasets/MODatasetD"
|
dataset_dir = "F:/Datasets/MODatasetD"
|
||||||
# 训练参数
|
# 训练参数
|
||||||
batch_size = 64
|
batch_size = 64
|
||||||
@ -21,5 +21,5 @@ test_objective_params_list = [
|
|||||||
"40x-0.65-1.0000", "100x-0.80-1.0000",
|
"40x-0.65-1.0000", "100x-0.80-1.0000",
|
||||||
"100x-1.25-1.4730"
|
"100x-1.25-1.4730"
|
||||||
]
|
]
|
||||||
# 加载形式化预训练参数
|
# 断点加载
|
||||||
checkpoint_load = true
|
checkpoint_load = true
|
||||||
Loading…
x
Reference in New Issue
Block a user