make Elim.compare0 able to pass a type to isSubSing

it now recovers from (most) errors and always returns a type, so that
isSubSing doesn't have to recalculate it

it already assumed the inputs had the same type. now it just leans on
that assumption harder
This commit is contained in:
rhiannon morris 2023-08-28 20:00:54 +02:00
parent 6f9d31aa0a
commit add2eb400c
2 changed files with 114 additions and 81 deletions

View file

@ -52,6 +52,7 @@ public export
data Length : List a -> Type where data Length : List a -> Type where
Z : Length [] Z : Length []
S : Length xs -> Length (x :: xs) S : Length xs -> Length (x :: xs)
%builtin Natural Length
export export
subsetWith : Length xs => (forall z. Has z xs -> Has z ys) -> subsetWith : Length xs => (forall z. Has z xs -> Has z ys) ->
@ -64,20 +65,28 @@ subsetSelf : Length xs => Subset xs xs
subsetSelf = subsetWith id subsetSelf = subsetWith id
export export
subsetTail : Length xs => Subset xs (x :: xs) subsetTail : Length xs => (0 x : a) -> Subset xs (x :: xs)
subsetTail = subsetWith S subsetTail _ = subsetWith S
-- [fixme] allow the error to be anywhere in the effect list
export export
wrapErrAt : Length fs => (0 lbl : tag) -> (e -> e') -> catchMaybeAt : (0 lbl : tag) -> (Has (ExceptL lbl e) fs, Length fs) =>
Eff (ExceptL lbl e :: fs) a -> Eff (ExceptL lbl e' :: fs) a (e -> Eff fs a) -> Eff fs a -> Eff fs a
wrapErrAt lbl f act = catchMaybeAt lbl hnd act =
rethrowAt lbl . mapFst f =<< lift @{subsetTail} (runExceptAt lbl act) catchAt lbl hnd $ lift @{subsetTail $ ExceptL lbl e} act
export %inline export %inline
wrapErr : Length fs => (e -> e') -> catchMaybe : (Has (Except e) fs, Length fs) =>
Eff (Except e :: fs) a -> Eff (Except e' :: fs) a (e -> Eff fs a) -> Eff fs a -> Eff fs a
catchMaybe = catchMaybeAt ()
export
wrapErrAt : (0 lbl : tag) -> (Has (ExceptL lbl e) fs, Length fs) =>
(e -> e) -> Eff fs a -> Eff fs a
wrapErrAt lbl wrap = catchMaybeAt lbl (\ex => throwAt lbl $ wrap ex)
export %inline
wrapErr : (Has (Except e) fs, Length fs) => (e -> e) -> Eff fs a -> Eff fs a
wrapErr = wrapErrAt () wrapErr = wrapErrAt ()

View file

@ -363,38 +363,49 @@ parameters (defs : Definitions)
||| equal types.** ⚠ ||| equal types.** ⚠
export covering %inline export covering %inline
compare0 : EqContext n -> (e, f : Elim 0 n) -> Eff EqualInner (Term 0 n) compare0 : EqContext n -> (e, f : Elim 0 n) -> Eff EqualInner (Term 0 n)
compare0 ctx e f = compare0 ctx e f = do
(err, ty) <- compare0Inner ctx e f
maybe (pure ty) throw err
private covering
compare0Inner : EqContext n -> (e, f : Elim 0 n) ->
Eff EqualInner (Maybe Error, Term 0 n)
compare0Inner ctx e f =
wrapErr (WhileComparingE ctx !mode e f) $ do wrapErr (WhileComparingE ctx !mode e f) $ do
let Val n = ctx.termLen let Val n = ctx.termLen
Element e' ne <- whnf defs ctx e.loc e Element e ne <- whnf defs ctx e.loc e
Element f' nf <- whnf defs ctx f.loc f Element f nf <- whnf defs ctx f.loc f
(err, ty) <- compare0' ctx e f ne nf
-- [todo] share the work of this computeElimTypeE and the return value
-- of compare0' somehow?????
ty <- computeElimTypeE defs ctx e'
if !(isSubSing defs ctx ty) if !(isSubSing defs ctx ty)
then pure ty then pure (Nothing, ty)
else compare0' ctx e' f' ne nf else pure (err, ty)
private
try_ : Eff EqualInner () -> Eff EqualInner (Maybe Error)
try_ act = lift $ catch (pure . Just) $ act $> Nothing
private
lookupFree : EqContext n -> Name -> Universe -> Loc ->
Eff EqualInner (Term 0 n)
lookupFree ctx x u loc =
let Val n = ctx.termLen in
maybe (throw $ NotInScope loc x) (\d => pure $ d.typeAt u) $
lookup x defs
private covering private covering
compare0' : EqContext n -> compare0' : EqContext n ->
(e, f : Elim 0 n) -> (e, f : Elim 0 n) ->
(0 ne : NotRedex defs e) -> (0 nf : NotRedex defs f) -> (0 ne : NotRedex defs e) -> (0 nf : NotRedex defs f) ->
Eff EqualInner (Term 0 n) Eff EqualInner (Maybe Error, Term 0 n)
compare0' ctx e@(F x u loc) f@(F y v _) _ _ = compare0' ctx e@(F x u loc) f@(F y v _) _ _ = do
if x == y && u == v pure (guard (x /= y || u /= v) $> ClashE loc ctx !mode e f,
then do let Val n = ctx.termLen !(lookupFree ctx x u loc))
let Just def = lookup x defs
| Nothing => throw $ NotInScope loc x
pure def.type
else clashE e.loc ctx e f
compare0' ctx e@(F {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(F {}) f _ _ = clashE e.loc ctx e f
compare0' ctx e@(B i _) f@(B j _) _ _ = compare0' ctx e@(B i loc) f@(B j _) _ _ =
if i == j pure (guard (i /= j) $> ClashE loc ctx !mode e f,
then pure $ ctx.tctx !! i ctx.tctx !! i)
else clashE e.loc ctx e f
compare0' ctx e@(B {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(B {}) f _ _ = clashE e.loc ctx e f
-- Ψ | Γ ⊢ e = f ⇒ π.(x : A) → B -- Ψ | Γ ⊢ e = f ⇒ π.(x : A) → B
@ -403,10 +414,10 @@ parameters (defs : Definitions)
-- Ψ | Γ ⊢ e s = f t ⇒ B[s∷A/x] -- Ψ | Γ ⊢ e s = f t ⇒ B[s∷A/x]
compare0' ctx (App e s eloc) (App f t floc) ne nf = compare0' ctx (App e s eloc) (App f t floc) ne nf =
local_ Equal $ do local_ Equal $ do
ety <- compare0 ctx e f (err1, ety) <- compare0Inner ctx e f
(_, arg, res) <- expectPi defs ctx eloc ety (_, arg, res) <- expectPi defs ctx eloc ety
Term.compare0 ctx arg s t err2 <- try_ $ Term.compare0 ctx arg s t
pure $ sub1 res (Ann s arg s.loc) pure (err1 <|> err2, sub1 res $ Ann s arg s.loc)
compare0' ctx e@(App {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(App {}) f _ _ = clashE e.loc ctx e f
-- Ψ | Γ ⊢ e = f ⇒ (x : A) × B -- Ψ | Γ ⊢ e = f ⇒ (x : A) × B
@ -418,15 +429,16 @@ parameters (defs : Definitions)
compare0' ctx (CasePair epi e eret ebody eloc) compare0' ctx (CasePair epi e eret ebody eloc)
(CasePair fpi f fret fbody {}) ne nf = (CasePair fpi f fret fbody {}) ne nf =
local_ Equal $ do local_ Equal $ do
ety <- compare0 ctx e f (err1, ety) <- compare0Inner ctx e f
compareType (extendTy Zero eret.name ety ctx) eret.term fret.term compareType (extendTy Zero eret.name ety ctx) eret.term fret.term
(fst, snd) <- expectSig defs ctx eloc ety (fst, snd) <- expectSig defs ctx eloc ety
let [< x, y] = ebody.names let [< x, y] = ebody.names
Term.compare0 (extendTyN [< (epi, x, fst), (epi, y, snd.term)] ctx) err2 <- try_ $
(substCasePairRet ebody.names ety eret) Term.compare0 (extendTyN [< (epi, x, fst), (epi, y, snd.term)] ctx)
ebody.term fbody.term (substCasePairRet ebody.names ety eret)
expectEqualQ e.loc epi fpi ebody.term fbody.term
pure $ sub1 eret e err3 <- try_ $ expectEqualQ e.loc epi fpi
pure (concat [err1, err2, err3], sub1 eret e)
compare0' ctx e@(CasePair {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(CasePair {}) f _ _ = clashE e.loc ctx e f
-- Ψ | Γ ⊢ e = f ⇒ {𝐚s} -- Ψ | Γ ⊢ e = f ⇒ {𝐚s}
@ -438,14 +450,17 @@ parameters (defs : Definitions)
compare0' ctx (CaseEnum epi e eret earms eloc) compare0' ctx (CaseEnum epi e eret earms eloc)
(CaseEnum fpi f fret farms floc) ne nf = (CaseEnum fpi f fret farms floc) ne nf =
local_ Equal $ do local_ Equal $ do
ety <- compare0 ctx e f (err1, ety) <- compare0Inner ctx e f
compareType (extendTy Zero eret.name ety ctx) eret.term fret.term err2 <- try_ $
for_ !(expectEnum defs ctx eloc ety) $ \t => do compareType (extendTy Zero eret.name ety ctx) eret.term fret.term
cases <- SortedSet.toList <$> expectEnum defs ctx eloc ety
exs <- for cases $ \t => do
l <- lookupArm eloc t earms l <- lookupArm eloc t earms
r <- lookupArm floc t farms r <- lookupArm floc t farms
compare0 ctx (sub1 eret $ Ann (Tag t l.loc) ety l.loc) l r try_ $
expectEqualQ eloc epi fpi Term.compare0 ctx (sub1 eret $ Ann (Tag t l.loc) ety l.loc) l r
pure $ sub1 eret e err3 <- try_ $ expectEqualQ eloc epi fpi
pure (concat $ [err1, err2, err3] ++ exs, sub1 eret e)
where where
lookupArm : Loc -> TagVal -> CaseEnumArms d n -> lookupArm : Loc -> TagVal -> CaseEnumArms d n ->
Eff EqualInner (Term d n) Eff EqualInner (Term d n)
@ -465,18 +480,21 @@ parameters (defs : Definitions)
compare0' ctx (CaseNat epi epi' e eret ezer esuc eloc) compare0' ctx (CaseNat epi epi' e eret ezer esuc eloc)
(CaseNat fpi fpi' f fret fzer fsuc floc) ne nf = (CaseNat fpi fpi' f fret fzer fsuc floc) ne nf =
local_ Equal $ do local_ Equal $ do
ety <- compare0 ctx e f (err1, ety) <- compare0Inner ctx e f
compareType (extendTy Zero eret.name ety ctx) eret.term fret.term err2 <- try_ $
compare0 ctx compareType (extendTy Zero eret.name ety ctx) eret.term fret.term
(sub1 eret (Ann (Zero ezer.loc) (Nat ezer.loc) ezer.loc)) err3 <- try_ $
ezer fzer Term.compare0 ctx
(sub1 eret (Ann (Zero ezer.loc) (Nat ezer.loc) ezer.loc))
ezer fzer
let [< p, ih] = esuc.names let [< p, ih] = esuc.names
compare0 err4 <- try_ $
(extendTyN [< (epi, p, Nat p.loc), (epi', ih, eret.term)] ctx) Term.compare0
(substCaseSuccRet esuc.names eret) esuc.term fsuc.term (extendTyN [< (epi, p, Nat p.loc), (epi', ih, eret.term)] ctx)
expectEqualQ e.loc epi fpi (substCaseSuccRet esuc.names eret) esuc.term fsuc.term
expectEqualQ e.loc epi' fpi' err5 <- try_ $ expectEqualQ e.loc epi fpi
pure $ sub1 eret e err6 <- try_ $ expectEqualQ e.loc epi' fpi'
pure (concat [err1, err2, err3, err4, err5, err6], sub1 eret e)
compare0' ctx e@(CaseNat {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(CaseNat {}) f _ _ = clashE e.loc ctx e f
-- Ψ | Γ ⊢ e = f ⇒ [ρ. A] -- Ψ | Γ ⊢ e = f ⇒ [ρ. A]
@ -488,14 +506,16 @@ parameters (defs : Definitions)
compare0' ctx (CaseBox epi e eret ebody eloc) compare0' ctx (CaseBox epi e eret ebody eloc)
(CaseBox fpi f fret fbody floc) ne nf = (CaseBox fpi f fret fbody floc) ne nf =
local_ Equal $ do local_ Equal $ do
ety <- compare0 ctx e f (err1, ety) <- compare0Inner ctx e f
compareType (extendTy Zero eret.name ety ctx) eret.term fret.term err2 <- try_ $
compareType (extendTy Zero eret.name ety ctx) eret.term fret.term
(q, ty) <- expectBOX defs ctx eloc ety (q, ty) <- expectBOX defs ctx eloc ety
compare0 (extendTy (epi * q) ebody.name ty ctx) err3 <- try_ $
(substCaseBoxRet ebody.name ety eret) Term.compare0 (extendTy (epi * q) ebody.name ty ctx)
ebody.term fbody.term (substCaseBoxRet ebody.name ety eret)
expectEqualQ eloc epi fpi ebody.term fbody.term
pure $ sub1 eret e err4 <- try_ $ expectEqualQ eloc epi fpi
pure (concat [err1, err2, err3, err4], sub1 eret e)
compare0' ctx e@(CaseBox {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(CaseBox {}) f _ _ = clashE e.loc ctx e f
-- all dim apps replaced with ends by whnf -- all dim apps replaced with ends by whnf
@ -509,8 +529,8 @@ parameters (defs : Definitions)
-- and similar for :> and A -- and similar for :> and A
compare0' ctx (Ann s a _) (Ann t b _) _ _ = do compare0' ctx (Ann s a _) (Ann t b _) _ _ = do
ty <- bigger a b ty <- bigger a b
Term.compare0 ctx ty s t err <- try_ $ Term.compare0 ctx ty s t
pure ty pure (err, ty)
-- Ψ | Γ ⊢ Ap₁/𝑖 <: Bp₂/𝑖 -- Ψ | Γ ⊢ Ap₁/𝑖 <: Bp₂/𝑖
-- Ψ | Γ ⊢ Aq₁/𝑖 <: Bq₂/𝑖 -- Ψ | Γ ⊢ Aq₁/𝑖 <: Bq₂/𝑖
@ -522,11 +542,11 @@ parameters (defs : Definitions)
(Coe ty2 p2 q2 val2 _) ne nf = do (Coe ty2 p2 q2 val2 _) ne nf = do
let ty1p = dsub1 ty1 p1; ty2p = dsub1 ty2 p2 let ty1p = dsub1 ty1 p1; ty2p = dsub1 ty2 p2
ty1q = dsub1 ty1 q1; ty2q = dsub1 ty2 q2 ty1q = dsub1 ty1 q1; ty2q = dsub1 ty2 q2
compareType ctx ty1p ty2p err1 <- try_ $ compareType ctx ty1p ty2p
compareType ctx ty1q ty2q err2 <- try_ $ compareType ctx ty1q ty2q
(ty_p, ty_q) <- bigger (ty1p, ty1q) (ty2p, ty2q) (ty_p, ty_q) <- bigger (ty1p, ty1q) (ty2p, ty2q)
Term.compare0 ctx ty_p val1 val2 err3 <- try_ $ Term.compare0 ctx ty_p val1 val2
pure ty_q pure (concat [err1, err2, err3], ty_q)
compare0' ctx e@(Coe {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(Coe {}) f _ _ = clashE e.loc ctx e f
-- (no neutral compositions in a closed dctx) -- (no neutral compositions in a closed dctx)
@ -538,14 +558,16 @@ parameters (defs : Definitions)
compare0' ctx (TypeCase ty1 ret1 arms1 def1 eloc) compare0' ctx (TypeCase ty1 ret1 arms1 def1 eloc)
(TypeCase ty2 ret2 arms2 def2 floc) ne _ = (TypeCase ty2 ret2 arms2 def2 floc) ne _ =
local_ Equal $ do local_ Equal $ do
ety <- compare0 ctx ty1 ty2 -- try
u <- expectTYPE defs ctx eloc ety (err1, ety) <- compare0Inner ctx ty1 ty2
compareType ctx ret1 ret2 u <- expectTYPE defs ctx eloc ety
compareType ctx def1 def2 err2 <- try_ $ compareType ctx ret1 ret2
for_ allKinds $ \k => err3 <- try_ $ compareType ctx def1 def2
compareArm ctx k ret1 u exs <- for allKinds $ \k =>
(lookupPrecise k arms1) (lookupPrecise k arms2) def1 try_ $
pure ret1 compareArm ctx k ret1 u
(lookupPrecise k arms1) (lookupPrecise k arms2) def1
pure (concat $ [err1, err2, err3] ++ exs, ret1)
compare0' ctx e@(TypeCase {}) f _ _ = clashE e.loc ctx e f compare0' ctx e@(TypeCase {}) f _ _ = clashE e.loc ctx e f
-- Ψ | Γ ⊢ s <: f ⇐ A -- Ψ | Γ ⊢ s <: f ⇐ A
@ -553,10 +575,12 @@ parameters (defs : Definitions)
-- Ψ | Γ ⊢ (s ∷ A) <: f ⇒ A -- Ψ | Γ ⊢ (s ∷ A) <: f ⇒ A
-- --
-- and vice versa -- and vice versa
compare0' ctx (Ann s a _) f _ _ = compare0' ctx (Ann s a _) f _ _ = do
Term.compare0 ctx a s (E f) $> a err <- try_ $ Term.compare0 ctx a s (E f)
compare0' ctx e (Ann t b _) _ _ = pure (err, a)
Term.compare0 ctx b (E e) t $> b compare0' ctx e (Ann t b _) _ _ = do
err <- try_ $ Term.compare0 ctx b (E e) t
pure (err, b)
compare0' ctx e@(Ann {}) f _ _ = compare0' ctx e@(Ann {}) f _ _ =
clashE e.loc ctx e f clashE e.loc ctx e f