unfold definitions in equality checking, plus cleanup

This commit is contained in:
rhiannon morris 2022-08-23 05:43:23 +02:00
parent 44778825c2
commit 8a55cc9581
4 changed files with 137 additions and 81 deletions

View file

@ -2,52 +2,71 @@ module Quox.Equal
import public Quox.Syntax import public Quox.Syntax
import public Quox.Definition import public Quox.Definition
import Control.Monad.Either import public Quox.Typing
import Generics.Derive import public Control.Monad.Either
import public Control.Monad.Reader
import Data.Maybe
%default total
%language ElabReflection
public export
data Mode = Equal | Sub
%runElab derive "Mode" [Generic, Meta, Eq, Ord, DecEq, Show]
public export
data Error
= ClashT Mode (Term d n) (Term d n)
| ClashU Mode Universe Universe
| ClashQ Qty Qty
private %inline private %inline
ClashE : Mode -> Elim d n -> Elim d n -> Error ClashE : EqMode -> Elim d n -> Elim d n -> Error
ClashE mode = ClashT mode `on` E ClashE mode = ClashT mode `on` E
parameters {auto _ : MonadError Error m} public export
CanEq : (Type -> Type) -> Type
CanEq m = (MonadError Error m, MonadReader Definitions m)
parameters {auto _ : CanEq m}
private %inline private %inline
clashT : Mode -> Term d n -> Term d n -> m a clashT : EqMode -> Term d n -> Term d n -> m a
clashT mode = throwError .: ClashT mode clashT mode = throwError .: ClashT mode
private %inline private %inline
clashE : Mode -> Elim d n -> Elim d n -> m a clashE : EqMode -> Elim d n -> Elim d n -> m a
clashE mode = throwError .: ClashE mode clashE mode = throwError .: ClashE mode
private %inline
defE : Name -> m (Maybe (Elim d n))
defE x = asks $ \defs => do
g <- lookup x defs
pure $ (!g.term).def :# g.type.def
private %inline
defT : Name -> m (Maybe (Term d n))
defT x = map E <$> defE x
export %inline
compareU' : EqMode -> Universe -> Universe -> Bool
compareU' = \case Equal => (==); Sub => (<=)
export %inline
compareU : EqMode -> Universe -> Universe -> m ()
compareU mode k l = unless (compareU' mode k l) $
throwError $ ClashU mode k l
mutual mutual
private covering private covering
compareTN' : Mode -> compareTN' : EqMode ->
(s, t : Term 0 n) -> (s, t : Term 0 n) ->
(0 _ : NotRedexT s) -> (0 _ : NotRedexT t) -> m () (0 _ : NotRedexT s) -> (0 _ : NotRedexT t) -> m ()
compareTN' mode (TYPE k) (TYPE l) _ _ = compareTN' mode (E e) (E f) ps pt = compareE0 mode e f
case mode of -- if either term is a def, try to unfold it
Equal => unless (k == l) $ throwError $ ClashU Equal k l compareTN' mode s@(E (F x)) t ps pt = do
Sub => unless (k <= l) $ throwError $ ClashU Sub k l Just s' <- defT x | Nothing => clashT mode s t
compareT0 mode s' t
compareTN' mode s t@(E (F y)) ps pt = do
Just t' <- defT y | Nothing => clashT mode s t
compareT0 mode s t'
compareTN' mode s@(E _) t _ _ = clashT mode s t
compareTN' mode (TYPE k) (TYPE l) _ _ = compareU mode k l
compareTN' mode s@(TYPE _) t _ _ = clashT mode s t compareTN' mode s@(TYPE _) t _ _ = clashT mode s t
compareTN' mode (Pi qty1 _ arg1 res1) (Pi qty2 _ arg2 res2) _ _ = do compareTN' mode (Pi qty1 _ arg1 res1) (Pi qty2 _ arg2 res2) _ _ = do
-- [todo] this should probably always be ==, right..?
unless (qty1 == qty2) $ throwError $ ClashQ qty1 qty2 unless (qty1 == qty2) $ throwError $ ClashQ qty1 qty2
compareT0 mode arg2 arg1 -- reversed for contravariant Sub compareT0 mode arg2 arg1 -- reversed for contravariant domain
compareTS0 mode res1 res2 compareTS0 mode res1 res2
compareTN' mode s@(Pi {}) t _ _ = clashT mode s t compareTN' mode s@(Pi {}) t _ _ = clashT mode s t
@ -56,20 +75,25 @@ parameters {auto _ : MonadError Error m}
compareTS0 Equal body1 body2 compareTS0 Equal body1 body2
compareTN' mode s@(Lam {}) t _ _ = clashT mode s t compareTN' mode s@(Lam {}) t _ _ = clashT mode s t
compareTN' mode (E e) (E f) ps pt = compareE0 mode e f
compareTN' mode s@(E _) t _ _ = clashT mode s t
compareTN' _ (CloT {}) _ ps _ = void $ ps IsCloT compareTN' _ (CloT {}) _ ps _ = void $ ps IsCloT
compareTN' _ (DCloT {}) _ ps _ = void $ ps IsDCloT compareTN' _ (DCloT {}) _ ps _ = void $ ps IsDCloT
private covering private covering
compareEN' : Mode -> compareEN' : EqMode ->
(e, f : Elim 0 n) -> (e, f : Elim 0 n) ->
(0 _ : NotRedexE e) -> (0 _ : NotRedexE f) -> m () (0 _ : NotRedexE e) -> (0 _ : NotRedexE f) -> m ()
compareEN' mode e@(F x) f@(F y) _ _ = compareEN' mode e@(F x) f@(F y) _ _ =
unless (x == y) $ clashE mode e f if x == y then pure () else
compareEN' mode e@(F _) f _ _ = clashE mode e f case (!(defE x), !(defE y)) of
(Nothing, Nothing) => clashE mode e f
(s', t') => compareE0 mode (fromMaybe e s') (fromMaybe f t')
compareEN' mode e@(F x) f _ _ = do
Just e' <- defE x | Nothing => clashE mode e f
compareE0 mode e' f
compareEN' mode e f@(F y) _ _ = do
Just f' <- defE y | Nothing => clashE mode e f
compareE0 mode e f'
compareEN' mode e@(B i) f@(B j) _ _ = compareEN' mode e@(B i) f@(B j) _ _ =
unless (i == j) $ clashE mode e f unless (i == j) $ clashE mode e f
@ -93,35 +117,35 @@ parameters {auto _ : MonadError Error m}
private covering %inline private covering %inline
compareTN : Mode -> NonRedexTerm 0 n -> NonRedexTerm 0 n -> m () compareTN : EqMode -> NonRedexTerm 0 n -> NonRedexTerm 0 n -> m ()
compareTN mode s t = compareTN' mode s.fst t.fst s.snd t.snd compareTN mode s t = compareTN' mode s.fst t.fst s.snd t.snd
private covering %inline private covering %inline
compareEN : Mode -> NonRedexElim 0 n -> NonRedexElim 0 n -> m () compareEN : EqMode -> NonRedexElim 0 n -> NonRedexElim 0 n -> m ()
compareEN mode e f = compareEN' mode e.fst f.fst e.snd f.snd compareEN mode e f = compareEN' mode e.fst f.fst e.snd f.snd
export covering %inline export covering %inline
compareT : Mode -> DimEq d -> Term d n -> Term d n -> m () compareT : EqMode -> DimEq d -> Term d n -> Term d n -> m ()
compareT mode eqs s t = compareT mode eqs s t =
for_ (splits eqs) $ \th => compareT0 mode (s /// th) (t /// th) for_ (splits eqs) $ \th => compareT0 mode (s /// th) (t /// th)
export covering %inline export covering %inline
compareE : Mode -> DimEq d -> Elim d n -> Elim d n -> m () compareE : EqMode -> DimEq d -> Elim d n -> Elim d n -> m ()
compareE mode eqs e f = compareE mode eqs e f =
for_ (splits eqs) $ \th => compareE0 mode (e /// th) (f /// th) for_ (splits eqs) $ \th => compareE0 mode (e /// th) (f /// th)
export covering %inline export covering %inline
compareT0 : Mode -> Term 0 n -> Term 0 n -> m () compareT0 : EqMode -> Term 0 n -> Term 0 n -> m ()
compareT0 mode s t = compareTN mode (whnfT s) (whnfT t) compareT0 mode s t = compareTN mode (whnfT s) (whnfT t)
export covering %inline export covering %inline
compareE0 : Mode -> Elim 0 n -> Elim 0 n -> m () compareE0 : EqMode -> Elim 0 n -> Elim 0 n -> m ()
compareE0 mode e f = compareEN mode (whnfE e) (whnfE f) compareE0 mode e f = compareEN mode (whnfE e) (whnfE f)
export covering %inline export covering %inline
compareTS0 : Mode -> ScopeTerm 0 n -> ScopeTerm 0 n -> m () compareTS0 : EqMode -> ScopeTerm 0 n -> ScopeTerm 0 n -> m ()
compareTS0 mode (TUnused body1) (TUnused body2) = compareTS0 mode (TUnused body1) (TUnused body2) =
compareT0 mode body1 body2 compareT0 mode body1 body2
compareTS0 mode body1 body2 = compareTS0 mode body1 body2 =

View file

@ -2,9 +2,9 @@ module Quox.Typechecker
import public Quox.Syntax import public Quox.Syntax
import public Quox.Typing import public Quox.Typing
import Control.Monad.Either import public Quox.Equal
import public Control.Monad.Either
%hide Equal.Error
%default total %default total
@ -27,7 +27,7 @@ private %inline
expectEqualQ : MonadError Error m => expectEqualQ : MonadError Error m =>
(expect, actual : Qty) -> m () (expect, actual : Qty) -> m ()
expectEqualQ pi rh = expectEqualQ pi rh =
unless (pi == rh) $ throwError $ EqualError $ ClashQ pi rh unless (pi == rh) $ throwError $ ClashQ pi rh
private %inline private %inline
@ -40,12 +40,6 @@ tail : TyContext d (S n) -> TyContext d n
tail = {tctx $= tail, qctx $= tail} tail = {tctx $= tail, qctx $= tail}
private %inline
globalSubjQty : Global -> Qty
globalSubjQty (MkGlobal {qty = Zero, _}) = Zero
globalSubjQty (MkGlobal {qty = Any, _}) = One
private %inline private %inline
weakI : InferResult d n -> InferResult d (S n) weakI : InferResult d n -> InferResult d (S n)
weakI = {type $= weakT, qout $= (:< zero)} weakI = {type $= weakT, qout $= (:< zero)}
@ -66,30 +60,35 @@ subjMult sg qty =
else Element One %search else Element One %search
mutual public export
CanTC : (Type -> Type) -> Type
CanTC m = (MonadError Error m, MonadReader Definitions m)
parameters {auto _ : CanTC m}
mutual
-- [todo] it seems like the options here for dealing with substitutions are -- [todo] it seems like the options here for dealing with substitutions are
-- to either push them or parametrise the whole typechecker over ambient -- to either push them or parametrise the whole typechecker over ambient
-- substitutions. both of them seem like the same amount of work for the -- substitutions. both of them seem like the same amount of work for the
-- computer but pushing is less work for the me -- computer but pushing is less work for the me
export covering %inline export covering %inline
check : MonadError Error m => {d, n : Nat} -> check : {d, n : Nat} ->
(ctx : TyContext d n) -> (sg : Qty) -> {auto 0 sgs : IsSubj sg} -> (ctx : TyContext d n) -> (sg : Qty) -> (0 _ : IsSubj sg) =>
(subj : Term d n) -> (ty : Term d n) -> (subj : Term d n) -> (ty : Term d n) ->
m (CheckResult n) m (CheckResult n)
check ctx sg subj ty = check' ctx sg (pushSubstsT subj) ty check ctx sg subj ty = check' ctx sg (pushSubstsT subj) ty
export covering %inline export covering %inline
infer : MonadError Error m => {d, n : Nat} -> infer : {d, n : Nat} ->
(ctx : TyContext d n) -> (sg : Qty) -> {auto 0 sgs : IsSubj sg} -> (ctx : TyContext d n) -> (sg : Qty) -> (0 _ : IsSubj sg) =>
(subj : Elim d n) -> (subj : Elim d n) ->
m (InferResult d n) m (InferResult d n)
infer ctx sg subj = infer' ctx sg (pushSubstsE subj) infer ctx sg subj = infer' ctx sg (pushSubstsE subj)
export covering export covering
check' : MonadError Error m => {d, n : Nat} -> check' : {d, n : Nat} ->
(ctx : TyContext d n) -> (sg : Qty) -> {auto 0 sgs : IsSubj sg} -> (ctx : TyContext d n) -> (sg : Qty) -> (0 _ : IsSubj sg) =>
(subj : NotCloTerm d n) -> (ty : Term d n) -> (subj : NotCloTerm d n) -> (ty : Term d n) ->
m (CheckResult n) m (CheckResult n)
@ -99,7 +98,6 @@ mutual
unless (l < l') $ throwError $ BadUniverse l l' unless (l < l') $ throwError $ BadUniverse l l'
pure zero pure zero
-- [todo] factor this stuff out
check' ctx sg (Element (Pi qty x arg res) _) ty = do check' ctx sg (Element (Pi qty x arg res) _) ty = do
l <- expectTYPE ty l <- expectTYPE ty
expectEqualQ zero sg expectEqualQ zero sg
@ -119,21 +117,19 @@ mutual
check' ctx sg (Element (E e) _) ty = do check' ctx sg (Element (E e) _) ty = do
infres <- infer ctx sg e infres <- infer ctx sg e
ignore $ check ctx zero ty (TYPE UAny) ignore $ check ctx zero ty (TYPE UAny)
either (throwError . EqualError) pure $ infres.type `subT` ty subT infres.type ty
pure infres.qout pure infres.qout
export covering export covering
infer' : MonadError Error m => {d, n : Nat} -> infer' : {d, n : Nat} ->
(ctx : TyContext d n) -> (sg : Qty) -> {auto 0 sgs : IsSubj sg} -> (ctx : TyContext d n) -> (sg : Qty) -> (0 _ : IsSubj sg) =>
(subj : NotCloElim d n) -> (subj : NotCloElim d n) ->
m (InferResult d n) m (InferResult d n)
infer' ctx sg (Element (F x) _) = infer' ctx sg (Element (F x) _) = do
case lookup x ctx.globals of Just g <- asks $ lookup x | Nothing => throwError $ NotInScope x
Just g => do when (isZero g) $ expectEqualQ sg Zero
expectEqualQ (globalSubjQty g) sg pure $ InfRes {type = g.type.def, qout = zero}
pure $ InfRes {type = g.type, qout = zero}
Nothing => throwError $ NotInScope x
infer' ctx sg (Element (B i) _) = infer' ctx sg (Element (B i) _) =
pure $ lookupBound sg i ctx pure $ lookupBound sg i ctx

View file

@ -2,13 +2,17 @@ module Quox.Typing
import public Quox.Syntax import public Quox.Syntax
import public Quox.Context import public Quox.Context
import public Quox.Equal
import public Quox.Definition import public Quox.Definition
import Data.Nat import Data.Nat
import public Data.SortedMap import public Data.SortedMap
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Generics.Derive
%hide TT.Name
%default total
%language ElabReflection
%default total %default total
@ -35,7 +39,6 @@ QOutput = QContext
public export public export
record TyContext (d, n : Nat) where record TyContext (d, n : Nat) where
constructor MkTyContext constructor MkTyContext
globals : Globals
dctx : DContext d dctx : DContext d
tctx : TContext d n tctx : TContext d n
qctx : QContext n qctx : QContext n
@ -88,11 +91,18 @@ record InferResult d n where
qout : QOutput n qout : QOutput n
public export
data EqMode = Equal | Sub
%runElab derive "EqMode" [Generic, Meta, Eq, Ord, DecEq, Show]
public export public export
data Error data Error
= NotInScope Name = ExpectedTYPE (Term d n)
| ExpectedTYPE (Term d n)
| ExpectedPi (Term d n) | ExpectedPi (Term d n)
| BadUniverse Universe Universe | BadUniverse Universe Universe
| EqualError (Equal.Error)
%hide Equal.Error | ClashT EqMode (Term d n) (Term d n)
| ClashU EqMode Universe Universe
| ClashQ Qty Qty
| NotInScope Name

View file

@ -5,30 +5,45 @@ import Quox.Pretty
import TAP import TAP
export export
ToInfo Equal.Error where ToInfo Error where
toInfo (NotInScope x) =
[("type", "NotInScope"),
("name", show x)]
toInfo (ExpectedTYPE t) =
[("type", "ExpectedTYPE"),
("got", prettyStr True t)]
toInfo (ExpectedPi t) =
[("type", "ExpectedPi"),
("got", prettyStr True t)]
toInfo (BadUniverse k l) =
[("type", "BadUniverse"),
("low", show k),
("high", show l)]
toInfo (ClashT mode s t) = toInfo (ClashT mode s t) =
[("clash", "term"), [("type", "ClashT"),
("mode", show mode), ("mode", show mode),
("left", prettyStr True s), ("left", prettyStr True s),
("right", prettyStr True t)] ("right", prettyStr True t)]
toInfo (ClashU mode k l) = toInfo (ClashU mode k l) =
[("clash", "universe"), [("type", "ClashU"),
("mode", show mode), ("mode", show mode),
("left", prettyStr True k), ("left", prettyStr True k),
("right", prettyStr True l)] ("right", prettyStr True l)]
toInfo (ClashQ pi rh) = toInfo (ClashQ pi rh) =
[("clash", "quantity"), [("type", "ClashQ"),
("left", prettyStr True pi), ("left", prettyStr True pi),
("right", prettyStr True rh)] ("right", prettyStr True rh)]
M = Either Equal.Error M = ReaderT Definitions (Either Error)
testEq : String -> Lazy (M ()) -> Test parameters (label : String) (act : Lazy (M ()))
testEq = test {default empty globals : Definitions}
testEq : Test
testEq = test label $ runReaderT globals act
testNeq : String -> Lazy (M ()) -> Test testNeq : Test
testNeq label = testThrows label $ const True testNeq = testThrows label (const True) $ runReaderT globals act
subT : {default 0 d, n : Nat} -> Term d n -> Term d n -> M () subT : {default 0 d, n : Nat} -> Term d n -> Term d n -> M ()
@ -144,11 +159,22 @@ tests = "equality & subtyping" :- [
todo "term d-closure", todo "term d-closure",
"free var" :- [ "free var" :-
let au_bu = fromList
[("A", MkDef Any (TYPE (U 1)) (TYPE (U 0))),
("B", MkDef Any (TYPE (U 1)) (TYPE (U 0)))]
au_ba = fromList
[("A", MkDef Any (TYPE (U 1)) (TYPE (U 0))),
("B", MkDef Any (TYPE (U 1)) (FT "A"))]
in [
testEq "A ≡ A" $ testEq "A ≡ A" $
equalE (F "A") (F "A"), equalE (F "A") (F "A"),
testNeq "A ≢ B" $ testNeq "A ≢ B" $
equalE (F "A") (F "B"), equalE (F "A") (F "B"),
testEq "A ≔ ★₀, B ≔ ★₀ ⊢ A ≡ B" {globals = au_bu} $
equalE (F "A") (F "B"),
testEq "A ≔ ★₀, B ≔ A ⊢ A ≡ B" {globals = au_ba} $
equalE (F "A") (F "B"),
testEq "A <: A" $ testEq "A <: A" $
subE (F "A") (F "A"), subE (F "A") (F "A"),
testNeq "A ≮: B" $ testNeq "A ≮: B" $