Skip to content

Commit 12ab0c8

Browse files
committed
Propagate check_bounds over model transformations
1 parent 2a253b2 commit 12ab0c8

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

pymc/model/fgraph.py

+2
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def fgraph_from_model(
223223
copy_inputs=True,
224224
)
225225
# Copy model meta-info to fgraph
226+
fgraph.check_bounds = model.check_bounds
226227
fgraph._coords = model._coords.copy()
227228
fgraph._dim_lengths = {k: memo.get(v, v) for k, v in model._dim_lengths.items()}
228229

@@ -318,6 +319,7 @@ def first_non_model_var(var):
318319
# TODO: Consider representing/extracting them from the fgraph!
319320
_dim_lengths = {k: memo.get(v, v) for k, v in _dim_lengths.items()}
320321

322+
model.check_bounds = getattr(fgraph, "check_bounds", False)
321323
model._coords = _coords
322324
model._dim_lengths = _dim_lengths
323325

tests/model/test_fgraph.py

+10
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,13 @@ def test_multivariate_transform():
397397
new_ip = new_m.initial_point()
398398
np.testing.assert_allclose(ip["x_simplex__"], new_ip["x_simplex__"])
399399
np.testing.assert_allclose(ip["y_cholesky-cov-packed__"], new_ip["y_cholesky-cov-packed__"])
400+
401+
402+
def test_check_bounds_preserved():
403+
with pm.Model(check_bounds=True) as m:
404+
x = pm.HalfNormal("x")
405+
406+
assert clone_model(m).check_bounds
407+
408+
m.check_bounds = False
409+
assert not clone_model(m).check_bounds

0 commit comments

Comments
 (0)