• 深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别


    一、前言

    基本的RNN存在梯度消失和梯度爆炸问题,会忘记它在较长序列中以前看到的内容,只具有短时记忆。得到比较广泛应用的是LSTM(Long Short Term Memory)——长短期记忆网络,它在一定程度上解决了这两个问题。

    二、网络结构

    我们来看一下LSTM网络的结构图:
    在这里插入图片描述
    咱们放大看看,由于网上找不到清晰版的示例图,亲绘了一幅:
    在这里插入图片描述
    LSTM包含遗忘门、输入门、输出门。分别用于LSTM的三个步骤:旧记忆的遗忘、新记忆的输入、最终结果的输出。

    三、可解释性

    为什么要这么设计LSTM网络呢?我们打个比方:

    小明上次考了数学,留下的大部分是数学的知识记忆 C t − 1 C_{t-1} Ct1;这次考生物,一些数学知识用不到,部分复杂的公式自然而然地被遗忘 f t ⊙ C t − 1 f_t\odot{C}_{t-1} ftCt1;复习生物知识一本书 C ~ t \tilde{C}_t C~t,大概记得八成 i t ⊙ C ~ t i_t\odot\tilde{C}_t itC~t,那么当前的记忆 C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t=f_t\odot{C}_{t-1}+i_t\odot\tilde{C}_t Ct=ftCt1+itC~t;考试时,成绩受到考题和当前记忆 C t C_t Ct的影响 h t = O t ⊙ tanh ⁡ C t h_t=O_t\odot\tanh{C_t} ht=OttanhCt

    注: ⊙ \odot 是矩阵的点乘符号,即两个矩阵对应元素相乘

    四、记忆主线

    在这里插入图片描述
    如上图所示,原有记忆是 C t − 1 C_{t-1} Ct1,经过遗忘(用矩阵参数进行点乘)、添加新记忆(加上新的记忆矩阵),当前最新的记忆就变成了 C t C_{t} Ct,如此循环,不重要的记忆就会忘记、重要的记忆就会一直流传下去。

    五、遗忘门

    第一步,我们会遗忘部分原有的记忆。
    在这里插入图片描述
    如上图所示, f t = σ ( W x f x t + W h f h t − 1 + b f ) f_t=\sigma(W_{xf}x_t+W_{hf} h_{t-1}+b_f) ft=σ(Wxfxt+Whfht1+bf)
    σ \sigma σ代表sigmoid函数。原有记忆是 C t − 1 C_{t-1} Ct1,遗忘后为 f t ⊙ C t − 1 f_t\odot{C}_{t-1} ftCt1

    六、输入门

    第二步,我们会新增部分新的记忆。我们要确定,哪些新信息要保留到记忆细胞里。
    在这里插入图片描述
    如上图所示,
    C ~ t = t a n h ( W x c x t + W h c h t − 1 + b c ) i t = σ ( W x i x t + W h i h t − 1 + b i ) ˜Ct=tanh(Wxcxt+Whcht1+bc)it=σ(Wxixt+Whiht1+bi)

    C~tit=tanh(Wxcxt+Whcht1+bc)=σ(Wxixt+Whiht1+bi)
    C~tit=tanh(Wxcxt+Whcht1+bc)=σ(Wxixt+Whiht1+bi)

    C ~ t \tilde{C}_t C~t表示所有的输入信息,但我们不是所有的都记得, i t i_t it控制记忆程度, i t ⊙ C ~ t i_t\odot\tilde{C}_t itC~t是本次输入所记住的信息。
    遗忘后的记忆是 f t ⊙ C t − 1 f_t\odot{C}_{t-1} ftCt1,输入新的记忆后, C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t=f_t\odot{C}_{t-1}+i_t\odot\tilde{C}_t Ct=ftCt1+itC~t

    七、输出门

    第三步,我们要根据现有记忆 C t C_t Ct,输出我们需要的内容。
    在这里插入图片描述
    如上图所示,
    O t = σ ( W x o x t + W h o h t − 1 + b o ) h t = O t ⊙ tanh ⁡ ( C t ) Ot=σ(Wxoxt+Whoht1+bo)ht=Ottanh(Ct)

    Otht=σ(Wxoxt+Whoht1+bo)=Ottanh(Ct)
    Otht=σ(Wxoxt+Whoht1+bo)=Ottanh(Ct)

    这就是LSTM网络的思想原理,接下来我们将用于手写数字识别实战。

    八、手写数字识别实战

    8.1 引入依赖库

    import torch
    import torch.nn as nn
    from torchvision import datasets,transforms
    from torch.utils.data import DataLoader
    import matplotlib.pyplot as plt
    
    • 1
    • 2
    • 3
    • 4
    • 5

    8.2 加载数据

    train_data = datasets.MNIST(root="./data",train=True,transform=transforms.ToTensor(),download=False)
    batch_size=64
    
    train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
    
    test_data = datasets.MNIST(root="./data",train=False,transform=transforms.ToTensor(),download=False)
    test_x = test_data.data.type(torch.FloatTensor)[:2000]/255.   #取2000个样本数据并将其缩放为0~1范围
    test_y = test_data.targets[:2000]
    
    print(train_data.data.shape)
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    torch.Size([60000, 28, 28])
    
    • 1

    8.3 迭代训练

    
    #迭代次数
    epochs=1
    
    #学习率
    learning_rate=0.01
    
    plt_epoch=[]
    plt_loss=[]
    
    class MyModel(nn.Module):
        def __init__(self):
            super().__init__()
    
            self.rnn = nn.LSTM(     # LSTM 效果要比 nn.RNN() 好多了
                input_size=28,      # 图片每行的数据像素点
                hidden_size=64,     # rnn hidden unit
                num_layers=1,       # 有几层 RNN layers
                batch_first=True,   # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
            )
    
            self.out = nn.Linear(64, 10)    # 输出层
    
        def forward(self, x):
            # x shape (batch, time_step, input_size)
            # r_out shape (batch, time_step, output_size)
            # h_n shape (n_layers, batch, hidden_size)   LSTM 有两个 hidden states, h_n 是分线, h_c 是主线
            # h_c shape (n_layers, batch, hidden_size)
            r_out, (h_n, h_c) = self.rnn(x, None)   # None 表示 hidden state 会用全0的 state
    
            # 选取最后一个时间点的 r_out 输出
            # 这里 r_out[:, -1, :] 的值也是 h_n 的值
            out = self.out(r_out[:, -1, :])
            return out
    
    model = MyModel()
    
    #损失函数
    cost=nn.CrossEntropyLoss()
    #迭代优化器
    optmizer=torch.optim.Adam(model.parameters(),lr=learning_rate)
    
    for epoch in range(epochs):
    
        for step, (images, labels) in enumerate(train_loader):
    
            images=images.view(-1,28,28)
    
            #预测结果
            output=model(images) #调用__call__函数
    
            #计算损失值
            loss=cost(output,labels)
    
            #在反向传播前先把梯度清零
            optmizer.zero_grad()
    
            #反向传播,计算各参数对于损失loss的梯度
            loss.backward()
    
            #根据刚刚反向传播得到的梯度更新模型参数
            optmizer.step()
        
            plt_epoch.append(step+1)
            plt_loss.append(loss.item())
            
            #打印损失值
            if step % 50 == 0:
                pred_y = model(test_x)
                pred_y=pred_y.argmax(dim=1) #返回最大值的下标
                print(f"step:{step},loss:{loss.item():.4f},accuracy: {(torch.sum(pred_y == test_y)/test_y.size()[0]) * 100:.2f}%")
    
    
    # 保存模型
    torch.save(model, 'LSTM_Digits.pt')
    
    #绘制迭代次数与损失函数的关系
    plt.plot(plt_epoch,plt_loss)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    step:0,loss:2.3081,accuracy: 8.75%
    step:50,loss:1.0913,accuracy: 59.40%
    step:100,loss:0.7879,accuracy: 70.30%
    step:150,loss:0.7618,accuracy: 73.75%
    step:200,loss:0.4271,accuracy: 86.70%
    step:250,loss:0.3963,accuracy: 90.65%
    step:300,loss:0.2965,accuracy: 91.85%
    step:350,loss:0.3396,accuracy: 94.15%
    step:400,loss:0.2283,accuracy: 92.30%
    step:450,loss:0.4932,accuracy: 94.05%
    step:500,loss:0.2487,accuracy: 93.25%
    step:550,loss:0.1460,accuracy: 94.20%
    step:600,loss:0.1908,accuracy: 94.70%
    step:650,loss:0.1521,accuracy: 92.35%
    step:700,loss:0.1530,accuracy: 94.80%
    step:750,loss:0.1192,accuracy: 94.65%
    step:800,loss:0.0478,accuracy: 95.30%
    step:850,loss:0.0535,accuracy: 95.70%
    step:900,loss:0.1174,accuracy: 95.45%
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    在这里插入图片描述

    8.4 数据验证

    #加载模型
    model=torch.load('LSTM_Digits.pt')
    
    #预测结果
    pred_y=model(test_x)
    #计算损失值
    loss=cost(pred_y,test_y)
    
    print('loss:',loss.detach().item())
    
    pred_y=pred_y.argmax(dim=1) #返回最大值的下标
    print(f"准确率: {(torch.sum(pred_y == test_y)/test_y.size()[0]) * 100}%")
    
    # 打印10个预测结果
    pred_y = model(test_x[:10].view(-1, 28, 28))
    pred_y=pred_y.argmax(dim=1) #返回最大值的下标
    print('预测数字',pred_y)
    print( '真实数字',test_y[:10])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    loss: 0.11265470087528229
    准确率: 96.45000457763672%
    预测数字 tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])
    真实数字 tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])
    
    • 1
    • 2
    • 3
    • 4

    九、参考资料

    如何从RNN起步,一步一步通俗理解LSTM
    大白话讲解LSTM长短期记忆网络 如何缓解梯度消失,手把手公式推导反向传播
    Understanding LSTM Networks
    【Pytorch教程】:RNN 循环神经网络 (分类)

  • 相关阅读:
    多线程和并发问题详解
    库存管理方法有哪些?有哪些库存管理软件比较好用?
    盒式交换机堆叠配置
    计算机毕业设计springboot+vue+elementUI高校招生管理系统
    远程办公安全:共同守护数字时代的明日
    实验七 循环神经网络(3)LSTM的记忆能力实验
    数据结构(c语言版) 栈
    如何构建完整的财务体系?
    跨项目配置,nacos的动态更新配置,如何才能生效
    HTMLDOM中的API之btoa和atob
  • 原文地址:https://blog.csdn.net/Leytton/article/details/128135701