Export local throwE/catchE/fromExceptT helpers

This commit is contained in:
Stephen Paul Weber 2021-02-13 20:31:35 -05:00
parent decd5d9cb2
commit 17b9279287

View file

@ -19,6 +19,9 @@ module Network.Protocol.TLS.GNU
( TLS
, Session
, Error (..)
, throwE
, catchE
, fromExceptT
, runTLS
, runClient
@ -42,7 +45,7 @@ import Control.Applicative (Applicative, pure, (<*>))
import qualified Control.Concurrent.MVar as M
import Control.Monad (ap, when, foldM, foldM_)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Except
import qualified Control.Monad.Trans.Except as E
import qualified Control.Monad.Trans.Reader as R
import Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Data.ByteString as B
@ -66,11 +69,11 @@ globalInitMVar :: M.MVar ()
{-# NOINLINE globalInitMVar #-}
globalInitMVar = unsafePerformIO $ M.newMVar ()
globalInit :: ExceptT Error IO ()
globalInit :: E.ExceptT Error IO ()
globalInit = do
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
F.ReturnCode rc <- liftIO init_
when (rc < 0) $ throwE $ mapError rc
when (rc < 0) $ E.throwE $ mapError rc
globalDeinit :: IO ()
globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit
@ -88,7 +91,7 @@ data Session = Session
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
}
newtype TLS a = TLS { unTLS :: ExceptT Error (R.ReaderT Session UIO) a }
newtype TLS a = TLS { unTLS :: E.ExceptT Error (R.ReaderT Session UIO) a }
instance Functor TLS where
fmap f = TLS . fmap f . unTLS
@ -103,10 +106,19 @@ instance Monad TLS where
-- | This is a transitional instance and may be deprecated in the future
instance MonadIO TLS where
liftIO = TLS . withExceptT IOError . UIO.fromIO' (userError . show)
liftIO = TLS . E.withExceptT IOError . UIO.fromIO' (userError . show)
throwE :: Error -> TLS a
throwE = fromExceptT . E.throwE
catchE :: TLS a -> (Error -> TLS a) -> TLS a
catchE inner handler = TLS $ unTLS inner `E.catchE` (unTLS . handler)
fromExceptT :: E.ExceptT Error UIO a -> TLS a
fromExceptT = TLS . E.mapExceptT lift
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
runTLS s tls = UIO.lift $ R.runReaderT (runExceptT (unTLS tls)) s
runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT (unTLS tls)) s
runClient :: Transport -> TLS a -> IO (Either Error a)
runClient transport tls = do
@ -116,10 +128,10 @@ runClient transport tls = do
Right session -> runTLS session tls
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
newSession transport end = F.alloca $ \sPtr -> runExceptT $ do
newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
globalInit
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
when (rc < 0) $ throwE $ mapError rc
when (rc < 0) $ E.throwE $ mapError rc
liftIO $ do
ptr <- F.peek sPtr
let session = F.Session ptr
@ -151,7 +163,7 @@ putBytes = putChunks . BL.toChunks where
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
case maybeErr of
Nothing -> return ()
Just err -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral err
Just err -> throwE $ mapError $ fromIntegral err
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
loop ptr len = do
@ -179,7 +191,7 @@ getBytes count = do
case mbytes of
Just bytes -> return bytes
Nothing -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral len
Nothing -> throwE $ mapError $ fromIntegral len
checkPending :: TLS Integer
checkPending = withSession $ \s -> do
@ -242,7 +254,7 @@ withSession io = do
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
checkRC :: F.ReturnCode -> TLS ()
checkRC (F.ReturnCode x) = when (x < 0) $ TLS $ mapExceptT lift $ throwE $ mapError x
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
mapError :: F.CInt -> Error
mapError = Error . toInteger