Skip to content
Merged
1 change: 1 addition & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ defaults:
jobs:
build_sdist:
name: Build an sdist and determine versions
if: ${{ github.ref != 'refs/heads/main' }}
uses: ./.github/workflows/packaging_sdist.yml
with:
testsuite: all
Expand Down
31 changes: 21 additions & 10 deletions duckdb/experimental/spark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,8 +1067,7 @@ def union(self, other: "DataFrame") -> "DataFrame":
unionAll = union

def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> "DataFrame":
"""Returns a new :class:`DataFrame` containing union of rows in this and another
:class:`DataFrame`.
"""Returns a new :class:`DataFrame` containing union of rows in this and another :class:`DataFrame`.

This is different from both `UNION ALL` and `UNION DISTINCT` in SQL. To do a SQL-style set
union (that does deduplication of elements), use this function followed by :func:`distinct`.
Expand Down Expand Up @@ -1121,15 +1120,27 @@ def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) ->
| 1| 2| 3|NULL|
|NULL| 4| 5| 6|
+----+----+----+----+
""" # noqa: D205
"""
if allowMissingColumns:
cols = []
for col in self.relation.columns:
if col in other.relation.columns:
cols.append(col)
else:
cols.append(spark_sql_functions.lit(None))
other = other.select(*cols)
df1 = self.select(
*self.relation.columns,
*[
spark_sql_functions.lit(None).alias(c)
for c in other.relation.columns
if c not in self.relation.columns
],
)

df2 = other.select(
*[
spark_sql_functions.col(c)
if c in other.relation.columns
else spark_sql_functions.lit(None).alias(c)
for c in df1.relation.columns
]
)

return df1.unionByName(df2, allowMissingColumns=False)
else:
other = other.select(*self.relation.columns)

Expand Down
15 changes: 15 additions & 0 deletions tests/fast/spark/test_spark_union_by_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,18 @@ def test_union_by_name_allow_missing_cols(self, df1, df2):
Row(name="Jeff", id=None),
]
assert res == expected

def test_union_by_name_allow_missing_cols_rev(self, df1, df2):
rel = df2.drop("id").unionByName(df1, allowMissingColumns=True)
res = rel.collect()
expected = [
Row(name="James", id=None),
Row(name="Maria", id=None),
Row(name="Jen", id=None),
Row(name="Jeff", id=None),
Row(name="James", id=34),
Row(name="Michael", id=56),
Row(name="Robert", id=30),
Row(name="Maria", id=24),
]
assert res == expected