• 动手学深度学习(Pytorch版)代码实践 -深度学习基础-06Softmax回归简洁版


    06Softmax回归简洁版

    import torch
    from torch import nn
    from d2l import torch as d2l
    import liliPytorch as lp
    
    batch_size = 256
    train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
    
    #初始化
    #Pytorch 不会隐式地调整输入的形状
    net = nn.Sequential(nn.Flatten(), nn.Linear(784,10))
    """
    nn.Sequential是PyTorch中的一个容器,
    它将多个层(modules)按照它们在传入的顺序组合在一起。
    数据按顺序通过这些层进行传递。
    
    nn.Flatten():这是一个将输入张量(tensor)展平的层。
    它会将多维的输入张量展平成一维。
    
    nn.Linear(784, 10):这是一个全连接层(或线性层)。
    它将输入张量从大小为784的向量变换为大小为10的向量。
    这个操作相当于进行一个矩阵乘法,再加上一个偏置向量。
    通常用于分类任务中将展平后的图像数据映射到10个类别
    """
    
    #这个参数是神经网络中的一个层(module)
    def init_weights(m):
        #检查参数m是否是一个全连接层(nn.Linear)。
        #只有当m是nn.Linear类型时,才会对其进行权重初始化。
        if type(m) == nn.Linear:
            """
            nn.init.normal_函数对其权重进行初始化。
            nn.init.normal_函数将权重初始化为服从均值为0,标准差为0.01的正态分布的值。
            注意这里使用的是原地操作(in-place operation),即直接修改了m.weight的值。
            """
            nn.init.normal_(m.weight, std=0.01)
    
    """
    apply方法会递归地遍历net中的所有子模块,
    并将init_weights函数应用到每一个模块上。
    这样,如果net中有多个全连接层(nn.Linear),
    init_weights函数就会对每一个全连接层的权重进行初始化。
    """
    net.apply(init_weights)
    
    #损失函数
    """
    这个损失函数结合了nn.LogSoftmax和nn.NLLLoss,
    它先计算每个类别的预测概率的对数(通过LogSoftmax),
    然后计算真实类别的负对数似然(Negative Log Likelihood)
    """
    loss = nn.CrossEntropyLoss(reduction='none')
    
    #小批量随机梯度下降作为优化算法
    #net.parameters()返回神经网络net中所有需要优化的参数
    trainer = torch.optim.SGD(net.parameters(), lr=0.1)
    
    #训练模型
    #d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer) #报错
    #将softmax基础班中的代码,封装到liliPytorch中,调用
    lp.train_ch3(net, train_iter, test_iter, loss, num_epochs=5, updater=trainer)
    d2l.plt.show() #可视化
    

    运行结果:

    <Figure size 350x250 with 1 Axes>
    epoch: 1,train_loss: 0.7846037483851115,train_acc: 0.7511833333333333,test_acc: 0.7936
    <Figure size 350x250 with 1 Axes>
    epoch: 2,train_loss: 0.5698513298034668,train_acc: 0.8127833333333333,test_acc: 0.8021
    <Figure size 350x250 with 1 Axes>
    epoch: 3,train_loss: 0.5255562342961629,train_acc: 0.8256,test_acc: 0.8002
    <Figure size 350x250 with 1 Axes>
    epoch: 4,train_loss: 0.5013835444132487,train_acc: 0.83245,test_acc: 0.8235
    <Figure size 350x250 with 1 Axes>
    epoch: 5,train_loss: 0.4861805295308431,train_acc: 0.8363666666666667,test_acc: 0.8167
    
  • 相关阅读:
    牛客网verilog刷题知识点盘点(75道题的版本)
    CentOS 7.6环境下Nginx1.23.3下载安装配置使用教程
    【视觉算法系列1】使用 KerasCV YOLOv8 进行红绿灯检测(上)
    并发之Synchronized说明
    C语言葵花宝典之——文件操作
    Hive环境搭建_远程部署
    共享内存和信号量的配合机制
    Day01_《MySQL索引与性能优化》摘要
    华为要用MateBook E Go系列开辟一个新市场
    java计算机毕业设计计算机office课程平台MyBatis+系统+LW文档+源码+调试部署
  • 原文地址:https://blog.csdn.net/weixin_46560570/article/details/139778211