• pytorch直线拟合


    目录

    1、数据分析

    2、pytorch直线拟合


    1、数据分析

    直线拟合的前提条件通常包括以下几点:

    存在线性关系:这是进行直线拟合的基础,数据点之间应该存在一种线性关系,即数据的分布可以用直线来近似描述。这种线性关系可以是数据点在直角坐标系上的分布趋势,也可以是通过实验或观测得到的数据点之间的关系。

    数据点之间的误差是随机的:误差应该是随机的,没有任何系统性的偏差,并且符合随机误差的统计规律。这意味着数据点在拟合直线周围的分布应该是随机的,而不是受到某种特定的规律或趋势的影响。

    直线应符合数据点的总体趋势:在拟合直线时,应该尽可能地符合数据点的总体趋势,而不是被一些异常值所影响。如果存在一些异常值,它们不应该对拟合结果产生过大的影响。

    数据点的数量应该足够多:在进行直线拟合时,需要足够多的数据点来保证拟合结果的准确性和可靠性。通常来说,数据点的数量应该足够多,以便涵盖各种情况,并且能够反映出数据的真实分布情况。

    数据的观测或实验过程是可靠的:数据的观测或实验过程应该是可靠的,这意味着数据的测量值应该是准确的,并且没有受到某些特定因素的影响。如果数据的观测或实验过程存在偏差或误差,那么直线拟合的结果也可能受到影响。

    从散点图看出,数据具有明显的线性关系​,本例不过多讨论数据是满足直线拟合的其它条件。

    1. import torch
    2. import matplotlib.pyplot as plt
    3. x=torch.Tensor([1.4,5,11,16,21])
    4. y=torch.Tensor([14.4,29.6,62,85,113.4])
    5. plt.scatter(x.numpy(),y.numpy())
    6. plt.show()

    2、pytorch直线拟合

    基于梯度下降法实现直线拟合。训练过程实际上是一种批量梯度下降(Batch Gradient Descent),这是因为每次更新参数时都使用了所有的数据。另外,学习率 learning_rate 和训练轮数 epochs 是可以调整的超参数,对模型的训练效果有很大影响。

    1. import torch
    2. import matplotlib.pyplot as plt
    3. def Produce_X(x):
    4. x0=torch.ones(x.numpy().size)
    5. X=torch.stack((x,x0),dim=1)
    6. return X
    7. def train(epochs=1,learning_rate=0.01):
    8. for epoch in range(epochs):
    9. output=inputs.mv(w)
    10. loss=(output-target).pow(2).sum()
    11. loss.backward()
    12. w.data-=learning_rate*w.grad
    13. w.grad.zero_()
    14. if epoch%80==0:
    15. draw(output,loss)
    16. return w,loss
    17. def draw(output,loss):
    18. plt.cla()
    19. plt.scatter(x.numpy(), y.numpy())
    20. plt.plot(x.numpy(),output.data.numpy(),'r-',lw=5)
    21. plt.text(5,20,'loss=%s' % (loss.item()),fontdict={'size':20,'color':'red'})
    22. plt.pause(0.005)
    23. if __name__ == "__main__":
    24. x = torch.Tensor([1.4, 5, 11, 16, 21])
    25. y = torch.Tensor([14.4, 29.6, 62, 85.5, 113.4])
    26. X = Produce_X(x)
    27. inputs = X
    28. target = y
    29. w = torch.rand(2, requires_grad=True)
    30. w,loss=train(10000,learning_rate=1e-4)
    31. print("final loss:",loss.item())
    32. print("weigths:",w.data)
    33. plt.show()

    final loss: 8.216197967529297

    weigths: tensor([5.0817, 5.6201])

  • 相关阅读:
    一分钟带你搞定前端”防抖节流“
    QTday3
    abp(net core)+easyui+efcore实现仓储管理系统——ABP升级7.3上(五十八)
    DDD之模块(Module)
    GIS工具maptalks开发手册(二)01-01之Geometry转化为GeoJSON——渲染点
    RabbitMQ--延迟队列--使用/原理
    (附源码)spring boot教学管理平台 毕业设计 281454
    【算法与数据结构】--高级算法和数据结构--排序和搜索
    【网络通信 -- WebRTC】项目实战记录 -- mediasoup android 适配 webrtc m94
    企业网络会议室解决方案-VIP会议室解决方案
  • 原文地址:https://blog.csdn.net/T20151470/article/details/134255417