• 【PyTorch】详细总结-如何创建和初始化Pytoch张量 (2022年最新)



    1. 什么是张量

    张量 (tensor) 是一种数据结构,将始终贯穿PyTorch的全过程。

    • 向量数据的延申方向是一条线,称为一维张量。描述向量元素的位置和形状只需要用一个数就可以,例如[x];
    • 矩阵数据的延申方式是一个平面,称为二维张量。描述矩阵元素的位置和形状需要用两个数,例如[x, y];
    • 而张量数据延申方向是三个或三个以上,称为多维张量。如三维张量的延申方向是立体,描述三维张量的位置和形状需要用三个数,例如[x, y, z];

    2. 创建张量

    在Pytorch中,创建张量主要有以下四种方式:


    2.1 直接生成张量

    先创建一个二维数组 (即矩阵) ,再把这个二维数组作为参数,传入函数 torch.tensor() 中,就会直接生成 torch.Tensor 张量。

    会自动推断张量的数据类型 (在深度学习中,数据类型绝大多数是浮点型,其次是整形) 。

    import torch
    
    data = [[1, 2], [3, 4]]
    print("转换前的数据类型:", type(data))
    x_data = torch.tensor(data)
    print("转换后的数据类型:", type(x_data))
    print(x_data.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出:

    转换前的数据类型: <class 'list'>
    转换后的数据类型: <class 'torch.Tensor'>
    torch.Size([2, 2])
    
    • 1
    • 2
    • 3

    2.2 通过Numpy arrays创建张量

    torch tensor 和 Numpy arrays之间是可以互相转化的。

    函数作用
    torch.from_numpy(ndarray)把Numpy数组转化为torch张量
    已经创建好的张量.numpy()把torch张量转化为Numpy数组

    举个例子:


    1.Numpy数组 –> torch张量

    import torch
    import numpy as np
    
    data = [[1, 2], [3, 4]]
    np_array = np.array(data)
    print("转换前Numpy数组的数据类型:", type(np_array))
    x_np = torch.from_numpy(np_array)
    print("转换后torch tensor的数据类型:", type(x_np))
    print("张量的形状:", x_np.shape)
    print(x_np)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    输出:

    转换前Numpy数组的数据类型: <class 'numpy.ndarray'>
    转换后torch tensor的数据类型: <class 'torch.Tensor'>
    张量的形状: torch.Size([2, 2])
    tensor([[1, 2],
            [3, 4]], dtype=torch.int32)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    2.torch张量 –> Numpy数组

    import torch
    import numpy as np
    
    tensor = torch.tensor([[1, 2], [3, 4]])
    print("转化前类型:", type(tensor))
    np_array = tensor.numpy()
    print("转化后类型:", type(np_array))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出:

    转化前类型: <class 'torch.Tensor'>
    转化后类型: <class 'numpy.ndarray'>
    
    • 1
    • 2

    2.3 通过已有的张量创建新的张量

    在Pytorch中,可以通过已有的张量,“复制”出另一个形状相同的新张量。其中,新张量的数据类型可以继承原来的张量,也可以指定为新的数据类型。但是两个张量的形状一定是相同的。常用的函数主要有以下两个:

    函数作用
    torch.ones_like(torch.Tensor x)创建一个和张量x形状相同,但值全为1的新张量
    torch.rand_like(torch.Tensor x, dtype=torch.float)创建一个和张量x形状相同,但值全为0~1随机数的新张量,且数据类型改为 torch.float

    举个栗子:

    import torch
    
    tensor1 = torch.tensor([[1, 2], [3, 4]])
    tensor2 = torch.ones_like(tensor1)  # 保留tensor1的所有属性
    print(f"Ones Tensor:\n {tensor2} \n")   # 这种写法类似于C中的printf()
    tensor3 = torch.rand_like(tensor1, dtype=torch.float)   # 重新指定tensor1的数据类型
    print(f"Random Tensor:\n {tensor3} \n")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出:

    Ones Tensor:
     tensor([[1, 1],
            [1, 1]]) 
    
    Random Tensor:
     tensor([[0.8884, 0.5649],
            [0.5172, 0.0832]]) 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    2.4 根据张量维度创建张量

    可以通过指定张量的维度 (即张量形状) shape ,来创建以下三种张量:

    函数作用
    torch.ones(shape)创建维度为shape的全1张量
    torch.rand(shape)创建维度为shape的随机值 (范围0~1) 张量
    torch.zeros(shape)创建维度为shape的全0张量

    举个栗子:

    import torch
    
    shape = (2, 3)  # 声明张量维度
    ones_tensor = torch.ones(shape)
    rand_tensor = torch.rand(shape)
    zeros_tensor = torch.zeros(shape)
    
    print(f"Ones Tensor:\n {ones_tensor} \n")
    print(f"Random Tensor:\n {rand_tensor} \n")
    print(f"Zeros Tensor:\n {zeros_tensor} \n")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    输出:

    Ones Tensor:
     tensor([[1., 1., 1.],
            [1., 1., 1.]]) 
    
    Random Tensor:
     tensor([[0.5321, 0.1629, 0.1860],
            [0.3626, 0.5045, 0.7865]]) 
    
    Zeros Tensor:
     tensor([[0., 0., 0.],
            [0., 0., 0.]]) 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
  • 相关阅读:
    C# redis通过stream实现消息队列以及ack机制
    使用FFmpeg+ubuntu系统转化flac无损音频为mp3
    【SpringBoot】怎么在一个大的SpringBoot项目中创建多个小的SpringBoot项目,从而形成子父依赖
    联想笔记本电脑触摸板失灵了怎么办
    电脑技巧:Win10粘贴文件到C盘提示没有权限的解决方法
    【每日一题Day35】LC878第N个神奇数字 | 二分查找 找规律 + 数学
    明白这3个规则,行走职场简直没有难度
    netty Recycler对象池
    maven的pom.xml文件爆红,并且刷新maven无法下载依赖的解决方案
    [ 笔记 ] 计算机网络安全_4_网络扫描和网络监听
  • 原文地址:https://blog.csdn.net/Sihang_Xie/article/details/125628576