Skip to content

Commit ebccc8c

Browse files
committed
Use Set instead of Map for underlying type of Sigma
A minor shortcoming is that indeed the resulting rows are ordered in a less pleasing way: previously, each row can be interpreted as a little-endian number, and the rows are in descending order.
1 parent e98208f commit ebccc8c

File tree

6 files changed

+34
-24
lines changed

6 files changed

+34
-24
lines changed

package.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies:
1111
- mtl
1212
- transformers
1313
- errors
14-
- containers
14+
- containers >= 0.5.11
1515
- array
1616
- reflection
1717
- megaparsec

src/Prob/CoreAST.hs

+9-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ module Prob.CoreAST
99
, Stmt(..)
1010
, Prog(..)
1111
, Sigma
12+
, sigmaInsert
1213
) where
1314

14-
import qualified Data.Map.Strict as M
15+
import qualified Data.Set as Set
1516
import Data.String
1617

1718
-- | The syntax of expressions, parametrized by the representation of variables
@@ -59,5 +60,10 @@ deriving instance Foldable (Prog r)
5960
instance IsString (Expr String) where
6061
fromString = Var
6162

62-
-- | Sigma is just the set of all variables assignments.
63-
type Sigma vt = M.Map vt Bool
63+
-- | Sigma is just the set of all variables assignments. Since our language only
64+
-- ever deals with Bool variables, we use a 'Set.Set' and the presence/absence
65+
-- indicates their values.
66+
type Sigma vt = Set.Set vt
67+
68+
sigmaInsert :: Ord vt => vt -> Bool -> Sigma vt -> Sigma vt
69+
sigmaInsert x v = (if v then Set.insert else Set.delete) x

src/Prob/Den.hs

+7-8
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ module Prob.Den
1515
, runDenStmt
1616
) where
1717

18-
import Control.Error
1918
import Control.Monad.Reader
2019
import Control.Monad.State
2120
import Data.Bifunctor
@@ -30,11 +29,11 @@ import qualified Prob.LinearEq as L
3029
-- Denotational Semantics
3130
--------------------------------------------------------------------------------
3231

33-
allPossibleStates :: (Ord k, Foldable t) => t k -> [M.Map k Bool]
34-
allPossibleStates = foldr (\var -> concatMap (\st -> [M.insert var True st, M.insert var False st])) [M.empty]
32+
allPossibleStates :: (Ord k) => Set.Set k -> [Sigma k]
33+
allPossibleStates = Set.toList . Set.powerSet
3534

3635
denExpr :: (Show vt, Ord vt) => Expr vt -> Sigma vt -> Bool
37-
denExpr (Var x) sigma = fromMaybe (error $ "undefined variable " ++ show x) $ M.lookup x sigma
36+
denExpr (Var x) sigma = Set.member x sigma
3837
denExpr (Constant d) _ = d
3938
denExpr (Or a b) sigma = denExpr a sigma || denExpr b sigma
4039
denExpr (And a b) sigma = denExpr a sigma && denExpr b sigma
@@ -52,10 +51,10 @@ type Den vt = StateT (Maybe (CurrentLoop vt)) (Reader (Sigma vt))
5251

5352
denStmt :: (Show vt, Ord vt) => [Stmt vt] -> Sigma vt -> Den vt (L.RHS (Sigma vt))
5453
denStmt [] sigma = lift $ ReaderT $ \sigma' -> if sigma' == sigma then pure (L.RHS 1 []) else pure (L.RHS 0 [])
55-
denStmt ((x := e):next) sigma = denStmt next (M.insert x (denExpr e sigma) sigma)
54+
denStmt ((x := e):next) sigma = denStmt next (sigmaInsert x (denExpr e sigma) sigma)
5655
denStmt ((x :~ Bernoulli theta):next) sigma = do
57-
dTrue <- denStmt next (M.insert x True sigma)
58-
dFalse <- denStmt next (M.insert x False sigma)
56+
dTrue <- denStmt next (sigmaInsert x True sigma)
57+
dFalse <- denStmt next (sigmaInsert x False sigma)
5958
pure ((theta `mult` dTrue) `plus` ((1 - theta) `mult` dFalse))
6059
where mult :: Rational -> L.RHS x -> L.RHS x
6160
mult k (L.RHS c tms) = L.RHS (k * c) (map (\(L.Term b y) -> L.Term (k*b) y) tms)
@@ -98,7 +97,7 @@ findDenProg :: (Ord vt) => [vt] -> (Set.Set vt -> Sigma vt -> r) -> r
9897
findDenProg p g = g vars initialState
9998
where
10099
vars = Set.fromList p
101-
initialState = M.fromSet (const False) vars
100+
initialState = Set.empty
102101
-- initial state: all variables initialized to False
103102

