Skip to content

Commit

Permalink
feat: offset equalities in grind (#6645)
Browse files Browse the repository at this point in the history
This PR implements support for offset equality constraints in the
`grind` tactic and exhaustive equality propagation for them. The `grind`
tactic can now solve problems such as the following:

```lean
example (f : Nat → Nat) (a b c d e : Nat) :
        f (a + 3) = b →
        f (c + 1) = d →
        c ≤ a + 2 →
        a + 1 ≤ e →
        e < c →
        b = d := by
  grind
```
  • Loading branch information
leodemoura authored Jan 14, 2025
1 parent 3da7f70 commit 563d5e8
Show file tree
Hide file tree
Showing 13 changed files with 153 additions and 44 deletions.
8 changes: 8 additions & 0 deletions src/Init/Grind/Offset.lean
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,12 @@ theorem Nat.ro_eq_false_of_lo (u v k₁ k₂ : Nat) : isLt k₂ k₁ = true →
theorem Nat.lo_eq_false_of_ro (u v k₁ k₂ : Nat) : isLt k₁ k₂ = true → u ≤ v + k₁ → (v + k₂ ≤ u) = False := by
simp [isLt]; omega

/-!
Helper theorems for equality propagation
-/

theorem Nat.le_of_eq_1 (u v : Nat) : u = v → u ≤ v := by omega
theorem Nat.le_of_eq_2 (u v : Nat) : u = v → v ≤ u := by omega
theorem Nat.eq_of_le_of_le (u v : Nat) : u ≤ v → v ≤ u → u = v := by omega

end Lean.Grind
3 changes: 3 additions & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ builtin_initialize registerTraceClass `grind.offset.dist
builtin_initialize registerTraceClass `grind.offset.internalize
builtin_initialize registerTraceClass `grind.offset.internalize.term (inherited := true)
builtin_initialize registerTraceClass `grind.offset.propagate
builtin_initialize registerTraceClass `grind.offset.eq
builtin_initialize registerTraceClass `grind.offset.eq.to (inherited := true)
builtin_initialize registerTraceClass `grind.offset.eq.from (inherited := true)

/-! Trace options for `grind` developers -/
builtin_initialize registerTraceClass `grind.debug
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import Lean.Meta.Tactic.Grind.Arith.Offset

namespace Lean.Meta.Grind.Arith

def internalize (e : Expr) : GoalM Unit := do
Offset.internalizeCnstr e
def internalize (e : Expr) (parent : Expr) : GoalM Unit := do
Offset.internalize e parent

end Lean.Meta.Grind.Arith
9 changes: 8 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Arith/Model.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Lean.Meta.Basic
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Util

namespace Lean.Meta.Grind.Arith.Offset
/-- Construct a model that statisfies all offset constraints -/
Expand Down Expand Up @@ -33,7 +34,13 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Nat)) := do
for u in [:nodes.size] do
let some val := pre[u]! | unreachable!
let val := (val - min).toNat
r := r.push (nodes[u]!, val)
let e := nodes[u]!
/-
We should not include the assignment for auxiliary offset terms since
they do not provide any additional information.
-/
unless isNatOffset? e |>.isSome do
r := r.push (e, val)
return r

end Lean.Meta.Grind.Arith.Offset
53 changes: 47 additions & 6 deletions src/Lean/Meta/Tactic/Grind/Arith/Offset.lean
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def mkNode (expr : Expr) : GoalM NodeId := do
targets := s.targets.push {}
proofs := s.proofs.push {}
}
markAsOffsetTerm expr
return nodeId

