Skip to content

Commit

Permalink
[SPARK-46753][PYTHON][TESTS] Fix pypy3 python test
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
The pr aims to fix `pypy3` python tests.

### Why are the changes needed?
Currently scheduled job fails (with PyPy3), we should fix it to improve test coverage.

### Does this PR introduce _any_ user-facing change?
No, test-only.

### How was this patch tested?
- Pass GA
- Manually test.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #44778 from panbingkun/SPARK-46753.

Lead-authored-by: panbingkun <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
panbingkun and HyukjinKwon committed Jan 31, 2024
1 parent 5d87ac6 commit 0871a6f
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 64 deletions.
6 changes: 6 additions & 0 deletions python/pyspark/sql/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from pyspark.errors import PySparkRuntimeError
from pyspark.sql import SparkSession, SQLContext, Row
from pyspark.sql.functions import col
from pyspark.testing.connectutils import (
should_test_connect,
connect_requirement_message,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import PySparkTestCase, PySparkErrorTestUtils

Expand Down Expand Up @@ -213,6 +217,7 @@ def test_active_session_with_None_and_not_None_context(self):
if sc is not None:
sc.stop()

@unittest.skipIf(not should_test_connect, connect_requirement_message)
def test_session_with_spark_connect_mode_enabled(self):
with unittest.mock.patch.dict(os.environ, {"SPARK_CONNECT_MODE_ENABLED": "1"}):
with self.assertRaisesRegex(RuntimeError, "Cannot create a Spark Connect session"):
Expand Down Expand Up @@ -454,6 +459,7 @@ def test_master_remote_conflicts(self):
del os.environ["SPARK_REMOTE"]
del os.environ["SPARK_LOCAL_REMOTE"]

@unittest.skipIf(not should_test_connect, connect_requirement_message)
def test_invalid_create(self):
with self.assertRaises(PySparkRuntimeError) as pe2:
SparkSession.builder.config("spark.remote", "local").create()
Expand Down
10 changes: 9 additions & 1 deletion python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import functools
import platform
import pydoc
import shutil
import tempfile
Expand All @@ -39,7 +40,11 @@
DayTimeIntervalType,
)
from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
test_compiled,
test_not_compiled_message,
)
from pyspark.testing.utils import QuietTest, assertDataFrameEqual


Expand Down Expand Up @@ -1039,6 +1044,9 @@ def test_udf(a):
with self.assertRaisesRegex(PythonException, "StopIteration"):
self.spark.range(10).select(test_udf(col("id"))).show()

@unittest.skipIf(
"pypy" in platform.python_implementation().lower(), "cannot run in environment pypy"
)
def test_python_udf_segfault(self):
with self.sql_conf({"spark.sql.execution.pyspark.udf.faulthandler.enabled": True}):
with self.assertRaisesRegex(Exception, "Segmentation fault"):
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/tests/test_udf_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def map(pdfs: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
df = self.spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0)], ("id", "v"))
df.mapInPandas(map, schema=df.schema).collect()

@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
def test_unsupported(self):
with warnings.catch_warnings(record=True) as warns:
warnings.simplefilter("always")
Expand Down
35 changes: 25 additions & 10 deletions python/pyspark/sql/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
IntegerType,
BooleanType,
)
from pyspark.testing.sqlutils import have_pandas
from pyspark.testing.sqlutils import have_pandas, have_pyarrow


class UtilsTestsMixin:
Expand Down Expand Up @@ -745,7 +745,10 @@ def test_assert_unequal_null_expected(self):
},
)

@unittest.skipIf(not have_pandas or not have_numpy, "no pandas or numpy dependency")
@unittest.skipIf(
not have_pandas or not have_numpy or not have_pyarrow,
"no pandas or numpy or pyarrow dependency",
)
def test_assert_equal_exact_pandas_df(self):
import pandas as pd
import numpy as np
Expand All @@ -760,7 +763,10 @@ def test_assert_equal_exact_pandas_df(self):
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)

