Skip to content

Commit ad6b81d

Browse files
authored
Add grid.rebuild_graphs() (#177)
* Add grid.rebuild_graphs() Signed-off-by: Thijs Baaijen <13253091+Thijss@users.noreply.github.com> * Update test Signed-off-by: Thijs Baaijen <13253091+Thijss@users.noreply.github.com> * remove internal .from_arrays usage Signed-off-by: Thijs Baaijen <13253091+Thijss@users.noreply.github.com> * fix tests Signed-off-by: Thijs Baaijen <13253091+Thijss@users.noreply.github.com> * ._from_grid -> .from_grid Signed-off-by: Thijs Baaijen <13253091+Thijss@users.noreply.github.com> * update docstrings Signed-off-by: Thijs Baaijen <13253091+Thijss@users.noreply.github.com> * format Signed-off-by: Thijs Baaijen <13253091+Thijss@users.noreply.github.com> --------- Signed-off-by: Thijs Baaijen <13253091+Thijss@users.noreply.github.com>
1 parent bbb63f6 commit ad6b81d

8 files changed

Lines changed: 57 additions & 43 deletions

File tree

src/power_grid_model_ds/_core/model/graphs/container.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""Stores the GraphContainer class"""
66

77
import dataclasses
8+
import warnings
89
from dataclasses import dataclass
910
from typing import TYPE_CHECKING, Generator
1011

@@ -36,7 +37,7 @@ def __repr__(self) -> str:
3637
return f"{self.__class__.__name__}({', '.join(graph_infos)})"
3738

3839
@property
39-
def graph_attributes(self) -> Generator:
40+
def graph_attributes(self) -> Generator[dataclasses.Field, None, None]:
4041
"""Get all graph attributes of the container.
4142
4243
Yield:
@@ -119,13 +120,24 @@ def make_inactive(self, branch: BranchArray) -> None:
119120

120121
@classmethod
121122
def from_arrays(cls, arrays: "Grid") -> "GraphContainer":
122-
"""Build from arrays"""
123-
cls._validate_branches(arrays=arrays)
123+
"""Build from arrays. DEPRECATED: Use .from_grid instead."""
124+
warnings.warn(
125+
f"{cls.__name__}.from_arrays is deprecated and will be removed in a future release. "
126+
f"Use grid.rebuild_graphs() or {cls.__name__}.from_grid(grid) instead.",
127+
DeprecationWarning,
128+
stacklevel=2,
129+
)
130+
return cls.from_grid(arrays)
131+
132+
@classmethod
133+
def from_grid(cls, grid: "Grid") -> "GraphContainer":
134+
"""Build from grid"""
135+
cls._validate_branches(arrays=grid)
124136

125137
new_container = cls.empty()
126138
for graph_field in new_container.graph_attributes:
127-
graph = getattr(new_container, graph_field.name)
128-
new_graph = graph.from_arrays(arrays, active_only=graph.active_only)
139+
graph: BaseGraphModel = getattr(new_container, graph_field.name)
140+
new_graph = graph.from_grid(grid, active_only=graph.active_only)
129141
setattr(new_container, graph_field.name, new_graph)
130142

131143
return new_container

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: Contributors to the Power Grid Model project <powergridmodel@lfenergy.org>
22
#
33
# SPDX-License-Identifier: MPL-2.0
4-
4+
import warnings
55
from abc import ABC, abstractmethod
66
from contextlib import contextmanager
77
from typing import TYPE_CHECKING, Counter, Generator
@@ -348,13 +348,22 @@ def find_fundamental_cycles(self) -> list[list[int]]:
348348

349349
@classmethod
350350
def from_arrays(cls, arrays: "Grid", active_only=False) -> "BaseGraphModel":
351-
"""Build from arrays"""
352-
new_graph = cls(active_only=active_only)
353-
354-
new_graph.add_node_array(node_array=arrays.node, raise_on_fail=False)
355-
new_graph.add_branch_array(arrays.branches)
356-
new_graph.add_branch3_array(arrays.three_winding_transformer)
351+
"""Build from arrays. DEPRECATED: Use .from_grid instead."""
352+
warnings.warn(
353+
f"{cls.__name__}.from_arrays is deprecated and will be removed in a future release. "
354+
f"Use {cls.__name__}.from_grid instead.",
355+
DeprecationWarning,
356+
stacklevel=2,
357+
)
358+
return cls.from_grid(arrays, active_only=active_only)
357359

360+
@classmethod
361+
def from_grid(cls, grid: "Grid", active_only=False) -> "BaseGraphModel":
362+
"""Build from grid."""
363+
new_graph = cls(active_only=active_only)
364+
new_graph.add_node_array(node_array=grid.node, raise_on_fail=False)
365+
new_graph.add_branch_array(grid.branches)
366+
new_graph.add_branch3_array(grid.three_winding_transformer)
358367
return new_graph
359368

360369
def _internals_to_externals(self, internal_nodes: list[int]) -> list[int]:

src/power_grid_model_ds/_core/model/grids/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,10 @@ def deserialize(cls: Type[Self], path: Path) -> Self:
461461
"""Deserialize the grid."""
462462
return deserialize_from_json(path=path, target_grid_class=cls)
463463

464+
def rebuild_graphs(self) -> None:
465+
"""(Re)build the graphs in the grid."""
466+
self.graphs = GraphContainer.from_grid(self)
467+
464468
def diff(self, other_grid: Self) -> None:
465469
"""Print the differences between two grids
466470

src/power_grid_model_ds/_core/model/grids/serialization/json.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ def deserialize_from_json(path: Path, target_grid_class: type[G]) -> G:
7676

7777
grid = target_grid_class.empty()
7878
_restore_grid_values(grid, json_data["data"])
79-
graph_class = grid.graphs.__class__
80-
grid.graphs = graph_class.from_arrays(grid)
79+
grid.rebuild_graphs()
8180
return grid
8281

8382

src/power_grid_model_ds/_core/model/grids/serialization/pickle.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pathlib import Path
77
from typing import TYPE_CHECKING
88

9-
from power_grid_model_ds._core.model.graphs.container import GraphContainer
109
from power_grid_model_ds._core.utils.pickle import get_pickle_path, load_from_pickle, save_to_pickle
1110
from power_grid_model_ds._core.utils.zip import file2gzip
1211

@@ -22,7 +21,7 @@ def load_grid_from_pickle(grid_class: type["Grid"], cache_path: Path, load_graph
2221
raise TypeError(f"{pickle_path.name} is not a valid {grid_class.__name__} cache.")
2322

2423
if load_graphs:
25-
grid.graphs = GraphContainer.from_arrays(grid)
24+
grid.rebuild_graphs()
2625
return grid
2726

2827

tests/performance/graph_performance_tests.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def perftest_delete_node():
3333
do_performance_test(code_to_test, GRAPH_SIZES, 100, setup_codes=GRAPH_SETUP_CODES)
3434

3535

36-
def perftest_from_arrays():
37-
code_to_test = "grid.graphs.complete_graph.__class__.from_arrays(grid);"
36+
def perftest_rebuild_graphs():
37+
code_to_test = "grid.rebuild_graphs()"
3838
do_performance_test(code_to_test, GRAPH_SIZES, 100, setup_codes=GRAPH_SETUP_CODES)
3939

4040

@@ -53,4 +53,4 @@ def perftest_add_node():
5353
perftest_get_components()
5454
perftest_delete_node()
5555
perftest_add_node()
56-
perftest_from_arrays()
56+
perftest_rebuild_graphs()

tests/unit/model/graphs/test_container.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Contributors to the Power Grid Model project <powergridmodel@lfenergy.org>
22
#
33
# SPDX-License-Identifier: MPL-2.0
4+
from copy import deepcopy
45

56
import pytest
67

@@ -12,18 +13,13 @@
1213
# pylint: disable=missing-function-docstring
1314

1415

15-
def test_from_arrays(basic_grid: Grid):
16-
graphs = GraphContainer.from_arrays(basic_grid)
16+
def test_rebuild_graphs(basic_grid: Grid):
17+
orig_graphs = deepcopy(basic_grid.graphs)
1718

18-
assert isinstance(graphs, GraphContainer)
19-
assert basic_grid.graphs.complete_graph.nr_nodes == graphs.complete_graph.nr_nodes
20-
assert basic_grid.graphs.complete_graph.nr_branches == 6
21-
22-
assert basic_grid.graphs.active_graph.nr_nodes == graphs.active_graph.nr_nodes
23-
assert basic_grid.graphs.active_graph.nr_branches == 5
24-
25-
assert set(basic_grid.node.id) == set(graphs.active_graph.external_ids)
26-
assert set(basic_grid.node.id) == set(graphs.complete_graph.external_ids)
19+
basic_grid.graphs = GraphContainer.empty()
20+
assert basic_grid.graphs != orig_graphs
21+
basic_grid.rebuild_graphs()
22+
assert orig_graphs == basic_grid.graphs
2723

2824

2925
@pytest.fixture
@@ -81,7 +77,7 @@ def test_delete_branch3(
8177
assert not graph_container_with_5_nodes.complete_graph.has_branch(from_node, to_node)
8278

8379

84-
def test_from_arrays_active_three_winding(basic_grid: Grid):
80+
def test_rebuild_graphs_active_three_winding(basic_grid: Grid):
8581
nodes = NodeArray.zeros(3)
8682
nodes.id = [1000, 1001, 1002]
8783
basic_grid.append(nodes)
@@ -95,15 +91,16 @@ def test_from_arrays_active_three_winding(basic_grid: Grid):
9591
three_winding_transformer.status_3 = 1
9692
basic_grid.append(three_winding_transformer)
9793

98-
graphs = GraphContainer.from_arrays(basic_grid)
94+
basic_grid.rebuild_graphs()
95+
graphs = basic_grid.graphs
9996
assert basic_grid.graphs.complete_graph.nr_nodes == graphs.complete_graph.nr_nodes
10097
assert basic_grid.graphs.complete_graph.nr_branches == 6 + 3
10198

10299
assert basic_grid.graphs.active_graph.nr_nodes == graphs.active_graph.nr_nodes
103100
assert basic_grid.graphs.active_graph.nr_branches == 5 + 3
104101

105102

106-
def test_from_arrays_partially_active_three_winding(basic_grid: Grid):
103+
def test_rebuild_graphs_partially_active_three_winding(basic_grid: Grid):
107104
nodes = NodeArray.zeros(3)
108105
nodes.id = [1000, 1001, 1002]
109106
basic_grid.append(nodes)
@@ -117,7 +114,8 @@ def test_from_arrays_partially_active_three_winding(basic_grid: Grid):
117114
three_winding_transformer.status_3 = 0
118115
basic_grid.append(three_winding_transformer)
119116

120-
graphs = GraphContainer.from_arrays(basic_grid)
117+
basic_grid.rebuild_graphs()
118+
graphs = basic_grid.graphs
121119
assert basic_grid.graphs.complete_graph.nr_nodes == graphs.complete_graph.nr_nodes
122120
assert basic_grid.graphs.complete_graph.nr_branches == 6 + 3
123121

@@ -132,8 +130,8 @@ def test_from_arrays_partially_active_three_winding(basic_grid: Grid):
132130
assert not basic_grid.graphs.active_graph.has_branch(1001, 1002)
133131

134132

135-
def test_from_arrays_invalid_arrays(basic_grid: Grid):
133+
def test_rebuild_graphs_invalid_arrays(basic_grid: Grid):
136134
basic_grid.node = basic_grid.node.exclude(id=106)
137135

138136
with pytest.raises(RecordDoesNotExist):
139-
GraphContainer.from_arrays(basic_grid)
137+
basic_grid.rebuild_graphs()

tests/unit/model/graphs/test_graph_model.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99

1010
import numpy as np
1111
import pytest
12-
from numpy.testing import assert_array_equal
1312

1413
from power_grid_model_ds._core.model.graphs.errors import GraphError
1514
from power_grid_model_ds._core.model.graphs.models.base import BaseGraphModel
16-
from power_grid_model_ds._core.model.grids.base import Grid
1715
from power_grid_model_ds.errors import MissingBranchError, MissingNodeError, NoPathBetweenNodes
1816

1917
# pylint: disable=missing-function-docstring,missing-class-docstring
@@ -208,11 +206,6 @@ def test_get_components_with_tmp_removed_substation_nodes(graph_with_2_routes):
208206
assert set(components[2]) == {99}
209207

210208

211-
def test_from_arrays(basic_grid: Grid):
212-
new_graph = basic_grid.graphs.complete_graph.__class__.from_arrays(basic_grid)
213-
assert_array_equal(new_graph.external_ids, basic_grid.node.id)
214-
215-
216209
class TestPathMethods:
217210
def test_get_shortest_path(self, graph_with_2_routes: BaseGraphModel):
218211
graph = graph_with_2_routes

0 commit comments

Comments
 (0)