Switch TLS to TLST to allow any Unexceptional base monad
This commit is contained in:
parent
ceac3318da
commit
06a662e63a
1 changed files with 30 additions and 34 deletions
|
@ -17,10 +17,9 @@
|
||||||
|
|
||||||
module Network.Protocol.TLS.GNU
|
module Network.Protocol.TLS.GNU
|
||||||
( TLS
|
( TLS
|
||||||
|
, TLST
|
||||||
, Session
|
, Session
|
||||||
, Error (..)
|
, Error (..)
|
||||||
, throwE
|
|
||||||
, fromExceptT
|
|
||||||
|
|
||||||
, runTLS
|
, runTLS
|
||||||
, runTLS'
|
, runTLS'
|
||||||
|
@ -46,7 +45,6 @@ import Control.Monad (when, foldM, foldM_)
|
||||||
import Control.Monad.Trans.Class (lift)
|
import Control.Monad.Trans.Class (lift)
|
||||||
import qualified Control.Monad.Trans.Except as E
|
import qualified Control.Monad.Trans.Except as E
|
||||||
import qualified Control.Monad.Trans.Reader as R
|
import qualified Control.Monad.Trans.Reader as R
|
||||||
import Control.Monad.IO.Class (liftIO)
|
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
import qualified Data.ByteString.Lazy as BL
|
import qualified Data.ByteString.Lazy as BL
|
||||||
import qualified Data.ByteString.Unsafe as B
|
import qualified Data.ByteString.Unsafe as B
|
||||||
|
@ -56,7 +54,7 @@ import qualified Foreign.C as F
|
||||||
import Foreign.Concurrent as FC
|
import Foreign.Concurrent as FC
|
||||||
import qualified System.IO as IO
|
import qualified System.IO as IO
|
||||||
import System.IO.Unsafe (unsafePerformIO)
|
import System.IO.Unsafe (unsafePerformIO)
|
||||||
import UnexceptionalIO.Trans (UIO, Unexceptional)
|
import UnexceptionalIO.Trans (Unexceptional)
|
||||||
import qualified UnexceptionalIO.Trans as UIO
|
import qualified UnexceptionalIO.Trans as UIO
|
||||||
|
|
||||||
import qualified Network.Protocol.TLS.GNU.Foreign as F
|
import qualified Network.Protocol.TLS.GNU.Foreign as F
|
||||||
|
@ -68,10 +66,10 @@ globalInitMVar :: M.MVar ()
|
||||||
{-# NOINLINE globalInitMVar #-}
|
{-# NOINLINE globalInitMVar #-}
|
||||||
globalInitMVar = unsafePerformIO $ M.newMVar ()
|
globalInitMVar = unsafePerformIO $ M.newMVar ()
|
||||||
|
|
||||||
globalInit :: E.ExceptT Error IO ()
|
globalInit :: (Unexceptional m) => E.ExceptT Error m ()
|
||||||
globalInit = do
|
globalInit = do
|
||||||
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
|
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
|
||||||
F.ReturnCode rc <- liftIO init_
|
F.ReturnCode rc <- UIO.unsafeFromIO init_
|
||||||
when (rc < 0) $ E.throwE $ mapError rc
|
when (rc < 0) $ E.throwE $ mapError rc
|
||||||
|
|
||||||
globalDeinit :: IO ()
|
globalDeinit :: IO ()
|
||||||
|
@ -90,33 +88,31 @@ data Session = Session
|
||||||
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
||||||
}
|
}
|
||||||
|
|
||||||
type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a
|
type TLS a = TLST IO a
|
||||||
|
type TLST m a = E.ExceptT Error (R.ReaderT Session m) a
|
||||||
|
|
||||||
throwE :: Error -> TLS a
|
runTLS :: (Unexceptional m) => Session -> TLST m a -> m (Either Error a)
|
||||||
throwE = fromExceptT . E.throwE
|
|
||||||
|
|
||||||
fromExceptT :: E.ExceptT Error UIO a -> TLS a
|
|
||||||
fromExceptT = E.mapExceptT lift
|
|
||||||
|
|
||||||
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
|
|
||||||
runTLS s = E.runExceptT . runTLS' s
|
runTLS s = E.runExceptT . runTLS' s
|
||||||
|
|
||||||
runTLS' :: (Unexceptional m) => Session -> TLS a -> E.ExceptT Error m a
|
runTLS' :: Session -> TLST m a -> E.ExceptT Error m a
|
||||||
runTLS' s = E.mapExceptT (UIO.lift . flip R.runReaderT s)
|
runTLS' s = E.mapExceptT (flip R.runReaderT s)
|
||||||
|
|
||||||
runClient :: Transport -> TLS a -> IO (Either Error a)
|
runClient :: (Unexceptional m) => Transport -> TLST m a -> m (Either Error a)
|
||||||
runClient transport tls = do
|
runClient transport tls = do
|
||||||
eitherSession <- newSession transport (F.ConnectionEnd 2)
|
eitherSession <- newSession transport (F.ConnectionEnd 2)
|
||||||
case eitherSession of
|
case eitherSession of
|
||||||
Left err -> return (Left err)
|
Left err -> return (Left err)
|
||||||
Right session -> runTLS session tls
|
Right session -> runTLS session tls
|
||||||
|
|
||||||
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
|
newSession :: (Unexceptional m) =>
|
||||||
newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
|
Transport
|
||||||
|
-> F.ConnectionEnd
|
||||||
|
-> m (Either Error Session)
|
||||||
|
newSession transport end = UIO.unsafeFromIO . F.alloca $ \sPtr -> E.runExceptT $ do
|
||||||
globalInit
|
globalInit
|
||||||
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
|
F.ReturnCode rc <- UIO.unsafeFromIO $ F.gnutls_init sPtr end
|
||||||
when (rc < 0) $ E.throwE $ mapError rc
|
when (rc < 0) $ E.throwE $ mapError rc
|
||||||
liftIO $ do
|
UIO.unsafeFromIO $ do
|
||||||
ptr <- F.peek sPtr
|
ptr <- F.peek sPtr
|
||||||
let session = F.Session ptr
|
let session = F.Session ptr
|
||||||
push <- F.wrapTransportFunc (pushImpl transport)
|
push <- F.wrapTransportFunc (pushImpl transport)
|
||||||
|
@ -132,22 +128,22 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
|
||||||
F.freeHaskellFunPtr pull
|
F.freeHaskellFunPtr pull
|
||||||
return (Session fp creds)
|
return (Session fp creds)
|
||||||
|
|
||||||
getSession :: TLS Session
|
getSession :: (Monad m) => TLST m Session
|
||||||
getSession = lift R.ask
|
getSession = lift R.ask
|
||||||
|
|
||||||
handshake :: TLS ()
|
handshake :: (Unexceptional m) => TLST m ()
|
||||||
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
|
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
|
||||||
|
|
||||||
rehandshake :: TLS ()
|
rehandshake :: (Unexceptional m) => TLST m ()
|
||||||
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
|
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
|
||||||
|
|
||||||
putBytes :: BL.ByteString -> TLS ()
|
putBytes :: (Unexceptional m) => BL.ByteString -> TLST m ()
|
||||||
putBytes = putChunks . BL.toChunks where
|
putBytes = putChunks . BL.toChunks where
|
||||||
putChunks chunks = do
|
putChunks chunks = do
|
||||||
maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
|
maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
|
||||||
case maybeErr of
|
case maybeErr of
|
||||||
Nothing -> return ()
|
Nothing -> return ()
|
||||||
Just err -> throwE $ mapError $ fromIntegral err
|
Just err -> E.throwE $ mapError $ fromIntegral err
|
||||||
|
|
||||||
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
|
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
|
||||||
loop ptr len = do
|
loop ptr len = do
|
||||||
|
@ -161,7 +157,7 @@ putBytes = putChunks . BL.toChunks where
|
||||||
|
|
||||||
putChunk _ err _ = return err
|
putChunk _ err _ = return err
|
||||||
|
|
||||||
getBytes :: Integer -> TLS BL.ByteString
|
getBytes :: (Unexceptional m) => Integer -> TLST m BL.ByteString
|
||||||
getBytes count = do
|
getBytes count = do
|
||||||
(mbytes, len) <- unsafeWithSession $ \s ->
|
(mbytes, len) <- unsafeWithSession $ \s ->
|
||||||
F.allocaBytes (fromInteger count) $ \ptr -> do
|
F.allocaBytes (fromInteger count) $ \ptr -> do
|
||||||
|
@ -175,9 +171,9 @@ getBytes count = do
|
||||||
|
|
||||||
case mbytes of
|
case mbytes of
|
||||||
Just bytes -> return bytes
|
Just bytes -> return bytes
|
||||||
Nothing -> throwE $ mapError $ fromIntegral len
|
Nothing -> E.throwE $ mapError $ fromIntegral len
|
||||||
|
|
||||||
checkPending :: TLS Integer
|
checkPending :: (Unexceptional m) => TLST m Integer
|
||||||
checkPending = unsafeWithSession $ \s -> do
|
checkPending = unsafeWithSession $ \s -> do
|
||||||
pending <- F.gnutls_record_check_pending s
|
pending <- F.gnutls_record_check_pending s
|
||||||
return $ toInteger pending
|
return $ toInteger pending
|
||||||
|
@ -209,7 +205,7 @@ handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger)
|
||||||
|
|
||||||
data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
|
data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
|
||||||
|
|
||||||
setCredentials :: Credentials -> TLS ()
|
setCredentials :: (Unexceptional m) => Credentials -> TLST m ()
|
||||||
setCredentials (Credentials ctype fp) = do
|
setCredentials (Credentials ctype fp) = do
|
||||||
rc <- unsafeWithSession $ \s ->
|
rc <- unsafeWithSession $ \s ->
|
||||||
F.withForeignPtr fp $ \ptr -> do
|
F.withForeignPtr fp $ \ptr -> do
|
||||||
|
@ -220,7 +216,7 @@ setCredentials (Credentials ctype fp) = do
|
||||||
then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
|
then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
|
||||||
else checkRC rc
|
else checkRC rc
|
||||||
|
|
||||||
certificateCredentials :: TLS Credentials
|
certificateCredentials :: (Unexceptional m) => TLST m Credentials
|
||||||
certificateCredentials = do
|
certificateCredentials = do
|
||||||
(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
|
(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
|
||||||
rc <- F.gnutls_certificate_allocate_credentials ptr
|
rc <- F.gnutls_certificate_allocate_credentials ptr
|
||||||
|
@ -233,13 +229,13 @@ certificateCredentials = do
|
||||||
return $ Credentials (F.CredentialsType 1) fp
|
return $ Credentials (F.CredentialsType 1) fp
|
||||||
|
|
||||||
-- | This must only be called with IO actions that do not throw NonPseudoException
|
-- | This must only be called with IO actions that do not throw NonPseudoException
|
||||||
unsafeWithSession :: (F.Session -> IO a) -> TLS a
|
unsafeWithSession :: (Unexceptional m) => (F.Session -> IO a) -> TLST m a
|
||||||
unsafeWithSession io = do
|
unsafeWithSession io = do
|
||||||
s <- getSession
|
s <- getSession
|
||||||
UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
||||||
|
|
||||||
checkRC :: F.ReturnCode -> TLS ()
|
checkRC :: (Monad m) => F.ReturnCode -> TLST m ()
|
||||||
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
|
checkRC (F.ReturnCode x) = when (x < 0) $ E.throwE $ mapError x
|
||||||
|
|
||||||
mapError :: F.CInt -> Error
|
mapError :: F.CInt -> Error
|
||||||
mapError = Error . toInteger
|
mapError = Error . toInteger
|
||||||
|
|
Loading…
Reference in a new issue