Switch monad transformer stack to a type alias

Since we already allowed injecting any Session via runTLS or throwing any Error
via throwE, this does not reduce safety at all but improves ergonomics
considerably.

The only downside here is that we must say goodbye to our transitional MonadIO
instance.
This commit is contained in:
Stephen Paul Weber 2021-02-13 20:59:00 -05:00
parent 17b9279287
commit 955b054ff4

View file

@ -20,7 +20,6 @@ module Network.Protocol.TLS.GNU
, Session , Session
, Error (..) , Error (..)
, throwE , throwE
, catchE
, fromExceptT , fromExceptT
, runTLS , runTLS
@ -41,13 +40,12 @@ module Network.Protocol.TLS.GNU
, certificateCredentials , certificateCredentials
) where ) where
import Control.Applicative (Applicative, pure, (<*>))
import qualified Control.Concurrent.MVar as M import qualified Control.Concurrent.MVar as M
import Control.Monad (ap, when, foldM, foldM_) import Control.Monad (when, foldM, foldM_)
import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.Except as E import qualified Control.Monad.Trans.Except as E
import qualified Control.Monad.Trans.Reader as R import qualified Control.Monad.Trans.Reader as R
import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.IO.Class (liftIO)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Unsafe as B import qualified Data.ByteString.Unsafe as B
@ -62,7 +60,7 @@ import qualified UnexceptionalIO.Trans as UIO
import qualified Network.Protocol.TLS.GNU.Foreign as F import qualified Network.Protocol.TLS.GNU.Foreign as F
data Error = Error Integer | IOError IOError data Error = Error Integer
deriving (Show) deriving (Show)
globalInitMVar :: M.MVar () globalInitMVar :: M.MVar ()
@ -91,34 +89,16 @@ data Session = Session
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials] , sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
} }
newtype TLS a = TLS { unTLS :: E.ExceptT Error (R.ReaderT Session UIO) a } type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a
instance Functor TLS where
fmap f = TLS . fmap f . unTLS
instance Applicative TLS where
pure = TLS . return
(<*>) = ap
instance Monad TLS where
return = TLS . return
m >>= f = TLS $ unTLS m >>= unTLS . f
-- | This is a transitional instance and may be deprecated in the future
instance MonadIO TLS where
liftIO = TLS . E.withExceptT IOError . UIO.fromIO' (userError . show)
throwE :: Error -> TLS a throwE :: Error -> TLS a
throwE = fromExceptT . E.throwE throwE = fromExceptT . E.throwE
catchE :: TLS a -> (Error -> TLS a) -> TLS a
catchE inner handler = TLS $ unTLS inner `E.catchE` (unTLS . handler)
fromExceptT :: E.ExceptT Error UIO a -> TLS a fromExceptT :: E.ExceptT Error UIO a -> TLS a
fromExceptT = TLS . E.mapExceptT lift fromExceptT = E.mapExceptT lift
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a) runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT (unTLS tls)) s runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT tls) s
runClient :: Transport -> TLS a -> IO (Either Error a) runClient :: Transport -> TLS a -> IO (Either Error a)
runClient transport tls = do runClient transport tls = do
@ -149,18 +129,18 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
return (Session fp creds) return (Session fp creds)
getSession :: TLS Session getSession :: TLS Session
getSession = TLS $ lift R.ask getSession = lift R.ask
handshake :: TLS () handshake :: TLS ()
handshake = withSession F.gnutls_handshake >>= checkRC handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
rehandshake :: TLS () rehandshake :: TLS ()
rehandshake = withSession F.gnutls_rehandshake >>= checkRC rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
putBytes :: BL.ByteString -> TLS () putBytes :: BL.ByteString -> TLS ()
putBytes = putChunks . BL.toChunks where putBytes = putChunks . BL.toChunks where
putChunks chunks = do putChunks chunks = do
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
case maybeErr of case maybeErr of
Nothing -> return () Nothing -> return ()
Just err -> throwE $ mapError $ fromIntegral err Just err -> throwE $ mapError $ fromIntegral err
@ -179,7 +159,7 @@ putBytes = putChunks . BL.toChunks where
getBytes :: Integer -> TLS BL.ByteString getBytes :: Integer -> TLS BL.ByteString
getBytes count = do getBytes count = do
(mbytes, len) <- withSession $ \s -> (mbytes, len) <- unsafeWithSession $ \s ->
F.allocaBytes (fromInteger count) $ \ptr -> do F.allocaBytes (fromInteger count) $ \ptr -> do
len <- F.gnutls_record_recv s ptr (fromInteger count) len <- F.gnutls_record_recv s ptr (fromInteger count)
bytes <- if len >= 0 bytes <- if len >= 0
@ -194,7 +174,7 @@ getBytes count = do
Nothing -> throwE $ mapError $ fromIntegral len Nothing -> throwE $ mapError $ fromIntegral len
checkPending :: TLS Integer checkPending :: TLS Integer
checkPending = withSession $ \s -> do checkPending = unsafeWithSession $ \s -> do
pending <- F.gnutls_record_check_pending s pending <- F.gnutls_record_check_pending s
return $ toInteger pending return $ toInteger pending
@ -227,31 +207,32 @@ data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
setCredentials :: Credentials -> TLS () setCredentials :: Credentials -> TLS ()
setCredentials (Credentials ctype fp) = do setCredentials (Credentials ctype fp) = do
rc <- withSession $ \s -> rc <- unsafeWithSession $ \s ->
F.withForeignPtr fp $ \ptr -> do F.withForeignPtr fp $ \ptr -> do
F.gnutls_credentials_set s ctype ptr F.gnutls_credentials_set s ctype ptr
s <- getSession s <- getSession
if F.unRC rc == 0 if F.unRC rc == 0
then liftIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ()))) then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
else checkRC rc else checkRC rc
certificateCredentials :: TLS Credentials certificateCredentials :: TLS Credentials
certificateCredentials = do certificateCredentials = do
(rc, ptr) <- liftIO $ F.alloca $ \ptr -> do (rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
rc <- F.gnutls_certificate_allocate_credentials ptr rc <- F.gnutls_certificate_allocate_credentials ptr
ptr' <- if F.unRC rc < 0 ptr' <- if F.unRC rc < 0
then return F.nullPtr then return F.nullPtr
else F.peek ptr else F.peek ptr
return (rc, ptr') return (rc, ptr')
checkRC rc checkRC rc
fp <- liftIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr fp <- UIO.unsafeFromIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
return $ Credentials (F.CredentialsType 1) fp return $ Credentials (F.CredentialsType 1) fp
withSession :: (F.Session -> IO a) -> TLS a -- | This must only be called with IO actions that do not throw NonPseudoException
withSession io = do unsafeWithSession :: (F.Session -> IO a) -> TLS a
unsafeWithSession io = do
s <- getSession s <- getSession
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
checkRC :: F.ReturnCode -> TLS () checkRC :: F.ReturnCode -> TLS ()
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x