Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 135 additions & 20 deletions tools/statvar_importer/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@
flags.DEFINE_integer(
'sampler_rows_per_key', 5,
'The maximum number of rows to select for each unique value found.')
flags.DEFINE_integer(
'sampler_uniques_per_column', 10,
'The maximum number of unique values to track per column. '
'If 0 or -1, all unique values are tracked.')
flags.DEFINE_boolean(
'sampler_exhaustive', False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we can have a combined flag to track both..do we necessarily need two separate flags for this?

'If True, sets sampler_output_rows and sampler_uniques_per_column to '
'infinity, and sampler_rows_per_key to 1, to capture every unique value.')
flags.DEFINE_float(
'sampler_rate', -1,
'The sampling rate for random row selection (e.g., 0.1 for 10%).')
Expand All @@ -65,6 +73,11 @@
flags.DEFINE_string(
'sampler_unique_columns', '',
'A comma-separated list of column names to use for selecting unique rows.')
flags.DEFINE_list(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also update the comment to describe the default behavior (empty)?

'sampler_column_keys', [],
'A list of "column:file" pairs containing values that MUST be included '
'in the sample if they appear in the input data. '
'Example: "variableMeasured:prominent_svs.txt"')
flags.DEFINE_string('sampler_input_delimiter', ',',
'The delimiter used in the input CSV file.')
flags.DEFINE_string('sampler_input_encoding', 'UTF8',
Expand All @@ -75,6 +88,7 @@
_FLAGS = flags.FLAGS

import file_util
import mcf_file_util

from config_map import ConfigMap
from counters import Counters
Expand Down Expand Up @@ -103,27 +117,37 @@ def __init__(
self,
config_dict: dict = None,
counters: Counters = None,
column_include_values: dict = None,
):
"""Initializes the DataSampler object.

Args:
config_dict: A dictionary of configuration parameters.
counters: A Counters object for tracking statistics.
column_include_values: a dictionary of column-name to set of values
in the column to be included in the sample
"""
self._config = ConfigMap(config_dict=get_default_config())
if config_dict:
self._config.add_configs(config_dict)
self._counters = counters if counters is not None else Counters()
self._column_include_values = column_include_values
self.reset()

def reset(self) -> None:
"""Resets the state of the DataSampler.

This method resets the internal state of the DataSampler, including the
counts of unique column values and the number of selected rows. This is
useful when you want to reuse the same DataSampler instance for sampling
multiple files.
counts of unique column values and the number of selected rows. If
sampler_exhaustive is set in the configuration, it applies overrides
to other configuration parameters to capture all unique values.
"""
if self._config.get('sampler_exhaustive'):
# Exhaustive mode overrides limits to capture all unique values.
self._config.set_config('sampler_output_rows', -1)
self._config.set_config('sampler_uniques_per_column', -1)
self._config.set_config('sampler_rows_per_key', 1)

# Dictionary of unique values: count per column
self._column_counts = {}
# Dictionary of column index: list of header strings
Expand All @@ -144,6 +168,16 @@ def reset(self) -> None:
if col.strip()
]

# Must include values: dict of column_name -> set of values
self._must_include_values = load_column_keys(
self._config.get('sampler_column_keys', []))
if self._column_include_values:
for col, vals in self._column_include_values.items():
self._must_include_values.setdefault(col, set()).update(vals)

# Map of column index -> set of values
self._must_include_indices = {}

def __del__(self) -> None:
"""Logs the column headers and counts upon object deletion."""
logging.log(2, f'Sampler column headers: {self._column_headers}')
Expand Down Expand Up @@ -179,22 +213,35 @@ def _get_column_count(self, column_index: int, value: str) -> int:
return 0
return col_values.get(value, 0)

def _should_track_column(self, column_index: int) -> bool:
"""Determines if a column should be tracked for unique values.
def _is_unique_column(self, column_index: int) -> bool:
"""Determines if a column is specified for unique value sampling.

