Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add proper erasure of type dependencies in LCNF #6678

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions src/Lean/Compiler/LCNF/MonoTypes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,23 @@ The type contains only `→` and constants.
-/
partial def toMonoType (type : Expr) : CoreM Expr := do
let type := type.headBeta
if type.isErased then
return erasedExpr
else if isTypeFormerType type then
return erasedExpr
else match type with
| .const .. => visitApp type #[]
| .app .. => type.withApp visitApp
| .forallE _ d b _ => mkArrow (← toMonoType d) (← toMonoType (b.instantiate1 erasedExpr))
| _ => return erasedExpr
match type with
| .const .. => visitApp type #[]
| .app .. => type.withApp visitApp
| .forallE _ d b _ =>
let monoB ← toMonoType (b.instantiate1 anyExpr)
match monoB with
| .const ``lcErased _ => return erasedExpr
| _ => mkArrow (← toMonoType d) monoB
| .sort _ => return erasedExpr
| _ => return anyExpr
where
visitApp (f : Expr) (args : Array Expr) : CoreM Expr := do
match f with
| .const ``lcErased _ => return erasedExpr
| .const ``lcAny _ => return anyExpr
| .const ``Decidable _ => return mkConst ``Bool
| .const declName us =>
if declName == ``Decidable then
return mkConst ``Bool
if let some info ← hasTrivialStructure? declName then
let ctorType ← getOtherDeclBaseType info.ctorName []
toMonoType (getParamTypes (← instantiateForall ctorType args[:info.numParams]))[info.fieldIdx]!
Expand All @@ -96,15 +98,13 @@ where
for arg in args do
let .forallE _ d b _ := type.headBeta | unreachable!
let arg := arg.headBeta
if arg.isErased then
result := mkApp result arg
else if d.isErased || d matches .sort _ then
if d matches .const ``lcErased _ | .sort _ then
result := mkApp result (← toMonoType arg)
else
result := mkApp result erasedExpr
type := b.instantiate1 arg
return result
| _ => return erasedExpr
| _ => return anyExpr

/--
State for the environment extension used to save the LCNF mono phase type for declarations
Expand Down
5 changes: 4 additions & 1 deletion src/Lean/Compiler/LCNF/PrettyPrinter.lean
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def ppLetDecl (letDecl : LetDecl) : M Format := do
return f!"let {letDecl.binderName} := {← ppLetValue letDecl.value}"

def getFunType (ps : Array Param) (type : Expr) : CoreM Expr :=
instantiateForall type (ps.map (mkFVar ·.fvarId))
if type.isErased then
pure type
else
instantiateForall type (ps.map (mkFVar ·.fvarId))

mutual
partial def ppFunDecl (funDecl : FunDecl) : M Format := do
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Compiler/LCNF/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ scoped notation:max "◾" => lcErased
namespace LCNF

def erasedExpr := mkConst ``lcErased
def anyExpr := mkConst ``lcAny

