Skip to content
Merged
104 changes: 61 additions & 43 deletions src/Hython/Expression.hs
Original file line number Diff line number Diff line change
Expand Up @@ -82,50 +82,30 @@ evalExpr (BinOp (BitOp op) leftExpr rightExpr) = do
return None

evalExpr (BinOp (BoolOp op) leftExpr rightExpr) = do
[lhs, rhs] <- mapM evalExpr [leftExpr, rightExpr]
case (op, lhs, rhs) of
(And, Bool l, Bool r) -> newBool (l && r)
(And, l, r) -> do
left <- isTruthy l
right <- isTruthy r
return $ if left && right
then r
else l
(Or, Bool l, Bool r) -> newBool (l || r)
(Or, l, r) -> do
left <- isTruthy l
return $ if left
then l
else r

evalExpr (BinOp (CompOp op) leftExpr rightExpr) = do
[lhs, rhs] <- mapM evalExpr [leftExpr, rightExpr]
case (op, lhs, rhs) of
(Eq, l, r) -> newBool =<< equal l r
(NotEq, l, r) -> do
b <- equal l r
newBool $ not b
(LessThan, Float l, Float r) -> newBool (l < r)
(LessThan, Int l, Int r) -> newBool (l < r)
(LessThanEq, Float l, Float r) -> newBool (l <= r)
(LessThanEq, Int l, Int r) -> newBool (l <= r)
(GreaterThan, Float l, Float r) -> newBool (l > r)
(GreaterThan, Int l, Int r) -> newBool (l > r)
(GreaterThanEq, Int l, Int r) -> newBool (l >= r)
(GreaterThanEq, Float l, Float r) -> newBool (l >= r)
(Is, l, r) -> newBool $ l `is` r
(IsNot, l, r) -> newBool . not $ l `is` r
(_, Float _, Int r) -> evalExpr (BinOp (CompOp op) leftExpr (constantF r))
(_, Int l, Float _) -> evalExpr (BinOp (CompOp op) (constantF l) rightExpr)
(In, l, r@(Object {})) -> invoke r "__contains__" [l]
(NotIn, _, _) -> do
(Bool b) <- evalExpr (BinOp (CompOp In) leftExpr rightExpr)
newBool (not b)
_ -> do
raise "SystemError" ("unsupported operand type " ++ show op)
return None
lhs <- evalExpr leftExpr
lTruthy <- isTruthy lhs
case (op, lTruthy) of
(And, False) -> return lhs
(And, True) -> evalExpr rightExpr
(Or, False) -> evalExpr rightExpr
(Or, True) -> return lhs

evalExpr (BinOp (CompOp operator) leftExpression rightExpression) = do
lhs <- evalExpr leftExpression
go operator lhs rightExpression
where
constantF i = Constant $ ConstantFloat $ fromIntegral i
go :: (MonadCont m, MonadEnv Object m, MonadIO m, MonadInterpreter m)
=> ComparisonOperator -> Object -> Expression -> m Object
-- for chained comparison, e.g. 1 < 2 < 3
go op lhs (BinOp (CompOp rop) rl rr) = do
rlhs <- evalExpr rl
comp <- compareObjs op lhs rlhs
case comp of
Bool True -> go rop rlhs rr
_ -> newBool False
go op lhs rightExpr = do
rhs <- evalExpr rightExpr
compareObjs op lhs rhs

evalExpr (Call expr argExprs) = do
target <- evalExpr expr
Expand Down Expand Up @@ -266,6 +246,44 @@ evalParam (DefaultParam param expr) = do
evalParam (SplatParam param) = return $ SParam param
evalParam (DoubleSplatParam param) = return $ DSParam param


compareObjs :: MonadInterpreter m => ComparisonOperator -> Object -> Object -> m Object
compareObjs op lhs rhs = case (op, lhs, rhs) of
-- equality
(Eq, l, r) -> newBool =<< equal l r
(NotEq, l, r) -> newBool . not =<< equal l r
-- "is", "is not", "in", "not in"
(Is, l, r) -> newBool $ l `is` r
(IsNot, l, r) -> newBool . not $ l `is` r
(In, l, r) -> invoke r "__contains__" [l]
(NotIn, l, r) -> do
Bool b <- invoke r "__contains__" [l]
newBool (not b)
-- conversion
(_, Bool l, _) -> compareObjs op (Float (boolToFloat l)) rhs
(_, Int l, _) -> compareObjs op (Float (fromIntegral l)) rhs
(_, _, Bool r) -> compareObjs op lhs (Float (boolToFloat r))
(_, _, Int r) -> compareObjs op lhs (Float (fromIntegral r))
-- comparison
(_, Bytes l, Bytes r) -> applyOp op l r
(_, Float l, Float r) -> applyOp op l r
(_, String l, String r) -> applyOp op l r
-- error cases
_ -> do
raise "TypeError" "unorderable types"
return None
where
boolToFloat True = 1
boolToFloat False = 0
applyOp :: (Ord a, MonadInterpreter m) => ComparisonOperator -> a -> a -> m Object
applyOp c a b = newBool $ compOp c a b
compOp LessThan = (<)
compOp LessThanEq = (<=)
compOp GreaterThan = (>)
compOp GreaterThanEq = (>=)
compOp _ = error "compOp: Eq, NotEq, Is, IsNot, In, NotIn should be handled"


