Switch from monads-tf and custom transformer to ExceptT
When this code was written ExceptT didn't exist yet, but there's no reason to vendor a custom implementation of it any longer. We're taking very little advantage of the monads-tf features so just remove that dependency (and the language extension reliance that goes with it).
This commit is contained in:
parent
b32c6617ed
commit
2882576126
3 changed files with 15 additions and 105 deletions
|
@ -35,8 +35,7 @@ library
|
||||||
build-depends:
|
build-depends:
|
||||||
base >= 4.0 && < 5.0
|
base >= 4.0 && < 5.0
|
||||||
, bytestring >= 0.9
|
, bytestring >= 0.9
|
||||||
, transformers >= 0.2
|
, transformers >= 0.4.0.0
|
||||||
, monads-tf >= 0.1 && < 0.2
|
|
||||||
|
|
||||||
extra-libraries: gnutls
|
extra-libraries: gnutls
|
||||||
pkgconfig-depends: gnutls
|
pkgconfig-depends: gnutls
|
||||||
|
@ -45,5 +44,4 @@ library
|
||||||
Network.Protocol.TLS.GNU
|
Network.Protocol.TLS.GNU
|
||||||
|
|
||||||
other-modules:
|
other-modules:
|
||||||
Network.Protocol.TLS.GNU.ErrorT
|
|
||||||
Network.Protocol.TLS.GNU.Foreign
|
Network.Protocol.TLS.GNU.Foreign
|
||||||
|
|
|
@ -41,10 +41,10 @@ module Network.Protocol.TLS.GNU
|
||||||
import Control.Applicative (Applicative, pure, (<*>))
|
import Control.Applicative (Applicative, pure, (<*>))
|
||||||
import qualified Control.Concurrent.MVar as M
|
import qualified Control.Concurrent.MVar as M
|
||||||
import Control.Monad (ap, when, foldM, foldM_)
|
import Control.Monad (ap, when, foldM, foldM_)
|
||||||
import qualified Control.Monad.Error as E
|
import Control.Monad.Trans.Class (lift)
|
||||||
import Control.Monad.Error (ErrorType)
|
import Control.Monad.Trans.Except
|
||||||
import qualified Control.Monad.Reader as R
|
import qualified Control.Monad.Trans.Reader as R
|
||||||
import Control.Monad.Trans (MonadIO, liftIO)
|
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
import qualified Data.ByteString.Lazy as BL
|
import qualified Data.ByteString.Lazy as BL
|
||||||
import qualified Data.ByteString.Unsafe as B
|
import qualified Data.ByteString.Unsafe as B
|
||||||
|
@ -55,7 +55,6 @@ import Foreign.Concurrent as FC
|
||||||
import qualified System.IO as IO
|
import qualified System.IO as IO
|
||||||
import System.IO.Unsafe (unsafePerformIO)
|
import System.IO.Unsafe (unsafePerformIO)
|
||||||
|
|
||||||
import Network.Protocol.TLS.GNU.ErrorT
|
|
||||||
import qualified Network.Protocol.TLS.GNU.Foreign as F
|
import qualified Network.Protocol.TLS.GNU.Foreign as F
|
||||||
|
|
||||||
data Error = Error Integer
|
data Error = Error Integer
|
||||||
|
@ -65,11 +64,11 @@ globalInitMVar :: M.MVar ()
|
||||||
{-# NOINLINE globalInitMVar #-}
|
{-# NOINLINE globalInitMVar #-}
|
||||||
globalInitMVar = unsafePerformIO $ M.newMVar ()
|
globalInitMVar = unsafePerformIO $ M.newMVar ()
|
||||||
|
|
||||||
globalInit :: ErrorT Error IO ()
|
globalInit :: ExceptT Error IO ()
|
||||||
globalInit = do
|
globalInit = do
|
||||||
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
|
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
|
||||||
F.ReturnCode rc <- liftIO init_
|
F.ReturnCode rc <- liftIO init_
|
||||||
when (rc < 0) $ E.throwError $ mapError rc
|
when (rc < 0) $ throwE $ mapError rc
|
||||||
|
|
||||||
globalDeinit :: IO ()
|
globalDeinit :: IO ()
|
||||||
globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit
|
globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit
|
||||||
|
@ -87,7 +86,7 @@ data Session = Session
|
||||||
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
||||||
}
|
}
|
||||||
|
|
||||||
newtype TLS a = TLS { unTLS :: ErrorT Error (R.ReaderT Session IO) a }
|
newtype TLS a = TLS { unTLS :: ExceptT Error (R.ReaderT Session IO) a }
|
||||||
|
|
||||||
instance Functor TLS where
|
instance Functor TLS where
|
||||||
fmap f = TLS . fmap f . unTLS
|
fmap f = TLS . fmap f . unTLS
|
||||||
|
@ -103,13 +102,8 @@ instance Monad TLS where
|
||||||
instance MonadIO TLS where
|
instance MonadIO TLS where
|
||||||
liftIO = TLS . liftIO
|
liftIO = TLS . liftIO
|
||||||
|
|
||||||
instance E.MonadError TLS where
|
|
||||||
type ErrorType TLS = Error
|
|
||||||
throwError = TLS . E.throwError
|
|
||||||
catchError m h = TLS $ E.catchError (unTLS m) (unTLS . h)
|
|
||||||
|
|
||||||
runTLS :: Session -> TLS a -> IO (Either Error a)
|
runTLS :: Session -> TLS a -> IO (Either Error a)
|
||||||
runTLS s tls = R.runReaderT (runErrorT (unTLS tls)) s
|
runTLS s tls = R.runReaderT (runExceptT (unTLS tls)) s
|
||||||
|
|
||||||
runClient :: Transport -> TLS a -> IO (Either Error a)
|
runClient :: Transport -> TLS a -> IO (Either Error a)
|
||||||
runClient transport tls = do
|
runClient transport tls = do
|
||||||
|
@ -119,10 +113,10 @@ runClient transport tls = do
|
||||||
Right session -> runTLS session tls
|
Right session -> runTLS session tls
|
||||||
|
|
||||||
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
|
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
|
||||||
newSession transport end = F.alloca $ \sPtr -> runErrorT $ do
|
newSession transport end = F.alloca $ \sPtr -> runExceptT $ do
|
||||||
globalInit
|
globalInit
|
||||||
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
|
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
|
||||||
when (rc < 0) $ E.throwError $ mapError rc
|
when (rc < 0) $ throwE $ mapError rc
|
||||||
liftIO $ do
|
liftIO $ do
|
||||||
ptr <- F.peek sPtr
|
ptr <- F.peek sPtr
|
||||||
let session = F.Session ptr
|
let session = F.Session ptr
|
||||||
|
@ -140,7 +134,7 @@ newSession transport end = F.alloca $ \sPtr -> runErrorT $ do
|
||||||
return (Session fp creds)
|
return (Session fp creds)
|
||||||
|
|
||||||
getSession :: TLS Session
|
getSession :: TLS Session
|
||||||
getSession = TLS R.ask
|
getSession = TLS $ lift R.ask
|
||||||
|
|
||||||
handshake :: TLS ()
|
handshake :: TLS ()
|
||||||
handshake = withSession F.gnutls_handshake >>= checkRC
|
handshake = withSession F.gnutls_handshake >>= checkRC
|
||||||
|
@ -154,7 +148,7 @@ putBytes = putChunks . BL.toChunks where
|
||||||
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
|
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
|
||||||
case maybeErr of
|
case maybeErr of
|
||||||
Nothing -> return ()
|
Nothing -> return ()
|
||||||
Just err -> E.throwError $ mapError $ fromIntegral err
|
Just err -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral err
|
||||||
|
|
||||||
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
|
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
|
||||||
loop ptr len = do
|
loop ptr len = do
|
||||||
|
@ -182,7 +176,7 @@ getBytes count = do
|
||||||
|
|
||||||
case mbytes of
|
case mbytes of
|
||||||
Just bytes -> return bytes
|
Just bytes -> return bytes
|
||||||
Nothing -> E.throwError $ mapError $ fromIntegral len
|
Nothing -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral len
|
||||||
|
|
||||||
checkPending :: TLS Integer
|
checkPending :: TLS Integer
|
||||||
checkPending = withSession $ \s -> do
|
checkPending = withSession $ \s -> do
|
||||||
|
@ -245,7 +239,7 @@ withSession io = do
|
||||||
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
||||||
|
|
||||||
checkRC :: F.ReturnCode -> TLS ()
|
checkRC :: F.ReturnCode -> TLS ()
|
||||||
checkRC (F.ReturnCode x) = when (x < 0) $ E.throwError $ mapError x
|
checkRC (F.ReturnCode x) = when (x < 0) $ TLS $ mapExceptT lift $ throwE $ mapError x
|
||||||
|
|
||||||
mapError :: F.CInt -> Error
|
mapError :: F.CInt -> Error
|
||||||
mapError = Error . toInteger
|
mapError = Error . toInteger
|
||||||
|
|
|
@ -1,82 +0,0 @@
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
|
||||||
|
|
||||||
-- Copyright (C) 2010 John Millikin <jmillikin@gmail.com>
|
|
||||||
--
|
|
||||||
-- This program is free software: you can redistribute it and/or modify
|
|
||||||
-- it under the terms of the GNU General Public License as published by
|
|
||||||
-- the Free Software Foundation, either version 3 of the License, or
|
|
||||||
-- any later version.
|
|
||||||
--
|
|
||||||
-- This program is distributed in the hope that it will be useful,
|
|
||||||
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
-- GNU General Public License for more details.
|
|
||||||
--
|
|
||||||
-- You should have received a copy of the GNU General Public License
|
|
||||||
-- along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
module Network.Protocol.TLS.GNU.ErrorT
|
|
||||||
( ErrorT (..)
|
|
||||||
, mapErrorT
|
|
||||||
) where
|
|
||||||
|
|
||||||
import Control.Applicative (Applicative, pure, (<*>))
|
|
||||||
import Control.Monad (ap,liftM)
|
|
||||||
import Control.Monad.Trans (MonadIO, liftIO)
|
|
||||||
import Control.Monad.Trans.Class (MonadTrans, lift)
|
|
||||||
import qualified Control.Monad.Error as E
|
|
||||||
import Control.Monad.Error (ErrorType)
|
|
||||||
import qualified Control.Monad.Reader as R
|
|
||||||
import Control.Monad.Reader (EnvType)
|
|
||||||
|
|
||||||
-- A custom version of ErrorT, without the 'Error' class restriction.
|
|
||||||
|
|
||||||
newtype ErrorT e m a = ErrorT { runErrorT :: m (Either e a) }
|
|
||||||
|
|
||||||
instance Functor m => Functor (ErrorT e m) where
|
|
||||||
fmap f = ErrorT . fmap (fmap f) . runErrorT
|
|
||||||
|
|
||||||
instance (Functor m, Monad m) => Applicative (ErrorT e m) where
|
|
||||||
pure a = ErrorT $ return (Right a)
|
|
||||||
f <*> v = ErrorT $ do
|
|
||||||
mf <- runErrorT f
|
|
||||||
case mf of
|
|
||||||
Left e -> return (Left e)
|
|
||||||
Right k -> do
|
|
||||||
mv <- runErrorT v
|
|
||||||
case mv of
|
|
||||||
Left e -> return (Left e)
|
|
||||||
Right x -> return (Right (k x))
|
|
||||||
|
|
||||||
instance Monad m => Monad (ErrorT e m) where
|
|
||||||
return = ErrorT . return . Right
|
|
||||||
(>>=) m k = ErrorT $ do
|
|
||||||
x <- runErrorT m
|
|
||||||
case x of
|
|
||||||
Left l -> return $ Left l
|
|
||||||
Right r -> runErrorT $ k r
|
|
||||||
|
|
||||||
instance Monad m => E.MonadError (ErrorT e m) where
|
|
||||||
type ErrorType (ErrorT e m) = e
|
|
||||||
throwError = ErrorT . return . Left
|
|
||||||
catchError m h = ErrorT $ do
|
|
||||||
x <- runErrorT m
|
|
||||||
case x of
|
|
||||||
Left l -> runErrorT $ h l
|
|
||||||
Right r -> return $ Right r
|
|
||||||
|
|
||||||
instance MonadTrans (ErrorT e) where
|
|
||||||
lift = ErrorT . liftM Right
|
|
||||||
|
|
||||||
instance R.MonadReader m => R.MonadReader (ErrorT e m) where
|
|
||||||
type EnvType (ErrorT e m) = EnvType m
|
|
||||||
ask = lift R.ask
|
|
||||||
local = mapErrorT . R.local
|
|
||||||
|
|
||||||
instance MonadIO m => MonadIO (ErrorT e m) where
|
|
||||||
liftIO = lift . liftIO
|
|
||||||
|
|
||||||
mapErrorT :: (m (Either e a) -> n (Either e' b))
|
|
||||||
-> ErrorT e m a
|
|
||||||
-> ErrorT e' n b
|
|
||||||
mapErrorT f m = ErrorT $ f (runErrorT m)
|
|
Loading…
Reference in a new issue