Export local throwE/catchE/fromExceptT helpers
This commit is contained in:
parent
decd5d9cb2
commit
17b9279287
1 changed files with 23 additions and 11 deletions
|
@ -19,6 +19,9 @@ module Network.Protocol.TLS.GNU
|
||||||
( TLS
|
( TLS
|
||||||
, Session
|
, Session
|
||||||
, Error (..)
|
, Error (..)
|
||||||
|
, throwE
|
||||||
|
, catchE
|
||||||
|
, fromExceptT
|
||||||
|
|
||||||
, runTLS
|
, runTLS
|
||||||
, runClient
|
, runClient
|
||||||
|
@ -42,7 +45,7 @@ import Control.Applicative (Applicative, pure, (<*>))
|
||||||
import qualified Control.Concurrent.MVar as M
|
import qualified Control.Concurrent.MVar as M
|
||||||
import Control.Monad (ap, when, foldM, foldM_)
|
import Control.Monad (ap, when, foldM, foldM_)
|
||||||
import Control.Monad.Trans.Class (lift)
|
import Control.Monad.Trans.Class (lift)
|
||||||
import Control.Monad.Trans.Except
|
import qualified Control.Monad.Trans.Except as E
|
||||||
import qualified Control.Monad.Trans.Reader as R
|
import qualified Control.Monad.Trans.Reader as R
|
||||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
|
@ -66,11 +69,11 @@ globalInitMVar :: M.MVar ()
|
||||||
{-# NOINLINE globalInitMVar #-}
|
{-# NOINLINE globalInitMVar #-}
|
||||||
globalInitMVar = unsafePerformIO $ M.newMVar ()
|
globalInitMVar = unsafePerformIO $ M.newMVar ()
|
||||||
|
|
||||||
globalInit :: ExceptT Error IO ()
|
globalInit :: E.ExceptT Error IO ()
|
||||||
globalInit = do
|
globalInit = do
|
||||||
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
|
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
|
||||||
F.ReturnCode rc <- liftIO init_
|
F.ReturnCode rc <- liftIO init_
|
||||||
when (rc < 0) $ throwE $ mapError rc
|
when (rc < 0) $ E.throwE $ mapError rc
|
||||||
|
|
||||||
globalDeinit :: IO ()
|
globalDeinit :: IO ()
|
||||||
globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit
|
globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit
|
||||||
|
@ -88,7 +91,7 @@ data Session = Session
|
||||||
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
||||||
}
|
}
|
||||||
|
|
||||||
newtype TLS a = TLS { unTLS :: ExceptT Error (R.ReaderT Session UIO) a }
|
newtype TLS a = TLS { unTLS :: E.ExceptT Error (R.ReaderT Session UIO) a }
|
||||||
|
|
||||||
instance Functor TLS where
|
instance Functor TLS where
|
||||||
fmap f = TLS . fmap f . unTLS
|
fmap f = TLS . fmap f . unTLS
|
||||||
|
@ -103,10 +106,19 @@ instance Monad TLS where
|
||||||
|
|
||||||
-- | This is a transitional instance and may be deprecated in the future
|
-- | This is a transitional instance and may be deprecated in the future
|
||||||
instance MonadIO TLS where
|
instance MonadIO TLS where
|
||||||
liftIO = TLS . withExceptT IOError . UIO.fromIO' (userError . show)
|
liftIO = TLS . E.withExceptT IOError . UIO.fromIO' (userError . show)
|
||||||
|
|
||||||
|
throwE :: Error -> TLS a
|
||||||
|
throwE = fromExceptT . E.throwE
|
||||||
|
|
||||||
|
catchE :: TLS a -> (Error -> TLS a) -> TLS a
|
||||||
|
catchE inner handler = TLS $ unTLS inner `E.catchE` (unTLS . handler)
|
||||||
|
|
||||||
|
fromExceptT :: E.ExceptT Error UIO a -> TLS a
|
||||||
|
fromExceptT = TLS . E.mapExceptT lift
|
||||||
|
|
||||||
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
|
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
|
||||||
runTLS s tls = UIO.lift $ R.runReaderT (runExceptT (unTLS tls)) s
|
runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT (unTLS tls)) s
|
||||||
|
|
||||||
runClient :: Transport -> TLS a -> IO (Either Error a)
|
runClient :: Transport -> TLS a -> IO (Either Error a)
|
||||||
runClient transport tls = do
|
runClient transport tls = do
|
||||||
|
@ -116,10 +128,10 @@ runClient transport tls = do
|
||||||
Right session -> runTLS session tls
|
Right session -> runTLS session tls
|
||||||
|
|
||||||
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
|
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
|
||||||
newSession transport end = F.alloca $ \sPtr -> runExceptT $ do
|
newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
|
||||||
globalInit
|
globalInit
|
||||||
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
|
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
|
||||||
when (rc < 0) $ throwE $ mapError rc
|
when (rc < 0) $ E.throwE $ mapError rc
|
||||||
liftIO $ do
|
liftIO $ do
|
||||||
ptr <- F.peek sPtr
|
ptr <- F.peek sPtr
|
||||||
let session = F.Session ptr
|
let session = F.Session ptr
|
||||||
|
@ -151,7 +163,7 @@ putBytes = putChunks . BL.toChunks where
|
||||||
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
|
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
|
||||||
case maybeErr of
|
case maybeErr of
|
||||||
Nothing -> return ()
|
Nothing -> return ()
|
||||||
Just err -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral err
|
Just err -> throwE $ mapError $ fromIntegral err
|
||||||
|
|
||||||
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
|
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
|
||||||
loop ptr len = do
|
loop ptr len = do
|
||||||
|
@ -179,7 +191,7 @@ getBytes count = do
|
||||||
|
|
||||||
case mbytes of
|
case mbytes of
|
||||||
Just bytes -> return bytes
|
Just bytes -> return bytes
|
||||||
Nothing -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral len
|
Nothing -> throwE $ mapError $ fromIntegral len
|
||||||
|
|
||||||
checkPending :: TLS Integer
|
checkPending :: TLS Integer
|
||||||
checkPending = withSession $ \s -> do
|
checkPending = withSession $ \s -> do
|
||||||
|
@ -242,7 +254,7 @@ withSession io = do
|
||||||
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
||||||
|
|
||||||
checkRC :: F.ReturnCode -> TLS ()
|
checkRC :: F.ReturnCode -> TLS ()
|
||||||
checkRC (F.ReturnCode x) = when (x < 0) $ TLS $ mapExceptT lift $ throwE $ mapError x
|
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
|
||||||
|
|
||||||
mapError :: F.CInt -> Error
|
mapError :: F.CInt -> Error
|
||||||
mapError = Error . toInteger
|
mapError = Error . toInteger
|
||||||
|
|
Loading…
Reference in a new issue