• YOLOv5 分类模型 数据集加载 3


    YOLOv5 分类模型 数据集加载 3 自定义类别

    flyfish

    YOLOv5 分类模型 数据集加载 1 样本处理
    YOLOv5 分类模型 数据集加载 2 切片处理
    YOLOv5 分类模型的预处理(1) Resize 和 CenterCrop
    YOLOv5 分类模型的预处理(2)ToTensor 和 Normalize
    YOLOv5 分类模型 Top 1和Top 5 指标说明
    YOLOv5 分类模型 Top 1和Top 5 指标实现

    之前的处理方式是类别名字是文件夹名字,类别ID是按照文件夹名字的字母顺序
    现在是类别名字是文件夹名字,按照文件列表名字顺序 例如

    classes_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 
    'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']
    
    • 1
    • 2

    n02086240 类别ID是0
    n02087394 类别ID是1
    代码处理是

    if classes_name is None or not classes_name:
        classes, class_to_idx = self.find_classes(self.root)
        print("not classes_name")
    
    else:
        classes = classes_name
        class_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}
        print("is classes_name")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    完整

    import time
    from models.common import DetectMultiBackend
    import os
    import os.path
    from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
    import cv2
    import numpy as np
    
    import torch
    from PIL import Image
    import torchvision.transforms as transforms
    
    import sys
    
    classes_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']
                  
    class DatasetFolder:
    
        def __init__(
            self,
            root: str,
    
        ) -> None:
            self.root = root
    
            if classes_name is None or not classes_name:
                classes, class_to_idx = self.find_classes(self.root)
                print("not classes_name")
    
            else:
                classes = classes_name
                class_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}
                print("is classes_name")
    
            print("classes:",classes)
            
            print("class_to_idx:",class_to_idx)
            samples = self.make_dataset(self.root, class_to_idx)
    
            self.classes = classes
            self.class_to_idx = class_to_idx
            self.samples = samples
            self.targets = [s[1] for s in samples]
    
        @staticmethod
        def make_dataset(
            directory: str,
            class_to_idx: Optional[Dict[str, int]] = None,
    
        ) -> List[Tuple[str, int]]:
    
            directory = os.path.expanduser(directory)
    
            if class_to_idx is None:
                _, class_to_idx = self.find_classes(directory)
            elif not class_to_idx:
                raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
    
            instances = []
            available_classes = set()
            for target_class in sorted(class_to_idx.keys()):
                class_index = class_to_idx[target_class]
                target_dir = os.path.join(directory, target_class)
                if not os.path.isdir(target_dir):
                    continue
                for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                    for fname in sorted(fnames):
                        path = os.path.join(root, fname)
                        if 1:  # 验证:
                            item = path, class_index
                            instances.append(item)
    
                            if target_class not in available_classes:
                                available_classes.add(target_class)
    
            empty_classes = set(class_to_idx.keys()) - available_classes
            if empty_classes:
                msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
    
            return instances
    
        def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
    
            classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
            if not classes:
                raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
    
            class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
            return classes, class_to_idx
    
        def __getitem__(self, index: int) -> Tuple[Any, Any]:
    
            path, target = self.samples[index]
            sample = self.loader(path)
    
            return sample, target
    
        def __len__(self) -> int:
            return len(self.samples)
    
        def loader(self, path):
            print("path:", path)
            #img = cv2.imread(path)  # BGR HWC
            img=Image.open(path).convert("RGB") # RGB HWC
            return img
    
    
    def time_sync():
        return time.time()
    
    #sys.exit() 
    dataset = DatasetFolder(root="/media/a/flyfish/source/yolov5/datasets/imagewoof/val")
    
    #image, label=dataset[7]
    
    #
    weights = "/home/a/classes.pt"
    device = "cpu"
    model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
    model.eval()
    print(model.names)
    print(type(model.names))
    
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    def preprocess(images):
      
    
        #实现 PyTorch Resize
        target_size =224
    
        img_w = images.width
        img_h = images.height
        
        if(img_h >= img_w):# hw
     
            resize_img = images.resize((target_size, int(target_size * img_h / img_w)), Image.BILINEAR)
        else:
            resize_img = images.resize((int(target_size * img_w  / img_h),target_size), Image.BILINEAR)
    
        #实现 PyTorch CenterCrop
        width = resize_img.width
        height = resize_img.height
    
        center_x,center_y = width//2,height//2
        left = center_x - (target_size//2)
        top = center_y- (target_size//2)
        right =center_x +target_size//2
        bottom = center_y+target_size//2
        cropped_img = resize_img.crop((left, top, right, bottom))
    
        #实现 PyTorch ToTensor Normalize
        images = np.asarray(cropped_img)
        print("preprocess:",images.shape)
        images = images.astype('float32')
        images = (images/255-mean)/std
        images = images.transpose((2, 0, 1))# HWC to CHW
        print("preprocess:",images.shape)
    
        images = np.ascontiguousarray(images)
        images=torch.from_numpy(images)
        #images = images.unsqueeze(dim=0).float()
        return images
    
    pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
    # current batch size =1
    for i, (images, labels) in enumerate(dataset):
        print("i:", i)
        im = preprocess(images)
        images = im.unsqueeze(0).to("cpu").float()
     
        print(images.shape)
        t1 = time_sync()
        images = images.to(device, non_blocking=True)
        t2 = time_sync()
        # dt[0] += t2 - t1
    
        y = model(images)
        y=y.numpy()
       
        #print("y:", y)
        t3 = time_sync()
        # dt[1] += t3 - t2
        #batch size >1 图像推理结果是二维的
        #y: [[     4.0855     -1.1707     -1.4998      -0.935     -1.9979      -2.258     -1.4691     -1.0867     -1.9042    -0.99979]]
    
        tmp1=y.argsort()[:,::-1][:, :5]
    
        #batch size =1 图像推理结果是一维的, 就要处理下argsort的维度
        #y: [     3.7441      -1.135     -1.1293     -0.9422     -1.6029     -2.0561      -1.025     -1.5842     -1.3952     -1.1824]
       
        #print("tmp1:", tmp1)
        pred.append(tmp1)
        #print("labels:", labels)
        targets.append(labels)
    
        #print("for pred:", pred)  # list
        #print("for targets:", targets)  # list
        # dt[2] += time_sync() - t3
    
    
    pred, targets = np.concatenate(pred), np.array(targets)
    print("pred:", pred)
    print("pred:", pred.shape)
    print("targets:", targets)
    print("targets:", targets.shape)
    correct = ((targets[:, None] == pred)).astype(np.float32)
    print("correct:", correct.shape)
    print("correct:", correct)
    acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
    print("acc:", acc.shape)
    print("acc:", acc)
    top = acc.mean(0)
    print("top1:", top[0])
    print("top5:", top[1])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
  • 相关阅读:
    【产品设计】素描最重要的目的是为了好看吗?
    Docker启动失败报错Failed to start Docker Application Container Engine解决方案
    七夕!专属于程序员的浪漫表白
    【经济研究】数字技术创新与中国企业高质量发展—来自企业数字专利的证据
    对求职面试者的一点小建议
    [附源码]java毕业设计拾穗在线培训考试系统
    JavaFx学习问题3---Jar包路径问题 (疑难杂症)
    【MySQL】CRUD (增删改查) 基础
    剪绳子(动态规划,贪心算法)
    记一次生产中使用CompletableFuture遇到的坑
  • 原文地址:https://blog.csdn.net/flyfish1986/article/details/134552641