Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 47 additions & 18 deletions src/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,19 @@ We should also investigate the impact on memory consumption.
-/
abbrev DefEqCache := PersistentHashMap DefEqCacheKey Bool

/--
A `DefEqTransCache` is a `DefEqCache` that is only valid in the `MetavarContext` in which it was declared.
To keep track of whether this cache is still valid, it stores the `numAssignments` from the original `MetavarContext`.
If the `numAssignments` in the `MetavarContext` has increased, we invalidate this cache.
And when reverting the metavariable context in `checkpointDefEq`, if the `numAssignments`
in the original `MetavarContext` is smaller than in this cache, we also invalidate this cache,
and revert to the original cache.
-/
structure DefEqTransCache where
cache : DefEqCache := {}
numAssignments : Nat := 0
deriving Inhabited

/--
Cache datastructures for type inference, type class resolution, whnf, and definitional equality.
-/
Expand All @@ -384,7 +397,7 @@ structure Cache where
funInfo : FunInfoCache := {}
synthInstance : SynthInstanceCache := {}
whnf : WhnfCache := {}
defEqTrans : DefEqCache := {} -- transient cache for terms containing mvars or using nonstandard configuration options, it is frequently reset.
defEqTrans : DefEqTransCache := {} -- transient cache for terms containing mvars or using nonstandard configuration options, it is valid as long as the count matches `MetavarContext.numAssignments`.
defEqPerm : DefEqCache := {} -- permanent cache for terms not containing mvars and using standard configuration options
deriving Inhabited

Expand Down Expand Up @@ -574,10 +587,15 @@ instance : AddMessageContext MetaM where
protected def saveState : MetaM SavedState :=
return { core := (← Core.saveState), «meta» := (← get) }

/-- Restore backtrackable parts of the state. -/
def SavedState.restore (b : SavedState) : MetaM Unit := do
/-- Restore backtrackable parts of the state.
If `transCache == true`, then also reset the tranient defEq cache -/
def SavedState.restore (b : SavedState) (transCache : Bool := false) : MetaM Unit := do
b.core.restore
modify fun s => { s with mctx := b.meta.mctx, zetaDeltaFVarIds := b.meta.zetaDeltaFVarIds, postponed := b.meta.postponed }
modify fun s => { s with
mctx := b.meta.mctx
zetaDeltaFVarIds := b.meta.zetaDeltaFVarIds
postponed := b.meta.postponed
cache.defEqTrans := if transCache then b.meta.cache.defEqTrans else s.cache.defEqTrans }

@[specialize, inherit_doc Core.withRestoreOrSaveFull]
def withRestoreOrSaveFull (reusableResult? : Option (α × SavedState)) (act : MetaM α) :
Expand Down Expand Up @@ -639,8 +657,12 @@ def resetCache : MetaM Unit :=
@[inline] def modifyInferTypeCache (f : InferTypeCache → InferTypeCache) : MetaM Unit :=
modifyCache fun ⟨ic, c1, c2, c3, c4, c5⟩ => ⟨f ic, c1, c2, c3, c4, c5⟩

@[inline] def modifyDefEqTransientCache (f : DefEqCache → DefEqCache) : MetaM Unit :=
modifyCache fun ⟨c1, c2, c3, c4, defeqTrans, c5⟩ => ⟨c1, c2, c3, c4, f defeqTrans, c5⟩
/-- Modify the defEq transient cache. If it is not valid anymore, reset it before modifying it. -/
@[inline] def modifyDefEqTransientCache (numAssignments : Nat) (f : DefEqCache → DefEqCache) : MetaM Unit :=
modifyCache fun c =>
let ⟨transCache, numAssignmentsOld⟩ := c.defEqTrans
let transCache := if numAssignments == numAssignmentsOld then transCache else {}
{ c with defEqTrans := ⟨f transCache, numAssignments⟩ }

@[inline] def modifyDefEqPermCache (f : DefEqCache → DefEqCache) : MetaM Unit :=
modifyCache fun ⟨c1, c2, c3, c4, c5, defeqPerm⟩ => ⟨c1, c2, c3, c4, c5, f defeqPerm⟩
Expand All @@ -658,6 +680,9 @@ def mkDefEqCacheKey (lhs rhs : Expr) : MetaM DefEqCacheKey := do
def mkInfoCacheKey (expr : Expr) (nargs? : Option Nat) : MetaM InfoCacheKey :=
return { expr, nargs?, configKey := (← read).configKey }

@[inline] def resetDefEqTransientCache : MetaM Unit :=
modify fun s => { s with cache.defEqTrans := ⟨{}, s.mctx.numAssignments⟩ }

