module Quox.Equal import public Quox.Syntax import public Quox.Definition import public Quox.Typing import public Control.Monad.Either import public Control.Monad.Reader import Data.Maybe private %inline ClashE : EqMode -> Elim q d n -> Elim q d n -> Error q ClashE mode = ClashT mode `on` E public export record Env where constructor MakeEnv mode : EqMode public export 0 HasEnv : (Type -> Type) -> Type HasEnv = MonadReader Env public export 0 CanEqual : (q : Type) -> (Type -> Type) -> Type CanEqual q m = (HasErr q m, HasEnv m) private %inline mode : HasEnv m => m EqMode mode = asks mode private %inline clashT : CanEqual q m => Term q d n -> Term q d n -> m a clashT s t = throwError $ ClashT !mode s t private %inline clashE : CanEqual q m => Elim q d n -> Elim q d n -> m a clashE e f = throwError $ ClashE !mode e f parameters {0 isGlobal : _} (defs : Definitions' q isGlobal) mutual namespace Term export covering compareN' : CanEqual q m => Eq q => (s, t : Term q 0 n) -> (0 _ : NotRedex defs s) -> (0 _ : NotRedex defs t) -> m () compareN' (TYPE k) (TYPE l) _ _ = expectModeU !mode k l compareN' s@(TYPE _) t _ _ = clashT s t compareN' (Pi qty1 _ arg1 res1) (Pi qty2 _ arg2 res2) _ _ = do expectEqualQ qty1 qty2 compare0 arg2 arg1 -- reversed for contravariant domain compare0 res1 res2 compareN' s@(Pi {}) t _ _ = clashT s t -- [todo] eta compareN' (Lam _ body1) (Lam _ body2) _ _ = local {mode := Equal} $ compare0 body1 body2 compareN' s@(Lam {}) t _ _ = clashT s t compareN' (Sig _ fst1 snd1) (Sig _ fst2 snd2) _ _ = do compare0 fst1 fst2 compare0 snd1 snd2 compareN' s@(Sig {}) t _ _ = clashT s t compareN' (Pair fst1 snd1) (Pair fst2 snd2) _ _ = local {mode := Equal} $ do compare0 fst1 fst2 compare0 snd1 snd2 compareN' s@(Pair {}) t _ _ = clashT s t compareN' (Eq _ ty1 l1 r1) (Eq _ ty2 l2 r2) _ _ = do compare0 ty1 ty2 local {mode := Equal} $ do compare0 l1 l2 compare0 r1 r2 compareN' s@(Eq {}) t _ _ = clashT s t compareN' (DLam _ body1) (DLam _ body2) _ _ = local {mode := Equal} $ do compare0 body1 body2 compareN' s@(DLam {}) t _ _ = clashT s t compareN' (E e) (E f) ne nf = compareN' e f (noOr2 ne) (noOr2 nf) compareN' s@(E e) t _ _ = clashT s t namespace Elim export covering compareN' : CanEqual q m => Eq q => (e, f : Elim q 0 n) -> (0 _ : NotRedex defs e) -> (0 _ : NotRedex defs f) -> m () compareN' e@(F x) f@(F y) _ _ = unless (x == y) $ clashE e f compareN' e@(F _) f _ _ = clashE e f compareN' e@(B i) f@(B j) _ _ = unless (i == j) $ clashE e f compareN' e@(B _) f _ _ = clashE e f -- [todo] tracking variance of functions? maybe??? -- probably not compareN' (fun1 :@ arg1) (fun2 :@ arg2) _ _ = local {mode := Equal} $ do compare0 fun1 fun2 compare0 arg1 arg2 compareN' e@(_ :@ _) f _ _ = clashE e f compareN' (CasePair pi1 pair1 _ ret1 _ _ body1) (CasePair pi2 pair2 _ ret2 _ _ body2) _ _ = local {mode := Equal} $ do expectEqualQ pi1 pi2 compare0 pair1 pair2 compare0 ret1 ret2 compare0 body1 body2 compareN' e@(CasePair {}) f _ _ = clashE e f -- retain the mode unlike above because dimensions can't do -- anything that would mess up variance compareN' (fun1 :% dim1) (fun2 :% dim2) _ _ = do compare0 fun1 fun2 expectEqualD dim1 dim2 compareN' e@(_ :% _) f _ _ = clashE e f -- using the same mode for the type allows, e.g. -- A : ★₁ ≔ ★₀, B : ★₃ ≔ ★₂ ⊢ A <: B -- which, since A : ★₁ implies A : ★₃, should be fine compareN' (tm1 :# ty1) (tm2 :# ty2) _ _ = do compare0 tm1 tm2 compare0 ty1 ty2 compareN' e@(_ :# _) f _ _ = clashE e f namespace Term export covering %inline compareN : CanEqual q m => Eq q => NonRedexTerm q 0 n defs -> NonRedexTerm q 0 n defs -> m () compareN s t = compareN' s.fst t.fst s.snd t.snd export covering %inline compare : CanEqual q m => Eq q => DimEq d -> Term q d n -> Term q d n -> m () compare eqs s t = for_ (splits eqs) $ \th => compare0 (s /// th) (t /// th) export covering %inline compare0 : CanEqual q m => Eq q => Term q 0 n -> Term q 0 n -> m () compare0 s t = compareN (whnf defs s) (whnf defs t) namespace Elim covering %inline compareN : CanEqual q m => Eq q => NonRedexElim q 0 n defs -> NonRedexElim q 0 n defs -> m () compareN e f = compareN' e.fst f.fst e.snd f.snd export covering %inline compare : CanEqual q m => Eq q => DimEq d -> Elim q d n -> Elim q d n -> m () compare eqs e f = for_ (splits eqs) $ \th => compare0 (e /// th) (f /// th) export covering %inline compare0 : CanEqual q m => Eq q => Elim q 0 n -> Elim q 0 n -> m () compare0 e f = compareN (whnf defs e) (whnf defs f) namespace ScopeTermN export covering %inline compare0 : {s : Nat} -> CanEqual q m => Eq q => ScopeTermN s q 0 n -> ScopeTermN s q 0 n -> m () compare0 (TUnused body0) (TUnused body1) = compare0 body0 body1 compare0 body0 body1 = compare0 body0.term body1.term -- [todo] extend to multi-var scopes namespace DScopeTerm export covering %inline compare0 : CanEqual q m => Eq q => DScopeTerm q 0 n -> DScopeTerm q 0 n -> m () compare0 (DUnused body0) (DUnused body1) = compare0 body0 body1 compare0 body0 body1 = do compare0 body0.zero body1.zero compare0 body0.one body1.one namespace Term export covering %inline equal : HasErr q m => Eq q => DimEq d -> Term q d n -> Term q d n -> m () equal eqs s t {m} = runReaderT {m} (MakeEnv Equal) $ compare eqs s t export covering %inline sub : HasErr q m => HasDefs' q _ m => Eq q => DimEq d -> Term q d n -> Term q d n -> m () sub eqs s t {m} = runReaderT {m} (MakeEnv Sub) $ compare eqs s t namespace Elim export covering %inline equal : HasErr q m => Eq q => DimEq d -> Elim q d n -> Elim q d n -> m () equal eqs e f {m} = runReaderT {m} (MakeEnv Equal) $ compare eqs e f export covering %inline sub : HasErr q m => HasDefs' q _ m => Eq q => DimEq d -> Elim q d n -> Elim q d n -> m () sub eqs e f {m} = runReaderT {m} (MakeEnv Sub) $ compare eqs e f