• 原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列8


    在这里插入图片描述


    前言

    这是该系列原型网络的最后一段代码及其详细解释,感谢各位的阅读!


    一、原始代码

    if __name__ == '__main__':
        ##载入数据
        labels_trainData, labels_testData = load_data()  # labels_trainData是字典,是key:value形式
        class_number_train = max(list(labels_trainData.keys())) #963
        class_number_test = max(list(labels_testData.keys())) #658
    
        wide = labels_trainData[0][0].shape[0]  # 105      #二维张量,shape[0]代表行数,shape[1]代表列数
        length = labels_trainData[0][0].shape[1]  # 105
    
        for label in labels_trainData.keys():
            labels_trainData[label] = np.reshape(labels_trainData[label], [-1, 1, wide, length])
    
        for label in labels_testData.keys():
            labels_testData[label] = np.reshape(labels_testData[label], [-1, 1, wide, length])
    
        ##初始化模型
        protonets = Protonets((1, wide, length), 10, 5, 5, 60, './log/', 50)  # '''根据需求修改类的初始化参数,参数含义见protonets_net.py'''
    
        ##训练prototypical_network
        for n in range(100):  ##随机选取x个类进行一个episode的训练
            protonets.train(labels_trainData, class_number_train)
            if n % 2 == 0 and n != 0:  # 每5次存储一次模型,并测试模型的准确率,训练集的准确率和测试集的准确率被存储在model_step_eval.txt中
                torch.save(protonets.model, './log/model_net_' + str(n) + '.pkl')
                protonets.save_center('./log/model_center_' + str(n) + '.csv')
                test_accury = protonets.evaluation_model(labels_testData, class_number_test)
                print(test_accury)
                str_data = str(n) + ',' + str('       test_accury     ') + str(test_accury) + '\n'
                with open('./log/model_step_eval.txt', "a") as f:
                    f.write(str_data)
            print(n)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    二、对每一行代码的解释:

    1. if __name__ == '__main__':
      这是一个Python的惯用写法,表示当脚本直接被运行时(而不是被作为模块导入时),才会执行下面的代码块。

    2. labels_trainData, labels_testData = load_data()
      调用 load_data() 函数加载数据,并将返回的标签训练数据和标签测试数据保存到 labels_trainDatalabels_testData 变量中。

    3. class_number_train = max(list(labels_trainData.keys()))
      获取标签训练数据中的最大键(即最大类别数),并将其保存到 class_number_train 变量中。

    4. class_number_test = max(list(labels_testData.keys()))
      获取标签测试数据中的最大键(即最大类别数),并将其保存到 class_number_test 变量中。

    5. wide = labels_trainData[0][0].shape[0]
      获取标签训练数据中第一个样本的宽度,并将其保存到 wide 变量中。

    6. length = labels_trainData[0][0].shape[1]
      获取标签训练数据中第一个样本的长度,并将其保存到 length 变量中。

    7. for label in labels_trainData.keys():
      遍历标签训练数据中的所有键。

    8. labels_trainData[label] = np.reshape(labels_trainData[label], [-1, 1, wide, length])
      对每个标签训练数据进行重塑,将其形状改为 [-1, 1, wide, length],其中 -1 表示自动计算维度大小。

    9. for label in labels_testData.keys():
      遍历标签测试数据中的所有键。

    10. labels_testData[label] = np.reshape(labels_testData[label], [-1, 1, wide, length])
      对每个标签测试数据进行重塑,将其形状改为 [-1, 1, wide, length]

    11. protonets = Protonets((1, wide, length), 10, 5, 5, 60, './log/', 50)
      创建一个 Protonets 类的实例,传入模型的初始化参数。

    12. for n in range(100):
      从0到99的循环中,执行以下代码块。

    13. protonets.train(labels_trainData, class_number_train)
      调用 protonets 实例的 train() 方法进行模型训练,传入标签训练数据和类别数。

    14. if n % 2 == 0 and n != 0:
      如果 n 是偶数且不为0,则执行以下代码块。

    15. torch.save(protonets.model, './log/model_net_' + str(n) + '.pkl')
      保存模型到 './log/model_net_' + str(n) + '.pkl' 的文件路径。

    16. protonets.save_center('./log/model_center_' + str(n) + '.csv')
      调用 protonets 实例的 save_center() 方法,将模型的中心点保存到 './log/model_center_' + str(n) + '.csv'

    17. test_accury = protonets.evaluation_model(labels_testData, class_number_test)
      调用 protonets 实例的 evaluation_model() 方法,对模型进行评估并返回测试准确率,将其保存到 test_accury 变量中。

    18. print(test_accury)
      打印测试准确率。

    19. str_data = str(n) + ',' + str(' test_accury ') + str(test_accury) + '\n'
      构建一个字符串以保存到文件中。

    20. with open('./log/model_step_eval.txt', "a") as f:
      打开一个文件,以追加模式写入。


    总结

    原型网络(Prototypical Network)是一种用于小样本学习的模型,由Jake Snell等人于2017年提出。它是一种基于元学习(meta-learning)的方法,主要用于解决在具有少量标记样本的情况下进行分类任务的问题。

    传统的深度学习模型在处理小样本学习时通常表现不佳,因为它们需要大量的标记样本来进行训练。然而,在现实世界中,我们往往只有少量标记样本可用。原型网络通过引入一个用于表示类别的中心向量(原型)的概念,解决了这个问题。

    原型网络的功能和优势如下:

    1. 小样本学习:原型网络适用于具有少量标记样本的分类任务,可以在只有几个样本可用时进行准确的分类。

    2. 元学习能力:原型网络通过学习类别的原型向量,能够在遇到新类别时进行快速学习,从而实现元学习的目标。

    3. 欧氏距离度量:原型网络使用欧氏距离来度量样本与原型之间的相似性,从而进行分类推断。这种度量方式非常直观和可解释,使得模型更易于理解

  • 相关阅读:
    ReentrantLock源码解析
    Python3 升级urllib3所遇到问题以及解决
    Leetcode 剑指 Offer II 045. 找树左下角的值
    Spring整合RabbitMQ——生产者
    机器学习 | SVD奇异值分解
    awtk用C语言开发串口通信示例
    车间调度|基于帝王蝶优化算法的车间调度(Matlab代码实现)
    【云原生】如何快速部署Kubernetes
    PT_数字特征_矩协方差相关系数
    多项目的.net core解决方案(项目间引用)如何使用Docker部署
  • 原文地址:https://blog.csdn.net/qlkaicx/article/details/134487281