Args:
column_index: The index of the column.

Returns:
True if the column should be tracked (either no unique columns
specified or this column is in the unique columns list).
True if the column should be sampled for unique values.
"""
if not self._unique_column_names:
# No specific columns specified, track all
# No specific columns specified, track all for unique sampling
return True
# Check if this column is in our unique columns
return column_index in self._unique_column_indices.values()

def _should_track_column(self, column_index: int) -> bool:
"""Determines if a column should be tracked for counts.

Args:
column_index: The index of the column.

Returns:
True if the column should be tracked for unique values or is a
must-include column.
"""
if self._is_unique_column(column_index):
return True
return column_index in self._must_include_indices

def _process_header_row(self, row: list[str]) -> None:
"""Process a header row to build column name to index mapping.

Expand All @@ -206,15 +253,27 @@ def _process_header_row(self, row: list[str]) -> None:
Args:
row: A header row containing column names.
"""
if not self._unique_column_names:
return

for index, column_name in enumerate(row):
if column_name in self._unique_column_names:
if self._unique_column_names and column_name in self._unique_column_names:
self._unique_column_indices[column_name] = index
logging.level_debug() and logging.debug(
f'Mapped unique column "{column_name}" to index {index}')

if self._must_include_values and column_name in self._must_include_values:
self._must_include_indices[index] = self._must_include_values[
column_name]
logging.info(
f'Mapped must-include column "{column_name}" to index {index}'
)

def _is_must_include(self, column_index: int, value: str) -> bool:
"""Checks if a column value is in the must-include list."""
if column_index not in self._must_include_indices:
return False
# Normalize the input value before checking against the set
return mcf_file_util.strip_namespace(
value) in self._must_include_indices[column_index]

def _add_column_header(self, column_index: int, value: str) -> str:
"""Adds the first non-empty value of a column as its header.

Expand Down Expand Up @@ -282,13 +341,26 @@ def select_row(self, row: list[str], sample_rate: float = -1) -> bool:
# Too many rows already selected. Drop it.
return False
max_count = self._config.get('sampler_rows_per_key', 3)
if max_count <= 0:
max_count = sys.maxsize
max_uniques_per_col = self._config.get('sampler_uniques_per_column', 10)
if max_uniques_per_col <= 0:
max_uniques_per_col = sys.maxsize

for index in range(len(row)):
# Skip columns not in unique_columns list
if not self._should_track_column(index):
continue
value = row[index]
value_count = self._get_column_count(index, value)

# Rule 1: Always include if it's a must-include value and
# we haven't reached per-key limit.
if value_count < max_count and self._is_must_include(index, value):
self._counters.add_counter('sampler-selected-must-include', 1)
return True

# Skip columns not in unique_columns list for general unique sampling
if not self._is_unique_column(index):
continue

