Switch TLS to TLST to allow any Unexceptional base monad

This commit is contained in:
Stephen Paul Weber 2021-02-15 21:51:41 -05:00
parent ceac3318da
commit 06a662e63a
No known key found for this signature in database
GPG key ID: D11C2911CE519CDE

View file

@ -17,10 +17,9 @@
module Network.Protocol.TLS.GNU
( TLS
, TLST
, Session
, Error (..)
, throwE
, fromExceptT
, runTLS
, runTLS'
@ -46,7 +45,6 @@ import Control.Monad (when, foldM, foldM_)
import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.Except as E
import qualified Control.Monad.Trans.Reader as R
import Control.Monad.IO.Class (liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Unsafe as B
@ -56,7 +54,7 @@ import qualified Foreign.C as F
import Foreign.Concurrent as FC
import qualified System.IO as IO
import System.IO.Unsafe (unsafePerformIO)
import UnexceptionalIO.Trans (UIO, Unexceptional)
import UnexceptionalIO.Trans (Unexceptional)
import qualified UnexceptionalIO.Trans as UIO
import qualified Network.Protocol.TLS.GNU.Foreign as F
@ -68,10 +66,10 @@ globalInitMVar :: M.MVar ()
{-# NOINLINE globalInitMVar #-}
globalInitMVar = unsafePerformIO $ M.newMVar ()
globalInit :: E.ExceptT Error IO ()
globalInit :: (Unexceptional m) => E.ExceptT Error m ()
globalInit = do
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
F.ReturnCode rc <- liftIO init_
F.ReturnCode rc <- UIO.unsafeFromIO init_
when (rc < 0) $ E.throwE $ mapError rc
globalDeinit :: IO ()
@ -90,33 +88,31 @@ data Session = Session
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
}
type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a
type TLS a = TLST IO a
type TLST m a = E.ExceptT Error (R.ReaderT Session m) a
throwE :: Error -> TLS a
throwE = fromExceptT . E.throwE
fromExceptT :: E.ExceptT Error UIO a -> TLS a
fromExceptT = E.mapExceptT lift
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
runTLS :: (Unexceptional m) => Session -> TLST m a -> m (Either Error a)
runTLS s = E.runExceptT . runTLS' s
runTLS' :: (Unexceptional m) => Session -> TLS a -> E.ExceptT Error m a
runTLS' s = E.mapExceptT (UIO.lift . flip R.runReaderT s)
runTLS' :: Session -> TLST m a -> E.ExceptT Error m a
runTLS' s = E.mapExceptT (flip R.runReaderT s)
runClient :: Transport -> TLS a -> IO (Either Error a)
runClient :: (Unexceptional m) => Transport -> TLST m a -> m (Either Error a)
runClient transport tls = do
eitherSession <- newSession transport (F.ConnectionEnd 2)
case eitherSession of
Left err -> return (Left err)
Right session -> runTLS session tls
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
newSession :: (Unexceptional m) =>
Transport
-> F.ConnectionEnd
-> m (Either Error Session)
newSession transport end = UIO.unsafeFromIO . F.alloca $ \sPtr -> E.runExceptT $ do
globalInit
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
F.ReturnCode rc <- UIO.unsafeFromIO $ F.gnutls_init sPtr end
when (rc < 0) $ E.throwE $ mapError rc
liftIO $ do
UIO.unsafeFromIO $ do
ptr <- F.peek sPtr
let session = F.Session ptr
push <- F.wrapTransportFunc (pushImpl transport)
@ -132,22 +128,22 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
F.freeHaskellFunPtr pull
return (Session fp creds)
getSession :: TLS Session
getSession :: (Monad m) => TLST m Session
getSession = lift R.ask
handshake :: TLS ()
handshake :: (Unexceptional m) => TLST m ()
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
rehandshake :: TLS ()
rehandshake :: (Unexceptional m) => TLST m ()
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
putBytes :: BL.ByteString -> TLS ()
putBytes :: (Unexceptional m) => BL.ByteString -> TLST m ()
putBytes = putChunks . BL.toChunks where
putChunks chunks = do
maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
case maybeErr of
Nothing -> return ()
Just err -> throwE $ mapError $ fromIntegral err
Just err -> E.throwE $ mapError $ fromIntegral err
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
loop ptr len = do
@ -161,7 +157,7 @@ putBytes = putChunks . BL.toChunks where
putChunk _ err _ = return err
getBytes :: Integer -> TLS BL.ByteString
getBytes :: (Unexceptional m) => Integer -> TLST m BL.ByteString
getBytes count = do
(mbytes, len) <- unsafeWithSession $ \s ->
F.allocaBytes (fromInteger count) $ \ptr -> do
@ -175,9 +171,9 @@ getBytes count = do
case mbytes of
Just bytes -> return bytes
Nothing -> throwE $ mapError $ fromIntegral len
Nothing -> E.throwE $ mapError $ fromIntegral len
checkPending :: TLS Integer
checkPending :: (Unexceptional m) => TLST m Integer
checkPending = unsafeWithSession $ \s -> do
pending <- F.gnutls_record_check_pending s
return $ toInteger pending
@ -209,7 +205,7 @@ handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger)
data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
setCredentials :: Credentials -> TLS ()
setCredentials :: (Unexceptional m) => Credentials -> TLST m ()
setCredentials (Credentials ctype fp) = do
rc <- unsafeWithSession $ \s ->
F.withForeignPtr fp $ \ptr -> do
@ -220,7 +216,7 @@ setCredentials (Credentials ctype fp) = do
then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
else checkRC rc
certificateCredentials :: TLS Credentials
certificateCredentials :: (Unexceptional m) => TLST m Credentials
certificateCredentials = do
(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
rc <- F.gnutls_certificate_allocate_credentials ptr
@ -233,13 +229,13 @@ certificateCredentials = do
return $ Credentials (F.CredentialsType 1) fp
-- | This must only be called with IO actions that do not throw NonPseudoException
unsafeWithSession :: (F.Session -> IO a) -> TLS a
unsafeWithSession :: (Unexceptional m) => (F.Session -> IO a) -> TLST m a
unsafeWithSession io = do
s <- getSession
UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
checkRC :: F.ReturnCode -> TLS ()
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
checkRC :: (Monad m) => F.ReturnCode -> TLST m ()
checkRC (F.ReturnCode x) = when (x < 0) $ E.throwE $ mapError x
mapError :: F.CInt -> Error
mapError = Error . toInteger