Skip to content

Commit 349e603

Browse files
sythamdementrock
authored andcommittedApr 14, 2017
Run hyperparameter search using hyperopt on EC2 (ryanjulian#110)
* Initial commit * Some additional explanatory comments
1 parent 2d08055 commit 349e603

File tree

8 files changed

+2191
-0
lines changed

8 files changed

+2191
-0
lines changed
 

‎contrib/rllab_hyperopt/__init__.py

Whitespace-only changes.

‎contrib/rllab_hyperopt/core.py

+232
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import os
2+
import sys
3+
sys.path.append('.')
4+
import threading
5+
import time
6+
import warnings
7+
import multiprocessing
8+
import importlib
9+
10+
from rllab import config
11+
from rllab.misc.instrument import run_experiment_lite
12+
13+
import polling
14+
from hyperopt import fmin, tpe, STATUS_OK, STATUS_FAIL
15+
from hyperopt.mongoexp import MongoTrials
16+
17+
class S3SyncThread(threading.Thread):
18+
'''
19+
Thread to periodically sync results from S3 in the background.
20+
21+
Uses same dirs as ./scripts/sync_s3.py.
22+
'''
23+
def __init__(self, sync_interval=60):
24+
super(S3SyncThread, self).__init__()
25+
self.sync_interval = sync_interval
26+
self._stop_event = threading.Event()
27+
28+
def stop(self):
29+
self._stop_event.set()
30+
31+
def stopped(self):
32+
return self._stop_event.isSet()
33+
34+
def run(self):
35+
remote_dir = config.AWS_S3_PATH
36+
local_dir = os.path.join(config.LOG_DIR, "s3")
37+
command = ("""
38+
aws s3 sync {remote_dir} {local_dir} --exclude '*stdout.log' --exclude '*stdouterr.log' --content-type "UTF-8"
39+
""".format(local_dir=local_dir, remote_dir=remote_dir))
40+
while True:
41+
fail = os.system(command)
42+
if fail:
43+
warnings.warn("Problem running the s3 sync command. You might want to run ./scripts/sync_s3.py manually in a shell to inspect.")
44+
if self.stopped():
45+
break
46+
time.sleep(self.sync_interval)
47+
48+
def _launch_workers(exp_key, n_workers, host, port, result_db_name):
49+
jobs = []
50+
for i in range(n_workers):
51+
p = multiprocessing.Process(target=_launch_worker, args=(exp_key,i,host, port, result_db_name))
52+
jobs.append(p)
53+
p.start()
54+
time.sleep(1)
55+
return jobs
56+
57+
def _launch_worker(exp_key, worker_id, host, port, result_db_name):
58+
command = "hyperopt-mongo-worker --mongo={h}:{p}/{db} --poll-interval=10 --exp-key={key} > hyperopt_worker{id}.log 2>&1"
59+
command = command.format(h=host, p=port, db=result_db_name, key=exp_key, id=worker_id)
60+
fail = os.system(command)
61+
if fail:
62+
raise RuntimeError("Problem starting hyperopt-mongo-worker.")
63+
64+
def _wait_result(exp_prefix, exp_name, timeout):
65+
"""
66+
Poll for the sync of params.pkl (currently hardcoded) from S3, indicating that the task is done.
67+
68+
:param exp_prefix: str, experiment name prefix (dir where results are expected to be stored)
69+
:param exp_name: str, experiment name. Name of dir below exp_prefix where result files of individual run are
70+
expected to be stored
71+
:param timeout: int, polling timeout in seconds
72+
:return bool. False if the polling times out. True if successful.
73+
"""
74+
result_path = os.path.join(config.LOG_DIR, "s3", exp_prefix, exp_name, 'params.pkl')
75+
print("Polling for results in",result_path)
76+
try:
77+
file_handle = polling.poll(
78+
lambda: open(result_path),
79+
ignore_exceptions=(IOError,),
80+
timeout=timeout,
81+
step=60)
82+
file_handle.close()
83+
except polling.TimeoutException:
84+
return False
85+
return True
86+
87+
def _launch_ec2(func, exp_prefix, exp_name, params, run_experiment_kwargs):
88+
print("Launching task", exp_name)
89+
kwargs = dict(
90+
n_parallel=1,
91+
snapshot_mode="last",
92+
seed=params.get("seed",None),
93+
mode="ec2"
94+
)
95+
kwargs.update(run_experiment_kwargs)
96+
kwargs.update(dict(
97+
exp_prefix=exp_prefix,
98+
exp_name=exp_name,
99+
variant=params,
100+
confirm_remote=False))
101+
102+
run_experiment_lite(func,**kwargs)
103+
104+
def _get_stubs(params):
105+
module_str = params.pop('task_module')
106+
func_str = params.pop('task_function')
107+
eval_module_str = params.pop('eval_module')
108+
eval_func_str = params.pop('eval_function')
109+
110+
module = importlib.import_module(module_str)
111+
func = getattr(module, func_str)
112+
eval_module = importlib.import_module(eval_module_str)
113+
eval_func = getattr(eval_module, eval_func_str)
114+
115+
return func, eval_func
116+
117+
task_id = 1
118+
def objective_fun(params):
119+
global task_id
120+
exp_prefix = params.pop("exp_prefix")
121+
exp_name = "{exp}_{pid}_{id}".format(exp=exp_prefix, pid=os.getpid(), id=task_id)
122+
max_retries = params.pop('max_retries', 0) + 1
123+
result_timeout = params.pop('result_timeout')
124+
run_experiment_kwargs = params.pop('run_experiment_kwargs', {})
125+
126+
func, eval_func = _get_stubs(params)
127+
128+
result_success = False
129+
while max_retries > 0:
130+
_launch_ec2(func, exp_prefix, exp_name, params, run_experiment_kwargs)
131+
task_id += 1; max_retries -= 1
132+
if _wait_result(exp_prefix, exp_name, result_timeout):
133+
result_success = True
134+
break
135+
elif max_retries > 0:
136+
print("Timed out waiting for results. Retrying...")
137+
138+
if not result_success:
139+
print("Reached max retries, no results. Giving up.")
140+
return {'status':STATUS_FAIL}
141+
142+
print("Results in! Processing.")
143+
result_dict = eval_func(exp_prefix, exp_name)
144+
result_dict['status'] = STATUS_OK
145+
result_dict['params'] = params
146+
return result_dict
147+
148+
149+
def launch_hyperopt_search(
150+
task_method,
151+
eval_method,
152+
param_space,
153+
hyperopt_experiment_key,
154+
hyperopt_db_host="localhost",
155+
hyperopt_db_port=1234,
156+
hyperopt_db_name="rllab",
157+
n_hyperopt_workers=1,
158+
hyperopt_max_evals=100,
159+
result_timeout=1200,
160+
max_retries=0,
161+
run_experiment_kwargs=None):
162+
"""
163+
Launch a hyperopt search using EC2.
164+
165+
This uses the hyperopt parallel processing functionality based on MongoDB. The MongoDB server at the specified host
166+
and port is assumed to be already running. Downloading and running MongoDB is pretty straightforward, see
167+
https://github.com/hyperopt/hyperopt/wiki/Parallelizing-Evaluations-During-Search-via-MongoDB for instructions.
168+
169+
The parameter space to be searched over is specified in param_space. See https://github.com/hyperopt/hyperopt/wiki/FMin,
170+
section "Defining a search space" for further info. Also see the (very basic) example in contrib.rllab_hyperopt.example.main.py.
171+
172+
NOTE: While the argument n_hyperopt_workers specifies the number of (local) parallel hyperopt workers to start, an equal
173+
number of EC2 instances will be started in parallel!
174+
NOTE2: Rllab currently terminates / starts a new EC2 instance for every task. This means what you'll pay amounts to
175+
hyperopt_max_evals * instance_hourly_rate. So you might want to be conservative with hyperopt_max_evals.
176+
177+
:param task_method: the stubbed method call that runs the actual task. Should take a single dict as argument, with
178+
the params to evaluate. See e.g. contrib.rllab_hyperopt.example.task.py
179+
:param eval_method: the stubbed method call that reads in results returned from S3 and produces a score. Should take
180+
the exp_prefix and exp_name as arguments (this is where S3 results will be synced to). See e.g.
181+
contrib.rllab_hyperopt.example.score.py
182+
:param param_space: dict specifying the param space to search. See https://github.com/hyperopt/hyperopt/wiki/FMin,
183+
section "Defining a search space" for further info
184+
:param hyperopt_experiment_key: str, the key hyperopt will use to store results in the DB
185+
:param hyperopt_db_host: str, optional (default "localhost"). The host where mongodb runs
186+
:param hyperopt_db_port: int, optional (default 1234), the port where mongodb is listening for connections
187+
:param hyperopt_db_name: str, optional (default "rllab"), the DB name where hyperopt will store results
188+
:param n_hyperopt_workers: int, optional (default 1). The nr of parallel workers to start. NOTE: an equal number of
189+
EC2 instances will be started in parallel.
190+
:param hyperopt_max_evals: int, optional (defailt 100). Number of parameterset evaluations hyperopt should try.
191+
NOTE: Rllab currently terminates / starts a new EC2 instance for every task. This means what you'll pay amounts to
192+
hyperopt_max_evals * instance_hourly_rate. So you might want to be conservative with hyperopt_max_evals.
193+
:param result_timeout: int, optional (default 1200). Nr of seconds to wait for results from S3 for a given task. If
194+
results are not in within this time frame, <max_retries> new attempts will be made. A new attempt entails launching
195+
the task again on a new EC2 instance.
196+
:param max_retries: int, optional (default 0). Number of times to retry launching a task when results don't come in from S3
197+
:param run_experiment_kwargs: dict, optional (default None). Further kwargs to pass to run_experiment_lite. Note that
198+
specified values for exp_prefix, exp_name, variant, and confirm_remote will be ignored.
199+
:return the best result as found by hyperopt.fmin
200+
"""
201+
exp_key = hyperopt_experiment_key
202+
203+
worker_args = {'exp_prefix':exp_key,
204+
'task_module':task_method.__module__,
205+
'task_function':task_method.__name__,
206+
'eval_module':eval_method.__module__,
207+
'eval_function':eval_method.__name__,
208+
'result_timeout':result_timeout,
209+
'max_retries':max_retries}
210+
211+
worker_args.update(param_space)
212+
if run_experiment_kwargs is not None:
213+
worker_args['run_experiment_kwargs'] = run_experiment_kwargs
214+
215+
trials = MongoTrials('mongo://{0}:{1:d}/{2}/jobs'.format(hyperopt_db_host, hyperopt_db_port, hyperopt_db_name),
216+
exp_key=exp_key)
217+
218+
workers = _launch_workers(exp_key, n_hyperopt_workers, hyperopt_db_host, hyperopt_db_port, hyperopt_db_name)
219+
220+
s3sync = S3SyncThread()
221+
s3sync.start()
222+
223+
print("Starting hyperopt")
224+
best = fmin(objective_fun, worker_args, trials=trials, algo=tpe.suggest, max_evals=hyperopt_max_evals)
225+
226+
s3sync.stop()
227+
s3sync.join()
228+
229+
for worker in workers:
230+
worker.terminate()
231+
232+
return best

‎contrib/rllab_hyperopt/example/__init__.py

Whitespace-only changes.
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
'''
2+
Main module to launch an example hyperopt search on EC2.
3+
4+
Launch this from outside the rllab main dir. Otherwise, rllab will try to ship the logfiles being written by this process,
5+
which will fail because tar doesn't want to tar files that are being written to. Alternatively, disable the packaging of
6+
log files by rllab, but I couldn't quickly find how to do this.
7+
8+
You can use Jupyter notebook visualize_hyperopt_results.ipynb to inspect results.
9+
'''
10+
from hyperopt import hp
11+
12+
from contrib.rllab_hyperopt.core import launch_hyperopt_search
13+
# the functions to run the task and process result do not need to be in separate files. They do need to be separate from
14+
# the main file though. Also, anything you import in the module that contains run_task needs to be on the Rllab AMI.
15+
# Therefore, since I use pandas to process results, I have put them in separate files here.
16+
from contrib.rllab_hyperopt.example.score import process_result
17+
from contrib.rllab_hyperopt.example.task import run_task
18+
19+
# define a search space. See https://github.com/hyperopt/hyperopt/wiki/FMin, sect 2 for more detail
20+
param_space = {'step_size': hp.uniform('step_size', 0.01, 0.1),
21+
'seed': hp.choice('seed',[0, 1, 2])}
22+
23+
# just by way of example, pass a different config to run_experiment_lite
24+
run_experiment_kwargs = dict(
25+
n_parallel=16,
26+
aws_config=dict(instance_type="c4.4xlarge",spot_price='0.7')
27+
)
28+
29+
launch_hyperopt_search(
30+
run_task, # the task to run
31+
process_result, # the function that will process results and return a score
32+
param_space, # param search space
33+
hyperopt_experiment_key='test12', # key for hyperopt DB, and also exp_prefix for run_experiment_lite
34+
n_hyperopt_workers=3, # nr of local workers AND nr of EC2 instances that will be started in parallel
35+
hyperopt_max_evals=5, # nr of parameter values to eval
36+
result_timeout=600, # wait this long for results from S3 before timing out
37+
run_experiment_kwargs=run_experiment_kwargs) # additional kwargs to pass to run_experiment_lite
+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
import pandas as pd
3+
4+
from rllab import config
5+
6+
def process_result(exp_prefix, exp_name):
7+
# Open the default rllab path for storing results
8+
result_path = os.path.join(config.LOG_DIR, "s3", exp_prefix, exp_name, 'progress.csv')
9+
print("Processing result from",result_path)
10+
11+
# This example uses pandas to easily read in results and create a simple smoothed learning curve
12+
df = pd.read_csv(result_path)
13+
curve = df['AverageReturn'].rolling(window=max(1,int(0.05*df.shape[0])), min_periods=1, center=True).mean().values.flatten()
14+
max_ix = curve.argmax()
15+
max_score = curve.max()
16+
17+
# The result dict can contain arbitrary values, but ALWAYS needs to have a "loss" entry.
18+
return dict(
19+
max_score=max_score,
20+
max_iter=max_ix,
21+
scores=curve, # returning the curve allows you to plot best, worst etc curve later
22+
loss=-max_score
23+
)
+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from rllab.algos.trpo import TRPO
2+
from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
3+
from rllab.envs.box2d.cartpole_env import CartpoleEnv
4+
from rllab.envs.normalized_env import normalize
5+
from rllab.policies.gaussian_mlp_policy import GaussianMLPPolicy
6+
7+
def run_task(v):
8+
env = normalize(CartpoleEnv())
9+
10+
policy = GaussianMLPPolicy(
11+
env_spec=env.spec,
12+
# The neural network policy should have two hidden layers, each with 32 hidden units.
13+
hidden_sizes=(32, 32)
14+
)
15+
16+
baseline = LinearFeatureBaseline(env_spec=env.spec)
17+
18+
algo = TRPO(
19+
env=env,
20+
policy=policy,
21+
baseline=baseline,
22+
batch_size=4000,
23+
max_path_length=100,
24+
n_itr=40,
25+
discount=0.99,
26+
step_size=v["step_size"],
27+
# Uncomment both lines (this and the plot parameter below) to enable plotting
28+
# plot=True,
29+
)
30+
algo.train()

‎contrib/rllab_hyperopt/visualize_hyperopt_results.ipynb

+1,866
Large diffs are not rendered by default.

‎environment.yml

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies:
2626
- pytorch==0.1.9
2727
- torchvision==0.1.6
2828
- mpi4py
29+
- pandas
2930
- pip:
3031
- Pillow
3132
- atari-py
@@ -59,3 +60,5 @@ dependencies:
5960
- numpy-stl==2.2.0
6061
- nibabel==2.1.0
6162
- pylru==1.0.9
63+
- hyperopt
64+
- polling

0 commit comments

Comments
 (0)
Please sign in to comment.