private def getExpr (u : NodeId) : GoalM Expr := do
Expand All @@ -59,6 +60,11 @@ private def getDist? (u v : NodeId) : GoalM (Option Int) := do
private def getProof? (u v : NodeId) : GoalM (Option ProofInfo) := do
return (← get').proofs[u]!.find? v

private def getNodeId (e : Expr) : GoalM NodeId := do
let some nodeId := (← get').nodeMap.find? { expr := e }
| throwError "internal `grind` error, term has not been internalized by offset module{indentExpr e}"
return nodeId

/--
Returns a proof for `u + k ≤ v` (or `u ≤ v + k`) where `k` is the
shortest path between `u` and `v`.
Expand Down Expand Up @@ -160,10 +166,24 @@ private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → GoalM B
return !(← f c e)
modify' fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' }

/-- Equality propagation. -/
private def propagateEq (u v : NodeId) (k : Int) : GoalM Unit := do
if k != 0 then return ()
let some k' ← getDist? v u | return ()
if k' != 0 then return ()
let ue ← getExpr u
let ve ← getExpr v
if (← isEqv ue ve) then return ()
let huv ← mkProofForPath u v
let hvu ← mkProofForPath v u
trace[grind.offset.eq.from] "{ue}, {ve}"
pushEq ue ve <| mkApp4 (mkConst ``Grind.Nat.eq_of_le_of_le) ue ve huv hvu

/-- Performs constraint propagation. -/
private def propagateAll (u v : NodeId) (k : Int) : GoalM Unit := do
updateCnstrsOf u v fun c e => return !(← propagateTrue u v k c e)
updateCnstrsOf v u fun c e => return !(← propagateFalse u v k c e)
propagateEq u v k

/--
If `isShorter u v k`, updates the shortest distance between `u` and `v`.
Expand Down Expand Up @@ -203,8 +223,7 @@ where
/- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/
updateIfShorter i j (k₁+k+k₂) v

def internalizeCnstr (e : Expr) : GoalM Unit := do
let some c := isNatOffsetCnstr? e | return ()
private def internalizeCnstr (e : Expr) (c : Cnstr Expr) : GoalM Unit := do
let u ← mkNode c.u
let v ← mkNode c.v
let c := { c with u, v }
Expand All @@ -222,6 +241,29 @@ def internalizeCnstr (e : Expr) : GoalM Unit := do
s.cnstrsOf.insert (u, v) cs
}

def internalize (e : Expr) (parent : Expr) : GoalM Unit := do
if let some c := isNatOffsetCnstr? e then
internalizeCnstr e c
else if let some (b, k) := isNatOffset? e then
if (isNatOffsetCnstr? parent).isSome then return ()
-- `e` is of the form `b + k`
let u ← mkNode e
let v ← mkNode b
-- `u = v + k`. So, we add edges for `u ≤ v + k` and `v + k ≤ u`.
let h := mkApp (mkConst ``Nat.le_refl) e
addEdge u v k h
addEdge v u (-k) h

@[export lean_process_new_offset_eq]
def processNewOffsetEqImpl (a b : Expr) : GoalM Unit := do
unless isSameExpr a b do
trace[grind.offset.eq.to] "{a}, {b}"
let u ← getNodeId a
let v ← getNodeId b
let h ← mkEqProof a b
addEdge u v 0 <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_1) a b h
addEdge v u 0 <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_2) a b h

