• YOLOv5 分类模型 数据集加载 3


    YOLOv5 分类模型 数据集加载 3 自定义类别

    flyfish

    YOLOv5 分类模型 数据集加载 1 样本处理
    YOLOv5 分类模型 数据集加载 2 切片处理
    YOLOv5 分类模型的预处理(1) Resize 和 CenterCrop
    YOLOv5 分类模型的预处理(2)ToTensor 和 Normalize
    YOLOv5 分类模型 Top 1和Top 5 指标说明
    YOLOv5 分类模型 Top 1和Top 5 指标实现

    之前的处理方式是类别名字是文件夹名字,类别ID是按照文件夹名字的字母顺序
    现在是类别名字是文件夹名字,按照文件列表名字顺序 例如

    classes_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 
    'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']
    
    • 1
    • 2

    n02086240 类别ID是0
    n02087394 类别ID是1
    代码处理是

    if classes_name is None or not classes_name:
        classes, class_to_idx = self.find_classes(self.root)
        print("not classes_name")
    
    else:
        classes = classes_name
        class_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}
        print("is classes_name")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    完整

    import time
    from models.common import DetectMultiBackend
    import os
    import os.path
    from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
    import cv2
    import numpy as np
    
    import torch
    from PIL import Image
    import torchvision.transforms as transforms
    
    import sys
    
    classes_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']
                  
    class DatasetFolder:
    
        def __init__(
            self,
            root: str,
    
        ) -> None:
            self.root = root
    
            if classes_name is None or not classes_name:
                classes, class_to_idx = self.find_classes(self.root)
                print("not classes_name")
    
            else:
                classes = classes_name
                class_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}
                print("is classes_name")
    
            print("classes:",classes)
            
            print("class_to_idx:",class_to_idx)
            samples = self.make_dataset(self.root, class_to_idx)
    
            self.classes = classes
            self.class_to_idx = class_to_idx
            self.samples = samples
            self.targets = [s[1] for s in samples]
    
        @staticmethod
        def make_dataset(
            directory: str,
            class_to_idx: Optional[Dict[str, int]] = None,
    
        ) -> List[Tuple[str, int]]:
    
            directory = os.path.expanduser(directory)
    
            if class_to_idx is None:
                _, class_to_idx = self.find_classes(directory)
            elif not class_to_idx:
                raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
    
            instances = []
            available_classes = set()
            for target_class in sorted(class_to_idx.keys()):
                class_index = class_to_idx[target_class]
                target_dir = os.path.join(directory, target_class)
                if not os.path.isdir(target_dir):
                    continue
                for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                    for fname in sorted(fnames):
                        path = os.path.join(root, fname)
                        if 1:  # 验证:
                            item = path, class_index
                            instances.append(item)
    
                            if target_class not in available_classes:
                                available_classes.add(target_class)
    
            empty_classes = set(class_to_idx.keys()) - available_classes
            if empty_classes:
                msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
    
            return instances
    
        def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
    
            classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
            if not classes:
                raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
    
            class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
            return classes, class_to_idx
    
        def __getitem__(self, index: int) -> Tuple[Any, Any]:
    
            path, target = self.samples[index]
            sample = self.loader(path)
    
            return sample, target
    
        def __len__(self) -> int:
            return len(self.samples)
    
        def loader(self, path):
            print("path:", path)
            #img = cv2.imread(path)  # BGR HWC
            img=Image.open(path).convert("RGB") # RGB HWC
            return img
    
    
    def time_sync():
        return time.time()
    
    #sys.exit() 
    dataset = DatasetFolder(root="/media/a/flyfish/source/yolov5/datasets/imagewoof/val")
    
    #image, label=dataset[7]
    
    #
    weights = "/home/a/classes.pt"
    device = "cpu"
    model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
    model.eval()
    print(model.names)
    print(type(model.names))
    
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    def preprocess(images):
      
    
        #实现 PyTorch Resize
        target_size =224
    
        img_w = images.width
        img_h = images.height
        
        if(img_h >= img_w):# hw
     
            resize_img = images.resize((target_size, int(target_size * img_h / img_w)), Image.BILINEAR)
        else:
            resize_img = images.resize((int(target_size * img_w  / img_h),target_size), Image.BILINEAR)
    
        #实现 PyTorch CenterCrop
        width = resize_img.width
        height = resize_img.height
    
        center_x,center_y = width//2,height//2
        left = center_x - (target_size//2)
        top = center_y- (target_size//2)
        right =center_x +target_size//2
        bottom = center_y+target_size//2
        cropped_img = resize_img.crop((left, top, right, bottom))
    
        #实现 PyTorch ToTensor Normalize
        images = np.asarray(cropped_img)
        print("preprocess:",images.shape)
        images = images.astype('float32')
        images = (images/255-mean)/std
        images = images.transpose((2, 0, 1))# HWC to CHW
        print("preprocess:",images.shape)
    
        images = np.ascontiguousarray(images)
        images=torch.from_numpy(images)
        #images = images.unsqueeze(dim=0).float()
        return images
    
    pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
    # current batch size =1
    for i, (images, labels) in enumerate(dataset):
        print("i:", i)
        im = preprocess(images)
        images = im.unsqueeze(0).to("cpu").float()
     
        print(images.shape)
        t1 = time_sync()
        images = images.to(device, non_blocking=True)
        t2 = time_sync()
        # dt[0] += t2 - t1
    
        y = model(images)
        y=y.numpy()
       
        #print("y:", y)
        t3 = time_sync()
        # dt[1] += t3 - t2
        #batch size >1 图像推理结果是二维的
        #y: [[     4.0855     -1.1707     -1.4998      -0.935     -1.9979      -2.258     -1.4691     -1.0867     -1.9042    -0.99979]]
    
        tmp1=y.argsort()[:,::-1][:, :5]
    
        #batch size =1 图像推理结果是一维的, 就要处理下argsort的维度
        #y: [     3.7441      -1.135     -1.1293     -0.9422     -1.6029     -2.0561      -1.025     -1.5842     -1.3952     -1.1824]
       
        #print("tmp1:", tmp1)
        pred.append(tmp1)
        #print("labels:", labels)
        targets.append(labels)
    
        #print("for pred:", pred)  # list
        #print("for targets:", targets)  # list
        # dt[2] += time_sync() - t3
    
    
    pred, targets = np.concatenate(pred), np.array(targets)
    print("pred:", pred)
    print("pred:", pred.shape)
    print("targets:", targets)
    print("targets:", targets.shape)
    correct = ((targets[:, None] == pred)).astype(np.float32)
    print("correct:", correct.shape)
    print("correct:", correct)
    acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
    print("acc:", acc.shape)
    print("acc:", acc)
    top = acc.mean(0)
    print("top1:", top[0])
    print("top5:", top[1])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
  • 相关阅读:
    需要SMB签名的漏洞解决方案
    图像分割算法
    小程序实现人脸识别功能
    常用DOS命令
    高德地图JSAPI 2.0使用Java代码代替Nginx进行反向代理产生CORS跨域
    Java中SnowFlake 雪花算法生成全局唯一id中的问题,时间不连续全为偶数解决
    备忘录模式
    springboot企业人力资源管理系统毕业设计源码291816
    7-119 奇偶分家
    mybatis plus遇到invalid bound statement(not found)报错
  • 原文地址:https://blog.csdn.net/flyfish1986/article/details/134552641