Skip to content

Commit 40ba2aa

Browse files
authored
Merge pull request #6 from mheinzel/bitwise-and-boolean-operators
Improvements for bitwise, logical and comparison operators
2 parents e78c2e2 + a67cf34 commit 40ba2aa

File tree

10 files changed

+302
-146
lines changed

10 files changed

+302
-146
lines changed

src/Hython/Expression.hs

Lines changed: 61 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -82,50 +82,30 @@ evalExpr (BinOp (BitOp op) leftExpr rightExpr) = do
8282
return None
8383

8484
evalExpr (BinOp (BoolOp op) leftExpr rightExpr) = do
85-
[lhs, rhs] <- mapM evalExpr [leftExpr, rightExpr]
86-
case (op, lhs, rhs) of
87-
(And, Bool l, Bool r) -> newBool (l && r)
88-
(And, l, r) -> do
89-
left <- isTruthy l
90-
right <- isTruthy r
91-
return $ if left && right
92-
then r
93-
else l
94-
(Or, Bool l, Bool r) -> newBool (l || r)
95-
(Or, l, r) -> do
96-
left <- isTruthy l
97-
return $ if left
98-
then l
99-
else r
100-
101-
evalExpr (BinOp (CompOp op) leftExpr rightExpr) = do
102-
[lhs, rhs] <- mapM evalExpr [leftExpr, rightExpr]
103-
case (op, lhs, rhs) of
104-
(Eq, l, r) -> newBool =<< equal l r
105-
(NotEq, l, r) -> do
106-
b <- equal l r
107-
newBool $ not b
108-
(LessThan, Float l, Float r) -> newBool (l < r)
109-
(LessThan, Int l, Int r) -> newBool (l < r)
110-
(LessThanEq, Float l, Float r) -> newBool (l <= r)
111-
(LessThanEq, Int l, Int r) -> newBool (l <= r)
112-
(GreaterThan, Float l, Float r) -> newBool (l > r)
113-
(GreaterThan, Int l, Int r) -> newBool (l > r)
114-
(GreaterThanEq, Int l, Int r) -> newBool (l >= r)
115-
(GreaterThanEq, Float l, Float r) -> newBool (l >= r)
116-
(Is, l, r) -> newBool $ l `is` r
117-
(IsNot, l, r) -> newBool . not $ l `is` r
118-
(_, Float _, Int r) -> evalExpr (BinOp (CompOp op) leftExpr (constantF r))
119-
(_, Int l, Float _) -> evalExpr (BinOp (CompOp op) (constantF l) rightExpr)
120-
(In, l, r@(Object {})) -> invoke r "__contains__" [l]
121-
(NotIn, _, _) -> do
122-
(Bool b) <- evalExpr (BinOp (CompOp In) leftExpr rightExpr)
123-
newBool (not b)
124-
_ -> do
125-
raise "SystemError" ("unsupported operand type " ++ show op)
126-
return None
85+
lhs <- evalExpr leftExpr
86+
lTruthy <- isTruthy lhs
87+
case (op, lTruthy) of
88+
(And, False) -> return lhs
89+
(And, True) -> evalExpr rightExpr
90+
(Or, False) -> evalExpr rightExpr
91+
(Or, True) -> return lhs
92+
93+
evalExpr (BinOp (CompOp operator) leftExpression rightExpression) = do
94+
lhs <- evalExpr leftExpression
95+
go operator lhs rightExpression
12796
where
128-
constantF i = Constant $ ConstantFloat $ fromIntegral i
97+
go :: (MonadCont m, MonadEnv Object m, MonadIO m, MonadInterpreter m)
98+
=> ComparisonOperator -> Object -> Expression -> m Object
99+
-- for chained comparison, e.g. 1 < 2 < 3
100+
go op lhs (BinOp (CompOp rop) rl rr) = do
101+
rlhs <- evalExpr rl
102+
comp <- compareObjs op lhs rlhs
103+
case comp of
104+
Bool True -> go rop rlhs rr
105+
_ -> newBool False
106+
go op lhs rightExpr = do
107+
rhs <- evalExpr rightExpr
108+
compareObjs op lhs rhs
129109

130110
evalExpr (Call expr argExprs) = do
131111
target <- evalExpr expr
@@ -266,6 +246,44 @@ evalParam (DefaultParam param expr) = do
266246
evalParam (SplatParam param) = return $ SParam param
267247
evalParam (DoubleSplatParam param) = return $ DSParam param
268248

