• 深度强化学习-TD3算法


    论文地址:https://arxiv.org/pdf/1802.09477.pdf

            TD3(Twin Delayed Deep Deterministic policy gradient algorithm)算法适合于高维连续动作空间,是DDPG算法的优化版本,为了优化DDPG在训练过程中Q值估计过高的问题。

    相较DDPG的改进:

    1、运用两个Critic网络。运用两个网络对动作价值函数进行估计。在练习的时分挑选最小的Q值作为估值(为了防止误差累积过高)。

    2、运用延迟学习。Critic网络更新的频率要比Actor网络更新的频率要大(类似GAN的思想,先训练好Critic才能更好的对actor指指点点)。

    3、运用梯度截取。将Actor的参数更新的梯度截取到某个范围内。

    4、加入训练噪声。更新Critic网络时候加入随机噪声,以达到对Critic网络波动的稳定性。

    算法流程:

            算法的伪代码 

    代码实现: 

            actor:

    1. class Actor(nn.Module):
    2. def __init__(self, state_dim, action_dim, net_width, maxaction):
    3. super(Actor, self).__init__()
    4. self.l1 = nn.Linear(state_dim, net_width)
    5. self.l2 = nn.Linear(net_width, net_width)
    6. self.l3 = nn.Linear(net_width, action_dim)
    7. self.maxaction = maxaction
    8. def forward(self, state):
    9. a = torch.tanh(self.l1(state))
    10. a = torch.tanh(self.l2(a))
    11. a = torch.tanh(self.l3(a)) * self.maxaction
    12. return a

             critic:

    1. class Q_Critic(nn.Module):
    2. def __init__(self, state_dim, action_dim, net_width):
    3. super(Q_Critic, self).__init__()
    4. # Q1 architecture
    5. self.l1 = nn.Linear(state_dim + action_dim, net_width)
    6. self.l2 = nn.Linear(net_width, net_width)
    7. self.l3 = nn.Linear(net_width, 1)
    8. # Q2 architecture
    9. self.l4 = nn.Linear(state_dim + action_dim, net_width)
    10. self.l5 = nn.Linear(net_width, net_width)
    11. self.l6 = nn.Linear(net_width, 1)
    12. def forward(self, state, action):
    13. sa = torch.cat([state, action], 1)
    14. q1 = F.relu(self.l1(sa))
    15. q1 = F.relu(self.l2(q1))
    16. q1 = self.l3(q1)
    17. q2 = F.relu(self.l4(sa))
    18. q2 = F.relu(self.l5(q2))
    19. q2 = self.l6(q2)
    20. return q1, q2
    21. def Q1(self, state, action):
    22. sa = torch.cat([state, action], 1)
    23. q1 = F.relu(self.l1(sa))
    24. q1 = F.relu(self.l2(q1))
    25. q1 = self.l3(q1)
    26. return q1

             TD3的整体实现:

    1. class TD3(object):
    2. def __init__(
    3. self,
    4. env_with_Dead,
    5. state_dim,
    6. action_dim,
    7. max_action,
    8. gamma=0.99,
    9. net_width=128,
    10. a_lr=1e-4,
    11. c_lr=1e-4,
    12. Q_batchsize=256
    13. ):
    14. self.actor = Actor(state_dim, action_dim, net_width, max_action).to(device)
    15. self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=a_lr)
    16. self.actor_target = copy.deepcopy(self.actor)
    17. self.q_critic = Q_Critic(state_dim, action_dim, net_width).to(device)
    18. self.q_critic_optimizer = torch.optim.Adam(self.q_critic.parameters(), lr=c_lr)
    19. self.q_critic_target = copy.deepcopy(self.q_critic)
    20. self.env_with_Dead = env_with_Dead
    21. self.action_dim = action_dim
    22. self.max_action = max_action
    23. self.gamma = gamma
    24. self.policy_noise = 0.2 * max_action
    25. self.noise_clip = 0.5 * max_action
    26. self.tau = 0.005
    27. self.Q_batchsize = Q_batchsize
    28. self.delay_counter = -1
    29. self.delay_freq = 1
    30. def select_action(self, state): # only used when interact with the env
    31. with torch.no_grad():
    32. state = torch.FloatTensor(state.reshape(1, -1)).to(device)
    33. a = self.actor(state)
    34. return a.cpu().numpy().flatten()
    35. def train(self, replay_buffer):
    36. self.delay_counter += 1
    37. with torch.no_grad():
    38. s, a, r, s_prime, dead_mask = replay_buffer.sample(self.Q_batchsize)
    39. noise = (torch.randn_like(a) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
    40. smoothed_target_a = (
    41. self.actor_target(s_prime) + noise # Noisy on target action
    42. ).clamp(-self.max_action, self.max_action)
    43. # Compute the target Q value
    44. target_Q1, target_Q2 = self.q_critic_target(s_prime, smoothed_target_a)
    45. target_Q = torch.min(target_Q1, target_Q2)
    46. '''DEAD OR NOT'''
    47. if self.env_with_Dead:
    48. target_Q = r + (1 - dead_mask) * self.gamma * target_Q # env with dead
    49. else:
    50. target_Q = r + self.gamma * target_Q # env without dead
    51. # Get current Q estimates
    52. current_Q1, current_Q2 = self.q_critic(s, a)
    53. # Compute critic loss
    54. q_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
    55. # Optimize the q_critic
    56. self.q_critic_optimizer.zero_grad()
    57. q_loss.backward()
    58. self.q_critic_optimizer.step()
    59. if self.delay_counter == self.delay_freq:
    60. # Update Actor
    61. a_loss = -self.q_critic.Q1(s, self.actor(s)).mean()
    62. self.actor_optimizer.zero_grad()
    63. a_loss.backward()
    64. self.actor_optimizer.step()
    65. # Update the frozen target models
    66. for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()):
    67. target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    68. for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
    69. target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    70. self.delay_counter = -1
    71. def save(self, episode):
    72. torch.save(self.actor.state_dict(), "ppo_actor{}.pth".format(episode))
    73. torch.save(self.q_critic.state_dict(), "ppo_q_critic{}.pth".format(episode))
    74. def load(self, episode):
    75. self.actor.load_state_dict(torch.load("ppo_actor{}.pth".format(episode)))
    76. self.q_critic.load_state_dict(torch.load("ppo_q_critic{}.pth".format(episode)))

     网络结构图:

             其中actor和target部分的网络参数会延迟更新,也就是说critic1和critic2参数在不断更新,训练好critic之后才能知道actor做出理想的动作。

     

  • 相关阅读:
    【从头构筑C#知识体系】1.8 语句
    什么是同源策略(same-origin policy)?它对AJAX有什么影响?
    Spark SQL 的总体工作流程
    [微信小程序踩坑]微信小程序editor富文本组件渲染字符串时,内部图片超出大小导致无法正常渲染或回显(数据传输长度为 3458 KB,存在有性能问题!)
    获取依赖库的N种方法
    基于nodejs的预约上门维修服务系统设计与实现-计算机毕业设计源码+LW文档
    CVPR 2018 基于累积注意力的视觉定位 Visual Grounding via Accumulated Attention 详解
    官方Redis视图化工具Redisinsight
    Day23力扣打卡
    数据结构——快排与归并
  • 原文地址:https://blog.csdn.net/athrunsunny/article/details/126653018