Clean up the Cabal file, and move library source to lib/.

This commit is contained in:
John Millikin 2012-02-23 18:30:29 -08:00
parent f0f4eef863
commit c36fdda7d8
No known key found for this signature in database
GPG key ID: 59A38F85F9C7C59E
4 changed files with 17 additions and 5 deletions

View file

@ -0,0 +1,252 @@
-- Copyright (C) 2010 John Millikin <jmillikin@gmail.com>
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program. If not, see <http://www.gnu.org/licenses/>.
{-# LANGUAGE TypeFamilies #-}
module Network.Protocol.TLS.GNU
( TLS
, Session
, Error (..)
, runTLS
, runClient
, getSession
, handshake
, putBytes
, getBytes
, checkPending
-- * Settings
, Transport (..)
, handleTransport
, Credentials
, setCredentials
, certificateCredentials
, Prioritised
, setPriority
, CertificateType (..)
) where
import Control.Monad (when, foldM, foldM_)
import Control.Monad.Trans (MonadIO, liftIO)
import qualified Control.Monad.Error as E
import Control.Monad.Error (ErrorType)
import qualified Control.Monad.Reader as R
import qualified Control.Concurrent.MVar as M
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import qualified Data.ByteString.Lazy as BL
import qualified System.IO as IO
import qualified Foreign as F
import qualified Foreign.C as F
import qualified Network.Protocol.TLS.GNU.Foreign as F
import Foreign.Concurrent as FC
import Network.Protocol.TLS.GNU.ErrorT
import System.IO.Unsafe (unsafePerformIO)
data Error = Error Integer
deriving (Show)
globalInitMVar :: M.MVar ()
{-# NOINLINE globalInitMVar #-}
globalInitMVar = unsafePerformIO $ M.newMVar ()
newtype GlobalState = GlobalState (F.ForeignPtr ())
globalInit :: ErrorT Error IO GlobalState
globalInit = do
let init = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
let deinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit
F.ReturnCode rc <- liftIO $ init
when (rc < 0) $ E.throwError $ mapError rc
fp <- liftIO $ FC.newForeignPtr F.nullPtr deinit
return $ GlobalState fp
data Session = Session
{ sessionPtr :: F.ForeignPtr F.Session
, sessionGlobalState :: GlobalState
}
newtype TLS a = TLS { unTLS :: ErrorT Error (R.ReaderT Session IO) a }
instance Functor TLS where
fmap f = TLS . fmap f . unTLS
instance Monad TLS where
return = TLS . return
m >>= f = TLS $ unTLS m >>= unTLS . f
instance MonadIO TLS where
liftIO = TLS . liftIO
instance E.MonadError TLS where
type ErrorType TLS = Error
throwError = TLS . E.throwError
catchError m h = TLS $ E.catchError (unTLS m) (unTLS . h)
runTLS :: Session -> TLS a -> IO (Either Error a)
runTLS s tls = R.runReaderT (runErrorT (unTLS tls)) s
runClient :: Transport -> TLS a -> IO (Either Error a)
runClient transport tls = do
eitherSession <- newSession transport (F.ConnectionEnd 2)
case eitherSession of
Left err -> return (Left err)
Right session -> runTLS session tls
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
newSession transport end = F.alloca $ \sPtr -> runErrorT $ do
global <- globalInit
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
when (rc < 0) $ E.throwError $ mapError rc
liftIO $ do
ptr <- F.peek sPtr
let session = F.Session ptr
push <- F.wrapTransportFunc (pushImpl transport)
pull <- F.wrapTransportFunc (pullImpl transport)
F.gnutls_transport_set_push_function session push
F.gnutls_transport_set_pull_function session pull
F.gnutls_set_default_priority session
fp <- FC.newForeignPtr ptr $ do
F.gnutls_deinit session
F.freeHaskellFunPtr push
F.freeHaskellFunPtr pull
return $ Session fp global
getSession :: TLS Session
getSession = TLS R.ask
handshake :: TLS ()
handshake = withSession F.gnutls_handshake >>= checkRC
rehandshake :: TLS ()
rehandshake = withSession F.gnutls_rehandshake >>= checkRC
putBytes :: BL.ByteString -> TLS ()
putBytes = putChunks . BL.toChunks where
putChunks chunks = do
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
case maybeErr of
Nothing -> return ()
Just err -> E.throwError $ mapError $ fromIntegral err
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
loop ptr len = do
let len' = fromIntegral len
sent <- F.gnutls_record_send s ptr len'
let sent' = fromIntegral sent
case len - sent' of
0 -> return Nothing
x | x > 0 -> loop (F.plusPtr ptr sent') x
| otherwise -> return $ Just x
putChunk _ err _ = return err
getBytes :: Integer -> TLS BL.ByteString
getBytes count = do
(bytes, len) <- withSession $ \s ->
F.allocaBytes (fromInteger count) $ \ptr -> do
len <- F.gnutls_record_recv s ptr (fromInteger count)
bytes <- if len >= 0
then do
chunk <- B.packCStringLen (ptr, fromIntegral len)
return $ Just $ BL.fromChunks [chunk]
else return Nothing
return (bytes, len)
case bytes of
Just bytes -> return bytes
Nothing -> E.throwError $ mapError $ fromIntegral len
checkPending :: TLS Integer
checkPending = withSession $ \s -> do
pending <- F.gnutls_record_check_pending s
return $ toInteger pending
data Transport = Transport
{ transportPush :: BL.ByteString -> IO ()
, transportPull :: Integer -> IO BL.ByteString
}
pullImpl :: Transport -> F.TransportFunc
pullImpl t _ buf bufSize = do
bytes <- transportPull t $ toInteger bufSize
let loop ptr chunk =
B.unsafeUseAsCStringLen chunk $ \(cstr, len) -> do
F.copyArray (F.castPtr ptr) cstr len
return $ F.plusPtr ptr len
foldM_ loop buf $ BL.toChunks bytes
return $ fromIntegral $ BL.length bytes
pushImpl :: Transport -> F.TransportFunc
pushImpl t _ buf bufSize = do
let buf' = F.castPtr buf
bytes <- B.unsafePackCStringLen (buf', fromIntegral bufSize)
transportPush t $ BL.fromChunks [bytes]
return bufSize
handleTransport :: IO.Handle -> Transport
handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger)
data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
setCredentials :: Credentials -> TLS ()
setCredentials (Credentials ctype fp) = do
rc <- withSession $ \s ->
F.withForeignPtr fp $ \ptr -> do
F.gnutls_credentials_set s ctype ptr
checkRC rc
certificateCredentials :: TLS Credentials
certificateCredentials = do
(rc, ptr) <- liftIO $ F.alloca $ \ptr -> do
rc <- F.gnutls_certificate_allocate_credentials ptr
ptr <- if F.unRC rc < 0
then return F.nullPtr
else F.peek ptr
return (rc, ptr)
checkRC rc
fp <- liftIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
return $ Credentials (F.CredentialsType 1) fp
class Prioritised a where
priorityInt :: a -> F.CInt
priorityProc :: a -> F.Session -> F.Ptr F.CInt -> IO F.ReturnCode
data CertificateType = X509 | OpenPGP
deriving (Show)
instance Prioritised CertificateType where
priorityProc = const F.gnutls_certificate_type_set_priority
priorityInt x = case x of
X509 -> 1
OpenPGP -> 2
setPriority :: Prioritised a => [a] -> TLS ()
setPriority xs = do
let fake = head $ [undefined] ++ xs
rc <- withSession $ F.withArray0 0 (map priorityInt xs) . priorityProc fake
checkRC rc
withSession :: (F.Session -> IO a) -> TLS a
withSession io = do
s <- getSession
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
checkRC :: F.ReturnCode -> TLS ()
checkRC (F.ReturnCode x) = when (x < 0) $ E.throwError $ mapError x
mapError :: F.CInt -> Error
mapError = Error . toInteger

View file

@ -0,0 +1,68 @@
-- Copyright (C) 2010 John Millikin <jmillikin@gmail.com>
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program. If not, see <http://www.gnu.org/licenses/>.
{-# LANGUAGE TypeFamilies #-}
module Network.Protocol.TLS.GNU.ErrorT
( ErrorT (..)
, mapErrorT
) where
import Control.Monad (liftM)
import Control.Monad.Trans (MonadIO, liftIO)
import Control.Monad.Trans.Class (MonadTrans, lift)
import qualified Control.Monad.Error as E
import Control.Monad.Error (ErrorType)
import qualified Control.Monad.Reader as R
import Control.Monad.Reader (EnvType)
-- A custom version of ErrorT, without the 'Error' class restriction.
newtype ErrorT e m a = ErrorT { runErrorT :: m (Either e a) }
instance Functor m => Functor (ErrorT e m) where
fmap f = ErrorT . fmap (fmap f) . runErrorT
instance Monad m => Monad (ErrorT e m) where
return = ErrorT . return . Right
(>>=) m k = ErrorT $ do
x <- runErrorT m
case x of
Left l -> return $ Left l
Right r -> runErrorT $ k r
instance Monad m => E.MonadError (ErrorT e m) where
type ErrorType (ErrorT e m) = e
throwError = ErrorT . return . Left
catchError m h = ErrorT $ do
x <- runErrorT m
case x of
Left l -> runErrorT $ h l
Right r -> return $ Right r
instance MonadTrans (ErrorT e) where
lift = ErrorT . liftM Right
instance R.MonadReader m => R.MonadReader (ErrorT e m) where
type EnvType (ErrorT e m) = EnvType m
ask = lift R.ask
local = mapErrorT . R.local
instance MonadIO m => MonadIO (ErrorT e m) where
liftIO = lift . liftIO
mapErrorT :: (m (Either e a) -> n (Either e' b))
-> ErrorT e m a
-> ErrorT e' n b
mapErrorT f m = ErrorT $ f (runErrorT m)

View file

@ -0,0 +1,283 @@
-- Copyright (C) 2010 John Millikin <jmillikin@gmail.com>
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program. If not, see <http://www.gnu.org/licenses/>.
{-# LANGUAGE ForeignFunctionInterface #-}
module Network.Protocol.TLS.GNU.Foreign where
import Foreign
import Foreign.C
-- Type aliases {{{
newtype ReturnCode = ReturnCode { unRC :: CInt }
deriving (Show, Eq)
newtype CipherAlgorithm = CipherAlgorithm CInt
deriving (Show, Eq)
newtype KXAlgorithm = KXAlgorithm CInt
deriving (Show, Eq)
newtype ParamsType = ParamsType CInt
deriving (Show, Eq)
newtype CredentialsType = CredentialsType CInt
deriving (Show, Eq)
newtype MACAlgorithm = MACAlgorithm CInt
deriving (Show, Eq)
newtype DigestAlgorithm = DigestAlgorithm CInt
deriving (Show, Eq)
newtype CompressionMethod = CompressionMethod CInt
deriving (Show, Eq)
newtype ConnectionEnd = ConnectionEnd CInt
deriving (Show, Eq)
newtype AlertLevel = AlertLevel CInt
deriving (Show, Eq)
newtype AlertDescription = AlertDescription CInt
deriving (Show, Eq)
newtype HandshakeDescription = HandshakeDescription CInt
deriving (Show, Eq)
newtype CertificateStatus = CertificateStatus CInt
deriving (Show, Eq)
newtype CertificateRequest = CertificateRequest CInt
deriving (Show, Eq)
newtype OpenPGPCrtStatus = OpenPGPCrtStatus CInt
deriving (Show, Eq)
newtype CloseRequest = CloseRequest CInt
deriving (Show, Eq)
newtype Protocol = Protocol CInt
deriving (Show, Eq)
newtype CertificateType = CertificateType CInt
deriving (Show, Eq)
newtype X509CrtFormat = X509CrtFormat CInt
deriving (Show, Eq)
newtype CertificatePrintFormats = CertificatePrintFormats CInt
deriving (Show, Eq)
newtype PKAlgorithm = PKAlgorithm CInt
deriving (Show, Eq)
newtype SignAlgorithm = SignAlgorithm CInt
deriving (Show, Eq)
newtype Credentials = Credentials (Ptr Credentials)
newtype Transport = Transport (Ptr Transport)
newtype Session = Session (Ptr Session)
newtype DHParams = DHParams (Ptr DHParams)
newtype RSAParams = RSAParams (Ptr RSAParams)
newtype Priority = Priority (Ptr Priority)
newtype Datum = Datum (Ptr Word8, CUInt)
-- }}}
-- Global library info / state {{{
foreign import ccall safe "gnutls_check_version"
gnutls_check_version :: CString -> IO CString
foreign import ccall safe "gnutls_extra_check_version"
gnutls_extra_check_version :: CString -> IO CString
foreign import ccall safe "gnutls_global_init"
gnutls_global_init :: IO ReturnCode
foreign import ccall safe "gnutls_global_init_extra"
gnutls_global_init_extra :: IO ReturnCode
foreign import ccall safe "gnutls_global_deinit"
gnutls_global_deinit :: IO ()
foreign import ccall safe "gnutls_global_set_log_function"
gnutls_global_set_log_function :: FunPtr (CInt -> CString -> IO ()) -> IO ()
foreign import ccall safe "gnutls_global_set_log_level"
gnutls_global_set_log_level :: CInt -> IO ()
-- }}}
-- Error handling {{{
foreign import ccall safe "gnutls_error_is_fatal"
gnutls_error_is_fatal :: ReturnCode -> IO CInt
foreign import ccall safe "gnutls_perror"
gnutls_perror :: ReturnCode -> IO ()
foreign import ccall safe "gnutls_strerror"
gnutls_strerror :: ReturnCode -> IO CString
foreign import ccall safe "gnutls_strerror_name"
gnutls_strerror_name :: ReturnCode -> IO CString
-- }}}
-- Sessions {{{
foreign import ccall safe "gnutls_init"
gnutls_init :: Ptr (Ptr Session) -> ConnectionEnd -> IO ReturnCode
foreign import ccall safe "gnutls_deinit"
gnutls_deinit :: Session -> IO ()
foreign import ccall safe "gnutls_handshake"
gnutls_handshake :: Session -> IO ReturnCode
foreign import ccall safe "gnutls_rehandshake"
gnutls_rehandshake :: Session -> IO ReturnCode
foreign import ccall safe "gnutls_bye"
gnutls_bye :: Session -> CloseRequest -> IO ReturnCode
foreign import ccall safe "gnutls_set_default_priority"
gnutls_set_default_priority :: Session -> IO ReturnCode
-- }}}
-- Alerts {{{
foreign import ccall safe "gnutls_alert_get_name"
gnutls_alert_get_name :: AlertDescription -> IO CString
foreign import ccall safe "gnutls_error_to_alert"
gnutls_error_to_alert :: ReturnCode -> Ptr AlertLevel -> IO AlertDescription
foreign import ccall safe "gnutls_alert_get"
gnutls_alert_get :: Session -> IO AlertDescription
foreign import ccall safe "gnutls_alert_send_appropriate"
gnutls_alert_send_appropriate :: Session -> ReturnCode -> IO ReturnCode
foreign import ccall safe "gnutls_alert_send"
gnutls_alert_send :: Session -> AlertLevel -> AlertDescription -> IO ReturnCode
-- }}}
-- Certificates {{{
foreign import ccall safe "gnutls_certificate_allocate_credentials"
gnutls_certificate_allocate_credentials :: Ptr (Ptr Credentials) -> IO ReturnCode
foreign import ccall safe "&gnutls_certificate_free_credentials"
gnutls_certificate_free_credentials_funptr :: FunPtr (Ptr Credentials -> IO ())
foreign import ccall safe "gnutls_certificate_type_get_id"
gnutls_certificate_type_get_id :: CString -> IO CertificateType
foreign import ccall safe "gnutls_certificate_type_get_name"
gnutls_certificate_type_get_name :: CertificateType -> IO CString
foreign import ccall safe "gnutls_certificate_type_get"
gnutls_certificate_type_get :: Session -> IO CertificateType
foreign import ccall safe "gnutls_certificate_type_list"
gnutls_certificate_type_list :: IO (Ptr CertificateType)
foreign import ccall safe "gnutls_certificate_type_set_priority"
gnutls_certificate_type_set_priority :: Session -> Ptr CInt -> IO ReturnCode
-- }}}
-- Credentials {{{
foreign import ccall safe "gnutls_credentials_clear"
gnutls_credentials_clear :: Session -> IO ()
foreign import ccall safe "gnutls_credentials_set"
gnutls_credentials_set :: Session -> CredentialsType -> Ptr a -> IO ReturnCode
-- }}}
-- Records {{{
foreign import ccall safe "gnutls_record_check_pending"
gnutls_record_check_pending :: Session -> IO CSize
foreign import ccall safe "gnutls_record_disable_padding"
gnutls_record_disable_padding :: Session -> IO ()
foreign import ccall safe "gnutls_record_get_direction"
gnutls_record_get_direction :: Session -> IO CInt
foreign import ccall safe "gnutls_record_get_max_size"
gnutls_record_get_max_size :: Session -> IO CSize
foreign import ccall safe "gnutls_record_recv"
gnutls_record_recv :: Session -> Ptr a -> CSize -> IO CSize
foreign import ccall safe "gnutls_record_send"
gnutls_record_send :: Session -> Ptr a -> CSize -> IO CSize
foreign import ccall safe "gnutls_record_set_max_size"
gnutls_record_set_max_size :: Session -> CSize -> IO CSize
-- }}}
-- Transports {{{
type TransportFunc = Transport -> Ptr () -> CSize -> IO CSize
foreign import ccall safe "gnutls_transport_set_push_function"
gnutls_transport_set_push_function :: Session -> FunPtr TransportFunc -> IO ()
foreign import ccall safe "gnutls_transport_set_pull_function"
gnutls_transport_set_pull_function :: Session -> FunPtr TransportFunc -> IO ()
foreign import ccall "wrapper"
wrapTransportFunc :: TransportFunc -> IO (FunPtr TransportFunc)
-- }}}
-- Utility {{{
foreign import ccall safe "gnutls_global_set_mem_functions"
gnutls_global_set_mem_functions
:: FunPtr (CSize -> IO (Ptr ()))
-> FunPtr (CSize -> CSize -> IO (Ptr ()))
-> FunPtr (Ptr () -> IO CInt)
-> FunPtr (Ptr () -> CSize -> IO (Ptr ()))
-> FunPtr (Ptr () -> IO ())
-> IO ()
foreign import ccall safe "gnutls_malloc"
gnutls_malloc :: CSize -> IO (Ptr a)
foreign import ccall safe "gnutls_free"
gnutls_free :: Ptr a -> IO ()
foreign import ccall safe "gnutls_hex2bin"
gnutls_hex2bin :: CString -> CSize -> Ptr Word8 -> Ptr CSize -> IO ReturnCode
foreign import ccall safe "gnutls_hex_decode"
gnutls_hex_decode :: Ptr Datum -> Ptr Word8 -> Ptr CSize -> IO ReturnCode
foreign import ccall safe "gnutls_hex_encode"
gnutls_hex_encode :: Ptr Datum -> CString -> Ptr CSize -> IO ReturnCode
-- }}}