Elim.compare0 infers the type

instead of calling computeElimType over and over. now there's just one
at the start
This commit is contained in:
rhiannon morris 2023-08-27 19:05:25 +02:00
parent 3e3bf1b67f
commit 72609bc12f

View file

@ -173,7 +173,7 @@ parameters (defs : Definitions)
(E e, Lam b {}) => eta s.loc e b (E e, Lam b {}) => eta s.loc e b
(Lam b {}, E e) => eta s.loc e b (Lam b {}, E e) => eta s.loc e b
(E e, E f) => Elim.compare0 ctx e f (E e, E f) => ignore $ Elim.compare0 ctx e f
(Lam {}, t) => wrongType t.loc ctx ty t (Lam {}, t) => wrongType t.loc ctx ty t
(E _, t) => wrongType t.loc ctx ty t (E _, t) => wrongType t.loc ctx ty t
@ -197,7 +197,7 @@ parameters (defs : Definitions)
compare0 ctx fst sFst tFst compare0 ctx fst sFst tFst
compare0 ctx (sub1 snd (Ann sFst fst fst.loc)) sSnd tSnd compare0 ctx (sub1 snd (Ann sFst fst fst.loc)) sSnd tSnd
(E e, E f) => Elim.compare0 ctx e f (E e, E f) => ignore $ Elim.compare0 ctx e f
(Pair {}, E _) => clashT s.loc ctx ty s t (Pair {}, E _) => clashT s.loc ctx ty s t
(E _, Pair {}) => clashT s.loc ctx ty s t (E _, Pair {}) => clashT s.loc ctx ty s t
@ -214,7 +214,7 @@ parameters (defs : Definitions)
-- t ∈ ts is in the typechecker, not here, ofc -- t ∈ ts is in the typechecker, not here, ofc
(Tag t1 {}, Tag t2 {}) => (Tag t1 {}, Tag t2 {}) =>
unless (t1 == t2) $ clashT s.loc ctx ty s t unless (t1 == t2) $ clashT s.loc ctx ty s t
(E e, E f) => Elim.compare0 ctx e f (E e, E f) => ignore $ Elim.compare0 ctx e f
(Tag {}, E _) => clashT s.loc ctx ty s t (Tag {}, E _) => clashT s.loc ctx ty s t
(E _, Tag {}) => clashT s.loc ctx ty s t (E _, Tag {}) => clashT s.loc ctx ty s t
@ -241,7 +241,7 @@ parameters (defs : Definitions)
-- Γ ⊢ succ s = succ t : -- Γ ⊢ succ s = succ t :
(Succ s' {}, Succ t' {}) => compare0 ctx nat s' t' (Succ s' {}, Succ t' {}) => compare0 ctx nat s' t'
(E e, E f) => Elim.compare0 ctx e f (E e, E f) => ignore $ Elim.compare0 ctx e f
(Zero {}, Succ {}) => clashT s.loc ctx nat s t (Zero {}, Succ {}) => clashT s.loc ctx nat s t
(Zero {}, E _) => clashT s.loc ctx nat s t (Zero {}, E _) => clashT s.loc ctx nat s t
@ -262,7 +262,7 @@ parameters (defs : Definitions)
-- Γ ⊢ [s] = [t] : [π.A] -- Γ ⊢ [s] = [t] : [π.A]
(Box s' {}, Box t' {}) => compare0 ctx ty' s' t' (Box s' {}, Box t' {}) => compare0 ctx ty' s' t'
(E e, E f) => Elim.compare0 ctx e f (E e, E f) => ignore $ Elim.compare0 ctx e f
(Box {}, t) => wrongType t.loc ctx ty t (Box {}, t) => wrongType t.loc ctx ty t
(E _, t) => wrongType t.loc ctx ty t (E _, t) => wrongType t.loc ctx ty t
@ -273,7 +273,7 @@ parameters (defs : Definitions)
-- e.g. an abstract value in an abstract type, bound variables, … -- e.g. an abstract value in an abstract type, bound variables, …
let E e = s | _ => wrongType s.loc ctx ty s let E e = s | _ => wrongType s.loc ctx ty s
E f = t | _ => wrongType t.loc ctx ty t E f = t | _ => wrongType t.loc ctx ty t
Elim.compare0 ctx e f ignore $ Elim.compare0 ctx e f
||| compares two types, using the current variance `mode` for universes. ||| compares two types, using the current variance `mode` for universes.
||| fails if they are not types, even if they would happen to be equal. ||| fails if they are not types, even if they would happen to be equal.
@ -353,39 +353,48 @@ parameters (defs : Definitions)
compareType' ctx (E e) (E f) = do compareType' ctx (E e) (E f) = do
-- no fanciness needed here cos anything other than a neutral -- no fanciness needed here cos anything other than a neutral
-- has been inlined by whnf -- has been inlined by whnf
Elim.compare0 ctx e f ignore $ Elim.compare0 ctx e f
namespace Elim namespace Elim
-- [fixme] the following code ends up repeating a lot of work in the
-- computeElimType calls. the results should be shared better
||| compare two eliminations according to the given variance `mode`. ||| compare two eliminations according to the given variance `mode`.
||| |||
||| ⚠ **assumes that they have both been typechecked, and have ||| ⚠ **assumes that they have both been typechecked, and have
||| equal types.** ⚠ ||| equal types.** ⚠
export covering %inline export covering %inline
compare0 : EqContext n -> (e, f : Elim 0 n) -> Eff EqualInner () compare0 : EqContext n -> (e, f : Elim 0 n) -> Eff EqualInner (Term 0 n)
compare0 ctx e f = compare0 ctx e f =
wrapErr (WhileComparingE ctx !mode e f) $ do wrapErr (WhileComparingE ctx !mode e f) $ do
let Val n = ctx.termLen let Val n = ctx.termLen
Element e' ne <- whnf defs ctx e.loc e Element e' ne <- whnf defs ctx e.loc e
Element f' nf <- whnf defs ctx f.loc f Element f' nf <- whnf defs ctx f.loc f
unless !(isSubSing defs ctx =<< computeElimTypeE defs ctx e') $
compare0' ctx e' f' ne nf -- [todo] share the work of this computeElimTypeE and the return value
-- of compare0' somehow?????
ty <- computeElimTypeE defs ctx e'
if !(isSubSing defs ctx ty)
then pure ty
else compare0' ctx e' f' ne nf
private covering private covering
compare0' : EqContext n -> compare0' : EqContext n ->
(e, f : Elim 0 n) -> (e, f : Elim 0 n) ->
(0 ne : NotRedex defs e) -> (0 nf : NotRedex defs f) -> (0 ne : NotRedex defs e) -> (0 nf : NotRedex defs f) ->
Eff EqualInner () Eff EqualInner (Term 0 n)
compare0' ctx e@(F x u _) f@(F y v _) _ _ = compare0' ctx e@(F x u loc) f@(F y v _) _ _ =
unless (x == y && u == v) $ clashE e.loc ctx e f if x == y && u == v
then do let Val n = ctx.termLen
let Just def = lookup x defs
| Nothing => throw $ NotInScope loc x
pure def.type
else clashE e.loc ctx e f
compare0' ctx e@(F {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(F {}) f _ _ = clashE e.loc ctx e f
compare0' ctx e@(B i _) f@(B j _) _ _ = compare0' ctx e@(B i _) f@(B j _) _ _ =
unless (i == j) $ clashE e.loc ctx e f if i == j
then pure $ ctx.tctx !! i
else clashE e.loc ctx e f
compare0' ctx e@(B {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(B {}) f _ _ = clashE e.loc ctx e f
-- Ψ | Γ ⊢ e = f ⇒ π.(x : A) → B -- Ψ | Γ ⊢ e = f ⇒ π.(x : A) → B
@ -394,10 +403,10 @@ parameters (defs : Definitions)
-- Ψ | Γ ⊢ e s = f t ⇒ B[s∷A/x] -- Ψ | Γ ⊢ e s = f t ⇒ B[s∷A/x]
compare0' ctx (App e s eloc) (App f t floc) ne nf = compare0' ctx (App e s eloc) (App f t floc) ne nf =
local_ Equal $ do local_ Equal $ do
compare0 ctx e f ety <- compare0 ctx e f
(_, arg, _) <- expectPi defs ctx eloc =<< (_, arg, res) <- expectPi defs ctx eloc ety
computeElimTypeE defs ctx e @{noOr1 ne}
Term.compare0 ctx arg s t Term.compare0 ctx arg s t
pure $ sub1 res (Ann s arg s.loc)
compare0' ctx e@(App {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(App {}) f _ _ = clashE e.loc ctx e f
-- Ψ | Γ ⊢ e = f ⇒ (x : A) × B -- Ψ | Γ ⊢ e = f ⇒ (x : A) × B
@ -409,8 +418,7 @@ parameters (defs : Definitions)
compare0' ctx (CasePair epi e eret ebody eloc) compare0' ctx (CasePair epi e eret ebody eloc)
(CasePair fpi f fret fbody {}) ne nf = (CasePair fpi f fret fbody {}) ne nf =
local_ Equal $ do local_ Equal $ do
compare0 ctx e f ety <- compare0 ctx e f
ety <- computeElimTypeE defs ctx e @{noOr1 ne}
compareType (extendTy Zero eret.name ety ctx) eret.term fret.term compareType (extendTy Zero eret.name ety ctx) eret.term fret.term
(fst, snd) <- expectSig defs ctx eloc ety (fst, snd) <- expectSig defs ctx eloc ety
let [< x, y] = ebody.names let [< x, y] = ebody.names
@ -418,6 +426,7 @@ parameters (defs : Definitions)
(substCasePairRet ebody.names ety eret) (substCasePairRet ebody.names ety eret)
ebody.term fbody.term ebody.term fbody.term
expectEqualQ e.loc epi fpi expectEqualQ e.loc epi fpi
pure $ sub1 eret e
compare0' ctx e@(CasePair {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(CasePair {}) f _ _ = clashE e.loc ctx e f
-- Ψ | Γ ⊢ e = f ⇒ {𝐚s} -- Ψ | Γ ⊢ e = f ⇒ {𝐚s}
@ -429,14 +438,14 @@ parameters (defs : Definitions)
compare0' ctx (CaseEnum epi e eret earms eloc) compare0' ctx (CaseEnum epi e eret earms eloc)
(CaseEnum fpi f fret farms floc) ne nf = (CaseEnum fpi f fret farms floc) ne nf =
local_ Equal $ do local_ Equal $ do
compare0 ctx e f ety <- compare0 ctx e f
ety <- computeElimTypeE defs ctx e @{noOr1 ne}
compareType (extendTy Zero eret.name ety ctx) eret.term fret.term compareType (extendTy Zero eret.name ety ctx) eret.term fret.term
for_ !(expectEnum defs ctx eloc ety) $ \t => do for_ !(expectEnum defs ctx eloc ety) $ \t => do
l <- lookupArm eloc t earms l <- lookupArm eloc t earms
r <- lookupArm floc t farms r <- lookupArm floc t farms
compare0 ctx (sub1 eret $ Ann (Tag t l.loc) ety l.loc) l r compare0 ctx (sub1 eret $ Ann (Tag t l.loc) ety l.loc) l r
expectEqualQ eloc epi fpi expectEqualQ eloc epi fpi
pure $ sub1 eret e
where where
lookupArm : Loc -> TagVal -> CaseEnumArms d n -> lookupArm : Loc -> TagVal -> CaseEnumArms d n ->
Eff EqualInner (Term d n) Eff EqualInner (Term d n)
@ -456,8 +465,7 @@ parameters (defs : Definitions)
compare0' ctx (CaseNat epi epi' e eret ezer esuc eloc) compare0' ctx (CaseNat epi epi' e eret ezer esuc eloc)
(CaseNat fpi fpi' f fret fzer fsuc floc) ne nf = (CaseNat fpi fpi' f fret fzer fsuc floc) ne nf =
local_ Equal $ do local_ Equal $ do
compare0 ctx e f ety <- compare0 ctx e f
ety <- computeElimTypeE defs ctx e @{noOr1 ne}
compareType (extendTy Zero eret.name ety ctx) eret.term fret.term compareType (extendTy Zero eret.name ety ctx) eret.term fret.term
compare0 ctx compare0 ctx
(sub1 eret (Ann (Zero ezer.loc) (Nat ezer.loc) ezer.loc)) (sub1 eret (Ann (Zero ezer.loc) (Nat ezer.loc) ezer.loc))
@ -468,6 +476,7 @@ parameters (defs : Definitions)
(substCaseSuccRet esuc.names eret) esuc.term fsuc.term (substCaseSuccRet esuc.names eret) esuc.term fsuc.term
expectEqualQ e.loc epi fpi expectEqualQ e.loc epi fpi
expectEqualQ e.loc epi' fpi' expectEqualQ e.loc epi' fpi'
pure $ sub1 eret e
compare0' ctx e@(CaseNat {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(CaseNat {}) f _ _ = clashE e.loc ctx e f
-- Ψ | Γ ⊢ e = f ⇒ [ρ. A] -- Ψ | Γ ⊢ e = f ⇒ [ρ. A]
@ -479,14 +488,14 @@ parameters (defs : Definitions)
compare0' ctx (CaseBox epi e eret ebody eloc) compare0' ctx (CaseBox epi e eret ebody eloc)
(CaseBox fpi f fret fbody floc) ne nf = (CaseBox fpi f fret fbody floc) ne nf =
local_ Equal $ do local_ Equal $ do
compare0 ctx e f ety <- compare0 ctx e f
ety <- computeElimTypeE defs ctx e @{noOr1 ne}
compareType (extendTy Zero eret.name ety ctx) eret.term fret.term compareType (extendTy Zero eret.name ety ctx) eret.term fret.term
(q, ty) <- expectBOX defs ctx eloc ety (q, ty) <- expectBOX defs ctx eloc ety
compare0 (extendTy (epi * q) ebody.name ty ctx) compare0 (extendTy (epi * q) ebody.name ty ctx)
(substCaseBoxRet ebody.name ety eret) (substCaseBoxRet ebody.name ety eret)
ebody.term fbody.term ebody.term fbody.term
expectEqualQ eloc epi fpi expectEqualQ eloc epi fpi
pure $ sub1 eret e
compare0' ctx e@(CaseBox {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(CaseBox {}) f _ _ = clashE e.loc ctx e f
-- all dim apps replaced with ends by whnf -- all dim apps replaced with ends by whnf
@ -501,6 +510,7 @@ parameters (defs : Definitions)
compare0' ctx (Ann s a _) (Ann t b _) _ _ = do compare0' ctx (Ann s a _) (Ann t b _) _ _ = do
ty <- bigger a b ty <- bigger a b
Term.compare0 ctx ty s t Term.compare0 ctx ty s t
pure ty
-- Ψ | Γ ⊢ Ap₁/𝑖 <: Bp₂/𝑖 -- Ψ | Γ ⊢ Ap₁/𝑖 <: Bp₂/𝑖
-- Ψ | Γ ⊢ Aq₁/𝑖 <: Bq₂/𝑖 -- Ψ | Γ ⊢ Aq₁/𝑖 <: Bq₂/𝑖
@ -514,8 +524,9 @@ parameters (defs : Definitions)
ty1q = dsub1 ty1 q1; ty2q = dsub1 ty2 q2 ty1q = dsub1 ty1 q1; ty2q = dsub1 ty2 q2
compareType ctx ty1p ty2p compareType ctx ty1p ty2p
compareType ctx ty1q ty2q compareType ctx ty1q ty2q
ty_p <- bigger ty1p ty2p (ty_p, ty_q) <- bigger (ty1p, ty1q) (ty2p, ty2q)
Term.compare0 ctx ty_p val1 val2 Term.compare0 ctx ty_p val1 val2
pure ty_q
compare0' ctx e@(Coe {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(Coe {}) f _ _ = clashE e.loc ctx e f
-- (no neutral compositions in a closed dctx) -- (no neutral compositions in a closed dctx)
@ -527,14 +538,14 @@ parameters (defs : Definitions)
compare0' ctx (TypeCase ty1 ret1 arms1 def1 eloc) compare0' ctx (TypeCase ty1 ret1 arms1 def1 eloc)
(TypeCase ty2 ret2 arms2 def2 floc) ne _ = (TypeCase ty2 ret2 arms2 def2 floc) ne _ =
local_ Equal $ do local_ Equal $ do
compare0 ctx ty1 ty2 ety <- compare0 ctx ty1 ty2
u <- expectTYPE defs ctx eloc =<< u <- expectTYPE defs ctx eloc ety
computeElimTypeE defs ctx ty1 @{noOr1 ne}
compareType ctx ret1 ret2 compareType ctx ret1 ret2
compareType ctx def1 def2 compareType ctx def1 def2
for_ allKinds $ \k => for_ allKinds $ \k =>
compareArm ctx k ret1 u compareArm ctx k ret1 u
(lookupPrecise k arms1) (lookupPrecise k arms2) def1 (lookupPrecise k arms1) (lookupPrecise k arms2) def1
pure ret1
compare0' ctx e@(TypeCase {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(TypeCase {}) f _ _ = clashE e.loc ctx e f
-- Ψ | Γ ⊢ s <: f ⇐ A -- Ψ | Γ ⊢ s <: f ⇐ A
@ -542,9 +553,12 @@ parameters (defs : Definitions)
-- Ψ | Γ ⊢ (s ∷ A) <: f ⇒ A -- Ψ | Γ ⊢ (s ∷ A) <: f ⇒ A
-- --
-- and vice versa -- and vice versa
compare0' ctx (Ann s a _) f _ _ = Term.compare0 ctx a s (E f) compare0' ctx (Ann s a _) f _ _ =
compare0' ctx e (Ann t b _) _ _ = Term.compare0 ctx b (E e) t Term.compare0 ctx a s (E f) $> a
compare0' ctx e@(Ann {}) f _ _ = clashE e.loc ctx e f compare0' ctx e (Ann t b _) _ _ =
Term.compare0 ctx b (E e) t $> b
compare0' ctx e@(Ann {}) f _ _ =
clashE e.loc ctx e f
||| compare two type-case branches, which came from the arms of the given ||| compare two type-case branches, which came from the arms of the given
||| kind. `ret` is the return type of the case expression, and `u` is the ||| kind. `ret` is the return type of the case expression, and `u` is the
@ -644,7 +658,7 @@ parameters (loc : Loc) (ctx : TyContext d n)
export covering export covering
compare : (e, f : Elim d n) -> Eff Equal () compare : (e, f : Elim d n) -> Eff Equal ()
compare e f = runCompare $ \defs, ectx, th => compare e f = runCompare $ \defs, ectx, th =>
compare0 defs ectx (e // th) (f // th) ignore $ compare0 defs ectx (e // th) (f // th)
namespace Term namespace Term
export covering %inline export covering %inline