def traceDists : GoalM Unit := do
let s ← get'
for u in [:s.targets.size], es in s.targets.toArray do
Expand All @@ -231,13 +273,12 @@ def traceDists : GoalM Unit := do
def Cnstr.toExpr (c : Cnstr NodeId) : GoalM Expr := do
let u := (← get').nodes[c.u]!
let v := (← get').nodes[c.v]!
let mk := if c.le then mkNatLE else mkNatEq
if c.k == 0 then
return mk u v
return mkNatLE u v
else if c.k < 0 then
return mk (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v
return mkNatLE (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v
else
return mk u (mkNatAdd v (Lean.toExpr c.k.toNat))
return mkNatLE u (mkNatAdd v (Lean.toExpr c.k.toNat))

def checkInvariants : GoalM Unit := do
let s ← get'
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def mkOfNegEqFalse (nodes : PArray Expr) (c : Cnstr NodeId) (h : Expr) : Expr :=
let v := nodes[c.v]!
if c.k == 0 then
mkApp3 (mkConst ``Nat.of_le_eq_false) u v h
else if c.k == -1 && c.le then
else if c.k == -1 then
mkApp3 (mkConst ``Nat.of_lo_eq_false_1) u v h
else if c.k < 0 then
mkApp4 (mkConst ``Nat.of_lo_eq_false) u v (toExprN (-c.k)) h
Expand Down
25 changes: 10 additions & 15 deletions src/Lean/Meta/Tactic/Grind/Arith/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,40 +50,35 @@ structure Offset.Cnstr (α : Type) where
u : α
v : α
k : Int := 0
le : Bool := true
deriving Inhabited

def Offset.Cnstr.neg : Cnstr α → Cnstr α
| { u, v, k, le } => { u := v, v := u, le, k := -k - 1 }
| { u, v, k } => { u := v, v := u, k := -k - 1 }

example (c : Offset.Cnstr α) : c.neg.neg = c := by
cases c; simp [Offset.Cnstr.neg]; omega

def Offset.toMessageData [inst : ToMessageData α] (c : Offset.Cnstr α) : MessageData :=
match c.k, c.le with
| .ofNat 0, true => m!"{c.u} ≤ {c.v}"
| .ofNat 0, false => m!"{c.u} = {c.v}"
| .ofNat k, true => m!"{c.u} ≤ {c.v} + {k}"
| .ofNat k, false => m!"{c.u} = {c.v} + {k}"
| .negSucc k, true => m!"{c.u} + {k + 1} ≤ {c.v}"
| .negSucc k, false => m!"{c.u} + {k + 1} = {c.v}"
match c.k with
| .ofNat 0 => m!"{c.u} ≤ {c.v}"
| .ofNat k => m!"{c.u} ≤ {c.v} + {k}"
| .negSucc k => m!"{c.u} + {k + 1} ≤ {c.v}"

instance : ToMessageData (Offset.Cnstr Expr) where
toMessageData c := Offset.toMessageData c

/-- Returns `some cnstr` if `e` is offset constraint. -/
def isNatOffsetCnstr? (e : Expr) : Option (Offset.Cnstr Expr) :=
match_expr e with
| LE.le _ inst a b => if isInstLENat inst then go a b true else none
| Eq α a b => if isNatType α then go a b false else none
| LE.le _ inst a b => if isInstLENat inst then go a b else none
| _ => none
where
go (u v : Expr) (le : Bool) :=
go (u v : Expr) :=
if let some (u, k) := isNatOffset? u then
some { u, k := - k, v, le }
some { u, k := - k, v }
else if let some (v, k) := isNatOffset? v then
some { u, v, k := k, le }
some { u, v, k := k }
else
some { u, v, le }
some { u, v }

end Lean.Meta.Grind.Arith
21 changes: 17 additions & 4 deletions src/Lean/Meta/Tactic/Grind/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ private partial def updateMT (root : Expr) : GoalM Unit := do
setENode parent { node with mt := gmt }
updateMT parent

/--
Helper function for combining `ENode.offset?` fields and propagating an equality
to the offset constraint module.
-/
private def propagateOffsetEq (root : Expr) (roofOffset? otherOffset? : Option Expr) : GoalM Unit := do
let some otherOffset := otherOffset? | return ()
if let some rootOffset := roofOffset? then
processNewOffsetEq rootOffset otherOffset
else
let n ← getENode root
setENode root { n with offset? := otherOffset? }

private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
let lhsNode ← getENode lhs
let rhsNode ← getENode rhs
Expand Down Expand Up @@ -146,17 +158,18 @@ where
next := rhsRoot.next
}
setENode rhsNode.root { rhsRoot with
next := lhsRoot.next
size := rhsRoot.size + lhsRoot.size
next := lhsRoot.next
size := rhsRoot.size + lhsRoot.size
hasLambdas := rhsRoot.hasLambdas || lhsRoot.hasLambdas
heqProofs := isHEq || rhsRoot.heqProofs || lhsRoot.heqProofs
}
copyParentsTo parents rhsNode.root
unless (← isInconsistent) do
updateMT rhsRoot.self
propagateOffsetEq rhsNode.root rhsRoot.offset? lhsRoot.offset?
unless (← isInconsistent) do
for parent in parents do
propagateUp parent
unless (← isInconsistent) do
updateMT rhsRoot.self

updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do
traverseEqc lhs fun n =>
Expand Down
16 changes: 9 additions & 7 deletions src/Lean/Meta/Tactic/Grind/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ private def pushCastHEqs (e : Expr) : GoalM Unit := do
| [email protected] α a motive b h v => pushHEq e v (mkApp6 (mkConst ``Grind.eqRecOn_heq f.constLevels!) α a motive b h v)
| _ => return ()

def noParent := mkBVar 0

mutual
/-- Internalizes the nested ground terms in the given pattern. -/
private partial def internalizePattern (pattern : Expr) (generation : Nat) : GoalM Expr := do
Expand Down Expand Up @@ -146,7 +148,7 @@ private partial def activateTheoremPatterns (fName : Name) (generation : Nat) :
trace_goal[grind.ematch] "reinsert `{thm.origin.key}`"
modify fun s => { s with thmMap := s.thmMap.insert thm }

partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
partial def internalize (e : Expr) (generation : Nat) (parent : Expr := noParent) : GoalM Unit := do
if (← alreadyInternalized e) then return ()
trace_goal[grind.internalize] "{e}"
match e with
Expand All @@ -157,10 +159,10 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
| .forallE _ d b _ =>
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
if (← isProp d <&&> isProp e) then
internalize d generation
internalize d generation e
registerParent e d
unless b.hasLooseBVars do
internalize b generation
internalize b generation e
registerParent e b
propagateUp e
| .lit .. | .const .. =>
Expand All @@ -182,22 +184,22 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
-- We only internalize the proposition. We can skip the proof because of
-- proof irrelevance
let c := args[0]!
internalize c generation
internalize c generation e
registerParent e c
else
if let .const fName _ := f then
activateTheoremPatterns fName generation
else
internalize f generation
internalize f generation e
registerParent e f
for h : i in [: args.size] do
let arg := args[i]
internalize arg generation
internalize arg generation e
registerParent e arg
mkENode e generation
addCongrTable e
updateAppMap e
Arith.internalize e
Arith.internalize e parent
propagateUp e
end

Expand Down
27 changes: 24 additions & 3 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,19 @@ structure ENode where
on heterogeneous equality.
-/
heqProofs : Bool := false
/--
Unique index used for pretty printing and debugging purposes.
-/
/-- Unique index used for pretty printing and debugging purposes. -/
idx : Nat := 0
/-- The generation in which this enode was created. -/
generation : Nat := 0
/-- Modification time -/
mt : Nat := 0
/--
The `offset?` field is used to propagate equalities from the `grind` congruence closure module
to the offset constraints module. When `grind` merges two equivalence classes, and both have
an associated `offset?` set to `some e`, the equality is propagated. This field is
assigned during the internalization of offset terms.
-/
offset? : Option Expr := none
deriving Inhabited, Repr

def ENode.isCongrRoot (n : ENode) :=
Expand Down Expand Up @@ -643,6 +649,21 @@ def mkENode (e : Expr) (generation : Nat) : GoalM Unit := do
let interpreted ← isInterpreted e
mkENodeCore e interpreted ctor generation

@[extern "lean_process_new_offset_eq"] -- forward definition
opaque processNewOffsetEq (a b : Expr) : GoalM Unit

/--
Marks `e` as a term of interest to the offset constraint module.
If the root of `e`s equivalence class has already a term of interest,
a new equality is propagated to the offset module.
-/
def markAsOffsetTerm (e : Expr) : GoalM Unit := do
let n ← getRootENode e
if let some e' := n.offset? then
processNewOffsetEq e e'
else
setENode n.self { n with offset? := some e }

/-- Returns `true` is `e` is the root of its congruence class. -/
def isCongrRoot (e : Expr) : GoalM Bool := do
return (← getENode e).isCongrRoot
Expand Down
3 changes: 1 addition & 2 deletions tests/lean/run/grind_offset.lean
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ info: [grind.assert] foo (c + 1) = a
-/
#guard_msgs (info) in
example : foo (c + 1) = a → c = b + 1 → a = g (foo b) := by
fail_if_success grind
sorry
grind

set_option trace.grind.assert false

Expand Down
21 changes: 21 additions & 0 deletions tests/lean/run/grind_offset_cnstr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,24 @@ example (p r : Prop) (a b : Nat) : (c + 1 ≤ a ↔ p) → (c + 2 ≤ a + 1 ↔
set_option trace.grind.split true in
example (p r : Prop) (a b : Nat) : (c + 5 ≤ a ↔ p) → (c + 4 ≤ a ↔ r) → a ≤ b → b ≤ c + 3 → ¬p ∧ ¬r := by
grind (splits := 0)

example (a b c d: Nat) : a ≤ b → b + 2 = c → c < d → a + 2 < d := by
grind

example (a b c : Nat) : a + 2 = b → b + 3 = c → a + 5 ≤ c := by
grind

example (a b c : Nat) : a + 2 = b → c ≤ a + 2 → a + 2 ≤ c → c = b := by
grind

example (a b c : Nat) : a + 2 = b → b + 3 = c → a + 5 = c := by
grind

example (f : Nat → Nat) (a b c d e : Nat) :
f (a + 3) = b →
f (c + 1) = d →
c ≤ a + 2
a + 1 ≤ e →
e < c →
b = d := by
grind
Loading

0 comments on commit 563d5e8

Please sign in to comment.