@[inline] def resetDefEqPermCaches : MetaM Unit :=
modifyDefEqPermCache fun _ => {}

Expand Down Expand Up @@ -2349,16 +2374,6 @@ partial def processPostponed (mayPostpone : Bool := true) (exceptionOnFailure :=
-/
@[specialize] def checkpointDefEq (x : MetaM Bool) (mayPostpone : Bool := true) : MetaM Bool := do
let s ← saveState
/-
It is not safe to use the `isDefEq` cache between different `isDefEq` calls.
Reason: different configuration settings, and result depends on the state of the `MetavarContext`
We have tried in the past to track when the result was independent of the `MetavarContext` state
but it was not effective. It is more important to cache aggressively inside of a single `isDefEq`
call because some of the heuristics create many similar subproblems.
See issue #1102 for an example that triggers an exponential blowup if we don't use this more
aggressive form of caching.
-/
modifyDefEqTransientCache fun _ => {}
let postponed ← getResetPostponed
try
if (← x) then
Expand All @@ -2367,10 +2382,14 @@ partial def processPostponed (mayPostpone : Bool := true) (exceptionOnFailure :=
setPostponed (postponed ++ newPostponed)
return true
else
s.restore
-- The transient cache needs to be reverted if it assumes an assignments that is being reverted.
let isInvalidCache := s.meta.mctx.numAssignments != (← get).cache.defEqTrans.numAssignments
s.restore (transCache := isInvalidCache)
return false
else
s.restore
-- The transient cache needs to be reverted if it assumes an assignments that is being reverted.
let isInvalidCache := s.meta.mctx.numAssignments != (← get).cache.defEqTrans.numAssignments
s.restore (transCache := isInvalidCache)
return false
catch ex =>
s.restore
Expand Down Expand Up @@ -2405,6 +2424,16 @@ def isExprDefEq (t s : Expr) : MetaM Bool :=
Remark: the kernel does *not* update the type of variables in the local context.
-/
resetDefEqPermCaches
/-
It is not safe to use the transient `isDefEq` cache between different `isDefEq` calls.
Reason: different configuration settings, and result depends on the state of the `MetavarContext`
We have tried in the past to track when the result was independent of the `MetavarContext` state
but it was not effective. It is more important to cache aggressively inside of a single `isDefEq`
call because some of the heuristics create many similar subproblems.
See issue #1102 and `tests/lean/run/defEqTransCache.lean` for examples that trigger an exponential blowup
if we don't use this more aggressive form of caching.
-/
resetDefEqTransientCache
checkpointDefEq (mayPostpone := true) <| Meta.isExprDefEqAux t s

/--
Expand Down
46 changes: 27 additions & 19 deletions src/Lean/Meta/ExprDefEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ where
if !isStructureLike (← getEnv) ctorVal.induct then
trace[Meta.isDefEq.eta.struct] "failed, type is not a structure{indentExpr b}"
return false
else if (← isDefEq (← inferType a) (← inferType b)) then
else if (← checkpointDefEq <| Meta.isExprDefEqAux (← inferType a) (← inferType b)) then
checkpointDefEq do
let args := b.getAppArgs
let params := args[*...ctorVal.numParams].toArray
Expand All @@ -95,7 +95,7 @@ where
-- See comment at `isAbstractedUnassignedMVar`.
continue
trace[Meta.isDefEq.eta.struct] "{a} =?= {b} @ [{j}], {proj} =?= {args[i]}"
unless (← isDefEq proj args[i]) do
unless (← Meta.isExprDefEqAux proj args[i]) do
trace[Meta.isDefEq.eta.struct] "failed, unexpected arg #{i}, projection{indentExpr proj}\nis not defeq to{indentExpr args[i]}"
return false
return true
Expand Down Expand Up @@ -1240,15 +1240,15 @@ private partial def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool :
```
-/
private def processAssignment' (mvarApp : Expr) (v : Expr) : MetaM Bool := do
if (← processAssignment mvarApp v) then
if (← checkpointDefEq <| processAssignment mvarApp v) then
return true
else
let vNew ← whnf v
if vNew != v then
if mvarApp == vNew then
return true
else
processAssignment mvarApp vNew
checkpointDefEq <| processAssignment mvarApp vNew
else
return false

Expand Down Expand Up @@ -1629,7 +1629,7 @@ private def isDefEqProofIrrel (t s : Expr) : MetaM LBool := do
private def isDefEqMVarSelf (mvar : Expr) (args₁ args₂ : Array Expr) : MetaM Bool := do
if args₁.size != args₂.size then
pure false
else if (← isDefEqArgs mvar args₁ args₂) then
else if (← checkpointDefEq <| isDefEqArgs mvar args₁ args₂) then
pure true
else if !(← isAssignable mvar) then
pure false
Expand Down Expand Up @@ -1861,8 +1861,8 @@ end

private def isDefEqOnFailure (t s : Expr) : MetaM Bool := do
withTraceNodeBefore `Meta.isDefEq.onFailure (return m!"{t} =?= {s}") do
unstuckMVar t (fun t => Meta.isExprDefEqAux t s) <|
unstuckMVar s (fun s => Meta.isExprDefEqAux t s) <|
unstuckMVar t (fun t => checkpointDefEq <| Meta.isExprDefEqAux t s) <|
unstuckMVar s (fun s => checkpointDefEq <| Meta.isExprDefEqAux t s) <|
tryUnificationHints t s <||> tryUnificationHints s t

/--
Expand Down Expand Up @@ -1940,11 +1940,11 @@ private def isDefEqProj : Expr → Expr → MetaM Bool
| .proj m i t, .proj n j s => do
if (← read).inTypeClassResolution then
-- See comment at `inTypeClassResolution`
pure (i == j && m == n) <&&> Meta.isExprDefEqAux t s
pure (i == j && m == n) <&&> checkpointDefEq (Meta.isExprDefEqAux t s)
else if !backward.isDefEq.lazyProjDelta.get (← getOptions) then
pure (i == j && m == n) <&&> Meta.isExprDefEqAux t s
pure (i == j && m == n) <&&> checkpointDefEq (Meta.isExprDefEqAux t s)
else if i == j && m == n then
isDefEqProjDelta t s i
checkpointDefEq (isDefEqProjDelta t s i)
else
return false
| .proj structName 0 s, v => isDefEqSingleton structName s v
Expand Down Expand Up @@ -2061,12 +2061,12 @@ private def isExprDefEqExpensive (t : Expr) (s : Expr) : MetaM Bool := do
isDefEqOnFailure t s

inductive DefEqCacheKind where
| transient -- problem has mvars or is using nonstandard configuration, we should use transient cache
| transient (numAssignments : Nat) -- problem has mvars or is using nonstandard configuration, we should use transient cache
| permanent -- problem does not have mvars and we are using standard config, we can use one persistent cache.

private def getDefEqCacheKind (t s : Expr) : MetaM DefEqCacheKind := do
if t.hasMVar || s.hasMVar || (← read).canUnfold?.isSome then
return .transient
return .transient (← getMCtx).numAssignments
else
return .permanent

Expand All @@ -2084,7 +2084,12 @@ private def mkCacheKey (t s : Expr) : MetaM DefEqCacheKeyInfo := do

private def getCachedResult (keyInfo : DefEqCacheKeyInfo) : MetaM LBool := do
let cache ← match keyInfo.kind with
| .transient => pure (← get).cache.defEqTrans
| .transient numAssignments =>
let ⟨cache, numAssignmentsCache⟩ := (← get).cache.defEqTrans
if numAssignments == numAssignmentsCache then
pure cache
else
return .undef
| .permanent => pure (← get).cache.defEqPerm
match cache.find? keyInfo.key with
| some val => return val.toLBool
Expand All @@ -2094,14 +2099,17 @@ private def cacheResult (keyInfo : DefEqCacheKeyInfo) (result : Bool) : MetaM Un
let key := keyInfo.key
match keyInfo.kind with
| .permanent => modifyDefEqPermCache fun c => c.insert key result
| .transient =>
| .transient numAssignmentsOld =>
/-
We must ensure that all assigned metavariables in the key are replaced by their current assignments.
Otherwise, the key is invalid after the assignment is "backtracked".
See issue #1870 for an example.
If the result is `false`, we cache it at `numAssignmentsOld`.
If the result is `true`, we only cache it if the number of assignments hasn't increased.
-/
let key ← mkDefEqCacheKey (← instantiateMVars key.lhs) (← instantiateMVars key.rhs)
modifyDefEqTransientCache fun c => c.insert key result
if !result then
modifyDefEqTransientCache numAssignmentsOld fun c => c.insert key result
else
let numAssignmentsNew := (← getMCtx).numAssignments
if numAssignmentsOld == numAssignmentsNew then
modifyDefEqTransientCache numAssignmentsOld fun c => c.insert key result

private def whnfCoreAtDefEq (e : Expr) : MetaM Expr := do
if backward.isDefEq.lazyWhnfCore.get (← getOptions) then
Expand Down
8 changes: 6 additions & 2 deletions src/Lean/MetavarContext.lean
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ structure MetavarContext where
lAssignment : PersistentHashMap LMVarId Level := {}
/-- Assignment table for expression metavariables.-/
eAssignment : PersistentHashMap MVarId Expr := {}
/-- Number of assignments in `eAssignment` and `lAssignments`. -/
numAssignments : Nat := 0
/-- Assignment table for delayed abstraction metavariables.
For more information about delayed abstraction, see the docstring for `DelayedMetavarAssignment`. -/
dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {}
Expand Down Expand Up @@ -507,8 +509,9 @@ def hasAssignableMVar [Monad m] [MonadMCtx m] : Expr → m Bool
This is a low-level API, and it is safer to use `isLevelDefEq (mkLevelMVar mvarId) u`.
-/
def assignLevelMVar [MonadMCtx m] (mvarId : LMVarId) (val : Level) : m Unit :=
modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val }
modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val, numAssignments := m.numAssignments + 1 }

