Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ef7cccb
perf: better cache sharing in `isDefEq`
JovanGerb May 12, 2025
f6881cf
use `(shareCache := true)` everywhere in `ExprDefEq.lean`
JovanGerb May 12, 2025
f43191b
better management of `Lean.Meta.Cache.defEqTrans`
JovanGerb May 12, 2025
19c4a79
retry CI
JovanGerb May 12, 2025
9964565
retry CI again
JovanGerb May 12, 2025
676b0fc
avoid `@[extern]` to resolve segfault
JovanGerb May 12, 2025
0422dcb
Revert "avoid `@[extern]` to resolve segfault"
JovanGerb May 12, 2025
be3f48e
fix: bug that was already present in `isDefEqProj`: missing `checkPoi…
JovanGerb May 13, 2025
34a76d3
retry CI
JovanGerb May 13, 2025
2fd2cde
retry CI
JovanGerb May 13, 2025
aacf58d
move `Term.addTermInfo id` one line up in `withRWRulesSeq`
JovanGerb May 14, 2025
4eee7f9
Revert "move `Term.addTermInfo id` one line up in `withRWRulesSeq`"
JovanGerb May 14, 2025
983efd6
fix: oh, `Meta.SavedState.restore` doesn't restore the cache, so I ha…
JovanGerb May 14, 2025
e1d19a5
change how to `modifyDefEqTransientCache`, and add missing `checkpoin…
JovanGerb May 15, 2025
2930737
perf: instead of resetting the transient cache to be empty, reset it …
JovanGerb Jun 2, 2025
e05d9f9
perf: don't increment `numAssignments` in `instantiateMVars`
JovanGerb Jun 3, 2025
85f2f78
empty commit
JovanGerb Jun 3, 2025
244866d
add a test file
JovanGerb Jun 3, 2025
53a4d87
`set_option maxHeartbeats 1000` in test
JovanGerb Jun 3, 2025
cc74f96
slightly improve caching
JovanGerb Jun 3, 2025
0e96d2c
undo caching unifications that instantiate
JovanGerb Jun 4, 2025
8c514b4
avoid `isDefEq` in `ExprDefEq.lean`
JovanGerb Jun 4, 2025
5f05587
empty commit
JovanGerb Jun 4, 2025
ebb5309
Merge branch 'nightly-with-mathlib' into Jovan-defEq-cache
JovanGerb Jun 7, 2025
9c0429d
empty commit
JovanGerb Jun 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions src/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ structure Cache where
funInfo : FunInfoCache := {}
synthInstance : SynthInstanceCache := {}
whnf : WhnfCache := {}
defEqTrans : DefEqCache := {} -- transient cache for terms containing mvars or using nonstandard configuration options, it is frequently reset.
defEqTrans : DefEqCache × Nat := ({}, 0) -- transient cache for terms containing mvars or using nonstandard configuration options, it is valid as long as the count matches `MetavarContext.numAssignments`.
defEqPerm : DefEqCache := {} -- permanent cache for terms not containing mvars and using standard configuration options
deriving Inhabited

Expand Down Expand Up @@ -558,9 +558,13 @@ protected def saveState : MetaM SavedState :=
return { core := (← Core.saveState), «meta» := (← get) }

/-- Restore backtrackable parts of the state. -/
def SavedState.restore (b : SavedState) : MetaM Unit := do
def SavedState.restore (b : SavedState) (transCache : Bool := false) : MetaM Unit := do
b.core.restore
modify fun s => { s with mctx := b.meta.mctx, zetaDeltaFVarIds := b.meta.zetaDeltaFVarIds, postponed := b.meta.postponed }
modify fun s => { s with
mctx := b.meta.mctx
zetaDeltaFVarIds := b.meta.zetaDeltaFVarIds
postponed := b.meta.postponed
cache.defEqTrans := if transCache then b.meta.cache.defEqTrans else s.cache.defEqTrans }

