• 举例说明PyTorch函数torch.cat与torch.stack的区别


    一、torch.cat与torch.stack的区别

    torch.cat用于在给定的维度上连接多个张量,它将这些张量沿着指定维度堆叠在一起。

    torch.stack用于在新的维度上堆叠多个张量,它会创建一个新的维度,并将这些张量沿着这个新维度堆叠在一起。

    二、torch.cat

    在这里插入图片描述

    Example1:

    import torch
    
    tensor1 = torch.tensor([[1, 2], [3, 4]])
    tensor2 = torch.tensor([[5, 6], [7, 8]])
    
    result1 = torch.cat((tensor1, tensor2), dim=0)
    result2 = torch.cat((tensor1, tensor2), dim=1)
    
    print(result1.shape)
    print(result1)
    print(result2.shape)
    print(result2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    torch.Size([4, 2])
    tensor([[1, 2],
            [3, 4],
            [5, 6],
            [7, 8]])
    torch.Size([2, 4])
    tensor([[1, 2, 5, 6],
            [3, 4, 7, 8]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    三、torch.stack

    在这里插入图片描述

    Example1:

    import torch
    
    tensor1 = torch.tensor([1, 2, 3])
    tensor2 = torch.tensor([4, 5, 6])
    
    result1 = torch.stack((tensor1, tensor2), dim=0)
    result2 = torch.stack((tensor1, tensor2), dim=1)
    
    print(result1.shape)
    print(result1)
    print(result2.shape)
    print(result2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    torch.Size([2, 3])
    tensor([[1, 2, 3],
            [4, 5, 6]])
    torch.Size([3, 2])
    tensor([[1, 4],
            [2, 5],
            [3, 6]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    Example2:

    import torch
    
    tensor1 = torch.tensor([[1, 2], [3, 4], [5, 6]])
    tensor2 = torch.tensor([[7, 8], [9, 10], [11, 12]])
    tensor3 = torch.tensor([[13, 14], [15, 16], [17, 18]])
    
    result1 = torch.stack((tensor1, tensor2, tensor3), dim=0)
    result2 = torch.stack((tensor1, tensor2, tensor3), dim=1)
    
    print(result1.shape)
    print(result1)
    print(result2.shape)
    print(result2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    torch.Size([3, 3, 2])
    tensor([[[ 1,  2],
             [ 3,  4],
             [ 5,  6]],
    
            [[ 7,  8],
             [ 9, 10],
             [11, 12]],
    
            [[13, 14],
             [15, 16],
             [17, 18]]])
    torch.Size([3, 3, 2])
    tensor([[[ 1,  2],
             [ 7,  8],
             [13, 14]],
    
            [[ 3,  4],
             [ 9, 10],
             [15, 16]],
    
            [[ 5,  6],
             [11, 12],
             [17, 18]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
  • 相关阅读:
    线代(高斯消元法、线性基)
    NLP | XLNet :用于语言理解的广义自回归预训练 论文详解
    Go构建模式:GOPATH、vendor、Go Module
    请查收 | Navicat 热门技术问答
    网络面试-ox09 http是如何维持用户的状态?
    springboot security基本配置
    @ConfigurationProperties的使用方式
    红队专题-Cobalt strike从小白到飞升手册
    Python进阶(更新中)
    企业如何通过CRM获得竞争力?
  • 原文地址:https://blog.csdn.net/weixin_45953673/article/details/132747733