Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion baselines/memory_addresses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@

MONEY_ADDRESS_1 = 0xD347
MONEY_ADDRESS_2 = 0xD348
MONEY_ADDRESS_3 = 0xD349
MONEY_ADDRESS_3 = 0xD349

SEEN_POKEMONS_ADDRESSES = [0xD30A, 0xD30B, 0xD30C, 0xD30D, 0xD30E, 0xD30F, 0xD310, 0xD311, 0xD312, 0xD313, 0xD314, 0xD315, 0xD316, 0xD317, 0xD318, 0xD319, 0xD31A, 0xD31B, 0xD31C]
120 changes: 120 additions & 0 deletions baselines/reader_pyboy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import memory_addresses as addresses


class ReaderPyBoy:

def __init__(self, pyboy):
self.pyboy = pyboy

def read_m(self, addr):
return self.pyboy.get_memory_value(addr)

def read_money(self):
return (100 * 100 * self.read_bcd(self.read_m(addresses.MONEY_ADDRESS_1)) +
100 * self.read_bcd(self.read_m(addresses.MONEY_ADDRESS_2)) +
self.read_bcd(self.read_m(addresses.MONEY_ADDRESS_3)))

def read_bcd(self, num):
return 10 * ((num >> 4) & 0x0f) + (num & 0x0f)

def read_bit(self, addr, bit: int) -> bool:
# add padding so zero will read '0b100000000' instead of '0b0'
return bin(256 + self.read_m(addr))[-bit-1] == '1'

def read_hp_fraction(self):
hp_sum = sum([self.read_hp(add) for add in addresses.HP_ADDRESSES])
max_hp_sum = sum([self.read_hp(add) for add in addresses.MAX_HP_ADDRESSES])
max_hp_sum = max(max_hp_sum, 1)
return hp_sum / max_hp_sum

def read_hp(self, start):
return 256 * self.read_m(start) + self.read_m(start+1)

# built-in since python 3.10
def bit_count(self, bits):
return bin(bits).count('1')

def read_triple(self, start_add):
return 256*256*self.read_m(start_add) + 256*self.read_m(start_add+1) + self.read_m(start_add+2)

def get_badges(self):
return self.bit_count(self.read_m(addresses.BADGE_COUNT_ADDRESS))

def get_opponent_level(self):
return max([self.read_m(a) for a in addresses.OPPONENT_LEVELS_ADDRESSES]) - 5

def read_party(self):
return [self.read_m(addr) for addr in addresses.PARTY_ADDRESSES]

def get_levels_sum(self):
poke_levels = [max(self.read_m(a) - 2, 0) for a in addresses.LEVELS_ADDRESSES]
return max(sum(poke_levels) - 4, 0) # subtract starting pokemon level

def read_party_size_address(self):
return self.read_m(addresses.PARTY_SIZE_ADDRESS)

def read_x_pos(self):
return self.read_m(addresses.X_POS_ADDRESS)

def read_y_pos(self):
return self.read_m(addresses.Y_POS_ADDRESS)

def read_map_n(self):
return self.read_m(addresses.MAP_N_ADDRESS)

def read_events(self):
return [
self.bit_count(self.read_m(i))
for i in range(addresses.EVENT_FLAGS_START_ADDRESS, addresses.EVENT_FLAGS_END_ADDRESS)
]

def read_museum_tickets(self):
museum_ticket = (addresses.MUSEUM_TICKET_ADDRESS, 0)
return self.read_bit(museum_ticket[0], museum_ticket[1])

def read_levels(self):
return [self.read_m(a) for a in addresses.LEVELS_ADDRESSES]

def read_seen_pokemons(self):
return [self.bit_count(self.read_m(a)) for a in addresses.SEEN_POKEMONS_ADDRESSES]

