Skip to content

Commit ef514e6

Browse files
committed
Added simple plotting and pickling
1 parent e3840cb commit ef514e6

8 files changed

+123
-29
lines changed

actions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161

6262

6363
def get_valid_action_numbers_from_state(s, state_repr='simple'):
64-
if state_repr is 'simple' or state_repr is 'adjacent':
64+
if state_repr is 'simple' or state_repr is 'adjacent' or state_repr is 'adjacent_conservative':
6565
actions = []
6666
top_left, top_right, bot_left, bot_right = s
6767
if top_left is not None:

agent.py

+53-13
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
import logging
22
from abc import ABCMeta, abstractmethod
33

4-
from actions import action_number_to_name
54
from learner import QLearner
65
from world import QbertWorld
76

8-
import matplotlib.pyplot as plt
9-
107

118
class Agent:
129
__metaclass__ = ABCMeta
@@ -18,8 +15,9 @@ def action(self):
1815

1916
class QbertAgent(Agent):
2017
def __init__(self, agent_type='subsumption', random_seed=123, frame_skip=4, repeat_action_probability=0,
21-
sound=True, display_screen=False, state_repr='adjacent', alpha=0.1, gamma=0.95, epsilon=0.2,
22-
unexplored_threshold=1, unexplored_reward=100, exploration='combined', distance_metric=None):
18+
sound=True, display_screen=False, state_repr='adjacent_conservative', alpha=0.1, gamma=0.95,
19+
epsilon=0.2, unexplored_threshold=1, unexplored_reward=100, exploration='combined',
20+
distance_metric=None):
2321
if agent_type is 'block':
2422
self.agent = QbertBlockAgent(random_seed, frame_skip, repeat_action_probability, sound, display_screen,
2523
state_repr, alpha, gamma, epsilon, unexplored_threshold, unexplored_reward,
@@ -41,6 +39,15 @@ def __init__(self, agent_type='subsumption', random_seed=123, frame_skip=4, repe
4139
def action(self):
4240
return self.agent.action()
4341

42+
def q_size(self):
43+
return self.agent.q_size()
44+
45+
def save(self, filename):
46+
self.agent.save(filename)
47+
48+
def load(self, filename):
49+
self.agent.load(filename)
50+
4451

4552
class QbertBlockAgent(Agent):
4653
def __init__(self, random_seed, frame_skip, repeat_action_probability, sound, display_screen, state_repr, alpha,
@@ -57,6 +64,15 @@ def action(self):
5764
self.block_learner.update(s, a, s_next, block_score)
5865
return block_score + friendly_score + enemy_score
5966

67+
def q_size(self):
68+
return len(self.block_learner.Q)
69+
70+
def save(self, filename):
71+
self.block_learner.save(filename)
72+
73+
def load(self, filename):
74+
self.block_learner.load(filename)
75+
6076

6177
class QbertEnemyAgent(Agent):
6278
def __init__(self, random_seed, frame_skip, repeat_action_probability, sound, display_screen, state_repr, alpha,
@@ -73,6 +89,15 @@ def action(self):
7389
self.enemy_learner.update(s, a, s_next, enemy_score + enemy_penalty)
7490
return block_score + friendly_score + enemy_score
7591

92+
def q_size(self):
93+
return len(self.enemy_learner.Q)
94+
95+
def save(self, filename):
96+
self.enemy_learner.save(filename)
97+
98+
def load(self, filename):
99+
self.enemy_learner.load(filename)
100+
76101

77102
class QbertFriendlyAgent(Agent):
78103
def __init__(self, random_seed, frame_skip, repeat_action_probability, sound, display_screen, state_repr, alpha,
@@ -89,6 +114,15 @@ def action(self):
89114
self.friendly_learner.update(s, a, s_next, friendly_score)
90115
return block_score + friendly_score + enemy_score
91116

117+
def q_size(self):
118+
return len(self.friendly_learner.Q)
119+
120+
def save(self, filename):
121+
self.friendly_learner.save(filename)
122+
123+
def load(self, filename):
124+
self.friendly_learner.load(filename)
125+
92126

93127
class QbertSubsumptionAgent(Agent):
94128
def __init__(self, random_seed, frame_skip, repeat_action_probability, sound, display_screen, state_repr, alpha,
@@ -130,14 +164,6 @@ def action(self):
130164
else:
131165
logging.debug('Chose block action!')
132166
chosen_action = a
133-
if chosen_action is None:
134-
logging.info('None action!')
135-
logging.info('Current row/col : {}/{}'.format(self.world.current_row, self.world.current_col))
136-
logging.info('Prev block state: {}'.format(s))
137-
logging.info('Prev enemy state: {}'.format(s_enemies))
138-
logging.info('Prev friendly state: {}'.format(s_friendlies))
139-
plt.imshow(self.world.rgb_screen)
140-
plt.show()
141167
block_score, friendly_score, enemy_score, enemy_penalty = self.world.perform_action(chosen_action)
142168
if enemy_present:
143169
s_next_enemies = self.world.to_state_enemies()
@@ -154,4 +180,18 @@ def action(self):
154180
self.block_learner.update(s, a, s_next, block_score)
155181
return block_score + friendly_score + enemy_score
156182

183+
def q_size(self):
184+
return len(self.block_learner.Q) + \
185+
len(self.friendly_learner.Q) + \
186+
len(self.enemy_learner.Q)
187+
188+
def save(self, filename):
189+
self.block_learner.save('{}_{}'.format(filename, 'block'))
190+
self.friendly_learner.save('{}_{}'.format(filename, 'friendly'))
191+
self.enemy_learner.save('{}_{}'.format(filename, 'enemy'))
192+
193+
def load(self, filename):
194+
self.block_learner.load('{}_{}'.format(filename, 'block'))
195+
self.friendly_learner.load('{}_{}'.format(filename, 'friendly'))
196+
self.enemy_learner.load('{}_{}'.format(filename, 'enemy'))
157197
# Human high scores: 15825, 27000

learner.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from abc import ABCMeta, abstractmethod
44

55
from actions import action_number_to_name, get_valid_action_numbers_from_state
6+
from pickler import save_to_pickle, load_from_pickle
67

78

89
class Learner:
@@ -19,16 +20,16 @@ def update(self, s, a, s_next, reward):
1920

2021
class QLearner(Learner):
2122
def __init__(self, world, alpha, gamma, epsilon, unexplored_threshold, unexplored_reward, exploration,
22-
distance_metric, state_repr):
23+
distance_metric, state_repr, initial_q=None, initial_n=None):
2324
self.alpha = alpha
2425
self.gamma = gamma
2526
self.epsilon = epsilon
2627
self.unexplored_threshold = unexplored_threshold
2728
self.unexplored_reward = unexplored_reward
2829
self.exploration = exploration
2930
self.distance_metric = distance_metric
30-
self.Q = {}
31-
self.N = {}
31+
self.Q = initial_q if initial_q is not None else {}
32+
self.N = initial_n if initial_n is not None else {}
3233
self.world = world
3334
self.state_repr = state_repr
3435

@@ -46,6 +47,8 @@ def update(self, s, a, s_next, reward):
4647
self.q_update(s, a, s_next, reward)
4748

4849
def q_update(self, s, a, s_next, reward):
50+
if self.exploration is 'combined':
51+
self.N[s, a] = self.N.get((s, a), 0) + 1
4952
old_q = self.get_q(s, a)
5053
new_q = old_q + self.alpha * (reward + self.gamma * self.get_max_q(s_next) - old_q)
5154
self.Q[s, a] = new_q
@@ -112,3 +115,13 @@ def update_close(self, a, new_q):
112115
for s_close, a_close in zip(states_close, actions_close):
113116
self.Q[s_close, a_close] = new_q
114117

118+
def save(self, filename):
119+
save_to_pickle(self.Q, '{}_{}'.format(filename, 'Q'))
120+
save_to_pickle(self.N, '{}_{}'.format(filename, 'N'))
121+
122+
def load(self, filename):
123+
self.Q = load_from_pickle('{}_{}'.format(filename, 'Q'))
124+
self.N = load_from_pickle('{}_{}'.format(filename, 'N'))
125+
logging.debug('Loaded Q: {}'.format(self.Q))
126+
logging.debug('Loaded N: {}'.format(self.N))
127+

main.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from argparse import ArgumentParser
66

77
from agent import QbertAgent
8+
from plotter import plot_scores
89

910
LOGGING_LEVELS = {
1011
'info': logging.INFO,
@@ -15,11 +16,14 @@
1516
}
1617

1718

18-
def play_learning_agent(num_episodes=1000, show_image=False):
19-
19+
def play_learning_agent(num_episodes=1000, show_image=False, load_learning_filename='test_pickle',
20+
save_learning_filename='test_pickle', plot_filename='adjacent_conservative_sub_combined'):
2021
agent = QbertAgent()
2122
world = agent.world
2223
max_score = 0
24+
scores = []
25+
if load_learning_filename is not None:
26+
agent.load(load_learning_filename)
2327
for episode in range(num_episodes):
2428
total_reward = 0
2529
world.reset()
@@ -28,12 +32,16 @@ def play_learning_agent(num_episodes=1000, show_image=False):
2832
if show_image:
2933
plt.imshow(world.rgb_screen)
3034
plt.show()
35+
scores.append(total_reward)
3136
logging.info('Episode {} ended with score: {}'.format(episode + 1, total_reward))
3237
max_score = max(max_score, total_reward)
3338
world.ale.reset_game()
39+
if plot_filename is not None:
40+
plot_scores(scores, plot_filename)
41+
if save_learning_filename is not None:
42+
agent.save(save_learning_filename)
3443
logging.info('Maximum reward: {}'.format(max_score))
35-
# TODO: plot results here
36-
44+
logging.info('Total Q size: {}'.format(agent.q_size()))
3745
# TODO: Exploration very key... getting very high scores early on because of unexplored weighting...
3846

3947

misc_test.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22

3+
from pickler import save_to_pickle, load_from_pickle
4+
35
INITIAL_PARAMETERS1 = [
46
[0],
57
[0, 0],
@@ -50,5 +52,13 @@ def test_return_none(param):
5052
return None
5153

5254

55+
def test_pickle():
56+
q = {(1, 2): 5, (5, 6): 10}
57+
print(q)
58+
save_to_pickle(q, 'test')
59+
q2 = load_from_pickle('test')
60+
print(q2)
61+
62+
5363
if __name__ == '__main__':
54-
print(test_return_none(5))
64+
test_pickle()

pickler.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import pickle
2+
3+
4+
def save_to_pickle(data, filename):
5+
with open('pickle/{}.pkl'.format(filename), 'wb') as f:
6+
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
7+
8+
9+
def load_from_pickle(filename):
10+
with open('pickle/{}.pkl'.format(filename), 'rb') as f:
11+
data = pickle.load(f)
12+
return data

plotter.py

+10
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
11
import matplotlib.pyplot as plt
2+
from matplotlib.ticker import MaxNLocator
3+
4+
5+
def plot_scores(scores, filename):
6+
f = plt.figure()
7+
ax = f.gca()
8+
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
9+
plt.plot(scores, label='Score')
10+
plt.legend()
11+
f.savefig('report/plots/{}.pdf'.format(filename), bbox_inches='tight')

world.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def __init__(self, random_seed, frame_skip, repeat_action_probability, sound, di
9999
ale = ALEInterface()
100100

101101
# Get & Set the desired settings
102-
ale.setInt('random_seed', random_seed)
102+
if random_seed is not None:
103+
ale.setInt('random_seed', random_seed)
103104
ale.setInt('frame_skip', frame_skip)
104105
ale.setFloat('repeat_action_probability', repeat_action_probability)
105106

@@ -148,7 +149,7 @@ def __init__(self, random_seed, frame_skip, repeat_action_probability, sound, di
148149
def to_state_blocks(self):
149150
if self.state_repr is 'simple':
150151
return self.to_state_blocks_simple()
151-
elif self.state_repr is 'adjacent':
152+
elif self.state_repr is 'adjacent' or self.state_repr is 'adjacent_conservative':
152153
return self.to_state_blocks_adjacent()
153154
elif self.state_repr is 'verbose':
154155
return self.to_state_blocks_verbose()
@@ -158,13 +159,15 @@ def to_state_enemies(self):
158159
return self.to_state_enemies_simple()
159160
elif self.state_repr is 'adjacent':
160161
return self.to_state_enemies_adjacent()
162+
elif self.state_repr is 'adjacent_conservative':
163+
return self.to_state_enemies_adjacent_conservative()
161164
elif self.state_repr is 'verbose':
162165
return self.to_state_enemies_verbose()
163166

164167
def to_state_friendlies(self):
165168
if self.state_repr is 'simple':
166169
return self.to_state_friendlies_simple()
167-
elif self.state_repr is 'adjacent':
170+
elif self.state_repr is 'adjacent' or self.state_repr is 'adjacent_conservative':
168171
return self.to_state_friendlies_simple() # TODO: Make adjacent version of friendlies
169172
elif self.state_repr is 'verbose':
170173
return self.to_state_friendlies_verbose()
@@ -346,7 +349,7 @@ def to_state_enemies_adjacent(self):
346349
bot_right = 0
347350
return top_left, top_right, bot_left, bot_right
348351

349-
def to_state_enemies_adjacent_old(self):
352+
def to_state_enemies_adjacent_conservative(self):
350353
"""
351354
Adjacent state representation for enemies around Qbert.
352355
@@ -493,7 +496,7 @@ def update_rgb(self):
493496
if self.screen_not_flashing() \
494497
and not np.array_equal(score_color, COLOR_BLACK) \
495498
and not np.array_equal(score_color, self.desired_color):
496-
logging.info('Identified {} as new desired color'.format(score_color))
499+
logging.debug('Identified {} as new desired color'.format(score_color))
497500
self.desired_color = score_color
498501

499502
self.enemy_present = False
@@ -553,8 +556,6 @@ def reset_position(self):
553556
reward += self.ale.act(NO_OP)
554557
self.ale.getRAM(self.ram)
555558
self.update_rgb()
556-
if reward > 0:
557-
logging.info('Nonzero reward of {} when resetting position.'.format(reward))
558559
return reward
559560

560561
def reset(self):

0 commit comments

Comments
 (0)