Switch from monads-tf and custom transformer to ExceptT
When this code was written ExceptT didn't exist yet, but there's no reason to vendor a custom implementation of it any longer. We're taking very little advantage of the monads-tf features so just remove that dependency (and the language extension reliance that goes with it).
This commit is contained in:
parent
b32c6617ed
commit
2882576126
3 changed files with 15 additions and 105 deletions
|
@ -35,8 +35,7 @@ library
|
|||
build-depends:
|
||||
base >= 4.0 && < 5.0
|
||||
, bytestring >= 0.9
|
||||
, transformers >= 0.2
|
||||
, monads-tf >= 0.1 && < 0.2
|
||||
, transformers >= 0.4.0.0
|
||||
|
||||
extra-libraries: gnutls
|
||||
pkgconfig-depends: gnutls
|
||||
|
@ -45,5 +44,4 @@ library
|
|||
Network.Protocol.TLS.GNU
|
||||
|
||||
other-modules:
|
||||
Network.Protocol.TLS.GNU.ErrorT
|
||||
Network.Protocol.TLS.GNU.Foreign
|
||||
|
|
|
@ -41,10 +41,10 @@ module Network.Protocol.TLS.GNU
|
|||
import Control.Applicative (Applicative, pure, (<*>))
|
||||
import qualified Control.Concurrent.MVar as M
|
||||
import Control.Monad (ap, when, foldM, foldM_)
|
||||
import qualified Control.Monad.Error as E
|
||||
import Control.Monad.Error (ErrorType)
|
||||
import qualified Control.Monad.Reader as R
|
||||
import Control.Monad.Trans (MonadIO, liftIO)
|
||||
import Control.Monad.Trans.Class (lift)
|
||||
import Control.Monad.Trans.Except
|
||||
import qualified Control.Monad.Trans.Reader as R
|
||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import qualified Data.ByteString.Unsafe as B
|
||||
|
@ -55,7 +55,6 @@ import Foreign.Concurrent as FC
|
|||
import qualified System.IO as IO
|
||||
import System.IO.Unsafe (unsafePerformIO)
|
||||
|
||||
import Network.Protocol.TLS.GNU.ErrorT
|
||||
import qualified Network.Protocol.TLS.GNU.Foreign as F
|
||||
|
||||
data Error = Error Integer
|
||||
|
@ -65,11 +64,11 @@ globalInitMVar :: M.MVar ()
|
|||
{-# NOINLINE globalInitMVar #-}
|
||||
globalInitMVar = unsafePerformIO $ M.newMVar ()
|
||||
|
||||
globalInit :: ErrorT Error IO ()
|
||||
globalInit :: ExceptT Error IO ()
|
||||
globalInit = do
|
||||
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
|
||||
F.ReturnCode rc <- liftIO init_
|
||||
when (rc < 0) $ E.throwError $ mapError rc
|
||||
when (rc < 0) $ throwE $ mapError rc
|
||||
|
||||
globalDeinit :: IO ()
|
||||
globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit
|
||||
|
@ -87,7 +86,7 @@ data Session = Session
|
|||
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
||||
}
|
||||
|
||||
newtype TLS a = TLS { unTLS :: ErrorT Error (R.ReaderT Session IO) a }
|
||||
newtype TLS a = TLS { unTLS :: ExceptT Error (R.ReaderT Session IO) a }
|
||||
|
||||
instance Functor TLS where
|
||||
fmap f = TLS . fmap f . unTLS
|
||||
|
@ -103,13 +102,8 @@ instance Monad TLS where
|
|||
instance MonadIO TLS where
|
||||
liftIO = TLS . liftIO
|
||||
|
||||
instance E.MonadError TLS where
|
||||
type ErrorType TLS = Error
|
||||
throwError = TLS . E.throwError
|
||||
catchError m h = TLS $ E.catchError (unTLS m) (unTLS . h)
|
||||
|
||||
runTLS :: Session -> TLS a -> IO (Either Error a)
|
||||
runTLS s tls = R.runReaderT (runErrorT (unTLS tls)) s
|
||||
runTLS s tls = R.runReaderT (runExceptT (unTLS tls)) s
|
||||
|
||||
runClient :: Transport -> TLS a -> IO (Either Error a)
|
||||
runClient transport tls = do
|
||||
|
@ -119,10 +113,10 @@ runClient transport tls = do
|
|||
Right session -> runTLS session tls
|
||||
|
||||
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
|
||||
newSession transport end = F.alloca $ \sPtr -> runErrorT $ do
|
||||
newSession transport end = F.alloca $ \sPtr -> runExceptT $ do
|
||||
globalInit
|
||||
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
|
||||
when (rc < 0) $ E.throwError $ mapError rc
|
||||
when (rc < 0) $ throwE $ mapError rc
|
||||
liftIO $ do
|
||||
ptr <- F.peek sPtr
|
||||
let session = F.Session ptr
|
||||
|
@ -140,7 +134,7 @@ newSession transport end = F.alloca $ \sPtr -> runErrorT $ do
|
|||
return (Session fp creds)
|
||||
|
||||
getSession :: TLS Session
|
||||
getSession = TLS R.ask
|
||||
getSession = TLS $ lift R.ask
|
||||
|
||||
handshake :: TLS ()
|
||||
handshake = withSession F.gnutls_handshake >>= checkRC
|
||||
|
@ -154,7 +148,7 @@ putBytes = putChunks . BL.toChunks where
|
|||
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
|
||||
case maybeErr of
|
||||
Nothing -> return ()
|
||||
Just err -> E.throwError $ mapError $ fromIntegral err
|
||||
Just err -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral err
|
||||
|
||||
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
|
||||
loop ptr len = do
|
||||
|
@ -182,7 +176,7 @@ getBytes count = do
|
|||
|
||||
case mbytes of
|
||||
Just bytes -> return bytes
|
||||
Nothing -> E.throwError $ mapError $ fromIntegral len
|
||||
Nothing -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral len
|
||||
|
||||
checkPending :: TLS Integer
|
||||
checkPending = withSession $ \s -> do
|
||||
|
@ -245,7 +239,7 @@ withSession io = do
|
|||
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
||||
|
||||
checkRC :: F.ReturnCode -> TLS ()
|
||||
checkRC (F.ReturnCode x) = when (x < 0) $ E.throwError $ mapError x
|
||||
checkRC (F.ReturnCode x) = when (x < 0) $ TLS $ mapExceptT lift $ throwE $ mapError x
|
||||
|
||||
mapError :: F.CInt -> Error
|
||||
mapError = Error . toInteger
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
{-# LANGUAGE TypeFamilies #-}
|
||||
|
||||
-- Copyright (C) 2010 John Millikin <jmillikin@gmail.com>
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU General Public License as published by
|
||||
-- the Free Software Foundation, either version 3 of the License, or
|
||||
-- any later version.
|
||||
--
|
||||
-- This program is distributed in the hope that it will be useful,
|
||||
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
-- GNU General Public License for more details.
|
||||
--
|
||||
-- You should have received a copy of the GNU General Public License
|
||||
-- along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
module Network.Protocol.TLS.GNU.ErrorT
|
||||
( ErrorT (..)
|
||||
, mapErrorT
|
||||
) where
|
||||
|
||||
import Control.Applicative (Applicative, pure, (<*>))
|
||||
import Control.Monad (ap,liftM)
|
||||
import Control.Monad.Trans (MonadIO, liftIO)
|
||||
import Control.Monad.Trans.Class (MonadTrans, lift)
|
||||
import qualified Control.Monad.Error as E
|
||||
import Control.Monad.Error (ErrorType)
|
||||
import qualified Control.Monad.Reader as R
|
||||
import Control.Monad.Reader (EnvType)
|
||||
|
||||
-- A custom version of ErrorT, without the 'Error' class restriction.
|
||||
|
||||
newtype ErrorT e m a = ErrorT { runErrorT :: m (Either e a) }
|
||||
|
||||
instance Functor m => Functor (ErrorT e m) where
|
||||
fmap f = ErrorT . fmap (fmap f) . runErrorT
|
||||
|
||||
instance (Functor m, Monad m) => Applicative (ErrorT e m) where
|
||||
pure a = ErrorT $ return (Right a)
|
||||
f <*> v = ErrorT $ do
|
||||
mf <- runErrorT f
|
||||
case mf of
|
||||
Left e -> return (Left e)
|
||||
Right k -> do
|
||||
mv <- runErrorT v
|
||||
case mv of
|
||||
Left e -> return (Left e)
|
||||
Right x -> return (Right (k x))
|
||||
|
||||
instance Monad m => Monad (ErrorT e m) where
|
||||
return = ErrorT . return . Right
|
||||
(>>=) m k = ErrorT $ do
|
||||
x <- runErrorT m
|
||||
case x of
|
||||
Left l -> return $ Left l
|
||||
Right r -> runErrorT $ k r
|
||||
|
||||
instance Monad m => E.MonadError (ErrorT e m) where
|
||||
type ErrorType (ErrorT e m) = e
|
||||
throwError = ErrorT . return . Left
|
||||
catchError m h = ErrorT $ do
|
||||
x <- runErrorT m
|
||||
case x of
|
||||
Left l -> runErrorT $ h l
|
||||
Right r -> return $ Right r
|
||||
|
||||
instance MonadTrans (ErrorT e) where
|
||||
lift = ErrorT . liftM Right
|
||||
|
||||
instance R.MonadReader m => R.MonadReader (ErrorT e m) where
|
||||
type EnvType (ErrorT e m) = EnvType m
|
||||
ask = lift R.ask
|
||||
local = mapErrorT . R.local
|
||||
|
||||
instance MonadIO m => MonadIO (ErrorT e m) where
|
||||
liftIO = lift . liftIO
|
||||
|
||||
mapErrorT :: (m (Either e a) -> n (Either e' b))
|
||||
-> ErrorT e m a
|
||||
-> ErrorT e' n b
|
||||
mapErrorT f m = ErrorT $ f (runErrorT m)
|
Loading…
Reference in a new issue