• PyTorch概述(八)---ADAM


    1. import torch
    2. torch.optim.Adam(params,
    3. lr=0.001,
    4. betas=(0.9,0.999),
    5. eps=1e-8,
    6. weight_decay=0,
    7. amsgrad=False,
    8. *,
    9. foreach=None,
    10. maximize=False,
    11. capturable=False,
    12. differentiable=False,
    13. fused=None)
    • 类torch.optim.Adam实现Adam算法

    Adam算法描述

    • 输入:\gamma(lr),\beta_1,\beta_2(betas),\theta_0(params),f(\theta)(objective),\lambda(weight decay),
    • amsgrad,maximize
    • 初始化:m0\leftarrow 0(first- moment),v0\leftarrow0(second-moment),\widehat{v0}^{max}\leftarrow 0
    • for t=1 to ...do
      • if maximize:
        • g_t\leftarrow -\bigtriangledown_{\theta}f_t(\theta_{t-1})
      • else
        • g_t\leftarrow \bigtriangledown_{\theta}f_t(\theta_{t-1})
      • if \lambda \neq 0
        • g_t \leftarrow g_t+\lambda \theta_{t-1}
      • m_t \leftarrow \beta_1m_{t-1}+(1-\beta_1)g_t
      • v_t\leftarrow\beta_2v_{t-1}+(1-\beta_2)g_t^2
      • \widehat{m_t}\leftarrow m_t/(1-\beta_1^t)
      • \widehat{v_t}\leftarrow v_t/(1-\beta_2^t)
      • if amsgrad
        • \widehat{v_t}^{max}\leftarrow max(\widehat{v_t}^{max},\widehat{v_t})
        • \theta_t\leftarrow \theta_{t-1}-\gamma\widehat{m_t}/(\sqrt{\widehat{v_t}^{max}}+\epsilon)
      • else
        • \theta_t\leftarrow \theta_{t-1}-\gamma\widehat{m_t}/(\sqrt{\widehat{v_t}}+\epsilon)
    • return \theta_t

    参数

    • params(iterable)-- 可迭代的优化参数或者定义参数组的字典;
    • lr(float,Tensor,optional)---学习率(默认1e-3),张量LR还没有被所有的算法实现所支持;如果没有指定fused为True或者capturable为True的情况下请使用一个浮点型LR;
    • betas(Tuple[float,float],optional)---用于计算运行中的梯度均值和他的平方的系数,默认为(0.9,0.999);
    • eps(float,optional)---加和到分母上的项以提高数值稳定性(默认为1e-8);
    • weight_decay(float,optional)---权重衰减(L2惩罚)(默认0);
    • amsgrad(bool,optional)---该算法是否使用AMSGrad变量;
    • foreach(bool,optional)---是否使用foreach实现的优化器;如果未设置,在CUDA上将使用foreach的实现而不是for-loop的实现,因为foreach的实现具有更优化的性能;注意由于张量列表而不是张量的原因,foreach的实现较for-loop实现使用更多的峰值内存;如果内存被限制,优化器一次批处理更少的参数或者将此项设置为False(默认为None);
    • maximize(bool,optional)---最大化相对于参数的目标而不是最小化(默认:False);
    • capturable(bool,optional)---在CUDA图中捕捉此实例是否安全,设置为True可以损坏未绘图的性能,所以如果不打算图形捕捉实例,将其设置为False(默认为False);
    • differentiable(bool,optional)---训练中,优化器迭代步中是否自动梯度,如果不使用,在上下文中step()函数以torch.no_grad()运行;设置为True可能损害性能,因此在训练中如果不想使用自动梯度,将其设置为False(默认False);
    • fused(bool,optional)---是否使用fused实现(仅支持CUDA),当前torch.float64,torch.float32,torch.float16,torch.bfloat16被支持(默认None);
  • 相关阅读:
    如何设计一套单点登录系统 ?
    网络安全进阶学习第二十一课——XXE
    MySQL 快速入门之第一章 账号管理、建库以及四大引擎
    8、Mybatis-Plus 分页插件、自定义分页
    Android安全专题-so逆向入门和使用ida动态调试
    Golang | Web开发之Gin框架快速入门基础实践
    jvisualvm 远程连接 jvm
    vuepress+gitee免费搭建个人在线博客(无保留版)
    搭建zabbix监控及邮件报警(超详细教学)
    如何使用baostock代码下载股票数据?
  • 原文地址:https://blog.csdn.net/newsymme/article/details/136318090