is :: Object -> Object -> Bool
is None None = True
is (Bool l) (Bool r) = l == r
Expand Down
39 changes: 23 additions & 16 deletions src/Hython/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -223,27 +223,34 @@ isNone (None) = True
isNone _ = False

equal :: MonadInterpreter m => Object -> Object -> m Bool
equal (Bool l) (Bool r) = return $ l == r
equal (Bytes l) (Bytes r) = return $ l == r
equal (Float l) (Float r) = return $ l == r
equal (Imaginary l) (Imaginary r) = return $ l == r
equal (Int l) (Int r) = return $ l == r
equal (String l) (String r) = return $ l == r
equal (BuiltinFn l) (BuiltinFn r) = return $ l == r
equal (Dict l) (Dict r) = do
left <- readRef l
right <- readRef r
if IntMap.size left /= IntMap.size right
then return False
else do
results <- zipWithM pairEqual (IntMap.elems left) (IntMap.elems right)
return $ all (== True) results
equal lhs rhs = case (lhs, rhs) of
-- conversion
(Bool l, _) -> equal (Float (boolToFloat l)) rhs
(Int l, _) -> equal (Float (fromIntegral l)) rhs
(_, Bool r) -> equal lhs (Float (boolToFloat r))
(_, Int r) -> equal lhs (Float (fromIntegral r))
(None, None ) -> return True
(Bytes l, Bytes r) -> return $ l == r
(Float l, Float r) -> return $ l == r
(Imaginary l, Imaginary r) -> return $ l == r
(String l, String r) -> return $ l == r
(BuiltinFn l, BuiltinFn r) -> return $ l == r
(Dict l, Dict r) -> do
left <- readRef l
right <- readRef r
if IntMap.size left /= IntMap.size right
then return False
else do
results <- zipWithM pairEqual (IntMap.elems left) (IntMap.elems right)
return $ all (== True) results
(_, _) -> return False
where
boolToFloat False = 0.0
boolToFloat True = 1.0
pairEqual (lk, lv) (rk, rv) = do
k <- equal lk rk
v <- equal lv rv
return $ k && v
equal _ _ = return False

isTruthy :: MonadInterpreter m => Object -> m Bool
isTruthy (None) = return False
Expand Down
25 changes: 9 additions & 16 deletions src/Language/Python/Parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -457,28 +457,25 @@ lambdef_nocond
: LAMBDA varargslist ':' test_nocond { LambdaExpr $2 $4 }

-- or_test: and_test ('or' and_test)*
-- TODO: implement 0-n clauses
or_test
: and_test { $1 }
| or_test OR and_test { BinOp (BoolOp Or) $1 $3 }
| or_test OR and_test { BinOp (BoolOp Or) $1 $3 }

-- and_test: not_test ('and' not_test)*
-- TODO: implement 0-n clauses
and_test
: not_test { $1 }
| and_test AND not_test { BinOp (BoolOp And) $1 $3 }

-- not_test: 'not' not_test | comparison
-- TODO: implement 0-n clauses
not_test
: NOT not_test { UnaryOp Not $2 }
| comparison { $1 }

-- comparison: expr (comp_op expr)*
-- TODO: implement 0-n clauses
-- NOTE: right-recursive, because of 1 < 2 < 3 etc.
comparison
: expr { $1 }
| expr comp_op expr { BinOp (CompOp $2) $1 $3 }
: expr { $1 }
| expr comp_op comparison { BinOp (CompOp $2) $1 $3 }

-- comp_op: '<'|'>'|'=='|'>='|'<='|'<>'|'!='|'in'|'not' 'in'|'is'|'is' 'not'
comp_op
Expand All @@ -499,29 +496,25 @@ star_expr
: '*' expr { undefined }

-- expr: xor_expr ('|' xor_expr)*
-- TODO: implement 0-n handling
expr
: xor_expr { $1 }
| xor_expr '|' xor_expr { BinOp (BitOp BitOr) $1 $3 }
| expr '|' xor_expr { BinOp (BitOp BitOr) $1 $3 }

-- xor_expr: and_expr ('^' and_expr)*
-- TODO: implement 0-n handling
xor_expr
: and_expr { $1 }
| and_expr '^' and_expr { BinOp (BitOp BitXor) $1 $3 }
| xor_expr '^' and_expr { BinOp (BitOp BitXor) $1 $3 }

