From e0e686c40d23af2c7d222507a16e1713917c11f5 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Thu, 24 Oct 2024 15:03:43 +0200 Subject: [PATCH 1/4] Incorporated all changes from Kyle Miller. All code originally authored by Kyle Co-authored-by: Kyle Miller --- mala/common/parameters.py | 95 ++++++--- mala/common/physical_data.py | 53 ++++- mala/datahandling/data_handler.py | 264 +++++++++++++++++++++++-- mala/datahandling/data_handler_base.py | 55 +++++- mala/datahandling/snapshot.py | 51 ++++- 5 files changed, 454 insertions(+), 64 deletions(-) diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 28840ebec..c9dfabe7b 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -1556,7 +1556,7 @@ def load_from_file( Parameters ---------- - file : string or ZipExtFile + file : string or ZipExtFile or dict File to which the parameters will be saved to. save_format : string @@ -1585,37 +1585,51 @@ def load_from_file( json_dict = json.load(open(file, encoding="utf-8")) else: json_dict = json.load(file) + loaded_parameters = cls.load_from_dict( + json_dict, no_snapshots, force_no_ddp + ) + + elif save_format == "dict": + loaded_parameters = cls.load_from_dict( + file, no_snapshots, force_no_ddp + ) - loaded_parameters = cls() - for key in json_dict: - if ( - isinstance(json_dict[key], dict) - and key != "openpmd_configuration" - ): - # These are the other parameter classes. - sub_parameters = globals()[ - json_dict[key]["_parameters_type"] - ].from_json(json_dict[key]) - setattr(loaded_parameters, key, sub_parameters) - - # We iterate a second time, to set global values, so that they - # are properly forwarded. - for key in json_dict: - if ( - not isinstance(json_dict[key], dict) - or key == "openpmd_configuration" - ): - if key == "use_ddp" and force_no_ddp is True: - setattr(loaded_parameters, key, False) - else: - setattr(loaded_parameters, key, json_dict[key]) - if no_snapshots is True: - loaded_parameters.data.snapshot_directories_list = [] else: raise Exception("Unsupported parameter save format.") return loaded_parameters + @classmethod + def _process_loaded_parameters( + cls, parameters_dict, no_snapshots, force_no_ddp + ): + parameters_object = cls() + for key in parameters_dict: + if ( + isinstance(parameters_dict[key], dict) + and key != "openpmd_configuration" + ): + # These are the other parameter classes. + sub_parameters = globals()[ + parameters_dict[key]["_parameters_type"] + ].from_json(parameters_dict[key]) + setattr(parameters_object, key, sub_parameters) + + # We iterate a second time, to set global values, so that they + # are properly forwarded. + for key in parameters_dict: + if ( + not isinstance(parameters_dict[key], dict) + or key == "openpmd_configuration" + ): + if key == "use_ddp" and force_no_ddp is True: + setattr(parameters_object, key, False) + else: + setattr(parameters_object, key, parameters_dict[key]) + if no_snapshots is True: + parameters_object.data.snapshot_directories_list = [] + return parameters_object + @classmethod def load_from_pickle(cls, file, no_snapshots=False): """ @@ -1666,3 +1680,32 @@ def load_from_json(cls, file, no_snapshots=False, force_no_ddp=False): no_snapshots=no_snapshots, force_no_ddp=force_no_ddp, ) + + @classmethod + def load_from_dict( + cls, param_dict, no_snapshots=False, force_no_ddp=False + ): + """ + Load a Parameters object from a file. + + Parameters + ---------- + param_dict : dictionary + Dictionary containing parameters to be loaded + + no_snapshots : bool + If True, than the snapshot list will be emptied. Useful when + performing inference/testing after training a network. + + Returns + ------- + loaded_parameters : Parameters + The loaded Parameters object. + + """ + return Parameters.load_from_file( + param_dict, + save_format="dict", + no_snapshots=no_snapshots, + force_no_ddp=force_no_ddp, + ) diff --git a/mala/common/physical_data.py b/mala/common/physical_data.py index 7ec85623d..039dfd9ba 100644 --- a/mala/common/physical_data.py +++ b/mala/common/physical_data.py @@ -5,7 +5,7 @@ import json import numpy as np -from mala.common.parallelizer import get_comm, get_rank +from mala.common.parallelizer import get_comm, get_rank, printout from mala.version import __version__ as mala_version @@ -69,7 +69,7 @@ def si_unit_conversion(self): ############################## def read_from_numpy_file( - self, path, units=None, array=None, reshape=False + self, path, units=None, array=None, reshape=False, selection_mask=None ): """ Read the data from a numpy file. @@ -86,6 +86,11 @@ def read_from_numpy_file( If not None, the array to save the data into. The array has to be 4-dimensional. + selection_mask : None or [boolean] + If None, entire snapshot is loaded, else it is used as a + mask to select which examples are loaded + + Returns ------- data : numpy.ndarray or None @@ -97,17 +102,44 @@ def read_from_numpy_file( if array is None: loaded_array = np.load(path)[:, :, :, self._feature_mask() :] self._process_loaded_array(loaded_array, units=units) - return loaded_array + + # Select portion of array if mask provided + if selection_mask is not None: + original_dims = loaded_array.shape + + # Pseudo-flatten to apply mask without causing dimensionality mismatch later on + loaded_array = loaded_array.reshape( + (-1, 1, 1, original_dims[-1]) + )[selection_mask] + print(f"post-mask array_dims = {loaded_array.shape}") + return loaded_array + else: + return loaded_array else: if reshape: array_dims = np.shape(array) - array[:, :] = np.load(path)[ - :, :, :, self._feature_mask() : - ].reshape(array_dims) + if selection_mask is not None: + array[:, :] = np.load(path)[ + :, :, :, self._feature_mask() : + ].reshape((len(selection_mask), -1))[selection_mask] + else: + array[:, :] = np.load(path)[ + :, :, :, self._feature_mask() : + ].reshape(array_dims) else: + array_dims = np.shape(array) array[:, :, :, :] = np.load(path)[ :, :, :, self._feature_mask() : ] + + # Select portion of array if mask provided + if selection_mask is not None: + # Pseudo-flatten to apply mask without causing + # dimensionality mismatch later on + array = array.reshape((-1, 1, 1, array_dims[-1]))[ + selection_mask + ] + print(f"post-mask array_dims = {array.shape}") self._process_loaded_array(array, units=units) def read_from_openpmd_file(self, path, units=None, array=None): @@ -252,7 +284,9 @@ def read_from_openpmd_file(self, path, units=None, array=None): else: self._process_loaded_array(array, units=units) - def read_dimensions_from_numpy_file(self, path, read_dtype=False): + def read_dimensions_from_numpy_file( + self, path, read_dtype=False, selection_mask=None + ): """ Read only the dimensions from a numpy file. @@ -265,6 +299,11 @@ def read_dimensions_from_numpy_file(self, path, read_dtype=False): If True, the dtype is read alongside the dimensions. """ loaded_array = np.load(path, mmap_mode="r") + if selection_mask is not None: + original_dims = loaded_array.shape + loaded_array = loaded_array.reshape((-1, 1, 1, original_dims[-1]))[ + selection_mask + ] if read_dtype: return ( self._process_loaded_dimensions(np.shape(loaded_array)), diff --git a/mala/datahandling/data_handler.py b/mala/datahandling/data_handler.py index 7b8fc2a43..eb38ada50 100644 --- a/mala/datahandling/data_handler.py +++ b/mala/datahandling/data_handler.py @@ -135,7 +135,7 @@ def clear_data(self): # Preparing data ###################### - def prepare_data(self, reparametrize_scaler=True): + def prepare_data(self, reparametrize_scaler=True, from_arrays_dict=None): """ Prepare the data to be used in a training process. @@ -151,6 +151,15 @@ def prepare_data(self, reparametrize_scaler=True): If True (default), the DataScalers are parametrized based on the training data. + from_arrays_dict : dict or None + (Allows user to provide data directly from memory) + Dictionary which assigns an array (values) to each snapshot, e.g., + {(0,'inputs') : fp_array, (0, 'outputs') : ldos_array, ...} where 0 + is the index of the snapshot (absolute, not relative to data + partition) and inputs/outputs indicates the nature of the array. + None value indicates the data should be pulled from disk according + to the snapshot objects. + """ # During data loading, there is no need to save target data to # calculators. @@ -166,7 +175,7 @@ def prepare_data(self, reparametrize_scaler=True): "Checking the snapshots and your inputs for consistency.", min_verbosity=1, ) - self._check_snapshots() + self._check_snapshots(from_arrays_dict=from_arrays_dict) printout("Consistency check successful.", min_verbosity=0) # If the DataHandler is used for inference, i.e. no training or @@ -188,7 +197,7 @@ def prepare_data(self, reparametrize_scaler=True): # Parametrize the scalers, if needed. if reparametrize_scaler: printout("Initializing the data scalers.", min_verbosity=1) - self.__parametrize_scalers() + self.__parametrize_scalers(from_arrays_dict=from_arrays_dict) printout("Data scalers initialized.", min_verbosity=0) elif ( self.parameters.use_lazy_loading is False @@ -198,12 +207,16 @@ def prepare_data(self, reparametrize_scaler=True): "Data scalers already initilized, loading data to RAM.", min_verbosity=0, ) - self.__load_data("training", "inputs") - self.__load_data("training", "outputs") + self.__load_data( + "training", "inputs", from_arrays_dict=from_arrays_dict + ) + self.__load_data( + "training", "outputs", from_arrays_dict=from_arrays_dict + ) # Build Datasets. printout("Build datasets.", min_verbosity=1) - self.__build_datasets() + self.__build_datasets(from_arrays_dict=from_arrays_dict) printout("Build dataset: Done.", min_verbosity=0) # After the loading is done, target data can safely be saved again. @@ -216,6 +229,150 @@ def prepare_data(self, reparametrize_scaler=True): # allows for parallel I/O. barrier() + def refresh_data( + self, from_arrays_dict=None, partitions=["tr", "va", "te"] + ): + """ + Replace tr, va, te data for next generation of active learning. + + + Parameters + ---------- + + from_arrays_dict : dict or None + (Allows user to provide data directly from memory) + Dictionary which assigns an array (values) to each snapshot, e.g., + {(0,'inputs') : fp_array, (0, 'outputs') : ldos_array, ...} where 0 + is the index of the snapshot (absolute, not relative to data + partition) and inputs/outputs indicates the nature of the array. + None value indicates the data should be pulled from disk according + to the snapshot objects. + + partitions: list + Specifies the partitions for which to reload data + """ + # During data loading, there is no need to save target data to + # calculators. + # Technically, this would be no issue, but due to technical reasons + # (i.e. float64 to float32 conversion) saving the data this way + # may create copies in memory. + self.target_calculator.save_target_data = False + + printout( + "Checking the snapshots and your inputs for consistency.", + min_verbosity=1, + ) + self.__check_snapshots(from_arrays_dict=from_arrays_dict) + printout("Consistency check successful.", min_verbosity=0) + + # Reallocate arrays for data storage + if self.parameters.data_splitting_type == "by_snapshot": + ( + self.nr_training_snapshots, + self.nr_training_data, + self.nr_test_snapshots, + self.nr_test_data, + self.nr_validation_snapshots, + self.nr_validation_data, + ) = (0, 0, 0, 0, 0, 0) + # pprint(vars(self)) + # pprint(vars(self.parameters)) + snapshot: Snapshot + # As we are not actually interested in the number of snapshots, + # but in the number of datasets, we also need to multiply by that. + + for i, snapshot in enumerate( + self.parameters.snapshot_directories_list + ): + # if snapshot._selection_mask: + # snapshot.grid_size = sum(snapshot._selection_mask) + printout( + f"Snapshot {i}: {snapshot.grid_size}", min_verbosity=3 + ) + if snapshot.snapshot_function == "tr": + self.nr_training_snapshots += 1 + self.nr_training_data += snapshot.grid_size + elif snapshot.snapshot_function == "te": + self.nr_test_snapshots += 1 + self.nr_test_data += snapshot.grid_size + elif snapshot.snapshot_function == "va": + self.nr_validation_snapshots += 1 + self.nr_validation_data += snapshot.grid_size + else: + raise Exception( + "Unknown option for snapshot splitting " "selected." + ) + + # Now we need to check whether or not this input is believable. + nr_of_snapshots = len(self.parameters.snapshot_directories_list) + if nr_of_snapshots != ( + self.nr_training_snapshots + + self.nr_test_snapshots + + self.nr_validation_snapshots + ): + raise Exception( + "Cannot split snapshots with specified " + "splitting scheme, " + "too few or too many options selected: " + f"[{nr_of_snapshots} != {self.nr_training_snapshots} + {self.nr_test_snapshots} + {self.nr_validation_snapshots}]" + ) + + # MALA can either be run in training or test-only mode. + # But it has to be run in either of those! + # So either training AND validation snapshots can be provided + # OR only test snapshots. + if self.nr_test_snapshots != 0: + if self.nr_training_snapshots == 0: + printout( + "DataHandler prepared for inference. No training " + "possible with this setup. If this is not what " + "you wanted, please revise the input script. " + "Validation snapshots you may have entered will" + "be ignored.", + min_verbosity=0, + ) + else: + if self.nr_training_snapshots == 0: + raise Exception("No training snapshots provided.") + if self.nr_validation_snapshots == 0: + raise Exception("No validation snapshots provided.") + else: + raise Exception("Wrong parameter for data splitting provided.") + + self.__allocate_arrays() + + ### Load updated data + expand_partition_name = { + "tr": "training", + "va": "validation", + "te": "test", + } + for partition in partitions: + self.__load_data( + expand_partition_name[partition], + "inputs", + from_arrays_dict=from_arrays_dict, + ) + self.__load_data( + expand_partition_name[partition], + "outputs", + from_arrays_dict=from_arrays_dict, + ) + + # After the loading is done, target data can safely be saved again. + self.target_calculator.save_target_data = True + + printout("Build datasets.", min_verbosity=1) + self.__build_datasets(from_arrays_dict=from_arrays_dict) + printout("Build dataset: Done.", min_verbosity=0) + + # Wait until all ranks are finished with data preparation. + # It is not uncommon that ranks might be asynchronous in their + # data preparation by a small amount of minutes. If you notice + # an elongated wait time at this barrier, check that your file system + # allows for parallel I/O. + barrier() + def prepare_for_testing(self): """ Prepare DataHandler for usage within Tester class. @@ -409,9 +566,9 @@ def resize_snapshots_for_debugging( # Loading data ###################### - def _check_snapshots(self): + def __check_snapshots(self, from_arrays_dict=None): """Check the snapshots for consistency.""" - super(DataHandler, self)._check_snapshots() + super(DataHandler, self)._check_snapshots(from_arrays_dict) # Now we need to confirm that the snapshot list has some inner # consistency. @@ -444,7 +601,8 @@ def _check_snapshots(self): raise Exception( "Cannot split snapshots with specified " "splitting scheme, " - "too few or too many options selected" + "too few or too many options selected: " + f"[{nr_of_snapshots} != {self.nr_training_snapshots} + {self.nr_test_snapshots} + {self.nr_validation_snapshots}]" ) # MALA can either be run in training or test-only mode. # But it has to be run in either of those! @@ -508,7 +666,7 @@ def __allocate_arrays(self): dtype=DEFAULT_NP_DATA_DTYPE, ) - def __load_data(self, function, data_type): + def __load_data(self, function, data_type, from_arrays_dict=None): """ Load data into the appropriate arrays. @@ -545,8 +703,9 @@ def __load_data(self, function, data_type): snapshot_counter = 0 gs_old = 0 - for snapshot in self.parameters.snapshot_directories_list: - # get the snapshot grid size + for i, snapshot in enumerate( + self.parameters.snapshot_directories_list + ): # get the snapshot grid size gs_new = snapshot.grid_size # Data scaling is only performed on the training data sets. @@ -563,7 +722,56 @@ def __load_data(self, function, data_type): ) units = snapshot.output_units - if snapshot.snapshot_type == "numpy": + # Pull from existing array rather than file + if from_arrays_dict is not None: + if snapshot._selection_mask is not None: + gs_new = sum(snapshot._selection_mask) + # TODO streamline this + if snapshot._selection_mask is not None: + # Update data already in tensor form + if torch.is_tensor(getattr(self, array)): + getattr(self, array)[ + gs_old : gs_old + gs_new, : + ] = torch.from_numpy( + from_arrays_dict[(i, data_type)][ + :, calculator._feature_mask() : + ][snapshot._selection_mask] + ) + + # Update a fresh numpy array + else: + getattr(self, array)[ + gs_old : gs_old + gs_new, : + ] = from_arrays_dict[(i, data_type)][ + :, calculator._feature_mask() : + ][ + snapshot._selection_mask + ] + else: + # Update data already in tensor form + if torch.is_tensor(getattr(self, array)): + getattr(self, array)[ + gs_old : gs_old + gs_new, : + ] = torch.from_numpy( + from_arrays_dict[(i, data_type)][ + :, calculator._feature_mask() : + ] + ) + # Update a fresh numpy array + else: + getattr(self, array)[ + gs_old : gs_old + gs_new, : + ] = from_arrays_dict[(i, data_type)][ + :, calculator._feature_mask() : + ] + + calculator._process_loaded_array( + getattr(self, array)[gs_old : gs_old + gs_new, :], + units=units, + ) + + # Pull directly from file + elif snapshot.snapshot_type == "numpy": calculator.read_from_numpy_file( file, units=units, @@ -571,8 +779,13 @@ def __load_data(self, function, data_type): gs_old : gs_old + gs_new, : ], reshape=True, + selection_mask=snapshot._selection_mask, ) elif snapshot.snapshot_type == "openpmd": + if snapshot._selection_mask is not None: + raise NotImplementedError( + "Selection mask is not implemented for openpmd" + ) getattr(self, array)[gs_old : gs_old + gs_new] = ( calculator.read_from_openpmd_file( file, units=units @@ -623,7 +836,7 @@ def __load_data(self, function, data_type): self.test_data_outputs ).float() - def __build_datasets(self): + def __build_datasets(self, from_arrays_dict=None): """Build the DataSets that are used during training.""" if ( self.parameters.use_lazy_loading @@ -764,10 +977,14 @@ def __build_datasets(self): ) if self.nr_validation_data != 0: - self.__load_data("validation", "inputs") + self.__load_data( + "validation", "inputs", from_arrays_dict=from_arrays_dict + ) self.input_data_scaler.transform(self.validation_data_inputs) - self.__load_data("validation", "outputs") + self.__load_data( + "test", "outputs", from_arrays_dict=from_arrays_dict + ) self.output_data_scaler.transform(self.validation_data_outputs) if self.parameters.use_fast_tensor_data_set: printout("Using FastTensorDataset.", min_verbosity=2) @@ -802,7 +1019,7 @@ def __build_datasets(self): # Scaling ###################### - def __parametrize_scalers(self): + def __parametrize_scalers(self, from_arrays_dict=None): """Use the training data to parametrize the DataScalers.""" ################## # Inputs. @@ -828,6 +1045,7 @@ def __parametrize_scalers(self): snapshot.input_npy_file, ), units=snapshot.input_units, + selection_mask=snapshot._selection_mask, ) elif snapshot.snapshot_type == "openpmd": tmp = ( @@ -858,7 +1076,9 @@ def __parametrize_scalers(self): self.input_data_scaler.finish_incremental_fitting() else: - self.__load_data("training", "inputs") + self.__load_data( + "training", "inputs", from_arrays_dict=from_arrays_dict + ) self.input_data_scaler.fit(self.training_data_inputs) printout("Input scaler parametrized.", min_verbosity=1) @@ -880,6 +1100,10 @@ def __parametrize_scalers(self): # We need to perform the data scaling over the entirety of the # training data. for snapshot in self.parameters.snapshot_directories_list: + if snapshot._selection_mask is not None: + raise NotImplementedError( + "Example selection hasn't been implemented for lazy loading yet." + ) # Data scaling is only performed on the training data sets. if snapshot.snapshot_function == "tr": if snapshot.snapshot_type == "numpy": @@ -917,7 +1141,9 @@ def __parametrize_scalers(self): self.output_data_scaler.finish_incremental_fitting() else: - self.__load_data("training", "outputs") + self.__load_data( + "training", "outputs", from_arrays_dict=from_arrays_dict + ) self.output_data_scaler.fit(self.training_data_outputs) printout("Output scaler parametrized.", min_verbosity=1) diff --git a/mala/datahandling/data_handler_base.py b/mala/datahandling/data_handler_base.py index 54e27e959..938a2db12 100644 --- a/mala/datahandling/data_handler_base.py +++ b/mala/datahandling/data_handler_base.py @@ -92,6 +92,7 @@ def add_snapshot( input_units="None", calculation_output_file="", snapshot_type="numpy", + selection_mask=None, ): """ Add a snapshot to the data pipeline. @@ -130,7 +131,16 @@ def add_snapshot( snapshot_type : string Either "numpy" or "openpmd" based on what kind of files you want to operate on. + + selection_mask : None or [boolean] + If None, entire snapshot is loaded, if [boolean], it is used as a + mask to select which examples are loaded """ + if selection_mask is not None and self.parameters.use_lazy_loading: + raise NotImplementedError( + "Example selection hasn't been " + "implemented for lazy loading yet." + ) snapshot = Snapshot( input_file, input_directory, @@ -159,13 +169,15 @@ def clear_data(self): # Loading data ###################### - def _check_snapshots(self, comm=None): + def _check_snapshots(self, from_arrays_dict=None, comm=None): """Check the snapshots for consistency.""" self.nr_snapshots = len(self.parameters.snapshot_directories_list) # Read the snapshots using a memorymap to see if there is consistency. firstsnapshot = True - for snapshot in self.parameters.snapshot_directories_list: + for i, snapshot in enumerate( + self.parameters.snapshot_directories_list + ): #################### # Descriptors. #################### @@ -177,7 +189,28 @@ def _check_snapshots(self, comm=None): snapshot.input_npy_directory, min_verbosity=1, ) - if snapshot.snapshot_type == "numpy": + if from_arrays_dict is not None: + printout( + f'arrdim: {from_arrays_dict[(i, "inputs")].shape}', + min_verbosity=2, + ) + printout( + f"featmask: {self.descriptor_calculator._feature_mask()}", + min_verbosity=2, + ) + tmp_dimension = from_arrays_dict[(i, "inputs")][ + :, self.descriptor_calculator._feature_mask() : + ].shape + # We don't need any reference to full grid dim at this point + # so this is just for compatibility w other code + if len(tmp_dimension) > 2: + raise ValueError("Flatten the data pool arrays.") + tmp_dimension = (tmp_dimension[0], 1, 1, tmp_dimension[-1]) + printout( + f"from_arrays_dict dim {i}: {from_arrays_dict[(i, 'inputs')].shape}", + min_verbosity=2, + ) + elif snapshot.snapshot_type == "numpy": tmp_dimension = ( self.descriptor_calculator.read_dimensions_from_numpy_file( os.path.join( @@ -200,6 +233,11 @@ def _check_snapshots(self, comm=None): # for flexible grid sizes only this need be consistent tmp_input_dimension = tmp_dimension[-1] tmp_grid_dim = tmp_dimension[0:3] + + # If using selection_mask, apply to dimensions + if snapshot._selection_mask is not None: + tmp_grid_dim = (sum(snapshot._selection_mask), 1, 1) + snapshot.grid_dimension = tmp_grid_dim snapshot.grid_size = int(np.prod(snapshot.grid_dimension)) if firstsnapshot: @@ -221,7 +259,16 @@ def _check_snapshots(self, comm=None): snapshot.output_npy_directory, min_verbosity=1, ) - if snapshot.snapshot_type == "numpy": + if from_arrays_dict is not None: + tmp_dimension = from_arrays_dict[(i, "outputs")][ + :, self.target_calculator._feature_mask() : + ].shape + # We don't need any reference to full grid dim at this point + # so this is just for compatibility w other code + if len(tmp_dimension) > 2: + raise ValueError("Flatten the data pool arrays.") + tmp_dimension = (tmp_dimension[0], 1, 1, tmp_dimension[-1]) + elif snapshot.snapshot_type == "numpy": tmp_dimension = ( self.target_calculator.read_dimensions_from_numpy_file( os.path.join( diff --git a/mala/datahandling/snapshot.py b/mala/datahandling/snapshot.py index 8f6bc4666..6aea6cf03 100644 --- a/mala/datahandling/snapshot.py +++ b/mala/datahandling/snapshot.py @@ -1,5 +1,7 @@ """Represents an entire atomic snapshot (including descriptor/target data).""" +import numpy as np + from mala.common.json_serializable import JSONSerializable @@ -45,6 +47,10 @@ class Snapshot(JSONSerializable): Replaces the old approach of MALA to have a separate list. Default is None. + + selection_mask : None or [boolean] + If None, entire snapshot is loaded, if [boolean], it is used as a + mask to select which examples are loaded """ def __init__( @@ -58,6 +64,7 @@ def __init__( output_units="", calculation_output="", snapshot_type="openpmd", + selection_mask=None, ): super(Snapshot, self).__init__() @@ -87,6 +94,21 @@ def __init__( self.input_dimension = None self.output_dimension = None + # Mask determining which examples from the snapshot to use + if isinstance(selection_mask, np.ndarray): + self._selection_mask = selection_mask.tolist() + else: + self._selection_mask = selection_mask + + def set_selection_mask(self, selection_mask): + if isinstance(selection_mask, np.ndarray): + self._selection_mask = selection_mask.tolist() + else: + self._selection_mask = selection_mask + if selection_mask is not None: + self.grid_size = sum(self._selection_mask) + # TODO also adjust other dimensinot params + @classmethod def from_json(cls, json_dict): """ @@ -104,14 +126,27 @@ def from_json(cls, json_dict): The object as read from the JSON file. """ - deserialized_object = cls( - json_dict["input_npy_file"], - json_dict["input_npy_directory"], - json_dict["output_npy_file"], - json_dict["output_npy_directory"], - json_dict["snapshot_function"], - json_dict["snapshot_type"], - ) + # Temporary try,except for compatibility with + # pre-selection_mask parameter dicts TODO-remove + try: + deserialized_object = cls( + json_dict["input_npy_file"], + json_dict["input_npy_directory"], + json_dict["output_npy_file"], + json_dict["output_npy_directory"], + json_dict["snapshot_function"], + json_dict["snapshot_type"], + json_dict["selection_mask"], + ) + except: + deserialized_object = cls( + json_dict["input_npy_file"], + json_dict["input_npy_directory"], + json_dict["output_npy_file"], + json_dict["output_npy_directory"], + json_dict["snapshot_function"], + json_dict["snapshot_type"], + ) for key in json_dict: setattr(deserialized_object, key, json_dict[key]) return deserialized_object From cb4ab02ce20db9f45a10ac9d1382eecbc443c3bd Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Thu, 24 Oct 2024 15:40:24 +0200 Subject: [PATCH 2/4] Fixed some merging errors --- mala/datahandling/data_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mala/datahandling/data_handler.py b/mala/datahandling/data_handler.py index eb38ada50..cf49bbf56 100644 --- a/mala/datahandling/data_handler.py +++ b/mala/datahandling/data_handler.py @@ -262,7 +262,7 @@ def refresh_data( "Checking the snapshots and your inputs for consistency.", min_verbosity=1, ) - self.__check_snapshots(from_arrays_dict=from_arrays_dict) + self._check_snapshots(from_arrays_dict=from_arrays_dict) printout("Consistency check successful.", min_verbosity=0) # Reallocate arrays for data storage @@ -566,7 +566,7 @@ def resize_snapshots_for_debugging( # Loading data ###################### - def __check_snapshots(self, from_arrays_dict=None): + def _check_snapshots(self, from_arrays_dict=None): """Check the snapshots for consistency.""" super(DataHandler, self)._check_snapshots(from_arrays_dict) @@ -983,7 +983,7 @@ def __build_datasets(self, from_arrays_dict=None): self.input_data_scaler.transform(self.validation_data_inputs) self.__load_data( - "test", "outputs", from_arrays_dict=from_arrays_dict + "validation", "outputs", from_arrays_dict=from_arrays_dict ) self.output_data_scaler.transform(self.validation_data_outputs) if self.parameters.use_fast_tensor_data_set: From a52a7f3576dc03c63ecdbaef2a7efff4152f2005 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Thu, 24 Oct 2024 16:32:24 +0200 Subject: [PATCH 3/4] Fixed all the tests and CI in general --- mala/common/parameters.py | 8 +++----- mala/common/physical_data.py | 2 -- mala/datahandling/data_handler.py | 2 +- mala/datahandling/snapshot.py | 1 + mala/interfaces/ase_calculator.py | 4 ++-- 5 files changed, 7 insertions(+), 10 deletions(-) diff --git a/mala/common/parameters.py b/mala/common/parameters.py index c9dfabe7b..77185f40f 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -1585,12 +1585,12 @@ def load_from_file( json_dict = json.load(open(file, encoding="utf-8")) else: json_dict = json.load(file) - loaded_parameters = cls.load_from_dict( + loaded_parameters = cls._process_loaded_dict( json_dict, no_snapshots, force_no_ddp ) elif save_format == "dict": - loaded_parameters = cls.load_from_dict( + loaded_parameters = cls._process_loaded_dict( file, no_snapshots, force_no_ddp ) @@ -1600,9 +1600,7 @@ def load_from_file( return loaded_parameters @classmethod - def _process_loaded_parameters( - cls, parameters_dict, no_snapshots, force_no_ddp - ): + def _process_loaded_dict(cls, parameters_dict, no_snapshots, force_no_ddp): parameters_object = cls() for key in parameters_dict: if ( diff --git a/mala/common/physical_data.py b/mala/common/physical_data.py index 039dfd9ba..218fafe55 100644 --- a/mala/common/physical_data.py +++ b/mala/common/physical_data.py @@ -111,7 +111,6 @@ def read_from_numpy_file( loaded_array = loaded_array.reshape( (-1, 1, 1, original_dims[-1]) )[selection_mask] - print(f"post-mask array_dims = {loaded_array.shape}") return loaded_array else: return loaded_array @@ -139,7 +138,6 @@ def read_from_numpy_file( array = array.reshape((-1, 1, 1, array_dims[-1]))[ selection_mask ] - print(f"post-mask array_dims = {array.shape}") self._process_loaded_array(array, units=units) def read_from_openpmd_file(self, path, units=None, array=None): diff --git a/mala/datahandling/data_handler.py b/mala/datahandling/data_handler.py index cf49bbf56..660785e8e 100644 --- a/mala/datahandling/data_handler.py +++ b/mala/datahandling/data_handler.py @@ -235,10 +235,10 @@ def refresh_data( """ Replace tr, va, te data for next generation of active learning. + Internally replicates prepare_data function. Parameters ---------- - from_arrays_dict : dict or None (Allows user to provide data directly from memory) Dictionary which assigns an array (values) to each snapshot, e.g., diff --git a/mala/datahandling/snapshot.py b/mala/datahandling/snapshot.py index 6aea6cf03..27d16cfc7 100644 --- a/mala/datahandling/snapshot.py +++ b/mala/datahandling/snapshot.py @@ -101,6 +101,7 @@ def __init__( self._selection_mask = selection_mask def set_selection_mask(self, selection_mask): + """Set the selection mask for snapshot loading.""" if isinstance(selection_mask, np.ndarray): self._selection_mask = selection_mask.tolist() else: diff --git a/mala/interfaces/ase_calculator.py b/mala/interfaces/ase_calculator.py index 484395122..2336e47f7 100644 --- a/mala/interfaces/ase_calculator.py +++ b/mala/interfaces/ase_calculator.py @@ -79,7 +79,7 @@ def __init__( @classmethod def load_model(cls, run_name, path="./"): """ - DEPRECATED: Load a model to use for the calculator. + Load a model to use for the calculator (DEPRECATED). MALA.load_model() will be deprecated in MALA v1.4.0. Please use MALA.load_run() instead. @@ -239,5 +239,5 @@ def save_calculator(self, filename, path="./"): """ self.predictor.save_run( - filename, path=save_path, additional_calculation_data=True + filename, path=path, additional_calculation_data=True ) From e2ffcadd52c03c7b7d334b41fc5fa1078c8d8084 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Mon, 6 Jan 2025 14:32:59 +0100 Subject: [PATCH 4/4] Reformatting --- mala/datahandling/data_handler.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/mala/datahandling/data_handler.py b/mala/datahandling/data_handler.py index e37134ddd..b903d1571 100644 --- a/mala/datahandling/data_handler.py +++ b/mala/datahandling/data_handler.py @@ -1016,10 +1016,14 @@ def __build_datasets(self, from_arrays_dict=None): ) if self.nr_validation_data != 0: - self.__load_data("validation", "inputs", from_arrays_dict=from_arrays_dict) + self.__load_data( + "validation", "inputs", from_arrays_dict=from_arrays_dict + ) self.input_data_scaler.transform(self._validation_data_inputs) - self.__load_data("validation", "outputs", from_arrays_dict=from_arrays_dict) + self.__load_data( + "validation", "outputs", from_arrays_dict=from_arrays_dict + ) self.output_data_scaler.transform( self._validation_data_outputs ) @@ -1110,7 +1114,9 @@ def __parametrize_scalers(self, from_arrays_dict=None): self.input_data_scaler.partial_fit(tmp) else: - self.__load_data("training", "inputs", from_arrays_dict=from_arrays_dict) + self.__load_data( + "training", "inputs", from_arrays_dict=from_arrays_dict + ) self.input_data_scaler.fit(self._training_data_inputs) printout("Input scaler parametrized.", min_verbosity=1) @@ -1171,8 +1177,9 @@ def __parametrize_scalers(self, from_arrays_dict=None): i += 1 else: - self.__load_data("training", "outputs", from_arrays_dict=from_arrays_dict - )) + self.__load_data( + "training", "outputs", from_arrays_dict=from_arrays_dict + ) self.output_data_scaler.fit(self._training_data_outputs) printout("Output scaler parametrized.", min_verbosity=1)