def _root_.Lean.Expr.isErased (e : Expr) :=
e.isAppOf ``lcErased
Expand Down
18 changes: 8 additions & 10 deletions tests/lean/lcnfTypes.lean.expected.out
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,14 @@ weird1 : Bool → ◾
lamAny₁ : Bool → Monad ◾
lamAny₂ : Bool → Monad ◾
Term.constFold : List Ty → Ty → _root_.Term lcErased lcErased → _root_.Term lcErased lcErased
Term.denote : List Ty → Ty → _root_.Term lcErased lcErased → HList Ty lcErased lcErased → lcErased
HList.get : lcErased →
lcErased → List lcErased → lcErased → HList lcErased lcErased lcErased → Member lcErased lcErased lcErased → lcErased
Member.head : lcErased → lcErased → List lcErased → Member lcErased lcErased lcErased
Term.denote : lcErased
HList.get : lcErased → lcErased → List lcAny → lcAny → HList lcAny lcErased lcErased → Member lcAny lcErased lcErased → lcAny
Member.head : lcErased → lcAny → List lcAny → Member lcAny lcErased lcErased
Ty.denote : lcErased
MonadControl.liftWith : lcErased →
lcErased → MonadControl lcErased lcErased → lcErased → ((lcErased → lcErased → lcErased) → lcErased) → lcErased
MonadControl.restoreM : lcErased → lcErased → MonadControl lcErased lcErased → lcErased → lcErased → lcErased
Decidable.casesOn : lcErased → lcErased → Bool → (lcErased → lcErased) → (lcErased → lcErased) → lcErased
Lean.getConstInfo : lcErased → Monad lcErased → MonadEnv lcErased → MonadError lcErased → Name → lcErased
MonadControl.liftWith : lcErased → lcErased → MonadControl lcErased lcErased → lcErased → ((lcErased → lcAny → lcAny) → lcAny) → lcAny
MonadControl.restoreM : lcErased → lcErased → MonadControl lcErased lcErased → lcErased → lcAny → lcAny
Decidable.casesOn : lcErased → lcErased → Bool → (lcErased → lcAny) → (lcErased → lcAny) → lcAny
Lean.getConstInfo : lcErased → Monad lcErased → MonadEnv lcErased → MonadError lcErased → Name → lcAny
Lean.Meta.instMonadMetaM : Monad lcErased
Lean.Meta.inferType : Expr → Context → lcErased → Core.Context → lcErased → PUnit → EStateM.Result Exception PUnit Expr
Lean.Elab.Term.elabTerm : Syntax →
Expand All @@ -54,4 +52,4 @@ Lean.Elab.Term.elabTerm : Syntax →
lcErased → Context → lcErased → Core.Context → lcErased → PUnit → EStateM.Result Exception PUnit Expr
Nat.add : Nat → Nat → Nat
Fin.add : Nat → Nat → Nat → Nat
Lean.HashSetBucket.update : lcErased → Array (List lcErased) → USize → List lcErased → lcErased → Array (List lcErased)
Lean.HashSetBucket.update : lcErased → Array (List lcAny) → USize → List lcAny → lcErased → Array (List lcAny)
2 changes: 1 addition & 1 deletion tests/lean/run/erased.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ info: [Compiler.result] size: 1
let _x.1 : PSigma lcErased lcErased := PSigma.mk lcErased ◾ ◾ ◾;
return _x.1
[Compiler.result] size: 1
def Erased.mk (α : lcErased) (a : lcErased) : PSigma lcErased lcErased :=
def Erased.mk (α : lcErased) (a : lcAny) : PSigma lcErased lcErased :=
let _x.1 : PSigma lcErased lcErased := PSigma.mk lcErased ◾ ◾ ◾;
return _x.1
-/
Expand Down
264 changes: 264 additions & 0 deletions tests/lean/run/lcnfErasure.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import Lean
import Lean.Compiler.LCNF.MonoTypes
import Lean.Compiler.LCNF.Types

open Lean Meta
open Compiler.LCNF (toLCNFType toMonoType)

def toMonoLCNFType (type : Expr) : MetaM Expr := do
toMonoType (← toLCNFType type)

def checkMonoType! (type₁ type₂ : Expr) : MetaM Unit := do
let monoType ← toMonoLCNFType type₁
if monoType != type₂ then
throwError f!"mono type for {type₁} is {monoType}, expected {type₂}"
let monoMonoType ← toMonoType monoType
if monoMonoType != monoType then
throwError f!"toMonoType is not idempotent: toMonoType of {monoType} is {monoMonoType}"

-- Nat

#eval checkMonoType!
(.const ``Nat [])
(.const ``Nat [])

-- Decidable

#eval checkMonoType!
(.const ``Decidable [])
(.const ``Bool [])

-- Prop

#eval checkMonoType!
(.sort .zero)
(.const ``lcErased [])

-- Type

#eval checkMonoType!
(.sort (.succ .zero))
(.const ``lcErased [])

-- Sort u