249+
250+
compareObjs :: MonadInterpreter m => ComparisonOperator -> Object -> Object -> m Object
251+
compareObjs op lhs rhs = case (op, lhs, rhs) of
252+
-- equality
253+
(Eq, l, r) -> newBool =<< equal l r
254+
(NotEq, l, r) -> newBool . not =<< equal l r
255+
-- "is", "is not", "in", "not in"
256+
(Is, l, r) -> newBool $ l `is` r
257+
(IsNot, l, r) -> newBool . not $ l `is` r
258+
(In, l, r) -> invoke r "__contains__" [l]
259+
(NotIn, l, r) -> do
260+
Bool b <- invoke r "__contains__" [l]
261+
newBool (not b)
262+
-- conversion
263+
(_, Bool l, _) -> compareObjs op (Float (boolToFloat l)) rhs
264+
(_, Int l, _) -> compareObjs op (Float (fromIntegral l)) rhs
265+
(_, _, Bool r) -> compareObjs op lhs (Float (boolToFloat r))
266+
(_, _, Int r) -> compareObjs op lhs (Float (fromIntegral r))
267+
-- comparison
268+
(_, Bytes l, Bytes r) -> applyOp op l r
269+
(_, Float l, Float r) -> applyOp op l r
270+
(_, String l, String r) -> applyOp op l r
271+
-- error cases
272+
_ -> do
273+
raise "TypeError" "unorderable types"
274+
return None
275+
where
276+
boolToFloat True = 1
277+
boolToFloat False = 0
278+
applyOp :: (Ord a, MonadInterpreter m) => ComparisonOperator -> a -> a -> m Object
279+
applyOp c a b = newBool $ compOp c a b
280+
compOp LessThan = (<)
281+
compOp LessThanEq = (<=)
282+
compOp GreaterThan = (>)
283+
compOp GreaterThanEq = (>=)
284+
compOp _ = error "compOp: Eq, NotEq, Is, IsNot, In, NotIn should be handled"
285+
286+
269287
is :: Object -> Object -> Bool
270288
is None None = True
271289
is (Bool l) (Bool r) = l == r

src/Hython/Types.hs

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -223,27 +223,34 @@ isNone (None) = True
223223
isNone _ = False
224224

225225
equal :: MonadInterpreter m => Object -> Object -> m Bool
226-
equal (Bool l) (Bool r) = return $ l == r
227-
equal (Bytes l) (Bytes r) = return $ l == r
228-
equal (Float l) (Float r) = return $ l == r
229-
equal (Imaginary l) (Imaginary r) = return $ l == r
230-
equal (Int l) (Int r) = return $ l == r
231-
equal (String l) (String r) = return $ l == r
232-
equal (BuiltinFn l) (BuiltinFn r) = return $ l == r
233-
equal (Dict l) (Dict r) = do
234-
left <- readRef l
235-
right <- readRef r
236-
if IntMap.size left /= IntMap.size right
237-
then return False
238-
else do
239-
results <- zipWithM pairEqual (IntMap.elems left) (IntMap.elems right)
240-
return $ all (== True) results
226+
equal lhs rhs = case (lhs, rhs) of
227+
-- conversion
228+
(Bool l, _) -> equal (Float (boolToFloat l)) rhs
229+
(Int l, _) -> equal (Float (fromIntegral l)) rhs
230+
(_, Bool r) -> equal lhs (Float (boolToFloat r))
231+
(_, Int r) -> equal lhs (Float (fromIntegral r))
232+
(None, None ) -> return True
233+
(Bytes l, Bytes r) -> return $ l == r
234+
(Float l, Float r) -> return $ l == r
235+
(Imaginary l, Imaginary r) -> return $ l == r
236+
(String l, String r) -> return $ l == r
237+
(BuiltinFn l, BuiltinFn r) -> return $ l == r
238+
(Dict l, Dict r) -> do
239+
left <- readRef l
240+
right <- readRef r
241+
if IntMap.size left /= IntMap.size right
242+
then return False
243+
else do
244+
results <- zipWithM pairEqual (IntMap.elems left) (IntMap.elems right)
245+
return $ all (== True) results
246+
(_, _) -> return False
241247
where
248+
boolToFloat False = 0.0
249+
boolToFloat True = 1.0
242250
pairEqual (lk, lv) (rk, rv) = do
243251
k <- equal lk rk
244252
v <- equal lv rv
245253
return $ k && v
246-
equal _ _ = return False
247254

