Skip to content

Commit

Permalink
fix: Identification binding for polymorphic rigid head
Browse files Browse the repository at this point in the history
  • Loading branch information
PratherConid committed Oct 10, 2023
1 parent b885c1d commit a398140
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 157 deletions.
109 changes: 77 additions & 32 deletions Duper/DUnif/Bindings.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import Duper.Util.Misc
import Duper.Util.OccursCheck
import Duper.Util.LazyList
import Duper.DUnif.UnifProblem
import Duper.DUnif.Utils
open Lean

-- Note:
Expand All @@ -14,8 +15,6 @@ open Lean

-- TODO: Use 'withLocalDeclD'

open Duper

namespace DUnif

def withoutModifyingMCtx (x : MetaM α) : MetaM α := do
Expand Down Expand Up @@ -79,7 +78,7 @@ def iteration (F : Expr) (p : UnifProblem) (eq : UnifEq) (funcArgOnly : Bool) :
return LazyList.interleaveN iterAtIArr
)

-- `F` is a metavariable
/-- `F` is a metavariable -/
def jpProjection (F : Expr) (p : UnifProblem) (eq : UnifEq) : MetaM (Array UnifProblem) := do
setMCtx p.mctx
let Fty ← Meta.inferType F
Expand All @@ -103,7 +102,7 @@ def jpProjection (F : Expr) (p : UnifProblem) (eq : UnifEq) : MetaM (Array UnifP
setMCtx s₀
return ret)

-- `F` is a metavariable
/-- `F` is a metavariable -/
def huetProjection (F : Expr) (p : UnifProblem) (eq : UnifEq) : MetaM (Array UnifProblem) := do
setMCtx p.mctx
let Fty ← Meta.inferType F
Expand Down Expand Up @@ -131,7 +130,7 @@ def huetProjection (F : Expr) (p : UnifProblem) (eq : UnifEq) : MetaM (Array Uni
ret := ret.append newProblem
return ret)

