• 第六章 网络学习相关技巧5(超参数验证)



    神经网络中,除了权重和偏置等参数,超参数(hyper-parameter)也经常出现。这里所说的超参数是指,比如各层的神经元数量、batch大小、参数更新时的学习率或权值衰减等。如果这些超参数没有设置合适的值,模型的性能就会很差。虽然超参数的取值非常重要,但是在决定超参数的过程中一般会伴随很多的试错。本节将介绍尽可能高效地寻找超参数的值的方法。

    6.1验证数据

    我们使用的数据集分成了训练数据和测试数据,训练数据用于学习,测试数据用于评估泛化能力。由此,就可以评估是否只过度拟合了训练数据(是否发生了过拟合),以及泛化能力如何等。

    然而,我们也需要对超参数设置各种各样的值以进行验证。因此,调整超参数时,必须使用超参数专用的确认数据。用于调整超参
    数的数据,一般称为验证数据(validation data)。我们使用这个验证数据来评估超参数的好坏。

    【注】训练数据用于参数(权重和偏置)的学习,验证数据用于超参数的性能评估。测试数据是为了确认泛化能力,要在最后使用(比较理想的是只用一次)。

    根据不同的数据集,有的会事先分成训练数据、验证数据、测试数据三部分,有的只分成训练数据和测试数据两部分,有的则不进行分割。在这种情况下,用户需要自行进行分割。如果是MNIST数据集,获得验证数据的最简单的方法就是从训练数据中事先分割20%作为验证数据。

    代码实现如下:

    # coding: utf-8
    import os
    import sys
    sys.path.append(os.pardir) 
    from dataset.mnist import load_mnist
    from common.util import shuffle_dataset
     
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
     
    # 打乱训练数据
    x_train, t_train = shuffle_dataset(x_train, t_train)
     
    # 分割验证数据
    validation_rate = 0.20
    validation_num = int(x_train.shape[0] * validation_rate)  # 验证数据集的数量
     
    x_val = x_train[:validation_num]
    t_val = t_train[:validation_num]
    x_train = x_train[validation_num:]
    t_train = t_train[validation_num:]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    6.2超参数的最优化

    进行超参数的最优化时,逐渐缩小超参数的“好值”的存在范围非常重要。

    逐渐缩小范围:是指一开始先大致设定一个范围,从这个范围中随机选出一个超参数(采样),用这个采样到的值进行识别精度的评估;然后,多次重复该操作,观察识别精度的结果,根据这个结果缩小超参数的“好值”的范围。通过重复这一操作,就可以逐渐确定超参数的合适范围。

    【注】在进行神经网络的超参数的最优化时,与网格搜索等有规律的搜索相比,随机采样的搜索方式效果更好。这是因为在多个超参数中,各个超参数对最终的识别精度的影响程度不同。

    超参数的范围只要“大致地指定”就可以了。所谓“大致地指定”,是指像0.001(10^ −3 )到1000(10^ 3 )这样,以“10的阶乘”的尺度指定范围(也表述为“用对数尺度(log scale)指定”)。在Python中可以写成 10 ** np.random.uniform(-3, 3) 。

    在超参数的最优化中,要注意的是深度学习需要很长时间(比如,几天或几周)。因此,在超参数的搜索中,需要尽早放弃那些不符合逻辑的超参数。于是,在超参数的最优化中,减少学习的epoch,缩短一次评估所需的时间是一个不错的办法。

    6.2.1优化步骤

    1、设定超参数的范围。

    2、从设定的超参数范围中随机采样。

    3、使用步骤2中采样到的超参数的值进行学习,通过验证数据评估识别精度(但是要将epoch设置得很小)。

    4、重复步骤2和步骤3(100次等),根据它们的识别精度的结果,缩小超参数的范围。

    反复进行上述操作,不断缩小超参数的范围,在缩小到一定程度时,从该范围中选出一个超参数的值。这就是进行超参数的最优化的一种方法。

    【注】在超参数的最优化中,如果需要更精炼的方法,可以使用贝叶斯最优化(Bayesian optimization)。

    6.3实现

    使用MNIST数据集进行超参数的最优化。这里我们将学习率和控制权值衰减强度的系数(下文称为“权值衰减系数”)这两个超参数的搜索问题作为对象。

    通过从0.001(10 −3 )到1000(10 3 )这样的对数尺度的范围中随机采样进行超参数的验证。这在Python中可以写成 10 ** np.random.uniform(-3, 3) 。在该实验中,权值衰减系数的初始范围为10 −8 到10 −4 ,学习率的初始范围为10 ^−6 到10 ^−2 。此时,超参数的随机采样的代码如下所示。

    weight_decay = 10 ** np.random.uniform(-8, -4)
    lr = 10 ** np.random.uniform(-6, -2)
    
    • 1
    • 2

    6.3.1案例

    例:使用MNIST数据以权值衰减系数为10 −8 到10 −4 、学习率为10 −6 到10 −2 的范围进行实验。

    文件目录如下:

    img

    funtions.py, gradient.py, layers.py, multi_layer_net.py, optimizer.py, util.py)见前面博文

    trainer.py见该博文

    6.3.2代码及结果

    6.3.2.1结果

    运行hyperparameter_optimization.py 结果如下:

    img

    【注】按识别精度从高到低的顺序排列了验证数据的学习的变化。从图中可知,直到“Best-5”左右,学习进行得都很顺利。

    img

    【注】从这个结果可以看出,学习率在0.001到0.01、权值衰减系数在10 −8 到10 −6 之间时,学习可以顺利进行。像这样,观察可以使学习顺利进行的超参数的范围,从而缩小值的范围。然后,在这个缩小的范围中重复相同的操作。这样就能缩小到合适的超参数的存在范围,然后在某个阶段,选择一个最终的超参数的值。

    6.3.2.2代码实现

    hyperparameter_optimization.py 代码实现如下:

    # coding: utf-8
    import sys, os
    sys.path.append(os.pardir)
    import numpy as np
    import matplotlib.pyplot as plt
    from dataset.mnist import load_mnist
    from common.multi_layer_net import MultiLayerNet
    from common.util import shuffle_dataset
    from common.trainer import Trainer
     
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
     
    # 削减学习数据
    x_train = x_train[:500]
    t_train = t_train[:500]
     
    # 训练集与验证集分离
    validation_rate = 0.20
    validation_num = int(x_train.shape[0] * validation_rate)
    x_train, t_train = shuffle_dataset(x_train, t_train)
    x_val = x_train[:validation_num]
    t_val = t_train[:validation_num]
    x_train = x_train[validation_num:]
    t_train = t_train[validation_num:]
     
     
    def __train(lr, weight_decay, epocs=50):
        network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100],
                                output_size=10, weight_decay_lambda=weight_decay)
        trainer = Trainer(network, x_train, t_train, x_val, t_val,
                          epochs=epocs, mini_batch_size=100,
                          optimizer='sgd', optimizer_param={'lr': lr}, verbose=False)
        trainer.train()
     
        return trainer.test_acc_list, trainer.train_acc_list
     
     
    # 超参数随机搜索
    optimization_trial = 100
    results_val = {}
    results_train = {}
    for _ in range(optimization_trial):
        # 指定搜索超参数的范围===============
        weight_decay = 10 ** np.random.uniform(-8, -4)
        lr = 10 ** np.random.uniform(-6, -2)
        # ================================================
     
        val_acc_list, train_acc_list = __train(lr, weight_decay)
        print("val acc:" + str(val_acc_list[-1]) + " | lr:" + str(lr) + ", weight decay:" + str(weight_decay))
        key = "lr:" + str(lr) + ", weight decay:" + str(weight_decay)
        results_val[key] = val_acc_list
        results_train[key] = train_acc_list
     
    # 绘制图表========================================================
    print("=========== Hyper-Parameter Optimization Result ===========")
    graph_draw_num = 20
    col_num = 5
    row_num = int(np.ceil(graph_draw_num / col_num))
    i = 0
     
    for key, val_acc_list in sorted(results_val.items(), key=lambda x:x[1][-1], reverse=True):
        print("Best-" + str(i+1) + "(val acc:" + str(val_acc_list[-1]) + ") | " + key)
     
        plt.subplot(row_num, col_num, i+1)
        plt.title("Best-" + str(i+1))
        plt.ylim(0.0, 1.0)
        if i % 5: plt.yticks([])
        plt.xticks([])
        x = np.arange(len(val_acc_list))
        plt.plot(x, val_acc_list)
        plt.plot(x, results_train[key], "--")
        i += 1
     
        if i >= graph_draw_num:
            break
     
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
  • 相关阅读:
    windows socket网络编程--事件选择模型
    equals方法与hashCode方法相关
    HarmonyOS脚手架:快捷实现ArkTs中json转对象
    C#__资源访问冲突和死锁问题
    JMH基准测试工具 (一):介绍
    Mac笔记本聚焦SpotLight占用内存太高的 解法
    请问我的html内部打开不了视频是什么原因
    linux环境下安装Nacos
    日期时间存入数据库会差一天?
    Win10,Office2016及以上图标异常解决方案
  • 原文地址:https://blog.csdn.net/segegse/article/details/125438204