Skip to content

Commit

Permalink
Fix new test
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Aug 27, 2024
1 parent d34bc3a commit 0cee8c9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
6 changes: 4 additions & 2 deletions python/examples/test_dual_write.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import tempfile

#tag::test[]
# tag::test[]
from sparktestingbase.sqltestcase import SQLTestCase
from pyspark.sql.functions import current_timestamp
from pyspark.sql.types import Row
Expand All @@ -22,4 +22,6 @@ def test_actual_dual_write(self):
df1 = self.sqlCtx.read.format("parquet").load(p1)
df2 = self.sqlCtx.read.format("parquet").load(p2)
self.assertDataFrameEqual(df2.select("times"), df1, 0.1)
#end::test[]


# end::test[]
16 changes: 10 additions & 6 deletions python/examples/test_dual_write_new.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import tempfile

#tag::test[]
# tag::test[]
import unittest
from pyspark.sql import SparkSession
from pyspark.sql.functions import current_timestamp
Expand All @@ -13,7 +13,9 @@
class DualWriteTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate()
cls.spark = SparkSession.builder.appName(
"Testing PySpark Example"
).getOrCreate()

@classmethod
def tearDownClass(cls):
Expand All @@ -26,10 +28,12 @@ def test_actual_dual_write(self):
tempdir = tempfile.mkdtemp()
p1 = os.path.join(tempdir, "data1")
p2 = os.path.join(tempdir, "data2")
df = self.sqlCtx.createDataFrame([Row("timbit"), Row("farted")], ["names"])
df = self.spark.createDataFrame([Row("timbit"), Row("farted")], ["names"])
combined = df.withColumn("times", current_timestamp())
DualWriteExample().do_write(combined, p1, p2)
df1 = self.sqlCtx.read.format("parquet").load(p1)
df2 = self.sqlCtx.read.format("parquet").load(p2)
df1 = self.spark.read.format("parquet").load(p1)
df2 = self.spark.read.format("parquet").load(p2)
assertDataFrameEqual(df2.select("times"), df1, 0.1)
#end::test[]


# end::test[]

0 comments on commit 0cee8c9

Please sign in to comment.