Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions dco/Jakiur Rahman.dco
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
1) I, Jakiur Rahman, certify that all work committed with the commit message
"covered by: <YOUR NAME>.dco" is my original work and I own the copyright
to this work. I agree to contribute this code under the Apache 2.0 license.

2) I understand and agree all contribution including all personal
information I submit with it is maintained indefinitely and may be
redistributed consistent with the open source license(s) involved.

This certification is effective for all code contributed from 2025-10-02 to 9999-01-01.
42 changes: 39 additions & 3 deletions gs_quant/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,50 @@ def _build_data_query(
return self.provider.build_query(start=start, end=end, as_of=as_of, since=since, fields=field_names,
empty_intervals=empty_intervals, **kwargs), schema_varies

def _build_data_frame(self, data, schema_varies, standard_fields) -> pd.DataFrame:
def _filter_df_by_fields(self, df: pd.DataFrame, field_names=None) -> pd.DataFrame:
"""
If caller requested specific fields, reduce to only those fields present in df.
field_names: original requested fields (list or None)
"""
# If no explicit requested fields, return as-is
if not field_names:
return df

# Use inflection if available, otherwise fallback to a simple underscore converter
def _underscore(name: str) -> str:
if inflection:
return inflection.underscore(name)
# lightweight fallback: replace CamelCase and spaced chars
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
s2 = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1)
return s2.replace('-', '_').replace(' ', '_').lower()

sanitized = []
for f in field_names:
# sanitize function-style names like diff(foo) -> diff_foo
fname = f.replace('(', '_').replace(')', '')
# add underscore between letter+digit if present
fname = re.sub(r'([a-zA-Z])(\d)', r'\1_\2', fname)
sanitized.append(_underscore(fname))

# Only select columns that actually exist in the DataFrame
selected = [c for c in sanitized if c in df.columns]
if selected:
return df.loc[:, selected]

# If nothing matches, preserve original df (makes the change non-breaking)
return df

def _build_data_frame(self, data, schema_varies, standard_fields, field_names=None) -> pd.DataFrame:
if type(data) is tuple:
df = self.provider.construct_dataframe_with_types(self.id, data[0], schema_varies,
standard_fields=standard_fields)
df = self._filter_df_by_fields(df, field_names)
return df.groupby(data[1], group_keys=True).apply(lambda x: x)
else:
return self.provider.construct_dataframe_with_types(self.id, data, schema_varies,
standard_fields=standard_fields)
df = self.provider.construct_dataframe_with_types(self.id, data, schema_varies,
standard_fields=standard_fields)
return self._filter_df_by_fields(df, field_names)

def get_data(
self,
Expand Down
36 changes: 36 additions & 0 deletions gs_quant/test/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,42 @@ def handler(data_frame):
assert mock.call_count == 1
assert df.equals(df2)

def test_get_data_filters_to_requested_fields():
class DummyProvider:
def build_query(self, *args, **kwargs):
return object()

def query_data(self, query, dataset_id, asset_id_type=None):
return [
{'date': dt.date(2020, 1, 1), 'city': 'Boston', 'maxTemperature': 10.0, 'minTemperature': 2.0, 'extraField': None},
{'date': dt.date(2020, 1, 2), 'city': 'Boston', 'maxTemperature': 12.0, 'minTemperature': 3.0, 'extraField': None},
]

def construct_dataframe_with_types(self, dataset_id, data, schema_varies=False, standard_fields=False):
df = pd.DataFrame(data)
rename_map = {}
if 'maxTemperature' in df.columns:
rename_map['maxTemperature'] = 'max_temperature'
if 'minTemperature' in df.columns:
rename_map['minTemperature'] = 'min_temperature'
df = df.rename(columns=rename_map)
# Keep only the requested fields plus date which is required
df = df[['date', 'max_temperature']]
return df

provider = DummyProvider()
ds = Dataset('DUMMY', provider=provider)

# Request only 'maxTemperature' (user-facing name); provider returns snake_cased column
res = ds.get_data(start=dt.date(2020, 1, 1), end=dt.date(2020, 1, 2), fields=['maxTemperature'])

# Columns should contain only the requested snake_cased field
assert 'max_temperature' in res.columns
assert 'min_temperature' not in res.columns
assert 'extraField' not in res.columns and 'extra_field' not in res.columns

# Verify values preserved
assert res['max_temperature'].iloc[0] == 10.0

if __name__ == "__main__":
pytest.main(args=["test_dataset.py"])