• tensor的索引、切片、拼接和压缩等


    ensor的索引、切片和拼接

    一、相关命令

    命令1:拼接-torch.cat()

    • 格式: torch.cat(tensors, dim=0, out=None) → Tensor
    • 解释:在指定维度上拼接两个tensor
    >>> x = torch.randn(2, 3)
    >>> x
    tensor([[ 0.6580, -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497]])
    >>> torch.cat((x, x, x), 0)  # 在dim=0上拼接,也就是行方向
    tensor([[ 0.6580, -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497],
            [ 0.6580, -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497],
            [ 0.6580, -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497]])
    >>> torch.cat((x, x, x), 1) # 在dim=1上拼接
    tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
             -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
             -0.5790,  0.1497]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    命令2:拼接-torch.stack()

    • 格式:torch.stack(tensors, dim=0, out=None) → Tensor
    • 解释:沿着一个新的维度对张量进行拼接。序列中的tensors必须具有相同的size。
    • **直白的说:**它可以将二维tensor变三维tensor,三维变4维
    • 与torch.cat()的区别。stack()属于扩张再拼接的函数。通常用于NLP和CV领域

    在这里插入图片描述
    注意: 下图及实例来自torch.stack()的官方解释,详解以及例子
    在这里插入图片描述

    # 假设是时间步T1的输出
    T1 = torch.tensor([[1, 2, 3],
            		[4, 5, 6],
            		[7, 8, 9]])
    # 假设是时间步T2的输出
    T2 = torch.tensor([[10, 20, 30],
            		[40, 50, 60],
            		[70, 80, 90]])
    print(torch.stack((T1,T2),dim=0))
    print(torch.stack((T1,T2),dim=0).shape)
    print(==================================)
    print(torch.stack((T1,T2),dim=1))
    print(torch.stack((T1,T2),dim=1).shape)
    print(==================================)
    print(torch.stack((T1,T2),dim=2))
    print(torch.stack((T1,T2),dim=2).shape)
    print(==================================)
    print(torch.stack((T1,T2),dim=3).shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    输出:

    tensor([[[ 1,  2,  3],
             [ 4,  5,  6],
             [ 7,  8,  9]],
    
            [[10, 20, 30],
             [40, 50, 60],
             [70, 80, 90]]])
             
    torch.Size([2, 3, 3])
    ==================================
    tensor([[[ 1,  2,  3],
             [10, 20, 30]],
    
            [[ 4,  5,  6],
             [40, 50, 60]],
    
            [[ 7,  8,  9],
             [70, 80, 90]]])
             
    torch.Size([3, 2, 3])
    ==================================
    tensor([[[ 1, 10],
             [ 2, 20],
             [ 3, 30]],
    
            [[ 4, 40],
             [ 5, 50],
             [ 6, 60]],
    
            [[ 7, 70],
             [ 8, 80],
             [ 9, 90]]])
    
    torch.Size([3, 3, 2])
    ==================================
    '选择的dim>len(outputs),所以报错'
    IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37

    命令3:分割-torch.chunk()

    • 格式: torch.chunk(input, chunks, dim=0) → List of Tensors
    • chunks只能是整数
    • 解释:可以看做是torch.cat()的反向操作,分割tensor

    注意:

    如果沿着指定轴不能整数切分,那么最后一个块将会是最小的

    >>> c
    tensor([[0.9387, 0.5666, 0.3289, 0.7775, 0.5938],
            [0.0968, 0.0961, 0.6976, 0.9121, 0.0796],
            [0.4676, 0.7772, 0.2398, 0.5254, 0.9906],
            [0.0588, 0.7729, 0.2259, 0.6438, 0.8299]], dtype=torch.float64)
    >>> cc = torch.chunk(c,2,1) # 很明显size_c = (4,5)不能争分
    >>> cc  # 因此cc的两个块分别为(4,3)和(4,2)
    (tensor([[0.9387, 0.5666, 0.3289],
            [0.0968, 0.0961, 0.6976],
            [0.4676, 0.7772, 0.2398],
            [0.0588, 0.7729, 0.2259]], dtype=torch.float64), tensor([[0.7775, 0.5938],
            [0.9121, 0.0796],
            [0.5254, 0.9906],
            [0.6438, 0.8299]], dtype=torch.float64))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    命令4:分割-torch.split()

    • 格式: torch.split(tensor, split_size_or_sections, dim=0)
    • split_size_or_sections 可以是整数,也可以是一个列表。如果是int,则先尽可能整除,然后最后一个最小;如果是列表,则将分割为len(list)个块,并按照list的元素进行分配
    • 解释:可以看做是torch.cat()的反向操作,分割tensor
    >>> c
    tensor([[0.9387, 0.5666, 0.3289, 0.7775, 0.5938],
            [0.0968, 0.0961, 0.6976, 0.9121, 0.0796],
            [0.4676, 0.7772, 0.2398, 0.5254, 0.9906],
            [0.0588, 0.7729, 0.2259, 0.6438, 0.8299]], dtype=torch.float64)
    # 是一个列表
    >>>cc = torch.split(c,[2,2,1],1) 
    >>> cc
    (tensor([[0.9387, 0.5666],
            [0.0968, 0.0961],
            [0.4676, 0.7772],
            [0.0588, 0.7729]], dtype=torch.float64), tensor([[0.3289, 0.7775],
            [0.6976, 0.9121],
            [0.2398, 0.5254],
            [0.2259, 0.6438]], dtype=torch.float64), tensor([[0.5938],
            [0.0796],
            [0.9906],
            [0.8299]], dtype=torch.float64))
    # 是一个整数
    >>> cc=torch.split(c,2,1)
    >>> cc
    (tensor([[0.9387, 0.5666],
            [0.0968, 0.0961],
            [0.4676, 0.7772],
            [0.0588, 0.7729]], dtype=torch.float64), tensor([[0.3289, 0.7775],
            [0.6976, 0.9121],
            [0.2398, 0.5254],
            [0.2259, 0.6438]], dtype=torch.float64), tensor([[0.5938],
            [0.0796],
            [0.9906],
            [0.8299]], dtype=torch.float64))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    命令5:采集指定维度数据-torch.gather()

    • 格式:torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
    • 解释:从原tensor中按照指定轴dim和索引index获取数据
    • index是一个tensor。输出的维度与index的维度一定时相同的

    在这里插入图片描述
    实例:摘自知乎:图解PyTorch中的torch.gather函数

    先定义一个原始tensor:

    >>> a = torch.arange(3, 12).view(3, 3)
    >>> a
    tensor([[ 3,  4,  5],
            [ 6,  7,  8],
            [ 9, 10, 11]])
    
    • 1
    • 2
    • 3
    • 4
    • 5

    指定index

    index = torch.tensor([[2,1,0]])
    
    • 1

    在dim=0方向采集

    >>> b = torch.gather(a,0,index)
    >>> b
    tensor([[9, 7, 5]])
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    命令6:维度压缩-torch.squeeze()

    • 格式:torch.squeeze(input, dim=None, out=None) → Tensor
    • 解释:大小为1的维度都被删除

    官网实例
    如果 input 的size=(A×1×B×C×1×D),则返回的tensor size= (A×B×C×D)

    如果指定dim上的维度不为1,则返回不变

    >>> x = torch.zeros(2, 1, 2, 1, 2)
    >>> x.size()
    torch.Size([2, 1, 2, 1, 2])
    >>> y = torch.squeeze(x)
    >>> y.size()
    torch.Size([2, 2, 2])
    >>> y = torch.squeeze(x, 0)
    >>> y.size()
    torch.Size([2, 1, 2, 1, 2])
    >>> y = torch.squeeze(x, 1)
    >>> y.size()
    torch.Size([2, 2, 1, 2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    命令7:维度恢复或添加-torch.unsqueeze()

    • 格式:torch.unsqueeze(input, dim, out=None) → Tensor
    • 解释:在指定轴添加1。

    在这里插入图片描述

    >>> x = torch.tensor([1, 2, 3, 4])
    >>> torch.unsqueeze(x, 0)
    tensor([[ 1,  2,  3,  4]])
    >>> torch.unsqueeze(x, 1)
    tensor([[ 1],
            [ 2],
            [ 3],
            [ 4]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    命令8、tensor扁平化torch.flatten()

    • torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
      在这里插入图片描述

    举例

    >>> t = torch.tensor([[[1, 2],
                           [3, 4]],
                          [[5, 6],
                           [7, 8]]])
    >>>t.size # (2,2,2)
    >>> torch.flatten(t) # 默认从dim=0开始
    tensor([1, 2, 3, 4, 5, 6, 7, 8])
    
    >>> torch.flatten(t,start_dim = 1) # 
    tensor([[1, 2, 3, 4],
            [5, 6, 7, 8]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    参考

    torch官网:https://pytorch.org/docs/1.2.0/torch.html

  • 相关阅读:
    JMeter笔记14 | JMeter场景设计和设置
    《C++新经典》第14章 类
    C++异常捕获
    【.Net实用方法总结】 整理并总结System.IO中Directory类及其方法介绍
    uni-app 开发调试自动打开手机屏幕大小界面(Aidex移动端开发项目)
    客户端性能优化实践
    2023年9月青少年软件编程(C 语言) 等级考试试卷(八级)
    flutter 身兼数职的getx —— 依赖管理
    react-router v6使用createHashHistory进行history.push时,url改变页面不渲染
    NLP(六十八)使用Optimum进行模型量化
  • 原文地址:https://blog.csdn.net/weixin_44769214/article/details/126321466