@[specialize, inherit_doc Core.withRestoreOrSaveFull]
def withRestoreOrSaveFull (reusableResult? : Option (α × SavedState)) (act : MetaM α) :
Expand Down Expand Up @@ -622,8 +626,11 @@ def resetCache : MetaM Unit :=
@[inline] def modifyInferTypeCache (f : InferTypeCache → InferTypeCache) : MetaM Unit :=
modifyCache fun ⟨ic, c1, c2, c3, c4, c5⟩ => ⟨f ic, c1, c2, c3, c4, c5⟩

@[inline] def modifyDefEqTransientCache (f : DefEqCache → DefEqCache) : MetaM Unit :=
modifyCache fun ⟨c1, c2, c3, c4, defeqTrans, c5⟩ => ⟨c1, c2, c3, c4, f defeqTrans, c5⟩
@[inline] def modifyDefEqTransientCache (numAssignments : Nat) (f : DefEqCache → DefEqCache) : MetaM Unit :=
modifyCache fun c =>
let (transCache, numAssignmentsOld) := c.defEqTrans
let transCache := if numAssignments == numAssignmentsOld then transCache else {}
{ c with defEqTrans := (f transCache, numAssignments) }

@[inline] def modifyDefEqPermCache (f : DefEqCache → DefEqCache) : MetaM Unit :=
modifyCache fun ⟨c1, c2, c3, c4, c5, defeqPerm⟩ => ⟨c1, c2, c3, c4, c5, f defeqPerm⟩
Expand All @@ -641,6 +648,9 @@ def mkDefEqCacheKey (lhs rhs : Expr) : MetaM DefEqCacheKey := do
def mkInfoCacheKey (expr : Expr) (nargs? : Option Nat) : MetaM InfoCacheKey :=
return { expr, nargs?, configKey := (← read).configKey }

@[inline] def resetDefEqTransientCache : MetaM Unit :=
modify fun s => { s with cache.defEqTrans := ({}, s.mctx.numAssignments) }

@[inline] def resetDefEqPermCaches : MetaM Unit :=
modifyDefEqPermCache fun _ => {}

