• Pytorch基础:Tensor的reshape方法


    相关阅读 

    Pytorch基础icon-default.png?t=N7T8https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


            在Pytorch中,reshape是Tensor的一个重要方法,它与Numpy中的reshape类似,用于返回一个改变了形状但数据和数据顺序和原来一致的新Tensor对象。注意:此时返回的新tensor中的数据对象并不一定是新的,这取决于应用此方法的Tensor是否是连续的。

            reshape方法的语法如下所示:

    1. Tensor.reshape(*shape) → Tensor
    2. shape (tuple of ints or int...) - the desired shape

            reshape的用法如下所示:

    1. import torch
    2. # 创建一个张量
    3. x = torch.randn(3, 4)
    4. tensor([[ 0.1961, -0.9038, 0.9196, -1.1851],
    5. [ 1.1321, 0.3153, 0.3485, 0.7977],
    6. [-0.5279, 0.2062, -0.4224, -0.3993]])
    7. # 使用reshape方法将其重新塑造为26列的形状
    8. y = x.reshape(2, 6)
    9. y = x.reshape((2,6)) #两种形式均可,y = x.reshape([2,6])也可
    10. tensor([[ 0.1961, -0.9038, 0.9196, -1.1851, 1.1321, 0.3153],
    11. [ 0.3485, 0.7977, -0.5279, 0.2062, -0.4224, -0.3993]])

            可以看到,给出的参数既可以是多个整数(其中每个整数代表一个维度的大小,而整数的数量代表维度的数量),也可以是一个元组或是列表(其中每个元素代表一个维度的大小,而元素数量代表维度的数量)。而且reshape不改变Tensor中数据的排列顺序(指的是从上到下从左到右遍历的顺序),只改变形状,这也就对reshape各维度大小的乘积有要求,要与原Tensor一致。在上例中即3*4=2*6。

            另外reshape还有一个trick,即某一维的实参可以是-1,此时会自动根据原Tensor大小和给出的其他维度参数的大小,推断出这一维度的大小,举例如下:

    1. import torch
    2. # 创建一个张量
    3. x = torch.randn(3, 4)
    4. tensor([[ 0.1961, -0.9038, 0.9196, -1.1851],
    5. [ 1.1321, 0.3153, 0.3485, 0.7977],
    6. [-0.5279, 0.2062, -0.4224, -0.3993]])
    7. # 使用reshape方法将其重新塑造为6行n列的形状,n为自动推断出的值
    8. y = x.reshape(6, -1)
    9. tensor([[ 0.1961, -0.9038],
    10. [ 0.9196, -1.1851],
    11. [ 1.1321, 0.3153],
    12. [ 0.3485, 0.7977],
    13. [-0.5279, 0.2062],
    14. [-0.4224, -0.3993]])
    15. # 使用reshape方法将其重新塑造为(2,2,n)的形状,n为自动推断出的值
    16. y = x.reshape(2, 2, -1)
    17. tensor([[[ 0.1961, -0.9038, 0.9196],
    18. [-1.1851, 1.1321, 0.3153]],
    19. [[ 0.3485, 0.7977, -0.5279],
    20. [ 0.2062, -0.4224, -0.3993]]])
    21. # 不能在两个维度都指定-1,这时无法推断出唯一结果
    22. y = x.reshape(2, -1, -1)
    23. Traceback (most recent call last):
    24. File "", line 1, in <module>
    25. RuntimeError: only one dimension can be inferred

            除此之外,还可以使用torch.reshape()函数,这与使用reshape方式效果一致,torch.reshape()的语法如下所示。

    1. torch.reshape(input, shape) → Tensor
    2. input (Tensor) – the tensor to be reshaped
    3. shape (tuple of python:int) – the new shape
    4. import torch
    5. # 创建一个张量
    6. x = torch.randn(3, 4)
    7. tensor([[ 0.1961, -0.9038, 0.9196, -1.1851],
    8. [ 1.1321, 0.3153, 0.3485, 0.7977],
    9. [-0.5279, 0.2062, -0.4224, -0.3993]])
    10. # 使用reshape函数将其重新塑造为6行n列的形状,n为自动推断出的值
    11. y = torch.reshape(x, (6, -1))
    12. tensor([[ 0.1961, -0.9038],
    13. [ 0.9196, -1.1851],
    14. [ 1.1321, 0.3153],
    15. [ 0.3485, 0.7977],
    16. [-0.5279, 0.2062],
    17. [-0.4224, -0.3993]])

  • 相关阅读:
    day02 mybatis
    本地部署企业邮箱,让企业办公更安全高效
    【云原生】-Docker安装部署分布式数据库 OceanBase
    corosync+packmaker+drbd+nfs高可用存储
    CTF/AWD竞赛标准参考书+实战指南
    洛谷P7529 Permutation G
    「Nature领衔」8月BIOTREE成功助力发表文章17篇,总IF:190+!
    【生日快乐】SpringBoot SpringBoot 提高篇(第二篇) 第5章 SpringBoot 日志 5.1 日志介绍 & 5.2 日志框架
    小程序项目创建与Vant-UI引入
    浅谈C++重载、重写、重定义
  • 原文地址:https://blog.csdn.net/weixin_45791458/article/details/133445832