@@ -628,7 +628,9 @@ def sample_prior_predictive(
628
628
629
629
return prior_predictive_samples
630
630
631
- def sample_posterior_predictive (self , X_pred , extend_idata , combined , predictions = True , ** kwargs ):
631
+ def sample_posterior_predictive (
632
+ self , X_pred , extend_idata , combined , predictions = True , ** kwargs
633
+ ):
632
634
"""
633
635
Sample from the model's posterior predictive distribution.
634
636
@@ -652,15 +654,15 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, prediction
652
654
self ._data_setter (X_pred )
653
655
654
656
with self .model : # sample with new input data
655
- post_pred = pm .sample_posterior_predictive (self .idata , predictions = predictions , ** kwargs )
657
+ post_pred = pm .sample_posterior_predictive (
658
+ self .idata , predictions = predictions , ** kwargs
659
+ )
656
660
if extend_idata :
657
661
self .idata .extend (post_pred , join = "right" )
658
662
659
663
group_name = "predictions" if predictions else "posterior_predictive"
660
664
661
- posterior_predictive_samples = az .extract (
662
- post_pred , group_name , combined = combined
663
- )
665
+ posterior_predictive_samples = az .extract (post_pred , group_name , combined = combined )
664
666
665
667
return posterior_predictive_samples
666
668
0 commit comments