• torch.as_tensor()、torch.Tensor() 、 torch.tensor() 、transforms.ToTensor()的区别


    1)torch.as_tensor(data, dtype=None,device=None)->Tensor : 为data生成tensor,保留 autograd 历史记录并尽量避免复制(dtype和devices相同,尽量浅拷贝

    如果data已经是tensor,且dtype和device与参数相同,则生成的tensor会和data共享内存(浅拷贝)。如果data是ndarray,且dtype对应,devices为cpu,则同样共享内存。其他情况则不共享内存。

    1. #1)数据类型和device相同,浅拷贝,共享内存
    2. import numpy
    3. a = numpy.array([1, 2, 3])
    4. t = torch.as_tensor(a)
    5. t[0] = -1
    6. a,t
    7. #Out[77]: (array([-1, 2, 3]), tensor([-1, 2, 3], dtype=torch.int32))
    8. #2)数据类型相同,但是device不同,深拷贝,不再共享内存
    9. import numpy
    10. a = numpy.array([1, 2, 3])
    11. t = torch.as_tensor(a, device=torch.device('cuda'))
    12. t[0] = -1
    13. a,t
    14. #Out[78]: (array([1, 2, 3]), tensor([-1, 2, 3], device='cuda:0', dtype=torch.int32))
    15. #3)device相同,但数据类型不同,深拷贝,不再共享内存
    16. import numpy
    17. a = numpy.array([1, 2, 3])
    18. t = torch.as_tensor(a, dtype=torch.float32)
    19. t[0] = -1
    20. a,t
    21. Out[80]: (array([1, 2, 3]), tensor([-1., 2., 3.]))

    2) torch.tensor() 是一个通过深拷贝数据,构造一个新张量的函数

    torch.tensor(data*dtype=Nonedevice=Nonerequires_grad=Falsepin_memory=False) →Tensor

    深拷贝数据数据类型和device,with no autograd history (also known as a “leaf tensor”)。

    重点是data的数据类型can be a list, tuple, NumPy ndarray, scalar, and other types,就没waring。
    但data是tensor类型,使用torch.tensor(data)就会报waring::7: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

    #没警告:data can be a list, tuple, NumPy ndarray, scalar, and other types

    1. #没警告:data can be a list, tuple, NumPy ndarray, scalar, and other types
    2. import torch
    3. import numpy
    4. a = numpy.array([1, 2, 3])
    5. t = torch.tensor(a)
    6. b = [1,2,3]
    7. t= torch.tensor(b)
    8. c = (1,2,3)
    9. t= torch.tensor(c)

    #data是tensor类型,有警告 

    1. #data是tensor类型,有警告
    2. import torch
    3. import numpy
    4. d = torch.tensor([[1,2,3],[1,2,3]])
    5. t= torch.tensor(d) #能深拷贝,但会报warning,建议用t = a.clone().detach()
    6. # detach是内存共享的,而clone()是不内存共享的。
    7. print(d.shape,d.dtype,t.shape,t.dtype)
    8. # torch.Size([2, 3]) torch.int64 torch.Size([2, 3]) torch.int64

     

     3)torch.Tensor() 是默认张量类型 (torch.FloatTensor()) 的别名。也就是说,torch.Tensor() 的作用实际上跟 torch.FloatTensor() 一样,都是生成一个数据类型为 32 位浮点数的张量,如果没传入数据就返回空张量,如果有列表或者 narray 的返回其对应张量。但无论传入数据本身的数据类型是什么,返回的都是 32 位浮点数的张量。

    1. >>> torch.Tensor()
    2. tensor([])
    3. >>> torch.Tensor().dtype
    4. torch.float32
    5. >>> torch.FloatTensor()
    6. tensor([])
    7. >>> torch.FloatTensor().dtype
    8. torch.float32

    4)transforms.ToTensor()

    ToTensor()将shape为(H, W, C)的nump.ndarray或img转为shape为(C, H, W)的tensor,其将每一个数值归一化到[0,1],其归一化方法比较简单。

    1. # 归一化到(0,1)之后,再数据标准化处理 (x-mean)/std,归一化到(-1,1),数据中存在大于mean和小于mean
    2. transform2 = transforms.Compose([
    3. transforms.ToTensor(),
    4. transforms.Normalize(std=(0.5,0.5,0.5),
    5. mean=(0.5,0.5,0.5))])

    在transforms.Compose([transforms.ToTensor()])中加入transforms.Normalize(),如上面代码所示:则其作用就是先将输入归一化到(0,1),再使用公式"(x-mean)/std",将每个元素分布到(-1,1)。

  • 相关阅读:
    openmp 超越通用核心
    安卓Java面试题11-20
    R语言ggplot2可视化:使用ggpubr包的ggbarplot函数可视化柱状图、fill参数指定柱状图的填充色
    Unity编辑器扩展(一)编辑器扩展基础
    [补题记录] Atcoder Beginner Contest 299(E)
    【密码学篇】数字签名基础知识(无保密性)
    R语言ggplot2可视化:基于aes函数中的fill参数和shape参数自定义绘制分组折线图并添加数据点(散点)、设置可视化图像的主题为theme_bw
    @Valid与@Validated区别和详细使用及参数注解校验大全
    阿里巴巴面试题- - -多线程&并发篇(二十三)
    Java Object类详解
  • 原文地址:https://blog.csdn.net/qimo601/article/details/128014195