• 【torch.utils.data.sampler】采样器的解析和使用


    torch.utils.data.sampler

    内置的Sampler

    基类 Sampler

    sampler 采样器,是一个迭代器。PyTorch提供了多种采样器,用户也可以自定义采样器。所有sampler都是承 torch.utils.data.sampler.Sampler这个抽象类。

    class Sampler(object):
        r"""Base class for all Samplers.
        """
        def __init__(self, data_source):
            pass
        def __iter__(self):
            raise NotImplementedError
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    顺序采样 SequentialSampler

    • 功能
      • 顺序地对元素进行采样,总是以相同的顺序。
    • 参数
      • data_source(Dataset): 采样的数据集

    初始化方法仅仅需要一个Dataset类对象作为参数。对于__len__()只负责返回数据源包含的数据个数;iter()方法负责返回一个可迭代对象,这个可迭代对象是由range产生的顺序数值序列,也就是说迭代是按照顺序进行的。

    class SequentialSampler(Sampler):
        def __init__(self, data_source):
            self.data_source = data_source
        def __iter__(self):
            return iter(range(len(self.data_source)))
        def __len__(self):
            return len(self.data_source)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 例子
    # 定义数据和对应的采样器
    data = list([17, 22, 3, 41, 8])
    seq_sampler = sampler.SequentialSampler(data_source=data)
    # 迭代获取采样器生成的索引
    for index in seq_sampler:
        print("index: {}, data: {}".format(str(index), str(data[index])))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    得到下面的输出,说明Sequential Sampler产生的索引是顺序索引:

    index: 0, data: 17
    index: 1, data: 22
    index: 2, data: 3
    index: 3, data: 41
    index: 4, data: 8
    
    • 1
    • 2
    • 3
    • 4
    • 5

    随机采样 RandomSampler

    • 功能
      • 随机抽取元素。如果没有替换,则从打乱的数据集中采样。 如果有替换,则用户可以指定:attr:num_samples
    • 参数
      • data_source (Dataset): 采样的数据集
      • replacement (bool): 如果为 True抽取的样本是有放回的。默认是False
      • num_samples (int): 抽取样本的数量,默认是len(dataset)。当replacementTrue的时应该被被实例化
    class RandomSampler(Sampler):
        def __init__(self, data_source, replacement=False, num_samples=None):
            self.data_source = data_source
            # 这个参数控制的应该为是否重复采样
            self.replacement = replacement
            self._num_samples = num_samples
        def num_samples(self):
            # dataset size might change at runtime
            # 初始化时不传入num_samples的时候使用数据源的长度
            if self._num_samples is None:
                return len(self.data_source)
            return self._num_samples
        # 返回数据集长度
        def __len__(self):
            return self.num_samples
    		# 索引生成
    		def __iter__(self):
    		    n = len(self.data_source)
    		    if self.replacement:
    		        # 生成的随机数是可能重复的
    		        return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
    		    # 生成的随机数是不重复的
    		    return iter(torch.randperm(n).tolist())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    randint()函数生成的随机数学列是可能包含重复数值的,而randperm()函数生成的随机数序列是绝对不包含重复数值的

    • 例子
    '''不使用replacement,生成的随机索引不重复'''
    ran_sampler = sampler.RandomSampler(data_source=data)
    # 得到下面输出
    index: 0, data: 17
    index: 2, data: 3
    index: 3, data: 41
    index: 4, data: 8
    index: 1, data: 22
    
    '''使用replacement,生成的随机索引有重复'''
    ran_sampler = sampler.RandomSampler(data_source=data, replacement=True)
    # 得到下面的输出
    index: 0, data: 17
    index: 4, data: 8
    index: 3, data: 41
    index: 4, data: 8
    index: 2, data: 3
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    子集随机采样 SubsetRandomSampler

    • 功能
      • 从给定的索引列表中随机抽取元素,不进行替换。
    • 参数
      • indices (sequence): 索引列表
    class SubsetRandomSampler(Sampler):
        def __init__(self, indices):
            # 数据集的切片,比如划分训练集和测试集
            self.indices = indices
        def __iter__(self):
            # 以元组形式返回不重复打乱后的“数据”
            return (self.indices[i] for i in torch.randperm(len(self.indices)))
        def __len__(self):
            return len(self.indices)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    _iter__()返回的并不是随机数序列,而是通过随机数序列作为indices的索引,进而返回打乱的数据本身。需要注意的仍然是采样是不重复的,也是通过randperm()函数实现的。

    • 例子

    下面将data划分为train和val两个部分

    sub_sampler_train = sampler.SubsetRandomSampler(indices=data[0:2])
    sub_sampler_val = sampler.SubsetRandomSampler(indices=data[2:])
    # 下面是train输出
    index: 17
    index: 22
    *************
    # 下面是val输出
    index: 8
    index: 41
    index: 3
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    加权随机采样 WeightedRandomSampler

    • 功能
      • 按照给定的概率权重weights, 对元素进行采样
    • 参数
      • weights权重序列
      • num_samples采样数
      • replacement 抽取的样本是否有放回
    class WeightedRandomSampler(Sampler):
        def __init__(self, weights, num_samples, replacement=True):
             # ...省略类型检查
             # weights用于确定生成索引的权重
            self.weights = torch.as_tensor(weights, dtype=torch.double)
            self.num_samples = num_samples
            # 用于控制是否对数据进行有放回采样
            self.replacement = replacement
        def __iter__(self):
            # 按照加权返回随机索引值
            return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    __iter__()方法返回的数值为随机数序列,只不过生成的随机数序列是按照weights指定的权重确定的

    • 例子
    # 位置[0]的权重为0,位置[1]的权重为10,其余位置权重均为1.1
    weights = torch.Tensor([0, 10, 1.1, 1.1, 1.1, 1.1, 1.1])
    wei_sampler = sampler.WeightedRandomSampler(weights=weights, num_samples=6, replacement=True)
    # 下面是输出:
    index: 1
    index: 2
    index: 3
    index: 4
    index: 1
    index: 1
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    从输出可以看出,位置[1]由于权重较大,被采样的次数较多,位置[0]由于权重为0所以没有被采样到,其余位置权重低所以都仅仅被采样一次。

    批采样 BatchSampler

    • 功能
      • 包装另一个采样器以生成一个小批量索引。
    • 参数
      • sampler对应前面介绍的XxxSampler类实例
      • batch_size 批量大小
      • drop_last为“True”时,如果采样得到的数据个数小于batch_size则抛弃本个batch的数据
    class BatchSampler(Sampler):
        def __init__(self, sampler, batch_size, drop_last):# ...省略类型检查
            # 定义使用何种采样器Sampler
            self.sampler = sampler
            self.batch_size = batch_size
            # 是否在采样个数小于batch_size时剔除本次采样
            self.drop_last = drop_last
        def __iter__(self):
            batch = []
            for idx in self.sampler:
                batch.append(idx)
                # 如果采样个数和batch_size相等则本次采样完成
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
            # for结束后在不需要剔除不足batch_size的采样个数时返回当前batch        
            if len(batch) > 0 and not self.drop_last:
                yield batch
        def __len__(self):
            # 在不进行剔除时,数据的长度就是采样器索引的长度
            if self.drop_last:
                return len(self.sampler) // self.batch_size
            else:
                return (len(self.sampler) + self.batch_size - 1) // self.batch_size
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 例子

    下面的例子中batch sampler采用的采样器为顺序采样器:

    seq_sampler = sampler.SequentialSampler(data_source=data)
    batch_sampler = sampler.BatchSampler(seq_sampler, 3, False)
    # 下面是输出
    batch: [0, 1, 2]
    batch: [3, 4]
    
    • 1
    • 2
    • 3
    • 4
    • 5
  • 相关阅读:
    生成式人工智能促使社会转变
    nodeJs--querystring模块
    Python神经网络入门与实战,神经网络算法python实现
    SAP CO系统配置-成本中心会计
    【python】python的标准库——sys模块介绍
    PHP Laravel报错No application encryption key has been specified
    SpringBoot 如何使用 Ehcache 作为缓存
    吴恩达深度学习笔记:深度学习引言1.1-1.5
    Mathorcup数学建模竞赛第三届-【妈妈杯】C题:语音识别技术的应用(附带赛题解析&获奖论文&MATLAB代码)(一)
    如何快速地编译并且运行Github中的Android项目
  • 原文地址:https://blog.csdn.net/zyw2002/article/details/128176507