• PyTorch: 计算图与动态图机制


    本文已收录于Pytorch系列专栏: Pytorch入门与实践 专栏旨在详解Pytorch,精炼地总结重点,面向入门学习者,掌握Pytorch框架,为数据分析,机器学习及深度学习的代码能力打下坚实的基础。免费订阅,持续更新。

    计算图

    计算图是用来描述运算的有向无环图

    计算图有两个主要元素:

    • 结点 Node

    • 边 Edge

    结点表示数据:如向量,矩阵,张量

    边表示运算:如加减乘除卷积等

    用计算图表示:y = (x+ w) * (w+1)
    a = x + w
    b = w + 1
    y = a * b

    image-20221007140501247

    计算图与梯度求导

    y = (x+ w) * (w+1)
    a = x + w
    b = w + 1
    y = a * b

    image-20221007142316040

    ∂ y ∂ w = ∂ y ∂ a ∂ a ∂ w + ∂ y ∂ b ∂ b ∂ w = b ∗ 1 + a ∗ 1 = b + a = ( w + 1 ) + ( x + w ) = 2 ∗ w + x + 1 = 2 ∗ 1 + 2 + 1 = 5

    yw=yaaw+ybbw=b1+a1=b+a=(w+1)+(x+w)=2w+x+1=21+2+1=5" role="presentation" style="position: relative;">yw=yaaw+ybbw=b1+a1=b+a=(w+1)+(x+w)=2w+x+1=21+2+1=5
    wy=aywa+bywb=b1+a1=b+a=(w+1)+(x+w)=2w+x+1=21+2+1=5

    可见,对于变量w的求导过程就是寻找它在计算图中的所有路径的求导之和。

    code:

    import torch
    
    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    
    a = torch.add(w, x)     # retain_grad()
    b = torch.add(w, 1)
    y = torch.mul(a, b)
    
    y.backward()
    print(w.grad)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    tensor([5.])
    
    • 1

    计算图与梯度求导
    y = (x+ w) * (w+1)

    叶子结点 :用户创建的结点称为叶子结点,如 X 与 W

    is_leaf: 指示张量是否为叶子结点

    叶子节点的作用是标志存储叶子节点的梯度,而清除在反向传播过程中的变量的梯度,以达到节省内存的目的。

    当然,如果想要保存过程中变量的梯度值,可以采用retain_grad()

    grad_fn: 记录创建该张量时所用的方法(函数)

    • y.grad_fn=
    • a.grad_fn=
    • b.grad_fn=

    image-20221007142938198

    PyTorch的动态图机制

    根据计算图搭建方式,可将计算图分为动态图静态图

    • 动态图

      运算与搭建同时进行

      灵活 易调节

      例如动态图 PyTorch:

      image-20221007144304367

    • 静态

      先搭建图, 后运算

      高效 不灵活。

      静态图 TensorFlow

      image-20221007144319338

  • 相关阅读:
    ORA-22992 cannot use LOB locators selected from remote tables
    项目在linux上的简单部署
    Graphviz安装教程
    5 分钟,教你用 Docker 部署一个 Python 应用
    Windows 11家庭中文版安装安卓子系统
    好奇宝宝看 Docker 底层原理(上)
    maven 的多镜像次序生效问题
    Oracle 19c 可插拔数据库PDB的创建方式
    C#高级--XML详解
    py8_Python 类和对象最通俗易懂的解释
  • 原文地址:https://blog.csdn.net/m0_52316372/article/details/127609246