-- and_expr: shift_expr ('&' shift_expr)*
-- TODO: implement 0-n handling
and_expr
: shift_expr { $1 }
| shift_expr '&' shift_expr { BinOp (BitOp BitAnd) $1 $3 }
| and_expr '&' shift_expr { BinOp (BitOp BitAnd) $1 $3 }

-- shift_expr: arith_expr (('<<'|'>>') arith_expr)*
-- TODO: implement 0-n handling
shift_expr
: arith_expr { $1 }
| arith_expr '<<' arith_expr { BinOp (BitOp LShift) $1 $3 }
| arith_expr '>>' arith_expr { BinOp (BitOp RShift) $1 $3 }
| shift_expr '<<' arith_expr { BinOp (BitOp LShift) $1 $3 }
| shift_expr '>>' arith_expr { BinOp (BitOp RShift) $1 $3 }

-- arith_expr: term (('+'|'-') term)*
arith_expr
Expand Down
48 changes: 48 additions & 0 deletions test/operators/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Integral arithmetic
print(+5)
print(-5)

print(1+1)
print(3-5)
print(2*4)
print(3/5)
print(5%2)
print(5//2)
print(5**3)
print(0**0)

# Floating point arithmetic
print(+5.0)
print(-5.0)
print(1.0 + 2.0)
print(2.0 - 0.5)
print(1.0 * 1.6)
print(1.0 / 4.8)
print(5.0 % 2.0)
print(5.0 // 2.0)
print(5.0 ** 3.4)
print(0.0 ** 0.0)

# Mixed integral / floating point
print(1 + 2.0)
print(1.7 - 2)
print(17.24 * 4)
print(2 / 0.3)
print(5.4 % 2)
print(7 // 1.7)
print(9.1 ** 3)
print(2 ** 3.14159)

# Precedence
print(2**-1)

print(3 - 2 + 1)
print(1 + 2.2 * 3)
print((1 + 2.2) * 3)
print(3 + 5 % 2)
print(10 / 3 // 2)
print(4 + 2 ** 7)


# TODO: imaginary number arithmetic
#print(3.14j + 1.0)
8 changes: 8 additions & 0 deletions test/operators/bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,11 @@
print(22 ^ 5)
print(2 << 4)
print(16 >> 2)

# multiple
print(~~19)
print(197 & 92 & 345)
print(197 | 92 | 345)
print(197 ^ 92 ^ 345)
print(1 << 2 << 3)
print(256 >> 4 >> 3)
98 changes: 98 additions & 0 deletions test/operators/comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Boolean operations
print(42 == 42)
print(42 != 18)

# Integer boolean operations
print(5 > 1)
print(5 >= 5)
print(4 < 1)
print(4 <= 4)

# Float boolean operators
print(5.0 > 1.0)
print(4.0 < 1.0)
print(5.0 >= 5.0)
print(4.0 <= 4.0)

# Mixed int/float with operators
print(5 == 1.0)
print(4 != 1.0)
print(5 > 1.0)
print(4 < 1.0)
print(5 >= 5.0)
print(4 <= 4.0)

print(5.0 == 1)
print(4.0 != 1)
print(5.0 > 1)
print(4.0 < 1)
print(5.0 >= 5)
print(4.0 <= 4)


# On strings
print("test" == "test")
print("" == "test")
print("a" < "b")
print("b" < "a")
print("abc" < "abcde")


# On bytes
print(b'test' == b'test')
print(b'' == b'test')
print(b'a' < b'b')
print(b'b' < b'a')
print(b'abc' < b'abcde')


# Mixed
print("mixed!")
elements = [-3, -0.00001, 0, 0.0, 1, 1.0, 32, "", "foo", b'fo', b'foo', False, True, None]
for a in elements:
for b in elements:
try:
print("comparison")
print(a == b)
print(a != b)
print(a < b)
print(a > b)
print(a <= b)
print(a >= b)
except TypeError:
print("TypeError")


# Chaining comparison operators
print("chained")
print(3 == 3 != 3)
print(4 == 4 > 2)
print(1 < 2 < 3 < 4)
print(2 > 3 > -1)
print(3 <= 10 > 17)
print(-3 < 10 >= 10)

# Mixed with other operators
print(3 < 2 * 2 < 5)
print(4 // 2 == 2 < 3)
print(7 > 6 > 7 - 1 and (True and False) < True)

def p(n):
print(n)
return n

# Only evaluate each element once
print(p(1) < p(2) < p(3) < p(4))

# Stop after first wrong comparison
print(p(7) < p(6) < p(8) < p(9))



# TODO: is
# TODO: Chained comparison operators
# TODO: Comparison of lists, tuples
# TODO: Equality of sets, dicts
# TODO: Imaginary
# TODO: objects with __eq__() etc.

Loading