• YOLOv5 分类模型 Top 1和Top 5 指标实现


    YOLOv5 分类模型 Top 1和Top 5 指标实现

    flyfish

    import time
    from models.common import DetectMultiBackend
    import os
    import os.path
    from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
    import cv2
    import numpy as np
    
    import torch
    from utils.augmentations import classify_transforms
    
    
    class DatasetFolder:
    
        def __init__(
            self,
            root: str,
    
        ) -> None:
            self.root = root
            classes, class_to_idx = self.find_classes(self.root)
            samples = self.make_dataset(self.root, class_to_idx)
    
            self.classes = classes
            self.class_to_idx = class_to_idx
            self.samples = samples
            self.targets = [s[1] for s in samples]
    
        @staticmethod
        def make_dataset(
            directory: str,
            class_to_idx: Optional[Dict[str, int]] = None,
    
        ) -> List[Tuple[str, int]]:
    
            directory = os.path.expanduser(directory)
    
            if class_to_idx is None:
                _, class_to_idx = self.find_classes(directory)
            elif not class_to_idx:
                raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
    
            instances = []
            available_classes = set()
            for target_class in sorted(class_to_idx.keys()):
                class_index = class_to_idx[target_class]
                target_dir = os.path.join(directory, target_class)
                if not os.path.isdir(target_dir):
                    continue
                for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                    for fname in sorted(fnames):
                        path = os.path.join(root, fname)
                        if 1:  # 验证:
                            item = path, class_index
                            instances.append(item)
    
                            if target_class not in available_classes:
                                available_classes.add(target_class)
    
            empty_classes = set(class_to_idx.keys()) - available_classes
            if empty_classes:
                msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
    
            return instances
    
        def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
    
            classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
            if not classes:
                raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
    
            class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
            return classes, class_to_idx
    
        def __getitem__(self, index: int) -> Tuple[Any, Any]:
    
            path, target = self.samples[index]
            sample = self.loader(path)
    
            return sample, target
    
        def __len__(self) -> int:
            return len(self.samples)
    
        def loader(self, path):
            print("path:", path)
            img = cv2.imread(path)  # BGR HWC
            return img
    
    
    def time_sync():
        return time.time()
    
    
    dataset = DatasetFolder(root="/media/flyfish/test/val")
    
    # image, label=dataset[7]
    # print(image.shape)
    #
    weights = "/media/flyfish/yolov5-6.2/classes10.pt"
    device = "cpu"
    model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
    model.eval()
    
    transforms = classify_transforms(224)
    
    pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
    # current batch size =1
    for i, (images, labels) in enumerate(dataset):
        print("i:", i)
        print(images.shape, labels)
        im = cv2.cvtColor(images, cv2.COLOR_BGR2RGB)
        im = transforms(im)
        images = im.unsqueeze(0).to("cpu")
     
        print(images.shape)
    
    
            
        t1 = time_sync()
        images = images.to(device, non_blocking=True)
        t2 = time_sync()
        # dt[0] += t2 - t1
    
        y = model(images)
        y=y.numpy()
       
        print("y:", y)
        t3 = time_sync()
        # dt[1] += t3 - t2
    
        tmp1=y.argsort()[:,::-1][:, :5]
       
        print("tmp1:", tmp1)
        pred.append(tmp1)
    
        print("labels:", labels)
    
        
        targets.append(labels)
    
        print("for pred:", pred)  # list
        print("for targets:", targets)  # list
    
        # dt[2] += time_sync() - t3
    
    
    pred, targets = np.concatenate(pred), np.array(targets)
    print("pred:", pred)
    print("pred:", pred.shape)
    print("targets:", targets)
    print("targets:", targets.shape)
    correct = ((targets[:, None] == pred)).astype(np.float32)
    print("correct:", correct.shape)
    print("correct:", correct)
    acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
    print("acc:", acc.shape)
    print("acc:", acc)
    top = acc.mean(0)
    print("top1:", top[0])
    print("top5:", top[1])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161

    输出

    pred: [[7 4 0 5 9]
     [9 2 4 6 7]
     [8 9 6 2 1]
     [8 9 6 2 7]
     [9 2 4 6 3]
     [6 7 1 2 9]
     [4 2 1 8 9]
     [6 8 9 5 2]
     [8 7 4 2 6]
     [9 8 2 6 4]
     [2 9 8 0 6]
     [7 4 8 6 3]]
    pred: (12, 5)
    targets: [0 0 0 0 1 1 1 1 2 2 2 2]
    targets: (12,)
    correct: (12, 5)
    correct: [[          0           0           1           0           0]
     [          0           0           0           0           0]
     [          0           0           0           0           0]
     [          0           0           0           0           0]
     [          0           0           0           0           0]
     [          0           0           1           0           0]
     [          0           0           1           0           0]
     [          0           0           0           0           0]
     [          0           0           0           1           0]
     [          0           0           1           0           0]
     [          1           0           0           0           0]
     [          0           0           0           0           0]]
    acc: (12, 2)
    acc: [[          0           1]
     [          0           0]
     [          0           0]
     [          0           0]
     [          0           0]
     [          0           1]
     [          0           1]
     [          0           0]
     [          0           1]
     [          0           1]
     [          1           1]
     [          0           0]]
    top1: 0.083333336
    top5: 0.5
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43

    Yolov5 6.2 原版输出

    pred: tensor([[6, 7, 1, 2, 9],
            [9, 2, 4, 6, 3],
            [7, 4, 0, 5, 9],
            [9, 8, 2, 6, 4],
            [6, 8, 9, 5, 2],
            [8, 7, 4, 2, 6],
            [9, 2, 4, 6, 7],
            [2, 9, 8, 0, 6],
            [8, 9, 6, 2, 7],
            [7, 4, 8, 6, 3],
            [4, 2, 1, 8, 9],
            [8, 9, 6, 2, 1]])
    pred: torch.Size([12, 5])
    targets: tensor([1, 1, 0, 2, 1, 2, 0, 2, 0, 2, 1, 0])
    targets: torch.Size([12])
    correct: torch.Size([12, 5])
    acc: torch.Size([12, 2])
    top1: 0.0833333358168602
    top5: 0.5
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    文本代码是按照标签,即文件夹名字排序的,pred和target都是一一对应的,与Yolov5 6.2 原版相同

  • 相关阅读:
    [工业自动化-8]:西门子S7-15xxx编程 - PLC主站 - CPU模块
    hdc_std安装配置以及常用命令
    Django(ORM事务操作|ORM常见字段类型|ORM常见字段参数|关系字段|Meta元信息)
    需突破技术成本瓶颈 秸秆饲料前景如何国稻种芯现代饲料规划
    基于Vue+Nodejs实现医药商城销售系统
    Redis基础
    【Python自然语言处理+tkinter图形化界面】实现智能医疗客服问答机器人实战(附源码、数据集、演示 超详细)
    OpenCV图像处理学习二十一,直方图比较方法
    Bean容器里的单例是根据什么识别它是同一个类呢?(比如容器里创建了A类,再去用这个A类的时候,Bean容器怎么知道这个就是A类?)
    2.DesignForClines\3.QuickBusRouting
  • 原文地址:https://blog.csdn.net/flyfish1986/article/details/134447456