Switch monad transformer stack to a type alias
Since we already allowed injecting any Session via runTLS or throwing any Error via throwE, this does not reduce safety at all but improves ergonomics considerably. The only downside here is that we must say goodbye to our transitional MonadIO instance.
This commit is contained in:
parent
17b9279287
commit
955b054ff4
1 changed files with 20 additions and 39 deletions
|
@ -20,7 +20,6 @@ module Network.Protocol.TLS.GNU
|
||||||
, Session
|
, Session
|
||||||
, Error (..)
|
, Error (..)
|
||||||
, throwE
|
, throwE
|
||||||
, catchE
|
|
||||||
, fromExceptT
|
, fromExceptT
|
||||||
|
|
||||||
, runTLS
|
, runTLS
|
||||||
|
@ -41,13 +40,12 @@ module Network.Protocol.TLS.GNU
|
||||||
, certificateCredentials
|
, certificateCredentials
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Control.Applicative (Applicative, pure, (<*>))
|
|
||||||
import qualified Control.Concurrent.MVar as M
|
import qualified Control.Concurrent.MVar as M
|
||||||
import Control.Monad (ap, when, foldM, foldM_)
|
import Control.Monad (when, foldM, foldM_)
|
||||||
import Control.Monad.Trans.Class (lift)
|
import Control.Monad.Trans.Class (lift)
|
||||||
import qualified Control.Monad.Trans.Except as E
|
import qualified Control.Monad.Trans.Except as E
|
||||||
import qualified Control.Monad.Trans.Reader as R
|
import qualified Control.Monad.Trans.Reader as R
|
||||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
import qualified Data.ByteString.Lazy as BL
|
import qualified Data.ByteString.Lazy as BL
|
||||||
import qualified Data.ByteString.Unsafe as B
|
import qualified Data.ByteString.Unsafe as B
|
||||||
|
@ -62,7 +60,7 @@ import qualified UnexceptionalIO.Trans as UIO
|
||||||
|
|
||||||
import qualified Network.Protocol.TLS.GNU.Foreign as F
|
import qualified Network.Protocol.TLS.GNU.Foreign as F
|
||||||
|
|
||||||
data Error = Error Integer | IOError IOError
|
data Error = Error Integer
|
||||||
deriving (Show)
|
deriving (Show)
|
||||||
|
|
||||||
globalInitMVar :: M.MVar ()
|
globalInitMVar :: M.MVar ()
|
||||||
|
@ -91,34 +89,16 @@ data Session = Session
|
||||||
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
||||||
}
|
}
|
||||||
|
|
||||||
newtype TLS a = TLS { unTLS :: E.ExceptT Error (R.ReaderT Session UIO) a }
|
type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a
|
||||||
|
|
||||||
instance Functor TLS where
|
|
||||||
fmap f = TLS . fmap f . unTLS
|
|
||||||
|
|
||||||
instance Applicative TLS where
|
|
||||||
pure = TLS . return
|
|
||||||
(<*>) = ap
|
|
||||||
|
|
||||||
instance Monad TLS where
|
|
||||||
return = TLS . return
|
|
||||||
m >>= f = TLS $ unTLS m >>= unTLS . f
|
|
||||||
|
|
||||||
-- | This is a transitional instance and may be deprecated in the future
|
|
||||||
instance MonadIO TLS where
|
|
||||||
liftIO = TLS . E.withExceptT IOError . UIO.fromIO' (userError . show)
|
|
||||||
|
|
||||||
throwE :: Error -> TLS a
|
throwE :: Error -> TLS a
|
||||||
throwE = fromExceptT . E.throwE
|
throwE = fromExceptT . E.throwE
|
||||||
|
|
||||||
catchE :: TLS a -> (Error -> TLS a) -> TLS a
|
|
||||||
catchE inner handler = TLS $ unTLS inner `E.catchE` (unTLS . handler)
|
|
||||||
|
|
||||||
fromExceptT :: E.ExceptT Error UIO a -> TLS a
|
fromExceptT :: E.ExceptT Error UIO a -> TLS a
|
||||||
fromExceptT = TLS . E.mapExceptT lift
|
fromExceptT = E.mapExceptT lift
|
||||||
|
|
||||||
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
|
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
|
||||||
runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT (unTLS tls)) s
|
runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT tls) s
|
||||||
|
|
||||||
runClient :: Transport -> TLS a -> IO (Either Error a)
|
runClient :: Transport -> TLS a -> IO (Either Error a)
|
||||||
runClient transport tls = do
|
runClient transport tls = do
|
||||||
|
@ -149,18 +129,18 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
|
||||||
return (Session fp creds)
|
return (Session fp creds)
|
||||||
|
|
||||||
getSession :: TLS Session
|
getSession :: TLS Session
|
||||||
getSession = TLS $ lift R.ask
|
getSession = lift R.ask
|
||||||
|
|
||||||
handshake :: TLS ()
|
handshake :: TLS ()
|
||||||
handshake = withSession F.gnutls_handshake >>= checkRC
|
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
|
||||||
|
|
||||||
rehandshake :: TLS ()
|
rehandshake :: TLS ()
|
||||||
rehandshake = withSession F.gnutls_rehandshake >>= checkRC
|
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
|
||||||
|
|
||||||
putBytes :: BL.ByteString -> TLS ()
|
putBytes :: BL.ByteString -> TLS ()
|
||||||
putBytes = putChunks . BL.toChunks where
|
putBytes = putChunks . BL.toChunks where
|
||||||
putChunks chunks = do
|
putChunks chunks = do
|
||||||
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
|
maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
|
||||||
case maybeErr of
|
case maybeErr of
|
||||||
Nothing -> return ()
|
Nothing -> return ()
|
||||||
Just err -> throwE $ mapError $ fromIntegral err
|
Just err -> throwE $ mapError $ fromIntegral err
|
||||||
|
@ -179,7 +159,7 @@ putBytes = putChunks . BL.toChunks where
|
||||||
|
|
||||||
getBytes :: Integer -> TLS BL.ByteString
|
getBytes :: Integer -> TLS BL.ByteString
|
||||||
getBytes count = do
|
getBytes count = do
|
||||||
(mbytes, len) <- withSession $ \s ->
|
(mbytes, len) <- unsafeWithSession $ \s ->
|
||||||
F.allocaBytes (fromInteger count) $ \ptr -> do
|
F.allocaBytes (fromInteger count) $ \ptr -> do
|
||||||
len <- F.gnutls_record_recv s ptr (fromInteger count)
|
len <- F.gnutls_record_recv s ptr (fromInteger count)
|
||||||
bytes <- if len >= 0
|
bytes <- if len >= 0
|
||||||
|
@ -194,7 +174,7 @@ getBytes count = do
|
||||||
Nothing -> throwE $ mapError $ fromIntegral len
|
Nothing -> throwE $ mapError $ fromIntegral len
|
||||||
|
|
||||||
checkPending :: TLS Integer
|
checkPending :: TLS Integer
|
||||||
checkPending = withSession $ \s -> do
|
checkPending = unsafeWithSession $ \s -> do
|
||||||
pending <- F.gnutls_record_check_pending s
|
pending <- F.gnutls_record_check_pending s
|
||||||
return $ toInteger pending
|
return $ toInteger pending
|
||||||
|
|
||||||
|
@ -227,31 +207,32 @@ data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
|
||||||
|
|
||||||
setCredentials :: Credentials -> TLS ()
|
setCredentials :: Credentials -> TLS ()
|
||||||
setCredentials (Credentials ctype fp) = do
|
setCredentials (Credentials ctype fp) = do
|
||||||
rc <- withSession $ \s ->
|
rc <- unsafeWithSession $ \s ->
|
||||||
F.withForeignPtr fp $ \ptr -> do
|
F.withForeignPtr fp $ \ptr -> do
|
||||||
F.gnutls_credentials_set s ctype ptr
|
F.gnutls_credentials_set s ctype ptr
|
||||||
|
|
||||||
s <- getSession
|
s <- getSession
|
||||||
if F.unRC rc == 0
|
if F.unRC rc == 0
|
||||||
then liftIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
|
then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
|
||||||
else checkRC rc
|
else checkRC rc
|
||||||
|
|
||||||
certificateCredentials :: TLS Credentials
|
certificateCredentials :: TLS Credentials
|
||||||
certificateCredentials = do
|
certificateCredentials = do
|
||||||
(rc, ptr) <- liftIO $ F.alloca $ \ptr -> do
|
(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
|
||||||
rc <- F.gnutls_certificate_allocate_credentials ptr
|
rc <- F.gnutls_certificate_allocate_credentials ptr
|
||||||
ptr' <- if F.unRC rc < 0
|
ptr' <- if F.unRC rc < 0
|
||||||
then return F.nullPtr
|
then return F.nullPtr
|
||||||
else F.peek ptr
|
else F.peek ptr
|
||||||
return (rc, ptr')
|
return (rc, ptr')
|
||||||
checkRC rc
|
checkRC rc
|
||||||
fp <- liftIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
|
fp <- UIO.unsafeFromIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
|
||||||
return $ Credentials (F.CredentialsType 1) fp
|
return $ Credentials (F.CredentialsType 1) fp
|
||||||
|
|
||||||
withSession :: (F.Session -> IO a) -> TLS a
|
-- | This must only be called with IO actions that do not throw NonPseudoException
|
||||||
withSession io = do
|
unsafeWithSession :: (F.Session -> IO a) -> TLS a
|
||||||
|
unsafeWithSession io = do
|
||||||
s <- getSession
|
s <- getSession
|
||||||
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
||||||
|
|
||||||
checkRC :: F.ReturnCode -> TLS ()
|
checkRC :: F.ReturnCode -> TLS ()
|
||||||
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
|
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
|
||||||
|
|
Loading…
Reference in a new issue