def get_map_location(self):
map_locations = {
0: "Pallet Town",
1: "Viridian City",
2: "Pewter City",
3: "Cerulean City",
12: "Route 1",
13: "Route 2",
14: "Route 3",
15: "Route 4",
33: "Route 22",
37: "Red house first",
38: "Red house second",
39: "Blues house",
40: "oaks lab",
41: "Pokémon Center (Viridian City)",
42: "Poké Mart (Viridian City)",
43: "School (Viridian City)",
44: "House 1 (Viridian City)",
47: "Gate (Viridian City/Pewter City) (Route 2)",
49: "Gate (Route 2)",
50: "Gate (Route 2/Viridian Forest) (Route 2)",
51: "viridian forest",
52: "Pewter Museum (floor 1)",
53: "Pewter Museum (floor 2)",
54: "Pokémon Gym (Pewter City)",
55: "House with disobedient Nidoran♂ (Pewter City)",
56: "Poké Mart (Pewter City)",
57: "House with two Trainers (Pewter City)",
58: "Pokémon Center (Pewter City)",
59: "Mt. Moon (Route 3 entrance)",
60: "Mt. Moon",
61: "Mt. Moon",
68: "Pokémon Center (Route 4)",
193: "Badges check gate (Route 22)"
}
if self.read_map_n() in map_locations.keys():
return map_locations[self.read_map_n()]
else:
return "Unknown Location"
130 changes: 27 additions & 103 deletions baselines/red_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import hnswlib
import mediapy as media
import pandas as pd
from reader_pyboy import ReaderPyBoy

from gymnasium import Env, spaces
from pyboy.utils import WindowEvent
Expand Down Expand Up @@ -107,6 +108,7 @@ def __init__(
window_type=head,
hide_window='--quiet' in sys.argv,
)
self.reader = ReaderPyBoy(self.pyboy)

self.screen = self.pyboy.botsupport_manager().screen()

Expand Down Expand Up @@ -210,11 +212,11 @@ def step(self, action):
self.update_seen_coords()

self.update_heal_reward()
self.party_size = self.read_m(PARTY_SIZE_ADDRESS)
self.party_size = self.reader.read_party_size_address()

new_reward, new_prog = self.update_reward()

self.last_health = self.read_hp_fraction()
self.last_health = self.reader.read_hp_fraction()

# shift over short term reward memory
self.recent_memory = np.roll(self.recent_memory, 3)
Expand Down Expand Up @@ -260,25 +262,25 @@ def add_video_frame(self):
self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False))

def append_agent_stats(self, action):
x_pos = self.read_m(X_POS_ADDRESS)
y_pos = self.read_m(Y_POS_ADDRESS)
map_n = self.read_m(MAP_N_ADDRESS)
levels = [self.read_m(a) for a in LEVELS_ADDRESSES]
x_pos = self.reader.read_x_pos()
y_pos = self.reader.read_y_pos()
map_n = self.reader.read_map_n()
levels = self.reader.read_levels()
if self.use_screen_explore:
expl = ('frames', self.knn_index.get_current_count())
else:
expl = ('coord_count', len(self.seen_coords))
self.agent_stats.append({
'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n,
'map_location': self.get_map_location(map_n),
'map_location': self.reader.get_map_location(),
'last_action': action,
'pcount': self.read_m(PARTY_SIZE_ADDRESS),
'pcount': self.reader.read_party_size_address(),
'levels': levels,
'levels_sum': sum(levels),
'ptypes': self.read_party(),
'hp': self.read_hp_fraction(),
'ptypes': self.reader.read_party(),
'hp': self.reader.read_hp_fraction(),
expl[0]: expl[1],
'deaths': self.died_count, 'badge': self.get_badges(),
'deaths': self.died_count, 'badge': self.reader.get_badges(),
'event': self.progress_reward['event'], 'healr': self.total_healing_rew
})

Expand All @@ -304,9 +306,9 @@ def update_frame_knn_index(self, frame_vec):
)

