• PPO-KL散度近端策略优化玩cartpole游戏


     

    其实KL散度在这个游戏里的作用不大,游戏的action比较简单,不像LM里的action是一个很大的向量,可以直接用surr1,最大化surr1,实验测试确实是这样,而且KL的系数不能给太大,否则惩罚力度太大,action model 和ref model产生的action其实分布的差距并不太大

     

    复制代码
    import gym
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import numpy as np
    import pygame
    import sys
    from collections import deque
    
    # 定义策略网络
    class PolicyNetwork(nn.Module):
        def __init__(self):
            super(PolicyNetwork, self).__init__()
            self.fc = nn.Sequential(
                nn.Linear(4, 2),
                nn.Tanh(),
                nn.Linear(2, 2),  # CartPole的动作空间为2
                nn.Softmax(dim=-1)
            )
    
        def forward(self, x):
            return self.fc(x)
    
    # 定义值网络
    class ValueNetwork(nn.Module):
        def __init__(self):
            super(ValueNetwork, self).__init__()
            self.fc = nn.Sequential(
                nn.Linear(4, 2),
                nn.Tanh(),
                nn.Linear(2, 1)
            )
    
        def forward(self, x):
            return self.fc(x)
    
    # 经验回放缓冲区
    class RolloutBuffer:
        def __init__(self):
            self.states = []
            self.actions = []
            self.rewards = []
            self.dones = []
            self.log_probs = []
        
        def store(self, state, action, reward, done, log_prob):
            self.states.append(state)
            self.actions.append(action)
            self.rewards.append(reward)
            self.dones.append(done)
            self.log_probs.append(log_prob)
        
        def clear(self):
            self.states = []
            self.actions = []
            self.rewards = []
            self.dones = []
            self.log_probs = []
    
        def get_batch(self):
            return (
                torch.tensor(self.states, dtype=torch.float),
                torch.tensor(self.actions, dtype=torch.long),
                torch.tensor(self.rewards, dtype=torch.float),
                torch.tensor(self.dones, dtype=torch.bool),
                torch.tensor(self.log_probs, dtype=torch.float)
            )
    
    # PPO更新函数
    def ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer, epochs=100, gamma=0.99, clip_param=0.2):
        states, actions, rewards, dones, old_log_probs = buffer.get_batch()
        returns = []
        advantages = []
        G = 0
        adv = 0
        dones = dones.to(torch.int)
        # print(dones)
        for reward, done, value in zip(reversed(rewards), reversed(dones), reversed(value_net(states))):
            if done:
                G = 0
                adv = 0
            G = reward + gamma * G  #蒙特卡洛回溯G值
            delta = reward + gamma * value.item() * (1 - done) - value.item()  #TD差分
            # adv = delta + gamma * 0.95 * adv * (1 - done)  #
            adv = delta + adv*(1-done)
            returns.insert(0, G)
            advantages.insert(0, adv)
    
        returns = torch.tensor(returns, dtype=torch.float)  #价值
        advantages = torch.tensor(advantages, dtype=torch.float)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)  #add baseline
    
        for _ in range(epochs):
            action_probs = policy_net(states)
            dist = torch.distributions.Categorical(action_probs)
            new_log_probs = dist.log_prob(actions)
            ratio = (new_log_probs - old_log_probs).exp()
    
            KL = new_log_probs.exp()*(new_log_probs - old_log_probs).mean()   #KL散度 p*log(p/p')
            #下面三行是核心
            surr1 = ratio * advantages
    
            PPO1,PPO2 = True,False
            # print(surr1,KL*500)
            if PPO1 == True:
                actor_loss = -(surr1 - KL).mean()
    
            if PPO2 == True:
                surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages
                actor_loss = -torch.min(surr1, surr2).mean()
    
            optimizer_policy.zero_grad()
            actor_loss.backward()
            optimizer_policy.step()
    
            value_loss = (returns - value_net(states)).pow(2).mean()
    
            optimizer_value.zero_grad()
            value_loss.backward()
            optimizer_value.step()
    
    # 初始化环境和模型
    env = gym.make('CartPole-v1')
    policy_net = PolicyNetwork()
    value_net = ValueNetwork()
    optimizer_policy = optim.Adam(policy_net.parameters(), lr=3e-4)
    optimizer_value = optim.Adam(value_net.parameters(), lr=1e-3)
    buffer = RolloutBuffer()
    
    # Pygame初始化
    pygame.init()
    screen = pygame.display.set_mode((600, 400))
    clock = pygame.time.Clock()
    
    draw_on = False
    # 训练循环
    state = env.reset()
    for episode in range(10000):  # 训练轮次
        done = False
        state = state[0]
        step= 0
        while not done:
            step+=1
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action_probs = policy_net(state_tensor)   #旧policy推理数据
            dist = torch.distributions.Categorical(action_probs)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            
            next_state, reward, done, _ ,_ = env.step(action.item())
            buffer.store(state, action.item(), reward, done, log_prob)
            
            state = next_state
    
            # 实时显示
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    sys.exit()
    
            if draw_on:
                # 清屏并重新绘制
                screen.fill((0, 0, 0))
                cart_x = int(state[0] * 100 + 300)  # 位置转换为屏幕坐标
                pygame.draw.rect(screen, (0, 128, 255), (cart_x, 300, 50, 30))
                pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * np.sin(state[2])), 300 - int(50 * np.cos(state[2]))), 5)
                pygame.display.flip()
                clock.tick(60)
    
        if step >2000:
            draw_on = True
        ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer)
        buffer.clear()
        state = env.reset()
        print(f'Episode {episode} completed , reward:  {step}.')
    
    # 结束训练
    env.close()
    pygame.quit()
    复制代码

     

    效果:

     

  • 相关阅读:
    VMOS虚拟机开源,游戏安全面临新挑战
    印刷企业使用数字工厂管理系统前后有什么变化
    Ubuntu 22.04 编译 DPDK 19.11 igb_uio 和 kni 报错解决办法
    FastDFS-02-JavaAPI
    leetcode 50. Pow(x, n)
    数据分析------知识点(六)
    Lua脚本详解
    20.单例模式进阶
    关于使用命令行 cf login 登录 SAP BTP CloudFoundry 环境的问题
    ASON 技术简介
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/18194800