• nnUnet代码分析一训练


    nnUnet是做分割的一套完整代码,用在医学图像分析中较多,效果还很不错。

    先看训练的代码 run_training.py

    一般用法:nnUNet_train 2d nnUNetTrainerV2 TaskXXX_MYTASK FOLD --npz

    2d代表2d Unet网络,nnUNetTrainerV2代表trainer,Task是任务id,

    还有其他的参数。详见代码。

    plans_file, output_folder_name, dataset_directory, batch_dice, stage, \

        trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)

    根据网络、任务、trainer,计划生产一个trainer_class.

    命令:nnUNet_train 2d  nnUNetTrainerV2 Task004_Hippocampus 1 --npz

    输出结果如下:

    ###############################################
    I am running the following nnUNet: 2d
    My trainer class is:  
    For that I will be using the following configuration:
    num_classes:  2
    modalities:  {0: 'MRI'}
    use_mask_for_norm OrderedDict([(0, False)])
    keep_only_largest_region None
    min_region_size_per_class None
    min_size_per_class None
    normalization_schemes OrderedDict([(0, 'nonCT')])
    stages...

    stage:  0
    {'batch_size': 366, 'num_pool_per_axis': [3, 3], 'patch_size': array([56, 40]), 'median_patient_size_in_voxels': array([36, 50, 35]), 'current_spacing': array([1., 1., 1.]), 'original_spacing': array([1., 1., 1.]), 'pool_op_kernel_sizes': [[2, 2], [2, 2], [2, 2]], 'conv_kernel_sizes': [[3, 3], [3, 3], [3, 3], [3, 3]], 'do_dummy_2D_data_aug': False}

    I am using stage 0 from these plans
    I am using batch dice + CE loss

    I am using data from this folder:  /mnt/nnUNet_preprocessed/Task004_Hippocampus/nnUNetData_plans_v2.1_2D
    ###############################################

    这里有一个batch dice 和sample dice,stage的概念不是很理解。

    按照上面的提示,我们用的trainer是'nnunet.training.network_training.nnUNetTrainerV2.nnUNetTrainerV2'> 

    继承自class nnUNetTrainer(NetworkTrainer):

    先看:NetworkTrainer.py,nnUNetTrainer.py

    do_split :5折验证

    训练过程没有什么特别,sgd+poly_lr

    loss是dice+ce

    数据增强也非常简单,只有缩放和旋转

        def setup_DA_params(self):

            """

            - we increase roation angle from [-15, 15] to [-30, 30]

            - scale range is now (0.7, 1.4), was (0.85, 1.25)

            - we don't do elastic deformation anymore

    有早停,patience=50.

    下面先跑一下海马体分割的例子。

    nnUNet_convert_decathlon_task -i /xxx/Task04_Hippocampus 可以把msd格式的数据转成nnUnet格式。

    nnUNet_plan_and_preprocess -t 4 配置plan和preprocess是nnunet中的重要一环。后续再研究

    接下去就是训练了:nnUNet_train 3d_fullres nnUNetTrainerV2 4 0

    msd中的海马体分割数据集不到30M,经过预处理后生成两个nnUNetData_plans_v2.1_2D_stage0,nnUNetData_plans_v2.1_stage0 ,每个目录有166M,还有一个gt_segmentation目录 ,应该是label?

    图像大小才30*50*30左右,,但是速度挺慢的。

    k80一个epoch超过了6分钟

    epoch:  0
    2022-09-03 00:48:13.676541: train loss : -0.3330
    2022-09-03 00:48:33.931392: validation loss: -0.7511
    2022-09-03 00:48:33.935500: Average global foreground Dice: [0.8332, 0.8122]
    2022-09-03 00:48:33.936098: (interpret this as an estimate for the Dice of the different classes. This is not exact.)
    2022-09-03 00:48:35.234157: lr: 0.009991
    2022-09-03 00:48:35.235549: This epoch took 360.665858 s

    2022-09-03 00:48:35.236174: 
    epoch:  1

    2080ti一个epoch 30s,快了12倍。

    epoch:  0
    2022-09-03 00:57:25.939667: train loss : -0.3076
    2022-09-03 00:57:28.172749: validation loss: -0.7510
    2022-09-03 00:57:28.174898: Average global foreground Dice: [0.8301, 0.8239]
    2022-09-03 00:57:28.175579: (interpret this as an estimate for the Dice of the different classes. This is not exact.)
    2022-09-03 00:57:29.656255: lr: 0.009991
    2022-09-03 00:57:29.657619: This epoch took 30.441361 s

    2022-09-03 00:57:29.658239: 
    epoch:  1

    loss为什么是负的呢?

    海马体分割有两个区域,2018年nnunet的成绩是0.90,0.89

    10个epoch可以到0.89,0.877

    33个epoch时得到: [0.9017, 0.8877]

    2022-09-03 01:13:15.730579: 
    epoch:  33
    2022-09-03 01:13:40.102154: train loss : -0.8661
    2022-09-03 01:13:41.597394: validation loss: -0.8594
    2022-09-03 01:13:41.599368: Average global foreground Dice: [0.9017, 0.8877]
    2022-09-03 01:13:41.599944: (interpret this as an estimate for the Dice of the different classes. This is not exact.)
    2022-09-03 01:13:43.128609: lr: 0.009693
    2022-09-03 01:13:43.135475: saving checkpoint...
    2022-09-03 01:13:44.503204: done, saving took 1.37 seconds
    2022-09-03 01:13:44.544393: This epoch took 28.813107 s

    2022-09-03 01:13:44.545232: 

    换成local loss后:

    10 epoch 到 [0.8562, 0.8408]

    33epochs 到[0.8769, 0.8605] 34 [0.8808, 0.8643]

  • 相关阅读:
    flutter系列之:移动端的手势基础GestureDetector
    如何将网站部署到浏览器?
    项目(交友)
    会计制度设计
    什么是SpringCloud Alibaba Nacos注册中心
    NVMe协议详解(一)
    javaee spring 测试aop 切面
    odoo16 取消“系统各功能状态日报”的邮件
    MybatisPlus【SpringBoot】 3 基本CRUD
    【C/C++】类型转换
  • 原文地址:https://blog.csdn.net/txdb/article/details/126655264