Skip to content

Commit 7ddcc30

Browse files
committed
Merge remote-tracking branch 'origin/main' into validator-framework
2 parents be92065 + bc6d924 commit 7ddcc30

25 files changed

+278
-287
lines changed

tests/test_agnostic/test_dtype_validation.py

Lines changed: 0 additions & 91 deletions
This file was deleted.

tests/test_agnostic/test_valid_values.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

tests/test_agnostic/test_write_arrays.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

tests/test_agnostic/test_write_dicts.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

tests/test_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from geff import GeffMetadata
1010
from geff._cli import app
1111
from geff.testing.data import create_simple_temporal_geff
12-
from tests.test_interops.test_ctc import create_mock_data
12+
from tests.test_convert.test_ctc import create_mock_data
1313

1414
if TYPE_CHECKING:
1515
from pathlib import Path
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
from pathlib import Path
2+
from typing import TYPE_CHECKING, Literal
3+
4+
import numpy as np
5+
import pytest
6+
import zarr
7+
import zarr.storage
8+
9+
from geff.core_io import write_arrays
10+
from geff.core_io._base_read import read_to_memory
11+
from geff.metadata._schema import GeffMetadata
12+
from geff.testing.data import create_simple_3d_geff
13+
from geff.validate.structure import validate_structure
14+
15+
if TYPE_CHECKING:
16+
from geff._typing import PropDictNpArray
17+
18+
19+
from geff.core_io._base_write import dict_props_to_arr
20+
21+
22+
def _tmp_metadata():
23+
"""Return minimal valid GeffMetadata object for tests."""
24+
return GeffMetadata(geff_version="0.0.1", directed=True)
25+
26+
27+
@pytest.fixture
28+
def dict_data():
29+
data = [
30+
(0, {"num": 1, "str": "category"}),
31+
(127, {"num": 5, "str_arr": ["test", "string"]}),
32+
(1, {"num": 6, "num_arr": [1, 2]}),
33+
]
34+
return data
35+
36+
37+
class TestWriteArrays:
38+
@pytest.mark.parametrize("zarr_format", [2, 3])
39+
def test_write_arrays_basic(self, tmp_path: Path, zarr_format: Literal[2, 3]) -> None:
40+
"""Test basic functionality of write_arrays with minimal data."""
41+
# Create test data
42+
geff_path = tmp_path / "test.geff"
43+
node_ids = np.array([1, 2, 3], dtype=np.int32)
44+
edge_ids = np.array([[1, 2], [2, 3]], dtype=np.int32)
45+
metadata = GeffMetadata(geff_version="0.0.1", directed=True)
46+
47+
# Call write_arrays
48+
write_arrays(
49+
geff_store=geff_path,
50+
node_ids=node_ids,
51+
node_props=None,
52+
edge_ids=edge_ids,
53+
edge_props=None,
54+
metadata=metadata,
55+
zarr_format=zarr_format,
56+
)
57+
58+
# Verify the zarr group was created
59+
assert geff_path.exists()
60+
61+
# Verify node and edge IDs were written correctly
62+
root = zarr.open_group(str(geff_path))
63+
assert "nodes/ids" in root
64+
assert "edges/ids" in root
65+
66+
# Check the data matches
67+
np.testing.assert_array_equal(root["nodes/ids"][:], node_ids)
68+
np.testing.assert_array_equal(root["edges/ids"][:], edge_ids)
69+
70+
# Check the data types match
71+
assert root["nodes/ids"].dtype == node_ids.dtype
72+
assert root["edges/ids"].dtype == edge_ids.dtype
73+
74+
# Verify metadata was written
75+
assert "geff" in root.attrs
76+
assert root.attrs["geff"]["geff_version"] == "0.0.1"
77+
assert root.attrs["geff"]["directed"] is True
78+
79+
# TODO: test properties helper. It's covered by networkx tests now, so I'm okay merging,
80+
# but we should do it when we have time.
81+
82+
def test_write_in_mem_geff(self):
83+
store, attrs = create_simple_3d_geff()
84+
in_mem_geff = read_to_memory(store)
85+
86+
# Test writing
87+
new_store = zarr.storage.MemoryStore()
88+
write_arrays(new_store, **in_mem_geff)
89+
90+
validate_structure(new_store)
91+
92+
def test_write_arrays_rejects_disallowed_id_dtype(self, tmp_path) -> None:
93+
"""write_arrays must fail fast for node/edge ids with unsupported dtype."""
94+
geff_path = tmp_path / "invalid_ids.geff"
95+
96+
# float16 is currently not allowed by Java Zarr
97+
node_ids = np.array([1, 2, 3], dtype=np.float16)
98+
edge_ids = np.array([[1, 2], [2, 3]], dtype=np.float16)
99+
100+
with pytest.warns(UserWarning):
101+
write_arrays(
102+
geff_store=geff_path,
103+
node_ids=node_ids,
104+
node_props=None,
105+
edge_ids=edge_ids,
106+
edge_props=None,
107+
metadata=_tmp_metadata(),
108+
)
109+
110+
def test_write_arrays_rejects_disallowed_property_dtype(self, tmp_path) -> None:
111+
"""write_arrays must fail fast if any property array has an unsupported dtype."""
112+
geff_path = tmp_path / "invalid_prop.geff"
113+
114+
# ids are fine (int32)
115+
node_ids = np.array([1, 2, 3], dtype=np.int32)
116+
edge_ids = np.array([[1, 2], [2, 3]], dtype=np.int32)
117+
118+
# property with disallowed dtype (float16)
119+
bad_prop_values = np.array([0.1, 0.2, 0.3], dtype=np.float16)
120+
node_props: dict[str, PropDictNpArray] = {
121+
"score": {"values": bad_prop_values, "missing": None}
122+
}
123+
124+
with pytest.warns(UserWarning):
125+
write_arrays(
126+
geff_store=geff_path,
127+
node_ids=node_ids,
128+
node_props=node_props,
129+
edge_ids=edge_ids,
130+
edge_props=None,
131+
metadata=_tmp_metadata(),
132+
)
133+
134+
135+
@pytest.mark.parametrize(
136+
("data_type", "expected"),
137+
[
138+
("num", ([1, 5, 6], None)),
139+
("str", (["category", "", ""], [0, 1, 1])),
140+
("num_arr", ([[1, 2], [1, 2], [1, 2]], [1, 1, 0])),
141+
("str_arr", ([["test", "string"], ["test", "string"], ["test", "string"]], [1, 0, 1])),
142+
],
143+
)
144+
def test_dict_prop_to_arr(dict_data, data_type, expected) -> None:
145+
props_dict = dict_props_to_arr(dict_data, [data_type])
146+
print(props_dict)
147+
values = props_dict[data_type]["values"]
148+
missing = props_dict[data_type]["missing"]
149+
ex_values, ex_missing = expected
150+
ex_values = np.array(ex_values)
151+
ex_missing = np.array(ex_missing, dtype=bool) if ex_missing is not None else None
152+
153+
np.testing.assert_array_equal(missing, ex_missing)
154+
np.testing.assert_array_equal(values, ex_values)
155+
156+
157+
# TODO: test write_dicts (it is pretty solidly covered by networkx and write_array tests,
158+
# so I'm okay merging without, but we should do it when we have time)

0 commit comments

Comments
 (0)