Expand Down Expand Up @@ -2190,16 +2200,6 @@ partial def processPostponed (mayPostpone : Bool := true) (exceptionOnFailure :=
-/
@[specialize] def checkpointDefEq (x : MetaM Bool) (mayPostpone : Bool := true) : MetaM Bool := do
let s ← saveState
/-
It is not safe to use the `isDefEq` cache between different `isDefEq` calls.
Reason: different configuration settings, and result depends on the state of the `MetavarContext`
We have tried in the past to track when the result was independent of the `MetavarContext` state
but it was not effective. It is more important to cache aggressively inside of a single `isDefEq`
call because some of the heuristics create many similar subproblems.
See issue #1102 for an example that triggers an exponential blowup if we don't use this more
aggressive form of caching.
-/
modifyDefEqTransientCache fun _ => {}
let postponed ← getResetPostponed
try
if (← x) then
Expand All @@ -2208,10 +2208,14 @@ partial def processPostponed (mayPostpone : Bool := true) (exceptionOnFailure :=
setPostponed (postponed ++ newPostponed)
return true
else
s.restore
-- The transient cache needs to be reverted if it assumes some assignments that are being reverted as well.
let invalidCache := s.meta.mctx.numAssignments < (← get).cache.defEqTrans.2
s.restore (transCache := invalidCache)
return false
else
s.restore
-- The transient cache needs to be reverted if it assumes some assignments that are being reverted as well.
let invalidCache := s.meta.mctx.numAssignments < (← get).cache.defEqTrans.2
s.restore (transCache := invalidCache)
return false
catch ex =>
s.restore
Expand Down Expand Up @@ -2246,6 +2250,16 @@ def isExprDefEq (t s : Expr) : MetaM Bool :=
Remark: the kernel does *not* update the type of variables in the local context.
-/
resetDefEqPermCaches
/-
It is not safe to use the transient `isDefEq` cache between different `isDefEq` calls.
Reason: different configuration settings, and result depends on the state of the `MetavarContext`
We have tried in the past to track when the result was independent of the `MetavarContext` state
but it was not effective. It is more important to cache aggressively inside of a single `isDefEq`
call because some of the heuristics create many similar subproblems.
See issue #1102 for an example that triggers an exponential blowup if we don't use this more
aggressive form of caching.
-/
resetDefEqTransientCache
checkpointDefEq (mayPostpone := true) <| Meta.isExprDefEqAux t s

/--
Expand Down
48 changes: 28 additions & 20 deletions src/Lean/Meta/ExprDefEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ where
if !isStructureLike (← getEnv) ctorVal.induct then
trace[Meta.isDefEq.eta.struct] "failed, type is not a structure{indentExpr b}"
return false
else if (← isDefEq (← inferType a) (← inferType b)) then
else if (← checkpointDefEq <| Meta.isExprDefEqAux (← inferType a) (← inferType b)) then
checkpointDefEq do
let args := b.getAppArgs
let params := args[:ctorVal.numParams].toArray
Expand All @@ -91,7 +91,7 @@ where
-- See comment at `isAbstractedUnassignedMVar`.
continue
trace[Meta.isDefEq.eta.struct] "{a} =?= {b} @ [{j}], {proj} =?= {args[i]}"
unless (← isDefEq proj args[i]) do
unless (← Meta.isExprDefEqAux proj args[i]) do
trace[Meta.isDefEq.eta.struct] "failed, unexpected arg #{i}, projection{indentExpr proj}\nis not defeq to{indentExpr args[i]}"
return false
return true
Expand Down Expand Up @@ -312,7 +312,7 @@ private partial def isDefEqArgs (f : Expr) (args₁ args₂ : Array Expr) : Meta
let info := finfo.paramInfo[i]!
if info.isInstImplicit then
unless (← withInferTypeConfig <| Meta.isExprDefEqAux a₁ a₂) do
return false
return false
else
unless (← Meta.isExprDefEqAux a₁ a₂) do
return false
Expand Down Expand Up @@ -1219,15 +1219,15 @@ private partial def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool :
```
-/
private def processAssignment' (mvarApp : Expr) (v : Expr) : MetaM Bool := do
if (← processAssignment mvarApp v) then
if (← checkpointDefEq <| processAssignment mvarApp v) then
return true
else
let vNew ← whnf v
if vNew != v then
if mvarApp == vNew then
return true
else
processAssignment mvarApp vNew
checkpointDefEq <| processAssignment mvarApp vNew
else
return false

Expand Down Expand Up @@ -1608,7 +1608,7 @@ private def isDefEqProofIrrel (t s : Expr) : MetaM LBool := do
private def isDefEqMVarSelf (mvar : Expr) (args₁ args₂ : Array Expr) : MetaM Bool := do
if args₁.size != args₂.size then
pure false
else if (← isDefEqArgs mvar args₁ args₂) then
else if (← checkpointDefEq <| isDefEqArgs mvar args₁ args₂) then
pure true
else if !(← isAssignable mvar) then
pure false
Expand Down Expand Up @@ -1846,8 +1846,8 @@ end

private def isDefEqOnFailure (t s : Expr) : MetaM Bool := do
withTraceNodeBefore `Meta.isDefEq.onFailure (return m!"{t} =?= {s}") do
unstuckMVar t (fun t => Meta.isExprDefEqAux t s) <|
unstuckMVar s (fun s => Meta.isExprDefEqAux t s) <|
unstuckMVar t (fun t => checkpointDefEq <| Meta.isExprDefEqAux t s) <|
unstuckMVar s (fun s => checkpointDefEq <| Meta.isExprDefEqAux t s) <|
tryUnificationHints t s <||> tryUnificationHints s t

/--
Expand Down Expand Up @@ -1925,11 +1925,11 @@ private def isDefEqProj : Expr → Expr → MetaM Bool
| .proj m i t, .proj n j s => do
if (← read).inTypeClassResolution then
-- See comment at `inTypeClassResolution`
pure (i == j && m == n) <&&> Meta.isExprDefEqAux t s
pure (i == j && m == n) <&&> checkpointDefEq (Meta.isExprDefEqAux t s)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about the effect of just these changes in isDefEqProj on Mathlib.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just doing this change may cause a slight slowdown, because checkPointDefEq resets the transitive defEq cache. So the cache will be reset more often. I think it will work better as a fix done simultaneously with this the fix from this PR.

But I might be wrong, so feel free to try it.

else if !backward.isDefEq.lazyProjDelta.get (← getOptions) then
pure (i == j && m == n) <&&> Meta.isExprDefEqAux t s
pure (i == j && m == n) <&&> checkpointDefEq (Meta.isExprDefEqAux t s)
else if i == j && m == n then
isDefEqProjDelta t s i
checkpointDefEq (isDefEqProjDelta t s i)
else
return false
| .proj structName 0 s, v => isDefEqSingleton structName s v
Expand Down Expand Up @@ -2046,12 +2046,12 @@ private def isExprDefEqExpensive (t : Expr) (s : Expr) : MetaM Bool := do
isDefEqOnFailure t s

inductive DefEqCacheKind where
| transient -- problem has mvars or is using nonstandard configuration, we should use transient cache
| transient (numAssignments : Nat) -- problem has mvars or is using nonstandard configuration, we should use transient cache
| permanent -- problem does not have mvars and we are using standard config, we can use one persistent cache.

private def getDefEqCacheKind (t s : Expr) : MetaM DefEqCacheKind := do
if t.hasMVar || s.hasMVar || (← read).canUnfold?.isSome then
return .transient
return .transient (← getMCtx).numAssignments
else
return .permanent

Expand All @@ -2069,7 +2069,12 @@ private def mkCacheKey (t s : Expr) : MetaM DefEqCacheKeyInfo := do

private def getCachedResult (keyInfo : DefEqCacheKeyInfo) : MetaM LBool := do
let cache ← match keyInfo.kind with
| .transient => pure (← get).cache.defEqTrans
| .transient numAssignments =>
let (cache, numAssignments') := (← get).cache.defEqTrans
if numAssignments == numAssignments' then
pure cache
else
return .undef
| .permanent => pure (← get).cache.defEqPerm
match cache.find? keyInfo.key with
| some val => return val.toLBool
Expand All @@ -2079,14 +2084,17 @@ private def cacheResult (keyInfo : DefEqCacheKeyInfo) (result : Bool) : MetaM Un
let key := keyInfo.key
match keyInfo.kind with
| .permanent => modifyDefEqPermCache fun c => c.insert key result
| .transient =>
| .transient numAssignmentsOld =>
/-
We must ensure that all assigned metavariables in the key are replaced by their current assignments.
Otherwise, the key is invalid after the assignment is "backtracked".
See issue #1870 for an example.
If the result is `false`, we cache it at `numAssignmentsOld`.
If the result is `true`, we check that the number of assignments hasn't increased.
-/
let key ← mkDefEqCacheKey (← instantiateMVars key.lhs) (← instantiateMVars key.rhs)
modifyDefEqTransientCache fun c => c.insert key result
if !result then
modifyDefEqTransientCache numAssignmentsOld fun c => c.insert key result
else
let numAssignmentsNew := (← getMCtx).numAssignments
if numAssignmentsOld == numAssignmentsNew then
modifyDefEqTransientCache numAssignmentsNew fun c => c.insert key result

private def whnfCoreAtDefEq (e : Expr) : MetaM Expr := do
if backward.isDefEq.lazyWhnfCore.get (← getOptions) then
Expand Down
8 changes: 6 additions & 2 deletions src/Lean/MetavarContext.lean
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ structure MetavarContext where
lAssignment : PersistentHashMap LMVarId Level := {}
/-- Assignment table for expression metavariables.-/
eAssignment : PersistentHashMap MVarId Expr := {}
/-- Number of assignments in `eAssignment` and `lAssignments`. -/
numAssignments : Nat := 0
/-- Assignment table for delayed abstraction metavariables.
For more information about delayed abstraction, see the docstring for `DelayedMetavarAssignment`. -/
dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {}
Expand Down Expand Up @@ -489,8 +491,9 @@ def hasAssignableMVar [Monad m] [MonadMCtx m] : Expr → m Bool
This is a low-level API, and it is safer to use `isLevelDefEq (mkLevelMVar mvarId) u`.
-/
def assignLevelMVar [MonadMCtx m] (mvarId : LMVarId) (val : Level) : m Unit :=
modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val }
modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val, numAssignments := m.numAssignments + 1 }

-- `assignLevelMVarExp` is only used in `instantiateMVars`, so it doesn't need to increment `numAssignments`
@[export lean_assign_lmvar]
def assignLevelMVarExp (m : MetavarContext) (mvarId : LMVarId) (val : Level) : MetavarContext :=
{ m with lAssignment := m.lAssignment.insert mvarId val }
Expand All @@ -502,8 +505,9 @@ a cycle is being introduced, or whether the expression has the right type.
This is a low-level API, and it is safer to use `isDefEq (mkMVar mvarId) x`.
-/
def _root_.Lean.MVarId.assign [MonadMCtx m] (mvarId : MVarId) (val : Expr) : m Unit :=
modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val }
modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val, numAssignments := m.numAssignments + 1 }