@unittest.skipIf(not have_pandas or not have_numpy, "no pandas or numpy dependency")
@unittest.skipIf(
not have_pandas or not have_numpy or not have_pyarrow,
"no pandas or numpy or pyarrow dependency",
)
def test_assert_approx_equal_pandas_df(self):
import pandas as pd
import numpy as np
Expand All @@ -776,7 +782,10 @@ def test_assert_approx_equal_pandas_df(self):
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)

@unittest.skipIf(not have_pandas or not have_numpy, "no pandas or numpy dependency")
@unittest.skipIf(
not have_pandas or not have_numpy or not have_pyarrow,
"no pandas or numpy or pyarrow dependency",
)
def test_assert_approx_equal_fail_exact_pandas_df(self):
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -817,7 +826,10 @@ def test_assert_approx_equal_fail_exact_pandas_df(self):
},
)

@unittest.skipIf(not have_pandas or not have_numpy, "no pandas or numpy dependency")
@unittest.skipIf(
not have_pandas or not have_numpy or not have_pyarrow,
"no pandas or numpy or pyarrow dependency",
)
def test_assert_unequal_pandas_df(self):
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -857,7 +869,10 @@ def test_assert_unequal_pandas_df(self):
},
)

@unittest.skipIf(not have_pandas or not have_numpy, "no pandas or numpy dependency")
@unittest.skipIf(
not have_pandas or not have_numpy or not have_pyarrow,
"no pandas or numpy or pyarrow dependency",
)
def test_assert_type_error_pandas_df(self):
import pyspark.pandas as ps
import pandas as pd
Expand Down Expand Up @@ -896,7 +911,7 @@ def test_assert_type_error_pandas_df(self):
},
)

@unittest.skipIf(not have_pandas, "no pandas dependency")
@unittest.skipIf(not have_pandas or not have_pyarrow, "no pandas or pyarrow dependency")
def test_assert_equal_exact_pandas_on_spark_df(self):
import pyspark.pandas as ps

Expand All @@ -906,7 +921,7 @@ def test_assert_equal_exact_pandas_on_spark_df(self):
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)

@unittest.skipIf(not have_pandas, "no pandas dependency")
@unittest.skipIf(not have_pandas or not have_pyarrow, "no pandas or pyarrow dependency")
def test_assert_equal_exact_pandas_on_spark_df(self):
import pyspark.pandas as ps

Expand All @@ -915,7 +930,7 @@ def test_assert_equal_exact_pandas_on_spark_df(self):

assertDataFrameEqual(df1, df2)

@unittest.skipIf(not have_pandas, "no pandas dependency")
@unittest.skipIf(not have_pandas or not have_pyarrow, "no pandas or pyarrow dependency")
def test_assert_equal_approx_pandas_on_spark_df(self):
import pyspark.pandas as ps

Expand All @@ -925,7 +940,7 @@ def test_assert_equal_approx_pandas_on_spark_df(self):
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)

@unittest.skipIf(not have_pandas, "no pandas dependency")
@unittest.skipIf(not have_pandas or not have_pyarrow, "no pandas or pyarrow dependency")
def test_assert_error_pandas_pyspark_df(self):
import pyspark.pandas as ps
import pandas as pd
Expand Down
99 changes: 50 additions & 49 deletions python/pyspark/testing/connectutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
should_test_connect: str = typing.cast(str, connect_requirement_message is None)

if should_test_connect:
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import Read, Range, SQL, LogicalPlan
from pyspark.sql.connect.session import SparkSession

Expand All @@ -89,68 +90,68 @@ def __getattr__(self, item):

@unittest.skipIf(not should_test_connect, connect_requirement_message)
class PlanOnlyTestFixture(unittest.TestCase, PySparkErrorTestUtils):
from pyspark.sql.connect.dataframe import DataFrame
if should_test_connect:

class MockDF(DataFrame):
"""Helper class that must only be used for the mock plan tests."""
class MockDF(DataFrame):
"""Helper class that must only be used for the mock plan tests."""

def __init__(self, plan: LogicalPlan, session: SparkSession):
super().__init__(plan, session)
def __init__(self, plan: LogicalPlan, session: SparkSession):
super().__init__(plan, session)

