|
3 | 3 |
|
4 | 4 | import pandas as pd |
5 | 5 | import pyarrow as pa |
| 6 | +import pyarrow.parquet as pq |
6 | 7 | import pytest |
7 | 8 | from nested_pandas import read_parquet |
8 | 9 | from nested_pandas.datasets import generate_data |
@@ -101,6 +102,51 @@ def test_read_parquet_catch_failed_cast(): |
101 | 102 | read_parquet("tests/test_data/not_nestable.parquet") |
102 | 103 |
|
103 | 104 |
|
| 105 | +def test_read_parquet_test_mixed_struct(): |
| 106 | + """Test reading a parquet file with mixed struct types""" |
| 107 | + # Create the pure-list StructArray |
| 108 | + field1 = pa.array([[1, 2], [3, 4], [5, 6]]) |
| 109 | + field2 = pa.array([["a", "b"], ["b", "c"], ["c", "d"]]) |
| 110 | + field3 = pa.array([[True, False], [True, False], [True, False]]) |
| 111 | + struct_array_list = pa.StructArray.from_arrays([field1, field2, field3], ["list1", "list2", "list3"]) |
| 112 | + |
| 113 | + # Create the value StructArray |
| 114 | + field1 = pa.array([1, 2, 3]) |
| 115 | + field2 = pa.array(["a", "b", "c"]) |
| 116 | + field3 = pa.array([True, False, True]) |
| 117 | + struct_array_val = pa.StructArray.from_arrays([field1, field2, field3], ["val1", "va12", "val3"]) |
| 118 | + |
| 119 | + # Create the mixed-list StructArray |
| 120 | + field1 = pa.array([1, 2, 3]) |
| 121 | + field2 = pa.array(["a", "b", "c"]) |
| 122 | + field3 = pa.array([[True, False], [True, False], [True, False]]) |
| 123 | + struct_array_mix = pa.StructArray.from_arrays([field1, field2, field3], ["val1", "va12", "list3"]) |
| 124 | + |
| 125 | + # Create a PyArrow Table with the StructArray as one of the columns |
| 126 | + table = pa.table( |
| 127 | + { |
| 128 | + "id": pa.array([100, 101, 102]), # Another column |
| 129 | + "struct_list": struct_array_list, # Struct column |
| 130 | + "struct_value": struct_array_val, |
| 131 | + "struct_mix": struct_array_mix, |
| 132 | + } |
| 133 | + ) |
| 134 | + |
| 135 | + # Write to a temporary file |
| 136 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 137 | + pq.write_table(table, os.path.join(tmpdir, "structs.parquet")) |
| 138 | + |
| 139 | + # Test full read |
| 140 | + nf = read_parquet(os.path.join(tmpdir, "structs.parquet")) |
| 141 | + assert nf.columns.tolist() == ["id", "struct_list", "struct_value", "struct_mix"] |
| 142 | + assert nf.nested_columns == ["struct_list"] |
| 143 | + |
| 144 | + # Test partial read |
| 145 | + nf = read_parquet(os.path.join(tmpdir, "structs.parquet"), columns=["id", "struct_mix.list3"]) |
| 146 | + assert nf.columns.tolist() == ["id", "struct_mix"] |
| 147 | + assert nf.nested_columns == ["struct_mix"] |
| 148 | + |
| 149 | + |
104 | 150 | def test_to_parquet(): |
105 | 151 | """Test writing a parquet file with no columns specified""" |
106 | 152 | # Load in the example file |
|
0 commit comments