-- `assignExp` is only used in `instantiateMVars`, so it doesn't need to increment `numAssignments`
@[export lean_assign_mvar]
def assignExp (m : MetavarContext) (mvarId : MVarId) (val : Expr) : MetavarContext :=
{ m with eAssignment := m.eAssignment.insert mvarId val }
Expand Down
77 changes: 77 additions & 0 deletions tests/lean/run/defEqTransCache.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import Lean
/-!
Previously, unification wouldn't be very careful with the `isDefEq` cache for terms containing metavariables.
- This is mostly problematic because erasing the cache leads to exponential slowdowns (`test1` & `test2`)
- but in some cases it leads to metavariable assignments leaking into places where they shouldn't be,
which either causes unification to fail where it should succeed (`test3`)
or to succeed where it is expected to fail.

-/
set_option maxHeartbeats 1000

namespace test1
class A (n : Nat) where
x : Nat

instance [A n] : A (n+1) where
x := A.x n

theorem test [A 0] : A.x 100 = sorry := sorry

-- Previously, this example was exponentially slow
example [A 1] : A.x 100 = sorry := by
fail_if_success rw [@test]
sorry
end test1


namespace test2
@[irreducible] def A : Type := Unit

@[irreducible] def B : Type := Unit

unseal B in
@[coe] def AtoB (_a : A) : B := ()

