Skip to content

Commit 31a8608

Browse files
committed
Use model context in logp_dlogp_function to respect check_bounds
1 parent 12ab0c8 commit 31a8608

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

pymc/model/core.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -569,15 +569,16 @@ def logp_dlogp_function(
569569
for var in self.value_vars
570570
if var in input_vars and var not in grad_vars
571571
}
572-
return ValueGradFunction(
573-
costs,
574-
grad_vars,
575-
extra_vars_and_values,
576-
model=self,
577-
initial_point=initial_point,
578-
ravel_inputs=ravel_inputs,
579-
**kwargs,
580-
)
572+
with self:
573+
return ValueGradFunction(
574+
costs,
575+
grad_vars,
576+
extra_vars_and_values,
577+
model=self,
578+
initial_point=initial_point,
579+
ravel_inputs=ravel_inputs,
580+
**kwargs,
581+
)
581582

582583
def compile_logp(
583584
self,

tests/model/test_core.py

+9
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,15 @@ def test_missing_data(self):
443443
# Assert that all the elements of res are equal
444444
assert res[1:] == res[:-1]
445445

446+
def test_check_bounds_out_of_model_context(self):
447+
with pm.Model(check_bounds=False) as m:
448+
x = pm.Normal("x")
449+
y = pm.Normal("y", sigma=x)
450+
fn = m.logp_dlogp_function(ravel_inputs=True)
451+
fn.set_extra_values({})
452+
# When there are no bounds check logp turns into `nan`
453+
assert np.isnan(fn(np.array([-1.0, -1.0]))[0])
454+
446455

447456
class TestPytensorRelatedLogpBugs:
448457
def test_pytensor_switch_broadcast_edge_cases_1(self):

0 commit comments

Comments
 (0)