• torch 神经网络模型构建


    点赞收藏关注!
    如需转载,请注明出处!

    torch 的模型搭建,做一下简要的介绍

    神经网络由对数据进行操作的层/模块(layers/modules)组成。
    torch.nn提供构建网络的所有blocks,在PyTorch中的每个modules都继承了nn.Module,可以构建各种复杂的网络结构。通过nn.Module定义神经网络,使用init初始化,对数据的所有操作都在forward()中实现

    class NeuralNetwork(nn.Module):
    def __init__(self):
    	super(NeuralNetwork, self).__init__()
    	self.flatten = nn.Flatten()
    	self.linear_relu_stack = nn.Sequential(
    	nn.Linear(28*28, 512),
    	nn.ReLU(),
    	nn.Linear(512, 512),
    	nn.ReLU(),
    	nn.Linear(512, 10),
    	nn.ReLU()
    	)
    
    #前向传播
    def forward(self, x)
    	x = self.flatten(x)
    	logits = self.linear_relu_stack(x)
    	return logits
    
    
    ##使用示例
    
    #检测是否有GPU可用,若有可以在GPU上训练模型
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('Using {} device'.format(device))
    model = NeuralNetwork().to(device)
    print(model)
    X = torch.rand(1, 28, 28, device=device)
    logits = model(X)
    pred_probab = nn.Softmax(dim=1)(logits)
    y_pred = pred_probab.argmax(1)
    print(f"Predicted class: {y_pred}")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    为了方便理解,整理了代码中一些函数的意义

    • nn.Flatten()将连续的维度范围展平为张量。一般写在某个神经网络模型之后,用于对神经网络模型的输出进行处理,得到tensor类型的数据。
    • nn.Sequential()是PyTorch中的一个类,它允许用户将多个计算层按照顺序组合成一个模型。在深度学习中,模型可以是由各种不同类型的层组成的,例如卷积层、池化层、全连接层等。nn.Sequential()方法可以将这些层组合在一起,形成一个整体模型。
    • nn.Linear定义一个神经网络的线性层.,Linear其实就是对输入 X执行了一个线性变换的操作。
    • nn.ReLU()模型的激活函数,nn.relu函数是神经网络中常用的激活函数之一,即修正线性单元(Rectified Linear Unit)。ReLU函数的数学表示为f(x) = max(0, x),即输出值等于输入值和0中的较大者。ReLU函数的特点是在输入值大于0时,输出为输入值本身;而在输入值小于等于0时,输出为0。这意味着ReLU会将负值归零,而对正值不做修改

    如有帮助点赞收藏关注!
    如需转载,请注明出处!

  • 相关阅读:
    java毕业设计企业资产管理系统mybatis+源码+调试部署+系统+数据库+lw
    企业数据挖掘平台产品特色及合作案例介绍
    Vuex 和 Redux 的区别?
    「学习笔记」vector
    【MindSpore易点通】如何使用溢出检测工具定位精度问题
    我的master节点是好的,现在是第一个node1节点利用master节点给出的加入的代码结果报错,如何解决?(标签-kubernetes)
    Spring命名空间
    【Android笔记25】Android中的动画效果之逐帧动画
    一个好用的多方隐私求交算法库JasonCeng/MultipartyPSI-Pro
    如何设计一张数据库表
  • 原文地址:https://blog.csdn.net/weixin_42362399/article/details/134530131