diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 666fd2c1cc5..af495ae4c2f 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -412,6 +412,10 @@ cdef class Dataset(_Weakrefable): n_legs: [[2,4,4,100]] animal: [["Parrot","Dog","Horse","Centipede"]] """ + # Apply column projection from rename_columns() if present + if columns is None and 'columns' in self._scan_options: + columns = self._scan_options['columns'] + return Scanner.from_dataset( self, columns=columns, @@ -990,6 +994,81 @@ cdef class Dataset(_Weakrefable): right_dataset, right_on, right_by, tolerance, output_type=InMemoryDataset) + def rename_columns(self, names): + """ + Apply logical column renaming on the Dataset. + + The rename is applied lazily when data is scanned. Column names in the + files are not changed; the rename is a logical transformation applied + during reads. + + Parameters + ---------- + names : list, tuple, or dict + If a list or tuple, the new names for all columns (must match the + number of columns). If a dict, maps old column names to new names. + + Returns + ------- + Dataset + The existing dataset with column projection applied. + + Examples + -------- + Rename all columns by position: + + >>> import pyarrow as pa + >>> table = pa.table({'year': [2020, 2022, 2021, 2022, 2019, 2021], + ... 'n_legs': [2, 2, 4, 4, 5, 100], + ... 'animal': ["Flamingo", "Parrot", "Dog", "Horse", + ... "Brittle stars", "Centipede"]}) + + >>> import pyarrow.dataset as ds + >>> dataset = ds.InMemoryDataset([table]) + >>> dataset.rename_columns(['time', 'number_of_legs', 'name']).to_table() + pyarrow.Table + time: int64 + number_of_legs: int64 + name: string + ---- + time: [[2020,2022,2021,2022,2019,2021]] + number_of_legs: [[2,2,4,4,5,100]] + name: [["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"]] + + Rename specific columns: + + >>> dataset.rename_columns({'n_legs': 'number_of_legs'}).to_table() + pyarrow.Table + year: int64 + number_of_legs: int64 + animal: string + ---- + year: [[2020,2022,2021,2022,2019,2021]] + number_of_legs: [[2,2,4,4,5,100]] + animal: [["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"]] + """ + import pyarrow.dataset as ds + + schema = self.schema + + if isinstance(names, (list, tuple)): + if len(names) != len(schema): + raise ValueError( + f"Expected {len(schema)} names, got {len(names)}") + name_mapping = {schema.field(i).name: names[i] + for i in range(len(names))} + elif isinstance(names, dict): + name_mapping = {field.name: names.get(field.name, field.name) + for field in schema} + else: + raise TypeError(f"names must be list, tuple, or dict, not {type(names)!r}") + + projection = {new_name: ds.field(old_name) + for old_name, new_name in name_mapping.items()} + + self._scan_options['columns'] = projection + + return self cdef class InMemoryDataset(Dataset): """ diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 32bcebb28de..3e7e58fdb4e 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -5932,3 +5932,36 @@ def test_scanner_from_substrait(dataset): filter=ps.BoundExpressions.from_substrait(filtering) ).to_table() assert result.to_pydict() == {'str': ['4', '4']} + + +@pytest.mark.parametrize("names, expected_schema", [ + (["new-index", "new-color"], + pa.schema([pa.field("new-index", pa.int64()), + pa.field("new-color", pa.string())])), + (("new-index", "new-color"), + pa.schema([pa.field("new-index", pa.int64()), + pa.field("new-color", pa.string())])), + ({"index": "new-index", "color": "new-color"}, + pa.schema([pa.field("new-index", pa.int64()), + pa.field("new-color", pa.string())])), + ({"index": "new-index"}, + pa.schema([pa.field("new-index", pa.int64()), + pa.field("color", pa.string())])), +] +) +def test_rename_columns(names, expected_schema): + original_schema = pa.schema([ + pa.field('index', pa.int64()), + pa.field('color', pa.string()), + ] + ) + + dataset = ds.InMemoryDataset( + pa.RecordBatch.from_pylist( + [{"index": 1, "color": "green"}, {"index": 2, "color": "blue"}]), + schema=original_schema + ) + + dataset.rename_columns(names) + + assert dataset.to_table().schema.equals(expected_schema)