-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlearner.py
186 lines (162 loc) · 6.55 KB
/
learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import logging
import random
from abc import ABCMeta, abstractmethod
from actions import action_number_to_name, get_valid_action_numbers_from_state
from pickler import save_to_pickle, load_from_pickle
class Learner:
__metaclass__ = ABCMeta
@abstractmethod
def get_best_actions(self, s):
"""
Get the best actions from the given state.
"""
raise NotImplementedError
@abstractmethod
def update(self, s, a, s_next, reward):
"""
Update the learner parameters.
"""
raise NotImplementedError
class QLearner(Learner):
def __init__(self, world, alpha, gamma, epsilon, unexplored_threshold, unexplored_reward, exploration,
distance_metric, state_repr, initial_q=None, initial_n=None, tag=None,
exploration_function_type='simple'):
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.unexplored_threshold = unexplored_threshold
self.unexplored_reward = unexplored_reward
self.exploration = exploration
self.distance_metric = distance_metric
self.Q = initial_q if initial_q is not None else {}
self.N = initial_n if initial_n is not None else {}
self.world = world
self.state_repr = state_repr
self.tag = tag
self.exploration_function_type = exploration_function_type
def get_best_actions(self, s):
if self.exploration is 'optimistic':
return self.get_best_actions_optimistic(s)
elif self.exploration is 'random':
return self.get_best_actions_random(s)
elif self.exploration is 'combined':
return self.get_best_actions_combined(s)
else:
return self.get_best_actions_no_exploration(s)
def update(self, s, a, s_next, reward):
self.q_update(s, a, s_next, reward)
def get_best_actions_random(self, s):
"""
Get the best actions from the given state using epsilon-greedy.
"""
if random.random() < self.epsilon:
actions = get_valid_action_numbers_from_state(s, self.state_repr)
action = random.choice(actions)
logging.debug('Randomly chose {}'.format(action_number_to_name(action)))
return [action]
else:
return self.get_best_actions_no_exploration(s)
def get_best_actions_optimistic(self, s):
"""
Get the best actions from the given state using an optimistic prior.
"""
actions = get_valid_action_numbers_from_state(s, self.state_repr)
logging.debug('Valid actions: {}'.format([action_number_to_name(a) for a in actions]))
max_q = float('-inf')
max_actions = []
for a in actions:
q = self.exploration_function(s, a)
if q > max_q:
max_q = q
max_actions = [a]
elif q == max_q:
max_actions.append(a)
return max_actions
def get_best_actions_combined(self, s):
"""
Get the best actions from the given state using epsilon-greedy and an optimistic prior.
"""
if random.random() < self.epsilon:
actions = get_valid_action_numbers_from_state(s, self.state_repr)
action = random.choice(actions)
logging.debug('Randomly chose {}'.format(action_number_to_name(action)))
return [action]
else:
return self.get_best_actions_optimistic(s)
def q_update(self, s, a, s_next, reward):
"""
Q-learning update.
"""
if self.exploration is 'combined':
self.N[s, a] = self.N.get((s, a), 0) + 1
old_q = self.get_q(s, a)
new_q = old_q + self.alpha * (reward + self.gamma * self.get_max_q(s_next) - old_q)
if new_q == float('inf'):
logging.info('Infinite Q saved!')
if new_q == float('-inf'):
logging.info('-Infinite Q saved!')
self.Q[s, a] = new_q
self.update_close(a, new_q)
def save(self, filename):
"""
Save the current learning parameters to a pickle file.
"""
save_to_pickle(self.Q, '{}_{}'.format(filename, 'Q'))
save_to_pickle(self.N, '{}_{}'.format(filename, 'N'))
def load(self, filename):
"""
Load learning parameters from a pickle file.
"""
self.Q = load_from_pickle('{}_{}'.format(filename, 'Q'))
self.N = load_from_pickle('{}_{}'.format(filename, 'N'))
logging.debug('Loaded Q: {}'.format(self.Q))
logging.debug('Loaded N: {}'.format(self.N))
def get_best_single_action(self, s):
if self.exploration is 'optimistic':
actions = self.get_best_actions_optimistic(s)
elif self.exploration is 'random':
actions = self.get_best_actions_random(s)
elif self.exploration is 'combined':
actions = self.get_best_actions_combined(s)
else:
actions = self.get_best_actions_no_exploration(s)
return random.choice(actions)
def get_q(self, s, a):
return self.Q.get((s, a), 0)
def get_best_actions_no_exploration(self, s):
actions = get_valid_action_numbers_from_state(s, self.state_repr)
max_q = float('-inf')
max_actions = []
for a in actions:
q = self.get_q(s, a)
if q > max_q:
max_q = q
max_actions = [a]
elif q == max_q:
max_actions.append(a)
return max_actions
def exploration_function(self, s, a):
if self.exploration_function_type is 'simple':
return self.exploration_function_simple(s, a)
else:
return None
def exploration_function_simple(self, s, a):
return self.unexplored_reward if self.N.get((s, a), 0) < self.unexplored_threshold else self.get_q(s, a)
def get_best_action(self, s, actions):
best_action = None
max_q = float('-inf')
for a in actions:
q = self.get_q(s, a)
if q > max_q:
max_q = q
best_action = a
return best_action
def get_max_q(self, s):
max_q = float('-inf')
for a in get_valid_action_numbers_from_state(s):
max_q = max(max_q, self.Q.get((s, a), 0))
return max_q if max_q != float('-inf') else 0
def update_close(self, a, new_q):
states_close, actions_close = self.world.get_close_states_actions(a, distance_metric=self.distance_metric)
for s_close, a_close in zip(states_close, actions_close):
self.Q[s_close, a_close] = new_q