Skip to content

Commit

Permalink
Record local variables' kinds during lambda lifting
Browse files Browse the repository at this point in the history
Previously, `singletons-th` made no effort to track the kinds of local
variables when generating lambda-lifted code, instead generating local variable
binders with no kind annotations. As a result, GHC would generalize the kinds
of these lambda-lifted definitions to things that are way more polymorphic than
intended. While this technically works in today's GHC, it won't in a future
version of GHC that implements
[GHC#23515](https://gitlab.haskell.org/ghc/ghc/-/issues/23515).

In general, generating kinds for every local variable would require
`singletons-th` to implement something akin to full-blown type inference over
the Template Haskell AST, which is not something I am eager to implement.

Fortunately, there is a relatively simple approach we can do to alleviate this
problem that doesn't require full type inference. In situations where we know
the kind of a local variable (e.g., when there is a top-level signature or
there is a pattern signature), we record the variable's kind and use it when
generating binders for any lambda-lifted definitions that close over the
variable. For the full story on how this works, see `Note [Local variables and
kind information]` `D.S.TH.Promote.Syntax.LocalVar`.

This is not a perfect solution, as there will still be examples of the original
problem that won't be covered by this simple approach (see the Note). This
approach is still much better than what `singletons-th` was doing before, and I
think it's worth using this simple approach even if it doesn't fix 100% of all
cases.

This patch mostly resolves the "Overly polymorphic lambda-lifting, part 2"
section of #601.
  • Loading branch information
RyanGlScott committed Jun 25, 2024
1 parent a351105 commit 71d1877
Show file tree
Hide file tree
Showing 36 changed files with 607 additions and 333 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,11 @@ GradingClient/Database.hs:(0,0)-(0,0): Splicing declarations
type SchSym1 :: [Attribute] -> Schema
type family SchSym1 (a0123456789876543210 :: [Attribute]) :: Schema where
SchSym1 a0123456789876543210 = Sch a0123456789876543210
type family Let0123456789876543210Scrutinee_0123456789876543210Sym0 name0123456789876543210 name'0123456789876543210 u0123456789876543210 attrs0123456789876543210 where
type family Let0123456789876543210Scrutinee_0123456789876543210Sym0 (name0123456789876543210 :: [AChar]) name'0123456789876543210 u0123456789876543210 attrs0123456789876543210 where
Let0123456789876543210Scrutinee_0123456789876543210Sym0 name0123456789876543210 name'0123456789876543210 u0123456789876543210 attrs0123456789876543210 = Let0123456789876543210Scrutinee_0123456789876543210 name0123456789876543210 name'0123456789876543210 u0123456789876543210 attrs0123456789876543210
type family Let0123456789876543210Scrutinee_0123456789876543210 name0123456789876543210 name'0123456789876543210 u0123456789876543210 attrs0123456789876543210 where
type family Let0123456789876543210Scrutinee_0123456789876543210 (name0123456789876543210 :: [AChar]) name'0123456789876543210 u0123456789876543210 attrs0123456789876543210 where
Let0123456789876543210Scrutinee_0123456789876543210 name name' u attrs = Apply (Apply (==@#@$) name) name'
type family Case_0123456789876543210 name0123456789876543210 name'0123456789876543210 u0123456789876543210 attrs0123456789876543210 t where
type family Case_0123456789876543210 (name0123456789876543210 :: [AChar]) name'0123456789876543210 u0123456789876543210 attrs0123456789876543210 t where
Case_0123456789876543210 name name' u attrs 'True = u
Case_0123456789876543210 name name' u attrs 'False = Apply (Apply LookupSym0 name) (Apply SchSym0 attrs)
type LookupSym0 :: (~>) [AChar] ((~>) Schema U)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ InsertionSort/InsertionSortImp.hs:(0,0)-(0,0): Splicing declarations
insertionSort :: [Nat] -> [Nat]
insertionSort [] = []
insertionSort (h : t) = insert h (insertionSort t)
type family Let0123456789876543210Scrutinee_0123456789876543210Sym0 n0123456789876543210 h0123456789876543210 t0123456789876543210 where
type family Let0123456789876543210Scrutinee_0123456789876543210Sym0 (n0123456789876543210 :: Nat) h0123456789876543210 t0123456789876543210 where
Let0123456789876543210Scrutinee_0123456789876543210Sym0 n0123456789876543210 h0123456789876543210 t0123456789876543210 = Let0123456789876543210Scrutinee_0123456789876543210 n0123456789876543210 h0123456789876543210 t0123456789876543210
type family Let0123456789876543210Scrutinee_0123456789876543210 n0123456789876543210 h0123456789876543210 t0123456789876543210 where
type family Let0123456789876543210Scrutinee_0123456789876543210 (n0123456789876543210 :: Nat) h0123456789876543210 t0123456789876543210 where
Let0123456789876543210Scrutinee_0123456789876543210 n h t = Apply (Apply LeqSym0 n) h
type family Case_0123456789876543210 n0123456789876543210 h0123456789876543210 t0123456789876543210 t where
type family Case_0123456789876543210 (n0123456789876543210 :: Nat) h0123456789876543210 t0123456789876543210 t where
Case_0123456789876543210 n h t 'True = Apply (Apply (:@#@$) n) (Apply (Apply (:@#@$) h) t)
Case_0123456789876543210 n h t 'False = Apply (Apply (:@#@$) h) (Apply (Apply InsertSym0 n) t)
type InsertionSortSym0 :: (~>) [Nat] [Nat]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,41 +37,41 @@ Singletons/CaseExpressions.hs:(0,0)-(0,0): Splicing declarations
in z
foo5 :: a -> a
foo5 x = case x of y -> (\ _ -> x) y
type family Case_0123456789876543210 arg_01234567898765432100123456789876543210 y0123456789876543210 x0123456789876543210 t where
type family Case_0123456789876543210 arg_01234567898765432100123456789876543210 y0123456789876543210 (x0123456789876543210 :: a0123456789876543210) t where
Case_0123456789876543210 arg_0123456789876543210 y x _ = x
type family Lambda_0123456789876543210 y0123456789876543210 x0123456789876543210 arg_0123456789876543210 where
type family Lambda_0123456789876543210 y0123456789876543210 (x0123456789876543210 :: a0123456789876543210) arg_0123456789876543210 where
Lambda_0123456789876543210 y x arg_0123456789876543210 = Case_0123456789876543210 arg_0123456789876543210 y x arg_0123456789876543210
data Lambda_0123456789876543210Sym0 y0123456789876543210 x0123456789876543210 arg_01234567898765432100123456789876543210
data Lambda_0123456789876543210Sym0 y0123456789876543210 (x0123456789876543210 :: a0123456789876543210) arg_01234567898765432100123456789876543210
where
Lambda_0123456789876543210Sym0KindInference :: SameKind (Apply (Lambda_0123456789876543210Sym0 y0123456789876543210 x0123456789876543210) arg) (Lambda_0123456789876543210Sym1 y0123456789876543210 x0123456789876543210 arg) =>
Lambda_0123456789876543210Sym0 y0123456789876543210 x0123456789876543210 arg_01234567898765432100123456789876543210
type instance Apply @_ @_ (Lambda_0123456789876543210Sym0 y0123456789876543210 x0123456789876543210) arg_01234567898765432100123456789876543210 = Lambda_0123456789876543210 y0123456789876543210 x0123456789876543210 arg_01234567898765432100123456789876543210
instance SuppressUnusedWarnings (Lambda_0123456789876543210Sym0 y0123456789876543210 x0123456789876543210) where
suppressUnusedWarnings
= snd ((,) Lambda_0123456789876543210Sym0KindInference ())
type family Lambda_0123456789876543210Sym1 y0123456789876543210 x0123456789876543210 arg_01234567898765432100123456789876543210 where
type family Lambda_0123456789876543210Sym1 y0123456789876543210 (x0123456789876543210 :: a0123456789876543210) arg_01234567898765432100123456789876543210 where
Lambda_0123456789876543210Sym1 y0123456789876543210 x0123456789876543210 arg_01234567898765432100123456789876543210 = Lambda_0123456789876543210 y0123456789876543210 x0123456789876543210 arg_01234567898765432100123456789876543210
type family Case_0123456789876543210 x0123456789876543210 t where
type family Case_0123456789876543210 (x0123456789876543210 :: a0123456789876543210) t where
Case_0123456789876543210 x y = Apply (Lambda_0123456789876543210Sym0 y x) y
type family Let0123456789876543210ZSym0 a0123456789876543210 y0123456789876543210 x0123456789876543210 :: a0123456789876543210 where
type family Let0123456789876543210ZSym0 a0123456789876543210 y0123456789876543210 (x0123456789876543210 :: a0123456789876543210) :: a0123456789876543210 where
Let0123456789876543210ZSym0 a0123456789876543210 y0123456789876543210 x0123456789876543210 = Let0123456789876543210Z a0123456789876543210 y0123456789876543210 x0123456789876543210
type family Let0123456789876543210Z a0123456789876543210 y0123456789876543210 x0123456789876543210 :: a0123456789876543210 where
type family Let0123456789876543210Z a0123456789876543210 y0123456789876543210 (x0123456789876543210 :: a0123456789876543210) :: a0123456789876543210 where
Let0123456789876543210Z a y x = y
type family Case_0123456789876543210 a0123456789876543210 x0123456789876543210 t where
type family Case_0123456789876543210 a0123456789876543210 (x0123456789876543210 :: a0123456789876543210) t where
Case_0123456789876543210 a x y = Let0123456789876543210ZSym0 a y x
type family Let0123456789876543210Scrutinee_0123456789876543210Sym0 a0123456789876543210 b0123456789876543210 where
type family Let0123456789876543210Scrutinee_0123456789876543210Sym0 (a0123456789876543210 :: a0123456789876543210) (b0123456789876543210 :: b0123456789876543210) where
Let0123456789876543210Scrutinee_0123456789876543210Sym0 a0123456789876543210 b0123456789876543210 = Let0123456789876543210Scrutinee_0123456789876543210 a0123456789876543210 b0123456789876543210
type family Let0123456789876543210Scrutinee_0123456789876543210 a0123456789876543210 b0123456789876543210 where
type family Let0123456789876543210Scrutinee_0123456789876543210 (a0123456789876543210 :: a0123456789876543210) (b0123456789876543210 :: b0123456789876543210) where
Let0123456789876543210Scrutinee_0123456789876543210 a b = Apply (Apply Tuple2Sym0 a) b
type family Case_0123456789876543210 a0123456789876543210 b0123456789876543210 t where
type family Case_0123456789876543210 (a0123456789876543210 :: a0123456789876543210) (b0123456789876543210 :: b0123456789876543210) t where
Case_0123456789876543210 a b '(p, _) = p
type family Let0123456789876543210Scrutinee_0123456789876543210Sym0 d0123456789876543210 where
type family Let0123456789876543210Scrutinee_0123456789876543210Sym0 (d0123456789876543210 :: a0123456789876543210) where
Let0123456789876543210Scrutinee_0123456789876543210Sym0 d0123456789876543210 = Let0123456789876543210Scrutinee_0123456789876543210 d0123456789876543210
type family Let0123456789876543210Scrutinee_0123456789876543210 d0123456789876543210 where
type family Let0123456789876543210Scrutinee_0123456789876543210 (d0123456789876543210 :: a0123456789876543210) where
Let0123456789876543210Scrutinee_0123456789876543210 d = Apply JustSym0 d
type family Case_0123456789876543210 d0123456789876543210 t where
type family Case_0123456789876543210 (d0123456789876543210 :: a0123456789876543210) t where
Case_0123456789876543210 d ('Just y) = y
type family Case_0123456789876543210 d0123456789876543210 x0123456789876543210 t where
type family Case_0123456789876543210 (d0123456789876543210 :: a0123456789876543210) (x0123456789876543210 :: Maybe a0123456789876543210) t where
Case_0123456789876543210 d x ('Just y) = y
Case_0123456789876543210 d x 'Nothing = d
type Foo5Sym0 :: (~>) a a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Singletons/EmptyShowDeriving.hs:(0,0)-(0,0): Splicing declarations
======>
data Foo
deriving instance Show Foo
type family Case_0123456789876543210 v_01234567898765432100123456789876543210 a_01234567898765432100123456789876543210 t where
type family Case_0123456789876543210 (v_01234567898765432100123456789876543210 :: Foo) (a_01234567898765432100123456789876543210 :: GHC.Types.Symbol) t where
type ShowsPrec_0123456789876543210 :: GHC.Num.Natural.Natural
-> Foo -> GHC.Types.Symbol -> GHC.Types.Symbol
type family ShowsPrec_0123456789876543210 (a :: GHC.Num.Natural.Natural) (a :: Foo) (a :: GHC.Types.Symbol) :: GHC.Types.Symbol where
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ Singletons/EnumDeriving.hs:(0,0)-(0,0): Splicing declarations
type Q2Sym0 :: Quux
type family Q2Sym0 :: Quux where
Q2Sym0 = Q2
type family Case_0123456789876543210 n0123456789876543210 t where
type family Case_0123456789876543210 (n0123456789876543210 :: GHC.Num.Natural.Natural) t where
Case_0123456789876543210 n 'True = BumSym0
Case_0123456789876543210 n 'False = Apply ErrorSym0 "toEnum: bad argument"
type family Case_0123456789876543210 n0123456789876543210 t where
type family Case_0123456789876543210 (n0123456789876543210 :: GHC.Num.Natural.Natural) t where
Case_0123456789876543210 n 'True = BazSym0
Case_0123456789876543210 n 'False = Case_0123456789876543210 n (Apply (Apply (==@#@$) n) (FromInteger 2))
type family Case_0123456789876543210 n0123456789876543210 t where
type family Case_0123456789876543210 (n0123456789876543210 :: GHC.Num.Natural.Natural) t where
Case_0123456789876543210 n 'True = BarSym0
Case_0123456789876543210 n 'False = Case_0123456789876543210 n (Apply (Apply (==@#@$) n) (FromInteger 1))
type ToEnum_0123456789876543210 :: GHC.Num.Natural.Natural -> Foo
Expand Down Expand Up @@ -117,10 +117,10 @@ Singletons/EnumDeriving.hs:(0,0)-(0,0): Splicing declarations
Singletons/EnumDeriving.hs:0:0:: Splicing declarations
singEnumInstance ''Quux
======>
type family Case_0123456789876543210 n0123456789876543210 t where
type family Case_0123456789876543210 (n0123456789876543210 :: GHC.Num.Natural.Natural) t where
Case_0123456789876543210 n 'True = Q2Sym0
Case_0123456789876543210 n 'False = Apply ErrorSym0 "toEnum: bad argument"
type family Case_0123456789876543210 n0123456789876543210 t where
type family Case_0123456789876543210 (n0123456789876543210 :: GHC.Num.Natural.Natural) t where
Case_0123456789876543210 n 'True = Q1Sym0
Case_0123456789876543210 n 'False = Case_0123456789876543210 n (Apply (Apply (==@#@$) n) (FromInteger 1))
type ToEnum_0123456789876543210 :: GHC.Num.Natural.Natural -> Quux
Expand Down
Loading

0 comments on commit 71d1877

Please sign in to comment.