Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switched spark column metadata fetching to using pyspark data frame A… #2124

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pytest
from contracts.helpers.test_data_source import TestDataSource
from helpers.test_table import TestTable
from soda.execution.data_type import DataType

from soda.contracts.contract import ContractResult

contracts_spark_partitioning_test_table = TestTable(
name="contracts_spark_partitioning",
columns=[
("id", DataType.TEXT),
],
# fmt: off
values=[ ('1',), ('2',), ('3',), ('4',), ('5',), ('6',) ]
# fmt: on
)


@pytest.mark.skip("Takes too long to be part of the local development test suite")
def test_spark_partitionind_columns(test_data_source: TestDataSource):
table_name: str = test_data_source.ensure_test_table(contracts_spark_partitioning_test_table)

spark_session = test_data_source.sodacl_data_source.spark_session

spark_session.sql(
f"""
DROP TABLE IF EXISTS customer;
"""
)

spark_session.sql(
f"""
CREATE TABLE customer(
cust_id INT,
state VARCHAR(20),
name STRING COMMENT 'Short name'
)
USING PARQUET
PARTITIONED BY (state);
"""
)

cols_df = spark_session.sql(
f"""
DESCRIBE TABLE customer
"""
)
cols_df.show()

data_df = spark_session.sql(f"SELECT * FROM customer;")
data_df.show()

table_df = spark_session.table("customer")
for field in table_df.schema.fields:
print(field.dataType.simpleString())
print(field.name)

contract_result: ContractResult = test_data_source.assert_contract_pass(
contract_yaml_str=f"""
dataset: customer
columns:
- name: cust_id
- name: state
- name: name
"""
)
contract_result_str = str(contract_result)
print(contract_result_str)
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ def test_contract_freshness_fail(test_data_source: TestDataSource, environ: dict

assert "Expected freshness(created) < 3h" in contract_result_str
assert "Actual freshness(created) was 3:19:50" in contract_result_str
assert "Max value in column was ...... 2021-01-01 10:10:10+00:00" in contract_result_str

if test_data_source.data_source_type.startswith("spark"):
assert "Max value in column was ...... 2021-01-01 10:10:10" in contract_result_str
else:
assert "Max value in column was ...... 2021-01-01 10:10:10+00:00" in contract_result_str

assert "Max value in column in UTC was 2021-01-01 10:10:10+00:00" in contract_result_str
assert "Now was ...................... 2021-01-01 13:30" in contract_result_str
assert "Now in UTC was ............... 2021-01-01 13:30:00+00:00" in contract_result_str
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,6 @@ def test_contract_schema_data_type_mismatch(test_data_source: TestDataSource):
assert data_type_mismatch.expected_data_type == "WRONG_VARCHAR"
assert data_type_mismatch.actual_data_type == test_data_source.data_type_text()

assert "Column 'id': Expected type 'WRONG_VARCHAR', but was 'character varying'" in str(contract_result)
assert f"Column 'id': Expected type 'WRONG_VARCHAR', but was '{test_data_source.data_type_text()}'" in str(
contract_result
)
7 changes: 7 additions & 0 deletions soda/spark/soda/data_sources/spark_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
from soda.__version__ import SODA_CORE_VERSION
from soda.common.exceptions import DataSourceConnectionError
from soda.common.logs import Logs
from soda.data_sources.spark_table_columns_query import SparkTableColumnsQuery
from soda.execution.data_source import DataSource
from soda.execution.data_type import DataType
from soda.execution.metric.schema_metric import SchemaMetric
from soda.execution.partition import Partition
from soda.execution.query.query import Query
from soda.execution.query.schema_query import TableColumnsQuery

logger = logging.getLogger(__name__)
ColumnMetadata = namedtuple("ColumnMetadata", ["name", "data_type", "is_nullable"])
Expand Down Expand Up @@ -199,6 +203,9 @@ class SparkSQLBase(DataSource):
def __init__(self, logs: Logs, data_source_name: str, data_source_properties: dict):
super().__init__(logs, data_source_name, data_source_properties)

def create_table_columns_query(self, partition: Partition, schema_metric: SchemaMetric) -> TableColumnsQuery:
return SparkTableColumnsQuery(partition, schema_metric)

def get_table_columns(
self,
table_name: str,
Expand Down
19 changes: 19 additions & 0 deletions soda/spark/soda/data_sources/spark_table_columns_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from soda.execution.query.schema_query import TableColumnsQuery


class SparkTableColumnsQuery(TableColumnsQuery):
def __init__(self, partition: "Partition", schema_metric: "SchemaMetric"):
super().__init__(partition=partition, schema_metric=schema_metric)
self.metric = schema_metric

def _initialize_column_rows(self):
"""
Initializes member self.rows as a list (or tuple) of rows where each row representing a column description.
A column description is a list (or tuple) of column name on index 0 and column data type (str) on index 1
Eg [["col_name_one", "data_type_of_col_name_one"], ...]
"""
data_source = self.data_source_scan.data_source
table_df = data_source.spark_session.table(self.table.table_name)
self.rows = tuple([field.name, field.simpleString()] for field in table_df.schema.fields)
self.row_count = len(self.rows)
self.description = (("col_name", "StringType"), ("data_type", "StringType"))
Loading