Skip to content

Improvements for bitwise, logical and comparison operators #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 20, 2016
104 changes: 61 additions & 43 deletions src/Hython/Expression.hs
Original file line number Diff line number Diff line change
Expand Up @@ -82,50 +82,30 @@ evalExpr (BinOp (BitOp op) leftExpr rightExpr) = do
return None

evalExpr (BinOp (BoolOp op) leftExpr rightExpr) = do
[lhs, rhs] <- mapM evalExpr [leftExpr, rightExpr]
case (op, lhs, rhs) of
(And, Bool l, Bool r) -> newBool (l && r)
(And, l, r) -> do
left <- isTruthy l
right <- isTruthy r
return $ if left && right
then r
else l
(Or, Bool l, Bool r) -> newBool (l || r)
(Or, l, r) -> do
left <- isTruthy l
return $ if left
then l
else r

evalExpr (BinOp (CompOp op) leftExpr rightExpr) = do
[lhs, rhs] <- mapM evalExpr [leftExpr, rightExpr]
case (op, lhs, rhs) of
(Eq, l, r) -> newBool =<< equal l r
(NotEq, l, r) -> do
b <- equal l r
newBool $ not b
(LessThan, Float l, Float r) -> newBool (l < r)
(LessThan, Int l, Int r) -> newBool (l < r)
(LessThanEq, Float l, Float r) -> newBool (l <= r)
(LessThanEq, Int l, Int r) -> newBool (l <= r)
(GreaterThan, Float l, Float r) -> newBool (l > r)
(GreaterThan, Int l, Int r) -> newBool (l > r)
(GreaterThanEq, Int l, Int r) -> newBool (l >= r)
(GreaterThanEq, Float l, Float r) -> newBool (l >= r)
(Is, l, r) -> newBool $ l `is` r
(IsNot, l, r) -> newBool . not $ l `is` r
(_, Float _, Int r) -> evalExpr (BinOp (CompOp op) leftExpr (constantF r))
(_, Int l, Float _) -> evalExpr (BinOp (CompOp op) (constantF l) rightExpr)
(In, l, r@(Object {})) -> invoke r "__contains__" [l]
(NotIn, _, _) -> do
(Bool b) <- evalExpr (BinOp (CompOp In) leftExpr rightExpr)
newBool (not b)
_ -> do
raise "SystemError" ("unsupported operand type " ++ show op)
return None
lhs <- evalExpr leftExpr
lTruthy <- isTruthy lhs
case (op, lTruthy) of
(And, False) -> return lhs
(And, True) -> evalExpr rightExpr
(Or, False) -> evalExpr rightExpr
(Or, True) -> return lhs

evalExpr (BinOp (CompOp operator) leftExpression rightExpression) = do
lhs <- evalExpr leftExpression
go operator lhs rightExpression
where
constantF i = Constant $ ConstantFloat $ fromIntegral i
go :: (MonadCont m, MonadEnv Object m, MonadIO m, MonadInterpreter m)
=> ComparisonOperator -> Object -> Expression -> m Object
-- for chained comparison, e.g. 1 < 2 < 3
go op lhs (BinOp (CompOp rop) rl rr) = do
rlhs <- evalExpr rl
comp <- compareObjs op lhs rlhs
case comp of
Bool True -> go rop rlhs rr
_ -> newBool False
go op lhs rightExpr = do
rhs <- evalExpr rightExpr
compareObjs op lhs rhs

evalExpr (Call expr argExprs) = do
target <- evalExpr expr
Expand Down Expand Up @@ -266,6 +246,44 @@ evalParam (DefaultParam param expr) = do
evalParam (SplatParam param) = return $ SParam param
evalParam (DoubleSplatParam param) = return $ DSParam param


compareObjs :: MonadInterpreter m => ComparisonOperator -> Object -> Object -> m Object
compareObjs op lhs rhs = case (op, lhs, rhs) of
-- equality
(Eq, l, r) -> newBool =<< equal l r
(NotEq, l, r) -> newBool . not =<< equal l r
-- "is", "is not", "in", "not in"
(Is, l, r) -> newBool $ l `is` r
(IsNot, l, r) -> newBool . not $ l `is` r
(In, l, r) -> invoke r "__contains__" [l]
(NotIn, l, r) -> do
Bool b <- invoke r "__contains__" [l]
newBool (not b)
-- conversion
(_, Bool l, _) -> compareObjs op (Float (boolToFloat l)) rhs
(_, Int l, _) -> compareObjs op (Float (fromIntegral l)) rhs
(_, _, Bool r) -> compareObjs op lhs (Float (boolToFloat r))
(_, _, Int r) -> compareObjs op lhs (Float (fromIntegral r))
-- comparison
(_, Bytes l, Bytes r) -> applyOp op l r
(_, Float l, Float r) -> applyOp op l r
(_, String l, String r) -> applyOp op l r
-- error cases
_ -> do
raise "TypeError" "unorderable types"
return None
where
boolToFloat True = 1
boolToFloat False = 0
applyOp :: (Ord a, MonadInterpreter m) => ComparisonOperator -> a -> a -> m Object
applyOp c a b = newBool $ compOp c a b
compOp LessThan = (<)
compOp LessThanEq = (<=)
compOp GreaterThan = (>)
compOp GreaterThanEq = (>=)
compOp _ = error "compOp: Eq, NotEq, Is, IsNot, In, NotIn should be handled"


