目录
实机演示
代码实现
实机演示
用深度强化学习来玩Chrome小恐龙快跑
代码实现
import os
import cv2
from pygame import RLEACCEL
from pygame.image import load
from pygame.sprite import Sprite, Group, collide_mask
from pygame import Rect, init, time, display, mixer, transform, Surface
from pygame.surfarray import array3d
import torch
from random import randrange, choice
import numpy as npmixer.pre_init(44100, -16, 2, 2048)
init()scr_size = (width, height) = (600, 150)
FPS = 60
gravity = 0.6black = (0, 0, 0)
white = (255, 255, 255)
background_col = (235, 235, 235)high_score = 0screen = display.set_mode(scr_size)
clock = time.Clock()
display.set_caption("T-Rex Rush")def load_image(name,sizex=-1,sizey=-1,colorkey=None,
):fullname = os.path.join("assets/sprites", name)image = load(fullname)image = image.convert()if colorkey is not None:if colorkey is -1:colorkey = image.get_at((0, 0))image.set_colorkey(colorkey, RLEACCEL)if sizex != -1 or sizey != -1:image = transform.scale(image, (sizex, sizey))return (image, image.get_rect())def load_sprite_sheet(sheetname,nx,ny,scalex=-1,scaley=-1,colorkey=None,
):fullname = os.path.join("assets/sprites", sheetname)sheet = load(fullname)sheet = sheet.convert()sheet_rect = sheet.get_rect()sprites = []sizey = sheet_rect.height / nyif isinstance(nx, int):sizex = sheet_rect.width / nxfor i in range(0, ny):for j in range(0, nx):rect = Rect((j * sizex, i * sizey, sizex, sizey))image = Surface(rect.size)image = image.convert()image.blit(sheet, (0, 0), rect)if colorkey is not None:if colorkey is -1:colorkey = image.get_at((0, 0))image.set_colorkey(colorkey, RLEACCEL)if scalex != -1 or scaley != -1:image = transform.scale(image, (scalex, scaley))sprites.append(image)else: #listsizex_ls = [sheet_rect.width / i_nx for i_nx in nx]for i in range(0, ny):for i_nx, sizex, i_scalex in zip(nx, sizex_ls, scalex):for j in range(0, i_nx):rect = Rect((j * sizex, i * sizey, sizex, sizey))image = Surface(rect.size)image = image.convert()image.blit(sheet, (0, 0), rect)if colorkey is not None:if colorkey is -1:colorkey = image.get_at((0, 0))image.set_colorkey(colorkey, RLEACCEL)if i_scalex != -1 or scaley != -1:image = transform.scale(image, (i_scalex, scaley))sprites.append(image)sprite_rect = sprites[0].get_rect()return sprites, sprite_rectdef extractDigits(number):if number > -1:digits = []i = 0while (number / 10 != 0):digits.append(number % 10)number = int(number / 10)digits.append(number % 10)for i in range(len(digits), 5):digits.append(0)digits.reverse()return digitsdef pre_processing(image, w=84, h=84):image = image[:300, :, :]# cv2.imwrite("ori.jpg", image)image = cv2.cvtColor(cv2.resize(image, (w, h)), cv2.COLOR_BGR2GRAY)# cv2.imwrite("color.jpg", image)_, image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)# cv2.imwrite("bw.jpg", image)return image[None, :, :].astype(np.float32)class Dino():def __init__(self, sizex=-1, sizey=-1):self.images, self.rect = load_sprite_sheet("dino.png", 5, 1, sizex, sizey, -1)self.images1, self.rect1 = load_sprite_sheet("dino_ducking.png", 2, 1, 59, sizey, -1)self.rect.bottom = int(0.98 * height)self.rect.left = width / 15self.image = self.images[0]self.index = 0self.counter = 0self.score = 0self.isJumping = Falseself.isDead = Falseself.isDucking = Falseself.isBlinking = Falseself.movement = [0, 0]self.jumpSpeed = 11.5self.stand_pos_width = self.rect.widthself.duck_pos_width = self.rect1.widthdef draw(self):screen.blit(self.image, self.rect)def checkbounds(self):if self.rect.bottom > int(0.98 * height):self.rect.bottom = int(0.98 * height)self.isJumping = Falsedef update(self):if self.isJumping:self.movement[1] = self.movement[1] + gravityif self.isJumping:self.index = 0elif self.isBlinking:if self.index == 0:if self.counter % 400 == 399:self.index = (self.index + 1) % 2else:if self.counter % 20 == 19:self.index = (self.index + 1) % 2elif self.isDucking:if self.counter % 5 == 0:self.index = (self.index + 1) % 2else:if self.counter % 5 == 0:self.index = (self.index + 1) % 2 + 2if self.isDead:self.index = 4if not self.isDucking:self.image = self.images[self.index]self.rect.width = self.stand_pos_widthelse:self.image = self.images1[(self.index) % 2]self.rect.width = self.duck_pos_widthself.rect = self.rect.move(self.movement)self.checkbounds()if not self.isDead and self.counter % 7 == 6 and self.isBlinking == False:self.score += 1self.counter = (self.counter + 1)class Cactus(Sprite):def __init__(self, speed=5, sizex=-1, sizey=-1):Sprite.__init__(self, self.containers)self.images, self.rect = load_sprite_sheet("cacti-small.png", [2, 3, 6], 1, sizex, sizey, -1)self.rect.bottom = int(0.98 * height)self.rect.left = width + self.rect.widthself.image = self.images[randrange(0, 11)]self.movement = [-1 * speed, 0]def draw(self):screen.blit(self.image, self.rect)def update(self):self.rect = self.rect.move(self.movement)if self.rect.right < 0:self.kill()class Ptera(Sprite):def __init__(self, speed=5, sizex=-1, sizey=-1):Sprite.__init__(self, self.containers)self.images, self.rect = load_sprite_sheet("ptera.png", 2, 1, sizex, sizey, -1)self.ptera_height = [height * 0.82, height * 0.75, height * 0.60, height * 0.48]self.rect.centery = self.ptera_height[randrange(0, 4)]self.rect.left = width + self.rect.widthself.image = self.images[0]self.movement = [-1 * speed, 0]self.index = 0self.counter = 0def draw(self):screen.blit(self.image, self.rect)def update(self):if self.counter % 10 == 0:self.index = (self.index + 1) % 2self.image = self.images[self.index]self.rect = self.rect.move(self.movement)self.counter = (self.counter + 1)if self.rect.right < 0:self.kill()class Ground():def __init__(self, speed=-5):self.image, self.rect = load_image("ground.png", -1, -1, -1)self.image1, self.rect1 = load_image("ground.png", -1, -1, -1)self.rect.bottom = heightself.rect1.bottom = heightself.rect1.left = self.rect.rightself.speed = speeddef draw(self):screen.blit(self.image, self.rect)screen.blit(self.image1, self.rect1)def update(self):self.rect.left += self.speedself.rect1.left += self.speedif self.rect.right < 0:self.rect.left = self.rect1.rightif self.rect1.right < 0:self.rect1.left = self.rect.rightclass Cloud(Sprite):def __init__(self, x, y):Sprite.__init__(self, self.containers)self.image, self.rect = load_image("cloud.png", int(90 * 30 / 42), 30, -1)self.speed = 1self.rect.left = xself.rect.top = yself.movement = [-1 * self.speed, 0]def draw(self):screen.blit(self.image, self.rect)def update(self):self.rect = self.rect.move(self.movement)if self.rect.right < 0:self.kill()class Scoreboard():def __init__(self, x=-1, y=-1):self.score = 0self.tempimages, self.temprect = load_sprite_sheet("numbers.png", 12, 1, 11, int(11 * 6 / 5), -1)self.image = Surface((55, int(11 * 6 / 5)))self.rect = self.image.get_rect()if x == -1:self.rect.left = width * 0.89else:self.rect.left = xif y == -1:self.rect.top = height * 0.1else:self.rect.top = ydef draw(self):screen.blit(self.image, self.rect)def update(self, score):score_digits = extractDigits(score)self.image.fill(background_col)if len(score_digits) == 6:score_digits = score_digits[1:]for s in score_digits:self.image.blit(self.tempimages[s], self.temprect)self.temprect.left += self.temprect.widthself.temprect.left = 0class ChromeDino(object):def __init__(self):self.gamespeed = 5self.gameOver = Falseself.gameQuit = Falseself.playerDino = Dino(44, 47)self.new_ground = Ground(-1 * self.gamespeed)self.scb = Scoreboard()self.highsc = Scoreboard(width * 0.78)self.counter = 0self.cacti = Group()self.pteras = Group()self.clouds = Group()self.last_obstacle = Group()Cactus.containers = self.cactiPtera.containers = self.pterasCloud.containers = self.cloudsself.retbutton_image, self.retbutton_rect = load_image("replay_button.png", 35, 31, -1)self.gameover_image, self.gameover_rect = load_image("game_over.png", 190, 11, -1)self.temp_images, self.temp_rect = load_sprite_sheet("numbers.png", 12, 1, 11, int(11 * 6 / 5), -1)self.HI_image = Surface((22, int(11 * 6 / 5)))self.HI_rect = self.HI_image.get_rect()self.HI_image.fill(background_col)self.HI_image.blit(self.temp_images[10], self.temp_rect)self.temp_rect.left += self.temp_rect.widthself.HI_image.blit(self.temp_images[11], self.temp_rect)self.HI_rect.top = height * 0.1self.HI_rect.left = width * 0.73def step(self, action, record=False): # 0: Do nothing. 1: Jump. 2: Duckreward = 0.1if action == 0:reward += 0.01self.playerDino.isDucking = Falseelif action == 1:self.playerDino.isDucking = Falseif self.playerDino.rect.bottom == int(0.98 * height):self.playerDino.isJumping = Trueself.playerDino.movement[1] = -1 * self.playerDino.jumpSpeedelif action == 2:if not (self.playerDino.isJumping and self.playerDino.isDead) and self.playerDino.rect.bottom == int(0.98 * height):self.playerDino.isDucking = Truefor c in self.cacti:c.movement[0] = -1 * self.gamespeedif collide_mask(self.playerDino, c):self.playerDino.isDead = Truereward = -1breakelse:if c.rect.right < self.playerDino.rect.left < c.rect.right + self.gamespeed + 1:reward = 1breakfor p in self.pteras:p.movement[0] = -1 * self.gamespeedif collide_mask(self.playerDino, p):self.playerDino.isDead = Truereward = -1breakelse:if p.rect.right < self.playerDino.rect.left < p.rect.right + self.gamespeed + 1:reward = 1breakif len(self.cacti) < 2:if len(self.cacti) == 0 and len(self.pteras) == 0:self.last_obstacle.empty()self.last_obstacle.add(Cactus(self.gamespeed, [60, 40, 20], choice([40, 45, 50])))else:for l in self.last_obstacle:if l.rect.right < width * 0.7 and randrange(0, 50) == 10:self.last_obstacle.empty()self.last_obstacle.add(Cactus(self.gamespeed, [60, 40, 20], choice([40, 45, 50])))# if len(self.pteras) == 0 and randrange(0, 200) == 10 and self.counter > 500:if len(self.pteras) == 0 and len(self.cacti) < 2 and randrange(0, 50) == 10 and self.counter > 500:for l in self.last_obstacle:if l.rect.right < width * 0.8:self.last_obstacle.empty()self.last_obstacle.add(Ptera(self.gamespeed, 46, 40))if len(self.clouds) < 5 and randrange(0, 300) == 10:Cloud(width, randrange(height / 5, height / 2))self.playerDino.update()self.cacti.update()self.pteras.update()self.clouds.update()self.new_ground.update()self.scb.update(self.playerDino.score)state = display.get_surface()screen.fill(background_col)self.new_ground.draw()self.clouds.draw(screen)self.scb.draw()self.cacti.draw(screen)self.pteras.draw(screen)self.playerDino.draw()display.update()clock.tick(FPS)if self.playerDino.isDead:self.gameOver = Trueself.counter = (self.counter + 1)if self.gameOver:self.__init__()state = array3d(state)if record:return torch.from_numpy(pre_processing(state)), np.transpose(cv2.cvtColor(state, cv2.COLOR_RGB2BGR), (1, 0, 2)), reward, not (reward > 0)else:return torch.from_numpy(pre_processing(state)), reward, not (reward > 0)
import torch.nn as nnclass DeepQNetwork(nn.Module):def __init__(self):super(DeepQNetwork, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))self.fc2 = nn.Linear(512, 3)self._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):nn.init.uniform_(m.weight, -0.01, 0.01)nn.init.constant_(m.bias, 0)def forward(self, input):output = self.conv1(input)output = self.conv2(output)output = self.conv3(output)output = output.view(output.size(0), -1)output = self.fc1(output)output = self.fc2(output)return output
import argparse
import torchfrom src.model import DeepQNetwork
from src.env import ChromeDino
import cv2def get_args():parser = argparse.ArgumentParser("""Implementation of Deep Q Network to play Chrome Dino""")parser.add_argument("--saved_path", type=str, default="trained_models")parser.add_argument("--fps", type=int, default=60, help="frames per second")parser.add_argument("--output", type=str, default="output/chrome_dino.mp4", help="the path to output video")args = parser.parse_args()return argsdef q_test(opt):if torch.cuda.is_available():torch.cuda.manual_seed(123)else:torch.manual_seed(123)model = DeepQNetwork()checkpoint_path = "{}/chrome_dino.pth".format(opt.saved_path)checkpoint = torch.load(checkpoint_path)model.load_state_dict(checkpoint["model_state_dict"])model.eval()env = ChromeDino()state, raw_state, _, _ = env.step(0, True)state = torch.cat(tuple(state for _ in range(4)))[None, :, :, :]if torch.cuda.is_available():model.cuda()state = state.cuda()out = cv2.VideoWriter(opt.output, cv2.VideoWriter_fourcc(*"MJPG"), opt.fps, (600, 150))done = Falsewhile not done:prediction = model(state)[0]action = torch.argmax(prediction).item()next_state, raw_next_state, reward, done = env.step(action, True)out.write(raw_next_state)if torch.cuda.is_available():next_state = next_state.cuda()next_state = torch.cat((state[0, 1:, :, :], next_state))[None, :, :, :]state = next_stateif __name__ == "__main__":opt = get_args()q_test(opt)
import argparse
import os
from random import random, randint, sample
import pickle
import numpy as np
import torch
import torch.nn as nnfrom src.model import DeepQNetwork
from src.env import ChromeDinodef get_args():parser = argparse.ArgumentParser("""Implementation of Deep Q Network to play Chrome Dino""")parser.add_argument("--batch_size", type=int, default=64, help="The number of images per batch")parser.add_argument("--optimizer", type=str, choices=["sgd", "adam"], default="adam")parser.add_argument("--lr", type=float, default=1e-4)parser.add_argument("--gamma", type=float, default=0.99)parser.add_argument("--initial_epsilon", type=float, default=0.1)parser.add_argument("--final_epsilon", type=float, default=1e-4)parser.add_argument("--num_decay_iters", type=float, default=2000000)parser.add_argument("--num_iters", type=int, default=2000000)parser.add_argument("--replay_memory_size", type=int, default=50000,help="Number of epoches between testing phases")parser.add_argument("--saved_folder", type=str, default="trained_models")args = parser.parse_args()return argsdef train(opt):if torch.cuda.is_available():torch.cuda.manual_seed(123)else:torch.manual_seed(123)model = DeepQNetwork()if torch.cuda.is_available():model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)if not os.path.isdir(opt.saved_folder):os.makedirs(opt.saved_folder)checkpoint_path = os.path.join(opt.saved_folder, "chrome_dino.pth")memory_path = os.path.join(opt.saved_folder, "replay_memory.pkl")if os.path.isfile(checkpoint_path):checkpoint = torch.load(checkpoint_path)iter = checkpoint["iter"] + 1model.load_state_dict(checkpoint["model_state_dict"])optimizer.load_state_dict(checkpoint["optimizer"])print("Load trained model from iteration {}".format(iter))else:iter = 0if os.path.isfile(memory_path):with open(memory_path, "rb") as f:replay_memory = pickle.load(f)print("Load replay memory")else:replay_memory = []criterion = nn.MSELoss()env = ChromeDino()state, _, _ = env.step(0)state = torch.cat(tuple(state for _ in range(4)))[None, :, :, :]while iter < opt.num_iters:if torch.cuda.is_available():prediction = model(state.cuda())[0]else:prediction = model(state)[0]# Exploration or exploitationepsilon = opt.final_epsilon + (max(opt.num_decay_iters - iter, 0) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_decay_iters)u = random()random_action = u <= epsilonif random_action:action = randint(0, 2)else:action = torch.argmax(prediction).item()next_state, reward, done = env.step(action)next_state = torch.cat((state[0, 1:, :, :], next_state))[None, :, :, :]replay_memory.append([state, action, reward, next_state, done])if len(replay_memory) > opt.replay_memory_size:del replay_memory[0]batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*batch)state_batch = torch.cat(tuple(state for state in state_batch))action_batch = torch.from_numpy(np.array([[1, 0, 0] if action == 0 else [0, 1, 0] if action == 1 else [0, 0, 1] for action inaction_batch], dtype=np.float32))reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])next_state_batch = torch.cat(tuple(state for state in next_state_batch))if torch.cuda.is_available():state_batch = state_batch.cuda()action_batch = action_batch.cuda()reward_batch = reward_batch.cuda()next_state_batch = next_state_batch.cuda()current_prediction_batch = model(state_batch)next_prediction_batch = model(next_state_batch)y_batch = torch.cat(tuple(reward if done else reward + opt.gamma * torch.max(prediction) for reward, done, prediction inzip(reward_batch, done_batch, next_prediction_batch)))q_value = torch.sum(current_prediction_batch * action_batch, dim=1)optimizer.zero_grad()loss = criterion(q_value, y_batch)loss.backward()optimizer.step()state = next_stateiter += 1print("Iteration: {}/{}, Loss: {:.5f}, Epsilon {:.5f}, Reward: {}".format(iter + 1,opt.num_iters,loss,epsilon, reward))if (iter + 1) % 50000 == 0:checkpoint = {"iter": iter,"model_state_dict": model.state_dict(),"optimizer": optimizer.state_dict()}torch.save(checkpoint, checkpoint_path)with open(memory_path, "wb") as f:pickle.dump(replay_memory, f, protocol=pickle.HIGHEST_PROTOCOL)if __name__ == "__main__":opt = get_args()train(opt)