104103
extractRHS :: L.RHS vt -> Rational

src/Prob/Eval.hs

+6-6
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ import Control.Monad.ST
1515
import Control.Monad.State
1616
import Data.Bifunctor
1717
import Data.List
18-
import qualified Data.Map.Strict as M
1918
import Data.Ratio
19+
import qualified Data.Set as Set
2020
import Prob.CoreAST
2121
import System.Random.MWC
2222
import System.Random.MWC.Distributions
@@ -32,13 +32,13 @@ type ProgState vt s = (Sigma vt, Gen s)
3232
type Eval vt s = MaybeT (StateT (ProgState vt s) (ST s))
3333

3434
runE :: Eval vt s a -> IO (Maybe a)
35-
runE e = withSystemRandom . asGenST $ (\rng -> evalStateT (runMaybeT e) (M.empty, rng))
35+
runE e = withSystemRandom . asGenST $ (\rng -> evalStateT (runMaybeT e) (Set.empty, rng))
3636

3737
runEs :: Int -> Eval vt s a -> IO [a]
38-
runEs t e = withSystemRandom . asGenST $ (\rng -> catMaybes <$> evalStateT (replicateM t (runMaybeT e)) (M.empty, rng))
38+
runEs t e = withSystemRandom . asGenST $ (\rng -> catMaybes <$> evalStateT (replicateM t (runMaybeT e)) (Set.empty, rng))
3939

4040
evalExpr :: (Show vt, Ord vt) => Expr vt -> Eval vt s Bool
41-
evalExpr (Var x) = fromMaybe (error $ "undefined variable " ++ show x) <$> gets (M.lookup x . fst)
41+
evalExpr (Var x) = gets (Set.member x . fst)
4242
evalExpr (Constant d) = pure d
4343
evalExpr (Or a b) = liftA2 (||) (evalExpr a) (evalExpr b)
4444
evalExpr (And a b) = liftA2 (&&) (evalExpr a) (evalExpr b)
@@ -54,11 +54,11 @@ evalStmt :: (Show vt, Ord vt) => [Stmt vt] -> Eval vt s ()
5454
evalStmt [] = pure ()
5555
evalStmt ((x := a):next) = do
5656
v <- evalExpr a
57-
modify (first (M.insert x v))
57+
modify (first (sigmaInsert x v))
5858
evalStmt next
5959
evalStmt ((x :~ d):next) = do
6060
v <- drawDist d
61-
modify (first (M.insert x v))
61+
modify (first (sigmaInsert x v))
6262
evalStmt next
6363
evalStmt (Observe e:next) = do
6464
e' <- evalExpr e

src/Prob/Pretty.hs

+10-5
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ module Prob.Pretty
77
) where
88

99
import Data.Bifunctor
10-
import qualified Data.Map.Strict as M
10+
import Data.Foldable
1111
import Data.Ratio
12+
import qualified Data.Set as Set
1213
import Prob.CoreAST
1314
import Prob.Den (denProg)
1415
import Prob.Eval (sampled)
@@ -18,24 +19,28 @@ data Mode = ModeDen | ModeEval Int
1819
handleProgPretty :: forall vt r. (Show vt, Ord vt) => Prog r vt -> Mode -> IO ShowS
1920
handleProgPretty p m = formatResult <$> r
2021
where
22+
allVars :: Set.Set vt
23+
allVars =
24+
Set.fromList $ case p of Return s e -> concatMap toList s ++ toList e; ReturnAll s -> concatMap toList s
2125
r :: IO [(String, String)]
2226
r =
2327
case p of
2428
Return {} -> map (bimap (`shows` " ") (($ []) . formatRational)) <$> (case m of ModeDen -> pure (denProg p); ModeEval t -> sampled t p)
2529
ReturnAll {} -> map (bimap pprMap (($ []) . formatRational)) <$> (case m of ModeDen -> pure (denProg p); ModeEval t -> sampled t p)
2630
where
2731
pprMap :: Sigma vt -> String
28-
pprMap =
29-
M.foldrWithKey
30-
(\var val s ->
32+
pprMap sigma =
33+
foldr
34+
(\var s ->
3135
shows var .
3236
showString " ->" .
3337
showString
34-
(if val
38+
(if Set.member var sigma
3539
then " true "
3640
else " false ") $
3741
s)
3842
" "
43+
allVars
3944
formatRational rat = shows (numerator rat) . showChar '/' . shows (denominator rat)
4045
formatResult :: [(String, String)] -> ShowS
4146
formatResult [] = showString "No results produced.\n"

stack.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
resolver: lts-11.8
1+
resolver: lts-12.26
22
packages:
33
- .
44
extra-deps: []

0 commit comments

Comments
 (0)