diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7c1bdbd --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.venv/ +__pycache__/ +.vscode/ +logs/ +src/snake_ai_pytorch.egg-info/ +.pytest_cache/ +model/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8cec85c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,43 @@ +repos: + - repo: https://github.com/asottile/pyupgrade + rev: v3.20.0 + hooks: + - id: pyupgrade + + - repo: https://github.com/MarcoGorelli/absolufy-imports + rev: v0.3.1 + hooks: + - id: absolufy-imports + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + exclude: ^src/pattern_identifier\.egg-info/ + - id: end-of-file-fixer + exclude: ^src/pattern_identifier\.egg-info/ + - id: check-yaml + + - repo: https://github.com/PyCQA/bandit + rev: "1.8.3" + hooks: + - id: bandit + args: ["--exclude", ".tox,.eggs,tests"] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.11.13" # Use the latest version + hooks: + - id: ruff + args: ["--fix", "--exit-non-zero-on-fix"] + - id: ruff-format + # args: ["--check"] # Use --check in CI if you only want to verify + + + # - repo: local + # hooks: + # - id: pytest-check + # name: pytest-check + # entry: pytest + # language: system + # pass_filenames: false + # always_run: true diff --git a/README.md b/README.md index 3e358a1..1ca714a 100644 --- a/README.md +++ b/README.md @@ -8,3 +8,36 @@ You can find all tutorials on my channel: [Playlist](https://www.youtube.com/pla - Part 2: Learn how to setup the environment and implement the Snake game. - Part 3: Implement the agent that controls the game. - Part 4: Implement the neural network to predict the moves and train it. + +## How to install + +```bash +# Create virtual environment +python -m venv .venv + +# Activate virtual environment +source ./.venv/bin/activate + +# Install dependencies +pip install -e . +``` + +## Evaluating the Neural Network + +### 1. Running the training + +```bash +python src/snake_ai_pytorch/controllers/train.py +``` + +### 2. Seeing the trained agent play + +```bash +python src/snake_ai_pytorch/controllers/enjoy.py +``` + +> ## To play yourself + +```bash +python src/snake_ai_pytorch/controllers/play.py +``` diff --git a/agent.py b/agent.py deleted file mode 100755 index e4bd89e..0000000 --- a/agent.py +++ /dev/null @@ -1,147 +0,0 @@ -import torch -import random -import numpy as np -from collections import deque -from game import SnakeGameAI, Direction, Point -from model import Linear_QNet, QTrainer -from helper import plot - -MAX_MEMORY = 100_000 -BATCH_SIZE = 1000 -LR = 0.001 - -class Agent: - - def __init__(self): - self.n_games = 0 - self.epsilon = 0 # randomness - self.gamma = 0.9 # discount rate - self.memory = deque(maxlen=MAX_MEMORY) # popleft() - self.model = Linear_QNet(11, 256, 3) - self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma) - - - def get_state(self, game): - head = game.snake[0] - point_l = Point(head.x - 20, head.y) - point_r = Point(head.x + 20, head.y) - point_u = Point(head.x, head.y - 20) - point_d = Point(head.x, head.y + 20) - - dir_l = game.direction == Direction.LEFT - dir_r = game.direction == Direction.RIGHT - dir_u = game.direction == Direction.UP - dir_d = game.direction == Direction.DOWN - - state = [ - # Danger straight - (dir_r and game.is_collision(point_r)) or - (dir_l and game.is_collision(point_l)) or - (dir_u and game.is_collision(point_u)) or - (dir_d and game.is_collision(point_d)), - - # Danger right - (dir_u and game.is_collision(point_r)) or - (dir_d and game.is_collision(point_l)) or - (dir_l and game.is_collision(point_u)) or - (dir_r and game.is_collision(point_d)), - - # Danger left - (dir_d and game.is_collision(point_r)) or - (dir_u and game.is_collision(point_l)) or - (dir_r and game.is_collision(point_u)) or - (dir_l and game.is_collision(point_d)), - - # Move direction - dir_l, - dir_r, - dir_u, - dir_d, - - # Food location - game.food.x < game.head.x, # food left - game.food.x > game.head.x, # food right - game.food.y < game.head.y, # food up - game.food.y > game.head.y # food down - ] - - return np.array(state, dtype=int) - - def remember(self, state, action, reward, next_state, done): - self.memory.append((state, action, reward, next_state, done)) # popleft if MAX_MEMORY is reached - - def train_long_memory(self): - if len(self.memory) > BATCH_SIZE: - mini_sample = random.sample(self.memory, BATCH_SIZE) # list of tuples - else: - mini_sample = self.memory - - states, actions, rewards, next_states, dones = zip(*mini_sample) - self.trainer.train_step(states, actions, rewards, next_states, dones) - #for state, action, reward, nexrt_state, done in mini_sample: - # self.trainer.train_step(state, action, reward, next_state, done) - - def train_short_memory(self, state, action, reward, next_state, done): - self.trainer.train_step(state, action, reward, next_state, done) - - def get_action(self, state): - # random moves: tradeoff exploration / exploitation - self.epsilon = 80 - self.n_games - final_move = [0,0,0] - if random.randint(0, 200) < self.epsilon: - move = random.randint(0, 2) - final_move[move] = 1 - else: - state0 = torch.tensor(state, dtype=torch.float) - prediction = self.model(state0) - move = torch.argmax(prediction).item() - final_move[move] = 1 - - return final_move - - -def train(): - plot_scores = [] - plot_mean_scores = [] - total_score = 0 - record = 0 - agent = Agent() - game = SnakeGameAI() - while True: - # get old state - state_old = agent.get_state(game) - - # get move - final_move = agent.get_action(state_old) - - # perform move and get new state - reward, done, score = game.play_step(final_move) - state_new = agent.get_state(game) - - # train short memory - agent.train_short_memory(state_old, final_move, reward, state_new, done) - - # remember - agent.remember(state_old, final_move, reward, state_new, done) - - if done: - # train long memory, plot result - game.reset() - agent.n_games += 1 - agent.train_long_memory() - - if score > record: - record = score - agent.model.save() - - print('Game', agent.n_games, 'Score', score, 'Record:', record) - - plot_scores.append(score) - total_score += score - mean_score = total_score / agent.n_games - plot_mean_scores.append(mean_score) - plot(plot_scores, plot_mean_scores) - - -if __name__ == '__main__': - train() \ No newline at end of file diff --git a/game.py b/game.py deleted file mode 100755 index 7c34da7..0000000 --- a/game.py +++ /dev/null @@ -1,154 +0,0 @@ -import pygame -import random -from enum import Enum -from collections import namedtuple -import numpy as np - -pygame.init() -font = pygame.font.Font('arial.ttf', 25) -#font = pygame.font.SysFont('arial', 25) - -class Direction(Enum): - RIGHT = 1 - LEFT = 2 - UP = 3 - DOWN = 4 - -Point = namedtuple('Point', 'x, y') - -# rgb colors -WHITE = (255, 255, 255) -RED = (200,0,0) -BLUE1 = (0, 0, 255) -BLUE2 = (0, 100, 255) -BLACK = (0,0,0) - -BLOCK_SIZE = 20 -SPEED = 40 - -class SnakeGameAI: - - def __init__(self, w=640, h=480): - self.w = w - self.h = h - # init display - self.display = pygame.display.set_mode((self.w, self.h)) - pygame.display.set_caption('Snake') - self.clock = pygame.time.Clock() - self.reset() - - - def reset(self): - # init game state - self.direction = Direction.RIGHT - - self.head = Point(self.w/2, self.h/2) - self.snake = [self.head, - Point(self.head.x-BLOCK_SIZE, self.head.y), - Point(self.head.x-(2*BLOCK_SIZE), self.head.y)] - - self.score = 0 - self.food = None - self._place_food() - self.frame_iteration = 0 - - - def _place_food(self): - x = random.randint(0, (self.w-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE - y = random.randint(0, (self.h-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE - self.food = Point(x, y) - if self.food in self.snake: - self._place_food() - - - def play_step(self, action): - self.frame_iteration += 1 - # 1. collect user input - for event in pygame.event.get(): - if event.type == pygame.QUIT: - pygame.quit() - quit() - - # 2. move - self._move(action) # update the head - self.snake.insert(0, self.head) - - # 3. check if game over - reward = 0 - game_over = False - if self.is_collision() or self.frame_iteration > 100*len(self.snake): - game_over = True - reward = -10 - return reward, game_over, self.score - - # 4. place new food or just move - if self.head == self.food: - self.score += 1 - reward = 10 - self._place_food() - else: - self.snake.pop() - - # 5. update ui and clock - self._update_ui() - self.clock.tick(SPEED) - # 6. return game over and score - return reward, game_over, self.score - - - def is_collision(self, pt=None): - if pt is None: - pt = self.head - # hits boundary - if pt.x > self.w - BLOCK_SIZE or pt.x < 0 or pt.y > self.h - BLOCK_SIZE or pt.y < 0: - return True - # hits itself - if pt in self.snake[1:]: - return True - - return False - - - def _update_ui(self): - self.display.fill(BLACK) - - for pt in self.snake: - pygame.draw.rect(self.display, BLUE1, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE)) - pygame.draw.rect(self.display, BLUE2, pygame.Rect(pt.x+4, pt.y+4, 12, 12)) - - pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE)) - - text = font.render("Score: " + str(self.score), True, WHITE) - self.display.blit(text, [0, 0]) - pygame.display.flip() - - - def _move(self, action): - # [straight, right, left] - - clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP] - idx = clock_wise.index(self.direction) - - if np.array_equal(action, [1, 0, 0]): - new_dir = clock_wise[idx] # no change - elif np.array_equal(action, [0, 1, 0]): - next_idx = (idx + 1) % 4 - new_dir = clock_wise[next_idx] # right turn r -> d -> l -> u - else: # [0, 0, 1] - next_idx = (idx - 1) % 4 - new_dir = clock_wise[next_idx] # left turn r -> u -> l -> d - - self.direction = new_dir - - x = self.head.x - y = self.head.y - if self.direction == Direction.RIGHT: - x += BLOCK_SIZE - elif self.direction == Direction.LEFT: - x -= BLOCK_SIZE - elif self.direction == Direction.DOWN: - y += BLOCK_SIZE - elif self.direction == Direction.UP: - y -= BLOCK_SIZE - - self.head = Point(x, y) \ No newline at end of file diff --git a/helper.py b/helper.py deleted file mode 100755 index 3a0a979..0000000 --- a/helper.py +++ /dev/null @@ -1,19 +0,0 @@ -import matplotlib.pyplot as plt -from IPython import display - -plt.ion() - -def plot(scores, mean_scores): - display.clear_output(wait=True) - display.display(plt.gcf()) - plt.clf() - plt.title('Training...') - plt.xlabel('Number of Games') - plt.ylabel('Score') - plt.plot(scores) - plt.plot(mean_scores) - plt.ylim(ymin=0) - plt.text(len(scores)-1, scores[-1], str(scores[-1])) - plt.text(len(mean_scores)-1, mean_scores[-1], str(mean_scores[-1])) - plt.show(block=False) - plt.pause(.1) diff --git a/model.py b/model.py deleted file mode 100755 index 94b7001..0000000 --- a/model.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -import os - -class Linear_QNet(nn.Module): - def __init__(self, input_size, hidden_size, output_size): - super().__init__() - self.linear1 = nn.Linear(input_size, hidden_size) - self.linear2 = nn.Linear(hidden_size, output_size) - - def forward(self, x): - x = F.relu(self.linear1(x)) - x = self.linear2(x) - return x - - def save(self, file_name='model.pth'): - model_folder_path = './model' - if not os.path.exists(model_folder_path): - os.makedirs(model_folder_path) - - file_name = os.path.join(model_folder_path, file_name) - torch.save(self.state_dict(), file_name) - - -class QTrainer: - def __init__(self, model, lr, gamma): - self.lr = lr - self.gamma = gamma - self.model = model - self.optimizer = optim.Adam(model.parameters(), lr=self.lr) - self.criterion = nn.MSELoss() - - def train_step(self, state, action, reward, next_state, done): - state = torch.tensor(state, dtype=torch.float) - next_state = torch.tensor(next_state, dtype=torch.float) - action = torch.tensor(action, dtype=torch.long) - reward = torch.tensor(reward, dtype=torch.float) - # (n, x) - - if len(state.shape) == 1: - # (1, x) - state = torch.unsqueeze(state, 0) - next_state = torch.unsqueeze(next_state, 0) - action = torch.unsqueeze(action, 0) - reward = torch.unsqueeze(reward, 0) - done = (done, ) - - # 1: predicted Q values with current state - pred = self.model(state) - - target = pred.clone() - for idx in range(len(done)): - Q_new = reward[idx] - if not done[idx]: - Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx])) - - target[idx][torch.argmax(action[idx]).item()] = Q_new - - # 2: Q_new = r + y * max(next_predicted Q value) -> only do this if not done - # pred.clone() - # preds[argmax(action)] = Q_new - self.optimizer.zero_grad() - loss = self.criterion(target, pred) - loss.backward() - - self.optimizer.step() - - - diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a093e32 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,74 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "snake-ai-pytorch" +version = "1.0" +dependencies = [ + "pygame", + "gymnasium", + "rl-zoo3", + "IPython", + "torch==2.5.1", # For testing with RL algorithms ONLY_OLDER_CUDA +] + +# # Remove all the lines marked with ONLY_OLDER_CUDA if you want to use the latest CUDA version +[tool.pip] # ONLY_OLDER_CUDA +extra-index-url = "https://download.pytorch.org/whl/cu124" # ONLY_OLDER_CUDA + +[project.optional-dependencies] +test = [ + "pytest" +] + +dev = [ + "snake-ai-pytorch[test]", # Includes all dependencies from the 'test' group, + "pre-commit", # For managing pre-commit hooks +] + +[tool.setuptools] +package-dir = {"" = "src"} + + +[tool.ruff] +line-length = 149 +target-version = "py311" # Matches your black config and python support + +[tool.ruff.lint] +# Enable Pyflakes (F), pycodestyle (E, W), isort (I) +# Enable many common Pylint (PL), flake8-bugbear (B), flake8-comprehensions (C4), etc. rules +select = [ + "F", "E", "W", "I", "N", "D", # Core flake8, isort, pep8-naming, pydocstyle + "UP", # pyupgrade + "B", # flake8-bugbear + "A", # flake8-builtins + "C4", # flake8-comprehensions + "T20", # flake8-print (T201 for print, T203 for pprint) + "SIM", # flake8-simplify + "PTH", # flake8-use-pathlib + # "PL", # Pylint + # "TRY", # tryceratops +] +ignore = [ + # Ignore common missing docstring errors for now + "D100", # Missing docstring in public module + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D104", # Missing docstring in public package + "D107", # Missing docstring in `__init__` + "D203", # 1 blank line required before class docstring + "D213", # Multi-line docstring closing quotes should be on a separate line + "D401", # First line of docstring should be in imperative mood + "N803", # Function name should be lowercase + "N806", # Variable in function should be lowercase +] +# mccabe (cyclomatic complexity) +# mccabe.max-complexity = 10 # Default is 10, adjust as needed + +[tool.ruff.format] +quote-style = "double" # Black default +indent-style = "space" +skip-magic-trailing-comma = false # Black default +line-ending = "auto" diff --git a/snake_game_human.py b/snake_game_human.py deleted file mode 100644 index f51e9c0..0000000 --- a/snake_game_human.py +++ /dev/null @@ -1,147 +0,0 @@ -import pygame -import random -from enum import Enum -from collections import namedtuple - -pygame.init() -font = pygame.font.Font('arial.ttf', 25) -#font = pygame.font.SysFont('arial', 25) - -class Direction(Enum): - RIGHT = 1 - LEFT = 2 - UP = 3 - DOWN = 4 - -Point = namedtuple('Point', 'x, y') - -# rgb colors -WHITE = (255, 255, 255) -RED = (200,0,0) -BLUE1 = (0, 0, 255) -BLUE2 = (0, 100, 255) -BLACK = (0,0,0) - -BLOCK_SIZE = 20 -SPEED = 20 - -class SnakeGame: - - def __init__(self, w=640, h=480): - self.w = w - self.h = h - # init display - self.display = pygame.display.set_mode((self.w, self.h)) - pygame.display.set_caption('Snake') - self.clock = pygame.time.Clock() - - # init game state - self.direction = Direction.RIGHT - - self.head = Point(self.w/2, self.h/2) - self.snake = [self.head, - Point(self.head.x-BLOCK_SIZE, self.head.y), - Point(self.head.x-(2*BLOCK_SIZE), self.head.y)] - - self.score = 0 - self.food = None - self._place_food() - - def _place_food(self): - x = random.randint(0, (self.w-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE - y = random.randint(0, (self.h-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE - self.food = Point(x, y) - if self.food in self.snake: - self._place_food() - - def play_step(self): - # 1. collect user input - for event in pygame.event.get(): - if event.type == pygame.QUIT: - pygame.quit() - quit() - if event.type == pygame.KEYDOWN: - if event.key == pygame.K_LEFT: - self.direction = Direction.LEFT - elif event.key == pygame.K_RIGHT: - self.direction = Direction.RIGHT - elif event.key == pygame.K_UP: - self.direction = Direction.UP - elif event.key == pygame.K_DOWN: - self.direction = Direction.DOWN - - # 2. move - self._move(self.direction) # update the head - self.snake.insert(0, self.head) - - # 3. check if game over - game_over = False - if self._is_collision(): - game_over = True - return game_over, self.score - - # 4. place new food or just move - if self.head == self.food: - self.score += 1 - self._place_food() - else: - self.snake.pop() - - # 5. update ui and clock - self._update_ui() - self.clock.tick(SPEED) - # 6. return game over and score - return game_over, self.score - - def _is_collision(self): - # hits boundary - if self.head.x > self.w - BLOCK_SIZE or self.head.x < 0 or self.head.y > self.h - BLOCK_SIZE or self.head.y < 0: - return True - # hits itself - if self.head in self.snake[1:]: - return True - - return False - - def _update_ui(self): - self.display.fill(BLACK) - - for pt in self.snake: - pygame.draw.rect(self.display, BLUE1, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE)) - pygame.draw.rect(self.display, BLUE2, pygame.Rect(pt.x+4, pt.y+4, 12, 12)) - - pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE)) - - text = font.render("Score: " + str(self.score), True, WHITE) - self.display.blit(text, [0, 0]) - pygame.display.flip() - - def _move(self, direction): - x = self.head.x - y = self.head.y - if direction == Direction.RIGHT: - x += BLOCK_SIZE - elif direction == Direction.LEFT: - x -= BLOCK_SIZE - elif direction == Direction.DOWN: - y += BLOCK_SIZE - elif direction == Direction.UP: - y -= BLOCK_SIZE - - self.head = Point(x, y) - - -if __name__ == '__main__': - game = SnakeGame() - - # game loop - while True: - game_over, score = game.play_step() - - if game_over == True: - break - - print('Final Score', score) - - - pygame.quit() \ No newline at end of file diff --git a/arial.ttf b/src/snake_ai_pytorch/assets/arial.ttf similarity index 100% rename from arial.ttf rename to src/snake_ai_pytorch/assets/arial.ttf diff --git a/src/snake_ai_pytorch/controllers/enjoy.py b/src/snake_ai_pytorch/controllers/enjoy.py new file mode 100644 index 0000000..7b03cae --- /dev/null +++ b/src/snake_ai_pytorch/controllers/enjoy.py @@ -0,0 +1,5 @@ +from snake_ai_pytorch.models import Agent + +if __name__ == "__main__": + agent = Agent(render_mode="human") + agent.play() diff --git a/src/snake_ai_pytorch/controllers/play.py b/src/snake_ai_pytorch/controllers/play.py new file mode 100644 index 0000000..1e8f3fa --- /dev/null +++ b/src/snake_ai_pytorch/controllers/play.py @@ -0,0 +1,5 @@ +from snake_ai_pytorch.models import SnakeGameHuman + +if __name__ == "__main__": + env = SnakeGameHuman() + env.run() diff --git a/src/snake_ai_pytorch/controllers/train.py b/src/snake_ai_pytorch/controllers/train.py new file mode 100644 index 0000000..7f7a5ca --- /dev/null +++ b/src/snake_ai_pytorch/controllers/train.py @@ -0,0 +1,7 @@ +from snake_ai_pytorch.models import Agent + +if __name__ == "__main__": + # agent = Agent() + # agent = Agent(render_mode="fast_training") + agent = Agent(render_mode="human") + agent.train() diff --git a/src/snake_ai_pytorch/models/__init__.py b/src/snake_ai_pytorch/models/__init__.py new file mode 100644 index 0000000..0bd6e73 --- /dev/null +++ b/src/snake_ai_pytorch/models/__init__.py @@ -0,0 +1,2 @@ +from snake_ai_pytorch.models.agent import Agent # noqa: F401 +from snake_ai_pytorch.models.snake_game_human import SnakeGameHuman # noqa: F401 diff --git a/src/snake_ai_pytorch/models/agent.py b/src/snake_ai_pytorch/models/agent.py new file mode 100755 index 0000000..a1ac2d7 --- /dev/null +++ b/src/snake_ai_pytorch/models/agent.py @@ -0,0 +1,197 @@ +import atexit +import logging +import random +from collections import deque +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch import Tensor + +from snake_ai_pytorch.models.dueling_qnet import DuelingQNet +from snake_ai_pytorch.models.snake_env import SnakeEnv +from snake_ai_pytorch.views import Plotting + +# Configure logging to show info-level messages +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class Agent: + MAX_MEMORY = 100_000 + BATCH_SIZE = 1000 + LR = 0.001 + + def __init__(self, render_mode="human"): + self.n_games = 0 + # Epsilon is now calculated dynamically based on the number of games + # self.epsilon = 0 # randomness + self.gamma = 0.9 # discount rate + self.memory = deque(maxlen=self.MAX_MEMORY) # popleft() + self.render_mode = render_mode + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = DuelingQNet(input_size=11, hidden_size=256, output_size=3).to(self.device) + self.optimizer = optim.Adam(self.model.parameters(), lr=self.LR) + self.criterion = nn.MSELoss() + + def remember(self, state, action, reward, next_state, done): + self.memory.append((state, action, reward, next_state, done)) # popleft if MAX_MEMORY is reached + + def train_long_memory(self): + mini_sample = random.sample(self.memory, self.BATCH_SIZE) if len(self.memory) > self.BATCH_SIZE else self.memory + + states, actions, rewards, next_states, dones = zip(*mini_sample, strict=False) + self._train_step(states, actions, rewards, next_states, dones) + + def train_short_memory(self, state, action, reward, next_state, done): + self._train_step([state], [action], [reward], [next_state], [done]) + + def _train_step(self, states, actions, rewards, next_states, dones): + # 1. Convert to tensors and move to the correct device + states = torch.tensor(np.array(states), dtype=torch.float).to(self.device) + # Actions are now integers, convert to long tensor for indexing + actions = torch.tensor(actions, dtype=torch.long).to(self.device) + rewards = torch.tensor(np.array(rewards), dtype=torch.float).to(self.device) + next_states = torch.tensor(np.array(next_states), dtype=torch.float).to(self.device) + dones = torch.tensor(np.array(dones), dtype=torch.bool).to(self.device) + + # 2. Get predicted Q-values for current state + pred: Tensor = self.model(states) + + # 3. Get Q-values for next state and calculate max + target = pred.clone() + next_pred: Tensor = self.model(next_states) + next_q_values = next_pred.detach() + max_next_q = next_q_values.max(dim=1)[0] + + # 4. Calculate target Q-value (Bellman equation) + # For terminal states (done=True), the future reward is 0 + Q_new = rewards + (self.gamma * max_next_q * (~dones)) + + # 5. Create target tensor by cloning predictions and updating with new Q-values + target[torch.arange(len(dones)), actions] = Q_new + + # 6. Calculate loss + loss = self.criterion(target, pred) + + # 7. Manually perform the optimization + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + def get_action(self, state): + # random moves: tradeoff exploration / exploitation + epsilon = 80 - self.n_games + if random.randint(0, 200) < epsilon: # nosec + move = random.randint(0, 2) # nosec + else: + # Add a batch dimension for the model, which expects a 2D tensor + state0 = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(self.device) + prediction: Tensor = self.model(state0) + move = torch.argmax(prediction).item() + + return move + + def load_existing_model(self): + # Create model directory if it doesn't exist + model_folder_path = Path("model") + model_folder_path.mkdir(parents=True, exist_ok=True) + model_path = model_folder_path / "model.pth" + + # Load existing model if it exists + if model_path.exists(): + logging.info("Loading existing state from model.pth...") + # The saved checkpoint is a dictionary, not just weights. + # `weights_only=False` (the default) is needed. + return torch.load(model_path, weights_only=True), model_path + return None, model_path + + def train(self): + self.play(train=True) + + def play(self, train=False): + total_score = 0 + record = 0 + # Only create a plot if we are in a visual render mode + plotting = None + if self.render_mode == "human": + plotting = Plotting(train) + plotting.start() + # Register a cleanup function to stop the plotting process on exit + atexit.register(plotting.stop) + # Use the new Gymnasium environment + env = SnakeEnv(render_mode=self.render_mode) + + checkpoint, model_path = self.load_existing_model() + + # Load existing model if it exists + if checkpoint: + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + self.n_games = checkpoint["n_games"] + record = checkpoint["record"] + total_score = checkpoint["total_score"] + # To ensure the agent's exploration continues from where it left off + logging.info(f"Resuming from game {self.n_games} with record {record}") + if plotting: + plot_scores = checkpoint.get("plot_scores", []) + plot_mean_scores = checkpoint.get("plot_mean_scores", []) + plotting.load_data(plot_scores, plot_mean_scores) + # If we are in 'play' mode (not training) and there's historical data, add the marker. + if not train and len(plot_scores) > 0: + plotting.add_training_marker(len(plot_scores) - 1) + + # Initial state from the environment + state_old, info = env.reset() + + while True: + # get move (action is now an integer: 0, 1, or 2) + action = self.get_action(state_old) + + # perform move and get new state from environment + state_new, reward, terminated, truncated, info = env.step(action) + score = info["score"] + env.render() + + if train: + # train short memory + self.train_short_memory(state_old, action, reward, state_new, terminated) + + # remember + self.remember(state_old, action, reward, state_new, terminated) + + # The current step is done, update the state for the next iteration + state_old = state_new + + if terminated: + # train long memory, plot result + self.n_games += 1 + if train: + self.train_long_memory() + + if score > record: + record = score + if train: + # Save model checkpoint + checkpoint = { + "n_games": self.n_games, + "record": record, + "total_score": total_score, + "plot_scores": plotting.scores if plotting else [], + "plot_mean_scores": plotting.mean_scores if plotting else [], + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + } + torch.save(checkpoint, model_path) + + logging.info(f"Game: {self.n_games}, Score: {score}, Record: {record}") + + total_score += score + mean_score = total_score / self.n_games + if plotting: + plotting.plot(score, mean_score) + + # Reset the environment and get the new initial state + state_old, info = env.reset() diff --git a/src/snake_ai_pytorch/models/direction.py b/src/snake_ai_pytorch/models/direction.py new file mode 100644 index 0000000..2bd19a5 --- /dev/null +++ b/src/snake_ai_pytorch/models/direction.py @@ -0,0 +1,20 @@ +from enum import Enum + + +class Direction(Enum): + LEFT = 0 + UP = 1 + RIGHT = 2 + DOWN = 3 + + @property + def opposite(self) -> "Direction": + return Direction((self.value + 2) % 4) + + @property + def cw(self) -> "Direction": + return Direction((self.value + 1) % 4) + + @property + def ccw(self) -> "Direction": + return Direction((self.value - 1) % 4) diff --git a/src/snake_ai_pytorch/models/dueling_qnet.py b/src/snake_ai_pytorch/models/dueling_qnet.py new file mode 100644 index 0000000..553e38b --- /dev/null +++ b/src/snake_ai_pytorch/models/dueling_qnet.py @@ -0,0 +1,29 @@ +import torch.nn as nn +import torch.nn.functional as functional +from torch import Tensor + + +class DuelingQNet(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + # Shared feature learning layer + self.linear1 = nn.Linear(input_size, hidden_size) + + # Value stream + self.value_stream = nn.Linear(hidden_size, 1) + + # Advantage stream + self.advantage_stream = nn.Linear(hidden_size, output_size) + + def forward(self, x): + # Pass input through the shared layer + x = functional.relu(self.linear1(x)) + + # Calculate value and advantage + values: Tensor = self.value_stream(x) + advantages: Tensor = self.advantage_stream(x) + + # Combine value and advantage streams to get Q-values + # Q(s, a) = V(s) + (A(s, a) - mean(A(s, a))) + q_values = values + (advantages - advantages.mean(dim=1, keepdim=True)) + return q_values diff --git a/src/snake_ai_pytorch/models/linear_qnet.py b/src/snake_ai_pytorch/models/linear_qnet.py new file mode 100644 index 0000000..e3319ae --- /dev/null +++ b/src/snake_ai_pytorch/models/linear_qnet.py @@ -0,0 +1,14 @@ +import torch.nn as nn +import torch.nn.functional as functional + + +class LinearQNet(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.linear1 = nn.Linear(input_size, hidden_size) + self.linear2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + x = functional.relu(self.linear1(x)) + x = self.linear2(x) + return x diff --git a/src/snake_ai_pytorch/models/point.py b/src/snake_ai_pytorch/models/point.py new file mode 100644 index 0000000..bbac0c8 --- /dev/null +++ b/src/snake_ai_pytorch/models/point.py @@ -0,0 +1,19 @@ +from collections import namedtuple + +import numpy as np + + +class Point(namedtuple("Point", "x, y")): + def to_array(self): + return np.array(self) + + def __sub__(self, other: "Point") -> "Point": + """Subtract one from another.""" + return Point(self.x - other.x, self.y - other.y) + + def __add__(self, other: "Point") -> "Point": + """Add one to another.""" + return Point(self.x + other.x, self.y + self.y) + + def distance_to(self, other: "Point"): + return np.linalg.norm((self - other).to_array()) diff --git a/src/snake_ai_pytorch/models/renderer.py b/src/snake_ai_pytorch/models/renderer.py new file mode 100644 index 0000000..2725c2f --- /dev/null +++ b/src/snake_ai_pytorch/models/renderer.py @@ -0,0 +1,33 @@ +from typing import TYPE_CHECKING + +import pygame + +from snake_ai_pytorch.views.visual_configuration import BLOCK_SIZE, SPEED, FontConfig, GameColors + +if TYPE_CHECKING: + from snake_ai_pytorch.models.snake_game import SnakeGame + + +class Renderer: + def __init__(self, game: "SnakeGame"): + pygame.init() + self.game = game + self.font = pygame.font.Font(FontConfig.path, FontConfig.size) + # init display + self.display = pygame.display.set_mode((self.game.w, self.game.h)) + pygame.display.set_caption("Snake") + self.clock = pygame.time.Clock() + + def render(self, render_fps=SPEED): + self.display.fill(GameColors.BLACK) + + for pt in self.game.snake: + pygame.draw.rect(self.display, GameColors.BLUE1, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE)) + pygame.draw.rect(self.display, GameColors.BLUE2, pygame.Rect(pt.x + 4, pt.y + 4, 12, 12)) + + pygame.draw.rect(self.display, GameColors.RED, pygame.Rect(self.game.food.x, self.game.food.y, BLOCK_SIZE, BLOCK_SIZE)) + + text = self.font.render("Score: " + str(self.game.score), True, GameColors.WHITE) + self.display.blit(text, [0, 0]) + pygame.display.flip() + self.clock.tick(render_fps) diff --git a/src/snake_ai_pytorch/models/snake_env.py b/src/snake_ai_pytorch/models/snake_env.py new file mode 100644 index 0000000..b9c7a28 --- /dev/null +++ b/src/snake_ai_pytorch/models/snake_env.py @@ -0,0 +1,92 @@ +import gymnasium as gym +import numpy as np +from gymnasium import spaces + +from snake_ai_pytorch.models.direction import Direction +from snake_ai_pytorch.models.point import Point +from snake_ai_pytorch.models.snake_game_ai import SnakeGameAI +from snake_ai_pytorch.views.visual_configuration import BLOCK_SIZE + + +class SnakeEnv(gym.Env): + """A custom Gymnasium environment for the Snake game.""" + + metadata = {"render_modes": ["human"], "render_fps": 400} + + def __init__(self, w=640, h=480, render_mode="human"): + super().__init__() + self.game = SnakeGameAI(w, h, render_mode=render_mode) + self.render_mode = render_mode + + # Action space: 0: straight, 1: right turn, 2: left turn + self.action_space = spaces.Discrete(3) + + # Observation space: 11 boolean values + self.observation_space = spaces.Box(low=0, high=1, shape=(11,), dtype=np.int32) + + def _get_obs(self): + """Generates the observation from the current game state. + + This logic was previously in Agent.get_state(). + """ + head = self.game.snake[0] + point_l = Point(head.x - BLOCK_SIZE, head.y) + point_r = Point(head.x + BLOCK_SIZE, head.y) + point_u = Point(head.x, head.y - BLOCK_SIZE) + point_d = Point(head.x, head.y + BLOCK_SIZE) + + dir_l = self.game.direction == Direction.LEFT + dir_r = self.game.direction == Direction.RIGHT + dir_u = self.game.direction == Direction.UP + dir_d = self.game.direction == Direction.DOWN + + state = [ + # Danger straight + (dir_r and self.game.is_collision(point_r)) + or (dir_l and self.game.is_collision(point_l)) + or (dir_u and self.game.is_collision(point_u)) + or (dir_d and self.game.is_collision(point_d)), + # Danger right + (dir_u and self.game.is_collision(point_r)) + or (dir_d and self.game.is_collision(point_l)) + or (dir_l and self.game.is_collision(point_u)) + or (dir_r and self.game.is_collision(point_d)), + # Danger left + (dir_d and self.game.is_collision(point_r)) + or (dir_u and self.game.is_collision(point_l)) + or (dir_r and self.game.is_collision(point_u)) + or (dir_l and self.game.is_collision(point_d)), + # Move direction + dir_l, + dir_r, + dir_u, + dir_d, + # Food location + self.game.food.x < self.game.head.x, # food left + self.game.food.x > self.game.head.x, # food right + self.game.food.y < self.game.head.y, # food up + self.game.food.y > self.game.head.y, # food down + ] + return np.array(state, dtype=np.int32) + + def _get_info(self): + return {"score": self.game.score} + + def reset(self, seed=None, options=None): + super().reset(seed=seed) + self.game.reset() + observation = self._get_obs() + info = self._get_info() + return observation, info + + def step(self, action): + action_array = [0, 0, 0] + action_array[action] = 1 + + reward, terminated, score = self.game.play_step(action_array) + observation = self._get_obs() + info = self._get_info() + return observation, reward, terminated, False, info # truncated is always False + + def render(self): + self.game.render(self.metadata["render_fps"]) diff --git a/src/snake_ai_pytorch/models/snake_game.py b/src/snake_ai_pytorch/models/snake_game.py new file mode 100644 index 0000000..b9bb6c2 --- /dev/null +++ b/src/snake_ai_pytorch/models/snake_game.py @@ -0,0 +1,92 @@ +import random + +import pygame + +from snake_ai_pytorch.models.direction import Direction +from snake_ai_pytorch.models.point import Point +from snake_ai_pytorch.models.renderer import Renderer +from snake_ai_pytorch.views.visual_configuration import BLOCK_SIZE, SPEED + + +class SnakeGame: + renderer: Renderer = None + + def __init__(self, w=640, h=480, render_mode="human"): + self.w = w + self.h = h + if render_mode == "human": + self.renderer = Renderer(self) + self.reset() + + def reset(self): + # init game state + self.direction = Direction.RIGHT + + self.head = Point(self.w / 2, self.h / 2) + self.snake = [self.head, Point(self.head.x - BLOCK_SIZE, self.head.y), Point(self.head.x - (2 * BLOCK_SIZE), self.head.y)] + + self.score = 0 + self.food = None + self._place_food() + + def _place_food(self): + x = random.randint(0, (self.w - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE # nosec + y = random.randint(0, (self.h - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE # nosec + self.food = Point(x, y) + if self.food in self.snake: + self._place_food() + + def play_step(self, direction): + # 2. move + self._move(direction) # update the head + self.snake.insert(0, self.head) + + # 3. check if game over + game_over = False + if self.is_collision(): + game_over = True + else: + # 4. place new food or just move + if self.head == self.food: + self.score += 1 + self._place_food() + else: + self.snake.pop() + + # 6. return game over and score + return game_over, self.score + + def is_collision(self, pt=None): + if pt is None: + pt = self.head + # hits boundary + if pt.x > self.w - BLOCK_SIZE or pt.x < 0 or pt.y > self.h - BLOCK_SIZE or pt.y < 0: + return True + # hits itself + return pt in self.snake[1:] + + def _move(self, direction): + x = self.head.x + y = self.head.y + if direction == Direction.RIGHT: + x += BLOCK_SIZE + elif direction == Direction.LEFT: + x -= BLOCK_SIZE + elif direction == Direction.DOWN: + y += BLOCK_SIZE + elif direction == Direction.UP: + y -= BLOCK_SIZE + + self.head = Point(x, y) + + def render(self, render_fps=SPEED): + if self.renderer is None: + return + + # This is the standard way to handle rendering and events in a Pygame-based gym env + for event in pygame.event.get(): + if event.type == pygame.QUIT: + pygame.quit() + quit() + + self.renderer.render(render_fps) diff --git a/src/snake_ai_pytorch/models/snake_game_ai.py b/src/snake_ai_pytorch/models/snake_game_ai.py new file mode 100755 index 0000000..5d30e60 --- /dev/null +++ b/src/snake_ai_pytorch/models/snake_game_ai.py @@ -0,0 +1,66 @@ +import numpy as np + +from snake_ai_pytorch.models.snake_game import SnakeGame + + +class SnakeGameAI(SnakeGame): + """A subclass of SnakeGame that is designed for AI agents to interact with. + + It provides methods to get the current state of the game and to perform actions. + """ + + REWARD_FOOD = 10 + REWARD_GAMEOVER = -10 + REWARD_MOVE = -0.01 + REWARD_CHANGE_DIRECTION = -0.02 + REWARD_DIRECTION_KEPT = 0 + + def __init__(self, w=640, h=480, render_mode="human"): + super().__init__(w, h, render_mode) + + def reset(self): + """Reset the game state to the initial conditions.""" + super().reset() + self.frame_iteration = 0 + + def play_step(self, action): + self.frame_iteration += 1 + + # 1. determine new direction from action + reward = 0 + reward += self._determine_direction(action) + + # 2. move + self._move(self.direction) # update the head + self.snake.insert(0, self.head) + + # 3. check if game over + game_over = False + if self.is_collision() or self.frame_iteration > 100 * len(self.snake): + game_over = True + reward += self.REWARD_GAMEOVER + return reward, game_over, self.score + + # 4. place new food or just move + if self.head == self.food: + self.score += 1 + reward += self.REWARD_FOOD + self._place_food() + else: + reward += self.REWARD_MOVE + self.snake.pop() + + return reward, game_over, self.score + + def _determine_direction(self, action): + """Determines the new direction based on the AI's action. + + Action is a 3-element list: [straight, right_turn, left_turn] + """ + if np.array_equal(action, [0, 1, 0]): # right turn + self.direction = self.direction.cw + return self.REWARD_CHANGE_DIRECTION + elif np.array_equal(action, [0, 0, 1]): # left turn + self.direction = self.direction.ccw + return self.REWARD_CHANGE_DIRECTION + return self.REWARD_DIRECTION_KEPT diff --git a/src/snake_ai_pytorch/models/snake_game_human.py b/src/snake_ai_pytorch/models/snake_game_human.py new file mode 100644 index 0000000..0f2ca3f --- /dev/null +++ b/src/snake_ai_pytorch/models/snake_game_human.py @@ -0,0 +1,60 @@ +import logging + +import pygame + +from snake_ai_pytorch.models.direction import Direction +from snake_ai_pytorch.models.snake_game import SnakeGame + + +class SnakeGameHuman: + """An environment to run the human-playable version of the Snake game. + + This class is responsible for the game loop, handling user input, + and rendering the game state. It separates the game logic from the + application's main loop. + """ + + def __init__(self, w=640, h=480): + self.game = SnakeGame(w, h, render_mode="human") + self.speed = 15 # Set a comfortable speed for human play + self.key_direction_map = { + pygame.K_LEFT: Direction.LEFT, + pygame.K_RIGHT: Direction.RIGHT, + pygame.K_UP: Direction.UP, + pygame.K_DOWN: Direction.DOWN, + } + self.allowed_keys = [pygame.K_LEFT, pygame.K_RIGHT, pygame.K_UP, pygame.K_DOWN] + + def run(self): + """Starts and manages the main game loop.""" + running = True + while running: + # 1. Handle user input + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + if event.type == pygame.KEYDOWN: + # Prevent the snake from reversing on itself + key_direction = self.key_direction_map[event.key] + if event.key in self.allowed_keys and self.game.direction != key_direction.opposite: + self.game.direction = key_direction + + # 2. Advance the game state + game_over, score = self.game.play_step(self.game.direction) + + # 3. Render the game + self.game.render(self.speed) + + # 4. Check for game over + if game_over: + logging.info(f"Final Score: {score}") + # A short delay to see the final score before the window closes + pygame.time.wait(1500) + running = False + + pygame.quit() + + +if __name__ == "__main__": + env = SnakeGameHuman() + env.run() diff --git a/src/snake_ai_pytorch/views/__init__.py b/src/snake_ai_pytorch/views/__init__.py new file mode 100644 index 0000000..247c641 --- /dev/null +++ b/src/snake_ai_pytorch/views/__init__.py @@ -0,0 +1,2 @@ +from snake_ai_pytorch.views.progress_graph import Plotting # noqa: F401 +from snake_ai_pytorch.views.visual_configuration import BLOCK_SIZE, SPEED, FontConfig, GameColors # noqa: F401 diff --git a/src/snake_ai_pytorch/views/progress_graph.py b/src/snake_ai_pytorch/views/progress_graph.py new file mode 100755 index 0000000..39f4772 --- /dev/null +++ b/src/snake_ai_pytorch/views/progress_graph.py @@ -0,0 +1,110 @@ +import multiprocessing as mp + + +class Plotting: + def __init__(self, train): + # We import pyplot in the child process to avoid issues with forking on some OSes. + def plotter_process(queue, train): + """Runs in a separate process to handle plotting.""" + import matplotlib.pyplot as plt + + plt.ion() + fig, ax = plt.subplots() + ax.set_title(f"{'Training' if train else 'Playing'}...") + ax.set_xlabel("Number of Games") + ax.set_ylabel("Score") + (score_line,) = ax.plot([], [], label="Score") + (mean_score_line,) = ax.plot([], [], label="Mean Score") + score_text = ax.text(0, 0, "") + mean_score_text = ax.text(0, 0, "") + ax.legend(loc="upper left") + + scores = [] + mean_scores = [] + training_end_marker = None + vline = None + + def update_plot_data(): + nonlocal vline + + x_data = range(len(scores)) + score_line.set_data(x_data, scores) + mean_score_line.set_data(x_data, mean_scores) + + if scores: + last_game_idx = len(scores) - 1 + score_text.set_position((last_game_idx, scores[-1])) + score_text.set_text(str(scores[-1])) + mean_score_text.set_position((last_game_idx, mean_scores[-1])) + mean_score_text.set_text(f"{mean_scores[-1]:.2f}") + + if training_end_marker is not None and vline is None: + vline = ax.axvline(x=training_end_marker, color="r", linestyle="--", label="End of Training") + ax.legend(loc="upper left") + + ax.relim() + ax.autoscale_view() + ax.set_ylim(bottom=0, top=max(scores)) + fig.canvas.draw() + fig.canvas.flush_events() + + while True: + try: + if not queue.empty(): + data = queue.get() + if data is None: # Sentinel for stopping + break + + command, values = data + if command == "load": + scores, mean_scores = values + elif command == "plot": + score, mean_score = values + scores.append(score) + mean_scores.append(mean_score) + elif command == "add_marker": + training_end_marker = values + update_plot_data() + + plt.pause(0.1) + except (KeyboardInterrupt, BrokenPipeError): + break + except Exception: # Catches exceptions if the window is closed manually + break + + plt.ioff() + plt.close(fig) + + # Use a multiprocessing queue for safe data exchange + self.queue = mp.Queue() + self.process = mp.Process(target=plotter_process, args=(self.queue, train), daemon=True) + # The agent still needs to track scores for saving checkpoints + self.scores = [] + self.mean_scores = [] + + def start(self): + self.process.start() + + def stop(self): + """Send a signal to stop the plotting process.""" + if self.process.is_alive(): + self.queue.put(None) # Send sentinel + self.process.join(timeout=1) + if self.process.is_alive(): + self.process.terminate() # Forcefully stop if it doesn't close + + def load_data(self, scores, mean_scores): + """Load existing data and send it to the plotting process.""" + self.scores = scores + self.mean_scores = mean_scores + self.queue.put(("load", (scores, mean_scores))) + + def plot(self, score, mean_score): + """Append new data points and send them to the plotting process.""" + self.scores.append(score) + self.mean_scores.append(mean_score) + self.queue.put(("plot", (score, mean_score))) + + def add_training_marker(self, game_number): + """Sends a command to draw the training/playing delimiter.""" + self.queue.put(("add_marker", game_number)) diff --git a/src/snake_ai_pytorch/views/visual_configuration.py b/src/snake_ai_pytorch/views/visual_configuration.py new file mode 100644 index 0000000..5129240 --- /dev/null +++ b/src/snake_ai_pytorch/views/visual_configuration.py @@ -0,0 +1,22 @@ +from pathlib import Path + + +class GameColors: + WHITE = (255, 255, 255) + RED = (200, 0, 0) + BLUE1 = (0, 0, 255) + BLUE2 = (0, 100, 255) + BLACK = (0, 0, 0) + + +BASE_DIR = Path(__file__).resolve().parent.parent + + +class FontConfig: + # Construct a path to the font file relative to this script's location + path = BASE_DIR / "assets" / "arial.ttf" + size = 25 + + +BLOCK_SIZE = 20 +SPEED = 75