Switch monad transformer stack to a type alias
Since we already allowed injecting any Session via runTLS or throwing any Error via throwE, this does not reduce safety at all but improves ergonomics considerably. The only downside here is that we must say goodbye to our transitional MonadIO instance.
This commit is contained in:
parent
17b9279287
commit
955b054ff4
1 changed files with 20 additions and 39 deletions
|
@ -20,7 +20,6 @@ module Network.Protocol.TLS.GNU
|
|||
, Session
|
||||
, Error (..)
|
||||
, throwE
|
||||
, catchE
|
||||
, fromExceptT
|
||||
|
||||
, runTLS
|
||||
|
@ -41,13 +40,12 @@ module Network.Protocol.TLS.GNU
|
|||
, certificateCredentials
|
||||
) where
|
||||
|
||||
import Control.Applicative (Applicative, pure, (<*>))
|
||||
import qualified Control.Concurrent.MVar as M
|
||||
import Control.Monad (ap, when, foldM, foldM_)
|
||||
import Control.Monad (when, foldM, foldM_)
|
||||
import Control.Monad.Trans.Class (lift)
|
||||
import qualified Control.Monad.Trans.Except as E
|
||||
import qualified Control.Monad.Trans.Reader as R
|
||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import qualified Data.ByteString.Unsafe as B
|
||||
|
@ -62,7 +60,7 @@ import qualified UnexceptionalIO.Trans as UIO
|
|||
|
||||
import qualified Network.Protocol.TLS.GNU.Foreign as F
|
||||
|
||||
data Error = Error Integer | IOError IOError
|
||||
data Error = Error Integer
|
||||
deriving (Show)
|
||||
|
||||
globalInitMVar :: M.MVar ()
|
||||
|
@ -91,34 +89,16 @@ data Session = Session
|
|||
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
||||
}
|
||||
|
||||
newtype TLS a = TLS { unTLS :: E.ExceptT Error (R.ReaderT Session UIO) a }
|
||||
|
||||
instance Functor TLS where
|
||||
fmap f = TLS . fmap f . unTLS
|
||||
|
||||
instance Applicative TLS where
|
||||
pure = TLS . return
|
||||
(<*>) = ap
|
||||
|
||||
instance Monad TLS where
|
||||
return = TLS . return
|
||||
m >>= f = TLS $ unTLS m >>= unTLS . f
|
||||
|
||||
-- | This is a transitional instance and may be deprecated in the future
|
||||
instance MonadIO TLS where
|
||||
liftIO = TLS . E.withExceptT IOError . UIO.fromIO' (userError . show)
|
||||
type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a
|
||||
|
||||
throwE :: Error -> TLS a
|
||||
throwE = fromExceptT . E.throwE
|
||||
|
||||
catchE :: TLS a -> (Error -> TLS a) -> TLS a
|
||||
catchE inner handler = TLS $ unTLS inner `E.catchE` (unTLS . handler)
|
||||
|
||||
fromExceptT :: E.ExceptT Error UIO a -> TLS a
|
||||
fromExceptT = TLS . E.mapExceptT lift
|
||||
fromExceptT = E.mapExceptT lift
|
||||
|
||||
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
|
||||
runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT (unTLS tls)) s
|
||||
runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT tls) s
|
||||
|
||||
runClient :: Transport -> TLS a -> IO (Either Error a)
|
||||
runClient transport tls = do
|
||||
|
@ -149,18 +129,18 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
|
|||
return (Session fp creds)
|
||||
|
||||
getSession :: TLS Session
|
||||
getSession = TLS $ lift R.ask
|
||||
getSession = lift R.ask
|
||||
|
||||
handshake :: TLS ()
|
||||
handshake = withSession F.gnutls_handshake >>= checkRC
|
||||
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
|
||||
|
||||
rehandshake :: TLS ()
|
||||
rehandshake = withSession F.gnutls_rehandshake >>= checkRC
|
||||
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
|
||||
|
||||
putBytes :: BL.ByteString -> TLS ()
|
||||
putBytes = putChunks . BL.toChunks where
|
||||
putChunks chunks = do
|
||||
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
|
||||
maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
|
||||
case maybeErr of
|
||||
Nothing -> return ()
|
||||
Just err -> throwE $ mapError $ fromIntegral err
|
||||
|
@ -179,7 +159,7 @@ putBytes = putChunks . BL.toChunks where
|
|||
|
||||
getBytes :: Integer -> TLS BL.ByteString
|
||||
getBytes count = do
|
||||
(mbytes, len) <- withSession $ \s ->
|
||||
(mbytes, len) <- unsafeWithSession $ \s ->
|
||||
F.allocaBytes (fromInteger count) $ \ptr -> do
|
||||
len <- F.gnutls_record_recv s ptr (fromInteger count)
|
||||
bytes <- if len >= 0
|
||||
|
@ -194,7 +174,7 @@ getBytes count = do
|
|||
Nothing -> throwE $ mapError $ fromIntegral len
|
||||
|
||||
checkPending :: TLS Integer
|
||||
checkPending = withSession $ \s -> do
|
||||
checkPending = unsafeWithSession $ \s -> do
|
||||
pending <- F.gnutls_record_check_pending s
|
||||
return $ toInteger pending
|
||||
|
||||
|
@ -227,31 +207,32 @@ data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
|
|||
|
||||
setCredentials :: Credentials -> TLS ()
|
||||
setCredentials (Credentials ctype fp) = do
|
||||
rc <- withSession $ \s ->
|
||||
rc <- unsafeWithSession $ \s ->
|
||||
F.withForeignPtr fp $ \ptr -> do
|
||||
F.gnutls_credentials_set s ctype ptr
|
||||
|
||||
s <- getSession
|
||||
if F.unRC rc == 0
|
||||
then liftIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
|
||||
then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
|
||||
else checkRC rc
|
||||
|
||||
certificateCredentials :: TLS Credentials
|
||||
certificateCredentials = do
|
||||
(rc, ptr) <- liftIO $ F.alloca $ \ptr -> do
|
||||
(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
|
||||
rc <- F.gnutls_certificate_allocate_credentials ptr
|
||||
ptr' <- if F.unRC rc < 0
|
||||
then return F.nullPtr
|
||||
else F.peek ptr
|
||||
return (rc, ptr')
|
||||
checkRC rc
|
||||
fp <- liftIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
|
||||
fp <- UIO.unsafeFromIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
|
||||
return $ Credentials (F.CredentialsType 1) fp
|
||||
|
||||
withSession :: (F.Session -> IO a) -> TLS a
|
||||
withSession io = do
|
||||
-- | This must only be called with IO actions that do not throw NonPseudoException
|
||||
unsafeWithSession :: (F.Session -> IO a) -> TLS a
|
||||
unsafeWithSession io = do
|
||||
s <- getSession
|
||||
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
||||
UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
||||
|
||||
checkRC :: F.ReturnCode -> TLS ()
|
||||
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
|
||||
|
|
Loading…
Reference in a new issue