def update_seen_coords(self):
x_pos = self.read_m(X_POS_ADDRESS)
y_pos = self.read_m(Y_POS_ADDRESS)
map_n = self.read_m(MAP_N_ADDRESS)
x_pos = self.reader.read_m(X_POS_ADDRESS)
y_pos = self.reader.read_m(Y_POS_ADDRESS)
map_n = self.reader.read_m(MAP_N_ADDRESS)
coord_string = f"x:{x_pos} y:{y_pos} m:{map_n}"
if self.get_levels_sum() >= 22 and not self.levels_satisfied:
self.levels_satisfied = True
Expand All @@ -322,7 +324,7 @@ def update_reward(self):
new_prog = self.group_rewards()
new_total = sum([val for _, val in self.progress_reward.items()]) #sqrt(self.explore_reward * self.progress_reward)
new_step = new_total - self.total_reward
if new_step < 0 and self.read_hp_fraction() > 0:
if new_step < 0 and self.reader.read_hp_fraction() > 0:
#print(f'\n\nreward went down! {self.progress_reward}\n\n')
self.save_screenshot('neg_reward')

Expand All @@ -337,7 +339,7 @@ def group_rewards(self):
prog = self.progress_reward
# these values are only used by memory
return (prog['level'] * 100 / self.reward_scale,
self.read_hp_fraction()*2000,
self.reader.read_hp_fraction()*2000,
prog['explore'] * 150 / (self.explore_weight * self.reward_scale))
#(prog['events'],
# prog['levels'] + prog['party_xp'],
Expand Down Expand Up @@ -371,7 +373,7 @@ def make_reward_channel(r_val):
make_reward_channel(explore)
), axis=-1)

if self.get_badges() > 0:
if self.reader.get_badges() > 0:
full_memory[:, -1, :] = 255

return full_memory
Expand All @@ -389,7 +391,7 @@ def check_if_done(self):
done = True
else:
done = self.step_count >= self.max_steps
#done = self.read_hp_fraction() == 0
#done = self.reader.read_hp_fraction() == 0
return done

def save_and_print_info(self, done, obs_memory):
Expand Down Expand Up @@ -427,16 +429,9 @@ def save_and_print_info(self, done, obs_memory):
json.dump(self.all_runs, f)
pd.DataFrame(self.agent_stats).to_csv(
self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a')

def read_m(self, addr):
return self.pyboy.get_memory_value(addr)

def read_bit(self, addr, bit: int) -> bool:
# add padding so zero will read '0b100000000' instead of '0b0'
return bin(256 + self.read_m(addr))[-bit-1] == '1'

def get_levels_sum(self):
poke_levels = [max(self.read_m(a) - 2, 0) for a in LEVELS_ADDRESSES]
poke_levels = self.reader.read_levels()
return max(sum(poke_levels) - 4, 0) # subtract starting pokemon level

def get_levels_reward(self):
Expand All @@ -458,18 +453,12 @@ def get_knn_reward(self):
base = (self.base_explore if self.levels_satisfied else cur_size) * pre_rew
post = (cur_size if self.levels_satisfied else 0) * post_rew
return base + post

def get_badges(self):
return self.bit_count(self.read_m(BADGE_COUNT_ADDRESS))

def read_party(self):
return [self.read_m(addr) for addr in PARTY_ADDRESSES]

def update_heal_reward(self):
cur_health = self.read_hp_fraction()
cur_health = self.reader.read_hp_fraction()
# if health increased and party size did not change
if (cur_health > self.last_health and
self.read_m(PARTY_SIZE_ADDRESS) == self.party_size):
self.reader.read_party_size_address() == self.party_size):
if self.last_health > 0:
heal_amount = cur_health - self.last_health
if heal_amount > 0.5:
Expand All @@ -488,12 +477,12 @@ def get_all_events_reward(self):
return max(
sum(
[
self.bit_count(self.read_m(i))
self.reader.bit_count(self.reader.read_m(i))
for i in range(event_flags_start, event_flags_end)
]
)
- base_event_flags
- int(self.read_bit(museum_ticket[0], museum_ticket[1])),
- int(self.reader.read_bit(museum_ticket[0], museum_ticket[1])),
0,
)

