Skip to content

Commit d582111

Browse files
committed
add serialization tests for Sequential and Residual
1 parent 0b09487 commit d582111

File tree

6 files changed

+78
-0
lines changed

6 files changed

+78
-0
lines changed

tests/test_networks/test_residual/__init__.py

Whitespace-only changes.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
from bayesflow.networks.residual import Residual
4+
5+
6+
@pytest.fixture()
7+
def residual():
8+
import keras
9+
10+
return Residual(keras.layers.Flatten(), keras.layers.Dense(2))
11+
12+
13+
@pytest.fixture()
14+
def build_shapes():
15+
return {"input_shape": (32, 2)}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import keras
2+
3+
from bayesflow.utils.serialization import deserialize, serialize
4+
5+
from ...utils import assert_layers_equal
6+
7+
8+
def test_serialize_deserialize(residual, build_shapes):
9+
residual.build(**build_shapes)
10+
11+
serialized = serialize(residual)
12+
deserialized = deserialize(serialized)
13+
reserialized = serialize(deserialized)
14+
15+
assert reserialized == serialized
16+
17+
18+
def test_save_and_load(tmp_path, residual, build_shapes):
19+
residual.build(**build_shapes)
20+
21+
keras.saving.save_model(residual, tmp_path / "model.keras")
22+
loaded = keras.saving.load_model(tmp_path / "model.keras")
23+
24+
assert_layers_equal(residual, loaded)

tests/test_networks/test_sequential/__init__.py

Whitespace-only changes.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
from bayesflow.networks import Sequential
4+
5+
6+
@pytest.fixture()
7+
def sequential():
8+
import keras
9+
10+
return Sequential(keras.layers.Flatten(), keras.layers.Dense(2))
11+
12+
13+
@pytest.fixture()
14+
def build_shapes():
15+
return {"input_shape": (32, 2)}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import keras
2+
3+
from bayesflow.utils.serialization import deserialize, serialize
4+
5+
from ...utils import assert_layers_equal
6+
7+
8+
def test_serialize_deserialize(sequential, build_shapes):
9+
sequential.build(**build_shapes)
10+
11+
serialized = serialize(sequential)
12+
deserialized = deserialize(serialized)
13+
reserialized = serialize(deserialized)
14+
15+
assert reserialized == serialized
16+
17+
18+
def test_save_and_load(tmp_path, sequential, build_shapes):
19+
sequential.build(**build_shapes)
20+
21+
keras.saving.save_model(sequential, tmp_path / "model.keras")
22+
loaded = keras.saving.load_model(tmp_path / "model.keras")
23+
24+
assert_layers_equal(sequential, loaded)

0 commit comments

Comments
 (0)