-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrl.py
39 lines (29 loc) · 1.08 KB
/
rl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
def get_space_info(space):
""" Extract parameters of gym environment """
discrete = hasattr(space, 'n')
if discrete:
n = space.n
else:
n = len(space.sample())
return {'n': n, 'discrete': discrete}
def get_policy(obs_space_info, action_space_info):
def linear_policy(w, observation):
""" continuous observations and actions """
return np.outer(observation, w)
def sigmoid_policy(w, observation):
""" continuous observations, discrete actions """
#print(w.shape)
#print(observation.shape)
outer = np.dot(observation, w)
#print(outer)
return np.argmax(outer)
def one_hot_sigmoid_policy(w, observation):
one_hot = np.eye(obs_space_info['n'])[observation]
return np.argmax(np.outer(one_hot, w))
if not obs_space_info['discrete'] and not action_space_info['discrete']:
return linear_policy
elif not obs_space_info['discrete'] and action_space_info['discrete']:
return sigmoid_policy
else:
return one_hot_sigmoid_policy