• Pytorch中DataLoader的collate_fn()参数学习笔记


    1 Dataset和DataLoader创建和加载数据

    使用pytorch训练网络之前的数据准备部分都要有两个流程:Dataset和DataLoader。前者用来定义自己的数据集类型(内部实现最主要的是__getitem__()方法);而后者则是负责真正在运行的时侯给网络递送数据。

    1.1 Dataset类

    继承Dataset类,自定义数据处理类。必须重载实现len()、getitem()这两个方法。
    其中__len__返回数据集样本的数量,而__getitem__应该编写支持数据集索引的函数,返回数据和对应label,例如:通过dataset[i]可以得到数据集中的第i+1个数据。

    1.2 DataLoader类

    DataLoader完整的参数表如下:

    class torch.utils.data.DataLoader(
     dataset,
     batch_size=1,
     shuffle=False,
     sampler=None,
     batch_sampler=None,
     num_workers=0,
     collate_fn=,
     pin_memory=False,
     drop_last=False,
     timeout=0,
     worker_init_fn=None)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    几个关键的参数意思:
    dataset:PyTorch已有的数据读取接口或自定义数据接口的输出
    batch_size:根据具体情况设置
    shuffle:设置为True的时候,每个迭代都会打乱数据集,一般在训练数据中会采用
    num_workers:这个参数必须大于等于0,0表示数据导入在主进程中进行,大于0表示通过多个进程来导入数据,可以加快数据导入速度。
    collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能
    drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True抛弃,否则保留

    通常说来,我们在编写完Dataset之后,其内部的__getitem__会弹出一个[data, label]的一条数据,DataLoader中的collate_fn函数将这些一条一条的数据组织成一个batch。

    注意:
    通常的,默认的collate_fn函数是要求一个batch中的图片都具备相同size,当一个batch中的图片大小都不一样时(或者想要定值batch的输出形式),使用自定义的collate_fn函数。

    1.3 自定义batch

    通过collate_fn函数可以对这些样本做进一步的处理(任何你想要的处理),原则上返回值应当是一个有结构的batch。而DataLoader每次迭代的返回值就是collate_fn的返回值。

    可以使用collate_fn的同时,结合使用默认的default_collate。

    from torch.utils.data.dataloader import default_collate  # 导入这个函数
    def my_collate_fn(batch):
        """
        params:
            batch :是一个列表,列表的长度是 batch_size,list中的每个元素都是__getitem__得到的结果。
                   列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
                   大致的格式如下 [(x1,y1),(x2,y2),(x3,y3)...(xn,yn)]
        returns:
            整理之后的新的batch
        """
        # 这一部分是对 batch 进行重新 “校对、整理”的代码
        return default_collate(batch) #返回校对之后的batch,一般就直接推荐使用default_collate进行包装,因为它里面有很多功能,比如将numpy转化成tensor等操作,这是必须的。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    然后调用时使用:

    trainset = DataLoader(dataset=train_dataset,
                          batch_size=24,
                          shuffle=True,
                          collate_fn=my_collate_fn,
    
    • 1
    • 2
    • 3
    • 4

    2 实例

    图神经网络时,将多张图合并为一张大图。

    """
    Combine multiple graphs into one large graph.
    """
    
    #- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能
    # collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果。
    def collate_fn(batch):
        nodes_list = [b[0] for b in batch] #b[0]=p.array(nodes)
        nodes = np.concatenate(nodes_list, axis=0) #所有节点拼到一起,不扩展维度,拼成一个array
        #map 对于node_list每一组,计算shape(即节点个数)。也就是一个可迭代对象。返回array(每个图的节点个数)
        nodes_lens = np.fromiter(map(lambda l: l.shape[0], nodes_list), dtype=np.int64)
        nodes_inds = np.cumsum(nodes_lens) #计算一个数组各行的累加值
        nodes_num = nodes_inds[-1] #最后一个值,即总节点个数
        nodes_inds = np.insert(nodes_inds, 0, 0) #在第一个位置插入0这个值
        nodes_inds = np.delete(nodes_inds, -1) #按行展开后 删除最后一个元素
    
        edges_list = [b[1] for b in batch] #np.array(edges)
        edges_list = [e + i for e, i in zip(edges_list, nodes_inds)] #e是边的连接(每个图都从0开始) i是总节点个数
        edges = np.concatenate(edges_list, axis=0)
        m = edges_to_matrix(nodes_num, edges) #将每个batch拼接成一个邻接矩阵
    
        labels = [b[2] for b in batch] #np.array([float(label)])
        labels = np.concatenate(labels, axis=0)
        # batch中第i个图的节点个数为k batch_mask数组分别为[0,..,0] [1,..,1] 其中...为节点个数
        batch_mask = [np.array([i] * k, dtype=np.int32) for i, k in zip(range(len(batch)), nodes_lens)]
        batch_mask = np.concatenate(batch_mask, axis=0)
    
        #返回节点类型 邻接矩阵 预测值 batch_mask
        return torch.from_numpy(nodes), torch.from_numpy(m).float(), torch.from_numpy(labels), torch.from_numpy(batch_mask)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
  • 相关阅读:
    MySQL标准差和方差函数使用
    Elasticsearch-Rest风格
    22道js输出顺序问题,你能做出几道
    智能矩阵,引领商业新纪元!拓世方案:打破线上线下界限,开启无限营销可能!
    炫酷HTML蜘蛛侠登录页面
    【LVGL】组件的样式的设置、更改、删除API函数
    【ES实战】ES分页与去重
    设计循环队列,解决假溢出问题
    [SWPUCTF 2018]SimplePHP
    Unity 设置窗口置顶超级详解版
  • 原文地址:https://blog.csdn.net/weixin_45928096/article/details/126950619