is :: Object -> Object -> Bool
is None None = True
is (Bool l) (Bool r) = l == r
Expand Down
39 changes: 23 additions & 16 deletions src/Hython/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -223,27 +223,34 @@ isNone (None) = True
isNone _ = False

equal :: MonadInterpreter m => Object -> Object -> m Bool
equal (Bool l) (Bool r) = return $ l == r
equal (Bytes l) (Bytes r) = return $ l == r
equal (Float l) (Float r) = return $ l == r
equal (Imaginary l) (Imaginary r) = return $ l == r
equal (Int l) (Int r) = return $ l == r
equal (String l) (String r) = return $ l == r
equal (BuiltinFn l) (BuiltinFn r) = return $ l == r
equal (Dict l) (Dict r) = do
left <- readRef l
right <- readRef r
if IntMap.size left /= IntMap.size right
then return False
else do
results <- zipWithM pairEqual (IntMap.elems left) (IntMap.elems right)
return $ all (== True) results
equal lhs rhs = case (lhs, rhs) of
-- conversion
(Bool l, _) -> equal (Float (boolToFloat l)) rhs
(Int l, _) -> equal (Float (fromIntegral l)) rhs
(_, Bool r) -> equal lhs (Float (boolToFloat r))
(_, Int r) -> equal lhs (Float (fromIntegral r))
(None, None ) -> return True
(Bytes l, Bytes r) -> return $ l == r
(Float l, Float r) -> return $ l == r
(Imaginary l, Imaginary r) -> return $ l == r
(String l, String r) -> return $ l == r
(BuiltinFn l, BuiltinFn r) -> return $ l == r
(Dict l, Dict r) -> do
left <- readRef l
right <- readRef r
if IntMap.size left /= IntMap.size right
then return False
else do
results <- zipWithM pairEqual (IntMap.elems left) (IntMap.elems right)
return $ all (== True) results
(_, _) -> return False
where
boolToFloat False = 0.0
boolToFloat True = 1.0
pairEqual (lk, lv) (rk, rv) = do
k <- equal lk rk
v <- equal lv rv
return $ k && v
equal _ _ = return False

isTruthy :: MonadInterpreter m => Object -> m Bool
isTruthy (None) = return False
Expand Down
25 changes: 9 additions & 16 deletions src/Language/Python/Parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -457,28 +457,25 @@ lambdef_nocond
: LAMBDA varargslist ':' test_nocond { LambdaExpr $2 $4 }

-- or_test: and_test ('or' and_test)*
-- TODO: implement 0-n clauses
or_test
: and_test { $1 }
| or_test OR and_test { BinOp (BoolOp Or) $1 $3 }
| or_test OR and_test { BinOp (BoolOp Or) $1 $3 }

-- and_test: not_test ('and' not_test)*
-- TODO: implement 0-n clauses
and_test
: not_test { $1 }
| and_test AND not_test { BinOp (BoolOp And) $1 $3 }

-- not_test: 'not' not_test | comparison
-- TODO: implement 0-n clauses
not_test
: NOT not_test { UnaryOp Not $2 }
| comparison { $1 }

-- comparison: expr (comp_op expr)*
-- TODO: implement 0-n clauses
-- NOTE: right-recursive, because of 1 < 2 < 3 etc.
comparison
: expr { $1 }
| expr comp_op expr { BinOp (CompOp $2) $1 $3 }
: expr { $1 }
| expr comp_op comparison { BinOp (CompOp $2) $1 $3 }

-- comp_op: '<'|'>'|'=='|'>='|'<='|'<>'|'!='|'in'|'not' 'in'|'is'|'is' 'not'
comp_op
Expand All @@ -499,29 +496,25 @@ star_expr
: '*' expr { undefined }

-- expr: xor_expr ('|' xor_expr)*
-- TODO: implement 0-n handling
expr
: xor_expr { $1 }
| xor_expr '|' xor_expr { BinOp (BitOp BitOr) $1 $3 }
| expr '|' xor_expr { BinOp (BitOp BitOr) $1 $3 }

-- xor_expr: and_expr ('^' and_expr)*
-- TODO: implement 0-n handling
xor_expr
: and_expr { $1 }
| and_expr '^' and_expr { BinOp (BitOp BitXor) $1 $3 }
| xor_expr '^' and_expr { BinOp (BitOp BitXor) $1 $3 }

