-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restoring from checkpoints given an experiment folder #16
base: master
Are you sure you want to change the base?
Conversation
self.grid_position = (checkpoint['topology_details']['position'][0], checkpoint['topology_details']['position'][1]) | ||
self.local_node = checkpoint['topology_details']['cell_info'] | ||
self.cell_number = checkpoint['cell_number'] | ||
self.neighbours = checkpoint['adjacent_cells'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do the checkpoints keys have different names from the variables?
splitted_source = source.split(':') | ||
return {'address': splitted_source[0], 'port': splitted_source[1], 'id': source} | ||
|
||
class ExperimentResuming(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename to checkpoint.py
and ExperimentRestarter
makes it easier to read
return checkpoint | ||
|
||
def get_population_cell_info(self): | ||
client_soueces = self.get_population_sources() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
soeces=sources
return clients | ||
|
||
def get_population_sources(self): | ||
sources = set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this a set
is it intentional to get different size populations?
[sources.add(local_discriminator) for local_discriminator in checkpoint['local_discriminators']] | ||
return list(sources) | ||
|
||
def get_local_generators_paths(self, source): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like code duplication for 2 methods, that could be simplified with function parameters e.g. get_checkpoint_local_values(self, source, key)
@@ -102,6 +109,9 @@ def add_config_file(grp, is_required): | |||
def initialize_settings(args): | |||
cc = ConfigurationContainer.instance() | |||
cc.settings = read_settings(args.configuration_file) | |||
if args.checkpoint_folder is not None: | |||
cc.settings['general']['distribution']['auto_discover'] = False | |||
cc.settings['general']['distribution']['resuming'] = args.checkpoint_folder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resume
= resuming
easier to read
gen, dis = self.network_factory.create_both() | ||
populations[TYPE_GENERATOR].individuals.append(Individual(genome=gen, fitness=gen.default_fitness)) | ||
populations[TYPE_DISCRIMINATOR].individuals.append(Individual(genome=dis, fitness=dis.default_fitness)) | ||
if checkpoint is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe reverse branches and write if checkpoint
for readability
|
||
# Resume status from checkpoint | ||
checkpoint = self.cc.settings['general']['distribution'].get('client_checkpoint', None) | ||
if checkpoint is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe reverse branches and write if checkpoint
for readability
n_iterations = self.cc.settings['trainer'].get('n_iterations', 0) | ||
assert 0 <= checkpoint_period <= n_iterations, 'Checkpoint period paramenter (checkpoint_period) should be ' \ | ||
'between 0 and the number of iterations (n_iterations).' | ||
self.checkpoint_period = self.cc.settings['general'].get('checkpoint_period', checkpoint_period) | ||
|
||
|
||
def read_ckeckpoint(self): | ||
checkpoint = self.cc.settings['general']['distribution']['client_checkpoint'] | ||
return checkpoint.get('iteration', 0), float(checkpoint.get('generators_learning_rate', None)), float(checkpoint.get('discriminators_learning_rate', None)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is None
a good default value?
|
||
for iteration in range(n_iterations): | ||
while(self.iteration < n_iterations): | ||
iteration = self.iteration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant?
No description provided.