如有错误,恳请指出。
以下内容以训练few-shot的tfa模型为例,其中的tfa是一个小样本的目标检测模型。大致介绍如何使用MMFewShot,进行模型的训练(或者微调),以及推理验证。
配置安装,以下的配置亲测可用:

安装指令:
# install mmcv mmclassification mmdetection
pip install mmcv-full==1.6.1 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html
pip install mmcls==0.23.2
pip install mmdet==2.25.0
pip install mmfewshot
# install mmfewshot
git clone https://github.com/open-mmlab/mmfewshot.git
cd mmfewshot
pip install -r requirements/build.txt
pip install -v -e . # or "python setup.py develop"
需要注意,一定要进行后续的 pip install -r requirements/build.txt 与 pip install -v -e . ,否则你的安装是不完整的,可能还会导致无法使用。
训练过程:此处以在vol数据集中训练tfa模型为例
python tools/detection/train.py configs/detection/tfa/voc/split1/tfa_r101_fpn_voc-split1_base-training.py --gpu-id 0
python tools/detection/misc/initialize_bbox_head.py \
--src1 work_dirs/tfa_r101_fpn_voc-split1_base-training/latest.pth \
--method random_init \
--save-dir work_dirs/tfa_r101_fpn_voc-split1_base-training/
由于现在需要进行few-shot微调,需要进行特定的数据集处理,下载few-shot的标注信息,分别下载完解压到 data/few_shot_ann/ 路径下即可
coco数据集:https://download.openmmlab.com/mmfewshot/few_shot_ann/coco.tar.gz
mmfewshot
├── mmfewshot
├── tools
├── configs
├── data
│ ├── coco
│ │ ├── annotations
│ │ ├── train2014
│ │ ├── val2014
│ │ ├── train2017 (optional)
│ │ ├── val2017 (optional)
│ ├── few_shot_ann
│ │ ├── coco
│ │ │ ├── annotations
│ │ │ │ ├── train.json
│ │ │ │ ├── val.json
│ │ │ ├── attention_rpn_10shot (for coco17)
│ │ │ ├── benchmark_10shot
│ │ │ ├── benchmark_30shot
voc数据集:https://download.openmmlab.com/mmfewshot/few_shot_ann/voc.tar.gz
mmfewshot
├── mmfewshot
├── tools
├── configs
├── data
│ ├── VOCdevkit
│ │ ├── VOC2007
│ │ ├── VOC2012
│ ├── few_shot_ann
│ │ ├── voc
│ │ │ ├── benchmark_1shot
│ │ │ ├── benchmark_2shot
│ │ │ ├── benchmark_3shot
│ │ │ ├── benchmark_5shot
│ │ │ ├── benchmark_10shot
处理完的数据集结构如下所示:

准备完数据集,即可进行训练
CUDA_VISIBLE_DEVICES=0,1,3 bash tools/detection/dist_train.sh \
configs/detection/tfa/voc/split1/tfa_r101_fpn_voc-split1_5shot-fine-tuning.py 3
ps:这里的配置文件会自动的在相关的路径下架子基础训练的模型,比如:work_dirs/tfa_r101_fpn_voc-split1_base-training/base_model_random_init_bbox_head.pth,在配置文件中设置如下
# base model needs to be initialized with following script:
# tools/detection/misc/initialize_bbox_head.py
# please refer to configs/detection/tfa/README.md for more details.
load_from = ('work_dirs/tfa_r101_fpn_voc-split1_base-training/'
'base_model_random_init_bbox_head.pth')
训练结束后,由于此时已经微调结束,所以在对应的目录下tfa_r101_fpn_voc-split1_5shot-fine-tuning,会生成相对应的权重,如下所示:

在训练结束后,打印的信息有一个log文件记录,如下所示:
+-------------+------+-------+--------+-------+
| class | gts | dets | recall | ap |
+-------------+------+-------+--------+-------+
| aeroplane | 285 | 2296 | 0.877 | 0.615 |
| bicycle | 337 | 2700 | 0.840 | 0.457 |
| boat | 263 | 2400 | 0.768 | 0.372 |
| bottle | 469 | 5199 | 0.795 | 0.484 |
| car | 1201 | 5936 | 0.938 | 0.780 |
| cat | 358 | 2728 | 0.925 | 0.495 |
| chair | 756 | 7951 | 0.799 | 0.400 |
| diningtable | 206 | 5288 | 0.859 | 0.367 |
| dog | 489 | 3675 | 0.894 | 0.405 |
| horse | 348 | 7908 | 0.937 | 0.462 |
| person | 4528 | 14036 | 0.917 | 0.691 |
| pottedplant | 480 | 4764 | 0.729 | 0.295 |
| sheep | 242 | 2452 | 0.884 | 0.467 |
| train | 282 | 2988 | 0.837 | 0.398 |
| tvmonitor | 308 | 5185 | 0.880 | 0.549 |
| bird | 459 | 4232 | 0.693 | 0.247 |
| bus | 213 | 4890 | 0.864 | 0.267 |
| cow | 244 | 3942 | 0.947 | 0.380 |
| motorbike | 325 | 6876 | 0.871 | 0.461 |
| sofa | 239 | 6499 | 0.787 | 0.195 |
+-------------+------+-------+--------+-------+
| mAP | | | | 0.439 |
+-------------+------+-------+--------+-------+
2022-09-08 18:46:13,426 - mmfewshot - INFO - BASE_CLASSES_SPLIT1 mAP: 0.4825480580329895
2022-09-08 18:46:13,426 - mmfewshot - INFO - NOVEL_CLASSES_SPLIT1 mAP: 0.3099679946899414
2022-09-08 18:46:13,432 - mmfewshot - INFO - Exp name: tfa_r101_fpn_voc-split1_5shot-fine-tuning.py
2022-09-08 18:46:13,432 - mmfewshot - INFO - Iter(val) [1651] AP50: 0.4390, BASE_CLASSES_SPLIT1: AP50: 0.4830, NOVEL_CLASSES_SPLIT1: AP50: 0.3100, mAP: 0.4394
将刚刚训练好的模型进行验证,以查看效果是否匹配,这里验证的是最新的模型,也就是latest.pth
python tools/detection/test.py \
configs/detection/tfa/voc/split1/tfa_r101_fpn_voc-split1_5shot-fine-tuning.py \
work_dirs/tfa_r101_fpn_voc-split1_5shot-fine-tuning/latest.pth \
--eval mAP --gpu-id 2
验证结果:

可以看见,最后的验证结果与训练时刻的验证结果是一致的
参考资料:
https://mmfewshot.readthedocs.io/en/latest/index.html