Skip to content

Commit dee5c23

Browse files
committed
WIP
1 parent 840f8bf commit dee5c23

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

src/Control/Monad/Bayes/Class.hs

+25
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
{-# LANGUAGE RankNTypes #-}
33
{-# LANGUAGE RecordWildCards #-}
44
{-# OPTIONS_GHC -Wno-deprecations #-}
5+
{-# LANGUAGE KindSignatures #-}
6+
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
57

68
-- |
79
-- Module : Control.Monad.Bayes.Class
@@ -105,6 +107,7 @@ import Statistics.Distribution.Geometric (geometric0)
105107
import Statistics.Distribution.Normal (normalDistr)
106108
import Statistics.Distribution.Poisson qualified as Poisson
107109
import Statistics.Distribution.Uniform (uniformDistr)
110+
import Control.Monad.Trans (MonadTrans)
108111

109112
-- | Monads that can draw random variables.
110113
class Monad m => MonadDistribution m where
@@ -407,3 +410,25 @@ instance MonadFactor m => MonadFactor (ContT r m) where
407410
score = lift . score
408411

409412
instance MonadMeasure m => MonadMeasure (ContT r m)
413+
414+
-- * Utility for deriving MonadDistribution, MonadFactor and MonadMeasure
415+
416+
newtype MonadMeasureTrans (t :: (* -> *) -> * -> *) (m :: * -> *) a = MonadMeasureTrans { getMonadMeasureTrans :: t m a }
417+
deriving (Functor, Applicative, Monad)
418+
419+
instance MonadTrans t => MonadTrans (MonadMeasureTrans t) where
420+
lift = MonadMeasureTrans . lift
421+
422+
instance (MonadTrans t, MonadDistribution m, Monad (t m)) => MonadDistribution (MonadMeasureTrans t m) where
423+
random = lift random
424+
uniform = (lift .) . uniform
425+
normal = (lift .) . normal
426+
gamma = (lift .) . gamma
427+
beta = (lift .) . beta
428+
bernoulli = lift . bernoulli
429+
categorical = lift . categorical
430+
logCategorical = lift . logCategorical
431+
uniformD = lift . uniformD
432+
geometric = lift . geometric
433+
poisson = lift . poisson
434+
dirichlet = lift . dirichlet

0 commit comments

Comments
 (0)