248255
isTruthy :: MonadInterpreter m => Object -> m Bool
249256
isTruthy (None) = return False

src/Language/Python/Parser.y

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -458,28 +458,25 @@ lambdef_nocond
458458
: LAMBDA varargslist ':' test_nocond { LambdaExpr $2 $4 }
459459

460460
-- or_test: and_test ('or' and_test)*
461-
-- TODO: implement 0-n clauses
462461
or_test
463462
: and_test { $1 }
464-
| or_test OR and_test { BinOp (BoolOp Or) $1 $3 }
463+
| or_test OR and_test { BinOp (BoolOp Or) $1 $3 }
465464

466465
-- and_test: not_test ('and' not_test)*
467-
-- TODO: implement 0-n clauses
468466
and_test
469467
: not_test { $1 }
470468
| and_test AND not_test { BinOp (BoolOp And) $1 $3 }
471469

472470
-- not_test: 'not' not_test | comparison
473-
-- TODO: implement 0-n clauses
474471
not_test
475472
: NOT not_test { UnaryOp Not $2 }
476473
| comparison { $1 }
477474

478475
-- comparison: expr (comp_op expr)*
479-
-- TODO: implement 0-n clauses
476+
-- NOTE: right-recursive, because of 1 < 2 < 3 etc.
480477
comparison
481-
: expr { $1 }
482-
| expr comp_op expr { BinOp (CompOp $2) $1 $3 }
478+
: expr { $1 }
479+
| expr comp_op comparison { BinOp (CompOp $2) $1 $3 }
483480

484481
-- comp_op: '<'|'>'|'=='|'>='|'<='|'<>'|'!='|'in'|'not' 'in'|'is'|'is' 'not'
485482
comp_op
@@ -500,29 +497,25 @@ star_expr
500497
: '*' expr { undefined }
501498

502499
-- expr: xor_expr ('|' xor_expr)*
503-
-- TODO: implement 0-n handling
504500
expr
505501
: xor_expr { $1 }
506-
| xor_expr '|' xor_expr { BinOp (BitOp BitOr) $1 $3 }
502+
| expr '|' xor_expr { BinOp (BitOp BitOr) $1 $3 }
507503

508504
-- xor_expr: and_expr ('^' and_expr)*
509-
-- TODO: implement 0-n handling
510505
xor_expr
511506
: and_expr { $1 }
512-
| and_expr '^' and_expr { BinOp (BitOp BitXor) $1 $3 }
507+
| xor_expr '^' and_expr { BinOp (BitOp BitXor) $1 $3 }
513508

514509
-- and_expr: shift_expr ('&' shift_expr)*
515-
-- TODO: implement 0-n handling
516510
and_expr
517511
: shift_expr { $1 }
518-
| shift_expr '&' shift_expr { BinOp (BitOp BitAnd) $1 $3 }
512+
| and_expr '&' shift_expr { BinOp (BitOp BitAnd) $1 $3 }
519513

520514
-- shift_expr: arith_expr (('<<'|'>>') arith_expr)*
521-
-- TODO: implement 0-n handling
522515
shift_expr
523516
: arith_expr { $1 }
524-
| arith_expr '<<' arith_expr { BinOp (BitOp LShift) $1 $3 }
525-
| arith_expr '>>' arith_expr { BinOp (BitOp RShift) $1 $3 }
517+
| shift_expr '<<' arith_expr { BinOp (BitOp LShift) $1 $3 }
518+
| shift_expr '>>' arith_expr { BinOp (BitOp RShift) $1 $3 }
526519

527520
-- arith_expr: term (('+'|'-') term)*
528521
arith_expr

