File tree 2 files changed +19
-9
lines changed
2 files changed +19
-9
lines changed Original file line number Diff line number Diff line change @@ -569,15 +569,16 @@ def logp_dlogp_function(
569
569
for var in self .value_vars
570
570
if var in input_vars and var not in grad_vars
571
571
}
572
- return ValueGradFunction (
573
- costs ,
574
- grad_vars ,
575
- extra_vars_and_values ,
576
- model = self ,
577
- initial_point = initial_point ,
578
- ravel_inputs = ravel_inputs ,
579
- ** kwargs ,
580
- )
572
+ with self :
573
+ return ValueGradFunction (
574
+ costs ,
575
+ grad_vars ,
576
+ extra_vars_and_values ,
577
+ model = self ,
578
+ initial_point = initial_point ,
579
+ ravel_inputs = ravel_inputs ,
580
+ ** kwargs ,
581
+ )
581
582
582
583
def compile_logp (
583
584
self ,
Original file line number Diff line number Diff line change @@ -443,6 +443,15 @@ def test_missing_data(self):
443
443
# Assert that all the elements of res are equal
444
444
assert res [1 :] == res [:- 1 ]
445
445
446
+ def test_check_bounds_out_of_model_context (self ):
447
+ with pm .Model (check_bounds = False ) as m :
448
+ x = pm .Normal ("x" )
449
+ y = pm .Normal ("y" , sigma = x )
450
+ fn = m .logp_dlogp_function (ravel_inputs = True )
451
+ fn .set_extra_values ({})
452
+ # When there are no bounds check logp turns into `nan`
453
+ assert np .isnan (fn (np .array ([- 1.0 , - 1.0 ]))[0 ])
454
+
446
455
447
456
class TestPytensorRelatedLogpBugs :
448
457
def test_pytensor_switch_broadcast_edge_cases_1 (self ):
You can’t perform that action at this time.
0 commit comments