From 0d2e0c70deafc88bd3e33d4ff3660f97cc5fd9ae Mon Sep 17 00:00:00 2001 From: rhiannon morris Date: Tue, 24 Oct 2023 23:50:28 +0200 Subject: [PATCH] more erasure --- lib/Quox/Untyped/Erase.idr | 185 +++++++++++++++++++++++++----------- lib/Quox/Untyped/Syntax.idr | 93 +++++++++++++----- 2 files changed, 198 insertions(+), 80 deletions(-) diff --git a/lib/Quox/Untyped/Erase.idr b/lib/Quox/Untyped/Erase.idr index a79e297..89905f8 100644 --- a/lib/Quox/Untyped/Erase.idr +++ b/lib/Quox/Untyped/Erase.idr @@ -1,14 +1,15 @@ module Quox.Untyped.Erase import Quox.Definition as Q +import Quox.Pretty import Quox.Syntax.Term.Base as Q import Quox.Syntax.Term.Subst -import Quox.Untyped.Syntax as U import Quox.Typing +import Quox.Untyped.Syntax as U import Quox.Whnf -import Quox.Pretty import Quox.EffExtra +import Data.List1 import Data.Singleton import Data.SnocVect import Language.Reflection @@ -28,12 +29,6 @@ isErased Zero = Erased isErased One = Kept isErased Any = Kept -public export -ifErased : Qty -> Lazy a -> Lazy a -> a -ifErased pi x y = case isErased pi of - Erased => x - Kept => y - public export ErasureContext : Nat -> Nat -> Type @@ -98,23 +93,18 @@ export covering computeElimType : ErasureContext d n -> SQty -> Elim d n -> Eff Erase (Term d n) computeElimType ctx sg e = do defs <- askAt DEFS + let ctx = toWhnfContext ctx liftWhnf $ do - let ctx = toWhnfContext ctx - Element e enf <- whnf defs ctx sg e + Element e _ <- whnf defs ctx sg e computeElimType defs ctx sg e private %macro wrapExpect : TTImp -> Elab (TyContext d n -> Loc -> Term d n -> Eff Erase a) -wrapExpect f {a} = do - f' <- check `(\x => ~(f) x) - pure $ \ctx, loc, s => wrapExpect' f' ctx loc s -where - wrapExpect' : (Q.Definitions -> TyContext d n -> SQty -> Loc -> Term d n -> - Eff [Except TypeError, NameGen] a) -> - TyContext d n -> Loc -> Term d n -> Eff Erase a - wrapExpect' f ctx loc s = liftWhnf $ f !(askAt DEFS) ctx SZero loc s +wrapExpect f_ = do + f <- check `(\x => ~(f_) x) + pure $ \ctx, loc, s => liftWhnf $ f !(askAt DEFS) ctx SZero loc s public export @@ -263,22 +253,42 @@ eraseElim ctx (App fun arg loc) = do Kept => do arg <- eraseTerm ctx targ arg pure $ EraRes ty $ App efun.term arg loc --- e ⤋ e' ⇒ (x : A) × B +-- e ⇒ (x : A) × B -- x : A, y : B | ρ.x, ρ.y ⊢ s ⤋ s' ⇐ R[((x,y) ∷ (x : A) × B)/z] --- x̃ ≔ if ρ = 0 then ⌷ else fst e' ỹ ≔ if ρ = 0 then ⌷ else snd e' -- ------------------------------------------------------------------- --- (caseρ e return z ⇒ R of {(x, y) ⇒ s}) ⤋ s'[x̃/x, ỹ/y] ⇒ R[e/z] +-- (case0 e return z ⇒ R of {(x, y) ⇒ s}) ⤋ s'[⌷/x, ⌷/y] ⇒ R[e/z] +-- +-- e ⤋ e' ⇒ (x : A) × B ρ ≠ 0 +-- x : A, y : B | ρ.x, ρ.y ⊢ s ⤋ s' ⇐ R[((x,y) ∷ (x : A) × B)/z] +-- ---------------------------------------------------------------------------- +-- (caseρ e return z ⇒ R of {(x, y) ⇒ s}) ⤋ +-- ⤋ +-- let xy = e' in let x = fst xy in let y = snd xy in s' ⇒ R[e/z] eraseElim ctx (CasePair qty pair ret body loc) = do - epair <- eraseElim ctx pair - let ty = sub1 (ret // shift 2) $ - Ann (Pair (BVT 0 loc) (BVT 1 loc) loc) (weakT 2 epair.type) loc - (tfst, tsnd) <- wrapExpect `(expectSig) ctx loc epair.type let [< x, y] = body.names - let ctx' = extendTyN [< (qty, x, tfst), (qty, y, tsnd.term)] ctx - body' <- eraseTerm ctx' ty body.term - let x' = ifErased qty (Erased loc) (Fst epair.term loc) - y' = ifErased qty (Erased loc) (Snd epair.term loc) - pure $ EraRes (sub1 ret pair) $ body' // fromSnocVect [< x', y'] + case isErased qty of + Kept => do + EraRes ety eterm <- eraseElim ctx pair + let ty = sub1 (ret // shift 2) $ + Ann (Pair (BVT 0 loc) (BVT 1 loc) loc) (weakT 2 ety) loc + (tfst, tsnd) <- wrapExpect `(expectSig) ctx loc ety + let ctx' = extendTyN [< (qty, x, tfst), (qty, y, tsnd.term)] ctx + body' <- eraseTerm ctx' ty body.term + p <- mnb "p" loc + pure $ EraRes (sub1 ret pair) $ + Let p eterm + (Let x (Fst (B VZ loc) loc) + (Let y (Snd (B (VS VZ) loc) loc) + (body' // (B VZ loc ::: B (VS VZ) loc ::: shift 3)) + loc) loc) loc + Erased => do + ety <- computeElimType ctx SOne pair + let ty = sub1 (ret // shift 2) $ + Ann (Pair (BVT 0 loc) (BVT 1 loc) loc) (weakT 2 ety) loc + (tfst, tsnd) <- wrapExpect `(expectSig) ctx loc ety + let ctx' = extendTyN0 [< (x, tfst), (y, tsnd.term)] ctx + body' <- eraseTerm ctx' ty body.term + pure $ EraRes (sub1 ret pair) $ subN [< Erased loc, Erased loc] body' -- e ⤋ e' ⇒ (x : A) × B -- ---------------------- @@ -296,33 +306,34 @@ eraseElim ctx (Snd pair loc) = do b <- snd <$> wrapExpect `(expectSig) ctx loc epair.type pure $ EraRes (sub1 b (Fst pair loc)) $ Snd epair.term loc --- case0 e return z ⇒ R of {} ⤋ absurd ⇒ R[e/z] +-- caseρ e return z ⇒ R of {} ⤋ absurd ⇒ R[e/z] -- -- s ⤋ s' ⇐ R[𝐚∷{𝐚}/z] -- ----------------------------------------------- -- case0 e return z ⇒ R of {𝐚 ⇒ s} ⤋ s' ⇒ R[e/z] -- --- e ⤋ e' ⇒ A ρ ≠ 0 sᵢ ⤋ s'ᵢ ⇐ R[𝐚ᵢ/z] +-- e ⤋ e' ⇒ A sᵢ ⤋ s'ᵢ ⇐ R[𝐚ᵢ/z] ρ ≠ 0 i ≠ 0 -- ------------------------------------------------------------------- -- caseρ e return z ⇒ R of {𝐚ᵢ ⇒ sᵢ} ⤋ case e of {𝐚ᵢ ⇒ s'ᵢ} ⇒ R[e/z] -eraseElim ctx e@(CaseEnum qty tag ret arms loc) = +eraseElim ctx e@(CaseEnum qty tag ret arms loc) = do + let ty = sub1 ret tag case isErased qty of Erased => case SortedMap.toList arms of - [] => pure $ EraRes (sub1 ret tag) $ Absurd loc - [(t, arm)] => do - let ty = sub1 ret tag - ty' = sub1 ret (Ann (Tag t loc) (enum [t] loc) loc) - arm' <- eraseTerm ctx ty' arm - pure $ EraRes ty arm' - _ => throw $ CompileTimeOnly ctx $ E e - Kept => do - let ty = sub1 ret tag - etag <- eraseElim ctx tag - arms <- for (SortedMap.toList arms) $ \(t, rhs) => do - let ty' = sub1 ret (Ann (Tag t loc) etag.type loc) + [] => pure $ EraRes ty $ Absurd loc + [(t, rhs)] => do + let ty' = sub1 ret (Ann (Tag t loc) (enum [t] loc) loc) rhs' <- eraseTerm ctx ty' rhs - pure (t, rhs') - pure $ EraRes ty $ CaseEnum etag.term arms loc + pure $ EraRes ty rhs' + _ => throw $ CompileTimeOnly ctx $ E e + Kept => case List1.fromList $ SortedMap.toList arms of + Nothing => pure $ EraRes ty $ Absurd loc + Just arms => do + etag <- eraseElim ctx tag + arms <- for arms $ \(t, rhs) => do + let ty' = sub1 ret (Ann (Tag t loc) etag.type loc) + rhs' <- eraseTerm ctx ty' rhs + pure (t, rhs') + pure $ EraRes ty $ CaseEnum etag.term arms loc -- n ⤋ n' ⇒ ℕ z ⤋ z' ⇐ R[zero∷ℕ/z] ς ≠ 0 -- m : ℕ, ih : R[m/z] | ρ.m, ς.ih ⊢ s ⤋ s' ⇐ R[(succ m)∷ℕ/z] @@ -352,25 +363,24 @@ eraseElim ctx (CaseNat qty qtyIH nat ret zero succ loc) = do Erased => NSNonrec p (sub1 (Erased loc) succ') pure $ EraRes ty $ CaseNat enat.term zero succ loc --- b ⤋ b' ⇒ [π.A] π ≠ 0 --- x : A | πρ.x ⊢ s ⤋ s' ⇐ R[[x]∷[π.A]/z] --- ------------------------------------------------------- --- caseρ b return z ⇒ R of {[x] ⇒ s} ⤋ s'[b'/x] ⇒ R[b/z] +-- b ⤋ b' ⇒ [π.A] πρ ≠ 0 x : A | πρ.x ⊢ s ⤋ s' ⇐ R[[x]∷[π.A]/z] +-- ------------------------------------------------------------------ +-- caseρ b return z ⇒ R of {[x] ⇒ s} ⤋ let x = b' in s' ⇒ R[b/z] -- --- b ⇒ [0.A] x : A | 0.x ⊢ s ⤋ s' ⇐ R[[x]∷[0.A]/z] --- ------------------------------------------------------- --- caseρ b return z ⇒ R of {[x] ⇒ s} ⤋ s'[⌷/x] ⇒ R[b/z] +-- b ⇒ [π.A] x : A | 0.x ⊢ s ⤋ s' ⇐ R[[x]∷[0.A]/z] πρ = 0 +-- ------------------------------------------------------------- +-- caseρ b return z ⇒ R of {[x] ⇒ s} ⤋ s'[⌷/x] ⇒ R[b/z] eraseElim ctx (CaseBox qty box ret body loc) = do tbox <- computeElimType ctx SOne box -- [fixme] is there any way to avoid this? (pi, tinner) <- wrapExpect `(expectBOX) ctx loc tbox let ctx' = extendTy (pi * qty) body.name tinner ctx bty = sub1 (ret // shift 1) $ Ann (Box (BVT 0 loc) loc) (weakT 1 tbox) loc - case isErased pi of + case isErased $ qty * pi of Kept => do ebox <- eraseElim ctx box ebody <- eraseTerm ctx' bty body.term - pure $ EraRes (sub1 ret box) $ ebody // one ebox.term + pure $ EraRes (sub1 ret box) $ Let body.name ebox.term ebody loc Erased => do body' <- eraseTerm ctx' bty body.term pure $ EraRes (sub1 ret box) $ body' // one (Erased loc) @@ -412,6 +422,67 @@ eraseElim ctx (DCloE (Sub term th)) = eraseElim ctx $ pushSubstsWith' th id term +export +uses : Var n -> Term n -> Nat +uses i (F x _) = 0 +uses i (B j _) = if i == j then 1 else 0 +uses i (Lam x body _) = uses (VS i) body +uses i (App fun arg _) = uses i fun + uses i arg +uses i (Pair fst snd _) = uses i fst + uses i snd +uses i (Fst pair _) = uses i pair +uses i (Snd pair _) = uses i pair +uses i (Tag tag _) = 0 +uses i (CaseEnum tag cases _) = + uses i tag + foldl max 0 (map (assert_total uses i . snd) cases) +uses i (Absurd _) = 0 +uses i (Zero _) = 0 +uses i (Succ nat _) = uses i nat +uses i (CaseNat nat zer suc _) = uses i nat + max (uses i zer) (uses' suc) + where uses' : CaseNatSuc n -> Nat + uses' (NSRec _ _ s) = uses (VS (VS i)) s + uses' (NSNonrec _ s) = uses (VS i) s +uses i (Let x rhs body _) = uses i rhs + uses (VS i) body +uses i (Erased _) = 0 + +export +inlineable : Term n -> Bool +inlineable (F {}) = True +inlineable (B {}) = True +inlineable (Tag {}) = True +inlineable (Absurd {}) = True +inlineable (Erased {}) = True +inlineable _ = False + +export +trimLets : Term n -> Term n +trimLets (F x loc) = F x loc +trimLets (B i loc) = B i loc +trimLets (Lam x body loc) = Lam x (trimLets body) loc +trimLets (App fun arg loc) = App (trimLets fun) (trimLets arg) loc +trimLets (Pair fst snd loc) = Pair (trimLets fst) (trimLets snd) loc +trimLets (Fst pair loc) = Fst (trimLets pair) loc +trimLets (Snd pair loc) = Snd (trimLets pair) loc +trimLets (Tag tag loc) = Tag tag loc +trimLets (CaseEnum tag cases loc) = + CaseEnum (trimLets tag) + (map (map $ \c => trimLets $ assert_smaller cases c) cases) loc +trimLets (Absurd loc) = Absurd loc +trimLets (Zero loc) = Zero loc +trimLets (Succ nat loc) = Succ (trimLets nat) loc +trimLets (CaseNat nat zer suc loc) = + CaseNat (trimLets nat) (trimLets zer) (trimLets' suc) loc + where trimLets' : CaseNatSuc n -> CaseNatSuc n + trimLets' (NSRec x ih s) = NSRec x ih $ trimLets s + trimLets' (NSNonrec x s) = NSNonrec x $ trimLets s +trimLets (Let x rhs body loc) = + let rhs' = trimLets rhs + body' = trimLets body in + if inlineable rhs' || uses VZ body' == 1 + then sub1 rhs' body' + else Let x rhs' body' loc +trimLets (Erased loc) = Erased loc + + export covering eraseDef : Name -> Q.Definition -> Eff Erase U.Definition eraseDef name def@(MkDef qty type body loc) = @@ -420,4 +491,4 @@ eraseDef name def@(MkDef qty type body loc) = Erased => pure ErasedDef Kept => case body of Postulate => throw $ Postulate loc name - Concrete body => KeptDef <$> eraseTerm empty type body + Concrete body => KeptDef . trimLets <$> eraseTerm empty type body diff --git a/lib/Quox/Untyped/Syntax.idr b/lib/Quox/Untyped/Syntax.idr index 36bbedd..a9b6955 100644 --- a/lib/Quox/Untyped/Syntax.idr +++ b/lib/Quox/Untyped/Syntax.idr @@ -35,8 +35,8 @@ data Term where Snd : (pair : Term n) -> Loc -> Term n Tag : (tag : String) -> Loc -> Term n - CaseEnum : (tag : Term n) -> (cases : List (String, Term n)) -> Loc -> Term n - ||| empty match with an erased head + CaseEnum : (tag : Term n) -> (cases : List1 (String, Term n)) -> Loc -> Term n + ||| empty match Absurd : Loc -> Term n Zero : Loc -> Term n @@ -47,6 +47,9 @@ data Term where Loc -> Term n + Let : (x : BindName) -> (rhs : Term n) -> (body : Term (S n)) -> Loc -> + Term n + Erased : Loc -> Term n %name Term s, t, u @@ -61,20 +64,21 @@ data CaseNatSuc where export Located (Term n) where - (F x loc).loc = loc - (B i loc).loc = loc - (Lam x body loc).loc = loc - (App fun arg loc).loc = loc - (Pair fst snd loc).loc = loc - (Fst pair loc).loc = loc - (Snd pair loc).loc = loc - (Tag tag loc).loc = loc - (CaseEnum tag cases loc).loc = loc - (Absurd loc).loc = loc - (Zero loc).loc = loc - (Succ nat loc).loc = loc - (CaseNat nat zer suc loc).loc = loc - (Erased loc).loc = loc + (F _ loc).loc = loc + (B _ loc).loc = loc + (Lam _ _ loc).loc = loc + (App _ _ loc).loc = loc + (Pair _ _ loc).loc = loc + (Fst _ loc).loc = loc + (Snd _ loc).loc = loc + (Tag _ loc).loc = loc + (CaseEnum _ _ loc).loc = loc + (Absurd loc).loc = loc + (Zero loc).loc = loc + (Succ _ loc).loc = loc + (CaseNat _ _ _ loc).loc = loc + (Let _ _ _ loc).loc = loc + (Erased loc).loc = loc public export @@ -85,6 +89,11 @@ public export Definitions = SortedMap Name Definition +export +letD, inD : {opts : LayoutOpts} -> Eff Pretty (Doc opts) +letD = hl Syntax "let" +inD = hl Syntax "in" + export prettyTerm : {opts : LayoutOpts} -> BContext n -> Term n -> Eff Pretty (Doc opts) @@ -136,6 +145,33 @@ private sucPat : {opts : LayoutOpts} -> BindName -> Eff Pretty (Doc opts) sucPat x = pure $ !succD <++> !(prettyTBind x) +private +splitLam : Telescope' BindName a b -> Term b -> + Exists $ \c => (Telescope' BindName a c, Term c) +splitLam ys (Lam x body _) = splitLam (ys :< x) body +splitLam ys t = Evidence _ (ys, t) + +private +splitLet : Telescope (\i => (BindName, Term i)) a b -> Term b -> + Exists $ \c => (Telescope (\i => (BindName, Term i)) a c, Term c) +splitLet ys (Let x rhs body _) = splitLet (ys :< (x, rhs)) body +splitLet ys t = Evidence _ (ys, t) + +private +prettyLets : {opts : LayoutOpts} -> + BContext a -> Telescope (\i => (BindName, Term i)) a b -> + Eff Pretty (SnocList (Doc opts)) +prettyLets xs lets = sequence $ snd $ go lets where + go : forall b. Telescope (\i => (BindName, Term i)) a b -> + (BContext b, SnocList (Eff Pretty (Doc opts))) + go [<] = (xs, [<]) + go (lets :< (x, rhs)) = + let (ys, docs) = go lets + doc = hsep <$> sequence + [letD, prettyTBind x, cstD, assert_total prettyTerm ys rhs, inD] + in + (ys :< x, docs :< doc) + private sucCaseArm : {opts : LayoutOpts} -> CaseNatSuc n -> Eff Pretty (PrettyCaseArm (Doc opts) n) @@ -148,9 +184,10 @@ prettyTerm _ (F x _) = prettyFree x prettyTerm xs (B i _) = prettyTBind $ xs !!! i prettyTerm xs (Lam x body _) = parensIfM Outer =<< do - header <- hsep <$> sequence [lamD, prettyTBind x, darrowD] - body <- withPrec Outer $ prettyTerm (xs :< x) body - hangDSingle header body + let Evidence n' (ys, body) = splitLam [< x] body + vars <- hsep . toList' <$> traverse prettyTBind ys + body <- withPrec Outer $ assert_total prettyTerm (xs . ys) body + hangDSingle (hsep [!lamD, vars, !darrowD]) body prettyTerm xs (App fun arg _) = prettyApp xs fun arg prettyTerm xs (Pair fst snd _) = parens =<< separateTight !commaD <$> @@ -161,21 +198,29 @@ prettyTerm xs (Tag tag _) = prettyTag tag prettyTerm xs (CaseEnum tag cases _) = assert_total prettyCase xs prettyTag tag $ - map (\(t, rhs) => MkPrettyCaseArm t [] rhs) cases + map (\(t, rhs) => MkPrettyCaseArm t [] rhs) $ toList cases prettyTerm xs (Absurd _) = hl Syntax "absurd" prettyTerm xs (Zero _) = zeroD prettyTerm xs (Succ nat _) = prettyApp' xs !succD nat prettyTerm xs (CaseNat nat zer suc _) = assert_total prettyCase xs pure nat [MkPrettyCaseArm !zeroD [] zer, !(sucCaseArm suc)] +prettyTerm xs (Let x rhs body _) = + parensIfM Outer =<< do + let Evidence n' (lets, body) = splitLet [< (x, rhs)] body + heads <- prettyLets xs lets + body <- withPrec Outer $ assert_total prettyTerm (xs . map fst lets) body + let lines = toList $ heads :< body + pure $ ifMultiline (hsep lines) (vsep lines) prettyTerm _ (Erased _) = hl Syntax =<< ifUnicode "⌷" "[]" export prettyDef : {opts : LayoutOpts} -> Name -> - Definition -> Eff Pretty (Maybe (Doc opts)) -prettyDef _ ErasedDef = [|Nothing|] -prettyDef name (KeptDef rhs) = map Just $ do + Definition -> Eff Pretty (Doc opts) +prettyDef name ErasedDef = + pure $ hsep [!(prettyFree name), !cstD, !(prettyTerm [<] $ Erased noLoc)] +prettyDef name (KeptDef rhs) = do name <- prettyFree name eq <- cstD rhs <- prettyTerm [<] rhs @@ -220,6 +265,8 @@ CanSubstSelf Term where CaseNat nat zer suc loc => CaseNat (nat // th) (zer // th) (assert_total substSuc suc th) loc + Let x rhs body loc => + Let x (rhs // th) (assert_total $ body // push th) loc Erased loc => Erased loc where