• 强化学习之REINFORECE策略梯度算法——已CartPole环境为例


    整体代码如下:

    1. import gym
    2. import numpy as np
    3. import torch
    4. import matplotlib.pyplot as plt
    5. from tqdm import tqdm
    6. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    7. print(device)
    8. def moving_average(a, window_size):
    9. cumulative_sum = np.cumsum(np.insert(a, 0, 0))
    10. middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
    11. r = np.arange(1, window_size-1, 2)
    12. begin = np.cumsum(a[:window_size-1])[::2] / r
    13. end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
    14. return np.concatenate((begin, middle, end))
    15. class PolicyNetwork(torch.nn.Module):
    16. def __init__(self,statedim,hiddendim,actiondim):
    17. super(PolicyNetwork,self).__init__()
    18. self.cf1=torch.nn.Linear(statedim,hiddendim)
    19. self.cf2=torch.nn.Linear(hiddendim,actiondim)
    20. def forward(self,x):
    21. x=torch.nn.functional.relu(self.cf1(x))
    22. return torch.nn.functional.softmax(self.cf2(x),dim=1)
    23. class REINFORCE:
    24. def __init__(self,statedim,hiddendim,actiondim,learningrate,gamma,device):
    25. self.policynet=PolicyNetwork(statedim,hiddendim,actiondim).to(device)
    26. self.gamma=gamma
    27. self.device=device
    28. self.optimizer=torch.optim.Adam(self.policynet.parameters(),lr=learningrate)
    29. def takeaction(self,state):
    30. state=torch.tensor([state],dtype=torch.float).to(self.device)
    31. probs=self.policynet(state)
    32. actiondist=torch.distributions.Categorical(probs)#torch.distributions.Categorical:这是 PyTorch 中用于表示类别分布的类,可以使用 actiondist.sample() 方法从这个分布中随机采样一个类别
    33. action=actiondist.sample()
    34. return action.item()
    35. def update(self,transitiondist):
    36. statelist=transitiondist['states']
    37. rewardlist=transitiondist['rewards']
    38. actionlist=transitiondist['actions']
    39. G=0
    40. self.optimizer.zero_grad()
    41. for i in reversed(range(len(rewardlist))):#从最后一步计算起
    42. reward=rewardlist[i]
    43. state=statelist[i]
    44. action=actionlist[i]
    45. state=torch.tensor([state],dtype=torch.float).to(self.device)
    46. action=torch.tensor([action]).view(-1,1).to(self.device)
    47. logprob=torch.log(self.policynet(state).gather(1,action)) #.gather(1, action) 方法从策略网络的输出中提取对应于特定动作 action 的概率值。这里的 1 表示沿着维度 1(通常对应于动作维度)进行索引。
    48. G=self.gamma*G+reward
    49. loss=-logprob*G#每一步的损失函数
    50. loss.backward()#反向传播计算梯度
    51. self.optimizer.step()#更新参数,梯度下降
    52. learningrate=4e-3
    53. episodesnum=1000
    54. hiddendim=128
    55. gamma=0.99
    56. pbarnum=10
    57. printreturnnum=10
    58. device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    59. env=gym.make('CartPole-v1')
    60. env.reset(seed=880)
    61. torch.manual_seed(880)
    62. statedim=env.observation_space.shape[0]
    63. actiondim=env.action_space.n
    64. agent=REINFORCE(statedim=statedim,hiddendim=hiddendim,actiondim=actiondim,learningrate=learningrate,gamma=gamma,device=device)
    65. returnlist=[]
    66. for k in range(pbarnum):
    67. with tqdm(total=int(episodesnum/pbarnum),desc='Iteration %d'%k)as pbar:
    68. for episode in range(int(episodesnum/pbarnum)):
    69. g=0
    70. transitiondist={'states':[],'actions':[],'nextstates':[],'rewards':[]}
    71. state,_=env.reset(seed=880)
    72. done=False
    73. while not done:
    74. action=agent.takeaction(state)
    75. nextstate,reward,done,truncated,_=env.step(action)
    76. done=done or truncated
    77. transitiondist['states'].append(state)
    78. transitiondist['actions'].append(action)
    79. transitiondist['nextstates'].append(nextstate)
    80. transitiondist['rewards'].append(reward)
    81. state=nextstate
    82. g=g+reward
    83. returnlist.append(g)
    84. agent.update(transitiondist)
    85. if (episode+1)%(printreturnnum)==0:
    86. pbar.set_postfix({'Episode':'%d'%(episodesnum//pbarnum+episode+1),'Return':'%.3f'%np.mean(returnlist[-printreturnnum:])})
    87. pbar.update(1)
    88. episodelist=list(range(len(returnlist)))
    89. plt.plot(episodelist,returnlist)
    90. plt.xlabel('Episodes')
    91. plt.ylabel('Returns')
    92. plt.title('REINFORCE on {}'.format(env.spec.name))
    93. plt.show()
    94. mvreturn=moving_average(returnlist,9)
    95. plt.plot(episodelist,mvreturn)
    96. plt.xlabel('Episodes')
    97. plt.ylabel('Returns')
    98. plt.title('REINFORCE on {}'.format(env.spec.name))
    99. plt.show()

    效果:

  • 相关阅读:
    使用PHP对接企业微信审批接口的问题与解决办法(二)
    基于当量因子法、InVEST、SolVES模型等多技术融合在生态系统服务功能社会价值评估中的应用及论文写作、拓展分析
    QT OpenGL (1)2D Painting Example
    【sass】 中使用 /deep/ 修改 elementUI 组件样式报错
    Spring底层原理学习笔记--第一讲--(BeanFactory与ApplicaitonContext)
    RHCSA 02 - 自启动rootless容器
    FPGA设计时序约束一、主时钟与生成时钟
    POI版本升级需要调整的代码整理(3.15升级到5.1.0版本)
    MyBatis简述
    spring框架源码十六、BeanDefinition加载注册子流程
  • 原文地址:https://blog.csdn.net/m0_56497861/article/details/141094744