test/operators/arithmetic.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Integral arithmetic
2+
print(+5)
3+
print(-5)
4+
5+
print(1+1)
6+
print(3-5)
7+
print(2*4)
8+
print(3/5)
9+
print(5%2)
10+
print(5//2)
11+
print(5**3)
12+
print(0**0)
13+
14+
# Floating point arithmetic
15+
print(+5.0)
16+
print(-5.0)
17+
print(1.0 + 2.0)
18+
print(2.0 - 0.5)
19+
print(1.0 * 1.6)
20+
print(1.0 / 4.8)
21+
print(5.0 % 2.0)
22+
print(5.0 // 2.0)
23+
print(5.0 ** 3.4)
24+
print(0.0 ** 0.0)
25+
26+
# Mixed integral / floating point
27+
print(1 + 2.0)
28+
print(1.7 - 2)
29+
print(17.24 * 4)
30+
print(2 / 0.3)
31+
print(5.4 % 2)
32+
print(7 // 1.7)
33+
print(9.1 ** 3)
34+
print(2 ** 3.14159)
35+
36+
# Precedence
37+
print(2**-1)
38+
39+
print(3 - 2 + 1)
40+
print(1 + 2.2 * 3)
41+
print((1 + 2.2) * 3)
42+
print(3 + 5 % 2)
43+
print(10 / 3 // 2)
44+
print(4 + 2 ** 7)
45+
46+
47+
# TODO: imaginary number arithmetic
48+
#print(3.14j + 1.0)

test/operators/bitwise.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,11 @@
44
print(22 ^ 5)
55
print(2 << 4)
66
print(16 >> 2)
7+
8+
# multiple
9+
print(~~19)
10+
print(197 & 92 & 345)
11+
print(197 | 92 | 345)
12+
print(197 ^ 92 ^ 345)
13+
print(1 << 2 << 3)
14+
print(256 >> 4 >> 3)

test/operators/comparison.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Boolean operations
2+
print(42 == 42)
3+
print(42 != 18)
4+
5+
# Integer boolean operations
6+
print(5 > 1)
7+
print(5 >= 5)
8+
print(4 < 1)
9+
print(4 <= 4)
10+
11+
# Float boolean operators
12+
print(5.0 > 1.0)
13+
print(4.0 < 1.0)
14+
print(5.0 >= 5.0)
15+
print(4.0 <= 4.0)
16+
17+
# Mixed int/float with operators
18+
print(5 == 1.0)
19+
print(4 != 1.0)
20+
print(5 > 1.0)
21+
print(4 < 1.0)
22+
print(5 >= 5.0)
23+
print(4 <= 4.0)
24+
25+
print(5.0 == 1)
26+
print(4.0 != 1)
27+
print(5.0 > 1)
28+
print(4.0 < 1)
29+
print(5.0 >= 5)
30+
print(4.0 <= 4)
31+
32+
33+
# On strings
34+
print("test" == "test")
35+
print("" == "test")
36+
print("a" < "b")
37+
print("b" < "a")
38+
print("abc" < "abcde")
39+
40+
41+
# On bytes
42+
print(b'test' == b'test')
43+
print(b'' == b'test')
44+
print(b'a' < b'b')
45+
print(b'b' < b'a')
46+
print(b'abc' < b'abcde')
47+
48+
49+
# Mixed
50+
print("mixed!")
51+
elements = [-3, -0.00001, 0, 0.0, 1, 1.0, 32, "", "foo", b'fo', b'foo', False, True, None]
52+
for a in elements:
53+
for b in elements:
54+
try:
55+
print("comparison")
56+
print(a == b)
57+
print(a != b)
58+
print(a < b)
59+
print(a > b)
60+
print(a <= b)
61+
print(a >= b)
62+
except TypeError:
63+
print("TypeError")
64+
65+
66+
# Chaining comparison operators
67+
print("chained")
68+
print(3 == 3 != 3)
69+
print(4 == 4 > 2)
70+
print(1 < 2 < 3 < 4)
71+
print(2 > 3 > -1)
72+
print(3 <= 10 > 17)
73+
print(-3 < 10 >= 10)
74+
75+
# Mixed with other operators
76+
print(3 < 2 * 2 < 5)
77+
print(4 // 2 == 2 < 3)
78+
print(7 > 6 > 7 - 1 and (True and False) < True)
79+
80+
def p(n):
81+
print(n)
82+
return n
83+
84+
# Only evaluate each element once
85+
print(p(1) < p(2) < p(3) < p(4))
86+
87+
# Stop after first wrong comparison
88+
print(p(7) < p(6) < p(8) < p(9))
89+
90+
91+
92+
# TODO: is
93+
# TODO: Chained comparison operators
94+
# TODO: Comparison of lists, tuples
95+
# TODO: Equality of sets, dicts
96+
# TODO: Imaginary
97+
# TODO: objects with __eq__() etc.
98+

0 commit comments

Comments
 (0)