• 【深度学习】实验06 使用TensorFlow完成线性回归


    使用TensorFlow完成线性回归

    TensorFlow是由Google开发的一个开源的机器学习框架。它可以让开发者更加轻松地构建和训练深度学习模型,从而解决各种自然语言处理、计算机视觉、语音识别、推荐系统等领域的问题。

    TensorFlow的主要特点是灵活性和可伸缩性。它实现了一种基于数据流图的计算模型,使得用户可以定义自己的计算图,控制模型的计算过程。同时,TensorFlow支持分布式计算,使得用户可以在多台机器上运行大规模计算任务,从而提高计算效率。

    TensorFlow包含了许多高级API,例如Keras和Estimator,使得用户可以更加轻松地构建和训练深度学习模型。Keras提供了一个易于使用的高级API,使得用户可以在不需要深入了解TensorFlow的情况下,构建和训练深度学习模型。Estimator则提供了一种更加低级的API,使得用户可以更加灵活地定义模型的结构和训练过程。

    TensorFlow还提供了一个交互式开发环境,称为TensorBoard,可以帮助用户可视化模型的计算图、训练过程和性能指标,从而更加直观地理解和调试深度学习模型。

    由于TensorFlow的灵活性和可伸缩性,它已经被广泛应用于各个领域,包括自然语言处理、计算机视觉、语音识别、推荐系统等。例如,在自然语言处理领域,TensorFlow被用于构建和训练各种强大的模型,例如机器翻译模型、文本分类模型、语言生成模型等。

    总的来说,TensorFlow是一个强大的机器学习框架,可以帮助用户更加轻松地构建和训练深度学习模型。随着深度学习技术的不断发展,TensorFlow将继续发挥重要的作用,推动各个领域的发展和创新。

    1. 导入TensorFlow库

    # 导入相关库
    %matplotlib inline
    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    
    • 1
    • 2
    • 3
    • 4
    • 5

    2. 构造数据集

    # 产出样本点个数
    n_observations = 100
    # 产出-3~3之间的样本点
    xs = np.linspace(-3, 3, n_observations) 
    # sin扰动
    ys = np.sin(xs) + np.random.uniform(-0.5, 0.5, n_observations) 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    xs
    
    • 1
       array([-3.        , -2.93939394, -2.87878788, -2.81818182, -2.75757576,
              -2.6969697 , -2.63636364, -2.57575758, -2.51515152, -2.45454545,
              -2.39393939, -2.33333333, -2.27272727, -2.21212121, -2.15151515,
              -2.09090909, -2.03030303, -1.96969697, -1.90909091, -1.84848485,
              -1.78787879, -1.72727273, -1.66666667, -1.60606061, -1.54545455,
              -1.48484848, -1.42424242, -1.36363636, -1.3030303 , -1.24242424,
              -1.18181818, -1.12121212, -1.06060606, -1.        , -0.93939394,
              -0.87878788, -0.81818182, -0.75757576, -0.6969697 , -0.63636364,
              -0.57575758, -0.51515152, -0.45454545, -0.39393939, -0.33333333,
              -0.27272727, -0.21212121, -0.15151515, -0.09090909, -0.03030303,
               0.03030303,  0.09090909,  0.15151515,  0.21212121,  0.27272727,
               0.33333333,  0.39393939,  0.45454545,  0.51515152,  0.57575758,
               0.63636364,  0.6969697 ,  0.75757576,  0.81818182,  0.87878788,
               0.93939394,  1.        ,  1.06060606,  1.12121212,  1.18181818,
               1.24242424,  1.3030303 ,  1.36363636,  1.42424242,  1.48484848,
               1.54545455,  1.60606061,  1.66666667,  1.72727273,  1.78787879,
               1.84848485,  1.90909091,  1.96969697,  2.03030303,  2.09090909,
               2.15151515,  2.21212121,  2.27272727,  2.33333333,  2.39393939,
               2.45454545,  2.51515152,  2.57575758,  2.63636364,  2.6969697 ,
               2.75757576,  2.81818182,  2.87878788,  2.93939394,  3.        ])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    ys
    
    • 1
       array([-0.62568008,  0.01486274, -0.29232541, -0.05271084,
    -0.53407957,
              -0.37199581, -0.40235236, -0.80005504, -0.2280913 , -0.96111433,
              -0.58732159, -0.71310851, -1.19817878, -0.93036437, -1.02682804,
              -1.33669261, -1.36873043, -0.44500172, -1.38769079, -0.52899793,
              -0.78090929, -1.1470421 , -0.79274726, -0.95139505, -1.3536293 ,
              -1.15097615, -1.04909201, -0.89071026, -0.81181765, -0.70292996,
              -0.49732344, -1.22800179, -1.21280414, -0.59583172, -1.05027515,
              -0.56369191, -0.68680323, -0.20454038, -0.32429566, -0.84640122,
              -0.08175012, -0.76910728, -0.59206189, -0.09984673, -0.52465978,
              -0.30498277,  0.08593627, -0.29488864,  0.24698113, -0.07324925,
               0.12773032,  0.55508531,  0.14794648,  0.40155342,  0.31717698,
               0.63213964,  0.35736413,  0.05264068,  0.39858619,  1.00710311,
               0.73844747,  1.12858026,  0.59779567,  1.22131999,  0.80849061,
               0.72796849,  1.0990044 ,  0.45447096,  1.15217952,  1.31846002,
               1.27140258,  0.65264777,  1.15205186,  0.90705463,  0.82489198,
               0.50572125,  1.47115594,  0.98209434,  0.95763951,  0.50225094,
               1.40415029,  0.74618984,  0.90620692,  0.40593222,  0.62737999,
               1.05236579,  1.20041249,  1.14784273,  0.54798933,  0.18167682,
               0.50830766,  0.92498585,  0.9778136 ,  0.42331405,  0.88163729,
               0.67235809, -0.00539421, -0.06219493,  0.26436412,  0.51978602])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    # 可视化图长和宽
    plt.rcParams["figure.figsize"] = (6,4)
    # 绘制散点图
    plt.scatter(xs, ys) 
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5

    1

    3. 定义基本模型

    # 占位
    X = tf.placeholder(tf.float32, name='X')
    Y = tf.placeholder(tf.float32, name='Y')
    
    • 1
    • 2
    • 3
    # 随机采样出变量
    W = tf.Variable(tf.random_normal([1]), name='weight') 
    b = tf.Variable(tf.random_normal([1]), name='bias')
    
    • 1
    • 2
    • 3
    # 手写y = wx+b
    Y_pred = tf.add(tf.multiply(X, W), b) 
    
    • 1
    • 2
    # 定义损失函数mse
    loss = tf.square(Y - Y_pred, name='loss') 
    
    • 1
    • 2
    # 学习率
    learning_rate = 0.01
    # 优化器,就是tensorflow中梯度下降的策略
    # 定义梯度下降,申明学习率和针对那个loss求最小化
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss) 
    
    • 1
    • 2
    • 3
    • 4
    • 5

    4. 训练模型

    # 去样本数量
    n_samples = xs.shape[0]
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        # 记得初始化所有变量
        sess.run(init) 
        writer = tf.summary.FileWriter('../graphs/linear_reg', sess.graph)
        # 训练模型
        for i in range(50):
            #初始化损失函数
            total_loss = 0
            for x, y in zip(xs, ys):
                # 通过feed_dic把数据灌进去
                _, l = sess.run([optimizer, loss], feed_dict={X: x, Y:y}) #_是optimizer的返回,在这没有用就省略
                total_loss += l #统计每轮样本的损失
            print('Epoch {0}: {1}'.format(i, total_loss/n_samples)) #求损失平均
    
        # 关闭writer
        writer.close() 
        # 取出w和b的值
        W, b = sess.run([W, b]) 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    Epoch 0: [0.48447946]
    Epoch 1: [0.20947962]
    Epoch 2: [0.19649307]
    Epoch 3: [0.19527708]
    Epoch 4: [0.19514856]
    Epoch 5: [0.19513479]
    Epoch 6: [0.19513334]
    Epoch 7: [0.19513316]
    Epoch 8: [0.19513315]
    Epoch 9: [0.19513315]
    Epoch 10: [0.19513315]
    Epoch 11: [0.19513315]
    Epoch 12: [0.19513315]
    Epoch 13: [0.19513315]
    Epoch 14: [0.19513315]
    Epoch 15: [0.19513315]
    Epoch 16: [0.19513315]
    Epoch 17: [0.19513315]
    Epoch 18: [0.19513315]
    Epoch 19: [0.19513315]
    Epoch 20: [0.19513315]
    Epoch 21: [0.19513315]
    Epoch 22: [0.19513315]
    Epoch 23: [0.19513315]
    Epoch 24: [0.19513315]
    Epoch 25: [0.19513315]
    Epoch 26: [0.19513315]
    Epoch 27: [0.19513315]
    Epoch 28: [0.19513315]
    Epoch 29: [0.19513315]
    Epoch 30: [0.19513315]
    Epoch 31: [0.19513315]
    Epoch 32: [0.19513315]
    Epoch 33: [0.19513315]
    Epoch 34: [0.19513315]
    Epoch 35: [0.19513315]
    Epoch 36: [0.19513315]
    Epoch 37: [0.19513315]
    Epoch 38: [0.19513315]
    Epoch 39: [0.19513315]
    Epoch 40: [0.19513315]
    Epoch 41: [0.19513315]
    Epoch 42: [0.19513315]
    Epoch 43: [0.19513315]
    Epoch 44: [0.19513315]
    Epoch 45: [0.19513315]
    Epoch 46: [0.19513315]
    Epoch 47: [0.19513315]
    Epoch 48: [0.19513315]
    Epoch 49: [0.19513315]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    print(W,b)
    print("W:"+str(W[0]))
    print("b:"+str(b[0]))
    
    • 1
    • 2
    • 3
    [0.23069778] [-0.12590201]
    W:0.23069778
    b:-0.12590201
    
    • 1
    • 2
    • 3

    5. 线性回归图

    # 线性回归图
    plt.plot(xs, ys, 'bo', label='Real data')
    plt.plot(xs, xs * W + b, 'r', label='Predicted data')
    plt.legend()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5

    2

    附:系列文章

    序号文章目录直达链接
    1波士顿房价预测https://want595.blog.csdn.net/article/details/132181950
    2鸢尾花数据集分析https://want595.blog.csdn.net/article/details/132182057
    3特征处理https://want595.blog.csdn.net/article/details/132182165
    4交叉验证https://want595.blog.csdn.net/article/details/132182238
    5构造神经网络示例https://want595.blog.csdn.net/article/details/132182341
    6使用TensorFlow完成线性回归https://want595.blog.csdn.net/article/details/132182417
    7使用TensorFlow完成逻辑回归https://want595.blog.csdn.net/article/details/132182496
    8TensorBoard案例https://want595.blog.csdn.net/article/details/132182584
    9使用Keras完成线性回归https://want595.blog.csdn.net/article/details/132182723
    10使用Keras完成逻辑回归https://want595.blog.csdn.net/article/details/132182795
    11使用Keras预训练模型完成猫狗识别https://want595.blog.csdn.net/article/details/132243928
    12使用PyTorch训练模型https://want595.blog.csdn.net/article/details/132243989
    13使用Dropout抑制过拟合https://want595.blog.csdn.net/article/details/132244111
    14使用CNN完成MNIST手写体识别(TensorFlow)https://want595.blog.csdn.net/article/details/132244499
    15使用CNN完成MNIST手写体识别(Keras)https://want595.blog.csdn.net/article/details/132244552
    16使用CNN完成MNIST手写体识别(PyTorch)https://want595.blog.csdn.net/article/details/132244641
    17使用GAN生成手写数字样本https://want595.blog.csdn.net/article/details/132244764
    18自然语言处理https://want595.blog.csdn.net/article/details/132276591
  • 相关阅读:
    亚马逊自养号测评安全吗?
    Jenkins 安装
    nacos应用——占用内存过多问题解决(JVM调优初步)
    为什么创建 Redis 集群时会自动错开主从节点?
    杂谈 跟编程无关的事情21
    【vue】ant-design弹出框无法关闭和runtimecore提示isFucntion is not function的问题修复
    保洁实业如何使用虚拟机器人提高安全性
    kubernetes资源管理
    电路综合-基于简化实频的集总参数电路匹配2-得出解析解并综合
    原生js实现扫雷
  • 原文地址:https://blog.csdn.net/m0_68111267/article/details/132182417