diff --git a/MOAFDatasets.py b/MOAFDatasets.py index b394099..4d385d4 100644 --- a/MOAFDatasets.py +++ b/MOAFDatasets.py @@ -7,12 +7,11 @@ from torchvision import transforms class MOAFDataset(Dataset): - def __init__(self, dataset_root, tvt='train', objectives_params_list=None, output_type='distance'): + def __init__(self, dataset_root, tvt='train', objectives_params_list=None): """ dataset_root: 根目录(Pathable) tvt: 'train'|'val'|'test'(用于选择 transform) objectives_params_list: 列表,包含要加载的物镜目录名,例如 ["10x-0.25-1.0000", ...] - output_type: 'distance'(返回 nm)或 'ratio'(返回 defocus / DoF) """ super().__init__() self.dataset_root = Path(dataset_root) @@ -22,12 +21,6 @@ class MOAFDataset(Dataset): else: self.objectives_params_list = objectives_params_list - # 处理 output_type,非法输入回退到 'distance' - if isinstance(output_type, str) and output_type.lower() == "ratio": - self.output_type = "ratio" - else: - self.output_type = "distance" - # 根据 tvt 选择 transform if self.tvt == "train": self.transform = transforms.Compose([ @@ -101,30 +94,17 @@ class MOAFDataset(Dataset): rix_tensor = torch.tensor(rix, dtype=torch.float32) label_nm_tensor = torch.tensor(label_nm, dtype=torch.float32) - # min-max 归一化输入参数 - mag_tensor = (mag_tensor - 10) / (100 - 10) - na_tensor = (na_tensor - 0) / (1.25 - 0) - rix_tensor = (rix_tensor - 1.0) / (1.5 - 1.0) - - # 根据 output_type 决定输出 label - if self.output_type == "ratio": - dof_nm = self._compute_dof_nm(mag=mag, na=na, rix=rix, wavelength_nm=550.0, pixel_size_nm=3450.0) - # 若 DOF 为 inf 或极大,避免除零 - if not (dof_nm is None or dof_nm == float('inf') or dof_nm == 0): - label_out = label_nm / dof_nm - else: - label_out = label_nm # 回退,虽然不太可能 - label_out_tensor = torch.tensor(float(label_out), dtype=torch.float32) - else: - # distance 模式:直接返回 nm - label_out_tensor = label_nm_tensor + # # min-max 归一化输入参数 + # mag_tensor = (mag_tensor - 10) / (100 - 10) + # na_tensor = (na_tensor - 0) / (1.25 - 0) + # rix_tensor = (rix_tensor - 1.0) / (1.5 - 1.0) sample = { 'image': image, 'mag': mag_tensor, 'na': na_tensor, 'rix': rix_tensor, - 'label': label_out_tensor, + 'label': label_nm_tensor, 'path': img_path.as_posix(), } @@ -139,8 +119,7 @@ if __name__ == "__main__": train_set = MOAFDataset("F:/Datasets/MODatasetD", tvt='train', objectives_params_list=[ "100x-1.25-1.4730", - ], - output_type='ratio') + ]) from torch.utils.data import DataLoader train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2) for batch in train_loader: diff --git a/MOAFModels.py b/MOAFModels.py index 52eb3dd..b102a32 100644 --- a/MOAFModels.py +++ b/MOAFModels.py @@ -40,8 +40,14 @@ class ParamEmbedding(nn.Module): nn.LayerNorm(out_dim) ) - def forward(self, x): - return self.embedding(x) + def forward(self, params): + # min-max 归一化参数 + normalized_params = torch.stack([ + (params[:, 0] - 10.0) / 90.0, + params[:, 1] / 1.25, + (params[:, 2] - 1.0) / 0.5 + ], dim=1) + return self.embedding(normalized_params) # FiLM 融合块 @@ -238,3 +244,73 @@ class MOAFWithSE(nn.Module): x = self.regressor(x) return x +# 多回归头模型 +class MOAFWithMMLP(nn.Module): + def __init__(self, num_lenses=7): + super().__init__() + shuff = models.shufflenet_v2_x0_5(weights="DEFAULT") + self.features = nn.Sequential( + shuff.conv1, shuff.maxpool, + shuff.stage2, shuff.stage3, + shuff.stage4, shuff.conv5 + ) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.num_lenses = num_lenses + self.regressors = nn.ModuleList([ + self._create_regressor_head() for _ in range(num_lenses) + ]) + + def _create_regressor_head(self): + return nn.Sequential( + nn.Flatten(), + nn.Linear(1024, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 1) + ) + + def _find_lens_id(self, params): + batch_size = params.shape[0] + lens_ids = [] + + for i in range(batch_size): + mag, na, rix = params[i][0].item(), params[i][1].item(), params[i][2].item() + + # 直接的条件匹配(类似于switch-case) + if mag == 10 and na == 0.25 and rix == 1.0000: + lens_ids.append(0) # obj1 + elif mag == 10 and na == 0.30 and rix == 1.0000: + lens_ids.append(1) # obj2 + elif mag == 20 and na == 0.70 and rix == 1.0000: + lens_ids.append(2) # obj3 + elif mag == 20 and na == 0.80 and rix == 1.0000: + lens_ids.append(3) # obj4 + elif mag == 40 and na == 0.65 and rix == 1.0000: + lens_ids.append(4) # obj5 + elif mag == 100 and na == 0.80 and rix == 1.0000: + lens_ids.append(5) # obj6 + elif mag == 100 and na == 1.25 and rix == 1.4730: + lens_ids.append(6) # obj7 + else: + lens_ids.append(0) + + return torch.tensor(lens_ids, dtype=torch.long, device=params.device) + + def forward(self, image, params): + x = self.features(image) + x = self.avgpool(x) + + lens_ids = self._find_lens_id(params) + batch_size = params.size(0) + outputs = [] + for i in range(batch_size): + current_lens_id = lens_ids[i] + # 确保lens_id在有效范围内 + if current_lens_id < 0 or current_lens_id >= self.num_lenses: + current_lens_id = 0 + + # 选择对应的回归头 + head_output = self.regressors[current_lens_id](x[i].unsqueeze(0)) + outputs.append(head_output) + + return torch.cat(outputs, dim=0) + diff --git a/MOAFTest.py b/MOAFTest.py index 2b25f8c..a0761fa 100644 --- a/MOAFTest.py +++ b/MOAFTest.py @@ -9,10 +9,10 @@ from pathlib import Path from MOAFUtils import print_with_timestamp from MOAFDatasets import MOAFDataset -from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE +from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE, MOAFWithMMLP -def test(model, test_loader, device, model_type, output_type): +def test(model, test_loader, device, model_type, dataset_type): model.eval() results = [] @@ -30,7 +30,7 @@ def test(model, test_loader, device, model_type, output_type): df["pred"] = results Path("results").mkdir(exist_ok=True, parents=True) # !pip install openpyxl - df.to_excel(f"results/{model_type}_{output_type}.xlsx", index=False) + df.to_excel(f"results/{model_type}_{dataset_type}.xlsx", index=False) def main(): @@ -43,7 +43,7 @@ def main(): # 确定超参数 model_type = cfg["model_type"] - output_type = cfg["output_type"] + dataset_type = cfg["dataset_type"] dataset_dir = cfg["dataset_dir"] batch_size = int(cfg["batch_size"]) num_workers = int(cfg["num_workers"]) @@ -53,7 +53,7 @@ def main(): print_with_timestamp(f"Using device {device}") # 加载数据集 - test_set = MOAFDataset(dataset_dir, "test", objective_params_list, output_type) + test_set = MOAFDataset(dataset_dir, "test", objective_params_list) print_with_timestamp("Dataset Done") test_loader = DataLoader( @@ -72,15 +72,17 @@ def main(): elif "se" in model_type: fusion_depth_list = [int(ch) for ch in model_type[2:]] model = MOAFWithSE(fusion_depth_list).to(device) + elif "mmlp" in model_type: + model = MOAFWithMMLP(fusion_depth_list).to(device) else: model = MOAFNoFusion().to(device) - checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) + checkpoint = torch.load(f"ckpts/{model_type}_{dataset_type}_best_model.pt", map_location=device, weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) print_with_timestamp("Model Loaded") print_with_timestamp("Start testing") - test(model, test_loader, device, model_type, output_type) + test(model, test_loader, device, model_type, dataset_type) print_with_timestamp("Testing completed!") diff --git a/MOAFTrain.py b/MOAFTrain.py index 8ba7763..8dde796 100644 --- a/MOAFTrain.py +++ b/MOAFTrain.py @@ -12,7 +12,7 @@ from pathlib import Path from MOAFUtils import print_with_timestamp from MOAFDatasets import MOAFDataset -from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE +from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE, MOAFWithMMLP def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn): @@ -49,10 +49,10 @@ def valid_epoch(model, val_loader, epoch, epochs, device, loss_fn): return val_loss -def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type): +def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, dataset_type): best_val_loss = float('inf') # !pip install tensorboard - with SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") as writer: + with SummaryWriter(log_dir=f"runs/{model_type}_{dataset_type}") as writer: # Tensorboard 上显示模型结构 dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device) writer.add_graph(model, (dummy_input1, dummy_input2)) @@ -90,7 +90,7 @@ def fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, sch "val_loss": avg_val_loss } Path("ckpts").mkdir(exist_ok=True, parents=True) - torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt") + torch.save(save_dict, f"ckpts/{model_type}_{dataset_type}_best_model.pt") print_with_timestamp(f"New best model saved at epoch {epoch+1}") @@ -104,7 +104,7 @@ def main(): # 确定超参数 model_type = cfg["model_type"] - output_type = cfg["output_type"] + dataset_type = cfg["dataset_type"] dataset_dir = cfg["dataset_dir"] batch_size = int(cfg["batch_size"]) num_workers = int(cfg["num_workers"]) @@ -118,8 +118,8 @@ def main(): print_with_timestamp(f"Using device {device}") # 加载数据集 - train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type) - val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type) + train_set = MOAFDataset(dataset_dir, "train", objective_params_list) + val_set = MOAFDataset(dataset_dir, "val", objective_params_list) print_with_timestamp("Dataset Done") train_loader = DataLoader( @@ -142,14 +142,16 @@ def main(): elif "se" in model_type: fusion_depth_list = [int(ch) for ch in model_type[2:]] model = MOAFWithSE(fusion_depth_list).to(device) + elif "mmlp" in model_type: + model = MOAFWithMMLP(fusion_depth_list).to(device) else: model = MOAFNoFusion().to(device) print_with_timestamp("Model Loaded") # 形式化预训练参数加载 if checkpoint_load: - if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists(): - checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) + if Path(f"ckpts/{model_type}_{dataset_type}_best_model.pt").exists(): + checkpoint = torch.load(f"ckpts/{model_type}_{dataset_type}_best_model.pt", map_location=device, weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) print_with_timestamp("Model Checkpoint Params Loaded") else: @@ -165,7 +167,7 @@ def main(): ) print_with_timestamp("Start trainning") - fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, output_type) + fit(model, train_loader, val_loader, epochs, device, optimizer, loss_fn, scheduler, model_type, dataset_type) print_with_timestamp("Training completed!") diff --git a/MOAFTrainDDP.py b/MOAFTrainDDP.py index e9e65f9..bf95fb8 100644 --- a/MOAFTrainDDP.py +++ b/MOAFTrainDDP.py @@ -18,7 +18,7 @@ from torch.utils.tensorboard import SummaryWriter from MOAFUtils import print_with_timestamp from MOAFDatasets import MOAFDataset -from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE +from MOAFModels import MOAFNoFusion, MOAFWithFiLM, MOAFWithChannelCrossAttention, MOAFWithSE, MOAFWithMMLP def train_epoch(model, train_loader, epoch, epochs, device, optimizer, loss_fn, rank): @@ -71,7 +71,7 @@ def fit(rank, world_size, cfg): # 确定超参数 model_type = cfg["model_type"] - output_type = cfg["output_type"] + dataset_type = cfg["dataset_type"] dataset_dir = cfg["dataset_dir"] batch_size = int(cfg["batch_size"]) num_workers = int(cfg["num_workers"]) @@ -82,8 +82,8 @@ def fit(rank, world_size, cfg): checkpoint_load = cfg["checkpoint_load"] # 加载数据集 - train_set = MOAFDataset(dataset_dir, "train", objective_params_list, output_type) - val_set = MOAFDataset(dataset_dir, "val", objective_params_list, output_type) + train_set = MOAFDataset(dataset_dir, "train", objective_params_list) + val_set = MOAFDataset(dataset_dir, "val", objective_params_list) # 分布式化数据集 train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True) @@ -111,13 +111,15 @@ def fit(rank, world_size, cfg): elif "se" in model_type: fusion_depth_list = [int(ch) for ch in model_type[2:]] model = MOAFWithSE(fusion_depth_list).to(device) + elif "mmlp" in model_type: + model = MOAFWithMMLP(fusion_depth_list).to(device) else: model = MOAFNoFusion().to(device) # 形式化预训练参数加载 if checkpoint_load: - if Path(f"ckpts/{model_type}_{output_type}_best_model.pt").exists(): - checkpoint = torch.load(f"ckpts/{model_type}_{output_type}_best_model.pt", map_location=device, weights_only=False) + if Path(f"ckpts/{model_type}_{dataset_type}_best_model.pt").exists(): + checkpoint = torch.load(f"ckpts/{model_type}_{dataset_type}_best_model.pt", map_location=device, weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) if rank == 0: print_with_timestamp("Model Checkpoint Params Loaded") @@ -141,7 +143,7 @@ def fit(rank, world_size, cfg): # Tensorboard 上显示模型结构 if rank == 0: - tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{output_type}") + tb_writer = SummaryWriter(log_dir=f"runs/{model_type}_{dataset_type}") dummy_input1, dummy_input2 = torch.randn(5, 3, 384, 384).to(device), torch.randn(5, 3).to(device) tb_writer.add_graph(model.module, (dummy_input1, dummy_input2)) @@ -191,7 +193,7 @@ def fit(rank, world_size, cfg): "val_loss": avg_val_loss } Path("ckpts").mkdir(exist_ok=True, parents=True) - torch.save(save_dict, f"ckpts/{model_type}_{output_type}_best_model.pt") + torch.save(save_dict, f"ckpts/{model_type}_{dataset_type}_best_model.pt") print_with_timestamp(f"New best model saved at epoch {epoch+1}") # 清除进程 diff --git a/configs/config_example.toml b/configs/config_example.toml index ad66d2a..3192b07 100644 --- a/configs/config_example.toml +++ b/configs/config_example.toml @@ -1,6 +1,6 @@ -# 模型与数据 +# 模型与数据, 其中 dataset_type 应当和 train_objective_params_list 对应起来 model_type = "cca2" -output_type = "distance" +dataset_type = "objall" dataset_dir = "F:/Datasets/MODatasetD" # 训练参数 batch_size = 64 @@ -21,5 +21,5 @@ test_objective_params_list = [ "40x-0.65-1.0000", "100x-0.80-1.0000", "100x-1.25-1.4730" ] -# 加载形式化预训练参数 +# 断点加载 checkpoint_load = true \ No newline at end of file