-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restoring from checkpoints given an experiment folder #16
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from os.path import dirname | ||
import glob | ||
import yaml | ||
import os | ||
import logging | ||
|
||
|
||
#path = '/media/toutouh/224001034000DF81/lipi-gan-public-checkpointing/lipizzaner-gan/src/output/lipizzaner_gan/distributed/mnist/2019-11-22_11-49-00/' | ||
|
||
|
||
def create_cell_info(source): | ||
splitted_source = source.split(':') | ||
return {'address': splitted_source[0], 'port': splitted_source[1], 'id': source} | ||
|
||
class ExperimentResuming(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename to |
||
|
||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
def __init__(self, experiment_path): | ||
self.checkpoints_storage = self.create_checkpoints_storage(experiment_path) | ||
|
||
def create_checkpoints_storage(self, experiment_path): | ||
assert os.path.isdir(experiment_path), 'Checkpoint of experiment in folder {} not found.'.format(experiment_path) | ||
self._logger.info('Recovering checkpoint information from checkpoint of the experiment stored in {}'.format(experiment_path)) | ||
return [self.get_checkpoint_structure(checkpoint_file) for checkpoint_file in glob.glob(experiment_path + '*/checkpoint*.yml')] | ||
|
||
def get_checkpoint_structure(self, checkpoint_file): | ||
assert os.path.isfile(checkpoint_file) | ||
|
||
def get_local_individuals(individual_type, checkpoint_data): | ||
return [individual['source'] for individual in checkpoint_data[individual_type]['individuals'] if individual['is_local']] | ||
|
||
def get_adjacent_individuals(individual_type, checkpoint_data): | ||
return [create_cell_info(individual['source']) for individual in checkpoint_data[individual_type]['individuals'] if | ||
not individual['is_local']] | ||
|
||
def get_learning_rate(individual_type, checkpoint_data): | ||
return checkpoint_data[individual_type]['learning_rate'] | ||
|
||
checkpoint_data = yaml.load(open(checkpoint_file)) | ||
checkpoint = dict() | ||
checkpoint['local_generators'] = get_local_individuals('generators', checkpoint_data) | ||
checkpoint['local_discriminators'] = get_local_individuals('discriminators', checkpoint_data) | ||
checkpoint['discriminators_learning_rate'] = get_learning_rate('discriminators', checkpoint_data) | ||
checkpoint['generators_learning_rate'] = get_learning_rate('generators', checkpoint_data) | ||
checkpoint['adjacent_individuals'] = get_adjacent_individuals('generators', checkpoint_data) | ||
|
||
checkpoint['iteration'] = checkpoint_data['iteration'] | ||
checkpoint['time'] = checkpoint_data['time'] | ||
checkpoint['id'] = checkpoint_data['id'] | ||
checkpoint['path'] = dirname(checkpoint_file) | ||
checkpoint['grid_size'] = checkpoint_data['grid_size'] | ||
checkpoint['position'] = (checkpoint_data['position']['x'], checkpoint_data['position']['y']) | ||
|
||
self._logger.info( | ||
'Recovered checkpoint information from checkpoint file {}.\nCheckpoint information: {}'.format(checkpoint_file, checkpoint)) | ||
|
||
return checkpoint | ||
|
||
def get_population_cell_info(self): | ||
client_soueces = self.get_population_sources() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. soeces=sources |
||
clients = [] | ||
for source in client_soueces: | ||
clients.append(create_cell_info(source)) | ||
return clients | ||
|
||
def get_population_sources(self): | ||
sources = set() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this a |
||
for checkpoint in self.checkpoints_storage: | ||
[sources.add(local_genarator) for local_genarator in checkpoint['local_generators']] | ||
[sources.add(local_discriminator) for local_discriminator in checkpoint['local_discriminators']] | ||
return list(sources) | ||
|
||
def get_local_generators_paths(self, source): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like code duplication for 2 methods, that could be simplified with function parameters e.g. |
||
assert source in self.get_population_sources() | ||
for checkpoint in self.checkpoints_storage: | ||
for local_source in checkpoint['local_generators']: | ||
if source == local_source: | ||
return glob.glob(checkpoint['path'] + '/generator-*.pkl') | ||
|
||
def get_local_discriminators_paths(self, source): | ||
assert source in self.get_population_sources() | ||
for checkpoint in self.checkpoints_storage: | ||
for local_source in checkpoint['local_discriminators']: | ||
if source == local_source: | ||
return glob.glob(checkpoint['path'] + '/discriminator-*.pkl') | ||
|
||
def get_iterations(self, source): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like code duplication for 2 methods, that could be simplified with function parameters e.g. |
||
assert source in self.get_population_sources() | ||
for checkpoint in self.checkpoints_storage: | ||
for local_source in checkpoint['local_discriminators']: | ||
if source == local_source: | ||
return checkpoint['iteration'] | ||
|
||
def get_id(self, source): | ||
assert source in self.get_population_sources() | ||
for checkpoint in self.checkpoints_storage: | ||
for local_source in checkpoint['local_discriminators']: | ||
if source == local_source: | ||
return checkpoint['id'] | ||
|
||
def get_iterations_id(self, id): | ||
assert 0 <= id <= len(self.checkpoints_storage) | ||
return [checkpoint['iteration'] for checkpoint in self.checkpoints_storage if checkpoint['id'] == id][0] | ||
|
||
def get_local_generators_paths_id(self, id): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, code duplication that could be simplified with parameters |
||
assert 0 <= id <= len(self.checkpoints_storage) | ||
return [glob.glob(checkpoint['path'] + '/generator-*.pkl') for checkpoint in self.checkpoints_storage if | ||
checkpoint['id'] == id][0] | ||
|
||
def get_local_discriminators_paths_id(self, id): | ||
assert 0 <= id <= len(self.checkpoints_storage) | ||
return [glob.glob(checkpoint['path'] + '/discriminator-*.pkl') for checkpoint in self.checkpoints_storage if | ||
checkpoint['id'] == id][0] | ||
|
||
def get_discriminators_learning_rate_id(self, id): | ||
assert 0 <= id <= len(self.checkpoints_storage) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, code duplication that could be simplified with parameters |
||
return [checkpoint['discriminators_learning_rate'] for checkpoint in self.checkpoints_storage if checkpoint['id'] == id][0] | ||
|
||
def get_generators_learning_rate_id(self, id): | ||
assert 0 <= id <= len(self.checkpoints_storage) | ||
return [checkpoint['generators_learning_rate'] for checkpoint in self.checkpoints_storage if checkpoint['id'] == id][0] | ||
|
||
def get_adjacent_cells_id(self, id): | ||
assert 0 <= id <= len(self.checkpoints_storage) | ||
return \ | ||
[checkpoint['adjacent_individuals'] for checkpoint in self.checkpoints_storage if checkpoint['id'] == id][0] | ||
|
||
def get_topology_details_id(self, id): | ||
assert 0 <= id <= len(self.checkpoints_storage) | ||
return \ | ||
[{'grid_size': checkpoint['grid_size'], 'position': checkpoint['position'], 'cell_info': create_cell_info(checkpoint['local_generators'][0])} for checkpoint in self.checkpoints_storage if checkpoint['id'] == id][0] | ||
|
||
def get_checkpoint_data_id(self, id): | ||
assert 0 <= id <= len(self.checkpoints_storage) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return {'iteration': self.get_iterations_id(id), | ||
'generators_path': self.get_local_generators_paths_id(id), | ||
'discriminators_path': self.get_local_discriminators_paths_id(id), | ||
'generators_learning_rate': self.get_generators_learning_rate_id(id), | ||
'discriminators_learning_rate': self.get_discriminators_learning_rate_id(id), | ||
'adjacent_cells': self.get_adjacent_cells_id(id), | ||
'topology_details': self.get_topology_details_id(id), | ||
'cell_number': id} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,9 +17,11 @@ | |
from helpers.math_helpers import is_square | ||
from helpers.network_helpers import get_network_devices | ||
from helpers.reproducible_helpers import set_random_seed | ||
from helpers.checkpointing import ExperimentResuming | ||
from training.mixture.mixed_generator_dataset import MixedGeneratorDataset | ||
from training.mixture.score_factory import ScoreCalculatorFactory | ||
|
||
|
||
GENERATOR_PREFIX = 'generator-' | ||
DISCRIMINATOR_PREFIX = 'discriminator-' | ||
|
||
|
@@ -32,6 +34,7 @@ def __init__(self): | |
self.heartbeat_event = None | ||
self.heartbeat_thread = None | ||
self.experiment_id = None | ||
self.checkpoint = None | ||
|
||
def run(self): | ||
if os.environ.get('DOCKER', False) == 'True': | ||
|
@@ -43,10 +46,18 @@ def run(self): | |
clients = self._load_available_clients() | ||
self.cc.settings['general']['distribution']['client_nodes'] = clients | ||
self._logger.info('Detected {} clients ({})'.format(len(clients), clients)) | ||
elif self.cc.settings['general']['distribution'].get('resuming', None) is not None: | ||
self.checkpoint = ExperimentResuming(self.cc.settings['general']['distribution']['resuming']) | ||
clients = self.checkpoint.get_population_cell_info() | ||
self.cc.settings['general']['distribution']['client_nodes'] = clients | ||
self._logger.info('Resuming to the following {} clients ({})'.format(len(clients), clients)) | ||
|
||
else: | ||
# Expand port ranges to multiple client entries | ||
self.expand_clients() | ||
clients = self.cc.settings['general']['distribution']['client_nodes'] | ||
self._logger.info('Detected {} clients ({})'.format(len(clients), clients)) | ||
|
||
accessible_clients = self._accessible_clients(clients) | ||
|
||
if len(accessible_clients) == 0 or not is_square(len(accessible_clients)): | ||
|
@@ -126,17 +137,36 @@ def _start_experiments(self): | |
self.experiment_id = db_logger.create_experiment(self.cc.settings) | ||
self.cc.settings['general']['logging']['experiment_id'] = self.experiment_id | ||
|
||
if not self.checkpoint is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be |
||
for client in self.cc.settings['general']['distribution']['client_nodes']: | ||
client_id = self.checkpoint.get_id('{}:{}'.format(client['address'], client['port'])) | ||
address = 'http://{}:{}/experiments'.format(client['address'], client['port']) | ||
self.cc.settings['general']['distribution']['client_id'] = client_id | ||
self.cc.settings['general']['distribution'][ | ||
'client_checkpoint'] = self.checkpoint.get_checkpoint_data_id(client_id) | ||
try: | ||
resp = requests.post(address, json=self.cc.settings) | ||
assert resp.status_code == 200, resp.text | ||
self._logger.info('Successfully started experiment on {}'.format(address)) | ||
except AssertionError as err: | ||
self._logger.critical('Could not start experiment on {}: {}'.format(address, err)) | ||
self._terminate() | ||
|
||
for client_id, client in enumerate(self.cc.settings['general']['distribution']['client_nodes']): | ||
address = 'http://{}:{}/experiments'.format(client['address'], client['port']) | ||
self.cc.settings['general']['distribution']['client_id'] = client_id | ||
try: | ||
resp = requests.post(address, json=self.cc.settings) | ||
assert resp.status_code == 200, resp.text | ||
self._logger.info('Successfully started experiment on {}'.format(address)) | ||
except AssertionError as err: | ||
self._logger.critical('Could not start experiment on {}: {}'.format(address, err)) | ||
self._terminate() | ||
else: | ||
|
||
#AQUI CONTROLAR LOS PARAMETROS PARA RESUMING | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is resuming correct Spanish :-) |
||
|
||
|
||
for client_id, client in enumerate(self.cc.settings['general']['distribution']['client_nodes']): | ||
address = 'http://{}:{}/experiments'.format(client['address'], client['port']) | ||
self.cc.settings['general']['distribution']['client_id'] = client_id | ||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this duplicated and could be moved outside the |
||
resp = requests.post(address, json=self.cc.settings) | ||
assert resp.status_code == 200, resp.text | ||
self._logger.info('Successfully started experiment on {}'.format(address)) | ||
except AssertionError as err: | ||
self._logger.critical('Could not start experiment on {}: {}'.format(address, err)) | ||
self._terminate() | ||
|
||
def _terminate(self, stop_clients=True, return_code=-1): | ||
try: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,6 +54,13 @@ def add_config_file(grp, is_required): | |
action='store_true', | ||
help='Start as long-running client node. Waits for master ' | ||
'to send experiment configuration, and runs them.') | ||
group_train.add_argument( | ||
'--checkpoint-folder', | ||
'-c', | ||
type=str, | ||
required=False, | ||
dest='checkpoint_folder', | ||
help='Folder of the experiment to recover from its checkpoints.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unclear english, is this the folder for restarting or saving checkpoints. I like to make an example in the help of how to use the flage.g |
||
group_distributed = group_train.add_mutually_exclusive_group(required='--distributed' in sys.argv) | ||
group_distributed.add_argument( | ||
'--master', | ||
|
@@ -102,6 +109,9 @@ def add_config_file(grp, is_required): | |
def initialize_settings(args): | ||
cc = ConfigurationContainer.instance() | ||
cc.settings = read_settings(args.configuration_file) | ||
if args.checkpoint_folder is not None: | ||
cc.settings['general']['distribution']['auto_discover'] = False | ||
cc.settings['general']['distribution']['resuming'] = args.checkpoint_folder | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if 'logging' in cc.settings['general'] and cc.settings['general']['logging']['enabled']: | ||
log_dir = os.path.join(cc.settings['general']['output_dir'], 'log') | ||
LogHelper.setup(cc.settings['general']['logging']['log_level'], log_dir) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,6 @@ | |
from helpers.population import Population, TYPE_GENERATOR, TYPE_DISCRIMINATOR | ||
from training.nn_trainer import NeuralNetworkTrainer | ||
|
||
|
||
class EvolutionaryAlgorithmTrainer(NeuralNetworkTrainer, ABC): | ||
_logger = logging.getLogger(__name__) | ||
|
||
|
@@ -94,16 +93,25 @@ def tournament_selection(self, population, population_type, is_logging=False): | |
|
||
return new_population | ||
|
||
def initialize_populations(self): | ||
def initialize_populations(self, checkpoint): | ||
populations = [None] * 2 | ||
populations[TYPE_GENERATOR] = Population(individuals=[], default_fitness=0, population_type=TYPE_GENERATOR) | ||
populations[TYPE_DISCRIMINATOR] = Population(individuals=[], default_fitness=0, | ||
population_type=TYPE_DISCRIMINATOR) | ||
|
||
for i in range(self._population_size): | ||
gen, dis = self.network_factory.create_both() | ||
populations[TYPE_GENERATOR].individuals.append(Individual(genome=gen, fitness=gen.default_fitness)) | ||
populations[TYPE_DISCRIMINATOR].individuals.append(Individual(genome=dis, fitness=dis.default_fitness)) | ||
if checkpoint is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe reverse branches and write |
||
for gen_path, dis_path in zip(checkpoint['generators_path'], checkpoint['discriminators_path']): | ||
gen, dis = self.network_factory.create_both() | ||
gen.net.load_state_dict(torch.load(gen_path)) | ||
dis.net.load_state_dict(torch.load(dis_path)) | ||
gen.net.eval() | ||
dis.net.eval() | ||
populations[TYPE_GENERATOR].individuals.append(Individual(genome=gen, fitness=gen.default_fitness)) | ||
populations[TYPE_DISCRIMINATOR].individuals.append(Individual(genome=dis, fitness=dis.default_fitness)) | ||
else: | ||
for i in range(self._population_size): | ||
gen, dis = self.network_factory.create_both() | ||
populations[TYPE_GENERATOR].individuals.append(Individual(genome=gen, fitness=gen.default_fitness)) | ||
populations[TYPE_DISCRIMINATOR].individuals.append(Individual(genome=dis, fitness=dis.default_fitness)) | ||
|
||
populations[TYPE_GENERATOR].default_fitness = populations[TYPE_GENERATOR].individuals[0].fitness | ||
populations[TYPE_DISCRIMINATOR].default_fitness = populations[TYPE_DISCRIMINATOR].individuals[0].fitness | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do the checkpoints keys have different names from the variables?