if value_count == 0 or value_count < max_count:
# This is a new value for this column.
col_counts = self._column_counts.get(index, {})
Expand All @@ -301,7 +373,7 @@ def select_row(self, row: list[str], sample_rate: float = -1) -> bool:
# No new unique value for the row.
# Check random sampler.
if sample_rate < 0:
sample_rate = self._config.get('sampler_rate')
sample_rate = self._config.get('sampler_rate', -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If sample_rate is negative, we are setting it default to negative again. Isn't that contradictory?

if random.random() <= sample_rate:
self._counters.add_counter('sampler-sampled-rows', 1)
return True
Expand Down Expand Up @@ -390,6 +462,7 @@ def sample_csv_file(self, input_file: str, output_file: str = '') -> str:
# Process and write header rows from the first input file.
if row_index <= header_rows and input_index == 0:
self._process_header_row(row)
self._add_row_counts(row)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious where is this defined?

csv_writer.writerow(row)
self._counters.add_counter('sampler-header-rows', 1)
# After processing all header rows, validate that all
Expand Down Expand Up @@ -425,9 +498,40 @@ def sample_csv_file(self, input_file: str, output_file: str = '') -> str:
return output_file


def load_column_keys(column_keys: list) -> dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to strip namespace here like it's done in _is_must_include?

"""Returns a dictionary of column name to set of keys loaded from a file.
The set of keys for a column are used as filter when sampling.

Args:
column_keys: comma separated list of column_name:<csv file> with
first column as the keys to be loaded.

Returns:
dictionary of column name to a set of keys for that column
{ <column-name1>: { key1, key2, ...}, <column-name2>: { k1, k2...} ...}
"""
column_map = {}
if not isinstance(column_keys, list):
column_keys = column_keys.split(',')

for col_file in column_keys:
column_name, file_name = col_file.split(':', 1)
if not file_name:
logging.error(f'No file for column {column_name} in {column_keys}')
continue

col_items = file_util.file_load_csv_dict(file_name)
column_map[column_name] = set(col_items.keys())
logging.info(
f'Loaded {len(col_items)} for column {column_name} from {file_name}'
)
return column_map


def sample_csv_file(input_file: str,
output_file: str = '',
config: dict = None) -> str:
config: dict = None,
counters: Counters = None) -> str:
"""Samples a CSV file and returns the path to the sampled file.

This function provides a convenient way to sample a CSV file with a single
Expand All @@ -443,6 +547,8 @@ def sample_csv_file(input_file: str,
- sampler_output_rows: The maximum number of rows to include in the
sample.
- sampler_rate: The sampling rate to use for random selection.
- sampler_exhaustive: If True, overrides limits to capture all unique
values.
- header_rows: The number of header rows to copy from the input file
and search for sampler_unique_columns. Increase this if column names
appear in later header rows (e.g., after a title row).
Expand All @@ -455,6 +561,7 @@ def sample_csv_file(input_file: str,
- input_delimiter: The delimiter used in the input file.
- output_delimiter: The delimiter to use in the output file.
- input_encoding: The encoding of the input file.
counters: optional Counters object to get counts of sampling.

Returns:
The path to the output file with the sampled rows.
Expand Down Expand Up @@ -484,7 +591,7 @@ def sample_csv_file(input_file: str,
"""
if config is None:
config = {}
data_sampler = DataSampler(config_dict=config)
data_sampler = DataSampler(config_dict=config, counters=counters)
return data_sampler.sample_csv_file(input_file, output_file)


Expand All @@ -500,23 +607,31 @@ def get_default_config() -> dict:
# Use default values of flags for tests
if not _FLAGS.is_parsed():
_FLAGS.mark_as_parsed()
return {

config = {
'sampler_rate': _FLAGS.sampler_rate,
'sampler_input': _FLAGS.sampler_input,
'sampler_output': _FLAGS.sampler_output,
'sampler_output_rows': _FLAGS.sampler_output_rows,
'header_rows': _FLAGS.sampler_header_rows,
'sampler_rows_per_key': _FLAGS.sampler_rows_per_key,
'sampler_uniques_per_column': _FLAGS.sampler_uniques_per_column,
'sampler_column_regex': _FLAGS.sampler_column_regex,
'sampler_unique_columns': _FLAGS.sampler_unique_columns,
'sampler_column_keys': _FLAGS.sampler_column_keys,
'input_delimiter': _FLAGS.sampler_input_delimiter,
'output_delimiter': _FLAGS.sampler_output_delimiter,
'input_encoding': _FLAGS.sampler_input_encoding,
'sampler_exhaustive': _FLAGS.sampler_exhaustive,
}

return config


def main(_):
sample_csv_file(_FLAGS.sampler_input, _FLAGS.sampler_output)
counters = Counters()
sample_csv_file(_FLAGS.sampler_input, _FLAGS.sampler_output, counters=counters)
counters.print_counters()


if __name__ == '__main__':
Expand Down
Loading
Loading