Skip to content

Commit bb03347

Browse files
Update to_parquet to handle custom schema (to fix writing partitions with all missing data) (#201)
1 parent 35be560 commit bb03347

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

dask_geopandas/io/arrow.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,13 @@ def _pandas_to_arrow_table(
191191

192192
table = _geopandas_to_arrow(df, index=preserve_index)
193193

194-
# TODO add support for schema
195-
# (but let it already pass if the passed schema would not change the result)
196194
if schema is not None:
197-
if not table.schema.equals(schema) and len(df):
198-
raise NotImplementedError("Passing 'schema' is not yet supported")
195+
if not table.schema.equals(schema):
196+
# table.schema.metadata contains the "geo" metadata, so
197+
# ensure to preserve this in the cast operation
198+
if table.schema.metadata and not schema.metadata:
199+
schema = schema.with_metadata(table.schema.metadata)
200+
table = table.cast(schema)
199201

200202
return table
201203

dask_geopandas/tests/io/test_parquet.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,25 @@ def test_parquet_empty_partitions(tmp_path):
161161
assert result.spatial_partitions is None
162162

163163

164+
@pytest.mark.skipif(
165+
not Version(dask.__version__) >= Version("2022.06.0"),
166+
reason="Only works with dask 2022.06.0 or up",
167+
)
168+
def test_parquet_partitions_with_all_missing_strings(tmp_path):
169+
df = geopandas.GeoDataFrame(
170+
{"col": ["a", "b", None, None]},
171+
geometry=geopandas.points_from_xy([0, 1, 2, 3], [0, 1, 2, 3]),
172+
)
173+
# Creating filtered dask dataframe with at least one empty partition
174+
ddf = dask_geopandas.from_geopandas(df, npartitions=2)
175+
176+
basedir = tmp_path / "dataset"
177+
ddf.to_parquet(basedir)
178+
179+
result = dask_geopandas.read_parquet(basedir)
180+
assert_geodataframe_equal(result.compute(), df)
181+
182+
164183
@pytest.mark.skipif(
165184
Version(dask.__version__) < Version("2021.10.0"),
166185
reason="Only correct error message with dask 2021.10.0 or up",

0 commit comments

Comments
 (0)