@@ -21,6 +21,9 @@ def play_learning_agent(num_episodes=2, show_image=False, load_learning_filename
21
21
save_learning_filename = None , plot_filename = None , csv_filename = None , display_screen = False ,
22
22
state_representation = 'simple' , agent_type = 'subsumption' , exploration = None ,
23
23
distance_metric = None , random_seed = 123 ):
24
+ """
25
+ Let the learning agent play with the specified parameters.
26
+ """
24
27
logging .info ('Plot filename: {}' .format (plot_filename ))
25
28
logging .info ('Agent type: {}' .format (agent_type ))
26
29
logging .info ('Distance metric: {}' .format (distance_metric ))
@@ -68,105 +71,55 @@ def setup_logging(level):
68
71
datefmt = '%d-%m-%Y:%H:%M:%S' ,
69
72
level = LOGGING_LEVELS [level ])
70
73
71
-
72
74
def parse_command_line_arguments ():
73
75
"""
74
76
Parse the command-line arguments provided by the user.
75
77
"""
76
- parser = ArgumentParser (description = 'Reinforcement Learning with Q*bert .' )
78
+ parser = ArgumentParser (description = 'Reinforcement Learning with Qbert .' )
77
79
parser .add_argument ('-l' , '--logging_level' , default = 'info' , choices = LOGGING_LEVELS .keys (),
78
80
help = 'The logging level.' )
79
-
80
- subparsers = parser .add_subparsers ()
81
-
82
- args = parser .parse_args ()
81
+ parser .add_argument ('-e' , '--num_episodes' , default = 100 , type = int , help = 'The number of training episodes.' )
82
+ parser .add_argument ('-o' , '--load_learning_filename' , default = None ,
83
+ help = "The pickle file to load learning data from. To run the agent with pre-trained Q data, set"
84
+ " this parameter to 'data'" )
85
+ parser .add_argument ('-f' , '--save_learning_filename' , default = None ,
86
+ help = 'The pickle file to save learning data to.' )
87
+ parser .add_argument ('-p' , '--plot_filename' , default = None ,
88
+ help = 'The filename to save a score plot to.' )
89
+ parser .add_argument ('-c' , '--csv_filename' , default = None ,
90
+ help = 'The filename to save a score CSV file to.' )
91
+ parser .add_argument ('-d' , '--display_screen' , default = False , type = bool ,
92
+ help = 'Whether to display the ALE screen.' )
93
+ parser .add_argument ('-s' , '--state_representation' , default = 'simple' , choices = ['simple' , 'verbose' ],
94
+ help = 'The state representation to use.' )
95
+ parser .add_argument ('-a' , '--agent_type' , default = 'subsumption' ,
96
+ choices = ['block' , 'enemy' , 'friendly' , 'subsumption' , 'combined_verbose' ],
97
+ help = 'The agent type to use.' )
98
+ parser .add_argument ('-x' , '--exploration' , default = 'combined' , choices = ['random' , 'optimistic' , 'combined' ],
99
+ help = 'The exploration mode to use.' )
100
+ parser .add_argument ('-m' , '--distance_metric' , default = None , choices = ['manhattan' , 'hamming' , 'same_result' ],
101
+ help = 'The distance metric to use.' )
102
+ parser .add_argument ('-r' , '--random_seed' , default = None , type = int ,
103
+ help = 'The random seed to use.' )
104
+ parser .add_argument ('-i' , '--show_image' , default = False , type = bool ,
105
+ help = 'Whether to show a screenshot at the end of every episode.' )
106
+
107
+ args = parser .parse_args ('-help' .split ())
83
108
setup_logging (args .logging_level )
84
- args .func (args )
85
-
86
-
87
- def save_generalization_results ():
88
- distance_metric = 'no_generalization'
89
- play_learning_agent (num_episodes = 100 , plot_filename = distance_metric , csv_filename = distance_metric ,
90
- display_screen = False , agent_type = 'combined_verbose' , exploration = None , distance_metric = None )
91
-
92
- distance_metric = 'manhattan'
93
- play_learning_agent (num_episodes = 100 , plot_filename = distance_metric , csv_filename = distance_metric ,
94
- display_screen = False , agent_type = 'combined_verbose' , exploration = None ,
95
- distance_metric = distance_metric )
96
-
97
- distance_metric = 'hamming'
98
- play_learning_agent (num_episodes = 100 , plot_filename = distance_metric , csv_filename = distance_metric ,
99
- display_screen = False , agent_type = 'combined_verbose' , exploration = None ,
100
- distance_metric = distance_metric )
101
-
102
- distance_metric = 'same_result'
103
- play_learning_agent (num_episodes = 100 , plot_filename = distance_metric , csv_filename = distance_metric ,
104
- display_screen = False , agent_type = 'combined_verbose' , exploration = None ,
105
- distance_metric = distance_metric )
106
-
107
- filename = 'subsumption_generalization'
108
- play_learning_agent (num_episodes = 100 , plot_filename = filename , csv_filename = filename ,
109
- display_screen = False , agent_type = 'subsumption' , exploration = None ,
110
- distance_metric = None , save_learning_filename = 'subsumption_dangerous_no_exploration' )
111
-
112
-
113
- def save_exploration_results ():
114
- filename = 'subsumption_random'
115
- play_learning_agent (num_episodes = 100 , plot_filename = filename , csv_filename = filename ,
116
- display_screen = False , agent_type = 'subsumption' , exploration = 'random' ,
117
- distance_metric = None , save_learning_filename = 'subsumption_dangerous_random' )
118
-
119
- filename = 'subsumption_optimistic'
120
- play_learning_agent (num_episodes = 100 , plot_filename = filename , csv_filename = filename ,
121
- display_screen = False , agent_type = 'subsumption' , exploration = 'optimistic' ,
122
- distance_metric = None , save_learning_filename = 'subsumption_dangerous_optimistic' )
123
-
124
- filename = 'subsumption_combined'
125
- play_learning_agent (num_episodes = 100 , plot_filename = filename , csv_filename = filename ,
126
- display_screen = False , agent_type = 'subsumption' , exploration = 'combined' ,
127
- distance_metric = None , save_learning_filename = 'subsumption_dangerous_combined' )
128
-
129
-
130
- def save_performance_results ():
131
- filename = 'seed123'
132
- play_learning_agent (num_episodes = 100 , plot_filename = filename , csv_filename = filename ,
133
- display_screen = False , agent_type = 'subsumption' , exploration = 'combined' ,
134
- distance_metric = None , save_learning_filename = 'subsumption_dangerous_combined_123' ,
135
- random_seed = 123 )
136
-
137
- filename = 'seed459'
138
- play_learning_agent (num_episodes = 100 , plot_filename = filename , csv_filename = filename ,
139
- display_screen = False , agent_type = 'subsumption' , exploration = 'combined' ,
140
- distance_metric = None , save_learning_filename = 'subsumption_dangerous_combined_459' ,
141
- random_seed = 459 )
142
-
143
- filename = 'seed598'
144
- play_learning_agent (num_episodes = 100 , plot_filename = filename , csv_filename = filename ,
145
- display_screen = False , agent_type = 'subsumption' , exploration = 'combined' ,
146
- distance_metric = None , save_learning_filename = 'subsumption_dangerous_combined_598' ,
147
- random_seed = 459 )
148
-
149
-
150
- def continued_learning ():
151
- filename = 'seed459_600'
152
- play_learning_agent (num_episodes = 100 , plot_filename = filename , csv_filename = filename ,
153
- display_screen = False , agent_type = 'subsumption' , exploration = 'combined' ,
154
- distance_metric = None , save_learning_filename = 'subsumption_dangerous_combined_459_600' ,
155
- random_seed = 459 , load_learning_filename = 'subsumption_dangerous_combined_459_500' )
156
-
157
-
158
- def sample_play ():
159
- play_learning_agent (num_episodes = 100 ,
160
- display_screen = True , agent_type = 'subsumption' , exploration = 'combined' ,
161
- distance_metric = None ,
162
- random_seed = 459 , load_learning_filename = 'subsumption_dangerous_combined_459_400' )
109
+ play_learning_agent (num_episodes = args .num_episodes ,
110
+ load_learning_filename = args .load_learning_filename ,
111
+ save_learning_filename = args .save_learning_filename ,
112
+ plot_filename = args .plot_filename ,
113
+ csv_filename = args .csv_filename ,
114
+ display_screen = args .display_screen ,
115
+ state_representation = args .state_representation ,
116
+ agent_type = args .agent_type ,
117
+ exploration = args .exploration ,
118
+ distance_metric = args .distance_metric ,
119
+ random_seed = args .random_seed ,
120
+ show_image = args .show_image )
163
121
164
122
165
123
if __name__ == '__main__' :
166
124
setup_logging ('info' )
167
- # play_learning_agent()
168
- # save_generalization_results()
169
- # save_exploration_results()
170
- # save_performance_results()
171
- continued_learning ()
172
- # sample_play()
125
+ parse_command_line_arguments ()
0 commit comments