instance : Coe A B where coe := AtoB

def h {α : Type} (a b : α) : Nat → α
| 0 => a
| n + 1 => h b a n

def f {α : Type} (a b : α) : Nat → Prop
| 0 => a = b
| n + 1 => f (h a b n) (h b a n) n ∧ f (h a a n) (h b b n) n

axiom foo {p} {α : Type} (a b : α) : f a b p

variable (x : A) (y : B)
-- Previously, this check was exponentially slow; now it is quadratically slow
#check (foo (↑x) y : f (AtoB x) y 30)
end test2


namespace test3
structure A (α : Type) where
x : Type
y : α

structure B (α : Type) extends A α where
z : Nat

def A.map {α β} (f : α → β) (a : A α) : A β := ⟨a.x, f a.y⟩

open Lean Meta in
elab "unfold_head" e:term : term => do
let e ← Elab.Term.elabTerm e none
unfoldDefinition e

-- we use `unfold_head` in order to get the raw kernel projection `·.1` instead of the projection funtcion `A.x`.
def test {α} (i : B α) : unfold_head i.toA.x := sorry

-- Previously, in this example the unification failed,
-- because some metavariable assignment wasn't reverted properly
-- However, it is clearly the case that `(i.toA.map f).x` is the same as `i.toA.x`
example (i : B α) (f : α → β) : (i.toA.map f).x := by
apply @test
end test3
Loading