From a19f2018b0609ecb4a93cab0f832be523c80938d Mon Sep 17 00:00:00 2001 From: Mathieu D Date: Thu, 29 Jan 2026 15:46:17 +0100 Subject: [PATCH] add pyboy_reader Signed-off-by: Mathieu D --- baselines/memory_addresses.py | 4 +- baselines/reader_pyboy.py | 120 +++++++++++++++++++++++++++++++ baselines/red_gym_env.py | 130 +++++++--------------------------- 3 files changed, 150 insertions(+), 104 deletions(-) create mode 100644 baselines/reader_pyboy.py diff --git a/baselines/memory_addresses.py b/baselines/memory_addresses.py index be989ee58..b554c0a88 100644 --- a/baselines/memory_addresses.py +++ b/baselines/memory_addresses.py @@ -19,4 +19,6 @@ MONEY_ADDRESS_1 = 0xD347 MONEY_ADDRESS_2 = 0xD348 -MONEY_ADDRESS_3 = 0xD349 \ No newline at end of file +MONEY_ADDRESS_3 = 0xD349 + +SEEN_POKEMONS_ADDRESSES = [0xD30A, 0xD30B, 0xD30C, 0xD30D, 0xD30E, 0xD30F, 0xD310, 0xD311, 0xD312, 0xD313, 0xD314, 0xD315, 0xD316, 0xD317, 0xD318, 0xD319, 0xD31A, 0xD31B, 0xD31C] diff --git a/baselines/reader_pyboy.py b/baselines/reader_pyboy.py new file mode 100644 index 000000000..427a4068c --- /dev/null +++ b/baselines/reader_pyboy.py @@ -0,0 +1,120 @@ +import memory_addresses as addresses + + +class ReaderPyBoy: + + def __init__(self, pyboy): + self.pyboy = pyboy + + def read_m(self, addr): + return self.pyboy.get_memory_value(addr) + + def read_money(self): + return (100 * 100 * self.read_bcd(self.read_m(addresses.MONEY_ADDRESS_1)) + + 100 * self.read_bcd(self.read_m(addresses.MONEY_ADDRESS_2)) + + self.read_bcd(self.read_m(addresses.MONEY_ADDRESS_3))) + + def read_bcd(self, num): + return 10 * ((num >> 4) & 0x0f) + (num & 0x0f) + + def read_bit(self, addr, bit: int) -> bool: + # add padding so zero will read '0b100000000' instead of '0b0' + return bin(256 + self.read_m(addr))[-bit-1] == '1' + + def read_hp_fraction(self): + hp_sum = sum([self.read_hp(add) for add in addresses.HP_ADDRESSES]) + max_hp_sum = sum([self.read_hp(add) for add in addresses.MAX_HP_ADDRESSES]) + max_hp_sum = max(max_hp_sum, 1) + return hp_sum / max_hp_sum + + def read_hp(self, start): + return 256 * self.read_m(start) + self.read_m(start+1) + + # built-in since python 3.10 + def bit_count(self, bits): + return bin(bits).count('1') + + def read_triple(self, start_add): + return 256*256*self.read_m(start_add) + 256*self.read_m(start_add+1) + self.read_m(start_add+2) + + def get_badges(self): + return self.bit_count(self.read_m(addresses.BADGE_COUNT_ADDRESS)) + + def get_opponent_level(self): + return max([self.read_m(a) for a in addresses.OPPONENT_LEVELS_ADDRESSES]) - 5 + + def read_party(self): + return [self.read_m(addr) for addr in addresses.PARTY_ADDRESSES] + + def get_levels_sum(self): + poke_levels = [max(self.read_m(a) - 2, 0) for a in addresses.LEVELS_ADDRESSES] + return max(sum(poke_levels) - 4, 0) # subtract starting pokemon level + + def read_party_size_address(self): + return self.read_m(addresses.PARTY_SIZE_ADDRESS) + + def read_x_pos(self): + return self.read_m(addresses.X_POS_ADDRESS) + + def read_y_pos(self): + return self.read_m(addresses.Y_POS_ADDRESS) + + def read_map_n(self): + return self.read_m(addresses.MAP_N_ADDRESS) + + def read_events(self): + return [ + self.bit_count(self.read_m(i)) + for i in range(addresses.EVENT_FLAGS_START_ADDRESS, addresses.EVENT_FLAGS_END_ADDRESS) + ] + + def read_museum_tickets(self): + museum_ticket = (addresses.MUSEUM_TICKET_ADDRESS, 0) + return self.read_bit(museum_ticket[0], museum_ticket[1]) + + def read_levels(self): + return [self.read_m(a) for a in addresses.LEVELS_ADDRESSES] + + def read_seen_pokemons(self): + return [self.bit_count(self.read_m(a)) for a in addresses.SEEN_POKEMONS_ADDRESSES] + + def get_map_location(self): + map_locations = { + 0: "Pallet Town", + 1: "Viridian City", + 2: "Pewter City", + 3: "Cerulean City", + 12: "Route 1", + 13: "Route 2", + 14: "Route 3", + 15: "Route 4", + 33: "Route 22", + 37: "Red house first", + 38: "Red house second", + 39: "Blues house", + 40: "oaks lab", + 41: "Pokémon Center (Viridian City)", + 42: "Poké Mart (Viridian City)", + 43: "School (Viridian City)", + 44: "House 1 (Viridian City)", + 47: "Gate (Viridian City/Pewter City) (Route 2)", + 49: "Gate (Route 2)", + 50: "Gate (Route 2/Viridian Forest) (Route 2)", + 51: "viridian forest", + 52: "Pewter Museum (floor 1)", + 53: "Pewter Museum (floor 2)", + 54: "Pokémon Gym (Pewter City)", + 55: "House with disobedient Nidoran♂ (Pewter City)", + 56: "Poké Mart (Pewter City)", + 57: "House with two Trainers (Pewter City)", + 58: "Pokémon Center (Pewter City)", + 59: "Mt. Moon (Route 3 entrance)", + 60: "Mt. Moon", + 61: "Mt. Moon", + 68: "Pokémon Center (Route 4)", + 193: "Badges check gate (Route 22)" + } + if self.read_map_n() in map_locations.keys(): + return map_locations[self.read_map_n()] + else: + return "Unknown Location" diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index e1133f731..6ca4a0f86 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -15,6 +15,7 @@ import hnswlib import mediapy as media import pandas as pd +from reader_pyboy import ReaderPyBoy from gymnasium import Env, spaces from pyboy.utils import WindowEvent @@ -107,6 +108,7 @@ def __init__( window_type=head, hide_window='--quiet' in sys.argv, ) + self.reader = ReaderPyBoy(self.pyboy) self.screen = self.pyboy.botsupport_manager().screen() @@ -210,11 +212,11 @@ def step(self, action): self.update_seen_coords() self.update_heal_reward() - self.party_size = self.read_m(PARTY_SIZE_ADDRESS) + self.party_size = self.reader.read_party_size_address() new_reward, new_prog = self.update_reward() - self.last_health = self.read_hp_fraction() + self.last_health = self.reader.read_hp_fraction() # shift over short term reward memory self.recent_memory = np.roll(self.recent_memory, 3) @@ -260,25 +262,25 @@ def add_video_frame(self): self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False)) def append_agent_stats(self, action): - x_pos = self.read_m(X_POS_ADDRESS) - y_pos = self.read_m(Y_POS_ADDRESS) - map_n = self.read_m(MAP_N_ADDRESS) - levels = [self.read_m(a) for a in LEVELS_ADDRESSES] + x_pos = self.reader.read_x_pos() + y_pos = self.reader.read_y_pos() + map_n = self.reader.read_map_n() + levels = self.reader.read_levels() if self.use_screen_explore: expl = ('frames', self.knn_index.get_current_count()) else: expl = ('coord_count', len(self.seen_coords)) self.agent_stats.append({ 'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n, - 'map_location': self.get_map_location(map_n), + 'map_location': self.reader.get_map_location(), 'last_action': action, - 'pcount': self.read_m(PARTY_SIZE_ADDRESS), + 'pcount': self.reader.read_party_size_address(), 'levels': levels, 'levels_sum': sum(levels), - 'ptypes': self.read_party(), - 'hp': self.read_hp_fraction(), + 'ptypes': self.reader.read_party(), + 'hp': self.reader.read_hp_fraction(), expl[0]: expl[1], - 'deaths': self.died_count, 'badge': self.get_badges(), + 'deaths': self.died_count, 'badge': self.reader.get_badges(), 'event': self.progress_reward['event'], 'healr': self.total_healing_rew }) @@ -304,9 +306,9 @@ def update_frame_knn_index(self, frame_vec): ) def update_seen_coords(self): - x_pos = self.read_m(X_POS_ADDRESS) - y_pos = self.read_m(Y_POS_ADDRESS) - map_n = self.read_m(MAP_N_ADDRESS) + x_pos = self.reader.read_m(X_POS_ADDRESS) + y_pos = self.reader.read_m(Y_POS_ADDRESS) + map_n = self.reader.read_m(MAP_N_ADDRESS) coord_string = f"x:{x_pos} y:{y_pos} m:{map_n}" if self.get_levels_sum() >= 22 and not self.levels_satisfied: self.levels_satisfied = True @@ -322,7 +324,7 @@ def update_reward(self): new_prog = self.group_rewards() new_total = sum([val for _, val in self.progress_reward.items()]) #sqrt(self.explore_reward * self.progress_reward) new_step = new_total - self.total_reward - if new_step < 0 and self.read_hp_fraction() > 0: + if new_step < 0 and self.reader.read_hp_fraction() > 0: #print(f'\n\nreward went down! {self.progress_reward}\n\n') self.save_screenshot('neg_reward') @@ -337,7 +339,7 @@ def group_rewards(self): prog = self.progress_reward # these values are only used by memory return (prog['level'] * 100 / self.reward_scale, - self.read_hp_fraction()*2000, + self.reader.read_hp_fraction()*2000, prog['explore'] * 150 / (self.explore_weight * self.reward_scale)) #(prog['events'], # prog['levels'] + prog['party_xp'], @@ -371,7 +373,7 @@ def make_reward_channel(r_val): make_reward_channel(explore) ), axis=-1) - if self.get_badges() > 0: + if self.reader.get_badges() > 0: full_memory[:, -1, :] = 255 return full_memory @@ -389,7 +391,7 @@ def check_if_done(self): done = True else: done = self.step_count >= self.max_steps - #done = self.read_hp_fraction() == 0 + #done = self.reader.read_hp_fraction() == 0 return done def save_and_print_info(self, done, obs_memory): @@ -427,16 +429,9 @@ def save_and_print_info(self, done, obs_memory): json.dump(self.all_runs, f) pd.DataFrame(self.agent_stats).to_csv( self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a') - - def read_m(self, addr): - return self.pyboy.get_memory_value(addr) - def read_bit(self, addr, bit: int) -> bool: - # add padding so zero will read '0b100000000' instead of '0b0' - return bin(256 + self.read_m(addr))[-bit-1] == '1' - def get_levels_sum(self): - poke_levels = [max(self.read_m(a) - 2, 0) for a in LEVELS_ADDRESSES] + poke_levels = self.reader.read_levels() return max(sum(poke_levels) - 4, 0) # subtract starting pokemon level def get_levels_reward(self): @@ -458,18 +453,12 @@ def get_knn_reward(self): base = (self.base_explore if self.levels_satisfied else cur_size) * pre_rew post = (cur_size if self.levels_satisfied else 0) * post_rew return base + post - - def get_badges(self): - return self.bit_count(self.read_m(BADGE_COUNT_ADDRESS)) - def read_party(self): - return [self.read_m(addr) for addr in PARTY_ADDRESSES] - def update_heal_reward(self): - cur_health = self.read_hp_fraction() + cur_health = self.reader.read_hp_fraction() # if health increased and party size did not change if (cur_health > self.last_health and - self.read_m(PARTY_SIZE_ADDRESS) == self.party_size): + self.reader.read_party_size_address() == self.party_size): if self.last_health > 0: heal_amount = cur_health - self.last_health if heal_amount > 0.5: @@ -488,12 +477,12 @@ def get_all_events_reward(self): return max( sum( [ - self.bit_count(self.read_m(i)) + self.reader.bit_count(self.reader.read_m(i)) for i in range(event_flags_start, event_flags_end) ] ) - base_event_flags - - int(self.read_bit(museum_ticket[0], museum_ticket[1])), + - int(self.reader.read_bit(museum_ticket[0], museum_ticket[1])), 0, ) @@ -529,7 +518,7 @@ def get_game_state_reward(self, print_stats=False): 'heal': self.reward_scale*self.total_healing_rew, 'op_lvl': self.reward_scale*self.update_max_op_level(), 'dead': self.reward_scale*-0.1*self.died_count, - 'badge': self.reward_scale*self.get_badges() * 5, + 'badge': self.reward_scale*self.reader.get_badges() * 5, #'op_poke': self.reward_scale*self.max_opponent_poke * 800, #'money': self.reward_scale* money * 3, #'seen_poke': self.reward_scale * seen_poke_count * 400, @@ -547,7 +536,7 @@ def save_screenshot(self, name): def update_max_op_level(self): #opponent_level = self.read_m(0xCFE8) - 5 # base level - opponent_level = max([self.read_m(a) for a in OPPONENT_LEVELS_ADDRESSES]) - 5 + opponent_level = self.reader.get_opponent_level() #if opponent_level >= 7: # self.save_screenshot('highlevelop') self.max_opponent_level = max(self.max_opponent_level, opponent_level) @@ -558,68 +547,3 @@ def update_max_event_rew(self): self.max_event_rew = max(cur_rew, self.max_event_rew) return self.max_event_rew - def read_hp_fraction(self): - hp_sum = sum([self.read_hp(add) for add in HP_ADDRESSES]) - max_hp_sum = sum([self.read_hp(add) for add in MAX_HP_ADDRESSES]) - max_hp_sum = max(max_hp_sum, 1) - return hp_sum / max_hp_sum - - def read_hp(self, start): - return 256 * self.read_m(start) + self.read_m(start+1) - - # built-in since python 3.10 - def bit_count(self, bits): - return bin(bits).count('1') - - def read_triple(self, start_add): - return 256*256*self.read_m(start_add) + 256*self.read_m(start_add+1) + self.read_m(start_add+2) - - def read_bcd(self, num): - return 10 * ((num >> 4) & 0x0f) + (num & 0x0f) - - def read_money(self): - return (100 * 100 * self.read_bcd(self.read_m(MONEY_ADDRESS_1)) + - 100 * self.read_bcd(self.read_m(MONEY_ADDRESS_2)) + - self.read_bcd(self.read_m(MONEY_ADDRESS_3))) - - def get_map_location(self, map_idx): - map_locations = { - 0: "Pallet Town", - 1: "Viridian City", - 2: "Pewter City", - 3: "Cerulean City", - 12: "Route 1", - 13: "Route 2", - 14: "Route 3", - 15: "Route 4", - 33: "Route 22", - 37: "Red house first", - 38: "Red house second", - 39: "Blues house", - 40: "oaks lab", - 41: "Pokémon Center (Viridian City)", - 42: "Poké Mart (Viridian City)", - 43: "School (Viridian City)", - 44: "House 1 (Viridian City)", - 47: "Gate (Viridian City/Pewter City) (Route 2)", - 49: "Gate (Route 2)", - 50: "Gate (Route 2/Viridian Forest) (Route 2)", - 51: "viridian forest", - 52: "Pewter Museum (floor 1)", - 53: "Pewter Museum (floor 2)", - 54: "Pokémon Gym (Pewter City)", - 55: "House with disobedient Nidoran♂ (Pewter City)", - 56: "Poké Mart (Pewter City)", - 57: "House with two Trainers (Pewter City)", - 58: "Pokémon Center (Pewter City)", - 59: "Mt. Moon (Route 3 entrance)", - 60: "Mt. Moon", - 61: "Mt. Moon", - 68: "Pokémon Center (Route 4)", - 193: "Badges check gate (Route 22)" - } - if map_idx in map_locations.keys(): - return map_locations[map_idx] - else: - return "Unknown Location" -