-- and_expr: shift_expr ('&' shift_expr)*
-- TODO: implement 0-n handling
and_expr
: shift_expr { $1 }
| shift_expr '&' shift_expr { BinOp (BitOp BitAnd) $1 $3 }
| and_expr '&' shift_expr { BinOp (BitOp BitAnd) $1 $3 }

-- shift_expr: arith_expr (('<<'|'>>') arith_expr)*
-- TODO: implement 0-n handling
shift_expr
: arith_expr { $1 }
| arith_expr '<<' arith_expr { BinOp (BitOp LShift) $1 $3 }
| arith_expr '>>' arith_expr { BinOp (BitOp RShift) $1 $3 }
| shift_expr '<<' arith_expr { BinOp (BitOp LShift) $1 $3 }
| shift_expr '>>' arith_expr { BinOp (BitOp RShift) $1 $3 }

-- arith_expr: term (('+'|'-') term)*
arith_expr
Expand Down
48 changes: 48 additions & 0 deletions test/operators/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Integral arithmetic
print(+5)
print(-5)

print(1+1)
print(3-5)
print(2*4)
print(3/5)
print(5%2)
print(5//2)
print(5**3)
print(0**0)

# Floating point arithmetic
print(+5.0)
print(-5.0)
print(1.0 + 2.0)
print(2.0 - 0.5)
print(1.0 * 1.6)
print(1.0 / 4.8)
print(5.0 % 2.0)
print(5.0 // 2.0)
print(5.0 ** 3.4)
print(0.0 ** 0.0)

# Mixed integral / floating point
print(1 + 2.0)
print(1.7 - 2)
print(17.24 * 4)
print(2 / 0.3)
print(5.4 % 2)
print(7 // 1.7)
print(9.1 ** 3)
print(2 ** 3.14159)

# Precedence
print(2**-1)

print(3 - 2 + 1)
print(1 + 2.2 * 3)
print((1 + 2.2) * 3)
print(3 + 5 % 2)
print(10 / 3 // 2)
print(4 + 2 ** 7)


# TODO: imaginary number arithmetic
#print(3.14j + 1.0)
8 changes: 8 additions & 0 deletions test/operators/bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,11 @@
print(22 ^ 5)
print(2 << 4)
print(16 >> 2)

# multiple
print(~~19)
print(197 & 92 & 345)
print(197 | 92 | 345)
print(197 ^ 92 ^ 345)
print(1 << 2 << 3)
print(256 >> 4 >> 3)
98 changes: 98 additions & 0 deletions test/operators/comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Boolean operations
print(42 == 42)
print(42 != 18)

# Integer boolean operations
print(5 > 1)
print(5 >= 5)
print(4 < 1)
print(4 <= 4)

# Float boolean operators
print(5.0 > 1.0)
print(4.0 < 1.0)
print(5.0 >= 5.0)
print(4.0 <= 4.0)

# Mixed int/float with operators
print(5 == 1.0)
print(4 != 1.0)
print(5 > 1.0)
print(4 < 1.0)
print(5 >= 5.0)
print(4 <= 4.0)

print(5.0 == 1)
print(4.0 != 1)
print(5.0 > 1)
print(4.0 < 1)
print(5.0 >= 5)
print(4.0 <= 4)


# On strings
print("test" == "test")
print("" == "test")
print("a" < "b")
print("b" < "a")
print("abc" < "abcde")


# On bytes
print(b'test' == b'test')
print(b'' == b'test')
print(b'a' < b'b')
print(b'b' < b'a')
print(b'abc' < b'abcde')


# Mixed
print("mixed!")
elements = [-3, -0.00001, 0, 0.0, 1, 1.0, 32, "", "foo", b'fo', b'foo', False, True, None]
for a in elements:
for b in elements:
try:
print("comparison")
print(a == b)
print(a != b)
print(a < b)
print(a > b)
print(a <= b)
print(a >= b)
except TypeError:
print("TypeError")


# Chaining comparison operators
print("chained")
print(3 == 3 != 3)
print(4 == 4 > 2)
print(1 < 2 < 3 < 4)
print(2 > 3 > -1)
print(3 <= 10 > 17)
print(-3 < 10 >= 10)

# Mixed with other operators
print(3 < 2 * 2 < 5)
print(4 // 2 == 2 < 3)
print(7 > 6 > 7 - 1 and (True and False) < True)

def p(n):
print(n)
return n

# Only evaluate each element once
print(p(1) < p(2) < p(3) < p(4))

# Stop after first wrong comparison
print(p(7) < p(6) < p(8) < p(9))



# TODO: is
# TODO: Chained comparison operators
# TODO: Comparison of lists, tuples
# TODO: Equality of sets, dicts
# TODO: Imaginary
# TODO: objects with __eq__() etc.

Loading