• pytorch应用于MNIST手写字体识别


    前言

    手写字体MNIST数据集是一组常见的图像,其常用于测评和比较机器学习算法的性能,本文使用pytorch框架来实现对该数据集的识别,并对结果进行逐步的优化。

    一、数据集

    MNIST数据集是由28x28大小的0-255像素值范围的灰度图像(如下图所示),其中6万张用于训练模型,1万张用于测试模型。
    在这里插入图片描述
    该数据集可从以下链接获取:
    训练数据集:
    https://pjreddie.com/media/files/mnist_train.csv
    测试数据集:
    https://pjreddie.com/media/files/mnist_test.csv
    数据集一行有785个值,第一个值为图像中的数字标签,其余784个值为图像的像素值。
    读取数据实例代码如下:

    import pandas
    import matplotlib.pyplot as plt
    
    df = pandas.read_csv(r'./data/mnist_train.csv', header=None)
    # print(df.head())  # 显示前5行
    # print(df.info())   # 显示DataFrame概况
    row = 0
    data = df.iloc[row]
    label = data[0],
    img = data[1:].values.reshape(28, 28)
    plt.title('label = ' + str(label))
    plt.imshow(img, interpolation='none', cmap='Blues')
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    在这里插入图片描述

    二、建立模型

    # 构建模型
    import torch
    import torch.nn as nn
    from torch.utils.data import Dataset
    
    
    class Classifier(nn.Module):
        def __init__(self):
            # 初始化pytorch父类
            super().__init__()
    
            self.model = nn.Sequential(
                nn.Linear(784, 200),
                nn.Sigmoid(),
                nn.Linear(200, 10),
                nn.Sigmoid()
            )
            self.loss_function = nn.MSELoss()
            self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
            self.counter = 0
            self.progress = []
    
        def forward(self, inputs):
            return self.model(inputs)
    
        def train_model(self, inputs, targets):
            outputs = self.forward(inputs)
            loss = self.loss_function(outputs, targets)
    
            self.optimizer.zero_grad()  # 梯度归零 ,因为反向传播计算的梯度会累计
            loss.backward()  # 反向传播
            self.optimizer.step()  # 更新权重
            # 可视化训练过程
            self.counter += 1
            if self.counter % 10 == 0:
                self.progress.append(loss.item())  # 获取单张张量里的数字
                pass
            if self.counter % 10000 == 0:
                print('counter = ', self.counter)
                pass
    
        def plot_progress(self):
            df = pandas.DataFrame(self.progress, columns=['loss'])
            df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
            plt.show()
            pass
    
    
    class MnistDataset(Dataset):
        def __init__(self, csv_file):
            self.data_df = pandas.read_csv(csv_file, header=None)
            pass
    
        def __len__(self):
            return len(self.data_df)
    
        def __getitem__(self, index):
            label = self.data_df.iloc[index, 0]
            target = torch.zeros((10))
            target[label] = 1
            image_value = torch.FloatTensor(self.data_df.iloc[index, 1:].values) / 255.0
            return label, image_value, target
    
        def plot_image(self, index):
            arr = self.data_df.iloc[index, 1:].values.reshape(28, 28)
            plt.title('label = ' + str(self.data_df.iloc[index, 0]))
            plt.imshow(arr, interpolation='none', cmap='Blues')
            plt.show()
            pass
        pass
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70

    以上建立模型框架,并对训练过程进行可视化,建立读取数据类。

    三、训练分类模型

    mnist_train_dataset = MnistDataset(r'./data/mnist_train.csv')
    # mnist_dataset.plot_image(9)
    
    # 训练分类模型
    start_time = time.time()
    C = Classifier()
    epochs = 3  # 训练3轮
    for i in range(epochs):
        print('training epoch ', i+1, 'of', epochs)
        for lable, image_tensor, target_tensor in mnist_train_dataset:
            C.train_model(image_tensor, target_tensor)
            pass
        pass
    C.plot_process()
    print('run time = ', (time.time()-start_time) / 60)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    训练3轮所花费的时间大约不到3min,效率还不错

    四、测试模型

    # 测试模型
    mnist_test_dataset = MnistDataset(r'./data/mnist_test.csv')
    record = 19
    mnist_test_dataset.plot_image(record)  # 图像里的数字
    image_data = mnist_test_dataset[record][1]
    output = C.forward(image_data)
    pandas.DataFrame(output.detach().numpy()).plot(kind='bar', legend=False, ylim=(0, 1))  # 预测的数字
    plt.show()
    
    score = 0
    items = 0
    for label, img_tensor, label_tensor in mnist_test_dataset:
        ans = C.forward(img_tensor)
        if ans.argmax() == label:
            score += 1
            pass
        items += 1
        pass
    print(score, items, score / items)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    模型的测试分数是87%,考虑到这是一个简单的网络,这个分数不算太差。

    五、模型优化

    模型的优化主要从四个方面着手:

    • 1、损失函数
      在上面的模型中设计损失函数为MSEloss,这里将其更改为二元交叉熵损失((binary cross entropy loss)
    self.loss_function = nn.BCELoss()
    
    • 1

    训练3轮,发现分数由87%提升到91%了

    • 2、激活函数
      Sigmoid激活函数的一个缺点是,当输入值变大时,梯度会变得非常小甚至消失。现在常用的是改进过的线性整流函数Leaky ReLU,也叫带泄露线性整流函数。
    self.model = nn.Sequential(
                nn.Linear(784, 200),
                # nn.Sigmoid(),
                nn.LeakyReLU(0.02),
                nn.Linear(200, 10),
                # nn.Sigmoid()
                nn.LeakyReLU(0.02)
            )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    损失函数为原来的MSEloss,训练3轮,分数由87%上升到97%,这是一个很大的提升。

    • 3 、优化器
      上面模型所使用的是梯度下降法,该方法的一个缺点是会陷入损失函数的局部最小值,另一个缺点是对所有可学习参数都使用同一学习率。常见的替代方案是Adam,它利用动量减少陷入局部最小的可能,另外它对每个可学习参数使用单独的学习率,这些学习率随着每个参数在训练期间的变化而变化。
    self.optimizer = torch.optim.Adam(self.parameters())
    
    • 1

    仅改变优化器发现模型达到和修改激活函数一样的效果,分数由87%提升到97%。

    • 4、标准化
      标准化是指减少网络中的参数和信号的取值范围,将均值转换为0,常见做法是在信号输入到神经网络前将其进行标准化。
    self.model = nn.Sequential(
                nn.Linear(784, 200),
                nn.Sigmoid(),
                # nn.LeakyReLU(0.02),
                nn.LayerNorm(200),     # 标准化
                nn.Linear(200, 10),
                nn.Sigmoid()
                # nn.LeakyReLU(0.02)
            )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    向网络中添加标准化,模型的分数87%提升到91%
    将以上所有方法进行整合,由于二元交叉熵函数只能处理0~1的值,而LeakyReLU可能会输出范围外的值,将后一层激活函数保留为原来的Sigmoid函数:

     self.model = nn.Sequential(
                nn.Linear(784, 200),
                # nn.Sigmoid(),
                nn.LeakyReLU(0.02),
                nn.LayerNorm(200),
                nn.Linear(200, 10),
                nn.Sigmoid()
                # nn.LeakyReLU(0.02)
            )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    3周期训练完后,模型的分数为97%,整合的优化方案无法使模型分数大于97%。

    END

    参考资料

    -[英]塔里克•拉希德(Tariq Rashid)著,韩江雷译. PyTorch生成对抗网络编程. 人民邮电出版社

  • 相关阅读:
    Argo rollouts + istio服务网格实现金丝雀灰度发布
    【图解 HTTP】简单的HTTP协议
    VR全景广告:让消费者体验沉浸式交互,让营销更有趣
    可视化大屏设计模板 | 主题皮肤(报表UI设计)
    计算经纬度坐标之间的真实距离
    leetcode-LCP 06. 拿硬币
    【Qt图书管理系统】1.项目设计与需求分析
    阿里云负载均衡配置只能域名访问
    保险丝的工作原理
    【AGC】引导用户购买提升用户留存率
  • 原文地址:https://blog.csdn.net/weixin_40356612/article/details/126074673