-- `F` is a metavariable
/-- `F` is a metavariable -/
def imitForall (F : Expr) (p : UnifProblem) (eq : UnifEq) : MetaM (Array UnifProblem) := do
setMCtx p.mctx
let Fty ← Meta.inferType F
Expand All @@ -146,7 +145,7 @@ def imitForall (F : Expr) (p : UnifProblem) (eq : UnifEq) : MetaM (Array UnifPro
MVarId.assign F.mvarId! mt
return #[{(← p.pushParentRuleIfDbgOn (.ImitForall eq F mt)) with checked := false, mctx := ← getMCtx}]

-- `F` is a metavariable
/-- `F` is a metavariable -/
def imitProj (F : Expr) (nLam : Nat) (iTy : Expr) (oTy : Expr) (name : Name) (idx : Nat) (p : UnifProblem) (eq : UnifEq) : MetaM (Array UnifProblem) := do
setMCtx p.mctx
let Fty ← Meta.inferType F
Expand All @@ -166,23 +165,67 @@ def imitProj (F : Expr) (nLam : Nat) (iTy : Expr) (oTy : Expr) (name : Name) (id
MVarId.assign F.mvarId! mt
return #[{(← p.pushParentRuleIfDbgOn (.ImitProj eq F idx mt)) with checked := false, mctx := ← getMCtx}]

-- `F` is a metavariable, and `g` is a constant
/--
`F` is a metavariable, and `g` is a constant
Suppose
· The unification equation is
`(fun bin_F => F t₁ t₂ ⋯ tₙ) = (fun bin_g => g s₁ s₂ ⋯ sₘ)`
· `F` is of type `∀ (x₁ : α₁) ⋯ (xₖ : αₖ), β`
· `g` is of type `∀ (y₁ : γ₁) ⋯ (yₗ : γₗ), δ`
Then the binding for `F` should be
`binding := fun (x₁ : α₁) ⋯ (xₖ : αₖ) => g (?H₁ ⋯) (?H₂ ⋯) ⋯ (?Hₕ ⋯)`
Since the unification equation is eta-expanded, we have
· `k ≤ n, l ≤ m`
If we plug the binding into the original unification equation and headbeta
the left-hand side, we will see that `h + n - k - len(bin_F) = m - len(bin_g)`, i.e.
· `h = m + k + len(bin_F) - n - len(bin_g)`
The above equation can be used to determine the value of `h`.
Now we specify the types of fresh metavariables and the resulting equations
· The type of `?Hᵢ (1 ≤ i ≤ min (l, h))` is taken care of by `forallMetaTelescopeReducing`
· If `h ≤ l`, the type of `binding` should be unified with the type of `F`. This
unification equation should be prioritized
· If `h > l`, the type of `fun (x₁ : α₁) ⋯ (xₖ : αₖ) => g (?H₁ ⋯) (?H₂ ⋯) ⋯ (?Hₗ ⋯)`
should be unified with `∀ (z₁ : ?η₁) ⋯ (zₙ : ?ηₕ₋ₗ), β`. This unification equation
should be prioritized. Moreover, the type of `?Hᵢ` should be obtained by calling
`forallMetaTelescope` on `∀ (z₁ : ?η₁) ⋯ (zₙ : ?ηₕ₋ₗ), β`
-/
def imitation (F : Expr) (g : Expr) (p : UnifProblem) (eq : UnifEq) : MetaM (Array UnifProblem) := do
setMCtx p.mctx
let Fty ← Meta.inferType F
let gty ← Meta.inferType g
let (_, si_F) ← structInfo p eq.lhs
let (bin_F, _) := si_F.getLambdaForall; let nargs_F := si_F.getNArgs
let (_, si_g) ← structInfo p eq.rhs
let (bin_g, _) := si_g.getLambdaForall; let nargs_g := si_g.getNArgs
Meta.forallTelescopeReducing Fty fun xs β => do
let (ys, _, β') ← Meta.forallMetaTelescopeReducing gty
let h := nargs_g + xs.size + bin_F - nargs_F - bin_g
let mut p := p
if β' != β then
-- We need to unify their types first
let β ← Meta.mkLambdaFVars xs β
let β' ← Meta.mkLambdaFVars xs β'
p := p.pushPrioritized (UnifEq.fromExprPair β β')
-- Apply the binding to `F`
let mt ← Meta.mkLambdaFVars xs (mkAppN g ys)
MVarId.assign F.mvarId! mt
return #[{(← p.pushParentRuleIfDbgOn (.Imitation eq F g mt)) with checked := false, mctx := ← getMCtx}]
if h ≤ ys.size then
-- Override `β'`
let β' ← Meta.instantiateForall gty (ys.toSubarray 0 h)
if β' != β then
-- We need to unify their types first
let β ← Meta.mkLambdaFVars xs β
let β' ← Meta.mkLambdaFVars xs β'
p := p.pushPrioritized (UnifEq.fromExprPair β β')
-- Apply the binding to `F`
let mt ← Meta.mkLambdaFVars xs (mkAppN g ys)
MVarId.assign F.mvarId! mt
return #[{(← p.pushParentRuleIfDbgOn (.Imitation eq F g mt)) with checked := false, mctx := ← getMCtx}]
else
let βAbst ← mkGeneralFnTy (h - ys.size) β
-- Put them in a block so as not to affect `βAbst` and `β` on the outside
if true then
let βAbst ← Meta.mkLambdaFVars xs βAbst
let β' ← Meta.mkLambdaFVars xs β'
p := p.pushPrioritized (UnifEq.fromExprPair βAbst β')
let (ysExtra, _, _) ← Meta.forallMetaBoundedTelescope βAbst (h - ys.size)
-- Apply the binding to `F`
let mt ← Meta.mkLambdaFVars xs (mkAppN g (ys ++ ysExtra))
MVarId.assign F.mvarId! mt
return #[{(← p.pushParentRuleIfDbgOn (.Imitation eq F g mt)) with checked := false, mctx := ← getMCtx}]

def elimination (F : Expr) (p : UnifProblem) (eq : UnifEq) : MetaM (LazyList <| MetaM (Array UnifProblem)) := do
setMCtx p.mctx
Expand Down Expand Up @@ -221,22 +264,24 @@ def elimination (F : Expr) (p : UnifProblem) (eq : UnifEq) : MetaM (LazyList <|
return #[res]
return indsubseqs.map nats2binding

-- Both `F` and `G` are metavariables
-- Proposal
-- Premises
-- F : (x₁ : α₁) → (x₂ : α₂) → ⋯ → (xₙ : αₙ) → β x₁ x₂ ⋯ xₙ (F : ∀ [x], β [x])
-- G : (y₁ : γ₁) → (y₂ : γ₂) → ⋯ → (yₘ : γₘ) → δ y₁ y₂ ⋯ yₙ (G : ∀ [y], δ [y])
---------------------------------------------------------------
-- Binding
-- η : ∀ [x] [y], Type ?u
-- H : ∀ [x] [y], η [x] [y]
-- F ↦ λ [x]. H [x] (F₁ [x]) ⋯ (Fₘ [x])
-- G ↦ λ [y]. H (G₁ [y]) ⋯ (Gₙ [y]) [y]
-- Extra Unification Problems:
-- λ[x]. η [x] (F₁ [x]) ⋯ (Fₘ [x]) =? λ[x]. β [x]
-- λ[y]. η (G₁ [y]) ⋯ (Gₙ [y]) [y] =? λ[y]. δ [y]
-- Side condition: `F` cannot depend on `G`, and `G` cannot depend on `F`.
-- If any of `F` or `G` depends on another, switch to `elimination`
/--
Both `F` and `G` are metavariables
Proposal
Premises
F : (x₁ : α₁) → (x₂ : α₂) → ⋯ → (xₙ : αₙ) → β x₁ x₂ ⋯ xₙ (F : ∀ [x], β [x])
G : (y₁ : γ₁) → (y₂ : γ₂) → ⋯ → (yₘ : γₘ) → δ y₁ y₂ ⋯ yₙ (G : ∀ [y], δ [y])
-------------------------------------------------------------
Binding
η : ∀ [x] [y], Type ?u
H : ∀ [x] [y], η [x] [y]
F ↦ λ [x]. H [x] (F₁ [x]) ⋯ (Fₘ [x])
G ↦ λ [y]. H (G₁ [y]) ⋯ (Gₙ [y]) [y]
Extra Unification Problems:
λ[x]. η [x] (F₁ [x]) ⋯ (Fₘ [x]) =? λ[x]. β [x]
λ[y]. η (G₁ [y]) ⋯ (Gₙ [y]) [y] =? λ[y]. δ [y]
Side condition: `F` cannot depend on `G`, and `G` cannot depend on `F`.
If any of `F` or `G` depends on another, switch to `elimination`
-/
def identification (F : Expr) (G : Expr) (p : UnifProblem) (eq : UnifEq) : MetaM UnifRuleResult := do
setMCtx p.mctx
let Fty ← Meta.inferType F
Expand Down
3 changes: 2 additions & 1 deletion Duper/DUnif/DRefl.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def MVarId.refl (mvarId : MVarId) (nAttempt : Nat) (nUnif : Nat) (cont : Nat) (i

syntax (name := drefl) "drefl" " attempt " num "unifier " num "contains" num ("iteron")? : tactic


@[tactic drefl] def evalRefl : Elab.Tactic.Tactic := fun stx =>
match stx with
| `(tactic| drefl attempt $nAttempt unifier $nunif contains $cont iteron) =>
Expand All @@ -39,3 +38,5 @@ syntax (name := drefl) "drefl" " attempt " num "unifier " num "contains" num ("i
let ids ← DUnif.MVarId.refl mvarId nAttempt.getNat nunif.getNat cont.getNat false
return ids.data
| _ => Elab.throwUnsupportedSyntax

end DUnif
11 changes: 8 additions & 3 deletions Duper/DUnif/Oracles.lean
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import Lean
import Duper.DUnif.Utils
import Duper.DUnif.UnifProblem
import Duper.Util.OccursCheck

open Lean
namespace Duper

namespace DUnif

initialize
registerTraceClass `DUnif.oracles

register_option oracleInstOn : Bool := {
defValue := true
Expand All @@ -21,11 +26,11 @@ def oracleInst (p : UnifProblem) (eq : UnifEq) : MetaM (Option UnifProblem) := d
if ¬ (← getOracleInstOn) then
return none
let mut eq := eq
if let .mvar id := eq.rhs.eta then
if let .some id ← metaEta eq.rhs then
eq := eq.swapSide
mvarId := id
else
if let .mvar id := eq.lhs.eta then
if let .some id ← metaEta eq.lhs.eta then
mvarId := id
else
return none
Expand Down
101 changes: 100 additions & 1 deletion Duper/DUnif/UnifProblem.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import Duper.Util.MessageData
import Duper.Util.LazyList
open Lean

namespace Duper
namespace DUnif

initialize Lean.registerTraceClass `DUnif.debug
initialize Lean.registerTraceClass `DUnif.result
Expand Down Expand Up @@ -296,6 +296,105 @@ def UnifProblem.instantiateTrackedExpr (p : UnifProblem) : MetaM UnifProblem :=
let trackedExpr ← p.trackedExpr.mapM instantiateMVars
return {p with trackedExpr := trackedExpr}

inductive StructType where
-- Things considered as `const`:
-- 1. constants
-- 2. free variables
-- 3. metavariables not of current depth
-- 4. literals
-- The first `Nat` is the number of `lambda`s
-- The second `Nat` is the number of `forall`s
-- The third `Nat` is the number of arguments
| Const : Nat → Nat → Nat → StructType
-- `proj _ · idx` is viewed as a function, with type
-- `innerTy → outerTy` (with variables abstracted).
-- Irreducible `proj`s are viewed as rigid
| Proj : Nat → Nat → Nat → (innerTy : Expr) → (outerTy : Expr) → (name : Name) → (idx : Nat) → StructType
| Bound : Nat → Nat → Nat → StructType
| MVar : Nat → Nat → Nat → StructType
-- Currently, `mdata`, `forall`, `let`
| Other : Nat → Nat → Nat → StructType
deriving Hashable, Inhabited, BEq, Repr

instance : ToString StructType where
toString (ht : StructType) : String :=
match ht with
| .Const l f a => s!"StructType.Const {l} {f} {a}"
| .Proj l f a iTy oTy _ idx => s!"StructType.Proj {l} {f} {a} iTy = {iTy} oTy = {oTy} idx = {idx}"
| .Bound l f a => s!"StructType.Bound {l} {f}"
| .MVar l f a => s!"StructType.MVar {l} {f} {a}"
| .Other l f a => s!"StructType.Other {l} {f} {a}"

def StructType.getLambdaForall : StructType → Nat × Nat
| Const a b _ => (a, b)
| Proj a b _ _ _ _ _ => (a, b)
| Bound a b _ => (a, b)
| MVar a b _ => (a, b)
| Other a b _ => (a, b)

def StructType.getNArgs : StructType → Nat
| Const _ _ a => a
| Proj _ _ a _ _ _ _ => a
| Bound _ _ a => a
| MVar _ _ a => a
| Other _ _ a => a

def StructType.isFlex : StructType → Bool
| Const _ _ _ => false
| Proj _ _ _ _ _ _ _ => false
| Bound _ _ _ => false
| MVar _ _ _ => true
| Other _ _ _ => false

def StructType.isRigid : StructType → Bool
| Const _ _ _ => true
| Proj _ _ _ _ _ _ _ => true
| Bound _ _ _ => true
| MVar _ _ _ => false
-- If headType is `other`, then we assume that the head is rigid
| Other _ _ _ => true

def projName! : Expr → Name
| .proj n _ _ => n
| _ => panic! "proj expression expected"

def structInfo (p : UnifProblem) (e : Expr) : MetaM (Expr × StructType) := do
setMCtx p.mctx
Meta.lambdaTelescope e fun xs t => Meta.forallTelescope t fun ys b => do
let h := Expr.getAppFn b
let args := Expr.getAppArgs b
if h.isFVar then
let mut bound := false
for x in xs ++ ys do
if x == h then
bound := true
if bound then
return (h, .Bound xs.size ys.size args.size)
else
return (h, .Const xs.size ys.size args.size)
else if h.isConst ∨ h.isSort ∨ h.isLit then
return (h, .Const xs.size ys.size args.size)
else if h.consumeMData.isMVar then
let decl := (← getMCtx).getDecl h.mvarId!
if decl.depth != (← getMCtx).depth then
return (h, .Const xs.size ys.size args.size)
else
return (h, .MVar xs.size ys.size args.size)
else if h.isProj ∧ ys.size == 0 then
let idx := h.projIdx!
let expr := h.projExpr!
let name := projName! h
let innerTy ← Meta.inferType expr
let outerTy ← Meta.inferType h
let innerTyAbst ← Meta.mkForallFVars xs innerTy
let outerTyAbst ← Meta.mkForallFVars xs outerTy
return (.lit (.strVal "You shouldn't see me. I'm in `structInfo`"),
.Proj xs.size ys.size args.size innerTyAbst outerTyAbst name idx)
else
-- If the type is `other`, then free variables might
-- occur inside the head, so we must abstract them
return (← Meta.mkLambdaFVars xs (← Meta.mkForallFVars ys h), .Other xs.size ys.size args.size)

-- MetaM : mvar assignments
-- LazyList UnifProblem : unification problems being generated
-- Bool : True -> Succeed, False -> Fail
Expand Down
Loading

0 comments on commit a398140

Please sign in to comment.