def __getattr__(self, name):
"""All attributes are resolved to columns, because none really exist in the
mocked DataFrame."""
return self[name]
def __getattr__(self, name):
"""All attributes are resolved to columns, because none really exist in the
mocked DataFrame."""
return self[name]

@classmethod
def _read_table(cls, table_name):
return cls._df_mock(Read(table_name))
@classmethod
def _read_table(cls, table_name):
return cls._df_mock(Read(table_name))

@classmethod
def _udf_mock(cls, *args, **kwargs):
return "internal_name"
@classmethod
def _udf_mock(cls, *args, **kwargs):
return "internal_name"

@classmethod
def _df_mock(cls, plan: LogicalPlan) -> MockDF:
return PlanOnlyTestFixture.MockDF(plan, cls.connect)
@classmethod
def _df_mock(cls, plan: LogicalPlan) -> MockDF:
return PlanOnlyTestFixture.MockDF(plan, cls.connect)

@classmethod
def _session_range(
cls,
start,
end,
step=1,
num_partitions=None,
):
return cls._df_mock(Range(start, end, step, num_partitions))
@classmethod
def _session_range(
cls,
start,
end,
step=1,
num_partitions=None,
):
return cls._df_mock(Range(start, end, step, num_partitions))

@classmethod
def _session_sql(cls, query):
return cls._df_mock(SQL(query))
@classmethod
def _session_sql(cls, query):
return cls._df_mock(SQL(query))

if have_pandas:
if have_pandas:

@classmethod
def _with_plan(cls, plan):
return cls._df_mock(plan)
@classmethod
def _with_plan(cls, plan):
return cls._df_mock(plan)

@classmethod
def setUpClass(cls):
cls.connect = MockRemoteSession()
cls.session = SparkSession.builder.remote().getOrCreate()
cls.tbl_name = "test_connect_plan_only_table_1"
@classmethod
def setUpClass(cls):
cls.connect = MockRemoteSession()
cls.session = SparkSession.builder.remote().getOrCreate()
cls.tbl_name = "test_connect_plan_only_table_1"

cls.connect.set_hook("readTable", cls._read_table)
cls.connect.set_hook("range", cls._session_range)
cls.connect.set_hook("sql", cls._session_sql)
cls.connect.set_hook("with_plan", cls._with_plan)
cls.connect.set_hook("readTable", cls._read_table)
cls.connect.set_hook("range", cls._session_range)
cls.connect.set_hook("sql", cls._session_sql)
cls.connect.set_hook("with_plan", cls._with_plan)

@classmethod
def tearDownClass(cls):
cls.connect.drop_hook("readTable")
cls.connect.drop_hook("range")
cls.connect.drop_hook("sql")
cls.connect.drop_hook("with_plan")
@classmethod
def tearDownClass(cls):
cls.connect.drop_hook("readTable")
cls.connect.drop_hook("range")
cls.connect.drop_hook("sql")
cls.connect.drop_hook("with_plan")


@unittest.skipIf(not should_test_connect, connect_requirement_message)
Expand Down
9 changes: 5 additions & 4 deletions python/pyspark/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,10 +605,11 @@ def assertDataFrameEqual(
>>> list_of_rows = [Row(1, 1000), Row(2, 3000)]
>>> assertDataFrameEqual(df1, list_of_rows) # pass, actual and expected data are equal
>>> import pyspark.pandas as ps
>>> df1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
>>> df2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
>>> assertDataFrameEqual(df1, df2) # pass, pandas-on-Spark DataFrames are equal
>>> import pyspark.pandas as ps # doctest: +SKIP
>>> df1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) # doctest: +SKIP
>>> df2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) # doctest: +SKIP
>>> # pass, pandas-on-Spark DataFrames are equal
>>> assertDataFrameEqual(df1, df2) # doctest: +SKIP
>>> df1 = spark.createDataFrame(
... data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"])
Expand Down

0 comments on commit 0871a6f

Please sign in to comment.