diff --git a/Clear/PrimOps.lean b/Clear/PrimOps.lean index 12ff0d4..ef601b2 100644 --- a/Clear/PrimOps.lean +++ b/Clear/PrimOps.lean @@ -12,15 +12,20 @@ abbrev fromBool := Bool.toUInt256 def evmAddMod (a b c : UInt256) : UInt256 := if c = 0 then 0 else - Fin.mod (a + b) c + -- "All intermediate calculations of this operation are **not** subject to the 2^256 modulo." + Fin.mod (a.val + b.val) c def evmMulMod (a b c : UInt256) : UInt256 := if c = 0 then 0 else - Fin.mod (a * b) c + -- "All intermediate calculations of this operation are **not** subject to the 2^256 modulo." + Fin.mod (a.val * b.val) c def evmExp (a b : UInt256) : UInt256 := a ^ b.val +def evmMod (x y : UInt256) : UInt256 := + if y == 0 then 0 else x % y + set_option linter.unusedVariables false in def primCall (s : State) : PrimOp → List Literal → State × List Literal | .Add, [a,b] => (s, [a + b]) @@ -28,7 +33,7 @@ def primCall (s : State) : PrimOp → List Literal → State × List Literal | .Mul, [a,b] => (s, [a * b]) | .Div, [a,b] => (s, [a / b]) | .Sdiv, [a,b] => (s, [UInt256.sdiv a b]) - | .Mod, [a,b] => (s, [Fin.mod a b]) + | .Mod, [a,b] => (s, [evmMod a b]) | .Smod, [a,b] => (s, [UInt256.smod a b]) | .Addmod, [a,b,c] => (s, [evmAddMod a b c]) | .Mulmod, [a,b,c] => (s, [evmMulMod a b c]) @@ -109,7 +114,7 @@ lemma EVMSub' : primCall s .Sub [a,b] = (s, [a - lemma EVMMul' : primCall s .Mul [a,b] = (s, [a * b]) := rfl lemma EVMDiv' : primCall s .Div [a,b] = (s, [a / b]) := rfl lemma EVMSdiv' : primCall s .Sdiv [a,b] = (s, [UInt256.sdiv a b]) := rfl -lemma EVMMod' : primCall s .Mod [a,b] = (s, [Fin.mod a b]) := rfl +lemma EVMMod' : primCall s .Mod [a,b] = (s, [evmMod a b]) := rfl lemma EVMSmod' : primCall s .Smod [a,b] = (s, [UInt256.smod a b]) := rfl lemma EVMAddmod' : primCall s .Addmod [a,b,c] = (s, [evmAddMod a b c]) := rfl lemma EVMMulmod' : primCall s .Mulmod [a,b,c] = (s, [evmMulMod a b c]) := rfl diff --git a/Clear/UInt256.lean b/Clear/UInt256.lean index 390a6de..7a6126c 100644 --- a/Clear/UInt256.lean +++ b/Clear/UInt256.lean @@ -54,7 +54,7 @@ def eq0 (a : UInt256) : Bool := a = 0 def lnot (a : UInt256) : UInt256 := (UInt256.size - 1) - a def byteAt (a b : UInt256) : UInt256 := - b >>> (31 - a) * 8 <<< 248 + b >>> (.ofNat ((31 - a.val) * 8)) &&& 0xFF def sgn (a : UInt256) : UInt256 := if a ≥ 2 ^ 255 then -1 @@ -76,14 +76,15 @@ def sdiv (a b : UInt256) : UInt256 := else a / b def smod (a b : UInt256) : UInt256 := - if a ≥ 2 ^ 255 then - if b ≥ 2 ^ 255 then - Fin.mod (abs a) (abs b) - else (-1) * Fin.mod (abs a) b + if b == 0 then 0 else - if b ≥ 2 ^ 255 then - (-1) * Fin.mod a (abs b) - else Fin.mod a b + let sgnA := if 2 ^ 255 <= a then -1 else 1 + let sgnB := if 2 ^ 255 <= b then -1 else 1 + let mask : UInt256 := .ofNat (2 ^ 256 - 1 : ℕ) + let absA := if sgnA == 1 then a else - (.xor a mask + 1) + let absB := if sgnB == 1 then b else - (.xor b mask + 1) + sgnA * (absA % absB) + def slt (a b : UInt256) : Bool := if a ≥ 2 ^ 255 then