• torch.Tensor详解


    torch.Tensor详解

    参考:

    torch.Tensor — PyTorch 1.12 documentation

    提供的数据类型:

    Data typedtypeCPU tensorGPU tensor
    32-bit floating pointtorch.float32 or torch.floattorch.FloatTensortorch.cuda.FloatTensor
    64-bit floating pointtorch.float64 or torch.doubletorch.DoubleTensortorch.cuda.DoubleTensor
    16-bit floating point [1]torch.float16 or torch.halftorch.HalfTensortorch.cuda.HalfTensor
    16-bit floating point [2]torch.bfloat16torch.BFloat16Tensortorch.cuda.BFloat16Tensor
    32-bit complextorch.complex32 or torch.chalf
    64-bit complextorch.complex64 or torch.cfloat
    128-bit complextorch.complex128 or torch.cdouble
    8-bit integer (unsigned)torch.uint8torch.ByteTensortorch.cuda.ByteTensor
    8-bit integer (signed)torch.int8torch.CharTensortorch.cuda.CharTensor
    16-bit integer (signed)torch.int16 or torch.shorttorch.ShortTensortorch.cuda.ShortTensor
    32-bit integer (signed)torch.int32 or torch.inttorch.IntTensortorch.cuda.IntTensor
    64-bit integer (signed)torch.int64 or torch.longtorch.LongTensortorch.cuda.LongTensor
    Booleantorch.booltorch.BoolTensortorch.cuda.BoolTensor
    quantized 8-bit integer (unsigned)torch.quint8torch.ByteTensor/
    quantized 8-bit integer (signed)torch.qint8torch.CharTensor/
    quantized 32-bit integer (signed)torch.qint32torch.IntTensor/
    quantized 4-bit integer (unsigned)torch.quint4x2torch.ByteTensor/

    除了编码常见的类型,还有几种不常见的类型:

    1、16-bit floating point[1]:使用 1 个符号、5 个指数和 10 个有效位。 当精度很重要以牺牲范围为代价时很有用。

    2、16-bit floating point[2]:使用 1 个符号、8 个指数和 7 个有效位。 当范围很重要时很有用,因为它具有与 float32 相同数量的指数位

    3、quantized 4-bit integer (unsigned):量化的 4 位整数存储为 8 位有符号整数。 目前仅在 EmbeddingBag 运算符中支持。

    不指定类型的话,默认是:torch.FloatTensor

    初始化:

    1、使用列表或者序列:

    import torch
    
    torch.tensor([[1., -1.], [1., -1.]])
    torch.tensor(np.array([[1, 2, 3], [4, 5, 6]]))
    
    • 1
    • 2
    • 3
    • 4

    注意:torch.tensor会拷贝数据,如果在改变requires_grad时避免拷贝数据,需要使用requires_grad_() or detach()。如果在使用numpy array初始化想避免拷贝,需要使用torch.as_tensor()

    2、设置类型和设备
    >>> torch.zeros([2, 4], dtype=torch.int32)
    tensor([[ 0,  0,  0,  0],
            [ 0,  0,  0,  0]], dtype=torch.int32)
    >>> cuda0 = torch.device('cuda:0')
    >>> torch.ones([2, 4], dtype=torch.float64, device=cuda0)
    tensor([[ 1.0000,  1.0000,  1.0000,  1.0000],
            [ 1.0000,  1.0000,  1.0000,  1.0000]], dtype=torch.float64, device='cuda:0')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    3、可以使用 Python 的索引和切片符号访问和修改张量的内容
    >>> x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    >>> print(x[1][2])
    tensor(6)
    >>> x[0][1] = 8
    >>> print(x)
    tensor([[ 1,  8,  3],
            [ 4,  5,  6]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    4、使用 torch.Tensor.item() 从包含单个值的张量中获取 Python 数字:
    >>> x = torch.tensor([[1]])
    >>> x
    tensor([[ 1]])
    >>> x.item()
    1
    >>> x = torch.tensor(2.5)
    >>> x
    tensor(2.5000)
    >>> x.item()
    2.5
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    5、requires_grad=True自动梯度计算
    >>> x = torch.tensor([[1., -1.], [1., 1.]], requires_grad=True)
    >>> out = x.pow(2).sum()
    >>> out.backward()
    >>> x.grad
    tensor([[ 2.0000, -2.0000],
            [ 2.0000,  2.0000]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
  • 相关阅读:
    python小项目之利用pygame实现代码雨动画效果(附源码 可供学习)
    尚硅谷Vue3入门到实战,最新版vue3+TypeScript前端开发教程
    网络方向知识点梳理
    基于双层共识控制的直流微电网优化调度(Matlab代码实现)
    【linux】从linux学软件开发 | 变量与echo
    Perl脚本获取.bash_profile中变量
    Shell三剑客之sed命令详解
    python不同版本常用功能差异
    王杰qtday2
    数据分析 第一周 折线图笔记
  • 原文地址:https://blog.csdn.net/KPer_Yang/article/details/126295538