-- `assignLevelMVarExp` is only used in `instantiateMVars`, so it doesn't need to increment `numAssignments`
@[export lean_assign_lmvar]
def assignLevelMVarExp (m : MetavarContext) (mvarId : LMVarId) (val : Level) : MetavarContext :=
{ m with lAssignment := m.lAssignment.insert mvarId val }
Expand All @@ -520,8 +523,9 @@ a cycle is being introduced, or whether the expression has the right type.
This is a low-level API, and it is safer to use `isDefEq (mkMVar mvarId) x`.
-/
def _root_.Lean.MVarId.assign [MonadMCtx m] (mvarId : MVarId) (val : Expr) : m Unit :=
modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val }
modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val, numAssignments := m.numAssignments + 1 }

-- `assignExp` is only used in `instantiateMVars`, so it doesn't need to increment `numAssignments`
@[export lean_assign_mvar]
def assignExp (m : MetavarContext) (mvarId : MVarId) (val : Expr) : MetavarContext :=
{ m with eAssignment := m.eAssignment.insert mvarId val }
Expand Down
77 changes: 77 additions & 0 deletions tests/lean/run/defEqTransCache.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import Lean
/-!
Previously, unification wouldn't be very careful with the `isDefEq` cache for terms containing metavariables.
- This is mostly problematic because erasing the cache leads to exponential slowdowns (`test1` & `test2`)
- but in some cases it lead to metavariable assignments leaking into places where they shouldn't be,
which either caused unification to fail where it should succeed (`test3`)
or to succeed where it is expected to fail (which happened in one mathlib proof).
-/

