• NNDL 作业10:第六章课后题(LSTM | GRU)


    习题6-3 当使用公式(6.50)作为循环神经网络得状态更新公式时,分析其可能存在梯度爆炸的原因并给出解决办法.

    令Zk= Uhk-1 + Wxk + b 为在第 k 时刻函数g(.) 的输入,在计算公式6.34中的误差项在这里插入图片描述

    时,梯度可能过大,从而导致梯度过大问题。

    解决方法:增加门控机制,例如:使用长短期记忆神经网络。

    习题6-4 推导LSTM网络中参数的梯度,并分析其避免梯度消失的效果

    在这里插入图片描述
    LSTM中通过门控机制解决梯度问题,遗忘门,输入门和输出门是非0就是1的,并且三者之间都是相加关系,梯度能够很好的在LSTM传递,减轻了梯度消失发生的概率,门为0时,上一刻的信息对当前时刻无影响,没必要接受传递更新参数了。

    习题6-5 推导GRU网络中参数的梯度,并分析其避免梯度消失的效果

    GRU结构图:
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述

    GRU具有调节信息流动的门单元,但没有一个单独的记忆单元,GRU将输入门和遗忘门整合成一个升级门,通过门来控制梯度。这种门控的方式,让网络学会如何设置门控数值,来决定何时让梯度消失,何时保持梯度。

    附加题 6-1P 什么时候应该用GRU? 什么时候用LSTM?

    • GRU和LSTM的区别:
      • LSTM利用输出门(output gate)可以选择性的使用细胞状态(cell state),而GRU总是不加选择的使用细胞状态
      • LSTM利用更新门(update gate)可以独立控制加入多少新的“记忆”,与老“记忆”无关,而GRU对新“记忆”的加入会受老“记忆”的约束,老“记忆”留存越多新“记忆”加入越少。
      • GRU的优点是其模型的简单性 ,因此更适用于构建较大的网络。它只有两个门控,从计算角度看,它的效率更高,它的可扩展性有利于构筑较大的模型;但是LSTM更加的强大和灵活,因为它具有三个门控。LSTM是经过历史检验的方法。

    附加题 6-2P LSTM BP推导,并用Numpy实现

    首先求出它们在t时刻的梯度,然后再求出他们最终的梯度。
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    import numpy as np
    def sigmoid(x):
        return 1/(1+np.exp(-x))
    
    def softmax(x):
        e_x = np.exp(x-np.max(x))# 防溢出
        return e_x/e_x.sum(axis=0)
    
    
    def LSTM_CELL_Forward(xt, h_prev, C_prev, parameters):
       
        # 获取参数字典中各个参数
        Wf = parameters["Wf"]
        bf = parameters["bf"]
        Wi = parameters["Wi"]
        bi = parameters["bi"]
        Wc = parameters["Wc"]
        bc = parameters["bc"]
        Wo = parameters["Wo"]
        bo = parameters["bo"]
        Wy = parameters["Wy"]
        by = parameters["by"]
    
        # 获取 xt 和 Wy 的维度参数
        n_x, m = xt.shape
        n_y, n_h = Wy.shape
    
        # 拼接 h_prev 和 xt
        concat = np.zeros((n_x + n_h, m))
        concat[: n_h, :] = h_prev
        concat[n_h:, :] = xt
    
        # 计算遗忘门、输入门、记忆细胞候选值、下一时间步的记忆细胞、输出门和下一时间步的隐状态值
        ft = sigmoid(np.dot(Wf, concat) + bf)
        it = sigmoid(np.dot(Wi, concat) + bi)
        cct = np.tanh(np.dot(Wc, concat) + bc)
        c_next = ft * C_prev + it * cct
        ot = sigmoid(np.dot(Wo, concat) + bo)
        h_next = ot * np.tanh(c_next)
    
        # LSTM单元的计算预测
        yt_pred = softmax(np.dot(Wy, h_next) + by)
    
        return h_next, c_next, yt_pred
    np.random.seed(1)
    xt = np.random.randn(3,10)
    h_prev = np.random.randn(5,10)
    c_prev = np.random.randn(5,10)
    Wf = np.random.randn(5, 5+3)
    bf = np.random.randn(5,1)
    Wi = np.random.randn(5, 5+3)
    bi = np.random.randn(5,1)
    Wo = np.random.randn(5, 5+3)
    bo = np.random.randn(5,1)
    Wc = np.random.randn(5, 5+3)
    bc = np.random.randn(5,1)
    Wy = np.random.randn(2,5)
    by = np.random.randn(2,1)
    
    parameters = {"Wf": Wf, "Wi": Wi, "Wo": Wo, "Wc": Wc, "Wy": Wy, "bf": bf, "bi": bi, "bo": bo, "bc": bc, "by": by}
    
    h_next, c_next, yt = LSTM_CELL_Forward(xt, h_prev, c_prev, parameters)
    print("a_next[4] = ", h_next[4])
    print("a_next.shape = ", c_next.shape)
    print("c_next[2] = ", c_next[2])
    print("c_next.shape = ", c_next.shape)
    print("yt[1] =", yt[1])
    print("yt.shape = ", yt.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68

    运行结果:
    在这里插入图片描述

    参考

    deeplearning.ai - 网易云课堂 (163.com)

  • 相关阅读:
    怎么将PDF文件转换成图片呢?
    管理团队相关的梳理
    95-Java的对象序列化、反序列化
    Git 的介绍、安装及其基本操作
    OS2.3.2:进程互斥的软件实现方法
    Python实现机器学习(下)— 数据预处理、模型训练和模型评估
    JavaScript查找最长的公共前缀
    leetcode 32. 最长有效括号
    jni场景下c++代码种,编写jstring转char*
    Linux常用命令——clockdiff命令
  • 原文地址:https://blog.csdn.net/qq_51713698/article/details/128059972