|
35 | 35 | from keras.layers import Dense, Activation, Flatten
|
36 | 36 | from keras.optimizers import Adam
|
37 | 37 |
|
| 38 | +from rl.agents.cem import CEMAgent |
38 | 39 | from rl.agents.dqn import DQNAgent
|
| 40 | +from rl.agents.sarsa import SarsaAgent |
39 | 41 | from rl.policy import BoltzmannQPolicy, GreedyQPolicy, EpsGreedyQPolicy
|
40 |
| -from rl.memory import SequentialMemory |
| 42 | +from rl.memory import SequentialMemory, EpisodeParameterMemory |
41 | 43 | from rl.callbacks import FileLogger, ModelIntervalCheckpoint
|
42 | 44 |
|
43 | 45 | ENV_NAME = 'planar_crane-v0'
|
44 | 46 |
|
45 | 47 | LAYER_SIZE = 1024
|
46 | 48 | NUM_HIDDEN_LAYERS = 4
|
47 |
| -NUM_STEPS = 100000 |
48 |
| -DUEL_DQN = True |
| 49 | +NUM_STEPS = 50000 |
| 50 | +METHOD = 'DUEL_DQN' # can be DQN, DUEL_DQN, SARSA, or CEM |
49 | 51 | TRIAL_ID = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')
|
50 | 52 |
|
| 53 | +# Define the filenames to use for this session |
| 54 | +WEIGHT_FILENAME = 'weights/{}_{}_weights_{}_{}_{}_{}.h5f'.format(METHOD, ENV_NAME, LAYER_SIZE, NUM_HIDDEN_LAYERS, NUM_STEPS, TRIAL_ID) |
| 55 | +CHECKPOINT_WEIGHTS_FILENAME = 'weights/{}_{}_checkpointWeights_{{step}}_{}_{}_{}_{}.h5f'.format(METHOD, ENV_NAME, LAYER_SIZE, NUM_HIDDEN_LAYERS, NUM_STEPS, TRIAL_ID) |
| 56 | +LOG_FILENAME = 'logs/{}_{}_log_{}_{}_{}_{}.json'.format(METHOD, ENV_NAME, LAYER_SIZE, NUM_HIDDEN_LAYERS, NUM_STEPS, TRIAL_ID) |
| 57 | +MONITOR_FILENAME = 'example_data/{}_{}_monitor_{}_{}_{}_{}'.format(METHOD, ENV_NAME, LAYER_SIZE, NUM_HIDDEN_LAYERS, NUM_STEPS, TRIAL_ID) |
| 58 | + |
| 59 | + |
51 | 60 | # Get the environment and extract the number of actions.
|
52 | 61 | env = gym.make(ENV_NAME)
|
53 | 62 |
|
54 | 63 | # Record episode data?
|
55 | 64 | env.SAVE_DATA = False
|
56 | 65 |
|
57 |
| -# uncomment to record data about the training session, including video if visualize is true |
58 |
| - |
59 |
| -# uncomment to record data about the training session, including video if visualize is true |
60 |
| -if DUEL_DQN: |
61 |
| - MONITOR_FILENAME = 'example_data/duel_dqn_{}_monitor_{}_{}_{}_{}'.format(ENV_NAME, |
62 |
| - LAYER_SIZE, |
63 |
| - NUM_HIDDEN_LAYERS, |
64 |
| - NUM_STEPS, |
65 |
| - TRIAL_ID) |
66 |
| -else: |
67 |
| - MONITOR_FILENAME = 'example_data/dqn_{}_monitor_{}_{}_{}_{}'.format(ENV_NAME, |
68 |
| - LAYER_SIZE, |
69 |
| - NUM_HIDDEN_LAYERS, |
70 |
| - NUM_STEPS, |
71 |
| - TRIAL_ID) |
| 66 | +# uncomment to record data about the training session, including video if video_callable is true |
72 | 67 | env = gym.wrappers.Monitor(env, MONITOR_FILENAME, video_callable=False, force=True)
|
73 | 68 |
|
74 | 69 | np.random.seed(123)
|
|
91 | 86 | model.add(Activation('linear'))
|
92 | 87 | print(model.summary())
|
93 | 88 |
|
| 89 | + |
94 | 90 | # Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
|
95 | 91 | # even the metrics!
|
96 |
| -memory = SequentialMemory(limit=NUM_STEPS, window_length=1) |
| 92 | + |
97 | 93 | # train_policy = BoltzmannQPolicy(tau=0.05)
|
98 | 94 | train_policy = EpsGreedyQPolicy()
|
99 | 95 | test_policy = GreedyQPolicy()
|
100 | 96 |
|
101 |
| -if DUEL_DQN: |
102 |
| - dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=100, |
| 97 | +# Compile the agent based on method specified. We use .upper() to convert to |
| 98 | +# upper case for comparison |
| 99 | +if METHOD.upper() == 'DUEL_DQN': |
| 100 | + memory = SequentialMemory(limit=NUM_STEPS, window_length=1) |
| 101 | + agent = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=100, |
103 | 102 | enable_dueling_network=True, dueling_type='avg', target_model_update=1e-2,
|
104 | 103 | policy=train_policy, test_policy=test_policy)
|
105 |
| - |
106 |
| - filename = 'weights/duel_dqn_{}_weights_{}_{}_{}_{}.h5f'.format(ENV_NAME, LAYER_SIZE, NUM_HIDDEN_LAYERS, NUM_STEPS, TRIAL_ID) |
107 |
| - checkpoint_weights_filename = 'logs/duel_dqn_{}_checkpointWeights_{{step}}_{}_{}_{}_{}.h5f'.format(ENV_NAME, LAYER_SIZE, NUM_HIDDEN_LAYERS, NUM_STEPS, TRIAL_ID) |
108 |
| - log_filename = 'logs/duel_dqn_{}_log_{}_{}_{}_{}.json'.format(ENV_NAME, LAYER_SIZE, NUM_HIDDEN_LAYERS, NUM_STEPS, TRIAL_ID) |
| 104 | + agent.compile(Adam(lr=1e-3), metrics=['mae']) |
109 | 105 |
|
110 |
| -else: |
111 |
| - dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=100, |
| 106 | +elif METHOD.upper() == 'DQN': |
| 107 | + memory = SequentialMemory(limit=NUM_STEPS, window_length=1) |
| 108 | + agent = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=100, |
112 | 109 | target_model_update=1e-2, policy=train_policy, test_policy=test_policy)
|
113 |
| - |
114 |
| - filename = 'weights/dqn_{}_weights_{}_{}_{}_{}.h5f'.format(ENV_NAME, LAYER_SIZE, NUM_HIDDEN_LAYERS, NUM_STEPS, TRIAL_ID) |
115 |
| - checkpoint_weights_filename = 'weights/dqn_{}_checkpointWeights_{{step}}_{}_{}_{}_{}.h5f'.format(ENV_NAME, LAYER_SIZE, NUM_HIDDEN_LAYERS, NUM_STEPS, TRIAL_ID) |
116 |
| - log_filename = 'logs/dqn_{}_log_{}_{}_{}_{}.json'.format(ENV_NAME, LAYER_SIZE, NUM_HIDDEN_LAYERS, NUM_STEPS, TRIAL_ID) |
| 110 | + agent.compile(Adam(lr=1e-3), metrics=['mae']) |
117 | 111 |
|
| 112 | +elif METHOD.upper() == 'SARSA': |
| 113 | + # SARSA does not require a memory. |
| 114 | + agent = SarsaAgent(model=model, nb_actions=nb_actions, nb_steps_warmup=10, policy=train_policy) |
| 115 | + agent.compile(Adam(lr=1e-3), metrics=['mae']) |
| 116 | + |
| 117 | +elif METHOD.upper() == 'CEM': |
| 118 | + memory = EpisodeParameterMemory(limit=1000, window_length=1) |
| 119 | + agent = CEMAgent(model=model, nb_actions=nb_actions, memory=memory, |
| 120 | + batch_size=50, nb_steps_warmup=2000, train_interval=50, elite_frac=0.05) |
| 121 | + agent.compile() |
| 122 | + |
| 123 | +else: |
| 124 | + raise('Please select DQN, DUEL_DQN, SARSA, or CEM for your method type.') |
118 | 125 |
|
119 | 126 |
|
120 |
| -dqn.compile(Adam(lr=1e-3), metrics=['mae']) |
121 | 127 |
|
122 | 128 | callbacks = []
|
123 |
| -callbacks += [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=10000)] |
124 |
| -callbacks += [FileLogger(log_filename, interval=100)] |
| 129 | +# callbacks += [ModelIntervalCheckpoint(CHECKPOINT_WEIGHTS_FILENAME, interval=10000)] |
| 130 | +callbacks += [FileLogger(LOG_FILENAME, interval=100)] |
125 | 131 |
|
126 | 132 | # Optionally, we can reload a previous model's weights and continue training from there
|
127 |
| -# WEIGHTS_FILENAME = 'weights/duel_dqn_planar_crane-v0_weights_1024_4_50000_2017-07-12_160853.h5f' |
| 133 | +# LOAD_WEIGHTS_FILENAME = 'weights/duel_dqn_planar_crane-v0_weights_1024_4_50000_2017-07-12_160853.h5f' |
128 | 134 | # # # Load the model weights
|
129 |
| -# dqn.load_weights(WEIGHTS_FILENAME) |
| 135 | +# agent.load_weights(LOAD_WEIGHTS_FILENAME) |
130 | 136 |
|
131 | 137 | # Okay, now it's time to learn something! We visualize the training here for show, but this
|
132 | 138 | # slows down training quite a lot. You can always safely abort the training prematurely using
|
133 | 139 | # Ctrl + C.
|
134 |
| -dqn.fit(env, nb_steps=NUM_STEPS, callbacks=callbacks, visualize=False, verbose=1, nb_max_episode_steps=500) |
| 140 | +agent.fit(env, nb_steps=NUM_STEPS, callbacks=callbacks, visualize=False, verbose=1, nb_max_episode_steps=500) |
135 | 141 |
|
136 | 142 | # After training is done, we save the final weights.
|
137 |
| -dqn.save_weights(filename, overwrite=True) |
| 143 | +agent.save_weights(WEIGHT_FILENAME, overwrite=True) |
138 | 144 |
|
139 | 145 | # Finally, evaluate our algorithm for 5 episodes.
|
140 |
| -# dqn.test(env, nb_episodes=5, nb_max_episode_steps=500, visualize=True) |
| 146 | +agent.test(env, nb_episodes=5, nb_max_episode_steps=500, visualize=True) |
0 commit comments