• 用深度强化学习来玩Flappy Bird


    目录

    演示视频

    核心代码


    演示视频

    深度强化学习来玩Flappy Bird

    核心代码

    1. import torch.nn as nn
    2. class DeepQNetwork(nn.Module):
    3. def __init__(self):
    4. super(DeepQNetwork, self).__init__()
    5. self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
    6. self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
    7. self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))
    8. self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))
    9. self.fc2 = nn.Linear(512, 2)
    10. self._create_weights()
    11. def _create_weights(self):
    12. for m in self.modules():
    13. if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
    14. nn.init.uniform_(m.weight, -0.01, 0.01)
    15. nn.init.constant_(m.bias, 0)
    16. def forward(self, input):
    17. output = self.conv1(input)
    18. output = self.conv2(output)
    19. output = self.conv3(output)
    20. output = output.view(output.size(0), -1)
    21. output = self.fc1(output)
    22. output = self.fc2(output)
    23. return output
    1. from itertools import cycle
    2. from numpy.random import randint
    3. from pygame import Rect, init, time, display
    4. from pygame.event import pump
    5. from pygame.image import load
    6. from pygame.surfarray import array3d, pixels_alpha
    7. from pygame.transform import rotate
    8. import numpy as np
    9. class FlappyBird(object):
    10. init()
    11. fps_clock = time.Clock()
    12. screen_width = 288
    13. screen_height = 512
    14. screen = display.set_mode((screen_width, screen_height))
    15. display.set_caption('Deep Q-Network Flappy Bird')
    16. base_image = load('assets/sprites/base.png').convert_alpha()
    17. background_image = load('assets/sprites/background-black.png').convert()
    18. pipe_images = [rotate(load('assets/sprites/pipe-green.png').convert_alpha(), 180),
    19. load('assets/sprites/pipe-green.png').convert_alpha()]
    20. bird_images = [load('assets/sprites/redbird-upflap.png').convert_alpha(),
    21. load('assets/sprites/redbird-midflap.png').convert_alpha(),
    22. load('assets/sprites/redbird-downflap.png').convert_alpha()]
    23. # number_images = [load('assets/sprites/{}.png'.format(i)).convert_alpha() for i in range(10)]
    24. bird_hitmask = [pixels_alpha(image).astype(bool) for image in bird_images]
    25. pipe_hitmask = [pixels_alpha(image).astype(bool) for image in pipe_images]
    26. fps = 30
    27. pipe_gap_size = 100
    28. pipe_velocity_x = -4
    29. # parameters for bird
    30. min_velocity_y = -8
    31. max_velocity_y = 10
    32. downward_speed = 1
    33. upward_speed = -9
    34. bird_index_generator = cycle([0, 1, 2, 1])
    35. def __init__(self):
    36. self.iter = self.bird_index = self.score = 0
    37. self.bird_width = self.bird_images[0].get_width()
    38. self.bird_height = self.bird_images[0].get_height()
    39. self.pipe_width = self.pipe_images[0].get_width()
    40. self.pipe_height = self.pipe_images[0].get_height()
    41. self.bird_x = int(self.screen_width / 5)
    42. self.bird_y = int((self.screen_height - self.bird_height) / 2)
    43. self.base_x = 0
    44. self.base_y = self.screen_height * 0.79
    45. self.base_shift = self.base_image.get_width() - self.background_image.get_width()
    46. pipes = [self.generate_pipe(), self.generate_pipe()]
    47. pipes[0]["x_upper"] = pipes[0]["x_lower"] = self.screen_width
    48. pipes[1]["x_upper"] = pipes[1]["x_lower"] = self.screen_width * 1.5
    49. self.pipes = pipes
    50. self.current_velocity_y = 0
    51. self.is_flapped = False
    52. def generate_pipe(self):
    53. x = self.screen_width + 10
    54. gap_y = randint(2, 10) * 10 + int(self.base_y / 5)
    55. return {"x_upper": x, "y_upper": gap_y - self.pipe_height, "x_lower": x, "y_lower": gap_y + self.pipe_gap_size}
    56. def is_collided(self):
    57. # Check if the bird touch ground
    58. if self.bird_height + self.bird_y + 1 >= self.base_y:
    59. return True
    60. bird_bbox = Rect(self.bird_x, self.bird_y, self.bird_width, self.bird_height)
    61. pipe_boxes = []
    62. for pipe in self.pipes:
    63. pipe_boxes.append(Rect(pipe["x_upper"], pipe["y_upper"], self.pipe_width, self.pipe_height))
    64. pipe_boxes.append(Rect(pipe["x_lower"], pipe["y_lower"], self.pipe_width, self.pipe_height))
    65. # Check if the bird's bounding box overlaps to the bounding box of any pipe
    66. if bird_bbox.collidelist(pipe_boxes) == -1:
    67. return False
    68. for i in range(2):
    69. cropped_bbox = bird_bbox.clip(pipe_boxes[i])
    70. min_x1 = cropped_bbox.x - bird_bbox.x
    71. min_y1 = cropped_bbox.y - bird_bbox.y
    72. min_x2 = cropped_bbox.x - pipe_boxes[i].x
    73. min_y2 = cropped_bbox.y - pipe_boxes[i].y
    74. if np.any(self.bird_hitmask[self.bird_index][min_x1:min_x1 + cropped_bbox.width,
    75. min_y1:min_y1 + cropped_bbox.height] * self.pipe_hitmask[i][min_x2:min_x2 + cropped_bbox.width,
    76. min_y2:min_y2 + cropped_bbox.height]):
    77. return True
    78. return False
    79. def next_frame(self, action):
    80. pump()
    81. reward = 0.1
    82. terminal = False
    83. # Check input action
    84. if action == 1:
    85. self.current_velocity_y = self.upward_speed
    86. self.is_flapped = True
    87. # Update score
    88. bird_center_x = self.bird_x + self.bird_width / 2
    89. for pipe in self.pipes:
    90. pipe_center_x = pipe["x_upper"] + self.pipe_width / 2
    91. if pipe_center_x < bird_center_x < pipe_center_x + 5:
    92. self.score += 1
    93. reward = 1
    94. break
    95. # Update index and iteration
    96. if (self.iter + 1) % 3 == 0:
    97. self.bird_index = next(self.bird_index_generator)
    98. self.iter = 0
    99. self.base_x = -((-self.base_x + 100) % self.base_shift)
    100. # Update bird's position
    101. if self.current_velocity_y < self.max_velocity_y and not self.is_flapped:
    102. self.current_velocity_y += self.downward_speed
    103. if self.is_flapped:
    104. self.is_flapped = False
    105. self.bird_y += min(self.current_velocity_y, self.bird_y - self.current_velocity_y - self.bird_height)
    106. if self.bird_y < 0:
    107. self.bird_y = 0
    108. # Update pipes' position
    109. for pipe in self.pipes:
    110. pipe["x_upper"] += self.pipe_velocity_x
    111. pipe["x_lower"] += self.pipe_velocity_x
    112. # Update pipes
    113. if 0 < self.pipes[0]["x_lower"] < 5:
    114. self.pipes.append(self.generate_pipe())
    115. if self.pipes[0]["x_lower"] < -self.pipe_width:
    116. del self.pipes[0]
    117. if self.is_collided():
    118. terminal = True
    119. reward = -1
    120. self.__init__()
    121. # Draw everything
    122. self.screen.blit(self.background_image, (0, 0))
    123. self.screen.blit(self.base_image, (self.base_x, self.base_y))
    124. self.screen.blit(self.bird_images[self.bird_index], (self.bird_x, self.bird_y))
    125. for pipe in self.pipes:
    126. self.screen.blit(self.pipe_images[0], (pipe["x_upper"], pipe["y_upper"]))
    127. self.screen.blit(self.pipe_images[1], (pipe["x_lower"], pipe["y_lower"]))
    128. image = array3d(display.get_surface())
    129. display.update()
    130. self.fps_clock.tick(self.fps)
    131. return image, reward, terminal
    1. import argparse
    2. import torch
    3. from src.deep_q_network import DeepQNetwork
    4. from src.flappy_bird import FlappyBird
    5. from src.utils import pre_processing
    6. def get_args():
    7. parser = argparse.ArgumentParser(
    8. """Implementation of Deep Q Network to play Flappy Bird""")
    9. parser.add_argument("--image_size", type=int, default=84, help="The common width and height for all images")
    10. parser.add_argument("--saved_path", type=str, default="trained_models")
    11. args = parser.parse_args()
    12. return args
    13. def q_test(opt):
    14. if torch.cuda.is_available():
    15. torch.cuda.manual_seed(123)
    16. else:
    17. torch.manual_seed(123)
    18. if torch.cuda.is_available():
    19. model = torch.load("{}/flappy_bird".format(opt.saved_path))
    20. else:
    21. model = torch.load("{}/flappy_bird".format(opt.saved_path), map_location=lambda storage, loc: storage)
    22. model.eval()
    23. game_state = FlappyBird()
    24. image, reward, terminal = game_state.next_frame(0)
    25. image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)
    26. image = torch.from_numpy(image)
    27. if torch.cuda.is_available():
    28. model.cuda()
    29. image = image.cuda()
    30. state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]
    31. while True:
    32. prediction = model(state)[0]
    33. action = torch.argmax(prediction)
    34. next_image, reward, terminal = game_state.next_frame(action)
    35. next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,
    36. opt.image_size)
    37. next_image = torch.from_numpy(next_image)
    38. if torch.cuda.is_available():
    39. next_image = next_image.cuda()
    40. next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]
    41. state = next_state
    42. if __name__ == "__main__":
    43. opt = get_args()
    44. q_test(opt)
    1. def get_args():
    2. parser = argparse.ArgumentParser(
    3. """Implementation of Deep Q Network to play Flappy Bird""")
    4. parser.add_argument("--image_size", type=int, default=84, help="The common width and height for all images")
    5. parser.add_argument("--batch_size", type=int, default=32, help="The number of images per batch")
    6. parser.add_argument("--optimizer", type=str, choices=["sgd", "adam"], default="adam")
    7. parser.add_argument("--lr", type=float, default=1e-6)
    8. parser.add_argument("--gamma", type=float, default=0.99)
    9. parser.add_argument("--initial_epsilon", type=float, default=0.1)
    10. parser.add_argument("--final_epsilon", type=float, default=1e-4)
    11. parser.add_argument("--num_iters", type=int, default=2000000)
    12. parser.add_argument("--replay_memory_size", type=int, default=50000,
    13. help="Number of epoches between testing phases")
    14. parser.add_argument("--log_path", type=str, default="tensorboard")
    15. parser.add_argument("--saved_path", type=str, default="trained_models")
    16. args = parser.parse_args()
    17. return args
    18. def train(opt):
    19. if torch.cuda.is_available():
    20. torch.cuda.manual_seed(123)
    21. else:
    22. torch.manual_seed(123)
    23. model = DeepQNetwork()
    24. if os.path.isdir(opt.log_path):
    25. shutil.rmtree(opt.log_path)
    26. os.makedirs(opt.log_path)
    27. writer = SummaryWriter(opt.log_path)
    28. optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    29. criterion = nn.MSELoss()
    30. game_state = FlappyBird()
    31. image, reward, terminal = game_state.next_frame(0)
    32. image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)
    33. image = torch.from_numpy(image)
    34. if torch.cuda.is_available():
    35. model.cuda()
    36. image = image.cuda()
    37. state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]
    38. replay_memory = []
    39. iter = 0
    40. while iter < opt.num_iters:
    41. prediction = model(state)[0]
    42. # Exploration or exploitation
    43. epsilon = opt.final_epsilon + (
    44. (opt.num_iters - iter) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_iters)
    45. u = random()
    46. random_action = u <= epsilon
    47. if random_action:
    48. print("Perform a random action")
    49. action = randint(0, 1)
    50. else:
    51. action = torch.argmax(prediction)
    52. next_image, reward, terminal = game_state.next_frame(action)
    53. next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,
    54. opt.image_size)
    55. next_image = torch.from_numpy(next_image)
    56. if torch.cuda.is_available():
    57. next_image = next_image.cuda()
    58. next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]
    59. replay_memory.append([state, action, reward, next_state, terminal])
    60. if len(replay_memory) > opt.replay_memory_size:
    61. del replay_memory[0]
    62. batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))
    63. state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(*batch)
    64. state_batch = torch.cat(tuple(state for state in state_batch))
    65. action_batch = torch.from_numpy(
    66. np.array([[1, 0] if action == 0 else [0, 1] for action in action_batch], dtype=np.float32))
    67. reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])
    68. next_state_batch = torch.cat(tuple(state for state in next_state_batch))
    69. if torch.cuda.is_available():
    70. state_batch = state_batch.cuda()
    71. action_batch = action_batch.cuda()
    72. reward_batch = reward_batch.cuda()
    73. next_state_batch = next_state_batch.cuda()
    74. current_prediction_batch = model(state_batch)
    75. next_prediction_batch = model(next_state_batch)
    76. y_batch = torch.cat(
    77. tuple(reward if terminal else reward + opt.gamma * torch.max(prediction) for reward, terminal, prediction in
    78. zip(reward_batch, terminal_batch, next_prediction_batch)))
    79. q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
    80. optimizer.zero_grad()
    81. # y_batch = y_batch.detach()
    82. loss = criterion(q_value, y_batch)
    83. loss.backward()
    84. optimizer.step()
    85. state = next_state
    86. iter += 1
    87. print("Iteration: {}/{}, Action: {}, Loss: {}, Epsilon {}, Reward: {}, Q-value: {}".format(
    88. iter + 1,
    89. opt.num_iters,
    90. action,
    91. loss,
    92. epsilon, reward, torch.max(prediction)))
    93. writer.add_scalar('Train/Loss', loss, iter)
    94. writer.add_scalar('Train/Epsilon', epsilon, iter)
    95. writer.add_scalar('Train/Reward', reward, iter)
    96. writer.add_scalar('Train/Q-value', torch.max(prediction), iter)
    97. if (iter+1) % 1000000 == 0:
    98. torch.save(model, "{}/flappy_bird_{}".format(opt.saved_path, iter+1))
    99. torch.save(model, "{}/flappy_bird".format(opt.saved_path))
    100. if __name__ == "__main__":
    101. opt = get_args()
    102. train(opt)

  • 相关阅读:
    Linux 驱动开发 六十六:多点触控(MT)协议
    对SPA的理解、对 vue组件化的理解
    (数据科学学习手札137)orjson:Python中最好用的json库
    mac openssl 版本到底怎么回事 已解决
    公钥加密如何确保数据的完整性
    一座 “数智桥梁”,华为助力“天堑变通途”
    数据分析--数据预处理
    DETRs Beat YOLOs on Real-time Object Detection
    如何使用IDE端通义灵码
    sql:SQL优化知识点记录(十二)
  • 原文地址:https://blog.csdn.net/timberman666/article/details/132590406