Skip to content

feat(ingestion): add column level description for parquet files #12988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import IO, Any, Callable, Dict, List, Type

import pandas
import pyarrow
import pyarrow.parquet

Expand Down Expand Up @@ -64,6 +65,39 @@
}


def get_column_metadata(schema_dict: dict, column_name: str) -> str:
"""
Get metadata for a specific column from the schema dictionary.

Args:
schema_dict (dict): The schema dictionary containing column definitions
column_name (str): The name of the column to get metadata for

Returns:
dict: The metadata for the specified column, or None if column not found
"""
# Iterate through all columns in the schema
for _, column_info in schema_dict.items():
if column_info.get("name") == column_name:
return column_info.get("metadata", {})

# Return None if column not found
return None


def parse_metadata(schema_metadata: bytes) -> Dict:
"""
Parse parquet schema metadata into a dictionary of fields.

Args:
schema_metadata (bytes): Raw schema metadata from parquet file

Returns:
Dict: Parsed metadata fields dictionary
"""
return pandas.read_json(schema_metadata.decode("utf-8")).to_dict()["fields"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same, right?

Suggested change
return pandas.read_json(schema_metadata.decode("utf-8")).to_dict()["fields"]
return json.loads(schema_metadata.decode("utf-8"))["fields"]

Unless necessary, I would avoid depending on pandas for this.

For resilience, we should also account for the possibility that the fields field might be missing.



def map_pyarrow_type(pyarrow_type: Type) -> Type:
for checker, mapped_type in pyarrow_type_map.items():
if checker(pyarrow_type):
Expand All @@ -81,14 +115,23 @@ def infer_schema(self, file: IO[bytes]) -> List[SchemaField]:

fields: List[SchemaField] = []

meta_data_fields = parse_metadata(
schema.metadata[b"org.apache.spark.sql.parquet.row.metadata"]
)
Comment on lines +118 to +120
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a guarantee that this metadata field will always exist? We should consider treating it as optional.


for name, pyarrow_type in zip(schema.names, schema.types):
mapped_type = map_pyarrow_type(pyarrow_type)

description = get_column_metadata(meta_data_fields, name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of traversing meta_data_fields for every column, you could make parse_metadata to build a dictionary indexed by column name.


description = description.get(name, None)

field = SchemaField(
fieldPath=name,
type=SchemaFieldDataType(mapped_type()),
nativeDataType=str(pyarrow_type),
recursive=False,
description=description,
)

fields.append(field)
Expand Down
35 changes: 35 additions & 0 deletions metadata-ingestion/tests/unit/data_lake/test_schema_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
StringTypeClass,
)
from tests.unit.test_schema_util import assert_field_paths_match
from datahub.ingestion.source.schema_inference.parquet import get_column_metadata

expected_field_paths = [
"integer_field",
Expand All @@ -39,6 +40,31 @@
}
)

# Add descriptions to columns
test_table["integer_field"].attrs["description"] = "A column containing integer values"
test_table["boolean_field"].attrs["description"] = "A column containing boolean values"
test_table["string_field"].attrs["description"] = "A column containing string values"


expected_field_descriptions = [
"A column containing integer values",
"A column containing boolean values",
"A column containing string values",
]


test_column_metadata = [
{"name": "integer_field", "metadata": { "integer_field" : "A column containing integer values"}},
{"name": "boolean_field", "metadata": { "boolean_field" : "A column containing boolean values"}},
{"name": "string_field", "metadata": { "string_field": "A column containing string values"}},
]


def test_get_column_metadata():
assert "A column containing integer values" == get_column_metadata(test_column_metadata, "integer_field")
assert "A column containing boolean values" == get_column_metadata(test_column_metadata, "boolean_field")
assert "A column containing string values" == get_column_metadata(test_column_metadata, "string_field")


def assert_field_types_match(
fields: List[SchemaField], expected_field_types: List[Type]
Expand All @@ -48,6 +74,14 @@ def assert_field_types_match(
assert isinstance(field.type.type, expected_type)


def assert_field_descriptions_match(
fields: List[SchemaField], expected_field_descriptions: List[str]
) -> None:
assert len(fields) == len(expected_field_descriptions)
for field, expected_description in zip(fields, expected_field_descriptions):
assert field.description == expected_description


def test_infer_schema_csv():
with tempfile.TemporaryFile(mode="w+b") as file:
file.write(bytes(test_table.to_csv(index=False, header=True), encoding="utf-8"))
Expand Down Expand Up @@ -106,6 +140,7 @@ def test_infer_schema_parquet():

assert_field_paths_match(fields, expected_field_paths)
assert_field_types_match(fields, expected_field_types)
assert_field_descriptions_match(fields, expected_field_descriptions)


def test_infer_schema_avro():
Expand Down
Loading