diff --git a/metadata-ingestion/src/datahub/ingestion/source/schema_inference/parquet.py b/metadata-ingestion/src/datahub/ingestion/source/schema_inference/parquet.py index efc605e0df8cab..ae0b5e13e459b4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/schema_inference/parquet.py +++ b/metadata-ingestion/src/datahub/ingestion/source/schema_inference/parquet.py @@ -1,5 +1,6 @@ from typing import IO, Any, Callable, Dict, List, Type +import pandas import pyarrow import pyarrow.parquet @@ -64,6 +65,39 @@ } +def get_column_metadata(schema_dict: dict, column_name: str) -> str: + """ + Get metadata for a specific column from the schema dictionary. + + Args: + schema_dict (dict): The schema dictionary containing column definitions + column_name (str): The name of the column to get metadata for + + Returns: + dict: The metadata for the specified column, or None if column not found + """ + # Iterate through all columns in the schema + for _, column_info in schema_dict.items(): + if column_info.get("name") == column_name: + return column_info.get("metadata", {}) + + # Return None if column not found + return None + + +def parse_metadata(schema_metadata: bytes) -> Dict: + """ + Parse parquet schema metadata into a dictionary of fields. + + Args: + schema_metadata (bytes): Raw schema metadata from parquet file + + Returns: + Dict: Parsed metadata fields dictionary + """ + return pandas.read_json(schema_metadata.decode("utf-8")).to_dict()["fields"] + + def map_pyarrow_type(pyarrow_type: Type) -> Type: for checker, mapped_type in pyarrow_type_map.items(): if checker(pyarrow_type): @@ -81,14 +115,23 @@ def infer_schema(self, file: IO[bytes]) -> List[SchemaField]: fields: List[SchemaField] = [] + meta_data_fields = parse_metadata( + schema.metadata[b"org.apache.spark.sql.parquet.row.metadata"] + ) + for name, pyarrow_type in zip(schema.names, schema.types): mapped_type = map_pyarrow_type(pyarrow_type) + description = get_column_metadata(meta_data_fields, name) + + description = description.get(name, None) + field = SchemaField( fieldPath=name, type=SchemaFieldDataType(mapped_type()), nativeDataType=str(pyarrow_type), recursive=False, + description=description, ) fields.append(field) diff --git a/metadata-ingestion/tests/unit/data_lake/test_schema_inference.py b/metadata-ingestion/tests/unit/data_lake/test_schema_inference.py index a1ef02c27ea540..8aab1ba906cfd6 100644 --- a/metadata-ingestion/tests/unit/data_lake/test_schema_inference.py +++ b/metadata-ingestion/tests/unit/data_lake/test_schema_inference.py @@ -16,6 +16,7 @@ StringTypeClass, ) from tests.unit.test_schema_util import assert_field_paths_match +from datahub.ingestion.source.schema_inference.parquet import get_column_metadata expected_field_paths = [ "integer_field", @@ -39,6 +40,31 @@ } ) +# Add descriptions to columns +test_table["integer_field"].attrs["description"] = "A column containing integer values" +test_table["boolean_field"].attrs["description"] = "A column containing boolean values" +test_table["string_field"].attrs["description"] = "A column containing string values" + + +expected_field_descriptions = [ + "A column containing integer values", + "A column containing boolean values", + "A column containing string values", +] + + +test_column_metadata = [ + {"name": "integer_field", "metadata": { "integer_field" : "A column containing integer values"}}, + {"name": "boolean_field", "metadata": { "boolean_field" : "A column containing boolean values"}}, + {"name": "string_field", "metadata": { "string_field": "A column containing string values"}}, +] + + +def test_get_column_metadata(): + assert "A column containing integer values" == get_column_metadata(test_column_metadata, "integer_field") + assert "A column containing boolean values" == get_column_metadata(test_column_metadata, "boolean_field") + assert "A column containing string values" == get_column_metadata(test_column_metadata, "string_field") + def assert_field_types_match( fields: List[SchemaField], expected_field_types: List[Type] @@ -48,6 +74,14 @@ def assert_field_types_match( assert isinstance(field.type.type, expected_type) +def assert_field_descriptions_match( + fields: List[SchemaField], expected_field_descriptions: List[str] +) -> None: + assert len(fields) == len(expected_field_descriptions) + for field, expected_description in zip(fields, expected_field_descriptions): + assert field.description == expected_description + + def test_infer_schema_csv(): with tempfile.TemporaryFile(mode="w+b") as file: file.write(bytes(test_table.to_csv(index=False, header=True), encoding="utf-8")) @@ -106,6 +140,7 @@ def test_infer_schema_parquet(): assert_field_paths_match(fields, expected_field_paths) assert_field_types_match(fields, expected_field_types) + assert_field_descriptions_match(fields, expected_field_descriptions) def test_infer_schema_avro():