Expand Down Expand Up @@ -529,7 +518,7 @@ def get_game_state_reward(self, print_stats=False):
'heal': self.reward_scale*self.total_healing_rew,
'op_lvl': self.reward_scale*self.update_max_op_level(),
'dead': self.reward_scale*-0.1*self.died_count,
'badge': self.reward_scale*self.get_badges() * 5,
'badge': self.reward_scale*self.reader.get_badges() * 5,
#'op_poke': self.reward_scale*self.max_opponent_poke * 800,
#'money': self.reward_scale* money * 3,
#'seen_poke': self.reward_scale * seen_poke_count * 400,
Expand All @@ -547,7 +536,7 @@ def save_screenshot(self, name):

def update_max_op_level(self):
#opponent_level = self.read_m(0xCFE8) - 5 # base level
opponent_level = max([self.read_m(a) for a in OPPONENT_LEVELS_ADDRESSES]) - 5
opponent_level = self.reader.get_opponent_level()
#if opponent_level >= 7:
# self.save_screenshot('highlevelop')
self.max_opponent_level = max(self.max_opponent_level, opponent_level)
Expand All @@ -558,68 +547,3 @@ def update_max_event_rew(self):
self.max_event_rew = max(cur_rew, self.max_event_rew)
return self.max_event_rew

def read_hp_fraction(self):
hp_sum = sum([self.read_hp(add) for add in HP_ADDRESSES])
max_hp_sum = sum([self.read_hp(add) for add in MAX_HP_ADDRESSES])
max_hp_sum = max(max_hp_sum, 1)
return hp_sum / max_hp_sum

def read_hp(self, start):
return 256 * self.read_m(start) + self.read_m(start+1)

# built-in since python 3.10
def bit_count(self, bits):
return bin(bits).count('1')

def read_triple(self, start_add):
return 256*256*self.read_m(start_add) + 256*self.read_m(start_add+1) + self.read_m(start_add+2)

def read_bcd(self, num):
return 10 * ((num >> 4) & 0x0f) + (num & 0x0f)

def read_money(self):
return (100 * 100 * self.read_bcd(self.read_m(MONEY_ADDRESS_1)) +
100 * self.read_bcd(self.read_m(MONEY_ADDRESS_2)) +
self.read_bcd(self.read_m(MONEY_ADDRESS_3)))

def get_map_location(self, map_idx):
map_locations = {
0: "Pallet Town",
1: "Viridian City",
2: "Pewter City",
3: "Cerulean City",
12: "Route 1",
13: "Route 2",
14: "Route 3",
15: "Route 4",
33: "Route 22",
37: "Red house first",
38: "Red house second",
39: "Blues house",
40: "oaks lab",
41: "Pokémon Center (Viridian City)",
42: "Poké Mart (Viridian City)",
43: "School (Viridian City)",
44: "House 1 (Viridian City)",
47: "Gate (Viridian City/Pewter City) (Route 2)",
49: "Gate (Route 2)",
50: "Gate (Route 2/Viridian Forest) (Route 2)",
51: "viridian forest",
52: "Pewter Museum (floor 1)",
53: "Pewter Museum (floor 2)",
54: "Pokémon Gym (Pewter City)",
55: "House with disobedient Nidoran♂ (Pewter City)",
56: "Poké Mart (Pewter City)",
57: "House with two Trainers (Pewter City)",
58: "Pokémon Center (Pewter City)",
59: "Mt. Moon (Route 3 entrance)",
60: "Mt. Moon",
61: "Mt. Moon",
68: "Pokémon Center (Route 4)",
193: "Badges check gate (Route 22)"
}
if map_idx in map_locations.keys():
return map_locations[map_idx]
else:
return "Unknown Location"