-
Notifications
You must be signed in to change notification settings - Fork 1
/
pilco_runner.py
51 lines (37 loc) · 1.49 KB
/
pilco_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import logging
import sys
import gym
from experiments.util.logger_util import enable_logging
from experiments.util.logger_util import show_cmd_args
from pilco.cost_function.saturated_loss import SaturatedLoss
from pilco.pilco import PILCO
import time
from pilco.util.util import parse_args, evaluate_policy, load_model, get_env
def main():
args = parse_args(sys.argv[1:])
enable_logging(logging_lvl=logging.DEBUG, save_log=not args.no_log,
logfile_prefix="PILCO_" + args.env_name + "_")
logging.info(
f'Start experiment for {args.env_name} at {time.strftime("%m/%d/%Y, %Hh:%Mm:%Ss", time.gmtime(time.time()))}')
# show given cmd-parameters
show_cmd_args(args)
env = get_env(args.env_name, args.monitor)
# make sure that the dir ends with an "/"
if args.weight_dir:
if args.weight_dir[-1] != '/':
args.weight_dir += '/'
if args.test:
policy = load_model(f"{args.weight_dir}policy.p")
evaluate_policy(policy, env, max_action=args.max_action, no_render=args.no_render, n_runs=args.test_runs)
env.close()
else:
state_dim = env.observation_space.shape[0]
env.close()
loss = SaturatedLoss(state_dim=state_dim, target_state=args.target_state, weights=args.weights)
pilco = PILCO(args, loss=loss)
# load the models if "args.weight_dir" is given
if args.weight_dir:
pilco.load(args.weight_dir)
pilco.run()
if __name__ == '__main__':
main()