-
Notifications
You must be signed in to change notification settings - Fork 135
Enhancements to data sampler #1916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,6 +54,14 @@ | |
| flags.DEFINE_integer( | ||
| 'sampler_rows_per_key', 5, | ||
| 'The maximum number of rows to select for each unique value found.') | ||
| flags.DEFINE_integer( | ||
| 'sampler_uniques_per_column', 10, | ||
| 'The maximum number of unique values to track per column. ' | ||
| 'If 0 or -1, all unique values are tracked.') | ||
| flags.DEFINE_boolean( | ||
| 'sampler_exhaustive', False, | ||
| 'If True, sets sampler_output_rows and sampler_uniques_per_column to ' | ||
| 'infinity, and sampler_rows_per_key to 1, to capture every unique value.') | ||
| flags.DEFINE_float( | ||
| 'sampler_rate', -1, | ||
| 'The sampling rate for random row selection (e.g., 0.1 for 10%).') | ||
|
|
@@ -65,6 +73,11 @@ | |
| flags.DEFINE_string( | ||
| 'sampler_unique_columns', '', | ||
| 'A comma-separated list of column names to use for selecting unique rows.') | ||
| flags.DEFINE_list( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also update the comment to describe the default behavior (empty)? |
||
| 'sampler_column_keys', [], | ||
| 'A list of "column:file" pairs containing values that MUST be included ' | ||
| 'in the sample if they appear in the input data. ' | ||
| 'Example: "variableMeasured:prominent_svs.txt"') | ||
| flags.DEFINE_string('sampler_input_delimiter', ',', | ||
| 'The delimiter used in the input CSV file.') | ||
| flags.DEFINE_string('sampler_input_encoding', 'UTF8', | ||
|
|
@@ -75,6 +88,7 @@ | |
| _FLAGS = flags.FLAGS | ||
|
|
||
| import file_util | ||
| import mcf_file_util | ||
|
|
||
| from config_map import ConfigMap | ||
| from counters import Counters | ||
|
|
@@ -103,27 +117,37 @@ def __init__( | |
| self, | ||
| config_dict: dict = None, | ||
| counters: Counters = None, | ||
| column_include_values: dict = None, | ||
| ): | ||
| """Initializes the DataSampler object. | ||
|
|
||
| Args: | ||
| config_dict: A dictionary of configuration parameters. | ||
| counters: A Counters object for tracking statistics. | ||
| column_include_values: a dictionary of column-name to set of values | ||
| in the column to be included in the sample | ||
| """ | ||
| self._config = ConfigMap(config_dict=get_default_config()) | ||
| if config_dict: | ||
| self._config.add_configs(config_dict) | ||
| self._counters = counters if counters is not None else Counters() | ||
| self._column_include_values = column_include_values | ||
| self.reset() | ||
|
|
||
| def reset(self) -> None: | ||
| """Resets the state of the DataSampler. | ||
|
|
||
| This method resets the internal state of the DataSampler, including the | ||
| counts of unique column values and the number of selected rows. This is | ||
| useful when you want to reuse the same DataSampler instance for sampling | ||
| multiple files. | ||
| counts of unique column values and the number of selected rows. If | ||
| sampler_exhaustive is set in the configuration, it applies overrides | ||
| to other configuration parameters to capture all unique values. | ||
| """ | ||
| if self._config.get('sampler_exhaustive'): | ||
| # Exhaustive mode overrides limits to capture all unique values. | ||
| self._config.set_config('sampler_output_rows', -1) | ||
| self._config.set_config('sampler_uniques_per_column', -1) | ||
| self._config.set_config('sampler_rows_per_key', 1) | ||
|
|
||
| # Dictionary of unique values: count per column | ||
| self._column_counts = {} | ||
| # Dictionary of column index: list of header strings | ||
|
|
@@ -144,6 +168,16 @@ def reset(self) -> None: | |
| if col.strip() | ||
| ] | ||
|
|
||
| # Must include values: dict of column_name -> set of values | ||
| self._must_include_values = load_column_keys( | ||
| self._config.get('sampler_column_keys', [])) | ||
| if self._column_include_values: | ||
| for col, vals in self._column_include_values.items(): | ||
| self._must_include_values.setdefault(col, set()).update(vals) | ||
|
|
||
| # Map of column index -> set of values | ||
| self._must_include_indices = {} | ||
|
|
||
| def __del__(self) -> None: | ||
| """Logs the column headers and counts upon object deletion.""" | ||
| logging.log(2, f'Sampler column headers: {self._column_headers}') | ||
|
|
@@ -179,22 +213,35 @@ def _get_column_count(self, column_index: int, value: str) -> int: | |
| return 0 | ||
| return col_values.get(value, 0) | ||
|
|
||
| def _should_track_column(self, column_index: int) -> bool: | ||
| """Determines if a column should be tracked for unique values. | ||
| def _is_unique_column(self, column_index: int) -> bool: | ||
| """Determines if a column is specified for unique value sampling. | ||
|
|
||
| Args: | ||
| column_index: The index of the column. | ||
|
|
||
| Returns: | ||
| True if the column should be tracked (either no unique columns | ||
| specified or this column is in the unique columns list). | ||
| True if the column should be sampled for unique values. | ||
| """ | ||
| if not self._unique_column_names: | ||
| # No specific columns specified, track all | ||
| # No specific columns specified, track all for unique sampling | ||
| return True | ||
| # Check if this column is in our unique columns | ||
| return column_index in self._unique_column_indices.values() | ||
|
|
||
| def _should_track_column(self, column_index: int) -> bool: | ||
| """Determines if a column should be tracked for counts. | ||
|
|
||
| Args: | ||
| column_index: The index of the column. | ||
|
|
||
| Returns: | ||
| True if the column should be tracked for unique values or is a | ||
| must-include column. | ||
| """ | ||
| if self._is_unique_column(column_index): | ||
| return True | ||
| return column_index in self._must_include_indices | ||
|
|
||
| def _process_header_row(self, row: list[str]) -> None: | ||
| """Process a header row to build column name to index mapping. | ||
|
|
||
|
|
@@ -206,15 +253,27 @@ def _process_header_row(self, row: list[str]) -> None: | |
| Args: | ||
| row: A header row containing column names. | ||
| """ | ||
| if not self._unique_column_names: | ||
| return | ||
|
|
||
| for index, column_name in enumerate(row): | ||
| if column_name in self._unique_column_names: | ||
| if self._unique_column_names and column_name in self._unique_column_names: | ||
| self._unique_column_indices[column_name] = index | ||
| logging.level_debug() and logging.debug( | ||
| f'Mapped unique column "{column_name}" to index {index}') | ||
|
|
||
| if self._must_include_values and column_name in self._must_include_values: | ||
| self._must_include_indices[index] = self._must_include_values[ | ||
| column_name] | ||
| logging.info( | ||
| f'Mapped must-include column "{column_name}" to index {index}' | ||
| ) | ||
|
|
||
| def _is_must_include(self, column_index: int, value: str) -> bool: | ||
| """Checks if a column value is in the must-include list.""" | ||
| if column_index not in self._must_include_indices: | ||
| return False | ||
| # Normalize the input value before checking against the set | ||
| return mcf_file_util.strip_namespace( | ||
| value) in self._must_include_indices[column_index] | ||
|
|
||
| def _add_column_header(self, column_index: int, value: str) -> str: | ||
| """Adds the first non-empty value of a column as its header. | ||
|
|
||
|
|
@@ -282,13 +341,26 @@ def select_row(self, row: list[str], sample_rate: float = -1) -> bool: | |
| # Too many rows already selected. Drop it. | ||
| return False | ||
| max_count = self._config.get('sampler_rows_per_key', 3) | ||
| if max_count <= 0: | ||
| max_count = sys.maxsize | ||
| max_uniques_per_col = self._config.get('sampler_uniques_per_column', 10) | ||
| if max_uniques_per_col <= 0: | ||
| max_uniques_per_col = sys.maxsize | ||
|
|
||
| for index in range(len(row)): | ||
| # Skip columns not in unique_columns list | ||
| if not self._should_track_column(index): | ||
| continue | ||
| value = row[index] | ||
| value_count = self._get_column_count(index, value) | ||
|
|
||
| # Rule 1: Always include if it's a must-include value and | ||
| # we haven't reached per-key limit. | ||
| if value_count < max_count and self._is_must_include(index, value): | ||
| self._counters.add_counter('sampler-selected-must-include', 1) | ||
| return True | ||
|
|
||
| # Skip columns not in unique_columns list for general unique sampling | ||
| if not self._is_unique_column(index): | ||
| continue | ||
|
|
||
| if value_count == 0 or value_count < max_count: | ||
| # This is a new value for this column. | ||
| col_counts = self._column_counts.get(index, {}) | ||
|
|
@@ -301,7 +373,7 @@ def select_row(self, row: list[str], sample_rate: float = -1) -> bool: | |
| # No new unique value for the row. | ||
| # Check random sampler. | ||
| if sample_rate < 0: | ||
| sample_rate = self._config.get('sampler_rate') | ||
| sample_rate = self._config.get('sampler_rate', -1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If sample_rate is negative, we are setting it default to negative again. Isn't that contradictory? |
||
| if random.random() <= sample_rate: | ||
| self._counters.add_counter('sampler-sampled-rows', 1) | ||
| return True | ||
|
|
@@ -390,6 +462,7 @@ def sample_csv_file(self, input_file: str, output_file: str = '') -> str: | |
| # Process and write header rows from the first input file. | ||
| if row_index <= header_rows and input_index == 0: | ||
| self._process_header_row(row) | ||
| self._add_row_counts(row) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious where is this defined? |
||
| csv_writer.writerow(row) | ||
| self._counters.add_counter('sampler-header-rows', 1) | ||
| # After processing all header rows, validate that all | ||
|
|
@@ -425,9 +498,40 @@ def sample_csv_file(self, input_file: str, output_file: str = '') -> str: | |
| return output_file | ||
|
|
||
|
|
||
| def load_column_keys(column_keys: list) -> dict: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to strip namespace here like it's done in _is_must_include? |
||
| """Returns a dictionary of column name to set of keys loaded from a file. | ||
| The set of keys for a column are used as filter when sampling. | ||
|
|
||
| Args: | ||
| column_keys: comma separated list of column_name:<csv file> with | ||
| first column as the keys to be loaded. | ||
|
|
||
| Returns: | ||
| dictionary of column name to a set of keys for that column | ||
| { <column-name1>: { key1, key2, ...}, <column-name2>: { k1, k2...} ...} | ||
| """ | ||
| column_map = {} | ||
| if not isinstance(column_keys, list): | ||
| column_keys = column_keys.split(',') | ||
|
|
||
| for col_file in column_keys: | ||
| column_name, file_name = col_file.split(':', 1) | ||
| if not file_name: | ||
| logging.error(f'No file for column {column_name} in {column_keys}') | ||
| continue | ||
|
|
||
| col_items = file_util.file_load_csv_dict(file_name) | ||
| column_map[column_name] = set(col_items.keys()) | ||
| logging.info( | ||
| f'Loaded {len(col_items)} for column {column_name} from {file_name}' | ||
| ) | ||
| return column_map | ||
|
|
||
|
|
||
| def sample_csv_file(input_file: str, | ||
| output_file: str = '', | ||
| config: dict = None) -> str: | ||
| config: dict = None, | ||
| counters: Counters = None) -> str: | ||
| """Samples a CSV file and returns the path to the sampled file. | ||
|
|
||
| This function provides a convenient way to sample a CSV file with a single | ||
|
|
@@ -443,6 +547,8 @@ def sample_csv_file(input_file: str, | |
| - sampler_output_rows: The maximum number of rows to include in the | ||
| sample. | ||
| - sampler_rate: The sampling rate to use for random selection. | ||
| - sampler_exhaustive: If True, overrides limits to capture all unique | ||
| values. | ||
| - header_rows: The number of header rows to copy from the input file | ||
| and search for sampler_unique_columns. Increase this if column names | ||
| appear in later header rows (e.g., after a title row). | ||
|
|
@@ -455,6 +561,7 @@ def sample_csv_file(input_file: str, | |
| - input_delimiter: The delimiter used in the input file. | ||
| - output_delimiter: The delimiter to use in the output file. | ||
| - input_encoding: The encoding of the input file. | ||
| counters: optional Counters object to get counts of sampling. | ||
|
|
||
| Returns: | ||
| The path to the output file with the sampled rows. | ||
|
|
@@ -484,7 +591,7 @@ def sample_csv_file(input_file: str, | |
| """ | ||
| if config is None: | ||
| config = {} | ||
| data_sampler = DataSampler(config_dict=config) | ||
| data_sampler = DataSampler(config_dict=config, counters=counters) | ||
| return data_sampler.sample_csv_file(input_file, output_file) | ||
|
|
||
|
|
||
|
|
@@ -500,23 +607,31 @@ def get_default_config() -> dict: | |
| # Use default values of flags for tests | ||
| if not _FLAGS.is_parsed(): | ||
| _FLAGS.mark_as_parsed() | ||
| return { | ||
|
|
||
| config = { | ||
| 'sampler_rate': _FLAGS.sampler_rate, | ||
| 'sampler_input': _FLAGS.sampler_input, | ||
| 'sampler_output': _FLAGS.sampler_output, | ||
| 'sampler_output_rows': _FLAGS.sampler_output_rows, | ||
| 'header_rows': _FLAGS.sampler_header_rows, | ||
| 'sampler_rows_per_key': _FLAGS.sampler_rows_per_key, | ||
| 'sampler_uniques_per_column': _FLAGS.sampler_uniques_per_column, | ||
| 'sampler_column_regex': _FLAGS.sampler_column_regex, | ||
| 'sampler_unique_columns': _FLAGS.sampler_unique_columns, | ||
| 'sampler_column_keys': _FLAGS.sampler_column_keys, | ||
| 'input_delimiter': _FLAGS.sampler_input_delimiter, | ||
| 'output_delimiter': _FLAGS.sampler_output_delimiter, | ||
| 'input_encoding': _FLAGS.sampler_input_encoding, | ||
| 'sampler_exhaustive': _FLAGS.sampler_exhaustive, | ||
| } | ||
|
|
||
| return config | ||
ajaits marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def main(_): | ||
| sample_csv_file(_FLAGS.sampler_input, _FLAGS.sampler_output) | ||
| counters = Counters() | ||
| sample_csv_file(_FLAGS.sampler_input, _FLAGS.sampler_output, counters=counters) | ||
| counters.print_counters() | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems we can have a combined flag to track both..do we necessarily need two separate flags for this?