Switch TLS to TLST to allow any Unexceptional base monad
This commit is contained in:
parent
ceac3318da
commit
06a662e63a
1 changed files with 30 additions and 34 deletions
|
@ -17,10 +17,9 @@
|
|||
|
||||
module Network.Protocol.TLS.GNU
|
||||
( TLS
|
||||
, TLST
|
||||
, Session
|
||||
, Error (..)
|
||||
, throwE
|
||||
, fromExceptT
|
||||
|
||||
, runTLS
|
||||
, runTLS'
|
||||
|
@ -46,7 +45,6 @@ import Control.Monad (when, foldM, foldM_)
|
|||
import Control.Monad.Trans.Class (lift)
|
||||
import qualified Control.Monad.Trans.Except as E
|
||||
import qualified Control.Monad.Trans.Reader as R
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import qualified Data.ByteString.Unsafe as B
|
||||
|
@ -56,7 +54,7 @@ import qualified Foreign.C as F
|
|||
import Foreign.Concurrent as FC
|
||||
import qualified System.IO as IO
|
||||
import System.IO.Unsafe (unsafePerformIO)
|
||||
import UnexceptionalIO.Trans (UIO, Unexceptional)
|
||||
import UnexceptionalIO.Trans (Unexceptional)
|
||||
import qualified UnexceptionalIO.Trans as UIO
|
||||
|
||||
import qualified Network.Protocol.TLS.GNU.Foreign as F
|
||||
|
@ -68,10 +66,10 @@ globalInitMVar :: M.MVar ()
|
|||
{-# NOINLINE globalInitMVar #-}
|
||||
globalInitMVar = unsafePerformIO $ M.newMVar ()
|
||||
|
||||
globalInit :: E.ExceptT Error IO ()
|
||||
globalInit :: (Unexceptional m) => E.ExceptT Error m ()
|
||||
globalInit = do
|
||||
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
|
||||
F.ReturnCode rc <- liftIO init_
|
||||
F.ReturnCode rc <- UIO.unsafeFromIO init_
|
||||
when (rc < 0) $ E.throwE $ mapError rc
|
||||
|
||||
globalDeinit :: IO ()
|
||||
|
@ -90,33 +88,31 @@ data Session = Session
|
|||
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
||||
}
|
||||
|
||||
type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a
|
||||
type TLS a = TLST IO a
|
||||
type TLST m a = E.ExceptT Error (R.ReaderT Session m) a
|
||||
|
||||
throwE :: Error -> TLS a
|
||||
throwE = fromExceptT . E.throwE
|
||||
|
||||
fromExceptT :: E.ExceptT Error UIO a -> TLS a
|
||||
fromExceptT = E.mapExceptT lift
|
||||
|
||||
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
|
||||
runTLS :: (Unexceptional m) => Session -> TLST m a -> m (Either Error a)
|
||||
runTLS s = E.runExceptT . runTLS' s
|
||||
|
||||
runTLS' :: (Unexceptional m) => Session -> TLS a -> E.ExceptT Error m a
|
||||
runTLS' s = E.mapExceptT (UIO.lift . flip R.runReaderT s)
|
||||
runTLS' :: Session -> TLST m a -> E.ExceptT Error m a
|
||||
runTLS' s = E.mapExceptT (flip R.runReaderT s)
|
||||
|
||||
runClient :: Transport -> TLS a -> IO (Either Error a)
|
||||
runClient :: (Unexceptional m) => Transport -> TLST m a -> m (Either Error a)
|
||||
runClient transport tls = do
|
||||
eitherSession <- newSession transport (F.ConnectionEnd 2)
|
||||
case eitherSession of
|
||||
Left err -> return (Left err)
|
||||
Right session -> runTLS session tls
|
||||
|
||||
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
|
||||
newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
|
||||
newSession :: (Unexceptional m) =>
|
||||
Transport
|
||||
-> F.ConnectionEnd
|
||||
-> m (Either Error Session)
|
||||
newSession transport end = UIO.unsafeFromIO . F.alloca $ \sPtr -> E.runExceptT $ do
|
||||
globalInit
|
||||
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
|
||||
F.ReturnCode rc <- UIO.unsafeFromIO $ F.gnutls_init sPtr end
|
||||
when (rc < 0) $ E.throwE $ mapError rc
|
||||
liftIO $ do
|
||||
UIO.unsafeFromIO $ do
|
||||
ptr <- F.peek sPtr
|
||||
let session = F.Session ptr
|
||||
push <- F.wrapTransportFunc (pushImpl transport)
|
||||
|
@ -132,22 +128,22 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
|
|||
F.freeHaskellFunPtr pull
|
||||
return (Session fp creds)
|
||||
|
||||
getSession :: TLS Session
|
||||
getSession :: (Monad m) => TLST m Session
|
||||
getSession = lift R.ask
|
||||
|
||||
handshake :: TLS ()
|
||||
handshake :: (Unexceptional m) => TLST m ()
|
||||
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
|
||||
|
||||
rehandshake :: TLS ()
|
||||
rehandshake :: (Unexceptional m) => TLST m ()
|
||||
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
|
||||
|
||||
putBytes :: BL.ByteString -> TLS ()
|
||||
putBytes :: (Unexceptional m) => BL.ByteString -> TLST m ()
|
||||
putBytes = putChunks . BL.toChunks where
|
||||
putChunks chunks = do
|
||||
maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
|
||||
case maybeErr of
|
||||
Nothing -> return ()
|
||||
Just err -> throwE $ mapError $ fromIntegral err
|
||||
Just err -> E.throwE $ mapError $ fromIntegral err
|
||||
|
||||
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
|
||||
loop ptr len = do
|
||||
|
@ -161,7 +157,7 @@ putBytes = putChunks . BL.toChunks where
|
|||
|
||||
putChunk _ err _ = return err
|
||||
|
||||
getBytes :: Integer -> TLS BL.ByteString
|
||||
getBytes :: (Unexceptional m) => Integer -> TLST m BL.ByteString
|
||||
getBytes count = do
|
||||
(mbytes, len) <- unsafeWithSession $ \s ->
|
||||
F.allocaBytes (fromInteger count) $ \ptr -> do
|
||||
|
@ -175,9 +171,9 @@ getBytes count = do
|
|||
|
||||
case mbytes of
|
||||
Just bytes -> return bytes
|
||||
Nothing -> throwE $ mapError $ fromIntegral len
|
||||
Nothing -> E.throwE $ mapError $ fromIntegral len
|
||||
|
||||
checkPending :: TLS Integer
|
||||
checkPending :: (Unexceptional m) => TLST m Integer
|
||||
checkPending = unsafeWithSession $ \s -> do
|
||||
pending <- F.gnutls_record_check_pending s
|
||||
return $ toInteger pending
|
||||
|
@ -209,7 +205,7 @@ handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger)
|
|||
|
||||
data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
|
||||
|
||||
setCredentials :: Credentials -> TLS ()
|
||||
setCredentials :: (Unexceptional m) => Credentials -> TLST m ()
|
||||
setCredentials (Credentials ctype fp) = do
|
||||
rc <- unsafeWithSession $ \s ->
|
||||
F.withForeignPtr fp $ \ptr -> do
|
||||
|
@ -220,7 +216,7 @@ setCredentials (Credentials ctype fp) = do
|
|||
then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
|
||||
else checkRC rc
|
||||
|
||||
certificateCredentials :: TLS Credentials
|
||||
certificateCredentials :: (Unexceptional m) => TLST m Credentials
|
||||
certificateCredentials = do
|
||||
(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
|
||||
rc <- F.gnutls_certificate_allocate_credentials ptr
|
||||
|
@ -233,13 +229,13 @@ certificateCredentials = do
|
|||
return $ Credentials (F.CredentialsType 1) fp
|
||||
|
||||
-- | This must only be called with IO actions that do not throw NonPseudoException
|
||||
unsafeWithSession :: (F.Session -> IO a) -> TLS a
|
||||
unsafeWithSession :: (Unexceptional m) => (F.Session -> IO a) -> TLST m a
|
||||
unsafeWithSession io = do
|
||||
s <- getSession
|
||||
UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
||||
|
||||
checkRC :: F.ReturnCode -> TLS ()
|
||||
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
|
||||
checkRC :: (Monad m) => F.ReturnCode -> TLST m ()
|
||||
checkRC (F.ReturnCode x) = when (x < 0) $ E.throwE $ mapError x
|
||||
|
||||
mapError :: F.CInt -> Error
|
||||
mapError = Error . toInteger
|
||||
|
|
Loading…
Reference in a new issue