diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index 2b4f5a4da288..3a437d4dcb7b 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -376,6 +376,19 @@ We should also investigate the impact on memory consumption. -/ abbrev DefEqCache := PersistentHashMap DefEqCacheKey Bool +/-- +A `DefEqTransCache` is a `DefEqCache` that is only valid in the `MetavarContext` in which it was declared. +To keep track of whether this cache is still valid, it stores the `numAssignments` from the original `MetavarContext`. +If the `numAssignments` in the `MetavarContext` has increased, we invalidate this cache. +And when reverting the metavariable context in `checkpointDefEq`, if the `numAssignments` +in the original `MetavarContext` is smaller than in this cache, we also invalidate this cache, +and revert to the original cache. +-/ +structure DefEqTransCache where + cache : DefEqCache := {} + numAssignments : Nat := 0 + deriving Inhabited + /-- Cache datastructures for type inference, type class resolution, whnf, and definitional equality. -/ @@ -384,7 +397,7 @@ structure Cache where funInfo : FunInfoCache := {} synthInstance : SynthInstanceCache := {} whnf : WhnfCache := {} - defEqTrans : DefEqCache := {} -- transient cache for terms containing mvars or using nonstandard configuration options, it is frequently reset. + defEqTrans : DefEqTransCache := {} -- transient cache for terms containing mvars or using nonstandard configuration options, it is valid as long as the count matches `MetavarContext.numAssignments`. defEqPerm : DefEqCache := {} -- permanent cache for terms not containing mvars and using standard configuration options deriving Inhabited @@ -574,10 +587,15 @@ instance : AddMessageContext MetaM where protected def saveState : MetaM SavedState := return { core := (← Core.saveState), «meta» := (← get) } -/-- Restore backtrackable parts of the state. -/ -def SavedState.restore (b : SavedState) : MetaM Unit := do +/-- Restore backtrackable parts of the state. +If `transCache == true`, then also reset the tranient defEq cache -/ +def SavedState.restore (b : SavedState) (transCache : Bool := false) : MetaM Unit := do b.core.restore - modify fun s => { s with mctx := b.meta.mctx, zetaDeltaFVarIds := b.meta.zetaDeltaFVarIds, postponed := b.meta.postponed } + modify fun s => { s with + mctx := b.meta.mctx + zetaDeltaFVarIds := b.meta.zetaDeltaFVarIds + postponed := b.meta.postponed + cache.defEqTrans := if transCache then b.meta.cache.defEqTrans else s.cache.defEqTrans } @[specialize, inherit_doc Core.withRestoreOrSaveFull] def withRestoreOrSaveFull (reusableResult? : Option (α × SavedState)) (act : MetaM α) : @@ -639,8 +657,12 @@ def resetCache : MetaM Unit := @[inline] def modifyInferTypeCache (f : InferTypeCache → InferTypeCache) : MetaM Unit := modifyCache fun ⟨ic, c1, c2, c3, c4, c5⟩ => ⟨f ic, c1, c2, c3, c4, c5⟩ -@[inline] def modifyDefEqTransientCache (f : DefEqCache → DefEqCache) : MetaM Unit := - modifyCache fun ⟨c1, c2, c3, c4, defeqTrans, c5⟩ => ⟨c1, c2, c3, c4, f defeqTrans, c5⟩ +/-- Modify the defEq transient cache. If it is not valid anymore, reset it before modifying it. -/ +@[inline] def modifyDefEqTransientCache (numAssignments : Nat) (f : DefEqCache → DefEqCache) : MetaM Unit := + modifyCache fun c => + let ⟨transCache, numAssignmentsOld⟩ := c.defEqTrans + let transCache := if numAssignments == numAssignmentsOld then transCache else {} + { c with defEqTrans := ⟨f transCache, numAssignments⟩ } @[inline] def modifyDefEqPermCache (f : DefEqCache → DefEqCache) : MetaM Unit := modifyCache fun ⟨c1, c2, c3, c4, c5, defeqPerm⟩ => ⟨c1, c2, c3, c4, c5, f defeqPerm⟩ @@ -658,6 +680,9 @@ def mkDefEqCacheKey (lhs rhs : Expr) : MetaM DefEqCacheKey := do def mkInfoCacheKey (expr : Expr) (nargs? : Option Nat) : MetaM InfoCacheKey := return { expr, nargs?, configKey := (← read).configKey } +@[inline] def resetDefEqTransientCache : MetaM Unit := + modify fun s => { s with cache.defEqTrans := ⟨{}, s.mctx.numAssignments⟩ } + @[inline] def resetDefEqPermCaches : MetaM Unit := modifyDefEqPermCache fun _ => {} @@ -2349,16 +2374,6 @@ partial def processPostponed (mayPostpone : Bool := true) (exceptionOnFailure := -/ @[specialize] def checkpointDefEq (x : MetaM Bool) (mayPostpone : Bool := true) : MetaM Bool := do let s ← saveState - /- - It is not safe to use the `isDefEq` cache between different `isDefEq` calls. - Reason: different configuration settings, and result depends on the state of the `MetavarContext` - We have tried in the past to track when the result was independent of the `MetavarContext` state - but it was not effective. It is more important to cache aggressively inside of a single `isDefEq` - call because some of the heuristics create many similar subproblems. - See issue #1102 for an example that triggers an exponential blowup if we don't use this more - aggressive form of caching. - -/ - modifyDefEqTransientCache fun _ => {} let postponed ← getResetPostponed try if (← x) then @@ -2367,10 +2382,14 @@ partial def processPostponed (mayPostpone : Bool := true) (exceptionOnFailure := setPostponed (postponed ++ newPostponed) return true else - s.restore + -- The transient cache needs to be reverted if it assumes an assignments that is being reverted. + let isInvalidCache := s.meta.mctx.numAssignments != (← get).cache.defEqTrans.numAssignments + s.restore (transCache := isInvalidCache) return false else - s.restore + -- The transient cache needs to be reverted if it assumes an assignments that is being reverted. + let isInvalidCache := s.meta.mctx.numAssignments != (← get).cache.defEqTrans.numAssignments + s.restore (transCache := isInvalidCache) return false catch ex => s.restore @@ -2405,6 +2424,16 @@ def isExprDefEq (t s : Expr) : MetaM Bool := Remark: the kernel does *not* update the type of variables in the local context. -/ resetDefEqPermCaches + /- + It is not safe to use the transient `isDefEq` cache between different `isDefEq` calls. + Reason: different configuration settings, and result depends on the state of the `MetavarContext` + We have tried in the past to track when the result was independent of the `MetavarContext` state + but it was not effective. It is more important to cache aggressively inside of a single `isDefEq` + call because some of the heuristics create many similar subproblems. + See issue #1102 and `tests/lean/run/defEqTransCache.lean` for examples that trigger an exponential blowup + if we don't use this more aggressive form of caching. + -/ + resetDefEqTransientCache checkpointDefEq (mayPostpone := true) <| Meta.isExprDefEqAux t s /-- diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index 1c3975c8949c..456d57212bfa 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -80,7 +80,7 @@ where if !isStructureLike (← getEnv) ctorVal.induct then trace[Meta.isDefEq.eta.struct] "failed, type is not a structure{indentExpr b}" return false - else if (← isDefEq (← inferType a) (← inferType b)) then + else if (← checkpointDefEq <| Meta.isExprDefEqAux (← inferType a) (← inferType b)) then checkpointDefEq do let args := b.getAppArgs let params := args[*...ctorVal.numParams].toArray @@ -95,7 +95,7 @@ where -- See comment at `isAbstractedUnassignedMVar`. continue trace[Meta.isDefEq.eta.struct] "{a} =?= {b} @ [{j}], {proj} =?= {args[i]}" - unless (← isDefEq proj args[i]) do + unless (← Meta.isExprDefEqAux proj args[i]) do trace[Meta.isDefEq.eta.struct] "failed, unexpected arg #{i}, projection{indentExpr proj}\nis not defeq to{indentExpr args[i]}" return false return true @@ -1240,7 +1240,7 @@ private partial def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool : ``` -/ private def processAssignment' (mvarApp : Expr) (v : Expr) : MetaM Bool := do - if (← processAssignment mvarApp v) then + if (← checkpointDefEq <| processAssignment mvarApp v) then return true else let vNew ← whnf v @@ -1248,7 +1248,7 @@ private def processAssignment' (mvarApp : Expr) (v : Expr) : MetaM Bool := do if mvarApp == vNew then return true else - processAssignment mvarApp vNew + checkpointDefEq <| processAssignment mvarApp vNew else return false @@ -1629,7 +1629,7 @@ private def isDefEqProofIrrel (t s : Expr) : MetaM LBool := do private def isDefEqMVarSelf (mvar : Expr) (args₁ args₂ : Array Expr) : MetaM Bool := do if args₁.size != args₂.size then pure false - else if (← isDefEqArgs mvar args₁ args₂) then + else if (← checkpointDefEq <| isDefEqArgs mvar args₁ args₂) then pure true else if !(← isAssignable mvar) then pure false @@ -1861,8 +1861,8 @@ end private def isDefEqOnFailure (t s : Expr) : MetaM Bool := do withTraceNodeBefore `Meta.isDefEq.onFailure (return m!"{t} =?= {s}") do - unstuckMVar t (fun t => Meta.isExprDefEqAux t s) <| - unstuckMVar s (fun s => Meta.isExprDefEqAux t s) <| + unstuckMVar t (fun t => checkpointDefEq <| Meta.isExprDefEqAux t s) <| + unstuckMVar s (fun s => checkpointDefEq <| Meta.isExprDefEqAux t s) <| tryUnificationHints t s <||> tryUnificationHints s t /-- @@ -1940,11 +1940,11 @@ private def isDefEqProj : Expr → Expr → MetaM Bool | .proj m i t, .proj n j s => do if (← read).inTypeClassResolution then -- See comment at `inTypeClassResolution` - pure (i == j && m == n) <&&> Meta.isExprDefEqAux t s + pure (i == j && m == n) <&&> checkpointDefEq (Meta.isExprDefEqAux t s) else if !backward.isDefEq.lazyProjDelta.get (← getOptions) then - pure (i == j && m == n) <&&> Meta.isExprDefEqAux t s + pure (i == j && m == n) <&&> checkpointDefEq (Meta.isExprDefEqAux t s) else if i == j && m == n then - isDefEqProjDelta t s i + checkpointDefEq (isDefEqProjDelta t s i) else return false | .proj structName 0 s, v => isDefEqSingleton structName s v @@ -2061,12 +2061,12 @@ private def isExprDefEqExpensive (t : Expr) (s : Expr) : MetaM Bool := do isDefEqOnFailure t s inductive DefEqCacheKind where - | transient -- problem has mvars or is using nonstandard configuration, we should use transient cache + | transient (numAssignments : Nat) -- problem has mvars or is using nonstandard configuration, we should use transient cache | permanent -- problem does not have mvars and we are using standard config, we can use one persistent cache. private def getDefEqCacheKind (t s : Expr) : MetaM DefEqCacheKind := do if t.hasMVar || s.hasMVar || (← read).canUnfold?.isSome then - return .transient + return .transient (← getMCtx).numAssignments else return .permanent @@ -2084,7 +2084,12 @@ private def mkCacheKey (t s : Expr) : MetaM DefEqCacheKeyInfo := do private def getCachedResult (keyInfo : DefEqCacheKeyInfo) : MetaM LBool := do let cache ← match keyInfo.kind with - | .transient => pure (← get).cache.defEqTrans + | .transient numAssignments => + let ⟨cache, numAssignmentsCache⟩ := (← get).cache.defEqTrans + if numAssignments == numAssignmentsCache then + pure cache + else + return .undef | .permanent => pure (← get).cache.defEqPerm match cache.find? keyInfo.key with | some val => return val.toLBool @@ -2094,14 +2099,17 @@ private def cacheResult (keyInfo : DefEqCacheKeyInfo) (result : Bool) : MetaM Un let key := keyInfo.key match keyInfo.kind with | .permanent => modifyDefEqPermCache fun c => c.insert key result - | .transient => + | .transient numAssignmentsOld => /- - We must ensure that all assigned metavariables in the key are replaced by their current assignments. - Otherwise, the key is invalid after the assignment is "backtracked". - See issue #1870 for an example. + If the result is `false`, we cache it at `numAssignmentsOld`. + If the result is `true`, we only cache it if the number of assignments hasn't increased. -/ - let key ← mkDefEqCacheKey (← instantiateMVars key.lhs) (← instantiateMVars key.rhs) - modifyDefEqTransientCache fun c => c.insert key result + if !result then + modifyDefEqTransientCache numAssignmentsOld fun c => c.insert key result + else + let numAssignmentsNew := (← getMCtx).numAssignments + if numAssignmentsOld == numAssignmentsNew then + modifyDefEqTransientCache numAssignmentsOld fun c => c.insert key result private def whnfCoreAtDefEq (e : Expr) : MetaM Expr := do if backward.isDefEq.lazyWhnfCore.get (← getOptions) then diff --git a/src/Lean/MetavarContext.lean b/src/Lean/MetavarContext.lean index d79cae6eda8d..af50f67b3a71 100644 --- a/src/Lean/MetavarContext.lean +++ b/src/Lean/MetavarContext.lean @@ -349,6 +349,8 @@ structure MetavarContext where lAssignment : PersistentHashMap LMVarId Level := {} /-- Assignment table for expression metavariables.-/ eAssignment : PersistentHashMap MVarId Expr := {} + /-- Number of assignments in `eAssignment` and `lAssignments`. -/ + numAssignments : Nat := 0 /-- Assignment table for delayed abstraction metavariables. For more information about delayed abstraction, see the docstring for `DelayedMetavarAssignment`. -/ dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {} @@ -507,8 +509,9 @@ def hasAssignableMVar [Monad m] [MonadMCtx m] : Expr → m Bool This is a low-level API, and it is safer to use `isLevelDefEq (mkLevelMVar mvarId) u`. -/ def assignLevelMVar [MonadMCtx m] (mvarId : LMVarId) (val : Level) : m Unit := - modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val } + modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val, numAssignments := m.numAssignments + 1 } +-- `assignLevelMVarExp` is only used in `instantiateMVars`, so it doesn't need to increment `numAssignments` @[export lean_assign_lmvar] def assignLevelMVarExp (m : MetavarContext) (mvarId : LMVarId) (val : Level) : MetavarContext := { m with lAssignment := m.lAssignment.insert mvarId val } @@ -520,8 +523,9 @@ a cycle is being introduced, or whether the expression has the right type. This is a low-level API, and it is safer to use `isDefEq (mkMVar mvarId) x`. -/ def _root_.Lean.MVarId.assign [MonadMCtx m] (mvarId : MVarId) (val : Expr) : m Unit := - modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val } + modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val, numAssignments := m.numAssignments + 1 } +-- `assignExp` is only used in `instantiateMVars`, so it doesn't need to increment `numAssignments` @[export lean_assign_mvar] def assignExp (m : MetavarContext) (mvarId : MVarId) (val : Expr) : MetavarContext := { m with eAssignment := m.eAssignment.insert mvarId val } diff --git a/tests/lean/run/defEqTransCache.lean b/tests/lean/run/defEqTransCache.lean new file mode 100644 index 000000000000..40d297c7d4cb --- /dev/null +++ b/tests/lean/run/defEqTransCache.lean @@ -0,0 +1,77 @@ +import Lean +/-! +Previously, unification wouldn't be very careful with the `isDefEq` cache for terms containing metavariables. +- This is mostly problematic because erasing the cache leads to exponential slowdowns (`test1` & `test2`) +- but in some cases it lead to metavariable assignments leaking into places where they shouldn't be, + which either caused unification to fail where it should succeed (`test3`) + or to succeed where it is expected to fail (which happened in one mathlib proof). +-/ + +set_option maxHeartbeats 1000 + +namespace test1 +class A (n : Nat) where + x : Nat + +instance [A n] : A (n+1) where + x := A.x n + +theorem test [A 0] : A.x 100 = 0 := sorry + +-- This rewrite should fail. Previously, it failed exponentially slowly +example [A 1] : A.x 100 = 0 := by + fail_if_success rw [@test] + sorry +end test1 + + +namespace test2 +@[irreducible] def A : Type := Unit + +@[irreducible] def B : Type := Unit + +unseal B in +@[coe] def AtoB (_a : A) : B := () + +instance : Coe A B where coe := AtoB + +def h {α : Type} (a b : α) : Nat → α +| 0 => a +| n + 1 => h b a n + +def f {α : Type} (a b : α) : Nat → Prop +| 0 => a = b +| n + 1 => f (h a b n) (h b a n) n ∧ f (h a a n) (h b b n) n + +axiom foo {p} {α : Type} (a b : α) : f a b p + +variable (x : A) (y : B) +-- Previously, this check was exponentially slow; now it is quadratically slow +#check (foo (↑x) y : f (AtoB x) y 30) +end test2 + + +namespace test3 +structure A (α : Type) where + x : Type + y : α + +structure B (α : Type) extends A α where + z : Nat + +def A.map {α β} (f : α → β) (a : A α) : A β := ⟨a.x, f a.y⟩ + +open Lean Meta in +elab "unfold_head" e:term : term => do + let e ← Elab.Term.elabTerm e none + unfoldDefinition e + +-- use `unfold_head` to get the raw kernel projection `·.1` instead of the projection funtcion `A.x` +def test {α} (i : B α) : unfold_head i.toA.x := sorry + +-- Previously, in this example the unification failed, +-- because some metavariable assignment wasn't reverted properly +-- However, it is clearly the case that `(i.toA.map f).x` is the same as `i.toA.x` +example (i : B α) (f : α → β) : (i.toA.map f).x := by + apply @test +end test3