中原彼得堡航空学院网站的建设,买的服务器做两个网站,网站建设怎么设置多语言,建立公司网站视频在上一个关于3D 目标的任务#xff0c;是基于普通CNN网络的3D分类任务。在这个任务中#xff0c;分类数据采用的是CT结节的LIDC-IDRI数据集#xff0c;其中对结节的良恶性、毛刺、分叶征等等特征进行了各自的等级分类。感兴趣的可以直接点击下方的链接#xff0c;直达学习是基于普通CNN网络的3D分类任务。在这个任务中分类数据采用的是CT结节的LIDC-IDRI数据集其中对结节的良恶性、毛刺、分叶征等等特征进行了各自的等级分类。感兴趣的可以直接点击下方的链接直达学习
【3D图像分类】基于Pytorch的3D立体图像分类1基础篇【3D图像分类】基于Pytorch的3D立体图像分类2数据增强篇
在开始本次关于3D 目标的分割任务前呢我还是建议先去看看上述较为简单的分类任务毕竟大多数是相似的有很高的借鉴意义。
一、导言
准备一个训练需要下面这些内容组成
准备数据准备网络搭建训练主模型 train one epochvalid one epoch存储模型存储指标 loss 函数dice coeff 评估指标optimizer优化方式
其中在本项目中
网络采用vnet 3d模型数据采用patch裁剪大小loss函数未dice loss评价指标是dice coeff optimizer优化方式是SGD
二、搭建主结构
训练的主体结构骨架总数包括几个部分
config可调参数定义包括数据路径、图像大小、类别数量、学习率、batch size等等main主函数包括 构建模型构建数据优化器学习率变化方式损失函数评估指标训练batch循环验证batch循环 后处理包括模型参数存储指标走势绘图等等。
上面这些个内容基本上是囊括了深度学习模型训练的整体结构了后面的工作就是对每一部分进行补充。就犹如已经有了骨架后续就是补充肉身了。
后面给出的这个pytorch骨架案例也是后面再构建训练任务一个可以参考的依据可收藏。
2.1、导入库和配置参数
import os
import matplotlib.pyplot as plt
import torch.utils.data
import torch.optim as optimfrom datasets.datasets import myDatasetos.environ[CUDA_VISIBLE_DEVICES] 0, 1, 2, 3 # 使用gpu0
DEVICE torch.device(cuda if torch.cuda.is_available() else cpu) # 没gpu就用cpu
print(DEVICE)############################################################
# Configuration
############################################################
class Configuration(object):train_path r./database/sk_output/trainvalid_path r./database/sk_output/validmodel_path r./checkpointsCrop_Size (48, 96, 96)num_outs 2Batch_Train 32Batch_Test 16Max_epoch 220Num_Workers 8Dice_Best 0LR 0.0003momentum 0.99weight_decay 1e-8def display(self):Display Configuration values.print(\nConfigurations:)print()for a in dir(self):if not a.startswith(__) and not callable(getattr(self, a)):print({:30} {}.format(a, getattr(self, a)))print(\n)2.2、构建main主函数
def main():Config Configuration()Config.display()train_loader, valid_loader get_Dataloader(Config)model get_model(Config).to(DEVICE)# ---- OPTIMIZER ----optimizer optim.SGD(model.parameters(), lrConfig.LR, momentumConfig.momentum, weight_decayConfig.weight_decay)train_loss_list [] # 用来记录训练损失valid_loss_list [] # 用来记录验证损失valid_dice_list []epoch_list []for epoch in range(1, Config.Max_epoch 1):epoch_list.append(epoch)train_loss train_model(model, DEVICE, train_loader, optimizer, epoch) # 训练valid_loss, valid_dice valid_model(model, DEVICE, valid_loader, epoch) # 验证train_loss_list.append(train_loss)valid_loss_list.append(valid_loss)valid_dice_list.append(valid_dice)draw_plot(epoch_list, valid_dice_list, valid_dice)draw_plot(epoch_list, valid_loss_list, valid_loss)draw_plot(epoch_list, train_loss_list, train_loss)if valid_dice Config.Dice_Best:path_ckpt os.path.join(Config.model_path, best_model.pth)save_model(path_ckpt, model)Config.Dice_Best valid_diceelse:path_ckpt os.path.join(Config.model_path, last_model.pth)save_model(path_ckpt, model)print(best val Dice is , Config.Dice_Best)if __name__ __main__:main()2.3、构建获取模型和数据的函数
def get_model(config):from models.vnet3d import VNet3Dmodel VNet3D(num_outsconfig.num_outs, channels16)model model.to(DEVICE) # 模型部署到gpu或cpu里model torch.nn.DataParallel(model).to(DEVICE)return modeldef get_Dataloader(config):# get train datadataset_train myDataset(config.train_path, config.Crop_Size, isTrainTrue)print(len(dataset_train))train_loader torch.utils.data.DataLoader(dataset_train,batch_sizeconfig.Batch_Train, shuffleTrue,num_workersconfig.Num_Workers, drop_lastFalse)# get valid datadataset_valid myDataset(config.valid_path, config.Crop_Size, isTrainFalse)valid_loader torch.utils.data.DataLoader(dataset_valid,batch_sizeconfig.Batch_Test, shuffleFalse,num_workersconfig.Num_Workers, drop_lastFalse)return train_loader, valid_loader2.4、构建训练循环和验证循环
def train_model(model, device, train_loader, optimizer, epoch):config Configuration()model.train()for batch_index, (data, target) in enumerate(train_loader): # 取batch索引datatarget也就是图和标签data, target data.to(device), target.to(device)output model(data)loss Loss(output, target)optimizer.zero_grad() # 梯度归零loss.backward() # 反向传播optimizer.step() # 优化器走一步return losses.avg # 返回平均损失损失列表def valid_model(model, device, test_loader, epoch):config Configuration()model.eval()with torch.no_grad(): # 不进行 梯度计算反向传播for batch_index, (data, target) in enumerate(test_loader): # 枚举batch索引图标签data, target data.to(device), target.to(device)output model(data)loss Loss(output, target) # 计算损失return losses.avg, multi_dices.avg
2.5、后处理
保存模型的参数和绘制训练过程中train loss、valid loss以及valid dice走势图如下
def draw_plot(x_list, y_list, title_name):plt.plot(x_list, y_list, labeltitle_name)plt.xlabel(x, fontsize15)plt.ylabel(y, fontsize15)plt.title(title_name, fontsize15)plt.savefig(./logs/cure.png)def save_model(path, model):if isinstance(model, torch.nn.DataParallel):state_dict model.module.state_dict()else:state_dict model.state_dict()torch.save(state_dict, path)至此每一个模块都有了对应的归宿后面就是如何将缺漏的地方补全过程了。反倒是这部分的代码相对较少两大需要单独验证的数据和模型是大头其他就好办了。
三、总结
本文是关于Pytorch 的 VNet 3D 图像分割的第一篇也就是一个综述篇主要是对这个项目的任务目的以及其中的一个流程进行了梳理。
上述的骨干代码还不能够作为训练使用还需要补充进去骨肉才能够适应不同的任务这一块的内容将会在后面的几个篇章中一一陈述。
如果你也在做类似的事情欢迎点赞、收藏mark住。对于这部分的内容可以一起交流欢迎多多评论。