• AI实战营第二期 第八节 《MMSegmentation代码课》——笔记9


    AI实战营第二期 第八节 《MMSegmentation代码课》

    【课程链接】https://www.bilibili.com/video/BV1uh411T73q/
    【讲师介绍】张子豪 OpenMMLab算法工程师
    【学习形式】录播+社群答疑
    【作业布置】本次课程为实战课,需提交笔记+作业。
    在这里插入图片描述

    课程大纲:

    • 环境配置
    • 预训练模型预测图片、视频
    • 航拍图像语义分割案例
    • 肾小球病理切片语义分割案例

    作业:

    • 西瓜瓤、西瓜皮、西瓜籽像素级语义分割

    安装配置MMSegmentation

    安装pytorch

    pip install install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
    
    • 1

    用MIM安装MMCV

    pip install -U openmim
    mim install mmengine
    mim install 'mmcv==2.0.0rc4'
    
    • 1
    • 2
    • 3

    安装其它工具包

    pip install opencv-python pillow matplotlib seaborn tqdm pytorch-lightning 'mmdet>=3.0.0rc1' -i https://pypi.tuna.tsinghua.edu.cn/simple
    
    • 1

    下载 与安装MMSegmentation

    # 删掉原有的 mmsegmentation 文件夹(如有)
    rm -rf mmsegmentation
    # 从 github 上下载最新的 mmsegmentation 源代码
    git clone https://github.com/open-mmlab/mmsegmentation.git -b dev-1.x
    pip install -v -e .
    
    • 1
    • 2
    • 3
    • 4
    • 5

    下载预训练模型权重文件和视频素材

    import os
    
    # 创建 checkpoint 文件夹,用于存放预训练模型权重文件
    os.mkdir('checkpoint')
    
    # 创建 outputs 文件夹,用于存放预测结果
    os.mkdir('outputs')
    
    # 创建 data 文件夹,用于存放图片和视频素材
    os.mkdir('data')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    下载预训练模型权重至checkpoint目录

    Model Zoo:https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/model_zoo.md

     从 Model Zoo 获取 PSPNet 预训练模型,下载并保存在 checkpoint 文件夹中
    !wget https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth -P checkpoint
    
    
    • 1
    • 2
    • 3

    下载素材至data目录

    如果报错Unable to establish SSL connection.,重新运行代码块即可。

    # 伦敦街景图片
    !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_uk.jpeg -P data
    
    # 上海驾车街景视频,视频来源:https://www.youtube.com/watch?v=ll8TgCZ0plk
    !wget https://zihao-download.obs.cn-east-3.myhuaweicloud.com/detectron2/traffic.mp4 -P data
    
    # 街拍视频,2022年3月30日
    !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_20220330_174028.mp4 -P data
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    检查安装成功

    # 检查 Pytorch
    import torch, torchvision
    print('Pytorch 版本', torch.__version__)
    print('CUDA 是否可用',torch.cuda.is_available())
    
    • 1
    • 2
    • 3
    • 4
    # 检查 mmcv
    import mmcv
    from mmcv.ops import get_compiling_cuda_version, get_compiler_version
    print('MMCV版本', mmcv.__version__)
    print('CUDA版本', get_compiling_cuda_version())
    print('编译器版本', get_compiler_version())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    # 检查 mmsegmentation
    import mmseg
    from mmseg.utils import register_all_modules
    from mmseg.apis import inference_model, init_model
    print('mmsegmentation版本', mmseg.__version__)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    设置Matplotlib中文字体

    # # windows操作系统
    # plt.rcParams['font.sans-serif']=['SimHei']  # 用来正常显示中文标签 
    # plt.rcParams['axes.unicode_minus']=False  # 用来正常显示负号
    # Mac操作系统,参考 https://www.ngui.cc/51cto/show-727683.html
    # 下载 simhei.ttf 字体文件
    # !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf
    
    # Linux操作系统,例如 云GPU平台:https://featurize.cn/?s=d7ce99f842414bfcaea5662a97581bd1
    # 如果遇到 SSL 相关报错,重新运行本代码块即可
    !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf -O /environment/miniconda3/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf/SimHei.ttf
    !rm -rf /home/featurize/.cache/matplotlib
    
    import matplotlib 
    import matplotlib.pyplot as plt
    matplotlib.rc("font",family='SimHei') # 中文字体
    
    plt.plot([1,2,3], [100,500,300])
    plt.title('matplotlib中文字体测试', fontsize=25)
    plt.xlabel('X轴', fontsize=15)
    plt.ylabel('Y轴', fontsize=15)
    plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    预训练语义分割模型预测

    进入 mmsegmentation 主目录

    import os
    os.chdir('mmsegmentation')
    
    • 1
    • 2

    载入测试图像

    from PIL import Image
    
    • 1

    MMSegmentation模型库

    Model Zoo:https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/model_zoo.md

    常用config和checkpoint文件

    configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py

    https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth

    configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py

    https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_8x1_1024x1024_160k_cityscapes/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth

    configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py

    https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth

    PSPNet

    PSPNet语义分割算法

    !python demo/image_demo.py \
            data/street_uk.jpeg \
            configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \
            https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth \
            --out-file outputs/B1_uk_pspnet.jpg \
            --device cuda:0 \
            --opacity 0.5
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    在这里插入图片描述

    SegFormer

    !python demo/image_demo.py \
            data/street_uk.jpeg \
            configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py \
            https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_8x1_1024x1024_160k_cityscapes/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth \
            --out-file outputs/B1_uk_segformer.jpg \
            --device cuda:0 \
            --opacity 0.5
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    在这里插入图片描述

    Mask2Former

    !python demo/image_demo.py \
            data/street_uk.jpeg \
            configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py \
            https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth \
            --out-file outputs/B1_uk_Mask2Former.jpg \
            --device cuda:0 \
            --opacity 0.5
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    ADE20K语义分割数据集

    mmsegmentation/mmseg/datasets/ade.py

    关于ADE20K的故事:https://www.zhihu.com/question/390783647/answer/1226097849

    !python demo/image_demo.py \
            data/street_uk.jpeg \
            configs/segformer/segformer_mit-b5_8xb2-160k_ade20k-640x640.py \
            https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20220617_203542-940a6bd8.pth \
            --out-file outputs/B1_Segformer_ade20k.jpg \
            --device cuda:0 \
            --opacity 0.5
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    在这里插入图片描述

    预训练语义分割模型预测-视频

    视频预测-命令行(不推荐,慢)

    !python demo/video_demo.py \
            data/street_20220330_174028.mp4 \
            configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py \
            https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth \
            --device cuda:0 \
            --output-file outputs/B3_video.mp4 \
            --opacity 0.5
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    视频预测-Python API(推荐,快)

    import numpy as np
    import time
    import shutil
    
    import torch
    
    from PIL import Image
    import cv2
    
    import mmcv
    import mmengine
    from mmseg.apis import inference_model
    from mmseg.utils import register_all_modules
    register_all_modules()
    
    from mmseg.datasets import CityscapesDataset
    
    # 模型 config 配置文件
    config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'
    
    # 模型 checkpoint 权重文件
    checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'
    from mmseg.apis import init_model
    model = init_model(config_file, checkpoint_file, device='cuda:0')
    
    from mmengine.model.utils import revert_sync_batchnorm
    if not torch.cuda.is_available():
        model = revert_sync_batchnorm(model)
    # input_video = 'data/traffic.mp4'
    
    input_video = 'data/street_20220330_174028.mp4'
    
    temp_out_dir = time.strftime('%Y%m%d%H%M%S')
    os.mkdir(temp_out_dir)
    print('创建临时文件夹 {} 用于存放每帧预测结果'.format(temp_out_dir))
    # 获取 Cityscapes 街景数据集 类别名和调色板
    from mmseg.datasets import cityscapes
    classes = cityscapes.CityscapesDataset.METAINFO['classes']
    palette = cityscapes.CityscapesDataset.METAINFO['palette']
    def pridict_single_frame(img, opacity=0.2):
        
        result = inference_model(model, img)
        
        # 将分割图按调色板染色
        seg_map = np.array(result.pred_sem_seg.data[0].detach().cpu().numpy()).astype('uint8')
        seg_img = Image.fromarray(seg_map).convert('P')
        seg_img.putpalette(np.array(palette, dtype=np.uint8))
        
        show_img = (np.array(seg_img.convert('RGB')))*(1-opacity) + img*opacity
        
        return show_img
    # 读入待预测视频
    imgs = mmcv.VideoReader(input_video)
    
    prog_bar = mmengine.ProgressBar(len(imgs))
    
    # 对视频逐帧处理
    for frame_id, img in enumerate(imgs):
        
        ## 处理单帧画面
        show_img = pridict_single_frame(img, opacity=0.15)
        temp_path = f'{temp_out_dir}/{frame_id:06d}.jpg' # 保存语义分割预测结果图像至临时文件夹
        cv2.imwrite(temp_path, show_img)
    
        prog_bar.update() # 更新进度条
    
    # 把每一帧串成视频文件
    mmcv.frames2video(temp_out_dir, 'outputs/B3_video.mp4', fps=imgs.fps, fourcc='mp4v')
    
    shutil.rmtree(temp_out_dir) # 删除存放每帧画面的临时文件夹
    print('删除临时文件夹', temp_out_dir)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71

    Kaggle实战-迪拜卫星航拍多类别语义分割

    下载整理好的数据集

    Kaggle原版数据集:https://www.kaggle.com/datasets/humansintheloop/semantic-segmentation-of-aerial-imagery

    下载整理好之后的数据集

    !rm -rf Dubai-dataset.zip Dubai-dataset
    
    !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/Dubai-dataset.zip
    
    !unzip Dubai-dataset.zip >> /dev/null # 解压
    
    !rm -rf Dubai-dataset.zip # 删除压缩包
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    删除系统自动生成的多余文件

    查看待删除的多余文件

    !find . -iname '__MACOSX'
    !find . -iname '__MACOSX'
    !find . -iname '.DS_Store'
    !find . -iname '.ipynb_checkpoints'
    ./.ipynb_checkpoints
    
    • 1
    • 2
    • 3
    • 4
    • 5

    删除多余文件

    !for i in `find . -iname '__MACOSX'`; do rm -rf $i;done
    !for i in `find . -iname '.DS_Store'`; do rm -rf $i;done
    !for i in `find . -iname '.ipynb_checkpoints'`; do rm -rf $i;done
    
    • 1
    • 2
    • 3

    验证多余文件已删除

    !find . -iname '__MACOSX'
    !find . -iname '.DS_Store'
    !find . -iname '.ipynb_checkpoints'
    
    • 1
    • 2
    • 3

    ​可视化探索数据集

    import os
    
    import cv2
    import numpy as np
    from PIL import Image
    from tqdm import tqdm
    
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    # 指定单张图像路径
    img_path = 'Dubai-dataset/img_dir/train/14.jpg'
    mask_path = 'Dubai-dataset/ann_dir/train/14.png'
    
    img = cv2.imread(img_path)
    mask = cv2.imread(mask_path)
    # 可视化语义分割标注
    plt.imshow(mask[:,:,0])
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    在这里插入图片描述

    plt.imshow(img[:,:,::-1])
    plt.imshow(mask[:,:,0], alpha=0.4) # alpha 高亮区域透明度,越小越接近原图
    plt.axis('off')
    plt.show()
    
    • 1
    • 2
    • 3
    • 4

    在这里插入图片描述
    批量可视化图像和标注

    # 指定图像和标注路径
    PATH_IMAGE = 'Dubai-dataset/img_dir/train'
    PATH_MASKS = 'Dubai-dataset/ann_dir/train'
    # n行n列可视化
    n = 5
    
    # 标注区域透明度
    opacity = 0.5
    
    fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, figsize=(12,12))
    
    for i, file_name in enumerate(os.listdir(PATH_IMAGE)[:n**2]):
        
        # 载入图像和标注
        img_path = os.path.join(PATH_IMAGE, file_name)
        mask_path = os.path.join(PATH_MASKS, file_name.split('.')[0]+'.png')
        img = cv2.imread(img_path)
        mask = cv2.imread(mask_path)
        
        # 可视化
        axes[i//n, i%n].imshow(img)
        axes[i//n, i%n].imshow(mask[:,:,0], alpha=opacity)
        axes[i//n, i%n].axis('off') # 关闭坐标轴显示
    fig.suptitle('Image and Semantic Label', fontsize=30)
    plt.tight_layout()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    准备config配置文件

    import numpy as np
    from PIL import Image
    
    import os.path as osp
    from tqdm import tqdm
    
    import mmcv
    import mmengine
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    定义数据集类(各类别名称及配色)

    !rm -rf mmseg/datasets/DubaiDataset.py # 删除原有文件
    !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/Dubai/DubaiDataset.py -P mmseg/datasets
    
    • 1
    • 2

    注册数据集类

    !rm -rf mmseg/datasets/__init__.py # 删除原有文件
    !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/Dubai/__init__.py -P mmseg/datasets
    
    • 1
    • 2

    定义训练及测试pipeline

    !rm -rf configs/_base_/datasets/DubaiDataset_pipeline.py
    !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/Dubai/DubaiDataset_pipeline.py -P configs/_base_/datasets
    
    • 1
    • 2

    下载模型config配置文件

    !rm -rf configs/pspnet/pspnet_r50-d8_4xb2-40k_DubaiDataset.py # 删除原有文件
    !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/Dubai/pspnet_r50-d8_4xb2-40k_DubaiDataset.py -P configs/pspnet 
    
    • 1
    • 2

    载入config配置文件

    from mmengine import Config
    cfg = Config.fromfile('./configs/pspnet/pspnet_r50-d8_4xb2-40k_DubaiDataset.py')
    
    • 1
    • 2

    修改config配置文件

    cfg.norm_cfg = dict(type='BN', requires_grad=True) # 只使用GPU时,BN取代SyncBN
    cfg.crop_size = (256, 256)
    cfg.model.data_preprocessor.size = cfg.crop_size
    cfg.model.backbone.norm_cfg = cfg.norm_cfg
    cfg.model.decode_head.norm_cfg = cfg.norm_cfg
    cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
    # modify num classes of the model in decode/auxiliary head
    
    # 模型 decode/auxiliary 输出头,指定为类别个数
    cfg.model.decode_head.num_classes = 6
    cfg.model.auxiliary_head.num_classes = 6
    
    cfg.train_dataloader.batch_size = 8
    
    cfg.test_dataloader = cfg.val_dataloader
    
    # 结果保存目录
    cfg.work_dir = './work_dirs/DubaiDataset'
    
    # 训练迭代次数
    cfg.train_cfg.max_iters = 3000
    # 评估模型间隔
    cfg.train_cfg.val_interval = 400
    # 日志记录间隔
    cfg.default_hooks.logger.interval = 100
    # 模型权重保存间隔
    cfg.default_hooks.checkpoint.interval = 1500
    
    # 随机数种子
    cfg['randomness'] = dict(seed=0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    查看完整config配置文件

    print(cfg.pretty_text)
    
    • 1

    保存config配置文件

    cfg.dump('pspnet-DubaiDataset_20230612.py')
    
    • 1

    MMSegmentation训练语义分割模型

    import numpy as np
    
    import os.path as osp
    from tqdm import tqdm
    
    import mmcv
    import mmengine
    
    from mmengine import Config
    cfg = Config.fromfile('pspnet-DubaiDataset_20230612.py')
    
    from mmengine.runner import Runner
    from mmseg.utils import register_all_modules
    
    # register all modules in mmseg into the registries
    # do not init the default scope here because it will be init in the runner
    register_all_modules(init_default_scope=False)
    runner = Runner.from_cfg(cfg)
    
    runner.train()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    开始训练

    如果遇到报错CUDA out of memeory,可尝试以下步骤:

    • 调小 batch size

    • 左上角内核-关闭所有内核

    • 重启实例,或者使用显存更高的实例即可。

    可视化训练日志

    设置Matplotlib中文字体

    # # windows操作系统
    # plt.rcParams['font.sans-serif']=['SimHei']  # 用来正常显示中文标签 
    # plt.rcParams['axes.unicode_minus']=False  # 用来正常显示负号
    # Mac操作系统,参考 https://www.ngui.cc/51cto/show-727683.html
    # 下载 simhei.ttf 字体文件
    # !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf
    
    # Linux操作系统,例如 云GPU平台:https://featurize.cn/?s=d7ce99f842414bfcaea5662a97581bd1
    # 如果遇到 SSL 相关报错,重新运行本代码块即可
    !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf -O /environment/miniconda3/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf/SimHei.ttf
    !rm -rf /home/featurize/.cache/matplotlib
    
    import matplotlib 
    import matplotlib.pyplot as plt
    matplotlib.rc("font",family='SimHei') # 中文字体
    
    plt.plot([1,2,3], [100,500,300])
    plt.title('matplotlib中文字体测试', fontsize=25)
    plt.xlabel('X轴', fontsize=15)
    plt.ylabel('Y轴', fontsize=15)
    plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    导入工具包
    import pandas as pd

    import matplotlib.pyplot as plt
    %matplotlib inline
    载入训练日志

    日志文件路径

    log_path = 'work_dirs/DubaiDataset/20230612_100725/vis_data/scalars.json'
    with open(log_path, "r") as f:
        json_list = f.readlines()
    len(json_list)
    
    eval(json_list[4])
    
    df_train = pd.DataFrame()
    df_test = pd.DataFrame()
    for each in json_list[:-1]:
        if 'aAcc' in each:
            df_test = df_test.append(eval(each), ignore_index=True)
        else:
            df_train = df_train.append(eval(each), ignore_index=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    导出训练日志表格

    df_train.to_csv('训练日志-训练集.csv', index=False)
    df_test.to_csv('训练日志-测试集.csv', index=False)
    
    • 1
    • 2

    可视化辅助函数

    from matplotlib import colors as mcolors
    import random
    random.seed(124)
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan', 'black', 'indianred', 'brown', 'firebrick', 'maroon', 'darkred', 'red', 'sienna', 'chocolate', 'yellow', 'olivedrab', 'yellowgreen', 'darkolivegreen', 'forestgreen', 'limegreen', 'darkgreen', 'green', 'lime', 'seagreen', 'mediumseagreen', 'darkslategray', 'darkslategrey', 'teal', 'darkcyan', 'dodgerblue', 'navy', 'darkblue', 'mediumblue', 'blue', 'slateblue', 'darkslateblue', 'mediumslateblue', 'mediumpurple', 'rebeccapurple', 'blueviolet', 'indigo', 'darkorchid', 'darkviolet', 'mediumorchid', 'purple', 'darkmagenta', 'fuchsia', 'magenta', 'orchid', 'mediumvioletred', 'deeppink', 'hotpink']
    markers = [".",",","o","v","^","<",">","1","2","3","4","8","s","p","P","*","h","H","+","x","X","D","d","|","_",0,1,2,3,4,5,6,7,8,9,10,11]
    linestyle = ['--', '-.', '-']def get_line_arg():
        '''
        随机产生一种绘图线型
        '''
        line_arg = {}
        line_arg['color'] = random.choice(colors)
        # line_arg['marker'] = random.choice(markers)
        line_arg['linestyle'] = random.choice(linestyle)
        line_arg['linewidth'] = random.randint(1, 4)
        # line_arg['markersize'] = random.randint(3, 5)
        return line_arg
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    训练集损失函数

    metrics = ['loss', 'decode.loss_ce', 'aux.loss_ce']
    plt.figure(figsize=(16, 8))
    ​
    x = df_train['step']
    for y in metrics:
        plt.plot(x, df_train[y], label=y, **get_line_arg())
    ​
    plt.tick_params(labelsize=20)
    plt.xlabel('step', fontsize=20)
    plt.ylabel('loss', fontsize=20)
    plt.title('训练集损失函数', fontsize=25)
    plt.savefig('训练集损失函数.pdf', dpi=120, bbox_inches='tight')
    ​
    plt.legend(fontsize=20)
    ​
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    在这里插入图片描述

    训练集准确率

    plt.figure(figsize=(16, 8))
    ​
    x = df_train['step']
    for y in metrics:
        plt.plot(x, df_train[y], label=y, **get_line_arg())
    ​
    plt.tick_params(labelsize=20)
    plt.xlabel('step', fontsize=20)
    plt.ylabel('loss', fontsize=20)
    plt.title('训练集准确率', fontsize=25)
    plt.savefig('训练集准确率.pdf', dpi=120, bbox_inches='tight')
    ​
    plt.legend(fontsize=20)
    ​
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    在这里插入图片描述

    测试集评估指标

    plt.figure(figsize=(16, 8))
    ​
    x = df_test['step']
    for y in metrics:
        plt.plot(x, df_test[y], label=y, **get_line_arg())
    ​
    plt.tick_params(labelsize=20)
    plt.ylim([0, 100])
    plt.xlabel('step', fontsize=20)
    plt.ylabel(y, fontsize=20)
    plt.title('测试集评估指标', fontsize=25)
    plt.savefig('测试集分类评估指标.pdf', dpi=120, bbox_inches='tight')
    ​
    plt.legend(fontsize=20)
    ​
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    在这里插入图片描述

    模型预测

    导入工具包

    import numpy as np
    import matplotlib.pyplot as plt
    %matplotlib inline
    ​
    from mmseg.apis import init_model, inference_model, show_result_pyplot
    import mmcv
    import cv2
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    载入配置文件

    # 载入 config 配置文件
    from mmengine import Config
    cfg = Config.fromfile('pspnet-DubaiDataset_20230612.py')
    from mmengine.runner import Runner
    from mmseg.utils import register_all_modules
    ​
    # register all modules in mmseg into the registries
    # do not init the default scope here because it will be init in the runner
    ​
    register_all_modules(init_default_scope=False)
    runner = Runner.from_cfg(cfg)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    载入模型

    checkpoint_path = './work_dirs/DubaiDataset/iter_3000.pth'
    model = init_model(cfg, checkpoint_path, 'cuda:0')
    
    • 1
    • 2

    载入测试集图像,或新图像

    img = mmcv.imread('Dubai-dataset/img_dir/val/71.jpg')
    
    • 1

    语义分割预测

    result = inference_model(model, img)
    pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
    np.unique(pred_mask)
    array([0, 1, 2, 3, 4])
    plt.imshow(pred_mask)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    在这里插入图片描述

    可视化预测结果

    visualization = show_result_pyplot(model, img, result, opacity=0.7, out_file='pred.jpg')
    plt.imshow(mmcv.bgr2rgb(visualization))
    plt.show()
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    获取测试集标注

    获取测试集标注
    label = mmcv.imread('Dubai-dataset/ann_dir/val/71.png')
    label_mask = label[:,:,0]
    np.unique(label_mask)
    plt.imshow(label_mask)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    在这里插入图片描述

    对比测试集标注和语义分割预测结果

    # 测试集标注
    
    # 真实为前景,预测为前景
    TP = (label_mask == 1) & (pred_mask==1)
    # 真实为背景,预测为背景
    TN = (label_mask == 0) & (pred_mask==0)
    # 真实为前景,预测为背景
    FN = (label_mask == 1) & (pred_mask==0)
    # 真实为背景,预测为前景
    FP = (label_mask == 0) & (pred_mask==1)
    plt.imshow(TP)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    在这里插入图片描述

    confusion_map = TP * 255 + FP * 150 + FN * 80 + TN * 30
    plt.imshow(confusion_map)
    plt.show()
    
    • 1
    • 2
    • 3

    在这里插入图片描述
    混淆矩阵

    from sklearn.metrics import confusion_matrix
    confusion_matrix_model = confusion_matrix(label_mask.flatten(), pred_mask.flatten())
    confusion_matrix_model
    
    import itertools
    def cnf_matrix_plotter(cm, classes, cmap=plt.cm.Blues):
        """
        传入混淆矩阵和标签名称列表,绘制混淆矩阵
        """
        plt.figure(figsize=(10, 10))
        
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        # plt.colorbar() # 色条
        tick_marks = np.arange(len(classes))
        
        plt.title('Confusion Matrix', fontsize=30)
        plt.xlabel('Pred', fontsize=25, c='r')
        plt.ylabel('True', fontsize=25, c='r')
        plt.tick_params(labelsize=16) # 设置类别文字大小
        plt.xticks(tick_marks, classes, rotation=90) # 横轴文字旋转
        plt.yticks(tick_marks, classes)
        
        # 写数字
        threshold = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, cm[i, j],
                     horizontalalignment="center",
                     color="white" if cm[i, j] > threshold else "black",
                     fontsize=12)
    ​
        plt.tight_layout()
    ​
        plt.savefig('混淆矩阵.pdf', dpi=300) # 保存图像
        plt.show()
    classes = ['Land', 'Road', 'Building', 'Vegetation', 'Water', 'Unlabeled']
    cnf_matrix_plotter(confusion_matrix_model, classes, cmap='Blues')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36

    在这里插入图片描述

    测试集性能评估

    测试集精度指标

    python tools/test.py pspnet-DubaiDataset_20230612.py work_dirs/DubaiDataset/iter_3000.pth
    
    • 1

    速度指标-FPS

    !python tools/analysis_tools/benchmark.py pspnet-DubaiDataset_20230612.py work_dirs/DubaiDataset/iter_3000.pth
    
    • 1

    Kaggle实战-小鼠肾小球组织病理切片语义分割

    下载整理好的数据集

    下载数据集

    !rm -rf Glomeruli-dataset.zip Glomeruli-dataset
    !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/Glomeruli-dataset.zip
    !unzip Glomeruli-dataset.zip >> /dev/null # 解压
    !rm -rf Glomeruli-dataset.zip # 删除压缩包
    
    • 1
    • 2
    • 3
    • 4

    删除系统自动生成的多余文件
    查看待删除的多余文件

    !find . -iname '__MACOSX'
    !find . -iname '.DS_Store'
    !find . -iname '.ipynb_checkpoints'
    ./.ipynb_checkpoints
    ./Glomeruli-dataset/.ipynb_checkpoints
    ./Glomeruli-dataset/splits/.ipynb_checkpoints
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    删除多余文件

    !for i in `find . -iname '__MACOSX'`; do rm -rf $i;done
    !for i in `find . -iname '.DS_Store'`; do rm -rf $i;done
    !for i in `find . -iname '.ipynb_checkpoints'`; do rm -rf $i;done
    
    • 1
    • 2
    • 3

    验证多余文件已删除

    !find . -iname '__MACOSX'
    !find . -iname '.DS_Store'
    !find . -iname '.ipynb_checkpoints'
    • 1
    • 2
    • 3
    • 4

    探索数据集

    导入工具包

    import os
    ​
    import cv2
    import numpy as np
    from PIL import Image
    from tqdm import tqdm
    ​
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    指定图像和标注文件夹路径

    PATH_IMAGE = 'Glomeruli-dataset/images'
    PATH_MASKS = 'Glomeruli-dataset/masks'
    print('图像个数', len(os.listdir(PATH_IMAGE)))
    print('标注个数', len(os.listdir(PATH_MASKS)))
    
    • 1
    • 2
    • 3
    • 4

    查看单张图像及其语义分割标注

    指定图像文件名

    file_name = 'SAS_21883_001_10.png'
    img_path = os.path.join(PATH_IMAGE, file_name)
    mask_path = os.path.join(PATH_MASKS, file_name)print('图像路径', img_path)
    print('标注路径', mask_path)
    img = cv2.imread(img_path)
    mask = cv2.imread(mask_path)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    # 可视化图像
    plt.imshow(img)
    plt.show()
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    # mask 语义分割标注,与原图大小相同,0 为 背景, 1 为 肾小球
    np.unique(mask)
    
    • 1
    • 2

    array([0, 1], dtype=uint8)
    在本数据集中,只有一部分图像有肾小球语义分割标注(即mask中值为1的像素),其余图像mask的值均为0

    可视化语义分割标注

    plt.imshow(mask[:,:,0])
    plt.show()

    # 可视化语义分割标注
    plt.imshow(mask*255)
    plt.show()
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    可视化单张图像及其语义分割标注-代码模板

    plt.imshow(img)
    plt.imshow(mask*255, alpha=0.5) # alpha 高亮区域透明度,越小越接近原图
    plt.title(file_name)
    plt.axis('off')
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5

    在这里插入图片描述

    可视化模板-有前景标注

    # n行n列可视化
    n = 7# 标注区域透明度
    opacity = 0.5
    ​
    fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, figsize=(12,12))
    ​
    i = 0for file_name in os.listdir(PATH_IMAGE):
        
        # 载入图像和标注
        img_path = os.path.join(PATH_IMAGE, file_name)
        mask_path = os.path.join(PATH_MASKS, file_name)
        img = cv2.imread(img_path)
        mask = cv2.imread(mask_path)
        
        if 1 in mask:
            axes[i//n, i%n].imshow(img)
            axes[i//n, i%n].imshow(mask*255, alpha=opacity)
            axes[i//n, i%n].axis('off') # 关闭坐标轴显示
            i += 1
        if i > n**2-1:
            break
    fig.suptitle('Image and Semantic Label', fontsize=30)
    plt.tight_layout()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    在这里插入图片描述

    可视化模板-无论前景是否有标注

    # n行n列可视化
    n = 10# 标注区域透明度
    opacity = 0.5
    ​
    fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, figsize=(12,12))for i, file_name in enumerate(os.listdir(PATH_IMAGE)[:n**2]):
        
        # 载入图像和标注
        img_path = os.path.join(PATH_IMAGE, file_name)
        mask_path = os.path.join(PATH_MASKS, file_name)
        img = cv2.imread(img_path)
        mask = cv2.imread(mask_path)
        
        # 可视化
        axes[i//n, i%n].imshow(img)
        axes[i//n, i%n].imshow(mask*255, alpha=opacity)
        axes[i//n, i%n].axis('off') # 关闭坐标轴显示
    fig.suptitle('Image and Semantic Label', fontsize=30)
    plt.tight_layout()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    在这里插入图片描述

    划分训练集和测试集

    导入工具包

    import os
    import random
    
    • 1
    • 2

    获取全部数据文件名列表

    PATH_IMAGE = 'Glomeruli-dataset/images'
    all_file_list = os.listdir(PATH_IMAGE)
    all_file_num = len(all_file_list)
    random.shuffle(all_file_list) # 随机打乱全部数据文件名列表
    
    • 1
    • 2
    • 3
    • 4

    指定训练集和测试集比例

    train_ratio = 0.8
    test_ratio = 1 - train_ratio
    train_file_list = all_file_list[:int(all_file_num*train_ratio)]
    test_file_list = all_file_list[int(all_file_num*train_ratio):]
    print('数据集图像总数', all_file_num)
    print('训练集划分比例', train_ratio)
    print('训练集图像个数', len(train_file_list))
    print('测试集图像个数', len(test_file_list))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    数据集图像总数 2576
    训练集划分比例 0.8
    训练集图像个数 2060
    测试集图像个数 516

    生成两个txt划分文件

    os.mkdir('Glomeruli-dataset/splits')
    with open('Glomeruli-dataset/splits/train.txt', 'w') as f:
        f.writelines(line.split('.')[0] + '\n' for line in train_file_list)
    with open('Glomeruli-dataset/splits/val.txt', 'w') as f:
        f.writelines(line.split('.')[0] + '\n' for line in test_file_list)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    MMSegmentation训练语义分割模型

    导入工具包

    import numpy as np
    from PIL import Image
    ​
    import os.path as osp
    from tqdm import tqdm
    ​
    import mmcv
    import mmengine
    import matplotlib.pyplot as plt
    %matplotlib inline
    # 数据集图片和标注路径
    data_root = 'Glomeruli-dataset'
    img_dir = 'images'
    ann_dir = 'masks'# 类别和对应的颜色
    classes = ('background', 'glomeruili')
    palette = [[128, 128, 128], [151, 189, 8]]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    修改数据集类(指定图像扩展名)

    from mmseg.registry import DATASETS
    from mmseg.datasets import BaseSegDataset
    ​
    @DATASETS.register_module()
    class StanfordBackgroundDataset(BaseSegDataset):
      METAINFO = dict(classes = classes, palette = palette)
      def __init__(self, **kwargs):
        super().__init__(img_suffix='.png', seg_map_suffix='.png', **kwargs)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    文档:https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/tutorials/customize_datasets.md#customize-datasets-by-reorganizing-data

    修改config配置文件

    # 下载 config 文件 和 预训练模型checkpoint权重文件
    !mim download mmsegmentation --config pspnet_r50-d8_4xb2-40k_cityscapes-512x1024 --dest .
    
    • 1
    • 2
    from mmengine import Config
    cfg = Config.fromfile('../mmsegmentation/configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py')
    cfg.norm_cfg = dict(type='BN', requires_grad=True) # 只使用GPU时,BN取代SyncBN
    cfg.crop_size = (256, 256)
    cfg.model.data_preprocessor.size = cfg.crop_size
    cfg.model.backbone.norm_cfg = cfg.norm_cfg
    cfg.model.decode_head.norm_cfg = cfg.norm_cfg
    cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
    # modify num classes of the model in decode/auxiliary head
    cfg.model.decode_head.num_classes = 2
    cfg.model.auxiliary_head.num_classes = 2# 修改数据集的 type 和 root
    cfg.dataset_type = 'StanfordBackgroundDataset'
    cfg.data_root = data_root
    ​
    cfg.train_dataloader.batch_size = 8
    ​
    cfg.train_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(type='LoadAnnotations'),
        dict(type='RandomResize', scale=(320, 240), ratio_range=(0.5, 2.0), keep_ratio=True),
        dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
        dict(type='RandomFlip', prob=0.5),
        dict(type='PackSegInputs')
    ]
    ​
    cfg.test_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(type='Resize', scale=(320, 240), keep_ratio=True),
        # add loading annotation after ``Resize`` because ground truth
        # does not need to do resize data transform
        dict(type='LoadAnnotations'),
        dict(type='PackSegInputs')
    ]
    ​
    ​
    cfg.train_dataloader.dataset.type = cfg.dataset_type
    cfg.train_dataloader.dataset.data_root = cfg.data_root
    cfg.train_dataloader.dataset.data_prefix = dict(img_path=img_dir, seg_map_path=ann_dir)
    cfg.train_dataloader.dataset.pipeline = cfg.train_pipeline
    cfg.train_dataloader.dataset.ann_file = 'splits/train.txt'
    ​
    cfg.val_dataloader.dataset.type = cfg.dataset_type
    cfg.val_dataloader.dataset.data_root = cfg.data_root
    cfg.val_dataloader.dataset.data_prefix = dict(img_path=img_dir, seg_map_path=ann_dir)
    cfg.val_dataloader.dataset.pipeline = cfg.test_pipeline
    cfg.val_dataloader.dataset.ann_file = 'splits/val.txt'
    ​
    cfg.test_dataloader = cfg.val_dataloader
    ​
    ​
    # 载入预训练模型权重
    cfg.load_from = 'pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'# 工作目录
    cfg.work_dir = './work_dirs/tutorial'# 训练迭代次数
    cfg.train_cfg.max_iters = 800
    # 评估模型间隔
    cfg.train_cfg.val_interval = 400
    # 日志记录间隔
    cfg.default_hooks.logger.interval = 100
    # 模型权重保存间隔
    cfg.default_hooks.checkpoint.interval = 400# 随机数种子
    cfg['randomness'] = dict(seed=0)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70

    查看完整config配置文件

    print(cfg.pretty_text)
    
    • 1

    保存config配置文件

    cfg.dump('new_cfg.py')
    
    • 1

    准备训练

    from mmengine.runner import Runner
    from mmseg.utils import register_all_modules
    ​
    # register all modules in mmseg into the registries
    # do not init the default scope here because it will be init in the runner
    register_all_modules(init_default_scope=False)
    runner = Runner.from_cfg(cfg)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    开始训练
    如果遇到报错CUDA out of memeory,重启实例或使用显存更高的实例即可。

    runner.train()
    
    • 1

    用训练得到的模型预测

    导入工具包

    import numpy as np
    import matplotlib.pyplot as plt
    %matplotlib inline
    ​
    from mmseg.apis import init_model, inference_model, show_result_pyplot
    import mmcv
    import cv2
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    载入模型

    # 载入 config 配置文件
    from mmengine import Config
    cfg = Config.fromfile('new_cfg.py')
    
    • 1
    • 2
    • 3
    from mmengine.runner import Runner
    from mmseg.utils import register_all_modules
    ​
    # register all modules in mmseg into the registries
    # do not init the default scope here because it will be init in the runner
    ​
    register_all_modules(init_default_scope=False)
    runner = Runner.from_cfg(cfg)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    # 初始化模型
    checkpoint_path = './work_dirs/tutorial/iter_800.pth'
    model = init_model(cfg, checkpoint_path, 'cuda:0')
    
    • 1
    • 2
    • 3

    载入测试集图像,或新图像

    img = mmcv.imread('Glomeruli-dataset/images/VUHSK_1702_39.png')
    
    • 1

    语义分割预测

    result = inference_model(model, img)
    result.keys()
    
    • 1
    • 2
    pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
    pred_mask.shape
    
    • 1
    • 2
    np.unique(pred_mask)
    
    • 1

    可视化语义分割预测结果

    plt.imshow(pred_mask)
    plt.show()
    
    • 1
    • 2

    在这里插入图片描述

    # 可视化预测结果
    visualization = show_result_pyplot(model, img, result, opacity=0.7, out_file='pred.jpg')
    plt.imshow(mmcv.bgr2rgb(visualization))
    plt.show()
    
    • 1
    • 2
    • 3
    • 4

    在这里插入图片描述

    语义分割预测结果-连通域分析

    plt.imshow(np.uint8(pred_mask))
    plt.show()
    
    • 1
    • 2

    在这里插入图片描述

    connected = cv2.connectedComponentsWithStats(np.uint8(pred_mask), connectivity=4)
    
    • 1
    # 连通域个数(第一个有可能是全图,可以忽略)
    connected[0]
    
    • 1
    • 2
    # 用整数表示每个连通域区域
    connected[1].shape
    np.unique(connected[1])
    plt.imshow(connected[1])
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5

    在这里插入图片描述

    # 每个连通域外接矩形的左上角X、左上角Y、宽度、高度、面积
    connected[2]
    
    • 1
    • 2
    # 每个连通域的质心坐标
    connected[3]
    
    • 1
    • 2

    获取测试集标注

    label = mmcv.imread('Glomeruli-dataset/masks/VUHSK_1702_39.png')
    label_mask = label[:,:,0]
    label_mask.shape
    np.unique(label_mask)
    plt.imshow(label_mask)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    在这里插入图片描述

    对比测试集标注和语义分割预测结果

    # 测试集标注
    label_mask.shape
    
    # 语义分割预测结果
    pred_mask.shape
    # 真实为前景,预测为前景
    TP = (label_mask == 1) & (pred_mask==1)
    # 真实为背景,预测为背景
    TN = (label_mask == 0) & (pred_mask==0)
    # 真实为前景,预测为背景
    FN = (label_mask == 1) & (pred_mask==0)
    # 真实为背景,预测为前景
    FP = (label_mask == 0) & (pred_mask==1)
    plt.imshow(TP)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    在这里插入图片描述

    confusion_map = TP * 255 + FP * 150 + FN * 80 + TN * 10
    plt.imshow(confusion_map)
    plt.show()
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    混淆矩阵

    from sklearn.metrics import confusion_matrix
    confusion_matrix_model = confusion_matrix(label_map.flatten(), pred_mask.flatten())
    import itertools
    def cnf_matrix_plotter(cm, classes, cmap=plt.cm.Blues):
        """
        传入混淆矩阵和标签名称列表,绘制混淆矩阵
        """
        plt.figure(figsize=(10, 10))
        
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        # plt.colorbar() # 色条
        tick_marks = np.arange(len(classes))
        
        plt.title('Confusion Matrix', fontsize=30)
        plt.xlabel('Pred', fontsize=25, c='r')
        plt.ylabel('True', fontsize=25, c='r')
        plt.tick_params(labelsize=16) # 设置类别文字大小
        plt.xticks(tick_marks, classes, rotation=90) # 横轴文字旋转
        plt.yticks(tick_marks, classes)
        
        # 写数字
        threshold = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, cm[i, j],
                     horizontalalignment="center",
                     color="white" if cm[i, j] > threshold else "black",
                     fontsize=12)
    ​
        plt.tight_layout()
    ​
        plt.savefig('混淆矩阵.pdf', dpi=300) # 保存图像
        plt.show()
    classes = ('background', 'glomeruili')
    cnf_matrix_plotter(confusion_matrix_model, classes, cmap='Blues')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34


    在这里插入图片描述

    测试集性能评估

    添加数据集类

    # 数据集配置文件
    !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/StanfordBackgroundDataset.py -O ../mmsegmentation/mmseg/datasets/StanfordBackgroundDataset.py
    ​
    # 修改 ../mmsegmentation/mmseg/datasets/__init__.py,添加数据集
    !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/__init__.py -O ../mmsegmentation/mmseg/datasets/__init__.py
    
    • 1
    • 2
    • 3
    • 4
    • 5

    测试集精度指标

    !python ../mmsegmentation/tools/test.py new_cfg.py ./work_dirs/tutorial/iter_800.pth
    
    • 1

    速度指标-FPS

    !python ../mmsegmentation/tools/analysis_tools/benchmark.py new_cfg.py ./work_dirs/tutorial/iter_800.pth
    
    • 1
  • 相关阅读:
    网络安全(黑客)自学
    FlinkCDC 3.1.0 与 Flink 1.18.0 安装及使用 Mysql To Doris 整库同步,使用 pipepline连接器
    gwas数据获取如何获取完整的GWAS summary数据(1)------GWAS catalog数据库
    Nginx 优化
    大数据发展史
    前端图片转成base64
    "科来杯"第十届山东省大学生网络安全技能大赛决赛复现WP
    成都理工大学_Python程序设计_第1章
    couldn‘t find “libopencv_java3.so“
    Cmasher颜色包--共53种--全平台可用
  • 原文地址:https://blog.csdn.net/m0_47867638/article/details/131863176