This repository was archived by the owner on Apr 13, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnac.py
More file actions
113 lines (90 loc) · 4.14 KB
/
nac.py
File metadata and controls
113 lines (90 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import Callable
import gym
import numpy as np
import torch
from matplotlib import pyplot as plt
from utils.vector_utils import angle_between
def train(env: gym.Env, model: torch.nn.Module, phi: Callable[[np.ndarray], np.ndarray], render: bool,
gamma: float, lambda_: float, alpha: float, alpha_decay: float, h: int,
beta: float, eps: float, max_episodes: int):
"""Perform a gradient-ascent policy optimization using the natural gradient and LSTD-Q
Arguments:
env {gym.Env} -- Gym environment to train on.
model {torch.nn.Module} -- Actor/policy model.
phi {Callable[[np.ndarray], np.ndarray]} -- Maps observations of the environment to basic functions for the critic.
render {bool} -- If the environment should render during training. Will slow down training a lot.
gamma {float} -- Gamma value in [0,1).
lambda_ {float} -- Lambda value in [0,1].
alpha {float} -- Learning rate in (0,1].
alpha_decay {float} -- How much learning rate is decreasing over time. Should be in [0,1], 0 to disable completely.
h {int} -- How many gradient estimations must be within the angle eps to consider the gradient as converged.
beta {float} -- Forgetting factor, used for resetting the critic after a gradient step. Set to 0 to discard old statistics.
eps {float} -- Angle in gradients in which the last h gradient estimates must be to consider the gradient as converged.
max_episodes {int} -- Number of episodes after which the training stops and returns.
Returns:
None -- nothing
"""
dim_theta = model.theta().shape[0]
dim_phi = phi(np.zeros(model.state_dim)).shape[0]
b = z = np.zeros(dim_phi + dim_theta)
A = np.zeros((dim_phi + dim_theta, dim_phi + dim_theta))
x = env.reset()
phi_x = phi(x)
episodes = 0
epochs = 0
done = False
total_return = 0
w_history = list()
while episodes < max_episodes:
if done:
model.returns.append(total_return)
print(str(episodes) + ", R: " + "{:.2E}".format(total_return))
total_return = 0
episodes += 1
x = env.reset()
phi_x = phi(x)
if render:
env.render()
policy = model(torch.FloatTensor(x))
u = policy.sample()
log_prob = policy.log_prob(u)
x1, r, done, _ = env.step(u.detach().numpy())
phi_x1 = phi(x1)
total_return += r
autograd = torch.autograd.grad(log_prob, model.parameters())
grad_theta = np.concatenate([grad.numpy().flatten() for grad in autograd])
phi_tilde = np.concatenate([phi_x1, np.zeros_like(grad_theta)])
phi_hat = np.concatenate([phi_x, grad_theta])
z = lambda_ * z + phi_hat
A = A + np.multiply.outer(z, (phi_hat - gamma * phi_tilde))
b = b + z * r
try:
if not np.linalg.matrix_rank(A) == len(b):
update = np.linalg.lstsq(A, b, rcond=None)[0]
else:
update = np.linalg.solve(A, b)
except np.linalg.LinAlgError:
break
w = update[:dim_theta]
if len(w_history) < h:
natural_gradient_converged = False
else:
natural_gradient_converged = True
for tao in range(1, h + 1):
angle_converged = angle_between(w, w_history[-tao]) < eps
approx_same = np.linalg.norm(w - w_history[-tao]) < np.finfo(float).eps
natural_gradient_converged = (angle_converged or approx_same) and natural_gradient_converged
w_history.append(w)
if natural_gradient_converged:
old_theta = model.theta()
learning_rate = alpha / (alpha_decay * epochs + 1)
new_theta = model.theta() + learning_rate * w
model.set_theta(new_theta)
model.theta_history.append(new_theta)
theta_delta = np.linalg.norm(new_theta - old_theta)
print("-> " + "{:.2E}".format(theta_delta))
z, A, b = beta * z, beta * A, beta * b
w_history = list()
epochs += 1
x = x1
phi_x = phi_x1