• 【手撕DroneSSOD】(一)Density Crop Labeling 部分


    该部分为DroneSSOD的Density Crop Labeling 部分,主要用于半监督训练中对无标签数据生成具有“密集作物”的伪标签。
    【论文标题】Density Crop-guided Semi-supervised Object Detection in Aerial Images
    【代码地址】https://github.com/akhilpm/DroneSSOD

    一、概述

    Density Crop Labeling 部分的代码总共分为四个部分,分别是:crop_unlabeled_algo_dota.pyget_crop_unlabeled.pyget_crop_unlabeled_algo.pytrain_fcos.py
    其中:
    1、 train_fcos.py 主要是用于 创建配置并执行基本设置
    2、get_crop_unlabeled.py 主要是用于 为无标签图像数据创建密集作物类,并进行密集作物检测
    3、get_crop_unlabeled_algo.py 主要是用于 使用无标签的数据集(训练集),根据模型的预测结果,生成裁剪后的边界框
    4、crop_unlabeled_algo_dota.py 主要是用于 使用无标签的数据集(非训练集),根据模型的预测结果,生成裁剪后的边界框,并进行可视化展示

    Detectron2 简介

    Detectron2 是由 FAIR 出品的,基于 caffe2 框架的物体检测和分割开源项目,这个项目是 maskrcnn-benchmark 的代替者,detectron2 的优势主要体现在以下三点:

    • R-CNN系列最强实现,另外也支持新的特性,如全景分割(panoptic segmentation)和旋转框(rotated bounding boxes);
    • 高度模块化,扩展性好,具体讲detectron2可以看成一个基础库,当实现新的模型时是去import而不是modify;
    • 训练速度更快。

    Detectron2 官方文档

    Detectron2 Requirements

    • Python 的版本大于等于 3.7 的 Linux 系统或者 macOS 系统。
    • PyTorch 的版本大于等于 1.8 ,以及 torchvision 的版本应与 PyTorch 的版本相匹配。一起在 pytorch.org 上安装可以确保这一点。
    • OpenCV 是可选的,但演示和可视化需要它。

    二、train_fcos.py部分代码

    1.导入必要库

    import torch
    import os
    import detectron2.utils.comm as comm
    from detectron2.checkpoint import DetectionCheckpointer
    from detectron2.config import get_cfg
    from detectron2.engine import default_argument_parser, default_setup, launch
    from detectron2.modeling import GeneralizedRCNN
    
    from croptrain import add_croptrainer_config, add_ubteacher_config, add_fcos_config
    from croptrain.engine.trainer import UBTeacherTrainer, BaselineTrainer
    # hacky way to register
    from croptrain.modeling.meta_arch.crop_rcnn import CropRCNN
    from croptrain.modeling.meta_arch.crop_fcos import CROP_FCOS
    import croptrain.data.datasets.builtin
    from croptrain.data.datasets.visdrone import register_visdrone
    from croptrain.data.datasets.dota import register_dota
    
    from croptrain.modeling.meta_arch.ts_ensemble import EnsembleTSModel
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    部分引用库介绍:

    • detectron2.utils.comm 中包含用于多 GPU 通信的基元,常用于进行分布式训练的设置
    • detectron2.checkpoint.DetectionCheckpointer 是一个检查指针,可以保存/加载模型及额外的可检查点对象。
      但能够:1、处理 detectron & detectron2 中的模型,并对遗留模型应用转换。2、正确加载仅在主工作线程上可用的检查点。
    • detectron2.config.get_cfg 用于获取默认配置的副本,返回 detectron2 CfgNode 实例。
    • detectron2.engine.default_argument_parser 用于解析命令行参数,返回一个 argparse.ArgumentParser 对象
    • detectron2.engine.default_setup 用于设置计算机视觉任务的默认配置,返回一个 detectron2.config.Config 对象
    • detectron2.engine.launch 用于启动多 GPU 或分布式训练
    • detectron2.modeling.GeneralizedRCNN 用于构建基于 R-CNN 模型。该类继承自 torch.nn.Module 类,并实现了 RPN 和 R-CNN 的计算图
    • croptrain.add_croptrainer_config 用于向 CropTrain 对象添加配置信息
    • croptrain.add_ubteacher_config 用于向 CropTrain 对象添加 Unbaise-Teacher 的配置信息
    • croptrain.add_fcos_config 用于向CropTrain对象添加fcos的配置信息
    • croptrain.engine.trainer.UBTeacherTrainer 用于向 CropTrain 中用于训练 Unbaise-Teacher 模型的默认训练器
    • croptrain.engine.trainer.BaselineTrainerCropTrain 中用于训练基准模型的默认训练器
    • croptrain.modeling.meta_arch.crop_rcnn.CropRCNNCropTrain 中用于构建 RPN 和 R-CNN 的模型的基类。该类继承自 torch.nn.Module 类,并实现了 RPN 和 R-CNN 模型的训练和推理过程
    • croptrain.modeling.meta_arch.crop_fcos.CROP_FCOSCropTrain 中用于构建 Feature Fusion Network 的模型的基类。该类继承自 torch.nn.Module 类,并实现了 Feature Fusion Network 模型的训练和推理过程
    • croptrain.data.datasets.builtin 中包含了一些内置数据集,可用于训练和测试CropTrain 模型
    • croptrain.data.datasets.visdrone.register_visdrone 用于注册 VisDrone 数据集
    • croptrain.data.datasets.dota.register_dota 用于注册 DOTA 数据集
    • croptrain.modeling.meta_arch.ts_ensemble.EnsembleTSModelCropTrain 中用于构建基于时间序列集成(Time Series Ensemble)的模型的基类。该类继承自 torch.nn.Module 类,并实现了时间序列集成模型的训练和推理过程

    2.setup 函数

    setup函数主要用于创建配置并执行基本设置pytpython

    def setup(args):
        cfg = get_cfg()    # 用于获取配置对象
        add_croptrainer_config(cfg)    # 用于向一个配置对象中添加 CropTrain 相关的配置信息
        add_ubteacher_config(cfg)    # 添加了一个针对半监督学习模型的配置
        add_fcos_config(cfg)    # 添加 FCOS 的相关配置
        cfg.merge_from_file(args.config_file)    # 将配置文件中的内容整合到 cfg 对象中
        cfg.merge_from_list(args.opts)    # 将命令行参数中的选项整合到 cfg 对象中
        cfg.freeze()    # 将配置冻结为不可更改的形式
        default_setup(cfg, args)    # 使用 detectron2.engine.default_setup() 函数设置默认的配置文件和命令行参数
        return cfg    # 将经过默认配置后的 cfg 对象返回给调用者
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    1.add_croptrainer_config(cfg)函数

    函数首先将 cfg 赋值给变量 _C,然后依次为 CROPTRAINCROPTESTMODEL.CUSTOM 三个部分添加配置信息。
    其中,CROPTRAINCROPTEST 分别表示训练集和测试集的配置信息,MODEL.CUSTOM 表示模型自定义配置信息

    def add_croptrainer_config(cfg):
        _C = cfg
        _C.CROPTRAIN = CN()
        _C.CROPTRAIN.USE_CROPS = False    # 是否使用作物(crops)作为训练集的一部分,这里设置为 False
        _C.CROPTRAIN.CLUSTER_THRESHOLD = 0.1    # 聚类阈值,用于将连续的像素点聚类为不同的物体,这里设置为 0.1
        _C.CROPTRAIN.CROPSIZE = (320, 476, 512, 640)    # 训练集中每个图像的大小,以 (宽度, 高度) 的形式给出,这里设置为 (320, 476, 512, 640)
        _C.CROPTRAIN.MAX_CROPSIZE = 800    # 训练集中最大的图像大小,这里设置为 800
        _C.CROPTEST = CN()
        _C.CROPTEST.CLUS_THRESH = 0.3    # 聚类阈值,用于确定一个像素点是否属于同一个物体,这里设置为 0.3
        _C.CROPTEST.MAX_CLUSTER = 5    # 最大聚类数,即测试集中最多允许有多少个不同的物体被检测到,这里设置为 5
        _C.CROPTEST.CROPSIZE = 800    # 测试集中每个图像的大小,以 (宽度, 高度) 的形式给出,这里设置为 800 x 800
        _C.CROPTEST.DETECTIONS_PER_IMAGE = 800    # 每个图像中要检测到的物体数量,这里设置为 800
        _C.MODEL.CUSTOM = CN()
        _C.MODEL.CUSTOM.FOCAL_LOSS_GAMMAS = []    # 焦距损失函数的伽马值列表,这里为空列表
        _C.MODEL.CUSTOM.FOCAL_LOSS_ALPHAS = []    # 焦距损失函数的阿尔法值列表,这里为空列表
    
        _C.MODEL.CUSTOM.CLS_WEIGHTS = []    # 分类器权重列表,这里为空列表
        _C.MODEL.CUSTOM.REG_WEIGHTS = []    # 回归器权重列表,这里为空列表
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    【补充】
    焦距损失函数是一种用于解决数据不平衡问题的损失函数,主要用于图像领域。它是通过动态缩放因子来降低训练过程中易区分样本的权重,从而将重心快速聚焦在那些难区分的样本上,从而提高模型的性能 。
    焦距损失函数的表达式

    2.add_ubteacher_config(cfg)函数

    add_ubteacher_config 添加了一个针对半监督学习模型的配置。它为不同的组件设置了各种参数,包括损失函数、数据加载器设置和模型架构。

    def add_ubteacher_config(cfg):
        """
        Add config for semisupnet.
        """
        _C = cfg
        _C.TEST.VAL_LOSS = True    #  将测试数据集的验证损失设置为True,表示在测试时使用验证集来计算损失
    
        _C.MODEL.RPN.UNSUP_LOSS_WEIGHT = 1.0    # 将 RPN 的无监督损失权重设置为1.0。这意味着在训练过程中,RPN模型将更关注有标签样本的学习
        _C.MODEL.RPN.LOSS = "CrossEntropy"    # 将 RPN 的损失函数设置为 “CrossEntropy"(交叉熵损失函数)
        _C.MODEL.ROI_HEADS.LOSS = "CrossEntropy"    # 将 Roi 头部的损失函数设置为 “CrossEntropy” 
    
        _C.SOLVER.IMG_PER_BATCH_LABEL = 1    # 设置每个批次中训练标签图像的数量为1
        _C.SOLVER.IMG_PER_BATCH_UNLABEL = 1    # 设置每个批次中未标记图像的数量为1
        _C.SOLVER.FACTOR_LIST = (1,)    # 设置数据加载器的采样器列表,每次只使用一个样本进行训练或评估
    
        _C.DATASETS.TRAIN_LABEL = ("visdrone_2019_train",)    # 将训练数据集的标签设为 “visdrone_2019_train”
        _C.DATASETS.TRAIN_UNLABEL = ("visdrone_2019_test",)    # 将训练数据集的未标记图像设置为"visdrone_2019_test"
        _C.DATASETS.CROSS_DATASET = False    # 不使用交叉数据集进行训练或评估
        _C.TEST.EVALUATOR = "COCOeval"    # 使用COCOeval进行模型评估
    
        _C.SEMISUPNET = CN()
    
        # Output dimension of the MLP projector after `res5` block
        """在 "res5" 块之后,MLP投影器的输出维度被设置为128
        res5指的是残差块中的第五个模块,常作为特征提取模块;
        MLP指的是多层感知机,是一种前馈神经网络模型,常用于分类和回归任务中"""
        _C.SEMISUPNET.MLP_DIM = 128    # 将 MLP 的配置成具有128个隐藏层单元模型
    
        # Semi-supervised training 半监督训练
        _C.SEMISUPNET.USE_SEMISUP = False    # 不使用半监督学习的方法
        _C.SEMISUPNET.AUG_CROPS_UNSUP = False    # 不使用数据增强的方式来处理未标注的数据
        _C.SEMISUPNET.BBOX_THRESHOLD = 0.7    # 将边界框阈值设为 0.7 ,用于筛选出置信度较高的边界框
        _C.SEMISUPNET.PSEUDO_BBOX_SAMPLE = "thresholding"    # 伪边界框采样方法设置为 “thresholding” (阈值处理)
        _C.SEMISUPNET.TEACHER_UPDATE_ITER = 1    # 教师网络更新迭代的次数,设置为 1
        _C.SEMISUPNET.BURN_UP_STEP = 12000    # 用于控制训练过程中模型参数的更新频率,设置为 12000
        _C.SEMISUPNET.EMA_KEEP_RATE = 0.0    # EMA 的保留率,设置为 0.0
        _C.SEMISUPNET.UNSUP_LOSS_WEIGHT = 4.0    # 未标注样本的损失权重,设置为 0.4
        _C.SEMISUPNET.SUP_LOSS_WEIGHT = 0.5    # 已标注样本的损失权重,设置为 0.5
        _C.SEMISUPNET.LOSS_WEIGHT_TYPE = "standard"    # 损失权重的类型,设置为 “standard”(标准)
    
        # dataloader
        # supervision level
        _C.DATALOADER.SUP_PERCENT = 100.0    # 使用全部数据集(100%)作为已标注的数据集【5 = 5% dataset as labeled set】
        _C.DATALOADER.RANDOM_DATA_SEED = 42    # 在读取数据时使用随机种子【random seed to read data】
        _C.DATALOADER.USE_RANDOM_SPLIT = False    # 不使用随机划分数据集的方法
        _C.DATALOADER.SEED_PATH = "dataseed/visdrone_sup_10.0.txt"    # 读取数据集的随机种子文件的路径,文件包含用于划分数据集的随机种子信息
    
        _C.EMAMODEL = CN()
        _C.EMAMODEL.SUP_CONSIST = True    # 使用一致性监督的方法来训练模型
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49

    【补充】
    COCOeval模型,是一种常用的目标检测评估工具,它基于COCO数据集设计了一系列评估指标。COCOeval可以用于计算目标检测结果的mAP,并提供了可视化的结果展示。

    # 安装 COCOeval 库(pycocotools)
    pip install pycocotools==2.0.0
    
    • 1
    • 2

    以下是在Python中使用COCOeval进行目标检测评估的示例代码:

    # 导入必要的库
    from pycocotools.coco import COCO
    from pycocotools.cocoeval import COCOeval
    
    # 加载ground truth和预测结果
    cocoGt = COCO('path/to/ground_truth_annotations.json')    # 替换实际标注文件
    cocoDt = cocoGt.loadRes('path/to/detection_results.json')    # 替换检测结果文件
    
    # 创建COCOeval对象
    cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
    
    # 进行评估
    cocoEval.evaluate()
    cocoEval.accumulate()
    cocoEval.summarize()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    3. add_fcos_config(cfg)函数

    add_fcos_config 用于配置 FCOS(Fully Convolutional One-Stage Object Detection) 模型的相关参数,以便正确地加载和处理数据集。

    def add_fcos_config(cfg):
        _C = cfg
        _C.MODEL.FCOS = CN()
        _C.MODEL.FCOS.NORM = "GN"    # 归一化方法使用“批量归一化”
        _C.MODEL.FCOS.NUM_CLASSES = 80    # 分类任务中的类别数量(80)
        _C.MODEL.FCOS.NUM_CONVS = 4    # 卷积层层数(4)
        _C.MODEL.FCOS.SCORE_THRESH_TEST = 0.01    # 测试集上的预测分数阈值
        _C.MODEL.FCOS.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"]    # 输入特征的名称和索引
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    FCOS算法 是一种基于 FCN 的逐像素目标检测算法,实现了无锚点(anchor-free)、无提议(proposal free)的解决方案,并提出了中心度(center-ness)的思想,同时在召回率等方面表现接近甚至超过目前很多先进主流的基于锚框目标检测算法。

    3.main 函数

    def main(args):
        cfg = setup(args)    # 将 setup 更新后的参数赋值给 cfg 对象
        if not torch.cuda.is_available():
            # 如果系统不能使用 GPU 加速
            cfg.defrost()    # 解冻配置对象 cfg
            cfg.MODEL.DEVICE = 'cpu'    # 模型部署设置为 cpu
            cfg.freeze()    # 冻结配置对象 cfg
        if cfg.SEMISUPNET.USE_SEMISUP:
            # 如果使用半监督训练方法(_C.SEMISUPNET.USE_SEMISUP = True)
            Trainer = UBTeacherTrainer    # 使用 Unbiased Teacher 的训练器
        else:
            # 如果不使用半监督训练方法(_C.SEMISUPNET.USE_SEMISUP = False )
            Trainer = BaselineTrainer    # 使用 Baseline 的训练器
    
        if cfg.CROPTRAIN.USE_CROPS:
            # 如果将作物作为训练集的一部分(_C.CROPTRAIN.USE_CROPS = True)
            cfg.defrost()    # 解冻配置对象 cfg
            cfg.MODEL.FCOS.NUM_CLASSES += 1    # FCOS 分类任务种类 +1(新增“密集作物”类)
            cfg.freeze()    # 冻结配置对象 cfg
        if "visdrone" in cfg.DATASETS.TRAIN[0] or "visdrone" in cfg.DATASETS.TEST[0]:
            # cfg.DATASETS.TRAIN[0] 和 cfg.DATASETS.TEST[0] 分别是用来存放训练集和测试集的配置信息。
            # 通过判断是否存在 “visdrone” 字符串来判断 VisDrone 数据集是否被使用
            # 如果使用,则将数据集的存储路径设置为环境变量 SLURM_TMPDIR 下的 "VisDrone" 文件夹
            data_dir = os.path.join(os.environ['SLURM_TMPDIR'], "VisDrone")
            if not args.eval_only:
                # 判断是否只进行评估而不进行训练
                # 使用 register_visdrone() 函数在训练过程中将 VisDrone 数据集的图像和标注信息传递给回调函数进行处理
                # cfg.DATASETS.TRAIN[0]:表示训练集的配置信息
                # data_dir:表示数据集的存储路径
                # cfg:表示模型的配置信息
                # True:表示启用回调函数
                register_visdrone(cfg.DATASETS.TRAIN[0], data_dir, cfg, True)
            # 使用 register_visdrone() 函数在训练过程中不将 VisDrone 数据集的图像和标注信息传递给回调函数进行处理
            register_visdrone(cfg.DATASETS.TEST[0], data_dir, cfg, False)
        if "dota" in cfg.DATASETS.TRAIN[0] or "dota" in cfg.DATASETS.TEST[0]:
            data_dir = os.path.join(os.environ['SLURM_TMPDIR'], "DOTA")
            if not args.eval_only:
                register_dota(cfg.DATASETS.TRAIN[0], data_dir, cfg, True)
            register_dota(cfg.DATASETS.TEST[0], data_dir, cfg, False)
    
        if args.eval_only:
        # 如果只进行评估而不进行训练
            if cfg.SEMISUPNET.USE_SEMISUP:
            # 如果使用半监督学习(_C.SEMISUPNET.USE_SEMISUP = True)
                model = Trainer.build_model(cfg)    # 使用 cfg 配置信息构建模型
                model_teacher = Trainer.build_model(cfg)    # 使用 cfg 配置信息构建教师模型
                ensem_ts_model = EnsembleTSModel(model_teacher, model) # 使用 EnsembleTSModel() 函数构建模型,使用 model_teacher 和 model 的参数
    
                # 创建 ensem_te_model 的检查点 checkpointer
                DetectionCheckpointer(
                    ensem_ts_model, save_dir=cfg.OUTPUT_DIR
                ).resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume)    # 使用 resume_or_load() 方法恢复或加载模型的检查点
                # cfg.MODEL.WEIGHTS:预训练模型的权重文件路径
                # resume=args.resume:是否从上次训练中断的地方恢复模型
                
                #res = Trainer.test(cfg, ensem_ts_model.modelTeacher)
                if cfg.CROPTRAIN.USE_CROPS:
                # 如果将密集作物作为训练集的一部分(_C.CROPTRAIN.USE_CROPS = True)
                    res = Trainer.test_crop(cfg, ensem_ts_model.modelTeacher, 0)    # 使用测试数据对 ensem_ts_model.modelTeacher 模型进行评估
                else:
                # 如果不将密集作物作为训练集的一部分(_C.CROPTRAIN.USE_CROPS = False)
                    if "dota" in cfg.DATASETS.TEST[0]:
                    # 如果测试集的配置信息中存在 “dota” 字符串
                        res = Trainer.test_crop(cfg, ensem_ts_model.modelTeacher, 0)    # 使用测试数据对 ensem_ts_model.modelTeacher 模型进行评估
                    else:
                    # 如果测试集的配置信息中不存在 “dota” 字符串
                        res = Trainer.test(cfg, ensem_ts_model.modelTeacher)    # 不使用测试数据对 ensem_ts_model.modelTeacher 模型进行评估
            else:
            # 如果不使用半监督学习(_C.SEMISUPNET.USE_SEMISUP = False)
                model = Trainer.build_model(cfg)    # 使用 cfg 配置信息构建模型
                # 创建 ensem_te_model 的检查点 checkpointer
                DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
                    cfg.MODEL.WEIGHTS, resume=args.resume
                )    # 使用 resume_or_load() 方法恢复或加载模型的检查点
                if cfg.CROPTRAIN.USE_CROPS:
                # 如果将密集作物作为训练集的一部分(_C.CROPTRAIN.USE_CROPS = True)
                    res = Trainer.test_crop(cfg, model, 0)    # 使用测试数据对 model 模型进行评估
                else:
                # 如果不将密集作物作为训练集的一部分(_C.CROPTRAIN.USE_CROPS = False) 
                    if "dota" in cfg.DATASETS.TEST[0]:
                    # 如果测试集的配置信息中存在 “dota” 字符串
                        res = Trainer.test_crop(cfg, model, 0)    # 使用测试数据对 model 模型进行评估
                    else:
                    # 如果测试集的配置信息中不存在 “dota” 字符串
                        res = Trainer.test(cfg, model)    # 不使用测试数据对 model 模型进行评估
            return res    # 返回 res 对象
    
        trainer = Trainer(cfg)    # 使用 cfg 配置文件创建 Trainer (训练器)
        trainer.resume_or_load(resume=args.resume)    # 使用 resume_or_load() 方法来恢复或加载模型的检查点
    
        return trainer.train()    # 返回 开始训练模型
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    1.ts_ensemble.py

    主要作用是通过创建 EnsembleTSModel 对象时传入两个模型对象作为参数,就可以将这两个模型进行集成训练,得到一个结合了教师模型和学生模型性能的新模型。

    from torch.nn.parallel import DataParallel, DistributedDataParallel
    import torch.nn as nn
    
    class EnsembleTSModel(nn.Module):
        def __init__(self, modelTeacher, modelStudent):
            super(EnsembleTSModel, self).__init__()
    
            if isinstance(modelTeacher, (DistributedDataParallel, DataParallel)):
            # 使用 isinstance(x, A_tuple) 函数判断 modelTeacher 是否属于(DistributedDataParallel, DataParallel)类
            # DistributedDataParallel 是用于实现分布式训练的数据并行的类
            # DataParallel 是用于实现数据并行
                modelTeacher = modelTeacher.module
            if isinstance(modelStudent, (DistributedDataParallel, DataParallel)):
                modelStudent = modelStudent.module
    
            self.modelTeacher = modelTeacher
            self.modelStudent = modelStudent
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    2.Trainer.test_crop()函数

    test_crop() 函数用于在训练过程中对模型进行评估

    def test_crop(cls, cfg, model, iter, evaluators=None):
        logger = logging.getLogger(__name__)
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]
        if evaluators is not None:
            assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                len(cfg.DATASETS.TEST), len(evaluators)
            )
    
        results = OrderedDict()
        for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
            data_loader = cls.build_test_loader(cfg, dataset_name)
    
            if evaluators is not None:
                evaluator = evaluators[idx]
            else:
                try:
                    evaluator = cls.build_evaluator(cfg, dataset_name)
                except NotImplementedError:
                    logger.warn(
                        "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
                        "or implement its `build_evaluator` method."
                    )
                    results[dataset_name] = {}
                    continue
            if "dota" in cfg.DATASETS.TEST[0]:
                results_i = inference_dota(model, data_loader, evaluator, cfg, iter)
            else:
                results_i = inference_with_crops(model, data_loader, evaluator, cfg, iter)
            results[dataset_name] = results_i
            #experiment.log_metrics(results_i["bbox"], step=iter)
            if comm.is_main_process():
                assert isinstance(
                    results_i, dict
                ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                    results_i
                )
                logger.info("Evaluation results for {} in csv format:".format(dataset_name))
                print_csv_format(results_i)
    
        if len(results) == 1:
            results = list(results.values())[0] 
        return results
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43

    4.运行程序部分

    if __name__ == "__main__":
        args = default_argument_parser().parse_args()    # 解析命令行参数
        print("No of gpus used: {}".format(args.num_gpus))    # 打印当前使用的 GPU 数量
        print("Cuda detected {} gpus".format(torch.cuda.device_count()))    # 打印模型所部署的 GPU 数量
    
        print("Command Line Args:", args)    # 打印命令行参数
        launch(
            main,    # 主函数
            args.num_gpus,    # 当前使用的 GPU 数量
            num_machines=args.num_machines,    # 参与分布训练的机器数量
            machine_rank=args.machine_rank,    # 当前机器在分布式训练中的编号
            dist_url=args.dist_url,    # 分布式训练的通信 URL
            args=(args,),    # 将命令行参数传递给 main 函数
        )    # 启动分布式训练任务,并将训练过程中需要的各种参数传递给 main 函数
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    三、get_crop_unlabeled.py 代码

    1.导入必要库
    from croptrain.modeling.meta_arch.rcnn import TwoStagePseudoLabGeneralizedRCNN
    from croptrain.modeling.roi_heads.roi_heads import StandardROIHeadsPseudoLab
    from croptrain.modeling.proposal_generator.rpn import PseudoLabRPN
    from detectron2.checkpoint import DetectionCheckpointer
    from detectron2.config import get_cfg
    from croptrain import add_croptrainer_config, add_ubteacher_config
    from detectron2.data import DatasetCatalog, MetadataCatalog
    import os
    from croptrain.data.datasets.visdrone import register_visdrone
    from croptrain.engine.trainer import UBTeacherTrainer, BaselineTrainer
    import numpy as np
    import torch
    import datetime
    import time
    import copy
    import cv2
    import json
    from utils.crop_utils import get_dict_from_crops
    from contextlib import ExitStack, contextmanager
    from detectron2.structures.instances import Instances
    from detectron2.structures.boxes import Boxes
    import matplotlib.pyplot as plt
    import logging
    from croptrain.modeling.meta_arch.ts_ensemble import EnsembleTSModel
    from detectron2.data.build import get_detection_dataset_dicts
    from detectron2.utils.logger import log_every_n_seconds
    from croptrain.data.datasets.visdrone import compute_crops
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    2. 基本日志设置
    logging.basicConfig(level = logging.INFO)    # 为程序设置基本日志文件
    
    • 1
    3. infere_context(model)函数

    infere_context(model) 函数,用于创建一个上下文环境,在该环境中模型会被临时更改为评估模式,并在退出该环境后恢复为之前的模式。

    @contextmanager
    def inference_context(model):
        """
        A context where the model is temporarily changed to eval mode,
        and restored to previous mode afterwards.
        模型临时更改为eval模式的上下文,然后恢复到以前的模式
    
        Args:
            model: a torch Module
        """
        training_mode = model.training    # 获取模型当前的训练模式
        model.eval()    # 模型设置为评估模式
        yield    # 使用 yield 暂停函数执行,等待外部代码进入该上下文环境
        model.train(training_mode)    # 当上下文环境退出时,恢复之前的训练模式
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    4. shift_crop_boxes(data_dict, cluster_boxes)函数

    shift_crop_boxes(data_dict, cluster_boxes) 函数,用于将给定的聚类框进行平移和裁剪操作。

    def shift_crop_boxes(data_dict, cluster_boxes):
        '''
        Args:
        data_dict: 数据信息的字典对象,包含“crop_are” 对应值(裁剪区域左上角坐标(x1,y1))
        cluster_boxes: 聚类框坐标的数组(左上角和右下角坐标(x1,y1,x2,y2))
        
        Return:
        cluster_boxes: 平移后的聚类框数组
        '''
        x1, y1 = data_dict["crop_area"][0], data_dict["crop_area"][1]    # 从 data_dict 中获取裁剪区域的左上角坐标(x1,y1)
        ref_point = np.array([x1, y1, x1, y1])    # 使用 np.array 创建一个一维数组,用于表示平移向量
        cluster_boxes = cluster_boxes + ref_point    # 使用相加实现平移操作
        return cluster_boxes
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    5. inference_crops(model, data_loader, cfg)函数

    inference_crops(model, data_loader, cfg) 函数,用于在给定的数据加载器上对模型进行推理,并从模型中提取出特定类别的边框。

    def inference_crops(model, data_loader, cfg):
        '''
        Args:
        model: 用于执行推理操作的模型
        data_loader: 用于按批次提供输入数据的数据加载器
        cfg: 配置参数
        
        Return:
        None
        '''
        #dataset_dicts = get_detection_dataset_dicts(cfg.DATASETS.TEST, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS)
        dataset_name = cfg.DATASETS.TRAIN[0].split("_")[0]    # 使用 split() 函数实现分割,并获得数据集名称
        crop_file = os.path.join("dataseed", dataset_name + "_crops_{}.txt".format(cfg.DATALOADER.SUP_PERCENT))    # 获取裁剪文件路径(dataseed文件夹下'name_crops_{}.txt'文件)
        # cfg.DATALOADER.SUP_PERCENT 在 configs/*.yaml 中进行配置
        crop_storage = {}    # 创建一个空字典,用于存储每个图像文件中提取的边界框
    
        total = len(data_loader)  # inference data loader must have a fixed length(推理数据加载程序必须具有固定长度)
        cluster_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES - 1    # 获取 ROI头部 中的类别数量(10),减 1 以获得聚类的索引
        with ExitStack() as stack:    # 使用 ExitStack 上下文管理器来管理资源,确保在推出上下文时正确释放资源
            if isinstance(model, torch.nn.Module):
            # 如果 model 是 torch.nn.Module 中的类型
                # enter_context() 方法用于将给定的上下文管理器添加到 ExitStack 中,并确保在退出上下文时自动恢复资源
                # stack:用于管理资源的获取和释放,通常与 ExitStack 一起使用。
                stack.enter_context(inference_context(model))    # 调用 inference_context(model) 函数,并将返回的上下文管理器添加到 stack 中
            # torch.no_grad() 用于禁用梯度计算,从而减少内存消耗并加速推理过程
            stack.enter_context(torch.no_grad())    # 将 torch.no_grad() 函数的上下文管理器添加到 stack 中。
            count = 0    # 计数器初始化
            n_crops = 0    # 总作物数量初始化
            for idx, inputs in enumerate(data_loader):
                outputs = model(inputs)    # 进行推理
                cluster_class_indices = (outputs[0]["instances"].pred_classes==cluster_class)    # 从模型的输出中获取预测类别的索引
                cluster_boxes = outputs[0]["instances"][cluster_class_indices]    # 使用索引来选择属于特定聚类类的边界框 
                cluster_boxes = cluster_boxes[cluster_boxes.scores>0.35]    # 过滤掉置信度低于0.35的边界框
                file_name = inputs[0]["file_name"].split('/')[-1]    # 使用 split() 函数分割获取输入文件名称
                if file_name not in crop_storage:
                # 如果 crop_stroage 不存在文件名
                    crop_storage[file_name] = []    # 添加到 crop_storage 中,并将其初始化为空列表
                if idx%100==0:
                    print("processing {}th image".format(idx))    # 每 100 张图片打印一条进度信息
                if len(cluster_boxes)>0:
                # 如果提取到了边界框
                    # 从 pred_boxes 中获取预测边界框张量
                    cluster_boxes = cluster_boxes.pred_boxes.tensor.cpu().numpy().astype(np.int32)    # 将 cluster_boxes 中的预测边界框从 GPU 中转移到 CPU 中,并转化为 Numpy 数组,np.int32 数据类型
                    if not inputs[0]["full_image"]:
                    # 如果输入数据中的第一个图像不是完整图像
                        cluster_boxes = shift_crop_boxes(inputs[0], cluster_boxes)    # 使用 shift_crop_boxes 函数对边界框进行平移裁剪操作
                    crop_storage[file_name] += cluster_boxes.tolist()    # 将相应的裁剪区域以字典值的形式储存,使用每个图像文件名作为键
                    count += 1    # 更新计数器
                    n_crops += len(cluster_boxes)    # 更新从输入数据中提取的裁剪数量
        with open(crop_file, "w") as f:
            json.dump(crop_storage, f)    # 将crop_storage 字典中的内容以 JSON 格式储存
        print("crops present in {}/{} images".format(count, len(data_loader)))    # 打印数据集中图像总数
        print("number of crops is {} ".format(n_crops))    # 打印裁剪数量
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    6. main()函数

    main() 函数包含了一些配置和模型的构建步骤,并调用了一些参数来执行特定任务。

    def main():
        cfg = get_cfg()    # 获取 cfg 值
        add_croptrainer_config(cfg)    # 获取 add_croptrainer_config(cfg) 中配置信息
        add_ubteacher_config(cfg)    # 获取 add_ubteacher_config(cfg) 中配置信息
        # 工作目录:/home/uesr/project
        # 配置文件目录:configs/visdrone/Semi-Sup-RCNN-FPN-CROP.yaml
        # 构建路径:/home/uesr/project/configs/visdrone/Semi-Sup-RCNN-FPN-CROP.yaml
        cfg.merge_from_file(os.path.join(os.getcwd(), 'configs', 'visdrone', 'Semi-Sup-RCNN-FPN-CROP.yaml'))    # 从指定文件(*.yaml)中加载配置信息到 cfg 对象中
        if cfg.CROPTRAIN.USE_CROPS:
        # 是否使用作物作为训练集的一部分(_C.CROPTRAIN.USE_CROPS = True)
            cfg.MODEL.ROI_HEADS.NUM_CLASSES += 1    # ROI头部 种类 +1(新增“密集作物”类)
            cfg.MODEL.RETINANET.NUM_CLASSES += 1    # RetinaNet 种类 +1(新增“密集作物”类)
        data_dir = os.path.join(os.environ['SLURM_TMPDIR'], "VisDrone")    # 数据地址./SLURM_TMPDIR/VisDrone
        dataset_name = cfg.DATASETS.TRAIN[0]     # 从训练数据集中获取数据名字
        cfg.OUTPUT_DIR = "/home/akhil135/scratch/DroneSSOD/FPN_CROP_SS_10_06"    # 输出地址
        cfg.MODEL.WEIGHTS = "/home/akhil135/scratch/DroneSSOD/FPN_CROP_SS_10_06/model_0069999.pth"    # 输出权重
        if not dataset_name in DatasetCatalog:
        # dataset_name 不在 DatasetCatalog 中
            register_visdrone(dataset_name, data_dir, cfg, False)    # 调用 register_visdrone() 函数进行注册
        if cfg.SEMISUPNET.USE_SEMISUP:
        # 进行半监督学习(_C.SEMISUPNET.USE_SEMISUP = True)
            Trainer = UBTeacherTrainer    # 使用 Unbaise-Teacher 作为训练器
        else:
        # 不进行半监督学习(_C.SEMISUPNET.USE_SEMISUP = False)
            Trainer = BaselineTrainer    # 只用 基线模型 作为训练器
    
        model = Trainer.build_model(cfg)    # 按照 cfg 配置文件构建训练器模型
        if cfg.SEMISUPNET.USE_SEMISUP:
        # 进行半监督学习(_C.SEMISUPNET.USE_SEMISUP = True)
            model_teacher = Trainer.build_model(cfg)    # 按照 cfg 配置文件构建教师模型
            ensem_ts_model = EnsembleTSModel(model_teacher, model)    # 使用 EnsembleTSModel() 函数构建模型,使用 model_teacher 和 model 的参数
            # 创建 ensem_te_model 的检查点 checkpointer
            DetectionCheckpointer(
                ensem_ts_model, save_dir=cfg.OUTPUT_DIR
            ).resume_or_load(cfg.MODEL.WEIGHTS, resume=False)    # 使用 resume_or_load() 方法恢复或加载模型的检查点
        else:
        # 不进行半监督学习(_C.SEMISUPNET.USE_SEMISUP = False)
            # 创建 model 的检查点 checkpointer
            DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(cfg.MODEL.WEIGHTS, resume=False)    # 使用 resume_or_load() 方法恢复或加载模型的检查点
    
    
        data_loader = Trainer.build_test_loader(cfg, dataset_name)    # 使用 cfg 配置文件和 dataset_name 创建测试数据集
        inference_crops(ensem_ts_model.modelTeacher, data_loader, cfg)    # 调用inference_crops()函数对测试数据进行裁剪检测
        
        
        if __name__ == "__main__":
        main()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47

    四、get_crop_unlabeled_algo.py 代码

    1. 导入必要库

    import sys
    from croptrain.modeling.meta_arch.rcnn import TwoStagePseudoLabGeneralizedRCNN
    from croptrain.modeling.roi_heads.roi_heads import StandardROIHeadsPseudoLab
    from croptrain.modeling.proposal_generator.rpn import PseudoLabRPN
    from detectron2.checkpoint import DetectionCheckpointer
    from detectron2.config import get_cfg
    from croptrain import add_croptrainer_config, add_ubteacher_config
    from detectron2.data import DatasetCatalog, MetadataCatalog
    import os
    from PIL import Image
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle
    from croptrain.data.datasets.visdrone import register_visdrone
    from croptrain.engine.trainer import UBTeacherTrainer, BaselineTrainer
    import numpy as np
    import torch
    import datetime
    import time
    import copy
    import cv2
    import json
    from utils.crop_utils import get_dict_from_crops
    from contextlib import ExitStack, contextmanager
    from detectron2.structures.instances import Instances
    from detectron2.structures.boxes import Boxes
    import matplotlib.pyplot as plt
    import logging
    from utils.crop_utils import get_dict_from_crops
    from utils.box_utils import compute_one_stage_clusters, bbox_scale
    from croptrain.data.detection_utils import read_image
    from croptrain.modeling.meta_arch.ts_ensemble import EnsembleTSModel
    from detectron2.data.build import get_detection_dataset_dicts
    from detectron2.utils.logger import log_every_n_seconds
    from croptrain.data.datasets.dota import register_dota
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    2. 系统文件初始化

    sys.path.insert(0, '/home/akhil135/PhD/DroneSSOD')    # 将路径'/home/akhil135/PhD/DroneSSOD'添加到Python解释器的系统路径的开头
    logging.basicConfig(level = logging.INFO)    # 配置日志记录的级别为INFO
    
    • 1
    • 2

    3. inference_context(model) 函数

    inference_context(model) 函数,用于创建一个上下文环境,在该环境中模型会被临时更改为评估模式,并在退出该环境后恢复为之前的模式。

    @contextmanager
    def inference_context(model):
        """
        A context where the model is temporarily changed to eval mode,
        and restored to previous mode afterwards.
        模型临时更改为eval模式的上下文,然后恢复到以前的模式
    
        Args:
            model: a torch Module
        """
        training_mode = model.training    # 获取模型当前的训练模式
        model.eval()    # 模型设置为评估模式
        yield    # 使用 yield 暂停函数执行,等待外部代码进入该上下文环境
        model.train(training_mode)    # 当上下文环境退出时,恢复之前的训练模式
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    4. shift_crop_boxes(data_dict, cluster_boxes)函数

    shift_crop_boxes(data_dict, cluster_boxes) 函数,用于将给定的聚类框进行平移和裁剪操作。

    def shift_crop_boxes(data_dict, cluster_boxes):
        '''
        Args:
        data_dict: 数据信息的字典对象,包含“crop_are” 对应值(裁剪区域左上角坐标(x1,y1))
        cluster_boxes: 聚类框坐标的数组(左上角和右下角坐标(x1,y1,x2,y2))
        
        Return:
        cluster_boxes: 平移后的聚类框数组
        '''
        x1, y1 = data_dict["crop_area"][0], data_dict["crop_area"][1]    # 从 data_dict 中获取裁剪区域的左上角坐标(x1,y1)
        ref_point = np.array([x1, y1, x1, y1])    # 使用 np.array 创建一个一维数组,用于表示平移向量
        cluster_boxes = cluster_boxes + ref_point    # 使用相加实现平移操作
        return cluster_boxes
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    5. compute_crops_with_prediction(inputs, outputs, cfg) 函数【核心】

    compute_crops_with_prediction(inputs, outputs, cfg) 函数,用于根据模型的预测结果,生成裁剪后的边界框。

    def compute_crops_with_prediction(inputs, outputs, cfg):
        instances = outputs[0].get("instances", [])    # 获取输出中的实列列表
        instances = instances[instances.scores>0.6]    # 过滤掉得分低于0.6的实例
        crop_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES - 1    # ROI头部获取的类别数量 -1 作为裁剪类别
        crop_class_indices = (instances.pred_classes==crop_class)    # 获取预测类别为裁剪类别的实例索引
        instances = instances[~crop_class_indices]    # 过滤掉预测类别为裁剪类别的实例
        gt_boxes = instances.pred_boxes.tensor.cpu().numpy().astype(np.int32)    # 获取实例的预测边框
        gt_classes = instances.pred_classes.cpu().numpy().astype(np.int32)    # 获取实例的预测类别
        scaled_boxes = bbox_scale(gt_boxes.copy(), inputs[0]['height'], inputs[0]['width'])    # 对边界框进行缩放
        seg_areas = Boxes(gt_boxes).area()    # 计算缩放后的边界框面积
        data_dict_this_image = copy.deepcopy(inputs[0])    # 深拷贝输入数据字典
        objs = []    # 创建一个空列表用于储存对象信息
        for i in range(len(gt_boxes)):
        # 遍历每个实例
            obj = {}    # 创建一个空字典用于存放实例的键与键值
            obj["bbox"] = gt_boxes[i].tolist()    # 将实列的边界框信息储存到 bbox 键中
            obj["category_id"] = gt_classes[i]    # 将实例的类别信息储存到 categroy_id 键中
            objs.append(obj)    # 将一个字典对象 obj 添加到 objs 列表中
        data_dict_this_image["annotations"] = objs    # 将对象信息添加到数据字典中
        #stage 1 - merging
        data_dict_this_image, new_boxes, new_seg_areas = compute_one_stage_clusters(data_dict_this_image, scaled_boxes, seg_areas, cfg, stage=1)
        #stage 2 - merging
        data_dict_this_image, new_boxes, new_seg_areas = compute_one_stage_clusters(data_dict_this_image, new_boxes, new_seg_areas, cfg, stage=2)
        return new_boxes
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    1. compute_one_stage_clusters(data_dict, bboxes, seg_areas, cfg, stage=1) 函数

    compute_one_stage_clusters(data_dict, bboxes, seg_areas, cfg, stage=1) 函数,主要目的是在给定的图像数据中检测并计算聚类。

    def compute_one_stage_clusters(data_dict, bboxes, seg_areas, cfg, stage=1):
        '''
        Args:
        data_dict: 包含图像数据的字典,包含高度、宽度和注释信息
        bboxes: 边界框的张量
        seg_areas: 分割区域的张量
        cfg: 配置信息,包含一些参数和阈值
        stage: 一个可选的参数,默认为 1
        Return:
        data_dict: 包含图像数据的字典,包含高度、宽度和注释信息
        new_boxes: 新边界框的张量
        np.array(new_seg_areas): 新分割区域的张量
        '''
        # Boxes 类将盒子列表存储为 Nx4 torch.Tensor。支持一些关于 boxes 的常用方法(区域、剪辑、非空等),并且行为也像张量(支持索引、to(device)、.device 和所有框的迭代)
        bboxes = Boxes(bboxes)    # 将边界框转化为 Boxes 对象
        # pairwise_iou 类将给定两个大小为 N 和 M 的框列表,计算 IoU (并集上的交集)在所有 N x M 对 boxes 之间。boxes 顺序必须是(xmin,ymin,xmax,ymax)
        overlaps = pairwise_iou(bboxes, bboxes)    # 计算边界框之间的重叠度
        connectivity = (overlaps > cfg.CROPTRAIN.CLUSTER_THRESHOLD)    # 根据重叠度确定是否属于同一个聚类
        new_boxes = np.zeros((0, 4), dtype=np.int32)    # 初始化 new_boxes ,用于后续储存和处理新的边界框数据
        new_seg_areas = []    # 创建一个空列表用于存放新分割区的数据
        image_area = data_dict["height"] * data_dict["width"]    # 计算图像的面积(高度 * 宽度)
        # 循环直到没有剩余的连接
        while len(connectivity)>0:
            connections = connectivity.sum(dim=1)    # 计算每个像素与相邻像素之间的连接情况
            max_connected, max_connections = torch.argmax(connections), torch.max(connections)    # 找到具有最大连接数的像素,并计算最大的连接数
            if max_connections==1:
            # 如果最大连接数为 1 ,则跳出循环
                break
            cluster_components = torch.nonzero(connectivity[max_connected]).view(-1)    # 找到与具有最大连接数的像素相连的虽有像素的索引,并将其保存在一个张量中
            other_boxes = torch.nonzero(~connectivity[max_connected]).view(-1)    # 找到与具有最大连接数的像素不相连的所有像素的索引,并将其保存在一个张量中
            cluster_member_areas = seg_areas[cluster_components]    # 根据聚类组件的索引,从分割区域seg_areas中获取每个聚类的像素平均面积
            cluster_member_areas = cluster_member_areas / float(image_area)    # 将每个聚类的像素平均面积除以整个图像的像素数量,得到每个聚类的像素平均面积
    
            # if the bounding boxes inside a cluster are sufficiently big, detect it from the original image itself.
            # 如果簇内的边界框足够大,请从原始图像本身检测它
            # 聚类任务
            if cluster_member_areas.min()>0.2:
            # 如果某个聚类的像素平均面积小于 0.2 ,则该聚类中的像素从原始的边界框、分割区域和连接矩阵中移除
                bboxes.tensor = bboxes.tensor[other_boxes]    # 使用 other_boxes 作为索引,从 bboxes 中选择需要保留的像素
                seg_areas = seg_areas[other_boxes]    # 使用 other_boxes 作为索引,从 seg_ares 中选择需要保留的像素
                connectivity = connectivity[:, other_boxes]    # 使用 other_boxes 作为索引,从 connectivity 中选择需要保留的像素
                connectivity = connectivity[other_boxes, :]    
                if stage==1:
                # 如果当前处于聚类阶段(即stage==1),需要从 data_dict["annotations"] 中移除 other_boxes 对应的边界框信息    
                    data_dict['annotations'] = list(compress(data_dict["annotations"], other_boxes))
                continue
    
            cluster_members = bboxes.tensor[cluster_components]    # 使用 cluster_components 作为索引,从 bboxes 中选择需要保留的像素
            
            # 获取聚类成员的边界框
            x1, y1 = cluster_members[:, 0].min()-20, cluster_members[:, 1].min()-20    # 计算聚类中心点在原始图像中的位置,并将其作为新的聚类中心点(x1,y1)
            x2, y2 = cluster_members[:, 2].max()+20, cluster_members[:, 3].max()+20    # 计算聚类中心点在原始图像中的位置,并将其作为新的聚类中心点(x2,y2)
            x1, y1 = torch.clamp(x1, min=0), torch.clamp(y1, min=0)    # 将变量 x1,y1 的值限制在0及其以上
            x2, y2 = torch.clamp(x2, max=data_dict['width']), torch.clamp(y2, max=data_dict['height'])    # 将变量 x2,y2 的值限制在图像的宽度和高度范围内 
            crop_area = np.array([int(x1), int(y1), int(x2), int(y2)]).astype(np.int32)    # 将聚类中心点在原始图像中的位置转换为一个四元素的数组
            bboxes.tensor = bboxes.tensor[other_boxes]    # 使用 other_boxes 作为索引,从 bboxes 中选择需要保留的像素
            seg_areas = seg_areas[other_boxes]    # 使用 other_boxes 作为索引,从 seg_areas 中选择需要保留的像素
    
            if stage==1:
            # 如果当前处于聚类阶段(即stage==1),需要从 data_dict["annotations"] 中移除 other_boxes 对应的边界框信息
                data_dict['annotations'] = list(compress(data_dict["annotations"], other_boxes))
            new_boxes = np.append(new_boxes, crop_area.reshape(1, -1), axis=0)    # 将 crop_area 数组添加到 new_boxes 数组的末尾
            new_seg_areas.append((x2-x1) * (y2- y1))    # 计算聚类中心点在原始图像中的实际大小
            connectivity = connectivity[:, other_boxes]    # 使用 other_boxes 作为索引,从 connectivity 中选择需要保留的像素
            connectivity = connectivity[other_boxes, :]
    
        return data_dict, new_boxes, np.array(new_seg_areas)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67

    6. inference_crops(model, data_loader, cfg) 函数

    inference_crops(model, data_loader, cfg) 函数,用于在给定的数据加载器上对模型进行推理,并从模型中提取出特定类别的边框。

    def inference_crops(model, data_loader, cfg):
        '''
        Args:
        model: 用于执行推理操作的模型
        data_loader: 数据加载器,用于存放数据
        cfg: 配置参数
        
        Return:
        None
        '''
        dataset_name = cfg.DATASETS.TRAIN[0].split("_")[0]    # 使用 split() 函数实现分割,并获得数据集名称
        crop_file = os.path.join("dataseed", dataset_name + "_crops_algo_{}.txt".format(cfg.DATALOADER.SUP_PERCENT))    # 获取裁剪文件路径(dataseed文件夹下'name_crops_algo_{}.txt'文件)
        # cfg.DATALOADER.SUP_PERCENT 在 configs/*.yaml 中进行配置
        crop_storage = {}    # 创建一个空字典,用于存储每个图像文件中提取的边界框
    
        total = len(data_loader)  # inference data loader must have a fixed length(推理数据加载程序必须具有固定长度)
        cluster_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES - 1    # 获取 ROI头部 中的类别数量(10),减 1 以获得聚类的索引
        with ExitStack() as stack:    # 使用 ExitStack 上下文管理器来管理资源,确保在推出上下文时正确释放资源
            if isinstance(model, torch.nn.Module):
            # 如果 model 是 torch.nn.Module 中的类型
                # enter_context() 方法用于将给定的上下文管理器添加到 ExitStack 中,并确保在退出上下文时自动恢复资源
                # stack:用于管理资源的获取和释放,通常与 ExitStack 一起使用。
                stack.enter_context(inference_context(model))    # 调用 inference_context(model) 函数,并将返回的上下文管理器添加到 stack 中
            # torch.no_grad() 用于禁用梯度计算,从而减少内存消耗并加速推理过程
            stack.enter_context(torch.no_grad())    # 将 torch.no_grad() 函数的上下文管理器添加到 stack 中。
            count = 0    # 计数器初始化
            n_crops = 0    # 总作物数量初始化
            for idx, inputs in enumerate(data_loader):
                outputs = model(inputs)    # 进行推理
                file_name = inputs[0]["file_name"].split('/')[-1]    # 使用 split() 函数分割获取输入文件名称
                if file_name not in crop_storage:
                # 如果 crop_stroage 不存在文件名
                    crop_storage[file_name] = []    # 添加到 crop_storage 中,并将其初始化为空列表
                if idx%100==0:
                    print("processing {}th image".format(idx))    # 每 100 张图片打印一条进度信息
                if len(crop_boxes)>0:
                # 如果提取到了边界框
                    # 从 pred_boxes 中获取预测边界框张量
                    cluster_boxes = cluster_boxes.pred_boxes.tensor.cpu().numpy().astype(np.int32)    # 将 cluster_boxes 中的预测边界框从 GPU 中转移到 CPU 中,并转化为 Numpy 数组,np.int32 数据类型
                    if not inputs[0]["full_image"]:
                    # 如果输入数据中的第一个图像不是完整图像
                        cluster_boxes = shift_crop_boxes(inputs[0], crop_boxes)    # 使用 shift_crop_boxes 函数对边界框进行平移裁剪操作
                    crop_storage[file_name] += crop_boxes.tolist()    # 将相应的裁剪区域以字典值的形式储存,使用每个图像文件名作为键
                    count += 1    # 更新计数器
                    n_crops += len(crop_boxes)    # 更新从输入数据中提取的裁剪数量
        with open(crop_file, "w") as f:
            json.dump(crop_storage, f)    # 将crop_storage 字典中的内容以 JSON 格式储存
        print("crops present in {}/{} images".format(count, len(data_loader)))    # 打印数据集中图像总数
        print("number of crops is {} ".format(n_crops))    # 打印裁剪数量
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49

    7. main()函数

    main() 函数包含了一些配置和模型的构建步骤,并调用了一些参数来执行特定任务。

    def main():
        cfg = get_cfg()    # 获取 cfg 值
        add_croptrainer_config(cfg)    # 获取 add_croptrainer_config(cfg) 中配置信息
        add_ubteacher_config(cfg)    # 获取 add_ubteacher_config(cfg) 中配置信息
        # 构建路径:/home/akhil135/PhD/DroneSSOD/configs/dota/Semi-Sup-RCNN-FPN-CROP.yaml
        cfg.merge_from_file(os.path.join('/home/akhil135/PhD/DroneSSOD', 'configs', 'dota', 'Semi-Sup-RCNN-FPN-CROP.yaml'))    # 从指定文件(*.yaml)中加载配置信息到 cfg 对象中
        if cfg.CROPTRAIN.USE_CROPS:
         # 是否使用作物作为训练集的一部分(_C.CROPTRAIN.USE_CROPS = True)
            cfg.MODEL.ROI_HEADS.NUM_CLASSES += 1    # ROI头部 种类 +1(新增“密集作物”类)
        data_dir = os.path.join(os.environ['SLURM_TMPDIR'], "DOTA")    # 数据地址./SLURM_TMPDIR/DOTA
        dataset_name = cfg.DATASETS.TRAIN[0]    # 从训练数据集中获取数据名字
        cfg.OUTPUT_DIR = "/home/akhil135/scratch/DroneSSOD/DOTA_CROP_SS_10_LR_02"    # 输出地址
        #cfg.MODEL.WEIGHTS = "/home/akhil135/scratch/DroneSSOD/FPN_CROP_SS_5/model_0047999.pth" # mAP = 22.52    # 模型使用权重
        cfg.MODEL.WEIGHTS = "/home/akhil135/scratch/DroneSSOD/FPN_CROP_SS_1_07/model_0062999.pth" # mAP= 16.74
        #cfg.MODEL.WEIGHTS = "/home/akhil135/scratch/DroneSSOD/FPN_CROP_SS_10_06/model_0071999.pth" # mAP = 26.48
        if not dataset_name in DatasetCatalog:
        # dataset_name 不在 DatasetCatalog 中
            #register_visdrone(dataset_name, data_dir, cfg, False)
            register_dota(dataset_name, data_dir, cfg, True)    # 调用 register_dota() 函数进行注册数据集(训练集)
        if cfg.SEMISUPNET.USE_SEMISUP:
        # 进行半监督学习(_C.SEMISUPNET.USE_SEMISUP = True)
            Trainer = UBTeacherTrainer    # 使用 Unbaise-Teacher 作为训练器
        else:
        # 不进行半监督学习(_C.SEMISUPNET.USE_SEMISUP = False)
            Trainer = BaselineTrainer    # 只用 基线模型 作为训练器
    
        model = Trainer.build_model(cfg)    # 按照 cfg 配置文件构建训练器模型
        if cfg.SEMISUPNET.USE_SEMISUP:
        # 进行半监督学习(_C.SEMISUPNET.USE_SEMISUP = True)
            model_teacher = Trainer.build_model(cfg)    # 按照 cfg 配置文件构建教师模型
            ensem_ts_model = EnsembleTSModel(model_teacher, model)    # 使用 EnsembleTSModel() 函数构建模型,使用 model_teacher 和 model 的参数
            # 创建 ensem_te_model 的检查点 checkpointer
            DetectionCheckpointer(
                ensem_ts_model, save_dir=cfg.OUTPUT_DIR
            ).resume_or_load(cfg.MODEL.WEIGHTS, resume=False)    # 使用 resume_or_load() 方法恢复或加载模型的检查点
        else:
        # 不进行半监督学习(_C.SEMISUPNET.USE_SEMISUP = False)
            # 创建 model 的检查点 checkpointer
            DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(cfg.MODEL.WEIGHTS, resume=False)    # 使用 resume_or_load() 方法恢复或加载模型的检查点
    
    
        data_loader = Trainer.build_test_loader(cfg, dataset_name)    # 使用 cfg 配置文件和 dataset_name 创建测试数据集
        inference_crops(ensem_ts_model.modelTeacher, data_loader, cfg)    # 调用inference_crops()函数对测试数据进行裁剪检测
        
        
        if __name__ == "__main__":
        main()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47

    五、crop_unlabeled_algo_dota.py 代码

    1. 导入必要库

    import sys
    from croptrain.modeling.meta_arch.rcnn import TwoStagePseudoLabGeneralizedRCNN
    from croptrain.modeling.roi_heads.roi_heads import StandardROIHeadsPseudoLab
    from croptrain.modeling.proposal_generator.rpn import PseudoLabRPN
    from detectron2.checkpoint import DetectionCheckpointer
    from detectron2.config import get_cfg
    from croptrain import add_croptrainer_config, add_ubteacher_config
    from detectron2.data import DatasetCatalog, MetadataCatalog
    import os
    from PIL import Image
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle
    from croptrain.data.datasets.visdrone import register_visdrone
    from croptrain.engine.trainer import UBTeacherTrainer, BaselineTrainer
    import numpy as np
    import torch
    import datetime
    import time
    import copy
    import cv2
    import json
    from utils.crop_utils import get_dict_from_crops
    from contextlib import ExitStack, contextmanager
    from detectron2.structures.instances import Instances
    from detectron2.structures.boxes import Boxes
    import matplotlib.pyplot as plt
    import logging
    from croptrain.data.datasets.dota import get_overlapping_sliding_window
    from utils.box_utils import compute_one_stage_clusters, bbox_scale
    from croptrain.data.detection_utils import read_image
    from croptrain.modeling.meta_arch.ts_ensemble import EnsembleTSModel
    from detectron2.data.build import get_detection_dataset_dicts
    from detectron2.utils.logger import log_every_n_seconds
    from croptrain.data.datasets.dota import register_dota
    from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference
    from detectron2.evaluation import COCOEvaluator
    from utils.plot_utils import plot_detections
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37

    2. 系统文件初始化

    sys.path.insert(0, '/home/akhil135/PhD/DroneSSOD')    # 将路径'/home/akhil135/PhD/DroneSSOD'添加到Python解释器的系统路径的开头
    logging.basicConfig(level = logging.INFO)    # 配置日志记录的级别为INFO
    
    • 1
    • 2

    3. inference_context(model) 函数

    inference_context(model) 函数,用于创建一个上下文环境,在该环境中模型会被临时更改为评估模式,并在退出该环境后恢复为之前的模式。

    @contextmanager
    def inference_context(model):
        """
        A context where the model is temporarily changed to eval mode,
        and restored to previous mode afterwards.
        模型临时更改为eval模式的上下文,然后恢复到以前的模式
    
        Args:
            model: a torch Module
        """
        training_mode = model.training    # 获取模型当前的训练模式
        model.eval()    # 模型设置为评估模式
        yield    # 使用 yield 暂停函数执行,等待外部代码进入该上下文环境
        model.train(training_mode)    # 当上下文环境退出时,恢复之前的训练模式
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    4. shift_crop_boxes(data_dict, cluster_boxes)函数

    shift_crop_boxes(data_dict, cluster_boxes) 函数,用于将给定的聚类框进行平移和裁剪操作。

    def shift_crop_boxes(data_dict, cluster_boxes):
        '''
        Args:
        data_dict: 数据信息的字典对象,包含“crop_are” 对应值(裁剪区域左上角坐标(x1,y1))
        cluster_boxes: 聚类框坐标的数组(左上角和右下角坐标(x1,y1,x2,y2))
        
        Return:
        cluster_boxes: 平移后的聚类框数组
        '''
        x1, y1 = data_dict["crop_area"][0], data_dict["crop_area"][1]    # 从 data_dict 中获取裁剪区域的左上角坐标(x1,y1)
        ref_point = np.array([x1, y1, x1, y1])    # 使用 np.array 创建一个一维数组,用于表示平移向量
        cluster_boxes = cluster_boxes + ref_point    # 使用相加实现平移操作
        return cluster_boxes
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    5. plot_detection_boxes(predictions, cluster_boxes, data_dict) 函数【核心】

    plot_detection_boxes(predictions, cluster_boxes, data_dict) 函数,用于在图像上绘制检测到的边界框。

    def plot_detection_boxes(predictions, cluster_boxes, data_dict):
        img = Image.open(data_dict["file_name"])  # 打开图像文件
        plt.axis('off')  # 关闭坐标轴
        plt.imshow(img)  # 显示图像
        ax = plt.gca()  # 获取当前的坐标轴对象
    
        if len(predictions) != 0:
            predictions = predictions[predictions.scores > 0.6]  # 过滤置信度大于0.6的预测结果
            predictions = predictions.pred_boxes.tensor.cpu()  # 将预测结果转换为张量,并移动到CPU上
            for bbox in predictions:
                x1, y1 = bbox[0], bbox[1]  # 获取边界框的左上角坐标
                h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]  # 计算边界框的高度和宽度
                rect = Rectangle((x1, y1), w, h, linewidth=2, edgecolor='orange', facecolor='none')  # 创建橙色边框的矩形对象
                ax.add_patch(rect)  # 将矩形对象添加到图像上
    
        if len(cluster_boxes) != 0:
            if isinstance(cluster_boxes, Instances):
                cluster_boxes = cluster_boxes.pred_boxes.tensor.cpu()  # 将聚类结果转换为张量,并移动到CPU上
            for bbox in cluster_boxes:
                x1, y1 = bbox[0], bbox[1]  # 获取聚类边界框的左上角坐标
                h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]  # 计算聚类边界框的高度和宽度
                rect = Rectangle((x1, y1), w, h, linewidth=2, edgecolor='r', facecolor='none')  # 创建红色边框的矩形对象
                ax.add_patch(rect)  # 将矩形对象添加到图像上
    
        im_name = os.path.basename(data_dict["file_name"])[:-4]  # 获取图像文件名(去除后缀)
        plt.savefig(os.path.join('./temp', im_name + "_det.jpg"), dpi=90, bbox_inches='tight')  # 保存绘制了边界框的图像
        plt.clf()  # 清除当前图像
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    6. compute_crops_with_prediction(inputs, outputs, cfg) 函数【核心】

    compute_crops_with_prediction(inputs, outputs, cfg) 函数,用于根据模型的预测结果,生成裁剪后的边界框。

    def compute_crops_with_prediction(inputs, outputs, cfg):
        instances = outputs[0].get("instances", [])    # 获取输出中的实列列表
        instances = instances[instances.scores>0.6]    # 过滤掉得分低于0.6的实例
        crop_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES - 1    # ROI头部获取的类别数量 -1 作为裁剪类别
        crop_class_indices = (instances.pred_classes==crop_class)    # 获取预测类别为裁剪类别的实例索引
        instances = instances[~crop_class_indices]    # 过滤掉预测类别为裁剪类别的实例
        gt_boxes = instances.pred_boxes.tensor.cpu().numpy().astype(np.int32)    # 获取实例的预测边框
        gt_classes = instances.pred_classes.cpu().numpy().astype(np.int32)    # 获取实例的预测类别
        scaled_boxes = bbox_scale(gt_boxes.copy(), inputs[0]['height'], inputs[0]['width'])    # 对边界框进行缩放
        seg_areas = Boxes(gt_boxes).area()    # 计算缩放后的边界框面积
        data_dict_this_image = copy.deepcopy(inputs[0])    # 深拷贝输入数据字典
        objs = []    # 创建一个空列表用于储存对象信息
        for i in range(len(gt_boxes)):
        # 遍历每个实例
            obj = {}    # 创建一个空字典用于存放实例的键与键值
            obj["bbox"] = gt_boxes[i].tolist()    # 将实列的边界框信息储存到 bbox 键中
            obj["category_id"] = gt_classes[i]    # 将实例的类别信息储存到 categroy_id 键中
            objs.append(obj)    # 将一个字典对象 obj 添加到 objs 列表中
        data_dict_this_image["annotations"] = objs    # 将对象信息添加到数据字典中
        #stage 1 - merging
        data_dict_this_image, new_boxes, new_seg_areas = compute_one_stage_clusters(data_dict_this_image, scaled_boxes, seg_areas, cfg, stage=1)
        #stage 2 - merging
        data_dict_this_image, new_boxes, new_seg_areas = compute_one_stage_clusters(data_dict_this_image, new_boxes, new_seg_areas, cfg, stage=2)
        return new_boxes
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    7. inference_crops(model, cfg) 函数

    inference_crops(model, cfg) 函数,用于在给定的数据加载器上对模型进行推理,并从模型中提取出特定类别的边框。

    def inference_crops(model, cfg):
        '''
        Args:
        model: 用于执行推理操作的模型
        cfg: 配置参数
        
        Return:
        None
        '''
        dataset_dicts = get_detection_dataset_dicts(cfg.DATASETS.TRAIN, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS)    # 获取一个检测数据集的字典列表
        print("len of dataset dicts: {}".format(len(dataset_dicts)))    # 打印输出数据集字典的长度
        dataset_name = cfg.DATASETS.TRAIN[0].split("_")[0]    # 使用 split() 函数实现分割,并获得数据集名称
        crop_file = os.path.join("dataseed", dataset_name + "_crops_algo_{}.txt".format(cfg.DATALOADER.SUP_PERCENT))    # 获取裁剪文件路径(dataseed文件夹下'name_crops_algo_{}.txt'文件)
        # cfg.DATALOADER.SUP_PERCENT 在 configs/*.yaml 中进行配置
        crop_storage = {}    # 创建一个空字典,用于存储每个图像文件中提取的边界框
        
        cluster_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES - 1    # 获取 ROI头部 中的类别数量(10),减 1 以获得聚类的索引
        with ExitStack() as stack:    # 使用 ExitStack 上下文管理器来管理资源,确保在推出上下文时正确释放资源
            if isinstance(model, torch.nn.Module):
            # 如果 model 是 torch.nn.Module 中的类型
                # enter_context() 方法用于将给定的上下文管理器添加到 ExitStack 中,并确保在退出上下文时自动恢复资源
                # stack:用于管理资源的获取和释放,通常与 ExitStack 一起使用。
                stack.enter_context(inference_context(model))    # 调用 inference_context(model) 函数,并将返回的上下文管理器添加到 stack 中
            # torch.no_grad() 用于禁用梯度计算,从而减少内存消耗并加速推理过程
            stack.enter_context(torch.no_grad())    # 将 torch.no_grad() 函数的上下文管理器添加到 stack 中。
            count = 0    # 计数器初始化
            n_crops = 0    # 总作物数量初始化
            for idx, inputs in enumerate(dataset_dicts):
            # 遍历整个检测数据集的字典列表
                # 调用get_overlapping_sliding_window函数获取当前数据集的重叠滑动窗口,返回一个新的边界框列表new_boxes
                new_boxes = get_overlapping_sliding_window(dataset_dicts[idx])
                # 调用get_dict_from_crops函数将new_boxes、当前数据集字典dataset_dicts[idx]以及最小测试图像尺寸cfg.INPUT.MIN_SIZE_TEST作为输入,返回一个新的数据字典列表new_data_dicts
                new_data_dicts = get_dict_from_crops(new_boxes, dataset_dicts[idx], cfg.INPUT.MIN_SIZE_TEST)
                # 创建一个形状为(height, width)的张量image_shapes,其中height和width分别为当前数据集字典中的图像高度和宽度
                # 创建形状,储存边界框
                boxes = torch.zeros(0, cfg.MODEL.ROI_HEADS.NUM_CLASSES*4).to(model.device)
                # 创建形状,储存得分
                scores = torch.zeros(0, cfg.MODEL.ROI_HEADS.NUM_CLASSES+1).to(model.device)
                for data_dict in new_data_dicts:
                # 遍历新检测数据集的字典列表
                    # 调用模型的前向传播函数model,传入data_dict、infer_on_crops=True和配置参数cfg,得到边界框和得分的输出boxes_patch和scores_patch
                    boxes_patch, scores_patch = model([data_dict], infer_on_crops=True, cfg=cfg)
                    # 将 boxes_patch[0] 和 boxes 在维度0上拼接起来,更新 boxes 值
                    boxes = torch.cat([boxes, boxes_patch[0]], dim=0)
                    # 将 scores_patch[0] 和 scores 在维度0上拼接起来,更新 scores 值
                    scores = torch.cat([scores, scores_patch[0]], dim=0)
                # 调用fast_rcnn_inference函数对边界框和得分进行后处理和预测,得到最终的预测结果pred_instances
                pred_instances, _ = fast_rcnn_inference([boxes], [scores], image_shapes, cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST, \
                                        cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST, cfg.CROPTEST.DETECTIONS_PER_IMAGE)
                pred_instances = pred_instances[0]    # 取出第一个元素作为最终的预测结果
                outputs = [{"instances": pred_instances}]    # 将 pres_instances 的值保存在 instance 键中
                file_name = dataset_dicts[idx]["file_name"].split('/')[-1]    # 使用 split() 函数分割获取输入文件名称
                if file_name not in crop_storage:
                # 如果 crop_stroage 不存在文件名
                    crop_storage[file_name] = []    # 添加到 crop_storage 中,并将其初始化为空列表
                #try:
                crop_boxes = compute_crops_with_prediction(dataset_dicts[idx], outputs, cfg)[:10]    # 调用 compute_crops_with_prediction 函数,截取前10个元素进行保存
                #except:
                #    print("failed for image {}".format(idx+1))
                #    crop_boxes = []
                if idx%100==0:
                    print("processing {}th image".format(idx))    # 每 100 张图片打印一条进度信息
                    plot_detection_boxes(pred_instances, crop_boxes, dataset_dicts[idx])    # 调用 plot_detection_boxes 函数,可视化细节
                if len(crop_boxes)>0:
                # 如果提取到了边界框
                    if not dataset_dicts[idx]["full_image"]:
                    # 如果 dataset_dicts[idx] 字典中未出现 “full_image” 字符
                        crop_boxes = shift_crop_boxes(dataset_dicts[idx], crop_boxes)    # 使用 shift_crop_boxes 函数对边界框进行平移裁剪操作
                    crop_storage[file_name] += crop_boxes.tolist()    # 将相应的裁剪区域以字典值的形式储存,使用每个图像文件名作为键
                    count += 1    # 更新计数器
                    n_crops += len(crop_boxes)    # 更新从输入数据中提取的裁剪数量
            del boxes, scores, new_data_dicts        
        print("crops present in {}/{} images".format(count, len(data_loader)))    # 打印数据集中图像总数
        print("number of crops is {} ".format(n_crops))    # 打印裁剪数量
        with open(crop_file, "w") as f:
            json.dump(crop_storage, f)    # 将crop_storage 字典中的内容以 JSON 格式储存
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76

    8. main()函数

    main()函数包含了一些配置和模型的构建步骤,并调用了一些参数来执行特定任务。

    def main():
        cfg = get_cfg()    # 获取 cfg 值
        add_croptrainer_config(cfg)    # 获取 add_croptrainer_config(cfg) 中配置信息
        add_ubteacher_config(cfg)    # 获取 add_ubteacher_config(cfg) 中配置信息
        # 构建路径:/home/akhil135/PhD/DroneSSOD/configs/dota/Semi-Sup-RCNN-FPN-CROP.yaml
        cfg.merge_from_file(os.path.join('/home/akhil135/PhD/DroneSSOD', 'configs', 'dota', 'Semi-Sup-RCNN-FPN-CROP.yaml'))    # 从指定文件(*.yaml)中加载配置信息到 cfg 对象中
        if cfg.CROPTRAIN.USE_CROPS:
         # 是否使用作物作为训练集的一部分(_C.CROPTRAIN.USE_CROPS = True)
            cfg.MODEL.ROI_HEADS.NUM_CLASSES += 1    # ROI头部 种类 +1(新增“密集作物”类)
        data_dir = os.path.join(os.environ['SLURM_TMPDIR'], "DOTA")    # 数据地址./SLURM_TMPDIR/DOTA
        dataset_name = cfg.DATASETS.TRAIN[0]    # 从训练数据集中获取数据名字
        #cfg.OUTPUT_DIR = "/home/akhil135/scratch/DroneSSOD/DOTA_CROP_SS_10_LR_02"
        #cfg.MODEL.WEIGHTS = "/home/akhil135/scratch/DroneSSOD/FPN_CROP_SS_1_07/model_0062999.pth" # mAP= 16.74
        #cfg.MODEL.WEIGHTS = "/home/akhil135/scratch/DroneSSOD/DOTA_CROP_SS_10_LR_02/model_0092999.pth"
        #cfg.MODEL.WEIGHTS = "/home/akhil135/scratch/DroneSSOD/DOTA_CROP_SS_5/model_0062999.pth"
        cfg.MODEL.WEIGHTS = "/home/akhil135/scratch/DroneSSOD/DOTA_CROP_SS_1_06/model_0020999.pth"    # 模型权重
    if not dataset_name in DatasetCatalog:
        # dataset_name 不在 DatasetCatalog 中
            #register_visdrone(dataset_name, data_dir, cfg, False)
            register_dota(dataset_name, data_dir, cfg, False)    # 调用 register_dota() 函数进行注册数据集(非训练集)
        if cfg.SEMISUPNET.USE_SEMISUP:
        # 进行半监督学习(_C.SEMISUPNET.USE_SEMISUP = True)
            Trainer = UBTeacherTrainer    # 使用 Unbaise-Teacher 作为训练器
        else:
        # 不进行半监督学习(_C.SEMISUPNET.USE_SEMISUP = False)
            Trainer = BaselineTrainer    # 只用 基线模型 作为训练器
    
        model = Trainer.build_model(cfg)    # 按照 cfg 配置文件构建训练器模型
        if cfg.SEMISUPNET.USE_SEMISUP:
        # 进行半监督学习(_C.SEMISUPNET.USE_SEMISUP = True)
            model_teacher = Trainer.build_model(cfg)    # 按照 cfg 配置文件构建教师模型
            ensem_ts_model = EnsembleTSModel(model_teacher, model)    # 使用 EnsembleTSModel() 函数构建模型,使用 model_teacher 和 model 的参数
            # 创建 ensem_te_model 的检查点 checkpointer
            DetectionCheckpointer(
                ensem_ts_model, save_dir=cfg.OUTPUT_DIR
            ).resume_or_load(cfg.MODEL.WEIGHTS, resume=False)    # 使用 resume_or_load() 方法恢复或加载模型的检查点
        else:
        # 不进行半监督学习(_C.SEMISUPNET.USE_SEMISUP = False)
            # 创建 model 的检查点 checkpointer
            DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(cfg.MODEL.WEIGHTS, resume=False)    # 使用 resume_or_load() 方法恢复或加载模型的检查点
            
        inference_crops(ensem_ts_model.modelTeacher, data_loader, cfg)    # 调用inference_crops()函数对测试数据进行裁剪检测
        
        
        if __name__ == "__main__":
        main()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46

    六、结束

    Drone SSOD 的 Density Crop Labeling 部分到此位置就结束了,作者是初入这个领域的小白,如果有什么错误,欢迎各位大佬指正!!

  • 相关阅读:
    springboot闲置衣物捐赠系统毕业设计源码021009
    《MongoDB》在docker中用得到关于MongoDB一行命令
    uview-plus的DatetimePicker如何只选年份?
    基于HASM模型的土壤高精度建模matlab仿真
    【数据结构】带头双向循环链表基本操作的实现(C语言)
    图文并茂演示小程序movable-view的可移动范围
    MySQL主从复制实现高可用性和负载均衡
    23..【摆脱list链表的束缚、让你爱上链表】
    Docker原理与基础命令
    WebSocket技术解析:实现Web实时双向通信的利器
  • 原文地址:https://blog.csdn.net/qq_44102942/article/details/132593070