Switch monad transformer stack to a type alias

Since we already allowed injecting any Session via runTLS or throwing any Error
via throwE, this does not reduce safety at all but improves ergonomics
considerably.

The only downside here is that we must say goodbye to our transitional MonadIO
instance.
This commit is contained in:
Stephen Paul Weber 2021-02-13 20:59:00 -05:00
parent 17b9279287
commit 955b054ff4

View file

@ -20,7 +20,6 @@ module Network.Protocol.TLS.GNU
, Session
, Error (..)
, throwE
, catchE
, fromExceptT
, runTLS
@ -41,13 +40,12 @@ module Network.Protocol.TLS.GNU
, certificateCredentials
) where
import Control.Applicative (Applicative, pure, (<*>))
import qualified Control.Concurrent.MVar as M
import Control.Monad (ap, when, foldM, foldM_)
import Control.Monad (when, foldM, foldM_)
import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.Except as E
import qualified Control.Monad.Trans.Reader as R
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.IO.Class (liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Unsafe as B
@ -62,7 +60,7 @@ import qualified UnexceptionalIO.Trans as UIO
import qualified Network.Protocol.TLS.GNU.Foreign as F
data Error = Error Integer | IOError IOError
data Error = Error Integer
deriving (Show)
globalInitMVar :: M.MVar ()
@ -91,34 +89,16 @@ data Session = Session
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
}
newtype TLS a = TLS { unTLS :: E.ExceptT Error (R.ReaderT Session UIO) a }
instance Functor TLS where
fmap f = TLS . fmap f . unTLS
instance Applicative TLS where
pure = TLS . return
(<*>) = ap
instance Monad TLS where
return = TLS . return
m >>= f = TLS $ unTLS m >>= unTLS . f
-- | This is a transitional instance and may be deprecated in the future
instance MonadIO TLS where
liftIO = TLS . E.withExceptT IOError . UIO.fromIO' (userError . show)
type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a
throwE :: Error -> TLS a
throwE = fromExceptT . E.throwE
catchE :: TLS a -> (Error -> TLS a) -> TLS a
catchE inner handler = TLS $ unTLS inner `E.catchE` (unTLS . handler)
fromExceptT :: E.ExceptT Error UIO a -> TLS a
fromExceptT = TLS . E.mapExceptT lift
fromExceptT = E.mapExceptT lift
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT (unTLS tls)) s
runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT tls) s
runClient :: Transport -> TLS a -> IO (Either Error a)
runClient transport tls = do
@ -149,18 +129,18 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
return (Session fp creds)
getSession :: TLS Session
getSession = TLS $ lift R.ask
getSession = lift R.ask
handshake :: TLS ()
handshake = withSession F.gnutls_handshake >>= checkRC
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
rehandshake :: TLS ()
rehandshake = withSession F.gnutls_rehandshake >>= checkRC
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
putBytes :: BL.ByteString -> TLS ()
putBytes = putChunks . BL.toChunks where
putChunks chunks = do
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
case maybeErr of
Nothing -> return ()
Just err -> throwE $ mapError $ fromIntegral err
@ -179,7 +159,7 @@ putBytes = putChunks . BL.toChunks where
getBytes :: Integer -> TLS BL.ByteString
getBytes count = do
(mbytes, len) <- withSession $ \s ->
(mbytes, len) <- unsafeWithSession $ \s ->
F.allocaBytes (fromInteger count) $ \ptr -> do
len <- F.gnutls_record_recv s ptr (fromInteger count)
bytes <- if len >= 0
@ -194,7 +174,7 @@ getBytes count = do
Nothing -> throwE $ mapError $ fromIntegral len
checkPending :: TLS Integer
checkPending = withSession $ \s -> do
checkPending = unsafeWithSession $ \s -> do
pending <- F.gnutls_record_check_pending s
return $ toInteger pending
@ -227,31 +207,32 @@ data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
setCredentials :: Credentials -> TLS ()
setCredentials (Credentials ctype fp) = do
rc <- withSession $ \s ->
rc <- unsafeWithSession $ \s ->
F.withForeignPtr fp $ \ptr -> do
F.gnutls_credentials_set s ctype ptr
s <- getSession
if F.unRC rc == 0
then liftIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
else checkRC rc
certificateCredentials :: TLS Credentials
certificateCredentials = do
(rc, ptr) <- liftIO $ F.alloca $ \ptr -> do
(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
rc <- F.gnutls_certificate_allocate_credentials ptr
ptr' <- if F.unRC rc < 0
then return F.nullPtr
else F.peek ptr
return (rc, ptr')
checkRC rc
fp <- liftIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
fp <- UIO.unsafeFromIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
return $ Credentials (F.CredentialsType 1) fp
withSession :: (F.Session -> IO a) -> TLS a
withSession io = do
-- | This must only be called with IO actions that do not throw NonPseudoException
unsafeWithSession :: (F.Session -> IO a) -> TLS a
unsafeWithSession io = do
s <- getSession
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
checkRC :: F.ReturnCode -> TLS ()
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x