|
2 | 2 | {-# LANGUAGE RankNTypes #-}
|
3 | 3 | {-# LANGUAGE RecordWildCards #-}
|
4 | 4 | {-# OPTIONS_GHC -Wno-deprecations #-}
|
| 5 | +{-# LANGUAGE KindSignatures #-} |
| 6 | +{-# LANGUAGE GeneralizedNewtypeDeriving #-} |
5 | 7 |
|
6 | 8 | -- |
|
7 | 9 | -- Module : Control.Monad.Bayes.Class
|
@@ -105,6 +107,7 @@ import Statistics.Distribution.Geometric (geometric0)
|
105 | 107 | import Statistics.Distribution.Normal (normalDistr)
|
106 | 108 | import Statistics.Distribution.Poisson qualified as Poisson
|
107 | 109 | import Statistics.Distribution.Uniform (uniformDistr)
|
| 110 | +import Control.Monad.Trans (MonadTrans) |
108 | 111 |
|
109 | 112 | -- | Monads that can draw random variables.
|
110 | 113 | class Monad m => MonadDistribution m where
|
@@ -407,3 +410,25 @@ instance MonadFactor m => MonadFactor (ContT r m) where
|
407 | 410 | score = lift . score
|
408 | 411 |
|
409 | 412 | instance MonadMeasure m => MonadMeasure (ContT r m)
|
| 413 | + |
| 414 | +-- * Utility for deriving MonadDistribution, MonadFactor and MonadMeasure |
| 415 | + |
| 416 | +newtype MonadMeasureTrans (t :: (* -> *) -> * -> *) (m :: * -> *) a = MonadMeasureTrans { getMonadMeasureTrans :: t m a } |
| 417 | + deriving (Functor, Applicative, Monad) |
| 418 | + |
| 419 | +instance MonadTrans t => MonadTrans (MonadMeasureTrans t) where |
| 420 | + lift = MonadMeasureTrans . lift |
| 421 | + |
| 422 | +instance (MonadTrans t, MonadDistribution m, Monad (t m)) => MonadDistribution (MonadMeasureTrans t m) where |
| 423 | + random = lift random |
| 424 | + uniform = (lift .) . uniform |
| 425 | + normal = (lift .) . normal |
| 426 | + gamma = (lift .) . gamma |
| 427 | + beta = (lift .) . beta |
| 428 | + bernoulli = lift . bernoulli |
| 429 | + categorical = lift . categorical |
| 430 | + logCategorical = lift . logCategorical |
| 431 | + uniformD = lift . uniformD |
| 432 | + geometric = lift . geometric |
| 433 | + poisson = lift . poisson |
| 434 | + dirichlet = lift . dirichlet |
0 commit comments