Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 69 additions & 37 deletions agents/gpt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List
from api.classes import Agent, AvailableActions, Action, Observation, Rules
import random
import openai
import api.util as util
import ast
import json
from PIL import Image


action_format_instructions_no_openended = """\
Expand All @@ -30,29 +31,38 @@

@dataclass
class OpenAITextAgent(Agent):
openai_model: str
agent_type_id: str
system_message: str = "You are an agent playing a game. Select the action that maximizes your probability of winning."
max_retries: int = 3
transparent_reasoning: bool = False
openai_model : str
agent_type_id : str
system_message : str = "You are an agent playing a game. Select the action that maximizes your probability of winning."
max_retries : int = 3
transparent_reasoning : bool = False
mode: int = 0 # 0 = normal, 1 = chain of thought, 2 = babble and prune

def print(self, *args, **kwargs):
if self.transparent_reasoning:
print(self.agent_type_id, *args, **kwargs)

def get_user_message_content(self, text_prompt: str, image: Image) -> List[Dict[str, Any]]:
return [
{
"type": "text",
"text": text_prompt
}
]

def take_action(
self,
rules: Rules,
observation: Observation,
available_actions: AvailableActions,
show_state: bool,
):
def get_request_params(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
return {
"model": self.openai_model,
"messages": messages,
"response_format": { "type": "json_object" }
}

def take_action(self, rules: Rules, observation: Observation, available_actions: AvailableActions, show_state : bool) -> Action:
valid_actions = []
prompt = f"You are playing a game called {rules.title}. The rules are as follows:\n{rules.summary}\n"
if rules.additional_details != None:
prompt += "The following are headings with additional information about the rules that you can expand by taking the action Explain(<heading key>).\n"
details_dict = {f"H{i+1}": topic for i, topic in enumerate(rules.additional_details)}
details_dict = {f"H{i+1}": topic + " - " + description for i, (topic, description) in enumerate(rules.additional_details.items())}
prompt += json.dumps(details_dict, indent=4)
valid_actions.extend(f"Explain({h})" for h in list(details_dict.keys()))

Expand Down Expand Up @@ -83,7 +93,10 @@ def take_action(
):
prompt += "Return the action Explain(<action>) to receive additional info about what any of the above actions do.\n"

messages = [{"role": "system", "content": self.system_message}]
messages = [
{"role": "system", "content": self.system_message},
{"role": "user", "content": self.get_user_message_content(prompt, observation.image)},
]

# Chain of Thought
if self.mode == 1:
Expand Down Expand Up @@ -117,7 +130,7 @@ def take_action(
)
messages.append({"role": "assistant", "content": response})
prompt = ""

self.print(
f"GPT listed the following actions as possibilities: {response}"
)
Expand All @@ -126,24 +139,20 @@ def take_action(
prompt += str(list(available_actions.predefined))
prompt += "\nOr if you choose an openended action, you must return json with an 'action' key which contains one of the following valid actions and an 'openeded_response' key which contains your reponse to the prompt:\n"
prompt += str(list(available_actions.openended))
#prompt += "\nMake sure to return ONLY a JSON. It should contain nothing outside the curly braces of the JSON."
messages.append({"role": "user", "content": prompt})



#print(prompt)
result = None

for _ in range(self.max_retries):
response = (
openai_client.chat.completions.create(
model=self.openai_model,
response_format={"type": "json_object"},
messages=messages,
)
.choices[0]
.message.content
)
response = openai_client.chat.completions.create(**self.get_request_params(messages)).choices[0].message.content
messages.append({"role": "assistant", "content": response})
self.print("GPT responded with", response)

try:
action = ast.literal_eval(response)
action = ast.literal_eval(util.extract_json(response))
except:
self.print("GPT returned invalid JSON")
continue
Expand All @@ -154,27 +163,23 @@ def take_action(
messages.append({"role": "user", "content": error_message})
continue


if action["action"] in valid_actions:
self.print("GPT chose valid action", action)
result = action
break

self.print("GPT returned invalid action", action)
error_message = f"{action['action']} is not one of the valid actions. "
error_message += "As a reminder, the valid actions are as follows:\n"
error_message += f"{str(list(valid_actions))}\n"
error_message += "Please return a json with the key 'action' with the action you choose and (optionally) the key 'openended_response' if you select openended response action."
messages.append({"role": "user", "content": error_message})

if result == None:
self.print(
f"WARNING: GPT returned an a random action after {self.max_retries} tries"
)
self.print(f"WARNING: GPT returned an a random action after {self.max_retries} tries")
return Action(action_id=None)
return Action(
action_id=result["action"],
openended_response=result.get("openended_response"),
)

return Action(action_id=result["action"], openended_response=result.get("openended_response"))

@dataclass
class ChatGPTText(OpenAITextAgent):
Expand All @@ -199,3 +204,30 @@ class BabbleAndPrune(OpenAITextAgent):
openai_model: str = "gpt-4-1106-preview"
agent_type_id: str = "b&p"
mode: int = 2

@dataclass
class GPT4Vision(OpenAITextAgent):
openai_model : str = "gpt-4-vision-preview"
agent_type_id : str = "gpt-4-vision"
is_vision_agent : bool = True

def get_user_message_content(self, text_prompt: str, image: Image) -> List[Dict[str, str]]:
content = super().get_user_message_content(text_prompt, image)
if image is not None:
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{util.base64_encode_image(image)}",
"detail": "low"
}
})
return content

def get_request_params(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
return {
"model": self.openai_model,
"messages": messages,
# As vision models have a low(but undocumented?) default value for below parameter
# https://community.openai.com/t/documented-max-token-default-is-incorrect-for-gpt-4-vision-preview/507329
"max_tokens": 600,
}
2 changes: 1 addition & 1 deletion agents/random_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
class RandomAgent(Agent):
agent_type_id : str = "random"

def take_action(self, rules : Rules, observation: Observation, available_actions: AvailableActions, show_state : bool):
def take_action(self, rules : Rules, observation: Observation, available_actions: AvailableActions, show_state : bool) -> Action:
actions = list(available_actions.predefined.keys())
return Action(action_id=random.choice(actions))
24 changes: 20 additions & 4 deletions api/classes.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
from typing import List, Dict, Optional, Tuple
from typing import List, Dict, Optional, Tuple, Type
from dataclasses import dataclass, field
from abc import abstractmethod
from PIL import Image


@dataclass
class Observation:
text : str
text : str = ""
image : Image = None

def __eq__(self, other):
if not isinstance(other, Observation):
return False

# Check text equality
if self.text != other.text:
return False

# Check image equality
if self.image is None and other.image is None:
return True
elif self.image is None or other.image is None:
return False
else:
return (self.image.tobytes() == other.image.tobytes())

@dataclass
class AvailableActions:
Expand All @@ -30,7 +46,7 @@ class Agent:
agent_type_id : str

@abstractmethod
def take_action(self, rules : dict, observation: Observation, available_actions : AvailableActions):
def take_action(self, rules : dict, observation: Observation, available_actions : AvailableActions, show_state : bool) -> Action:
pass

@dataclass
Expand All @@ -51,7 +67,7 @@ class Game:
agent_2_kwargs : dict = field(default_factory=dict) # kwargs to pass to the agent 2 class when initializing.

@abstractmethod
def init_game(self, agent_1: Agent, agent_2: Agent):
def init_game(self, agent_1: Type[Agent], agent_2: Type[Agent]):
pass

@abstractmethod
Expand Down
4 changes: 3 additions & 1 deletion api/play_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,10 @@ def play_game(agent_1_path, agent_2_path, game_path, num_matches = 1, save_resul
util.save_json(matches, "matches.json")
print("Saved match information")

agent_1_rating = agent_1_rating + K * (player_1_score - agent_1_expected_score)
agent_1_rating = agent_1_rating + K * (player_1_score - agent_1_expected_score)
agent_2_rating = agent_2_rating + K * (player_2_score - agent_2_expected_score)
# Without below line, we get a KeyError: '<game_class.id>'
all_ratings.setdefault(game_class.id, {})
all_ratings[game_class.id][agent_1_id] = agent_1_rating
all_ratings[game_class.id][agent_2_id] = agent_2_rating
print("Updated elos:")
Expand Down
23 changes: 21 additions & 2 deletions api/util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import importlib
import os
import json
from PIL import Image
from io import BytesIO
import base64
import re

def save_json(data, file_path):
if not os.path.exists(file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as f:
json.dump(data, f, indent=4)


def load_json(file_path):
if not os.path.exists(file_path):
raise ValueError(f"File {file_path} does not exist")
Expand All @@ -18,4 +21,20 @@ def load_json(file_path):
def import_class(class_path):
module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
return getattr(module, class_name)

def base64_encode_image(image: Image) -> str:
img_buffer = BytesIO()
image.save(img_buffer, format="PNG")
img_str = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
return img_str

def extract_json(input: str) -> dict:
json_match = re.search(r'{.*}', input, re.DOTALL)
if json_match == None:
raise ValueError(f"Could not find JSON in input: {input}")
json_content = json_match.group(0)
return json_content
# Parse the JSON content into a Python dictionary
response_data = json.loads(json_content)
return response_data
64 changes: 64 additions & 0 deletions games/atari/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Multiplayer Atari Games via PettingZoo
This folder contains ported Atari Games made available in [PettingZoo](https://pettingzoo.farama.org/), a multi agent games environment with an API similar to [OpenAI gym](https://gymnasium.farama.org/).


## Installation

Running these Atari games make use of [AtariARI](github.com/mila-iqia/atari-representation-learning.git
) and [PettingZoo](https://pettingzoo.farama.org/environments/atari/boxing/) libraries

### 0. Install pip3

If not already installed(i.e. if pip3 command not found), install pip3:
> sudo apt-get install python3-pip

Then, upgrade pip

>python3 -m pip install --upgrade pip

### 1. Install AtariARI

Successfully run below 2 commands

>pip3 install 'gym[atari]'

>pip3 install git+https://github.com/mila-iqia/atari-representation-learning.git

### 2. Install PettingZoo

Run
>pip3 install 'pettingzoo[atari]

### 3. Install misc libraries

> pip3 install matplotlib

> pip3 install autorom

> AutoROM

## PettingZoo implementation of realtime games

Atari games were suppsosed to appear realtime for humans, but under the hood they are programmed as turn based games with tens of turns per second.

To a human, a game running at full speed still appears realtime.

PettingZoo models these games as [Agent Environment Cycle](https://pettingzoo.farama.org/api/aec/) environments.

![Alt text](image.png)

At each step, a player(depending on turn) is queried for their next move.

## GameBench implementation of PettingZoo games

Agents are run in background threads. The agent loop is:

1. Get current game state
2. Query agent on what action should be done
3. Store this action in a variable Act

At every turn, we query the stored action Act for that player and execute it.

## Current list of games

1. [Boxing](https://pettingzoo.farama.org/environments/atari/boxing/)
Loading