From 563d5e8bcf8dbfd0fe56e80bfe927b1bb6fc544b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 14 Jan 2025 15:45:46 -0800 Subject: [PATCH] feat: offset equalities in `grind` (#6645) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements support for offset equality constraints in the `grind` tactic and exhaustive equality propagation for them. The `grind` tactic can now solve problems such as the following: ```lean example (f : Nat → Nat) (a b c d e : Nat) : f (a + 3) = b → f (c + 1) = d → c ≤ a + 2 → a + 1 ≤ e → e < c → b = d := by grind ``` --- src/Init/Grind/Offset.lean | 8 +++ src/Lean/Meta/Tactic/Grind.lean | 3 ++ .../Meta/Tactic/Grind/Arith/Internalize.lean | 4 +- src/Lean/Meta/Tactic/Grind/Arith/Model.lean | 9 +++- src/Lean/Meta/Tactic/Grind/Arith/Offset.lean | 53 ++++++++++++++++--- .../Meta/Tactic/Grind/Arith/ProofUtil.lean | 2 +- src/Lean/Meta/Tactic/Grind/Arith/Util.lean | 25 ++++----- src/Lean/Meta/Tactic/Grind/Core.lean | 21 ++++++-- src/Lean/Meta/Tactic/Grind/Internalize.lean | 16 +++--- src/Lean/Meta/Tactic/Grind/Types.lean | 27 ++++++++-- tests/lean/run/grind_offset.lean | 3 +- tests/lean/run/grind_offset_cnstr.lean | 21 ++++++++ tests/lean/run/grind_pre.lean | 5 +- 13 files changed, 153 insertions(+), 44 deletions(-) diff --git a/src/Init/Grind/Offset.lean b/src/Init/Grind/Offset.lean index 1326275b04ac..49e0f7b560bb 100644 --- a/src/Init/Grind/Offset.lean +++ b/src/Init/Grind/Offset.lean @@ -80,4 +80,12 @@ theorem Nat.ro_eq_false_of_lo (u v k₁ k₂ : Nat) : isLt k₂ k₁ = true → theorem Nat.lo_eq_false_of_ro (u v k₁ k₂ : Nat) : isLt k₁ k₂ = true → u ≤ v + k₁ → (v + k₂ ≤ u) = False := by simp [isLt]; omega +/-! +Helper theorems for equality propagation +-/ + +theorem Nat.le_of_eq_1 (u v : Nat) : u = v → u ≤ v := by omega +theorem Nat.le_of_eq_2 (u v : Nat) : u = v → v ≤ u := by omega +theorem Nat.eq_of_le_of_le (u v : Nat) : u ≤ v → v ≤ u → u = v := by omega + end Lean.Grind diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index fe32d443e1eb..7033f61b6af8 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -48,6 +48,9 @@ builtin_initialize registerTraceClass `grind.offset.dist builtin_initialize registerTraceClass `grind.offset.internalize builtin_initialize registerTraceClass `grind.offset.internalize.term (inherited := true) builtin_initialize registerTraceClass `grind.offset.propagate +builtin_initialize registerTraceClass `grind.offset.eq +builtin_initialize registerTraceClass `grind.offset.eq.to (inherited := true) +builtin_initialize registerTraceClass `grind.offset.eq.from (inherited := true) /-! Trace options for `grind` developers -/ builtin_initialize registerTraceClass `grind.debug diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean index a0bfc66d9de5..e9790a986eb4 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean @@ -8,7 +8,7 @@ import Lean.Meta.Tactic.Grind.Arith.Offset namespace Lean.Meta.Grind.Arith -def internalize (e : Expr) : GoalM Unit := do - Offset.internalizeCnstr e +def internalize (e : Expr) (parent : Expr) : GoalM Unit := do + Offset.internalize e parent end Lean.Meta.Grind.Arith diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Model.lean b/src/Lean/Meta/Tactic/Grind/Arith/Model.lean index cb8d797478b9..7a3f4e74d6d1 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Model.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Model.lean @@ -6,6 +6,7 @@ Authors: Leonardo de Moura prelude import Lean.Meta.Basic import Lean.Meta.Tactic.Grind.Types +import Lean.Meta.Tactic.Grind.Util namespace Lean.Meta.Grind.Arith.Offset /-- Construct a model that statisfies all offset constraints -/ @@ -33,7 +34,13 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Nat)) := do for u in [:nodes.size] do let some val := pre[u]! | unreachable! let val := (val - min).toNat - r := r.push (nodes[u]!, val) + let e := nodes[u]! + /- + We should not include the assignment for auxiliary offset terms since + they do not provide any additional information. + -/ + unless isNatOffset? e |>.isSome do + r := r.push (e, val) return r end Lean.Meta.Grind.Arith.Offset diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Offset.lean b/src/Lean/Meta/Tactic/Grind/Arith/Offset.lean index 31b07311e80b..cfac5021c179 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Offset.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Offset.lean @@ -48,6 +48,7 @@ def mkNode (expr : Expr) : GoalM NodeId := do targets := s.targets.push {} proofs := s.proofs.push {} } + markAsOffsetTerm expr return nodeId private def getExpr (u : NodeId) : GoalM Expr := do @@ -59,6 +60,11 @@ private def getDist? (u v : NodeId) : GoalM (Option Int) := do private def getProof? (u v : NodeId) : GoalM (Option ProofInfo) := do return (← get').proofs[u]!.find? v +private def getNodeId (e : Expr) : GoalM NodeId := do + let some nodeId := (← get').nodeMap.find? { expr := e } + | throwError "internal `grind` error, term has not been internalized by offset module{indentExpr e}" + return nodeId + /-- Returns a proof for `u + k ≤ v` (or `u ≤ v + k`) where `k` is the shortest path between `u` and `v`. @@ -160,10 +166,24 @@ private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → GoalM B return !(← f c e) modify' fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' } +/-- Equality propagation. -/ +private def propagateEq (u v : NodeId) (k : Int) : GoalM Unit := do + if k != 0 then return () + let some k' ← getDist? v u | return () + if k' != 0 then return () + let ue ← getExpr u + let ve ← getExpr v + if (← isEqv ue ve) then return () + let huv ← mkProofForPath u v + let hvu ← mkProofForPath v u + trace[grind.offset.eq.from] "{ue}, {ve}" + pushEq ue ve <| mkApp4 (mkConst ``Grind.Nat.eq_of_le_of_le) ue ve huv hvu + /-- Performs constraint propagation. -/ private def propagateAll (u v : NodeId) (k : Int) : GoalM Unit := do updateCnstrsOf u v fun c e => return !(← propagateTrue u v k c e) updateCnstrsOf v u fun c e => return !(← propagateFalse u v k c e) + propagateEq u v k /-- If `isShorter u v k`, updates the shortest distance between `u` and `v`. @@ -203,8 +223,7 @@ where /- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/ updateIfShorter i j (k₁+k+k₂) v -def internalizeCnstr (e : Expr) : GoalM Unit := do - let some c := isNatOffsetCnstr? e | return () +private def internalizeCnstr (e : Expr) (c : Cnstr Expr) : GoalM Unit := do let u ← mkNode c.u let v ← mkNode c.v let c := { c with u, v } @@ -222,6 +241,29 @@ def internalizeCnstr (e : Expr) : GoalM Unit := do s.cnstrsOf.insert (u, v) cs } +def internalize (e : Expr) (parent : Expr) : GoalM Unit := do + if let some c := isNatOffsetCnstr? e then + internalizeCnstr e c + else if let some (b, k) := isNatOffset? e then + if (isNatOffsetCnstr? parent).isSome then return () + -- `e` is of the form `b + k` + let u ← mkNode e + let v ← mkNode b + -- `u = v + k`. So, we add edges for `u ≤ v + k` and `v + k ≤ u`. + let h := mkApp (mkConst ``Nat.le_refl) e + addEdge u v k h + addEdge v u (-k) h + +@[export lean_process_new_offset_eq] +def processNewOffsetEqImpl (a b : Expr) : GoalM Unit := do + unless isSameExpr a b do + trace[grind.offset.eq.to] "{a}, {b}" + let u ← getNodeId a + let v ← getNodeId b + let h ← mkEqProof a b + addEdge u v 0 <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_1) a b h + addEdge v u 0 <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_2) a b h + def traceDists : GoalM Unit := do let s ← get' for u in [:s.targets.size], es in s.targets.toArray do @@ -231,13 +273,12 @@ def traceDists : GoalM Unit := do def Cnstr.toExpr (c : Cnstr NodeId) : GoalM Expr := do let u := (← get').nodes[c.u]! let v := (← get').nodes[c.v]! - let mk := if c.le then mkNatLE else mkNatEq if c.k == 0 then - return mk u v + return mkNatLE u v else if c.k < 0 then - return mk (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v + return mkNatLE (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v else - return mk u (mkNatAdd v (Lean.toExpr c.k.toNat)) + return mkNatLE u (mkNatAdd v (Lean.toExpr c.k.toNat)) def checkInvariants : GoalM Unit := do let s ← get' diff --git a/src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean b/src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean index d45b266df2b8..c697a401d0c1 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean @@ -82,7 +82,7 @@ def mkOfNegEqFalse (nodes : PArray Expr) (c : Cnstr NodeId) (h : Expr) : Expr := let v := nodes[c.v]! if c.k == 0 then mkApp3 (mkConst ``Nat.of_le_eq_false) u v h - else if c.k == -1 && c.le then + else if c.k == -1 then mkApp3 (mkConst ``Nat.of_lo_eq_false_1) u v h else if c.k < 0 then mkApp4 (mkConst ``Nat.of_lo_eq_false) u v (toExprN (-c.k)) h diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Util.lean b/src/Lean/Meta/Tactic/Grind/Arith/Util.lean index f3da57f4c750..1c8804f0eec5 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Util.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Util.lean @@ -50,23 +50,19 @@ structure Offset.Cnstr (α : Type) where u : α v : α k : Int := 0 - le : Bool := true deriving Inhabited def Offset.Cnstr.neg : Cnstr α → Cnstr α - | { u, v, k, le } => { u := v, v := u, le, k := -k - 1 } + | { u, v, k } => { u := v, v := u, k := -k - 1 } example (c : Offset.Cnstr α) : c.neg.neg = c := by cases c; simp [Offset.Cnstr.neg]; omega def Offset.toMessageData [inst : ToMessageData α] (c : Offset.Cnstr α) : MessageData := - match c.k, c.le with - | .ofNat 0, true => m!"{c.u} ≤ {c.v}" - | .ofNat 0, false => m!"{c.u} = {c.v}" - | .ofNat k, true => m!"{c.u} ≤ {c.v} + {k}" - | .ofNat k, false => m!"{c.u} = {c.v} + {k}" - | .negSucc k, true => m!"{c.u} + {k + 1} ≤ {c.v}" - | .negSucc k, false => m!"{c.u} + {k + 1} = {c.v}" + match c.k with + | .ofNat 0 => m!"{c.u} ≤ {c.v}" + | .ofNat k => m!"{c.u} ≤ {c.v} + {k}" + | .negSucc k => m!"{c.u} + {k + 1} ≤ {c.v}" instance : ToMessageData (Offset.Cnstr Expr) where toMessageData c := Offset.toMessageData c @@ -74,16 +70,15 @@ instance : ToMessageData (Offset.Cnstr Expr) where /-- Returns `some cnstr` if `e` is offset constraint. -/ def isNatOffsetCnstr? (e : Expr) : Option (Offset.Cnstr Expr) := match_expr e with - | LE.le _ inst a b => if isInstLENat inst then go a b true else none - | Eq α a b => if isNatType α then go a b false else none + | LE.le _ inst a b => if isInstLENat inst then go a b else none | _ => none where - go (u v : Expr) (le : Bool) := + go (u v : Expr) := if let some (u, k) := isNatOffset? u then - some { u, k := - k, v, le } + some { u, k := - k, v } else if let some (v, k) := isNatOffset? v then - some { u, v, k := k, le } + some { u, v, k := k } else - some { u, v, le } + some { u, v } end Lean.Meta.Grind.Arith diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index 9afcfcc38456..77c91b817044 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -86,6 +86,18 @@ private partial def updateMT (root : Expr) : GoalM Unit := do setENode parent { node with mt := gmt } updateMT parent +/-- +Helper function for combining `ENode.offset?` fields and propagating an equality +to the offset constraint module. +-/ +private def propagateOffsetEq (root : Expr) (roofOffset? otherOffset? : Option Expr) : GoalM Unit := do + let some otherOffset := otherOffset? | return () + if let some rootOffset := roofOffset? then + processNewOffsetEq rootOffset otherOffset + else + let n ← getENode root + setENode root { n with offset? := otherOffset? } + private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do let lhsNode ← getENode lhs let rhsNode ← getENode rhs @@ -146,17 +158,18 @@ where next := rhsRoot.next } setENode rhsNode.root { rhsRoot with - next := lhsRoot.next - size := rhsRoot.size + lhsRoot.size + next := lhsRoot.next + size := rhsRoot.size + lhsRoot.size hasLambdas := rhsRoot.hasLambdas || lhsRoot.hasLambdas heqProofs := isHEq || rhsRoot.heqProofs || lhsRoot.heqProofs } copyParentsTo parents rhsNode.root + unless (← isInconsistent) do + updateMT rhsRoot.self + propagateOffsetEq rhsNode.root rhsRoot.offset? lhsRoot.offset? unless (← isInconsistent) do for parent in parents do propagateUp parent - unless (← isInconsistent) do - updateMT rhsRoot.self updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do traverseEqc lhs fun n => diff --git a/src/Lean/Meta/Tactic/Grind/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Internalize.lean index b29e3ad7574c..a867fe0eb36a 100644 --- a/src/Lean/Meta/Tactic/Grind/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/Internalize.lean @@ -98,6 +98,8 @@ private def pushCastHEqs (e : Expr) : GoalM Unit := do | f@Eq.recOn α a motive b h v => pushHEq e v (mkApp6 (mkConst ``Grind.eqRecOn_heq f.constLevels!) α a motive b h v) | _ => return () +def noParent := mkBVar 0 + mutual /-- Internalizes the nested ground terms in the given pattern. -/ private partial def internalizePattern (pattern : Expr) (generation : Nat) : GoalM Expr := do @@ -146,7 +148,7 @@ private partial def activateTheoremPatterns (fName : Name) (generation : Nat) : trace_goal[grind.ematch] "reinsert `{thm.origin.key}`" modify fun s => { s with thmMap := s.thmMap.insert thm } -partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do +partial def internalize (e : Expr) (generation : Nat) (parent : Expr := noParent) : GoalM Unit := do if (← alreadyInternalized e) then return () trace_goal[grind.internalize] "{e}" match e with @@ -157,10 +159,10 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do | .forallE _ d b _ => mkENodeCore e (ctor := false) (interpreted := false) (generation := generation) if (← isProp d <&&> isProp e) then - internalize d generation + internalize d generation e registerParent e d unless b.hasLooseBVars do - internalize b generation + internalize b generation e registerParent e b propagateUp e | .lit .. | .const .. => @@ -182,22 +184,22 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do -- We only internalize the proposition. We can skip the proof because of -- proof irrelevance let c := args[0]! - internalize c generation + internalize c generation e registerParent e c else if let .const fName _ := f then activateTheoremPatterns fName generation else - internalize f generation + internalize f generation e registerParent e f for h : i in [: args.size] do let arg := args[i] - internalize arg generation + internalize arg generation e registerParent e arg mkENode e generation addCongrTable e updateAppMap e - Arith.internalize e + Arith.internalize e parent propagateUp e end diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 123a92412779..5f536b967625 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -202,13 +202,19 @@ structure ENode where on heterogeneous equality. -/ heqProofs : Bool := false - /-- - Unique index used for pretty printing and debugging purposes. - -/ + /-- Unique index used for pretty printing and debugging purposes. -/ idx : Nat := 0 + /-- The generation in which this enode was created. -/ generation : Nat := 0 /-- Modification time -/ mt : Nat := 0 + /-- + The `offset?` field is used to propagate equalities from the `grind` congruence closure module + to the offset constraints module. When `grind` merges two equivalence classes, and both have + an associated `offset?` set to `some e`, the equality is propagated. This field is + assigned during the internalization of offset terms. + -/ + offset? : Option Expr := none deriving Inhabited, Repr def ENode.isCongrRoot (n : ENode) := @@ -643,6 +649,21 @@ def mkENode (e : Expr) (generation : Nat) : GoalM Unit := do let interpreted ← isInterpreted e mkENodeCore e interpreted ctor generation +@[extern "lean_process_new_offset_eq"] -- forward definition +opaque processNewOffsetEq (a b : Expr) : GoalM Unit + +/-- +Marks `e` as a term of interest to the offset constraint module. +If the root of `e`s equivalence class has already a term of interest, +a new equality is propagated to the offset module. +-/ +def markAsOffsetTerm (e : Expr) : GoalM Unit := do + let n ← getRootENode e + if let some e' := n.offset? then + processNewOffsetEq e e' + else + setENode n.self { n with offset? := some e } + /-- Returns `true` is `e` is the root of its congruence class. -/ def isCongrRoot (e : Expr) : GoalM Bool := do return (← getENode e).isCongrRoot diff --git a/tests/lean/run/grind_offset.lean b/tests/lean/run/grind_offset.lean index 946f4759a867..a817abe1fbaa 100644 --- a/tests/lean/run/grind_offset.lean +++ b/tests/lean/run/grind_offset.lean @@ -83,8 +83,7 @@ info: [grind.assert] foo (c + 1) = a -/ #guard_msgs (info) in example : foo (c + 1) = a → c = b + 1 → a = g (foo b) := by - fail_if_success grind - sorry + grind set_option trace.grind.assert false diff --git a/tests/lean/run/grind_offset_cnstr.lean b/tests/lean/run/grind_offset_cnstr.lean index 0e4d960ca984..0fd24a2871c5 100644 --- a/tests/lean/run/grind_offset_cnstr.lean +++ b/tests/lean/run/grind_offset_cnstr.lean @@ -352,3 +352,24 @@ example (p r : Prop) (a b : Nat) : (c + 1 ≤ a ↔ p) → (c + 2 ≤ a + 1 ↔ set_option trace.grind.split true in example (p r : Prop) (a b : Nat) : (c + 5 ≤ a ↔ p) → (c + 4 ≤ a ↔ r) → a ≤ b → b ≤ c + 3 → ¬p ∧ ¬r := by grind (splits := 0) + +example (a b c d: Nat) : a ≤ b → b + 2 = c → c < d → a + 2 < d := by + grind + +example (a b c : Nat) : a + 2 = b → b + 3 = c → a + 5 ≤ c := by + grind + +example (a b c : Nat) : a + 2 = b → c ≤ a + 2 → a + 2 ≤ c → c = b := by + grind + +example (a b c : Nat) : a + 2 = b → b + 3 = c → a + 5 = c := by + grind + +example (f : Nat → Nat) (a b c d e : Nat) : + f (a + 3) = b → + f (c + 1) = d → + c ≤ a + 2 → + a + 1 ≤ e → + e < c → + b = d := by + grind diff --git a/tests/lean/run/grind_pre.lean b/tests/lean/run/grind_pre.lean index beac5df04b5c..e0abdee2e646 100644 --- a/tests/lean/run/grind_pre.lean +++ b/tests/lean/run/grind_pre.lean @@ -75,9 +75,8 @@ x✝ : ¬g (i + 1) j ⋯ = i + j + 1 [prop] ¬g (i + 1) j ⋯ = i + j + 1[eqc] True propositions [prop] j + 1 ≤ i[eqc] False propositions [prop] g (i + 1) j ⋯ = i + j + 1[offset] Assignment satisfying offset contraints - [assign] j := 0 - [assign] i := 1 - [assign] g (i + 1) j ⋯ := 0 + [assign] j := 1 + [assign] i := 2 [assign] i + j := 0 -/ #guard_msgs (error) in