Switch from monads-tf and custom transformer to ExceptT

When this code was written ExceptT didn't exist yet, but there's no reason to
vendor a custom implementation of it any longer.  We're taking very little
advantage of the monads-tf features so just remove that dependency (and the
language extension reliance that goes with it).
This commit is contained in:
Stephen Paul Weber 2021-02-13 20:07:24 -05:00
parent b32c6617ed
commit 2882576126
3 changed files with 15 additions and 105 deletions

View file

@ -35,8 +35,7 @@ library
build-depends:
base >= 4.0 && < 5.0
, bytestring >= 0.9
, transformers >= 0.2
, monads-tf >= 0.1 && < 0.2
, transformers >= 0.4.0.0
extra-libraries: gnutls
pkgconfig-depends: gnutls
@ -45,5 +44,4 @@ library
Network.Protocol.TLS.GNU
other-modules:
Network.Protocol.TLS.GNU.ErrorT
Network.Protocol.TLS.GNU.Foreign

View file

@ -41,10 +41,10 @@ module Network.Protocol.TLS.GNU
import Control.Applicative (Applicative, pure, (<*>))
import qualified Control.Concurrent.MVar as M
import Control.Monad (ap, when, foldM, foldM_)
import qualified Control.Monad.Error as E
import Control.Monad.Error (ErrorType)
import qualified Control.Monad.Reader as R
import Control.Monad.Trans (MonadIO, liftIO)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Except
import qualified Control.Monad.Trans.Reader as R
import Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Unsafe as B
@ -55,7 +55,6 @@ import Foreign.Concurrent as FC
import qualified System.IO as IO
import System.IO.Unsafe (unsafePerformIO)
import Network.Protocol.TLS.GNU.ErrorT
import qualified Network.Protocol.TLS.GNU.Foreign as F
data Error = Error Integer
@ -65,11 +64,11 @@ globalInitMVar :: M.MVar ()
{-# NOINLINE globalInitMVar #-}
globalInitMVar = unsafePerformIO $ M.newMVar ()
globalInit :: ErrorT Error IO ()
globalInit :: ExceptT Error IO ()
globalInit = do
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
F.ReturnCode rc <- liftIO init_
when (rc < 0) $ E.throwError $ mapError rc
when (rc < 0) $ throwE $ mapError rc
globalDeinit :: IO ()
globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit
@ -87,7 +86,7 @@ data Session = Session
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
}
newtype TLS a = TLS { unTLS :: ErrorT Error (R.ReaderT Session IO) a }
newtype TLS a = TLS { unTLS :: ExceptT Error (R.ReaderT Session IO) a }
instance Functor TLS where
fmap f = TLS . fmap f . unTLS
@ -103,13 +102,8 @@ instance Monad TLS where
instance MonadIO TLS where
liftIO = TLS . liftIO
instance E.MonadError TLS where
type ErrorType TLS = Error
throwError = TLS . E.throwError
catchError m h = TLS $ E.catchError (unTLS m) (unTLS . h)
runTLS :: Session -> TLS a -> IO (Either Error a)
runTLS s tls = R.runReaderT (runErrorT (unTLS tls)) s
runTLS s tls = R.runReaderT (runExceptT (unTLS tls)) s
runClient :: Transport -> TLS a -> IO (Either Error a)
runClient transport tls = do
@ -119,10 +113,10 @@ runClient transport tls = do
Right session -> runTLS session tls
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
newSession transport end = F.alloca $ \sPtr -> runErrorT $ do
newSession transport end = F.alloca $ \sPtr -> runExceptT $ do
globalInit
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
when (rc < 0) $ E.throwError $ mapError rc
when (rc < 0) $ throwE $ mapError rc
liftIO $ do
ptr <- F.peek sPtr
let session = F.Session ptr
@ -140,7 +134,7 @@ newSession transport end = F.alloca $ \sPtr -> runErrorT $ do
return (Session fp creds)
getSession :: TLS Session
getSession = TLS R.ask
getSession = TLS $ lift R.ask
handshake :: TLS ()
handshake = withSession F.gnutls_handshake >>= checkRC
@ -154,7 +148,7 @@ putBytes = putChunks . BL.toChunks where
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
case maybeErr of
Nothing -> return ()
Just err -> E.throwError $ mapError $ fromIntegral err
Just err -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral err
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
loop ptr len = do
@ -182,7 +176,7 @@ getBytes count = do
case mbytes of
Just bytes -> return bytes
Nothing -> E.throwError $ mapError $ fromIntegral len
Nothing -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral len
checkPending :: TLS Integer
checkPending = withSession $ \s -> do
@ -245,7 +239,7 @@ withSession io = do
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
checkRC :: F.ReturnCode -> TLS ()
checkRC (F.ReturnCode x) = when (x < 0) $ E.throwError $ mapError x
checkRC (F.ReturnCode x) = when (x < 0) $ TLS $ mapExceptT lift $ throwE $ mapError x
mapError :: F.CInt -> Error
mapError = Error . toInteger

View file

@ -1,82 +0,0 @@
{-# LANGUAGE TypeFamilies #-}
-- Copyright (C) 2010 John Millikin <jmillikin@gmail.com>
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program. If not, see <http://www.gnu.org/licenses/>.
module Network.Protocol.TLS.GNU.ErrorT
( ErrorT (..)
, mapErrorT
) where
import Control.Applicative (Applicative, pure, (<*>))
import Control.Monad (ap,liftM)
import Control.Monad.Trans (MonadIO, liftIO)
import Control.Monad.Trans.Class (MonadTrans, lift)
import qualified Control.Monad.Error as E
import Control.Monad.Error (ErrorType)
import qualified Control.Monad.Reader as R
import Control.Monad.Reader (EnvType)
-- A custom version of ErrorT, without the 'Error' class restriction.
newtype ErrorT e m a = ErrorT { runErrorT :: m (Either e a) }
instance Functor m => Functor (ErrorT e m) where
fmap f = ErrorT . fmap (fmap f) . runErrorT
instance (Functor m, Monad m) => Applicative (ErrorT e m) where
pure a = ErrorT $ return (Right a)
f <*> v = ErrorT $ do
mf <- runErrorT f
case mf of
Left e -> return (Left e)
Right k -> do
mv <- runErrorT v
case mv of
Left e -> return (Left e)
Right x -> return (Right (k x))
instance Monad m => Monad (ErrorT e m) where
return = ErrorT . return . Right
(>>=) m k = ErrorT $ do
x <- runErrorT m
case x of
Left l -> return $ Left l
Right r -> runErrorT $ k r
instance Monad m => E.MonadError (ErrorT e m) where
type ErrorType (ErrorT e m) = e
throwError = ErrorT . return . Left
catchError m h = ErrorT $ do
x <- runErrorT m
case x of
Left l -> runErrorT $ h l
Right r -> return $ Right r
instance MonadTrans (ErrorT e) where
lift = ErrorT . liftM Right
instance R.MonadReader m => R.MonadReader (ErrorT e m) where
type EnvType (ErrorT e m) = EnvType m
ask = lift R.ask
local = mapErrorT . R.local
instance MonadIO m => MonadIO (ErrorT e m) where
liftIO = lift . liftIO
mapErrorT :: (m (Either e a) -> n (Either e' b))
-> ErrorT e m a
-> ErrorT e' n b
mapErrorT f m = ErrorT $ f (runErrorT m)