• PyTorch实现Logistic回归对多元高斯分布进行分类实战(附源码)


    需要源码请点赞关注收藏后评论区留言~~~

    Logistic常用于解决二分类问题,为了便于描述,我们分别从两个多元高斯分布中生成数据X1,X2.这两个多元高斯分布分别表示两个类别,分别设置其标签为y1,y2.

    注意 后面要打乱样本和标签的顺序,将数据重新随机排列是十分重要的步骤,否则算法的每次迭代只会学习到同一个类别的信息,容易造成模型过拟合

    优化算法

    Logistic回归通常采用梯度下降法优化目标函数,PyTorch的torch.optim包实现了大多数常用的优化算法,使用起来非常简单,首先构建一个优化器,在构建时,首先需要将待学习的参数传入,然后传入优化器需要的参数,比如学习率等等

    构造完优化器,就可以迭代的对模型进行训练,有两个步骤,其一是调用损失函数的backward()方法计算模型的梯度,然后再调用优化器的step()方法更新模型的参数,需要注意的是,首先应当调用优化器的zero_grad()方法清空参数的梯度

    效果如下 

    可以明显的看出多元高斯分布生成的样本聚成了两个簇,并且簇的中心分别处于不同的位置,右上方簇的样本分别更加稀疏,而左下方的样本分别紧凑,读者可以自行调整代码中第5-6行的参数 观察其变化

     部分源码如下

    1. import self as self
    2. import torch
    3. from cv2.ml import LogisticRegression
    4. from torch import nn
    5. from matplotlib import pyplot as plt
    6. import numpy as np
    7. from torch.distributions import MultivariateNormal
    8. mu1=-3*torch.ones(2)
    9. mu2=3*torch.ones(2)
    10. sigma1=torch.eye(2)*0.5
    11. sigma2=torch.eye(2)*2
    12. x1=m1.sample((100,))
    13. x2=m2.sample((100,))
    14. y=torch.zeros((200,1))
    15. y[100:]=1
    16. x=torch.cat([x1,x2],dim=0)
    17. idx=np.random.permutation(len(x))
    18. x=x[idx]
    19. y=y[idx]
    20. plt.scatter(x1.numpy()[:,0],x1.numpy()[:,1])
    21. plt.scatter(x2.numpy()[:,0],x2.numpy()[:,1])
    22. plt.show()
    23. D_in,D_out=2,1
    24. linear=nn.Linear(D_in,D_out,bias=True)
    25. output=linear(x)
    26. print(x.shape,linear.weight.shape,linear.bias.shape,output.shape)
    27. def my_linear(x,w,b):
    28. return torch.mm(x,w.t())+b
    29. print(torch.sum((output-my_linear(x,linear.weight,linear.bias))))
    30. sigmoid=nn.Sigmoid()
    31. scores=sigmoid(output)
    32. def my_sigmoid(x):
    33. x=1/(1+torch.exp(-x))
    34. return x
    35. loss=nn.BCELoss()
    36. loss(sigmoid(output),y)
    37. def my_loss(x,y):
    38. loss=-torch.mean(torch.log(x)*y+torch.log(1-x)*(1-y))
    39. return loss
    40. from torch import optim
    41. import torch.nn as nn
    42. class LogisticRegression(nn.Module):
    43. super(LogisticRegression,self).__init__()
    44. self.linear=nn.Linear()
    45. optimizer=optim.SGD(lr=0.03)
    46. batch_size=10
    47. iters=10
    48. for _ in range(iters):
    49. for i in range(int(len(x)/batch_size)):
    50. input=x[i*batch_size:(i+1)*batch_size]
    51. target=y[i*batch_size:(i+1)*batch_size]
    52. optimizer.zero_grad()
    53. output=lr_model(input)

  • 相关阅读:
    小杨哥陷入打假风波,会变成下一个辛巴吗?
    算力被“卡脖子”,光子时代“换道超车”
    【AGC】使用云调试优惠扣费、华为设备上触发崩溃、无法下载华为应用市场问题小结
    hive怎么设置元数据库为mysql
    关于Transfomer的思考
    【面试题 - mysql】进阶篇 - 索引
    51单片机项目(13)——基于51单片机的智能台灯protues仿真
    【每日一题】实现 Trie (前缀树)
    数学建模学习(101):车辆路线规划问题
    STC8H开发(十五): GPIO驱动Ci24R1无线模块
  • 原文地址:https://blog.csdn.net/jiebaoshayebuhui/article/details/127777571