set_option maxHeartbeats 1000

namespace test1
class A (n : Nat) where
x : Nat

instance [A n] : A (n+1) where
x := A.x n

theorem test [A 0] : A.x 100 = 0 := sorry

-- This rewrite should fail. Previously, it failed exponentially slowly
example [A 1] : A.x 100 = 0 := by
fail_if_success rw [@test]
sorry
end test1


namespace test2
@[irreducible] def A : Type := Unit

@[irreducible] def B : Type := Unit

unseal B in
@[coe] def AtoB (_a : A) : B := ()

instance : Coe A B where coe := AtoB

def h {α : Type} (a b : α) : Nat → α
| 0 => a
| n + 1 => h b a n

def f {α : Type} (a b : α) : Nat → Prop
| 0 => a = b
| n + 1 => f (h a b n) (h b a n) n ∧ f (h a a n) (h b b n) n

axiom foo {p} {α : Type} (a b : α) : f a b p

variable (x : A) (y : B)
-- Previously, this check was exponentially slow; now it is quadratically slow
#check (foo (↑x) y : f (AtoB x) y 30)
end test2


namespace test3
structure A (α : Type) where
x : Type
y : α

structure B (α : Type) extends A α where
z : Nat

def A.map {α β} (f : α → β) (a : A α) : A β := ⟨a.x, f a.y⟩

open Lean Meta in
elab "unfold_head" e:term : term => do
let e ← Elab.Term.elabTerm e none
unfoldDefinition e

-- use `unfold_head` to get the raw kernel projection `·.1` instead of the projection funtcion `A.x`
def test {α} (i : B α) : unfold_head i.toA.x := sorry

-- Previously, in this example the unification failed,
-- because some metavariable assignment wasn't reverted properly
-- However, it is clearly the case that `(i.toA.map f).x` is the same as `i.toA.x`
example (i : B α) (f : α → β) : (i.toA.map f).x := by
apply @test
end test3
Loading