• pytorch中unsqueeze用法说明


    在指定的位置插入一个维度,有两个参数,input是输入的tensor,dim是要插到的维度

    需要注意的是dim的范围是[-input.dim()-1, input.dim()+1),是一个左闭右开的区间,当dim为负值时,会自动转换为dim = dim+input.dim()+1,类似于使用负数对python列表进行切片。

    1. import torch
    2. a = torch.randn(2,5)
    3. print(a)
    4. print("")
    5. b = a.unsqueeze(0)
    6. print(b.shape)
    7. print("")
    8. c = a.unsqueeze(a.dim())
    9. print(c.shape)
    10. 输出:
    11. tensor([[-0.4734, 0.4115, -0.9415, -1.1280, -0.1065],
    12. [ 0.1613, 1.2594, 1.1261, 1.3881, 0.1112]])
    13. torch.Size([1, 2, 5])
    14. torch.Size([2, 5, 1])

    以上是二维数据情况:

    首先生成了一个二维矩阵,其大小为[2,5]

    然后,在0维度上插入一个维度,可以看到现在新矩阵a的形状变为[1,2,5],第0维度的大小默认是1

    最后,在最后一个维度上插入一个维度,形状变为[2, 5, 1]

    1. a=torch.rand(2,3,2)
    2. print("")
    3. print("torch.unsqueeze(a,3) size: {}".format(torch.unsqueeze(a,3).size()))
    4. print("")
    5. print("torch.unsqueeze(a,2) size: {}".format(torch.unsqueeze(a,2).size()))
    6. print("")
    7. print("torch.unsqueeze(a,1) size: {}".format(torch.unsqueeze(a,1).size()))
    8. print("")
    9. print("torch.unsqueeze(a,0) size: {}".format(torch.unsqueeze(a,0).size()))
    10. print("")
    11. print("torch.unsqueeze(a,-1) size: {}".format(torch.unsqueeze(a,-1).size()))
    12. print("")
    13. print("torch.unsqueeze(a,-2) size: {}".format(torch.unsqueeze(a,-2).size()))
    14. print("")
    15. print("torch.unsqueeze(a,-3) size: {}".format(torch.unsqueeze(a,-3).size()))
    16. print("")
    17. print("torch.unsqueeze(a,-4) size: {}".format(torch.unsqueeze(a,-4).size()))
    18. 输出:
    19. torch.unsqueeze(a,3) size: torch.Size([2, 3, 2, 1])
    20. torch.unsqueeze(a,2) size: torch.Size([2, 3, 1, 2])
    21. torch.unsqueeze(a,1) size: torch.Size([2, 1, 3, 2])
    22. torch.unsqueeze(a,0) size: torch.Size([1, 2, 3, 2])
    23. torch.unsqueeze(a,-1) size: torch.Size([2, 3, 2, 1])
    24. torch.unsqueeze(a,-2) size: torch.Size([2, 3, 1, 2])
    25. torch.unsqueeze(a,-3) size: torch.Size([2, 1, 3, 2])
    26. torch.unsqueeze(a,-4) size: torch.Size([1, 2, 3, 2])

    对于三维数据input.dim() = 3,因此dim的范围是[-4, 4)

  • 相关阅读:
    Hifiasm-meta | 你没看错!基于宏基因组的完成图!!
    使用Excel导入和导出数据
    华为机试真题 C++ 实现【We Are A Team】
    vivo手机如何隐藏应用 vivo手机隐藏应用方法
    互联网大厂女工抑郁症自救指南
    如何面向组件跨层级通信
    设置matplotlib绘图的y轴为百分比格式
    深度解析SpringBoot内嵌Web容器
    View绘制流程-Vsync信号是如何发送和接受的
    基于Vue+SpringBoot的大病保险管理系统 开源项目
  • 原文地址:https://blog.csdn.net/ym62033/article/details/137855194