Switch TLS to TLST to allow any Unexceptional base monad

This commit is contained in:
Stephen Paul Weber 2021-02-15 21:51:41 -05:00
parent ceac3318da
commit 06a662e63a
No known key found for this signature in database
GPG key ID: D11C2911CE519CDE

View file

@ -17,10 +17,9 @@
module Network.Protocol.TLS.GNU module Network.Protocol.TLS.GNU
( TLS ( TLS
, TLST
, Session , Session
, Error (..) , Error (..)
, throwE
, fromExceptT
, runTLS , runTLS
, runTLS' , runTLS'
@ -46,7 +45,6 @@ import Control.Monad (when, foldM, foldM_)
import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.Except as E import qualified Control.Monad.Trans.Except as E
import qualified Control.Monad.Trans.Reader as R import qualified Control.Monad.Trans.Reader as R
import Control.Monad.IO.Class (liftIO)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Unsafe as B import qualified Data.ByteString.Unsafe as B
@ -56,7 +54,7 @@ import qualified Foreign.C as F
import Foreign.Concurrent as FC import Foreign.Concurrent as FC
import qualified System.IO as IO import qualified System.IO as IO
import System.IO.Unsafe (unsafePerformIO) import System.IO.Unsafe (unsafePerformIO)
import UnexceptionalIO.Trans (UIO, Unexceptional) import UnexceptionalIO.Trans (Unexceptional)
import qualified UnexceptionalIO.Trans as UIO import qualified UnexceptionalIO.Trans as UIO
import qualified Network.Protocol.TLS.GNU.Foreign as F import qualified Network.Protocol.TLS.GNU.Foreign as F
@ -68,10 +66,10 @@ globalInitMVar :: M.MVar ()
{-# NOINLINE globalInitMVar #-} {-# NOINLINE globalInitMVar #-}
globalInitMVar = unsafePerformIO $ M.newMVar () globalInitMVar = unsafePerformIO $ M.newMVar ()
globalInit :: E.ExceptT Error IO () globalInit :: (Unexceptional m) => E.ExceptT Error m ()
globalInit = do globalInit = do
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
F.ReturnCode rc <- liftIO init_ F.ReturnCode rc <- UIO.unsafeFromIO init_
when (rc < 0) $ E.throwE $ mapError rc when (rc < 0) $ E.throwE $ mapError rc
globalDeinit :: IO () globalDeinit :: IO ()
@ -90,33 +88,31 @@ data Session = Session
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials] , sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
} }
type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a type TLS a = TLST IO a
type TLST m a = E.ExceptT Error (R.ReaderT Session m) a
throwE :: Error -> TLS a runTLS :: (Unexceptional m) => Session -> TLST m a -> m (Either Error a)
throwE = fromExceptT . E.throwE
fromExceptT :: E.ExceptT Error UIO a -> TLS a
fromExceptT = E.mapExceptT lift
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
runTLS s = E.runExceptT . runTLS' s runTLS s = E.runExceptT . runTLS' s
runTLS' :: (Unexceptional m) => Session -> TLS a -> E.ExceptT Error m a runTLS' :: Session -> TLST m a -> E.ExceptT Error m a
runTLS' s = E.mapExceptT (UIO.lift . flip R.runReaderT s) runTLS' s = E.mapExceptT (flip R.runReaderT s)
runClient :: Transport -> TLS a -> IO (Either Error a) runClient :: (Unexceptional m) => Transport -> TLST m a -> m (Either Error a)
runClient transport tls = do runClient transport tls = do
eitherSession <- newSession transport (F.ConnectionEnd 2) eitherSession <- newSession transport (F.ConnectionEnd 2)
case eitherSession of case eitherSession of
Left err -> return (Left err) Left err -> return (Left err)
Right session -> runTLS session tls Right session -> runTLS session tls
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session) newSession :: (Unexceptional m) =>
newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do Transport
-> F.ConnectionEnd
-> m (Either Error Session)
newSession transport end = UIO.unsafeFromIO . F.alloca $ \sPtr -> E.runExceptT $ do
globalInit globalInit
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end F.ReturnCode rc <- UIO.unsafeFromIO $ F.gnutls_init sPtr end
when (rc < 0) $ E.throwE $ mapError rc when (rc < 0) $ E.throwE $ mapError rc
liftIO $ do UIO.unsafeFromIO $ do
ptr <- F.peek sPtr ptr <- F.peek sPtr
let session = F.Session ptr let session = F.Session ptr
push <- F.wrapTransportFunc (pushImpl transport) push <- F.wrapTransportFunc (pushImpl transport)
@ -132,22 +128,22 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
F.freeHaskellFunPtr pull F.freeHaskellFunPtr pull
return (Session fp creds) return (Session fp creds)
getSession :: TLS Session getSession :: (Monad m) => TLST m Session
getSession = lift R.ask getSession = lift R.ask
handshake :: TLS () handshake :: (Unexceptional m) => TLST m ()
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
rehandshake :: TLS () rehandshake :: (Unexceptional m) => TLST m ()
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
putBytes :: BL.ByteString -> TLS () putBytes :: (Unexceptional m) => BL.ByteString -> TLST m ()
putBytes = putChunks . BL.toChunks where putBytes = putChunks . BL.toChunks where
putChunks chunks = do putChunks chunks = do
maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
case maybeErr of case maybeErr of
Nothing -> return () Nothing -> return ()
Just err -> throwE $ mapError $ fromIntegral err Just err -> E.throwE $ mapError $ fromIntegral err
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
loop ptr len = do loop ptr len = do
@ -161,7 +157,7 @@ putBytes = putChunks . BL.toChunks where
putChunk _ err _ = return err putChunk _ err _ = return err
getBytes :: Integer -> TLS BL.ByteString getBytes :: (Unexceptional m) => Integer -> TLST m BL.ByteString
getBytes count = do getBytes count = do
(mbytes, len) <- unsafeWithSession $ \s -> (mbytes, len) <- unsafeWithSession $ \s ->
F.allocaBytes (fromInteger count) $ \ptr -> do F.allocaBytes (fromInteger count) $ \ptr -> do
@ -175,9 +171,9 @@ getBytes count = do
case mbytes of case mbytes of
Just bytes -> return bytes Just bytes -> return bytes
Nothing -> throwE $ mapError $ fromIntegral len Nothing -> E.throwE $ mapError $ fromIntegral len
checkPending :: TLS Integer checkPending :: (Unexceptional m) => TLST m Integer
checkPending = unsafeWithSession $ \s -> do checkPending = unsafeWithSession $ \s -> do
pending <- F.gnutls_record_check_pending s pending <- F.gnutls_record_check_pending s
return $ toInteger pending return $ toInteger pending
@ -209,7 +205,7 @@ handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger)
data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials) data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
setCredentials :: Credentials -> TLS () setCredentials :: (Unexceptional m) => Credentials -> TLST m ()
setCredentials (Credentials ctype fp) = do setCredentials (Credentials ctype fp) = do
rc <- unsafeWithSession $ \s -> rc <- unsafeWithSession $ \s ->
F.withForeignPtr fp $ \ptr -> do F.withForeignPtr fp $ \ptr -> do
@ -220,7 +216,7 @@ setCredentials (Credentials ctype fp) = do
then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ()))) then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
else checkRC rc else checkRC rc
certificateCredentials :: TLS Credentials certificateCredentials :: (Unexceptional m) => TLST m Credentials
certificateCredentials = do certificateCredentials = do
(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do (rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
rc <- F.gnutls_certificate_allocate_credentials ptr rc <- F.gnutls_certificate_allocate_credentials ptr
@ -233,13 +229,13 @@ certificateCredentials = do
return $ Credentials (F.CredentialsType 1) fp return $ Credentials (F.CredentialsType 1) fp
-- | This must only be called with IO actions that do not throw NonPseudoException -- | This must only be called with IO actions that do not throw NonPseudoException
unsafeWithSession :: (F.Session -> IO a) -> TLS a unsafeWithSession :: (Unexceptional m) => (F.Session -> IO a) -> TLST m a
unsafeWithSession io = do unsafeWithSession io = do
s <- getSession s <- getSession
UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
checkRC :: F.ReturnCode -> TLS () checkRC :: (Monad m) => F.ReturnCode -> TLST m ()
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x checkRC (F.ReturnCode x) = when (x < 0) $ E.throwE $ mapError x
mapError :: F.CInt -> Error mapError :: F.CInt -> Error
mapError = Error . toInteger mapError = Error . toInteger