diff --git a/dco/Jakiur Rahman.dco b/dco/Jakiur Rahman.dco new file mode 100644 index 00000000..d1b91ca2 --- /dev/null +++ b/dco/Jakiur Rahman.dco @@ -0,0 +1,9 @@ +1) I, Jakiur Rahman, certify that all work committed with the commit message +"covered by: .dco" is my original work and I own the copyright +to this work. I agree to contribute this code under the Apache 2.0 license. + +2) I understand and agree all contribution including all personal +information I submit with it is maintained indefinitely and may be +redistributed consistent with the open source license(s) involved. + +This certification is effective for all code contributed from 2025-10-02 to 9999-01-01. diff --git a/gs_quant/data/dataset.py b/gs_quant/data/dataset.py index dc2e393e..ac300c5d 100644 --- a/gs_quant/data/dataset.py +++ b/gs_quant/data/dataset.py @@ -133,14 +133,50 @@ def _build_data_query( return self.provider.build_query(start=start, end=end, as_of=as_of, since=since, fields=field_names, empty_intervals=empty_intervals, **kwargs), schema_varies - def _build_data_frame(self, data, schema_varies, standard_fields) -> pd.DataFrame: + def _filter_df_by_fields(self, df: pd.DataFrame, field_names=None) -> pd.DataFrame: + """ + If caller requested specific fields, reduce to only those fields present in df. + field_names: original requested fields (list or None) + """ + # If no explicit requested fields, return as-is + if not field_names: + return df + + # Use inflection if available, otherwise fallback to a simple underscore converter + def _underscore(name: str) -> str: + if inflection: + return inflection.underscore(name) + # lightweight fallback: replace CamelCase and spaced chars + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + s2 = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1) + return s2.replace('-', '_').replace(' ', '_').lower() + + sanitized = [] + for f in field_names: + # sanitize function-style names like diff(foo) -> diff_foo + fname = f.replace('(', '_').replace(')', '') + # add underscore between letter+digit if present + fname = re.sub(r'([a-zA-Z])(\d)', r'\1_\2', fname) + sanitized.append(_underscore(fname)) + + # Only select columns that actually exist in the DataFrame + selected = [c for c in sanitized if c in df.columns] + if selected: + return df.loc[:, selected] + + # If nothing matches, preserve original df (makes the change non-breaking) + return df + + def _build_data_frame(self, data, schema_varies, standard_fields, field_names=None) -> pd.DataFrame: if type(data) is tuple: df = self.provider.construct_dataframe_with_types(self.id, data[0], schema_varies, standard_fields=standard_fields) + df = self._filter_df_by_fields(df, field_names) return df.groupby(data[1], group_keys=True).apply(lambda x: x) else: - return self.provider.construct_dataframe_with_types(self.id, data, schema_varies, - standard_fields=standard_fields) + df = self.provider.construct_dataframe_with_types(self.id, data, schema_varies, + standard_fields=standard_fields) + return self._filter_df_by_fields(df, field_names) def get_data( self, diff --git a/gs_quant/test/data/test_dataset.py b/gs_quant/test/data/test_dataset.py index fcaa5578..f6cf1e92 100644 --- a/gs_quant/test/data/test_dataset.py +++ b/gs_quant/test/data/test_dataset.py @@ -418,6 +418,42 @@ def handler(data_frame): assert mock.call_count == 1 assert df.equals(df2) +def test_get_data_filters_to_requested_fields(): + class DummyProvider: + def build_query(self, *args, **kwargs): + return object() + + def query_data(self, query, dataset_id, asset_id_type=None): + return [ + {'date': dt.date(2020, 1, 1), 'city': 'Boston', 'maxTemperature': 10.0, 'minTemperature': 2.0, 'extraField': None}, + {'date': dt.date(2020, 1, 2), 'city': 'Boston', 'maxTemperature': 12.0, 'minTemperature': 3.0, 'extraField': None}, + ] + + def construct_dataframe_with_types(self, dataset_id, data, schema_varies=False, standard_fields=False): + df = pd.DataFrame(data) + rename_map = {} + if 'maxTemperature' in df.columns: + rename_map['maxTemperature'] = 'max_temperature' + if 'minTemperature' in df.columns: + rename_map['minTemperature'] = 'min_temperature' + df = df.rename(columns=rename_map) + # Keep only the requested fields plus date which is required + df = df[['date', 'max_temperature']] + return df + + provider = DummyProvider() + ds = Dataset('DUMMY', provider=provider) + + # Request only 'maxTemperature' (user-facing name); provider returns snake_cased column + res = ds.get_data(start=dt.date(2020, 1, 1), end=dt.date(2020, 1, 2), fields=['maxTemperature']) + + # Columns should contain only the requested snake_cased field + assert 'max_temperature' in res.columns + assert 'min_temperature' not in res.columns + assert 'extraField' not in res.columns and 'extra_field' not in res.columns + + # Verify values preserved + assert res['max_temperature'].iloc[0] == 10.0 if __name__ == "__main__": pytest.main(args=["test_dataset.py"])