Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingestion): add column level description for parquet files #12988

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import IO, Any, Callable, Dict, List, Type

import pandas
import pyarrow
import pyarrow.parquet

Expand Down Expand Up @@ -64,6 +65,39 @@
}


def get_column_metadata(schema_dict: dict, column_name: str) -> str:
"""
Get metadata for a specific column from the schema dictionary.

Args:
schema_dict (dict): The schema dictionary containing column definitions
column_name (str): The name of the column to get metadata for

Returns:
dict: The metadata for the specified column, or None if column not found
"""
# Iterate through all columns in the schema
for _, column_info in schema_dict.items():
if column_info.get("name") == column_name:
return column_info.get("metadata", {})

# Return None if column not found
return None


def parse_metadata(schema_metadata: bytes) -> Dict:
"""
Parse parquet schema metadata into a dictionary of fields.

Args:
schema_metadata (bytes): Raw schema metadata from parquet file

Returns:
Dict: Parsed metadata fields dictionary
"""
return pandas.read_json(schema_metadata.decode("utf-8")).to_dict()["fields"]


def map_pyarrow_type(pyarrow_type: Type) -> Type:
for checker, mapped_type in pyarrow_type_map.items():
if checker(pyarrow_type):
Expand All @@ -81,14 +115,23 @@ def infer_schema(self, file: IO[bytes]) -> List[SchemaField]:

fields: List[SchemaField] = []

meta_data_fields = parse_metadata(
schema.metadata[b"org.apache.spark.sql.parquet.row.metadata"]
)

for name, pyarrow_type in zip(schema.names, schema.types):
mapped_type = map_pyarrow_type(pyarrow_type)

description = get_column_metadata(meta_data_fields, name)

description = description.get(name, None)

field = SchemaField(
fieldPath=name,
type=SchemaFieldDataType(mapped_type()),
nativeDataType=str(pyarrow_type),
recursive=False,
description=description,
)

fields.append(field)
Expand Down
35 changes: 35 additions & 0 deletions metadata-ingestion/tests/unit/data_lake/test_schema_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
StringTypeClass,
)
from tests.unit.test_schema_util import assert_field_paths_match
from datahub.ingestion.source.schema_inference.parquet import get_column_metadata

expected_field_paths = [
"integer_field",
Expand All @@ -39,6 +40,31 @@
}
)

# Add descriptions to columns
test_table["integer_field"].attrs["description"] = "A column containing integer values"
test_table["boolean_field"].attrs["description"] = "A column containing boolean values"
test_table["string_field"].attrs["description"] = "A column containing string values"


expected_field_descriptions = [
"A column containing integer values",
"A column containing boolean values",
"A column containing string values",
]


test_column_metadata = [
{"name": "integer_field", "metadata": { "integer_field" : "A column containing integer values"}},
{"name": "boolean_field", "metadata": { "boolean_field" : "A column containing boolean values"}},
{"name": "string_field", "metadata": { "string_field": "A column containing string values"}},
]


def test_get_column_metadata():
assert "A column containing integer values" == get_column_metadata(test_column_metadata, "integer_field")
assert "A column containing boolean values" == get_column_metadata(test_column_metadata, "boolean_field")
assert "A column containing string values" == get_column_metadata(test_column_metadata, "string_field")


def assert_field_types_match(
fields: List[SchemaField], expected_field_types: List[Type]
Expand All @@ -48,6 +74,14 @@ def assert_field_types_match(
assert isinstance(field.type.type, expected_type)


def assert_field_descriptions_match(
fields: List[SchemaField], expected_field_descriptions: List[str]
) -> None:
assert len(fields) == len(expected_field_descriptions)
for field, expected_description in zip(fields, expected_field_descriptions):
assert field.description == expected_description


def test_infer_schema_csv():
with tempfile.TemporaryFile(mode="w+b") as file:
file.write(bytes(test_table.to_csv(index=False, header=True), encoding="utf-8"))
Expand Down Expand Up @@ -106,6 +140,7 @@ def test_infer_schema_parquet():

assert_field_paths_match(fields, expected_field_paths)
assert_field_types_match(fields, expected_field_types)
assert_field_descriptions_match(fields, expected_field_descriptions)


def test_infer_schema_avro():
Expand Down
Loading