From 98fa8d9967d57c9fc9373e5d63db3a16794eec9e Mon Sep 17 00:00:00 2001 From: rhiannon morris Date: Sun, 8 Jan 2023 14:59:25 +0100 Subject: [PATCH] mode eq mode into a reader --- lib/Quox/Equal.idr | 220 ++++++++++++++++++++++++--------------------- 1 file changed, 118 insertions(+), 102 deletions(-) diff --git a/lib/Quox/Equal.idr b/lib/Quox/Equal.idr index a74fa6e..8a70149 100644 --- a/lib/Quox/Equal.idr +++ b/lib/Quox/Equal.idr @@ -12,24 +12,31 @@ private %inline ClashE : EqMode -> Elim d n -> Elim d n -> Error ClashE mode = ClashT mode `on` E + public export -CanEq : (Type -> Type) -> Type -CanEq m = (MonadError Error m, MonadReader Definitions m) +record Env where + constructor MakeEnv + defs : Definitions + mode : EqMode -parameters {auto _ : CanEq m} +parameters {auto _ : MonadError Error m} {auto _ : MonadReader Env m} private %inline - clashT : EqMode -> Term d n -> Term d n -> m a - clashT mode = throwError .: ClashT mode + mode : m EqMode + mode = asks mode private %inline - clashE : EqMode -> Elim d n -> Elim d n -> m a - clashE mode = throwError .: ClashE mode + clashT : Term d n -> Term d n -> m a + clashT s t = throwError $ ClashT !mode s t + + private %inline + clashE : Elim d n -> Elim d n -> m a + clashE e f = throwError $ ClashE !mode e f private %inline defE : Name -> m (Maybe (Elim d n)) - defE x = asks $ \defs => do - g <- lookup x defs + defE x = asks $ \env => do + g <- lookup x env.defs pure $ (!g.term).def :# g.type.def private %inline @@ -37,150 +44,159 @@ parameters {auto _ : CanEq m} defT x = map E <$> defE x export %inline - compareU' : EqMode -> Universe -> Universe -> Bool - compareU' = \case Equal => (==); Sub => (<=) + compareU' : Universe -> Universe -> m Bool + compareU' i j = pure $ + case !mode of Equal => i == j; Sub => i <= j export %inline - compareU : EqMode -> Universe -> Universe -> m () - compareU mode k l = unless (compareU' mode k l) $ - throwError $ ClashU mode k l + compareU : Universe -> Universe -> m () + compareU k l = unless !(compareU' k l) $ + throwError $ ClashU !mode k l mutual private covering - compareTN' : EqMode -> - (s, t : Term 0 n) -> + compareTN' : (s, t : Term 0 n) -> (0 _ : NotRedexT s) -> (0 _ : NotRedexT t) -> m () - compareTN' mode (E e) (E f) ps pt = compareE0 mode e f + compareTN' (E e) (E f) ps pt = compareE0 e f -- if either term is a def, try to unfold it - compareTN' mode s@(E (F x)) t ps pt = do - Just s' <- defT x | Nothing => clashT mode s t - compareT0 mode s' t - compareTN' mode s t@(E (F y)) ps pt = do - Just t' <- defT y | Nothing => clashT mode s t - compareT0 mode s t' - compareTN' mode s@(E _) t _ _ = clashT mode s t + compareTN' s@(E (F x)) t ps pt = do + Just s' <- defT x | Nothing => clashT s t + compareT0 s' t + compareTN' s t@(E (F y)) ps pt = do + Just t' <- defT y | Nothing => clashT s t + compareT0 s t' + compareTN' s@(E _) t _ _ = clashT s t - compareTN' mode (TYPE k) (TYPE l) _ _ = compareU mode k l - compareTN' mode s@(TYPE _) t _ _ = clashT mode s t + compareTN' (TYPE k) (TYPE l) _ _ = compareU k l + compareTN' s@(TYPE _) t _ _ = clashT s t - compareTN' mode (Pi qty1 _ arg1 res1) (Pi qty2 _ arg2 res2) _ _ = do + compareTN' (Pi qty1 _ arg1 res1) (Pi qty2 _ arg2 res2) _ _ = do unless (qty1 == qty2) $ throwError $ ClashQ qty1 qty2 - compareT0 mode arg2 arg1 -- reversed for contravariant domain - compareTS0 mode res1 res2 - compareTN' mode s@(Pi {}) t _ _ = clashT mode s t + compareT0 arg2 arg1 -- reversed for contravariant domain + compareTS0 res1 res2 + compareTN' s@(Pi {}) t _ _ = clashT s t -- [todo] eta - compareTN' _ (Lam _ body1) (Lam _ body2) _ _ = - compareTS0 Equal body1 body2 - compareTN' mode s@(Lam {}) t _ _ = clashT mode s t + compareTN' (Lam _ body1) (Lam _ body2) _ _ = + local {mode := Equal} $ compareTS0 body1 body2 + compareTN' s@(Lam {}) t _ _ = clashT s t - compareTN' _ (CloT {}) _ ps _ = void $ ps IsCloT - compareTN' _ (DCloT {}) _ ps _ = void $ ps IsDCloT + compareTN' (CloT {}) _ ps _ = void $ ps IsCloT + compareTN' (DCloT {}) _ ps _ = void $ ps IsDCloT private covering - compareEN' : EqMode -> - (e, f : Elim 0 n) -> + compareEN' : (e, f : Elim 0 n) -> (0 _ : NotRedexE e) -> (0 _ : NotRedexE f) -> m () - compareEN' mode e@(F x) f@(F y) _ _ = + compareEN' e@(F x) f@(F y) _ _ = if x == y then pure () else case (!(defE x), !(defE y)) of - (Nothing, Nothing) => clashE mode e f - (s', t') => compareE0 mode (fromMaybe e s') (fromMaybe f t') - compareEN' mode e@(F x) f _ _ = do - Just e' <- defE x | Nothing => clashE mode e f - compareE0 mode e' f - compareEN' mode e f@(F y) _ _ = do - Just f' <- defE y | Nothing => clashE mode e f - compareE0 mode e f' + (Nothing, Nothing) => clashE e f + (s', t') => compareE0 (fromMaybe e s') (fromMaybe f t') + compareEN' e@(F x) f _ _ = do + Just e' <- defE x | Nothing => clashE e f + compareE0 e' f + compareEN' e f@(F y) _ _ = do + Just f' <- defE y | Nothing => clashE e f + compareE0 e f' - compareEN' mode e@(B i) f@(B j) _ _ = - unless (i == j) $ clashE mode e f - compareEN' mode e@(B _) f _ _ = clashE mode e f + compareEN' e@(B i) f@(B j) _ _ = + unless (i == j) $ clashE e f + compareEN' e@(B _) f _ _ = clashE e f -- [todo] tracking variance of functions? maybe??? -- probably not - compareEN' _ (fun1 :@ arg1) (fun2 :@ arg2) _ _ = do - compareE0 Equal fun1 fun2 - compareT0 Equal arg1 arg2 - compareEN' mode e@(_ :@ _) f _ _ = clashE mode e f + compareEN' (fun1 :@ arg1) (fun2 :@ arg2) _ _ = + local {mode := Equal} $ do + compareE0 fun1 fun2 + compareT0 arg1 arg2 + compareEN' e@(_ :@ _) f _ _ = clashE e f -- [todo] is always checking the types are equal correct? - compareEN' mode (tm1 :# ty1) (tm2 :# ty2) _ _ = do - compareT0 mode tm1 tm2 - compareT0 Equal ty1 ty2 - compareEN' mode e@(_ :# _) f _ _ = clashE mode e f + compareEN' (tm1 :# ty1) (tm2 :# ty2) _ _ = do + compareT0 tm1 tm2 + local {mode := Equal} $ compareT0 ty1 ty2 + compareEN' e@(_ :# _) f _ _ = clashE e f - compareEN' _ (CloE {}) _ pe _ = void $ pe IsCloE - compareEN' _ (DCloE {}) _ pe _ = void $ pe IsDCloE + compareEN' (CloE {}) _ pe _ = void $ pe IsCloE + compareEN' (DCloE {}) _ pe _ = void $ pe IsDCloE private covering %inline - compareTN : EqMode -> NonRedexTerm 0 n -> NonRedexTerm 0 n -> m () - compareTN mode s t = compareTN' mode s.fst t.fst s.snd t.snd + compareTN : NonRedexTerm 0 n -> NonRedexTerm 0 n -> m () + compareTN s t = compareTN' s.fst t.fst s.snd t.snd private covering %inline - compareEN : EqMode -> NonRedexElim 0 n -> NonRedexElim 0 n -> m () - compareEN mode e f = compareEN' mode e.fst f.fst e.snd f.snd + compareEN : NonRedexElim 0 n -> NonRedexElim 0 n -> m () + compareEN e f = compareEN' e.fst f.fst e.snd f.snd export covering %inline - compareT : EqMode -> DimEq d -> Term d n -> Term d n -> m () - compareT mode eqs s t = - for_ (splits eqs) $ \th => compareT0 mode (s /// th) (t /// th) + compareT : DimEq d -> Term d n -> Term d n -> m () + compareT eqs s t = + for_ (splits eqs) $ \th => compareT0 (s /// th) (t /// th) export covering %inline - compareE : EqMode -> DimEq d -> Elim d n -> Elim d n -> m () - compareE mode eqs e f = - for_ (splits eqs) $ \th => compareE0 mode (e /// th) (f /// th) + compareE : DimEq d -> Elim d n -> Elim d n -> m () + compareE eqs e f = + for_ (splits eqs) $ \th => compareE0 (e /// th) (f /// th) export covering %inline - compareT0 : EqMode -> Term 0 n -> Term 0 n -> m () - compareT0 mode s t = compareTN mode (whnfT s) (whnfT t) + compareT0 : Term 0 n -> Term 0 n -> m () + compareT0 s t = compareTN (whnfT s) (whnfT t) export covering %inline - compareE0 : EqMode -> Elim 0 n -> Elim 0 n -> m () - compareE0 mode e f = compareEN mode (whnfE e) (whnfE f) + compareE0 : Elim 0 n -> Elim 0 n -> m () + compareE0 e f = compareEN (whnfE e) (whnfE f) export covering %inline - compareTS0 : EqMode -> ScopeTerm 0 n -> ScopeTerm 0 n -> m () - compareTS0 mode (TUnused body1) (TUnused body2) = - compareT0 mode body1 body2 - compareTS0 mode body1 body2 = - compareT0 mode (fromScopeTerm body1) (fromScopeTerm body2) + compareTS0 : ScopeTerm 0 n -> ScopeTerm 0 n -> m () + compareTS0 (TUnused body1) (TUnused body2) = + compareT0 body1 body2 + compareTS0 body1 body2 = + compareT0 (fromScopeTerm body1) (fromScopeTerm body2) - export covering %inline - equalTWith : DimEq d -> Term d n -> Term d n -> m () - equalTWith = compareT Equal +parameters {auto _ : MonadError Error m} {auto _ : MonadReader Definitions m} + private %inline + into : EqMode -> + (forall n. MonadError Error n => MonadReader Env n => + DimEq d -> a -> a -> n ()) -> + DimEq d -> a -> a -> m () + into mode f eqs a b = + runReaderT {m} (MakeEnv {mode, defs = !ask}) $ f eqs a b - export covering %inline - equalEWith : DimEq d -> Elim d n -> Elim d n -> m () - equalEWith = compareE Equal + export covering %inline + equalTWith : DimEq d -> Term d n -> Term d n -> m () + equalTWith = into Equal compareT - export covering %inline - subTWith : DimEq d -> Term d n -> Term d n -> m () - subTWith = compareT Sub + export covering %inline + equalEWith : DimEq d -> Elim d n -> Elim d n -> m () + equalEWith = into Equal compareE - export covering %inline - subEWith : DimEq d -> Elim d n -> Elim d n -> m () - subEWith = compareE Sub + export covering %inline + subTWith : DimEq d -> Term d n -> Term d n -> m () + subTWith = into Sub compareT + + export covering %inline + subEWith : DimEq d -> Elim d n -> Elim d n -> m () + subEWith = into Sub compareE - export covering %inline - equalT : {d : Nat} -> Term d n -> Term d n -> m () - equalT = equalTWith DimEq.new + export covering %inline + equalT : {d : Nat} -> Term d n -> Term d n -> m () + equalT = equalTWith DimEq.new - export covering %inline - equalE : {d : Nat} -> Elim d n -> Elim d n -> m () - equalE = equalEWith DimEq.new + export covering %inline + equalE : {d : Nat} -> Elim d n -> Elim d n -> m () + equalE = equalEWith DimEq.new - export covering %inline - subT : {d : Nat} -> Term d n -> Term d n -> m () - subT = subTWith DimEq.new + export covering %inline + subT : {d : Nat} -> Term d n -> Term d n -> m () + subT = subTWith DimEq.new - export covering %inline - subE : {d : Nat} -> Elim d n -> Elim d n -> m () - subE = subEWith DimEq.new + export covering %inline + subE : {d : Nat} -> Elim d n -> Elim d n -> m () + subE = subEWith DimEq.new