diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index 029246dab11f..09d79138e35a 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -370,7 +370,7 @@ structure Cache where funInfo : FunInfoCache := {} synthInstance : SynthInstanceCache := {} whnf : WhnfCache := {} - defEqTrans : DefEqCache := {} -- transient cache for terms containing mvars or using nonstandard configuration options, it is frequently reset. + defEqTrans : DefEqCache × Nat := ({}, 0) -- transient cache for terms containing mvars or using nonstandard configuration options, it is valid as long as the count matches `MetavarContext.numAssignments`. defEqPerm : DefEqCache := {} -- permanent cache for terms not containing mvars and using standard configuration options deriving Inhabited @@ -558,9 +558,13 @@ protected def saveState : MetaM SavedState := return { core := (← Core.saveState), «meta» := (← get) } /-- Restore backtrackable parts of the state. -/ -def SavedState.restore (b : SavedState) : MetaM Unit := do +def SavedState.restore (b : SavedState) (transCache : Bool := false) : MetaM Unit := do b.core.restore - modify fun s => { s with mctx := b.meta.mctx, zetaDeltaFVarIds := b.meta.zetaDeltaFVarIds, postponed := b.meta.postponed } + modify fun s => { s with + mctx := b.meta.mctx + zetaDeltaFVarIds := b.meta.zetaDeltaFVarIds + postponed := b.meta.postponed + cache.defEqTrans := if transCache then b.meta.cache.defEqTrans else s.cache.defEqTrans } @[specialize, inherit_doc Core.withRestoreOrSaveFull] def withRestoreOrSaveFull (reusableResult? : Option (α × SavedState)) (act : MetaM α) : @@ -622,8 +626,11 @@ def resetCache : MetaM Unit := @[inline] def modifyInferTypeCache (f : InferTypeCache → InferTypeCache) : MetaM Unit := modifyCache fun ⟨ic, c1, c2, c3, c4, c5⟩ => ⟨f ic, c1, c2, c3, c4, c5⟩ -@[inline] def modifyDefEqTransientCache (f : DefEqCache → DefEqCache) : MetaM Unit := - modifyCache fun ⟨c1, c2, c3, c4, defeqTrans, c5⟩ => ⟨c1, c2, c3, c4, f defeqTrans, c5⟩ +@[inline] def modifyDefEqTransientCache (numAssignments : Nat) (f : DefEqCache → DefEqCache) : MetaM Unit := + modifyCache fun c => + let (transCache, numAssignmentsOld) := c.defEqTrans + let transCache := if numAssignments == numAssignmentsOld then transCache else {} + { c with defEqTrans := (f transCache, numAssignments) } @[inline] def modifyDefEqPermCache (f : DefEqCache → DefEqCache) : MetaM Unit := modifyCache fun ⟨c1, c2, c3, c4, c5, defeqPerm⟩ => ⟨c1, c2, c3, c4, c5, f defeqPerm⟩ @@ -641,6 +648,9 @@ def mkDefEqCacheKey (lhs rhs : Expr) : MetaM DefEqCacheKey := do def mkInfoCacheKey (expr : Expr) (nargs? : Option Nat) : MetaM InfoCacheKey := return { expr, nargs?, configKey := (← read).configKey } +@[inline] def resetDefEqTransientCache : MetaM Unit := + modify fun s => { s with cache.defEqTrans := ({}, s.mctx.numAssignments) } + @[inline] def resetDefEqPermCaches : MetaM Unit := modifyDefEqPermCache fun _ => {} @@ -2190,16 +2200,6 @@ partial def processPostponed (mayPostpone : Bool := true) (exceptionOnFailure := -/ @[specialize] def checkpointDefEq (x : MetaM Bool) (mayPostpone : Bool := true) : MetaM Bool := do let s ← saveState - /- - It is not safe to use the `isDefEq` cache between different `isDefEq` calls. - Reason: different configuration settings, and result depends on the state of the `MetavarContext` - We have tried in the past to track when the result was independent of the `MetavarContext` state - but it was not effective. It is more important to cache aggressively inside of a single `isDefEq` - call because some of the heuristics create many similar subproblems. - See issue #1102 for an example that triggers an exponential blowup if we don't use this more - aggressive form of caching. - -/ - modifyDefEqTransientCache fun _ => {} let postponed ← getResetPostponed try if (← x) then @@ -2208,10 +2208,14 @@ partial def processPostponed (mayPostpone : Bool := true) (exceptionOnFailure := setPostponed (postponed ++ newPostponed) return true else - s.restore + -- The transient cache needs to be reverted if it assumes some assignments that are being reverted as well. + let invalidCache := s.meta.mctx.numAssignments < (← get).cache.defEqTrans.2 + s.restore (transCache := invalidCache) return false else - s.restore + -- The transient cache needs to be reverted if it assumes some assignments that are being reverted as well. + let invalidCache := s.meta.mctx.numAssignments < (← get).cache.defEqTrans.2 + s.restore (transCache := invalidCache) return false catch ex => s.restore @@ -2246,6 +2250,16 @@ def isExprDefEq (t s : Expr) : MetaM Bool := Remark: the kernel does *not* update the type of variables in the local context. -/ resetDefEqPermCaches + /- + It is not safe to use the transient `isDefEq` cache between different `isDefEq` calls. + Reason: different configuration settings, and result depends on the state of the `MetavarContext` + We have tried in the past to track when the result was independent of the `MetavarContext` state + but it was not effective. It is more important to cache aggressively inside of a single `isDefEq` + call because some of the heuristics create many similar subproblems. + See issue #1102 for an example that triggers an exponential blowup if we don't use this more + aggressive form of caching. + -/ + resetDefEqTransientCache checkpointDefEq (mayPostpone := true) <| Meta.isExprDefEqAux t s /-- diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index 2a705e7827b4..24bc554c806a 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -76,7 +76,7 @@ where if !isStructureLike (← getEnv) ctorVal.induct then trace[Meta.isDefEq.eta.struct] "failed, type is not a structure{indentExpr b}" return false - else if (← isDefEq (← inferType a) (← inferType b)) then + else if (← checkpointDefEq <| Meta.isExprDefEqAux (← inferType a) (← inferType b)) then checkpointDefEq do let args := b.getAppArgs let params := args[:ctorVal.numParams].toArray @@ -91,7 +91,7 @@ where -- See comment at `isAbstractedUnassignedMVar`. continue trace[Meta.isDefEq.eta.struct] "{a} =?= {b} @ [{j}], {proj} =?= {args[i]}" - unless (← isDefEq proj args[i]) do + unless (← Meta.isExprDefEqAux proj args[i]) do trace[Meta.isDefEq.eta.struct] "failed, unexpected arg #{i}, projection{indentExpr proj}\nis not defeq to{indentExpr args[i]}" return false return true @@ -312,7 +312,7 @@ private partial def isDefEqArgs (f : Expr) (args₁ args₂ : Array Expr) : Meta let info := finfo.paramInfo[i]! if info.isInstImplicit then unless (← withInferTypeConfig <| Meta.isExprDefEqAux a₁ a₂) do - return false + return false else unless (← Meta.isExprDefEqAux a₁ a₂) do return false @@ -1219,7 +1219,7 @@ private partial def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool : ``` -/ private def processAssignment' (mvarApp : Expr) (v : Expr) : MetaM Bool := do - if (← processAssignment mvarApp v) then + if (← checkpointDefEq <| processAssignment mvarApp v) then return true else let vNew ← whnf v @@ -1227,7 +1227,7 @@ private def processAssignment' (mvarApp : Expr) (v : Expr) : MetaM Bool := do if mvarApp == vNew then return true else - processAssignment mvarApp vNew + checkpointDefEq <| processAssignment mvarApp vNew else return false @@ -1608,7 +1608,7 @@ private def isDefEqProofIrrel (t s : Expr) : MetaM LBool := do private def isDefEqMVarSelf (mvar : Expr) (args₁ args₂ : Array Expr) : MetaM Bool := do if args₁.size != args₂.size then pure false - else if (← isDefEqArgs mvar args₁ args₂) then + else if (← checkpointDefEq <| isDefEqArgs mvar args₁ args₂) then pure true else if !(← isAssignable mvar) then pure false @@ -1846,8 +1846,8 @@ end private def isDefEqOnFailure (t s : Expr) : MetaM Bool := do withTraceNodeBefore `Meta.isDefEq.onFailure (return m!"{t} =?= {s}") do - unstuckMVar t (fun t => Meta.isExprDefEqAux t s) <| - unstuckMVar s (fun s => Meta.isExprDefEqAux t s) <| + unstuckMVar t (fun t => checkpointDefEq <| Meta.isExprDefEqAux t s) <| + unstuckMVar s (fun s => checkpointDefEq <| Meta.isExprDefEqAux t s) <| tryUnificationHints t s <||> tryUnificationHints s t /-- @@ -1925,11 +1925,11 @@ private def isDefEqProj : Expr → Expr → MetaM Bool | .proj m i t, .proj n j s => do if (← read).inTypeClassResolution then -- See comment at `inTypeClassResolution` - pure (i == j && m == n) <&&> Meta.isExprDefEqAux t s + pure (i == j && m == n) <&&> checkpointDefEq (Meta.isExprDefEqAux t s) else if !backward.isDefEq.lazyProjDelta.get (← getOptions) then - pure (i == j && m == n) <&&> Meta.isExprDefEqAux t s + pure (i == j && m == n) <&&> checkpointDefEq (Meta.isExprDefEqAux t s) else if i == j && m == n then - isDefEqProjDelta t s i + checkpointDefEq (isDefEqProjDelta t s i) else return false | .proj structName 0 s, v => isDefEqSingleton structName s v @@ -2046,12 +2046,12 @@ private def isExprDefEqExpensive (t : Expr) (s : Expr) : MetaM Bool := do isDefEqOnFailure t s inductive DefEqCacheKind where - | transient -- problem has mvars or is using nonstandard configuration, we should use transient cache + | transient (numAssignments : Nat) -- problem has mvars or is using nonstandard configuration, we should use transient cache | permanent -- problem does not have mvars and we are using standard config, we can use one persistent cache. private def getDefEqCacheKind (t s : Expr) : MetaM DefEqCacheKind := do if t.hasMVar || s.hasMVar || (← read).canUnfold?.isSome then - return .transient + return .transient (← getMCtx).numAssignments else return .permanent @@ -2069,7 +2069,12 @@ private def mkCacheKey (t s : Expr) : MetaM DefEqCacheKeyInfo := do private def getCachedResult (keyInfo : DefEqCacheKeyInfo) : MetaM LBool := do let cache ← match keyInfo.kind with - | .transient => pure (← get).cache.defEqTrans + | .transient numAssignments => + let (cache, numAssignments') := (← get).cache.defEqTrans + if numAssignments == numAssignments' then + pure cache + else + return .undef | .permanent => pure (← get).cache.defEqPerm match cache.find? keyInfo.key with | some val => return val.toLBool @@ -2079,14 +2084,17 @@ private def cacheResult (keyInfo : DefEqCacheKeyInfo) (result : Bool) : MetaM Un let key := keyInfo.key match keyInfo.kind with | .permanent => modifyDefEqPermCache fun c => c.insert key result - | .transient => + | .transient numAssignmentsOld => /- - We must ensure that all assigned metavariables in the key are replaced by their current assignments. - Otherwise, the key is invalid after the assignment is "backtracked". - See issue #1870 for an example. + If the result is `false`, we cache it at `numAssignmentsOld`. + If the result is `true`, we check that the number of assignments hasn't increased. -/ - let key ← mkDefEqCacheKey (← instantiateMVars key.lhs) (← instantiateMVars key.rhs) - modifyDefEqTransientCache fun c => c.insert key result + if !result then + modifyDefEqTransientCache numAssignmentsOld fun c => c.insert key result + else + let numAssignmentsNew := (← getMCtx).numAssignments + if numAssignmentsOld == numAssignmentsNew then + modifyDefEqTransientCache numAssignmentsNew fun c => c.insert key result private def whnfCoreAtDefEq (e : Expr) : MetaM Expr := do if backward.isDefEq.lazyWhnfCore.get (← getOptions) then diff --git a/src/Lean/MetavarContext.lean b/src/Lean/MetavarContext.lean index 3ba07b91e05b..426566cb537e 100644 --- a/src/Lean/MetavarContext.lean +++ b/src/Lean/MetavarContext.lean @@ -331,6 +331,8 @@ structure MetavarContext where lAssignment : PersistentHashMap LMVarId Level := {} /-- Assignment table for expression metavariables.-/ eAssignment : PersistentHashMap MVarId Expr := {} + /-- Number of assignments in `eAssignment` and `lAssignments`. -/ + numAssignments : Nat := 0 /-- Assignment table for delayed abstraction metavariables. For more information about delayed abstraction, see the docstring for `DelayedMetavarAssignment`. -/ dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {} @@ -489,8 +491,9 @@ def hasAssignableMVar [Monad m] [MonadMCtx m] : Expr → m Bool This is a low-level API, and it is safer to use `isLevelDefEq (mkLevelMVar mvarId) u`. -/ def assignLevelMVar [MonadMCtx m] (mvarId : LMVarId) (val : Level) : m Unit := - modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val } + modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val, numAssignments := m.numAssignments + 1 } +-- `assignLevelMVarExp` is only used in `instantiateMVars`, so it doesn't need to increment `numAssignments` @[export lean_assign_lmvar] def assignLevelMVarExp (m : MetavarContext) (mvarId : LMVarId) (val : Level) : MetavarContext := { m with lAssignment := m.lAssignment.insert mvarId val } @@ -502,8 +505,9 @@ a cycle is being introduced, or whether the expression has the right type. This is a low-level API, and it is safer to use `isDefEq (mkMVar mvarId) x`. -/ def _root_.Lean.MVarId.assign [MonadMCtx m] (mvarId : MVarId) (val : Expr) : m Unit := - modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val } + modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val, numAssignments := m.numAssignments + 1 } +-- `assignExp` is only used in `instantiateMVars`, so it doesn't need to increment `numAssignments` @[export lean_assign_mvar] def assignExp (m : MetavarContext) (mvarId : MVarId) (val : Expr) : MetavarContext := { m with eAssignment := m.eAssignment.insert mvarId val } diff --git a/tests/lean/run/defEqTransCache.lean b/tests/lean/run/defEqTransCache.lean new file mode 100644 index 000000000000..63eb49d18438 --- /dev/null +++ b/tests/lean/run/defEqTransCache.lean @@ -0,0 +1,77 @@ +import Lean +/-! +Previously, unification wouldn't be very careful with the `isDefEq` cache for terms containing metavariables. +- This is mostly problematic because erasing the cache leads to exponential slowdowns (`test1` & `test2`) +- but in some cases it leads to metavariable assignments leaking into places where they shouldn't be, + which either causes unification to fail where it should succeed (`test3`) + or to succeed where it is expected to fail. + +-/ +set_option maxHeartbeats 1000 + +namespace test1 +class A (n : Nat) where + x : Nat + +instance [A n] : A (n+1) where + x := A.x n + +theorem test [A 0] : A.x 100 = sorry := sorry + +-- Previously, this example was exponentially slow +example [A 1] : A.x 100 = sorry := by + fail_if_success rw [@test] + sorry +end test1 + + +namespace test2 +@[irreducible] def A : Type := Unit + +@[irreducible] def B : Type := Unit + +unseal B in +@[coe] def AtoB (_a : A) : B := () + +instance : Coe A B where coe := AtoB + +def h {α : Type} (a b : α) : Nat → α +| 0 => a +| n + 1 => h b a n + +def f {α : Type} (a b : α) : Nat → Prop +| 0 => a = b +| n + 1 => f (h a b n) (h b a n) n ∧ f (h a a n) (h b b n) n + +axiom foo {p} {α : Type} (a b : α) : f a b p + +variable (x : A) (y : B) +-- Previously, this check was exponentially slow; now it is quadratically slow +#check (foo (↑x) y : f (AtoB x) y 30) +end test2 + + +namespace test3 +structure A (α : Type) where + x : Type + y : α + +structure B (α : Type) extends A α where + z : Nat + +def A.map {α β} (f : α → β) (a : A α) : A β := ⟨a.x, f a.y⟩ + +open Lean Meta in +elab "unfold_head" e:term : term => do + let e ← Elab.Term.elabTerm e none + unfoldDefinition e + +-- we use `unfold_head` in order to get the raw kernel projection `·.1` instead of the projection funtcion `A.x`. +def test {α} (i : B α) : unfold_head i.toA.x := sorry + +-- Previously, in this example the unification failed, +-- because some metavariable assignment wasn't reverted properly +-- However, it is clearly the case that `(i.toA.map f).x` is the same as `i.toA.x` +example (i : B α) (f : α → β) : (i.toA.map f).x := by + apply @test +end test3