|
| 1 | +from pathlib import Path |
| 2 | +from typing import TYPE_CHECKING, Literal |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import pytest |
| 6 | +import zarr |
| 7 | +import zarr.storage |
| 8 | + |
| 9 | +from geff.core_io import write_arrays |
| 10 | +from geff.core_io._base_read import read_to_memory |
| 11 | +from geff.metadata._schema import GeffMetadata |
| 12 | +from geff.testing.data import create_simple_3d_geff |
| 13 | +from geff.validate.structure import validate_structure |
| 14 | + |
| 15 | +if TYPE_CHECKING: |
| 16 | + from geff._typing import PropDictNpArray |
| 17 | + |
| 18 | + |
| 19 | +from geff.core_io._base_write import dict_props_to_arr |
| 20 | + |
| 21 | + |
| 22 | +def _tmp_metadata(): |
| 23 | + """Return minimal valid GeffMetadata object for tests.""" |
| 24 | + return GeffMetadata(geff_version="0.0.1", directed=True) |
| 25 | + |
| 26 | + |
| 27 | +@pytest.fixture |
| 28 | +def dict_data(): |
| 29 | + data = [ |
| 30 | + (0, {"num": 1, "str": "category"}), |
| 31 | + (127, {"num": 5, "str_arr": ["test", "string"]}), |
| 32 | + (1, {"num": 6, "num_arr": [1, 2]}), |
| 33 | + ] |
| 34 | + return data |
| 35 | + |
| 36 | + |
| 37 | +class TestWriteArrays: |
| 38 | + @pytest.mark.parametrize("zarr_format", [2, 3]) |
| 39 | + def test_write_arrays_basic(self, tmp_path: Path, zarr_format: Literal[2, 3]) -> None: |
| 40 | + """Test basic functionality of write_arrays with minimal data.""" |
| 41 | + # Create test data |
| 42 | + geff_path = tmp_path / "test.geff" |
| 43 | + node_ids = np.array([1, 2, 3], dtype=np.int32) |
| 44 | + edge_ids = np.array([[1, 2], [2, 3]], dtype=np.int32) |
| 45 | + metadata = GeffMetadata(geff_version="0.0.1", directed=True) |
| 46 | + |
| 47 | + # Call write_arrays |
| 48 | + write_arrays( |
| 49 | + geff_store=geff_path, |
| 50 | + node_ids=node_ids, |
| 51 | + node_props=None, |
| 52 | + edge_ids=edge_ids, |
| 53 | + edge_props=None, |
| 54 | + metadata=metadata, |
| 55 | + zarr_format=zarr_format, |
| 56 | + ) |
| 57 | + |
| 58 | + # Verify the zarr group was created |
| 59 | + assert geff_path.exists() |
| 60 | + |
| 61 | + # Verify node and edge IDs were written correctly |
| 62 | + root = zarr.open_group(str(geff_path)) |
| 63 | + assert "nodes/ids" in root |
| 64 | + assert "edges/ids" in root |
| 65 | + |
| 66 | + # Check the data matches |
| 67 | + np.testing.assert_array_equal(root["nodes/ids"][:], node_ids) |
| 68 | + np.testing.assert_array_equal(root["edges/ids"][:], edge_ids) |
| 69 | + |
| 70 | + # Check the data types match |
| 71 | + assert root["nodes/ids"].dtype == node_ids.dtype |
| 72 | + assert root["edges/ids"].dtype == edge_ids.dtype |
| 73 | + |
| 74 | + # Verify metadata was written |
| 75 | + assert "geff" in root.attrs |
| 76 | + assert root.attrs["geff"]["geff_version"] == "0.0.1" |
| 77 | + assert root.attrs["geff"]["directed"] is True |
| 78 | + |
| 79 | + # TODO: test properties helper. It's covered by networkx tests now, so I'm okay merging, |
| 80 | + # but we should do it when we have time. |
| 81 | + |
| 82 | + def test_write_in_mem_geff(self): |
| 83 | + store, attrs = create_simple_3d_geff() |
| 84 | + in_mem_geff = read_to_memory(store) |
| 85 | + |
| 86 | + # Test writing |
| 87 | + new_store = zarr.storage.MemoryStore() |
| 88 | + write_arrays(new_store, **in_mem_geff) |
| 89 | + |
| 90 | + validate_structure(new_store) |
| 91 | + |
| 92 | + def test_write_arrays_rejects_disallowed_id_dtype(self, tmp_path) -> None: |
| 93 | + """write_arrays must fail fast for node/edge ids with unsupported dtype.""" |
| 94 | + geff_path = tmp_path / "invalid_ids.geff" |
| 95 | + |
| 96 | + # float16 is currently not allowed by Java Zarr |
| 97 | + node_ids = np.array([1, 2, 3], dtype=np.float16) |
| 98 | + edge_ids = np.array([[1, 2], [2, 3]], dtype=np.float16) |
| 99 | + |
| 100 | + with pytest.warns(UserWarning): |
| 101 | + write_arrays( |
| 102 | + geff_store=geff_path, |
| 103 | + node_ids=node_ids, |
| 104 | + node_props=None, |
| 105 | + edge_ids=edge_ids, |
| 106 | + edge_props=None, |
| 107 | + metadata=_tmp_metadata(), |
| 108 | + ) |
| 109 | + |
| 110 | + def test_write_arrays_rejects_disallowed_property_dtype(self, tmp_path) -> None: |
| 111 | + """write_arrays must fail fast if any property array has an unsupported dtype.""" |
| 112 | + geff_path = tmp_path / "invalid_prop.geff" |
| 113 | + |
| 114 | + # ids are fine (int32) |
| 115 | + node_ids = np.array([1, 2, 3], dtype=np.int32) |
| 116 | + edge_ids = np.array([[1, 2], [2, 3]], dtype=np.int32) |
| 117 | + |
| 118 | + # property with disallowed dtype (float16) |
| 119 | + bad_prop_values = np.array([0.1, 0.2, 0.3], dtype=np.float16) |
| 120 | + node_props: dict[str, PropDictNpArray] = { |
| 121 | + "score": {"values": bad_prop_values, "missing": None} |
| 122 | + } |
| 123 | + |
| 124 | + with pytest.warns(UserWarning): |
| 125 | + write_arrays( |
| 126 | + geff_store=geff_path, |
| 127 | + node_ids=node_ids, |
| 128 | + node_props=node_props, |
| 129 | + edge_ids=edge_ids, |
| 130 | + edge_props=None, |
| 131 | + metadata=_tmp_metadata(), |
| 132 | + ) |
| 133 | + |
| 134 | + |
| 135 | +@pytest.mark.parametrize( |
| 136 | + ("data_type", "expected"), |
| 137 | + [ |
| 138 | + ("num", ([1, 5, 6], None)), |
| 139 | + ("str", (["category", "", ""], [0, 1, 1])), |
| 140 | + ("num_arr", ([[1, 2], [1, 2], [1, 2]], [1, 1, 0])), |
| 141 | + ("str_arr", ([["test", "string"], ["test", "string"], ["test", "string"]], [1, 0, 1])), |
| 142 | + ], |
| 143 | +) |
| 144 | +def test_dict_prop_to_arr(dict_data, data_type, expected) -> None: |
| 145 | + props_dict = dict_props_to_arr(dict_data, [data_type]) |
| 146 | + print(props_dict) |
| 147 | + values = props_dict[data_type]["values"] |
| 148 | + missing = props_dict[data_type]["missing"] |
| 149 | + ex_values, ex_missing = expected |
| 150 | + ex_values = np.array(ex_values) |
| 151 | + ex_missing = np.array(ex_missing, dtype=bool) if ex_missing is not None else None |
| 152 | + |
| 153 | + np.testing.assert_array_equal(missing, ex_missing) |
| 154 | + np.testing.assert_array_equal(values, ex_values) |
| 155 | + |
| 156 | + |
| 157 | +# TODO: test write_dicts (it is pretty solidly covered by networkx and write_array tests, |
| 158 | +# so I'm okay merging without, but we should do it when we have time) |
0 commit comments