• 机器学习实战(11)——初识人工神经网络


    目录

    1 感知器

    2 多层感知器和反向传播

    3 用TensorFlow的高级API训练MLP

    4 使用纯TensorFlow训练DNN

    4.1 构建阶段

    4.2 执行阶段

    5 使用神经网络


    1 感知器

    感知器是最简单的ANN架构之一。它基于一个稍微不同的被称为线性阈值单元(LTU)的人工神经元:输入和输出都是数字,每个输入的连接都有一个对应的权重。LTU会加权求和所有的输入(z=\omega _{1}x_{1}+\omega _{2}x_{2}+\cdots +\omega _{n}x_{n}=W^{T}\cdot x),然后对求值结果应用一个阶跃函数并产生最后的输出,如下图:

    感知器中最常见的阶跃函数是Heaviside阶跃函数,如下:

    单个LTU可以用来做简单的线性二值分类。它计算输入的线性组合,如果结果超出了阈值,输出就是正,反之则为负。训练LTU的意思是寻找 w_{0},w_{1} 和 w_{2} 的正确值。

    感知器就是个单层的LTU,每个神经元都与所有输入相连。这些连接通常使用称为输入神经元的特殊传递神经元来表示:输入什么就输出什么。此外,还会加上一个额外的偏差特征(x_{0}=1)。偏差特征通常用偏差神经元来表示,永远都只输出1。

    图中感知器可以将实例同时分为三个不同的二进制类,因此它被称为多输出分类器。

    感知器怎么训练呢?当两个神经元有相同的输出时,它们之间的连接权重就会增强。感知器就是使用这个规则的变体进行训练。

    感知器学习规则

    因为每个输出神经元的决策边界是线性的,所以感知器无法学习复杂的模式。

    Scikit-Learn提供了一个实现单一LTU网络的Perceptron类。基本可以在鸢尾花数据集上应用:

    首先我们导入常规模块和可视化设置:

    1. # Common imports
    2. import numpy as np
    3. import os
    4. # 以下代码可以确保程序再次运行时结果保持不变
    5. def reset_graph(seed=42):
    6. tf.compat.v1.reset_default_graph()
    7. tf.compat.v1.set_random_seed(seed)
    8. np.random.seed(seed)
    9. # To plot pretty figures
    10. import matplotlib
    11. import matplotlib.pyplot as plt
    12. plt.rcParams['axes.labelsize'] = 14
    13. plt.rcParams['xtick.labelsize'] = 12
    14. plt.rcParams['ytick.labelsize'] = 12
    1. import numpy as np
    2. from sklearn.datasets import load_iris
    3. from sklearn.linear_model import Perceptron
    4. iris = load_iris()
    5. X = iris.data[:, (2, 3)] # 花瓣长度和宽度特征
    6. y = (iris.target == 0).astype(np.int) #y = (iris.target == 0)返回的是布尔类型数组,astype(np.int)将True转化为1,将False转化为0
    7. per_clf = Perceptron(max_iter=100, tol=-np.infty, random_state=42)
    8. per_clf.fit(X, y)
    9. y_pred = per_clf.predict([[2, 0.5]])
    10. y_pred

    运行结果如下:

    array([1])
    1. a = -per_clf.coef_[0][0] / per_clf.coef_[0][1]
    2. b = -per_clf.intercept_ / per_clf.coef_[0][1]
    3. axes = [0, 5, 0, 2]
    4. x0, x1 = np.meshgrid(
    5. np.linspace(axes[0], axes[1], 500).reshape(-1, 1),
    6. np.linspace(axes[2], axes[3], 200).reshape(-1, 1),
    7. )
    8. #X, Y = np.meshgrid(x, y) 代表的是将x中每一个数据和y中每一个数据组合生成很多点,然后将这些点的x坐标放入到X中,y坐标放入Y中,并且相应位置是对应的
    9. X_new = np.c_[x0.ravel(), x1.ravel()] #扁平化作用
    10. y_predict = per_clf.predict(X_new)
    11. zz = y_predict.reshape(x0.shape)
    12. plt.figure(figsize=(10, 4))
    13. #X[y==0, 0]代表标签值即y==0时的坐标的x值,X[y==0, 1]代表标签值即y==0时的坐标的y值
    14. plt.plot(X[y==0, 0], X[y==0, 1], "bs", label="Not Iris-Setosa")
    15. plt.plot(X[y==1, 0], X[y==1, 1], "yo", label="Iris-Setosa")
    16. plt.plot([axes[0], axes[1]], [a * axes[0] + b, a * axes[1] + b], "k-", linewidth=3)
    17. from matplotlib.colors import ListedColormap
    18. custom_cmap = ListedColormap(['#9898ff', '#fafab0'])
    19. plt.contourf(x0, x1, zz, cmap=custom_cmap) #用来画出不同分类的边界线,也常常用来绘制等高线
    20. plt.xlabel("Petal length", fontsize=14)
    21. plt.ylabel("Petal width", fontsize=14)
    22. plt.legend(loc="lower right", fontsize=14)
    23. plt.axis(axes)
    24. plt.show()

    运行结果如下:

    简洁版如下(掌握):

    a = -per_clf.coef_[0][0] / per_clf.
  • 相关阅读:
    read系统调用源码分析
    P1280 尼克的任务
    怎样避免执行走样
    【QT学习】7.事件,把文本显示在页面中(文本可变),鼠标指针切换,鼠标左键右键按下,qt设置背景样式
    DVWA安装教程(懂你的不懂·详细)
    我在京东的第417天:陷入了情绪的泥沼
    前端远程调试方案 Chii 的使用经验分享
    [AndroidStudio]_[初级]_[修改虚拟设备镜像文件的存放位置]
    浅谈测试需求分析
    SpringBoot加载静态资源
  • 原文地址:https://blog.csdn.net/WHJ226/article/details/127038527