• Pytorch基础:Tensor的transpose方法


    相关阅读 

    Pytorch基础icon-default.png?t=N7T8https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


             在Pytorch中,transpose是Tensor类的一个重要方法,同时它也是torch模块中的一个函数,它们的语法如下所示。

    1. Tensor.transpose(dim0, dim1) → Tensor
    2. torch.transpose(input, dim0, dim1) → Tensor
    3. input (Tensor) – the input tensor.
    4. dim0 (int) – the first dimension to be transposed
    5. dim1 (int) – the second dimension to be transposed

    官方的解释如下:

            返回一个张量,它是输入张量的转置版本,其中将给定的维度dim0和dim1交换。

            如果输入是一个具有步幅的张量(常规稠密张量),那么输出张量将与输入张量共享底层存储,因此改变一个张量的内容将改变另一个张量的内容。

            如果输入是一个稀疏张量,那么输出张量不与输入张量共享底层存储。

            如果输入是压缩布局的稀疏张量(SparseCSR, SparseBSR, SparseCSC或SparseBSC),参数dim0和dim1必须同时是批处理维度,或者必须同时是稀疏维度(稀疏张量的批处理维度是稀疏维之前的维度)。

            在详细说明之前,我们需要明确tensor的形状相关的概念,对于一个tensor来说,它的维度数和维度大小是两个概念,维度数是从0开始依次递增的,分别称为第0维、第1维...,每一个维度又有自己的大小。官方解释中的dim0和dim1指的是维度数。

    1. tensor([[1, 2, 3],
    2. [4, 5, 6]])
    3. 这个张量有2个维度,分别是0维和1维,第0维大小是2,第1维大小是3
    1. import torch
    2. # 创建一个张量
    3. x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    4. # 使用transpose操作
    5. y = x.transpose(0, 1)
    6. # 等价于y = x.transpose(1, 0)
    7. print(x, y)
    8. tensor([[1, 2, 3],
    9. [4, 5, 6]])
    10. tensor([[1, 4],
    11. [2, 5],
    12. [3, 6]])
    13. print(id(x),id(y))
    14. 4347791776 5006922112 # 说明两个张量对象不同
    15. print(x.storage().data_ptr(), y.storage().data_ptr())
    16. 5207345856 5207345856 # 说明两个张量对象里面保存的数据存储是共享的
    17. print(id(x[0,0]), id(y[0,0]))
    18. 4961223520 4961223520 # 进一步说明两个张量对象里面保存的数据存储是共享的
    19. y[0, 0] = 7
    20. print(x, y)
    21. tensor([[7, 2, 3],
    22. [4, 5, 6]])
    23. tensor([[7, 4],
    24. [2, 5],
    25. [3, 6]]) # 说明对新tensor的更改影响了原tensor
    26. print(x.is_contiguous(), y.is_contiguous())
    27. True False # 说明x是连续的,y不是连续的

            以上的内容,类似于之前在关于python中列表的浅拷贝中说到的那样,对新列表内部嵌套的列表中的元素的更改会影响原列表。如下所示。   列表的浅拷贝

    1. import copy
    2. my_list = [1, 2, [1, 2]]
    3. your_list = list(my_list) #工厂函数
    4. his_list = my_list[:] #切片操作
    5. her_list = copy.copy(my_list) #copy模块的copy函数
    6. your_list[2][0] = 3
    7. print(my_list)
    8. print(your_list)
    9. print(his_list)
    10. print(her_list)
    11. his_list[2][1] = 4
    12. print(my_list)
    13. print(your_list)
    14. print(his_list)
    15. print(her_list)
    16. her_list[2].append(5)
    17. print(my_list)
    18. print(your_list)
    19. print(his_list)
    20. print(her_list)
    21. 输出
    22. [1, 2, [3, 2]]
    23. [1, 2, [3, 2]]
    24. [1, 2, [3, 2]]
    25. [1, 2, [3, 2]]
    26. [1, 2, [3, 4]]
    27. [1, 2, [3, 4]]
    28. [1, 2, [3, 4]]
    29. [1, 2, [3, 4]]
    30. [1, 2, [3, 4, 5]]
    31. [1, 2, [3, 4, 5]]
    32. [1, 2, [3, 4, 5]]
    33. [1, 2, [3, 4, 5]]

            但不一样的是,在这里甚至对tensor中非嵌套的内容的修改,也会导致另一个tensor受到影响,如下所示。

    1. import torch
    2. # 创建一个张量
    3. x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    4. # 使用transpose操作
    5. y = x.transpose(0, 1)
    6. # 等价于y = x.transpose(1, 0)
    7. print(x, y)
    8. tensor([[1, 2, 3],
    9. [4, 5, 6]])
    10. tensor([[1, 4],
    11. [2, 5],
    12. [3, 6]])
    13. x[0] = torch.tensor[4, 4, 4] # 改变其中一个tensor的第0个元素
    14. print(x, y)
    15. tensor([[4, 4, 4],
    16. [4, 5, 6]])
    17. tensor([[4, 4],
    18. [4, 5],
    19. [4, 6]])

            在pytorch中,和transpose方法类似的还有方法permute,它比transpose功能更加强大,详细细节可以看下面的文章。

    Pytorch基础:Tensor的permute方法icon-default.png?t=N7T8https://blog.csdn.net/weixin_45791458/article/details/133612401?spm=1001.2014.3001.5502

  • 相关阅读:
    mybatis缓存介绍
    Hbase 之KeyValue结构详解
    杭电多校-Shortest Path in GCD Graph-(二进制容斥+优化)
    ArrayList源码解析
    spring boot+redis 的快速入门
    centos oracle11g开启归档模式
    Caused by: java.lang.ClassNotFoundException: freemarker.template.Configuration
    【软考-中级】系统集成项目管理工程师-风险管理历年案例
    服务安全-应用协议rsync未授权&ssh漏洞复现
    Microolap DAC for MySQL驱动程序或其他库
  • 原文地址:https://blog.csdn.net/weixin_45791458/article/details/133470992