patuni/Context.lhs
2024-06-23 18:03:32 +02:00

266 lines
8.8 KiB
Text

This module defines unification problems, metacontexts and operations
for working on them in the |Contextual| monad.
> module Context where
> import Control.Applicative
> import Control.Monad.Identity
> import Control.Monad.Except
> import Control.Monad.Reader
> import Control.Monad.State
> import Data.Bifunctor
> import Debug.Trace
> import Unbound.Generics.LocallyNameless
> import Unbound.Generics.LocallyNameless.Bind (Bind(..))
> import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind)
> import GHC.Generics
> import qualified Data.Set as Set
> import Kit
> import Tm
> data Dec = HOLE | DEFN VAL
> deriving (Show, Generic, Alpha, Subst VAL)
> instance Occurs Dec where
> occurrence _ HOLE = Nothing
> occurrence xs (DEFN t) = occurrence xs t
> frees _ HOLE = Set.empty
> frees isMeta (DEFN t) = frees isMeta t
> data Equation = EQN Type VAL Type VAL
> deriving (Show, Generic, Alpha, Subst VAL)
> instance Occurs Equation where
> occurrence xs (EQN _S s _T t) = occurrence xs [_S, s, _T, t]
> frees isMeta (EQN _S s _T t) = frees isMeta [_S, s, _T, t]
> instance Pretty Equation where
> pretty (EQN _S s _T t) =
> f <$> pretty _S <*> pretty s <*> pretty _T <*> pretty t
> where f _S' s' _T' t' = parens (s' <+> colon <+> _S') <+>
> text "==" <+> parens (t' <+> colon <+> _T')
> data Problem = Unify Equation
> | All Param (Bind Nom Problem)
> deriving (Show, Generic, Alpha)
> instance Occurs Problem where
> occurrence xs (Unify q) = occurrence xs q
> occurrence xs (All e (B _ p)) = max (occurrence xs e) (occurrence xs p)
> frees isMeta (Unify q) = frees isMeta q
> frees isMeta (All e (B _ p)) = frees isMeta e `Set.union` frees isMeta p
> instance Subst VAL Problem where
> substs s (Unify q) = Unify (substs s q)
> substs s (All e b) = All (substs s e) (bind x (substs s p))
> where (x, p) = unsafeUnbind b
> instance Pretty Problem where
> pretty (Unify q) = pretty q
> pretty (All e b) = lunbind b $ \ (x, p) -> do
> e' <- pretty e
> x' <- pretty x
> p' <- pretty p
> return $ parens (x' <+> colon <+> e') <+> text "->" <+> p'
> allProb :: Nom -> Type -> Problem -> Problem
> allProb x _T p = All (P _T) (bind x p)
> allTwinsProb :: Nom -> Type -> Type -> Problem -> Problem
> allTwinsProb x _S _T p = All (Twins _S _T) (bind x p)
> wrapProb :: [(Nom, Param)] -> Problem -> Problem
> wrapProb [] p = p
> wrapProb ((x, e) : _Gam) p = All e (bind x (wrapProb _Gam p))
> newtype ProbId = ProbId Nom
> deriving stock (Eq, Show, Generic)
> deriving newtype (Pretty)
> deriving anyclass (Alpha, Subst VAL)
> data ProblemState = Blocked | Active | Pending [ProbId] | Solved | Failed Err
> deriving (Eq, Show, Generic, Alpha, Subst VAL)
> instance Pretty ProblemState where
> pretty Blocked = return $ text "BLOCKED"
> pretty Active = return $ text "ACTIVE"
> pretty (Pending xs) = return $ text $ "PENDING " ++ show xs
> pretty Solved = return $ text "SOLVED"
> pretty (Failed e) = return $ text $ "FAILED: " ++ e
> data Entry = E Nom Type Dec
> | Prob ProbId Problem ProblemState
> deriving (Show, Generic, Alpha, Subst VAL)
> instance Occurs Entry where
> occurrence xs (E _ _T d) = occurrence xs _T `max` occurrence xs d
> occurrence xs (Prob _ p _) = occurrence xs p
> frees isMeta (E _ _T d) = frees isMeta _T `Set.union` frees isMeta d
> frees isMeta (Prob _ p _) = frees isMeta p
> instance Pretty Entry where
> pretty (E x _T HOLE) = between (text "? :") <$> pretty x <*> pretty _T
> pretty (E x _T (DEFN d)) =
> (\d' -> between (text ":=" <+> d' <+> text ":")) <$>
> prettyAt PiSize d <*> pretty x <*> prettyAt PiSize _T
> pretty (Prob x p s) =
> between (text "<=") <$>
> (between (text "?? :") <$> pretty x <*> pretty p) <*>
> pretty s
> type ContextL = Bwd Entry
> type ContextR = [Either Subs Entry]
> type Context = (ContextL, ContextR)
> type VarEntry = (Nom, Type)
> type HoleEntry = (Nom, Type)
> data Param = P Type | Twins Type Type
> deriving (Show, Generic, Alpha, Subst VAL)
> instance Occurs Param where
> occurrence xs (P _T) = occurrence xs _T
> occurrence xs (Twins _S _T) = max (occurrence xs _S) (occurrence xs _T)
> frees isMeta (P _T) = frees isMeta _T
> frees isMeta (Twins _S _T) = frees isMeta _S `Set.union` frees isMeta _T
> instance Pretty Param where
> pretty (P _T) = pretty _T
> pretty (Twins _S _T) = between (text "&") <$> pretty _S <*> pretty _T
> type Params = [(Nom, Param)]
> instance Pretty Context where
> pretty (cl, cr) = pair <$> prettyEntries (trail cl)
> <*> fmap vcat (mapM f cr)
> where
> pair cl' cr' = cl' $+$ text "*" $+$ cr'
> f (Left ns) = prettySubs ns
> f (Right e) = pretty e
> prettyEntries :: (Applicative m, LFresh m, MonadReader Size m) => [Entry] -> m Doc
> prettyEntries xs = vcat <$> mapM pretty xs
> prettySubs :: (Applicative m, LFresh m, MonadReader Size m) => Subs -> m Doc
> prettySubs ns = brackets . commaSep <$>
> forM ns (\ (x, v) -> between (text "|->") <$> pretty x <*> pretty v)
> prettyDeps :: (Applicative m, LFresh m, MonadReader Size m) => [(Nom, Type)] -> m Doc
> prettyDeps ns = brackets . commaSep <$>
> forM ns (\ (x, _T) -> between (text ":") <$> pretty x <*> pretty _T)
> prettyDefns :: (Applicative m, LFresh m, MonadReader Size m) => [(Nom, Type, VAL)] -> m Doc
> prettyDefns ns = brackets . commaSep <$>
> forM ns (\ (x, _T, v) -> f <$> pretty x <*> pretty _T <*> pretty v)
> where f x' _T' v' = x' <+> text ":=" <+> v' <+> text ":" <+> _T'
> prettyParams :: (Applicative m, LFresh m, MonadReader Size m) => Params -> m Doc
> prettyParams xs = vcat <$>
> forM xs (\ (x, p) -> between colon <$> pretty x <*> pretty p)
> type Err = String
> newtype Contextual a = Contextual { unContextual ::
> ReaderT Params (StateT Context (FreshMT (ExceptT Err Identity))) a }
> deriving newtype (Functor, Applicative, Monad,
> Fresh, MonadError Err,
> MonadState Context, MonadReader Params,
> MonadPlus, Alternative)
> instance MonadFail Contextual where fail = throwError
> ctrace :: String -> Contextual ()
> ctrace s = do
> cx <- get
> _Gam <- ask
> trace (s ++ "\n" ++ pp cx ++ "\n---\n" ++ ppWith prettyParams _Gam)
> (return ()) >>= \ () -> return ()
> runContextual :: Context -> Contextual a -> Either Err (a, Context)
> runContextual cx = runIdentity . runExceptT . runFreshMT . flip runStateT cx . flip runReaderT [] . unContextual
> modifyL :: (ContextL -> ContextL) -> Contextual ()
> modifyL f = modify $ first f
> modifyR :: (ContextR -> ContextR) -> Contextual ()
> modifyR f = modify $ second f
> pushL :: Entry -> Contextual ()
> pushL e = modifyL (:< e)
> pushR :: Either Subs Entry -> Contextual ()
> pushR (Left s) = pushSubs s
> pushR (Right e) = modifyR (Right e :)
> pushSubs :: Subs -> Contextual ()
> pushSubs [] = return ()
> pushSubs n = modifyR (\ cr -> if null cr then [] else Left n : cr)
> popL :: Contextual Entry
> popL = do
> cx <- getL
> case cx of
> (cx' :< e) -> putL cx' >> return e
> B0 -> error "popL ran out of context"
> popR :: Contextual (Maybe (Either Subs Entry))
> popR = do
> cx <- getR
> case cx of
> (x : cx') -> putR cx' >> return (Just x)
> [] -> return Nothing
> getL :: MonadState Context m => m ContextL
> getL = gets fst
> getR :: Contextual ContextR
> getR = gets snd
> putL :: ContextL -> Contextual ()
> putL x = modifyL (const x)
> putR :: ContextR -> Contextual ()
> putR x = modifyR (const x)
> inScope :: MonadReader Params m => Nom -> Param -> m a -> m a
> inScope x p = local (++ [(x, p)])
> localParams :: (Params -> Params) -> Contextual a -> Contextual a
> localParams = local
> lookupVar :: (MonadReader Params m, MonadFail m) => Nom -> Twin -> m Type
> lookupVar x w = look =<< ask
> where
> look [] = fail $ "lookupVar: missing " ++ show x
> look ((y, e) : _) | x == y = case (e, w) of
> (P _T, Only) -> return _T
> (Twins _S _T, TwinL) -> return _S
> (Twins _S _T, TwinR) -> return _T
> _ -> fail "lookupVar: evil twin"
> look (_ : _Gam) = look _Gam
> lookupMeta :: (MonadState Context m, MonadFail m) => Nom -> m Type
> lookupMeta x = look =<< getL
> where
> look :: MonadFail m => ContextL -> m Type
> look B0 = fail $ "lookupMeta: missing " ++ show x
> look (cx :< E y t _) | x == y = return t
> | otherwise = look cx
> look (cx :< Prob {}) = look cx