#eval checkMonoType!
(.sort (.param `u))
(.const ``lcErased [])

-- List Nat

#eval checkMonoType!
(.app (.const ``List [.succ .zero]) (.const ``Nat []))
(.app (.const ``List []) (.const ``Nat []))

-- List Type

#eval checkMonoType!
(.app (.const ``List [.succ (.succ .zero)]) (.sort (.succ .zero)))
(.app (.const ``List []) (.const ``lcErased []))

-- Inductive type with trivial structure

inductive TrivialInductive : Type where
| constructor (a : Nat) : TrivialInductive

#eval checkMonoType!
(.const ``TrivialInductive [])
(.const ``Nat [])

-- Inductive type with trivial structure and irrelevant fields

inductive TrivialInductivePropFields : Type where
| constructor (p₁ : Prop) (a : Nat) (p₂ : Prop) : TrivialInductivePropFields

#eval checkMonoType!
(.const ``TrivialInductivePropFields [])
(.const ``Nat [])

-- Structure type with trivial structure

structure TrivialStructure : Type where
a : Nat

#eval checkMonoType!
(.const ``TrivialStructure [])
(.const ``Nat [])

-- Structure type with trivial structure and irrelevant fields

structure TrivialStructurePropFields : Type where
p₁ : Prop
a : Nat
p₂ : Prop

#eval checkMonoType!
(.const ``TrivialStructurePropFields [])
(.const ``Nat [])

-- Nat → Nat

#eval checkMonoType!
(.forallE `a (.const ``Nat []) (.const ``Nat []) .default)
(.forallE `a (.const ``Nat []) (.const ``Nat []) .default)

-- Nat → List Nat

#eval checkMonoType!
(.forallE `a (.const ``Nat []) (.app (.const ``List [.succ .zero]) (.const ``Nat [])) .default)
(.forallE `a (.const ``Nat []) (.app (.const ``List []) (.const ``Nat [])) .default)

-- Nat → Prop

#eval checkMonoType!
(.forallE `a (.const ``Nat []) (.sort .zero) .default)
(.const ``lcErased [])

-- Nat → Type

#eval checkMonoType!
(.forallE `a (.const ``Nat []) (.sort (.succ .zero)) .default)
(.const ``lcErased [])

-- Nat → Bool → Type

#eval checkMonoType!
(.forallE `a
(.const ``Nat [])
(.forallE `a (.const ``Bool []) (.sort (.succ .zero)) .default)
.default)
(.const ``lcErased [])

-- (α : Type) → List α

#eval checkMonoType!
(.forallE `α (.sort (.succ .zero)) (.app (.const ``List [.succ .zero]) (.bvar 0)) .default)
(.forallE `α (.const ``lcErased []) (.app (.const ``List []) (.const ``lcAny [])) .default)

-- List Nat → List Bool

#eval checkMonoType!
(.forallE `a
(.app (.const ``List [.succ .zero]) (.const ``Nat []))
(.app (.const ``List [.succ .zero]) (.const ``Bool []))
.default)
(.forallE `a
(.app (.const ``List []) (.const ``Nat []))
(.app (.const ``List []) (.const ``Bool []))
.default)

-- List Nat → List Prop

#eval checkMonoType!
(.forallE `a
(.app (.const ``List [.succ .zero]) (.const ``Nat []))
(.app (.const ``List [.succ .zero]) (.sort .zero))
.default)
(.forallE `a
(.app (.const ``List []) (.const ``Nat []))
(.app (.const ``List []) (.const ``lcErased []))
.default)

-- (α : Type) → α → α

#eval checkMonoType!
(.forallE `α
(.sort (.succ .zero))
(.forallE `a (.bvar 0) (.bvar 1) .default)
.default)
(.forallE `α
(.const ``lcErased [])
(.forallE `a (.const ``lcAny []) (.const ``lcAny []) .default)
.default)

-- Nat → (α : Type) → α → Bool

#eval checkMonoType!
(.forallE `a
(.const ``Nat [])
(.forallE `α
(.sort (.succ .zero))
(.forallE `a (.bvar 0) (.const ``Bool []) .default)
.default)
.default)
(.forallE `a
(.const ``Nat [])
(.forallE `α
(.const ``lcErased [])
(.forallE `a (.const ``lcAny []) (.const ``Bool []) .default)
.default)
.default)

-- Nat → Bool → Type

#eval checkMonoType!
(.forallE `a
(.const ``Nat [])
(.forallE `b (.const ``Bool []) (.sort (.succ .zero)) .default)
.default)
(.const ``lcErased [])

-- Nat → Bool → (Nat → Type)

#eval checkMonoType!
(.forallE `a
(.const ``Nat [])
(.forallE `b (.const ``Bool []) (.sort (.succ .zero)) .default)
.default)
(.const ``lcErased [])

-- Nat → (Nat → Type) → Bool

#eval checkMonoType!
(.forallE `a
(.const ``Nat [])
(.forallE `b
(.forallE `c (.const ``lcErased []) (.sort (.succ .zero)) .default)
(.const ``Bool [])
.default)
.default)
(.forallE `a
(.const ``Nat [])
(.forallE `b
(.const ``lcErased [])
(.const ``Bool [])
.default)
.default)

-- (α : Sort u) → (β : α → Sort v) → (a : α) → ((x : α) → β x) → β a

#eval checkMonoType!
(.forallE
(.sort (.param `u))
(.forallE
(.forallE `f1 (.bvar 0) (.sort (.param `v)) .default)
(.forallE
`a
(.bvar 1)
(.forallE
`f2
(.forallE `x (.bvar 2) (.app (.bvar 2) (.bvar 0)) .default)
(.app (.bvar 2) (.bvar 1))
.default)
.default)
.default)
.default)
(.forallE
(.const ``lcErased [])
(.forallE
(.const ``lcErased [])
(.forallE
`a
(.const ``lcAny [])
(.forallE
`f2
(.forallE `x (.const ``lcAny []) (.const ``lcAny []) .default)
(.const ``lcAny [])
.default)
.default)
.default)
.default)
Loading