0d7a60a1cc
GnuTLS has separate initialization and deinitialization procedures for global and per-session state. Previously, haskell-gnutls used Haskell's garbage collector (via ForeignPtr) to manage these separate states by creating a dummy GlobalState type representing an initialized global state. The Session type contained ForeignPtrs to the global and session state, with the idea that GC would collect them both at the same time (albeit in non-determinstic order). It turns out that session deinitialization *requires* an initialized global state, and calling gnutls_deinit() after gnutls_global_deinit() can cause a crash. This patch solves the crash by removing the GlobalState ForeignPtr hack, and ensuring that gnutls_global_deinit() is always called after gnutls_deinit(). Originally reported by Keven McKenzie and Joey Hess.
269 lines
7.9 KiB
Haskell
269 lines
7.9 KiB
Haskell
{-# LANGUAGE TypeFamilies #-}
|
|
|
|
-- Copyright (C) 2010 John Millikin <jmillikin@gmail.com>
|
|
--
|
|
-- This program is free software: you can redistribute it and/or modify
|
|
-- it under the terms of the GNU General Public License as published by
|
|
-- the Free Software Foundation, either version 3 of the License, or
|
|
-- any later version.
|
|
--
|
|
-- This program is distributed in the hope that it will be useful,
|
|
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
-- GNU General Public License for more details.
|
|
--
|
|
-- You should have received a copy of the GNU General Public License
|
|
-- along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
module Network.Protocol.TLS.GNU
|
|
( TLS
|
|
, Session
|
|
, Error (..)
|
|
|
|
, runTLS
|
|
, runClient
|
|
, getSession
|
|
, handshake
|
|
, rehandshake
|
|
, putBytes
|
|
, getBytes
|
|
, checkPending
|
|
|
|
-- * Settings
|
|
, Transport (..)
|
|
, handleTransport
|
|
|
|
, Credentials
|
|
, setCredentials
|
|
, certificateCredentials
|
|
|
|
, Prioritised
|
|
, setPriority
|
|
, CertificateType (..)
|
|
) where
|
|
|
|
import qualified Control.Concurrent.MVar as M
|
|
import Control.Monad (when, foldM, foldM_)
|
|
import qualified Control.Monad.Error as E
|
|
import Control.Monad.Error (ErrorType)
|
|
import qualified Control.Monad.Reader as R
|
|
import Control.Monad.Trans (MonadIO, liftIO)
|
|
import qualified Data.ByteString as B
|
|
import qualified Data.ByteString.Lazy as BL
|
|
import qualified Data.ByteString.Unsafe as B
|
|
import Data.IORef
|
|
import qualified Foreign as F
|
|
import qualified Foreign.C as F
|
|
import Foreign.Concurrent as FC
|
|
import qualified System.IO as IO
|
|
import System.IO.Unsafe (unsafePerformIO)
|
|
|
|
import Network.Protocol.TLS.GNU.ErrorT
|
|
import qualified Network.Protocol.TLS.GNU.Foreign as F
|
|
|
|
data Error = Error Integer
|
|
deriving (Show)
|
|
|
|
globalInitMVar :: M.MVar ()
|
|
{-# NOINLINE globalInitMVar #-}
|
|
globalInitMVar = unsafePerformIO $ M.newMVar ()
|
|
|
|
globalInit :: ErrorT Error IO ()
|
|
globalInit = do
|
|
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
|
|
F.ReturnCode rc <- liftIO init_
|
|
when (rc < 0) $ E.throwError $ mapError rc
|
|
|
|
globalDeinit :: IO ()
|
|
globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit
|
|
|
|
data Session = Session
|
|
{ sessionPtr :: F.ForeignPtr F.Session
|
|
|
|
-- TLS credentials are not copied into the gnutls session struct,
|
|
-- so pointers to them must be kept alive until the credentials
|
|
-- are no longer needed.
|
|
--
|
|
-- TODO: Have some way to mark credentials as no longer needed.
|
|
-- The current code just keeps them alive for the duration
|
|
-- of the session, which may be excessive.
|
|
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
|
}
|
|
|
|
newtype TLS a = TLS { unTLS :: ErrorT Error (R.ReaderT Session IO) a }
|
|
|
|
instance Functor TLS where
|
|
fmap f = TLS . fmap f . unTLS
|
|
|
|
instance Monad TLS where
|
|
return = TLS . return
|
|
m >>= f = TLS $ unTLS m >>= unTLS . f
|
|
|
|
instance MonadIO TLS where
|
|
liftIO = TLS . liftIO
|
|
|
|
instance E.MonadError TLS where
|
|
type ErrorType TLS = Error
|
|
throwError = TLS . E.throwError
|
|
catchError m h = TLS $ E.catchError (unTLS m) (unTLS . h)
|
|
|
|
runTLS :: Session -> TLS a -> IO (Either Error a)
|
|
runTLS s tls = R.runReaderT (runErrorT (unTLS tls)) s
|
|
|
|
runClient :: Transport -> TLS a -> IO (Either Error a)
|
|
runClient transport tls = do
|
|
eitherSession <- newSession transport (F.ConnectionEnd 2)
|
|
case eitherSession of
|
|
Left err -> return (Left err)
|
|
Right session -> runTLS session tls
|
|
|
|
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
|
|
newSession transport end = F.alloca $ \sPtr -> runErrorT $ do
|
|
globalInit
|
|
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
|
|
when (rc < 0) $ E.throwError $ mapError rc
|
|
liftIO $ do
|
|
ptr <- F.peek sPtr
|
|
let session = F.Session ptr
|
|
push <- F.wrapTransportFunc (pushImpl transport)
|
|
pull <- F.wrapTransportFunc (pullImpl transport)
|
|
F.gnutls_transport_set_push_function session push
|
|
F.gnutls_transport_set_pull_function session pull
|
|
_ <- F.gnutls_set_default_priority session
|
|
creds <- newIORef []
|
|
fp <- FC.newForeignPtr ptr $ do
|
|
F.gnutls_deinit session
|
|
globalDeinit
|
|
F.freeHaskellFunPtr push
|
|
F.freeHaskellFunPtr pull
|
|
return (Session fp creds)
|
|
|
|
getSession :: TLS Session
|
|
getSession = TLS R.ask
|
|
|
|
handshake :: TLS ()
|
|
handshake = withSession F.gnutls_handshake >>= checkRC
|
|
|
|
rehandshake :: TLS ()
|
|
rehandshake = withSession F.gnutls_rehandshake >>= checkRC
|
|
|
|
putBytes :: BL.ByteString -> TLS ()
|
|
putBytes = putChunks . BL.toChunks where
|
|
putChunks chunks = do
|
|
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
|
|
case maybeErr of
|
|
Nothing -> return ()
|
|
Just err -> E.throwError $ mapError $ fromIntegral err
|
|
|
|
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
|
|
loop ptr len = do
|
|
let len' = fromIntegral len
|
|
sent <- F.gnutls_record_send s ptr len'
|
|
let sent' = fromIntegral sent
|
|
case len - sent' of
|
|
0 -> return Nothing
|
|
x | x > 0 -> loop (F.plusPtr ptr sent') x
|
|
| otherwise -> return $ Just x
|
|
|
|
putChunk _ err _ = return err
|
|
|
|
getBytes :: Integer -> TLS BL.ByteString
|
|
getBytes count = do
|
|
(mbytes, len) <- withSession $ \s ->
|
|
F.allocaBytes (fromInteger count) $ \ptr -> do
|
|
len <- F.gnutls_record_recv s ptr (fromInteger count)
|
|
bytes <- if len >= 0
|
|
then do
|
|
chunk <- B.packCStringLen (ptr, fromIntegral len)
|
|
return $ Just $ BL.fromChunks [chunk]
|
|
else return Nothing
|
|
return (bytes, len)
|
|
|
|
case mbytes of
|
|
Just bytes -> return bytes
|
|
Nothing -> E.throwError $ mapError $ fromIntegral len
|
|
|
|
checkPending :: TLS Integer
|
|
checkPending = withSession $ \s -> do
|
|
pending <- F.gnutls_record_check_pending s
|
|
return $ toInteger pending
|
|
|
|
data Transport = Transport
|
|
{ transportPush :: BL.ByteString -> IO ()
|
|
, transportPull :: Integer -> IO BL.ByteString
|
|
}
|
|
|
|
pullImpl :: Transport -> F.TransportFunc
|
|
pullImpl t _ buf bufSize = do
|
|
bytes <- transportPull t $ toInteger bufSize
|
|
let loop ptr chunk =
|
|
B.unsafeUseAsCStringLen chunk $ \(cstr, len) -> do
|
|
F.copyArray (F.castPtr ptr) cstr len
|
|
return $ F.plusPtr ptr len
|
|
foldM_ loop buf $ BL.toChunks bytes
|
|
return $ fromIntegral $ BL.length bytes
|
|
|
|
pushImpl :: Transport -> F.TransportFunc
|
|
pushImpl t _ buf bufSize = do
|
|
let buf' = F.castPtr buf
|
|
bytes <- B.unsafePackCStringLen (buf', fromIntegral bufSize)
|
|
transportPush t $ BL.fromChunks [bytes]
|
|
return bufSize
|
|
|
|
handleTransport :: IO.Handle -> Transport
|
|
handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger)
|
|
|
|
data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
|
|
|
|
setCredentials :: Credentials -> TLS ()
|
|
setCredentials (Credentials ctype fp) = do
|
|
rc <- withSession $ \s ->
|
|
F.withForeignPtr fp $ \ptr -> do
|
|
F.gnutls_credentials_set s ctype ptr
|
|
|
|
s <- getSession
|
|
if F.unRC rc == 0
|
|
then liftIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
|
|
else checkRC rc
|
|
|
|
certificateCredentials :: TLS Credentials
|
|
certificateCredentials = do
|
|
(rc, ptr) <- liftIO $ F.alloca $ \ptr -> do
|
|
rc <- F.gnutls_certificate_allocate_credentials ptr
|
|
ptr' <- if F.unRC rc < 0
|
|
then return F.nullPtr
|
|
else F.peek ptr
|
|
return (rc, ptr')
|
|
checkRC rc
|
|
fp <- liftIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
|
|
return $ Credentials (F.CredentialsType 1) fp
|
|
|
|
class Prioritised a where
|
|
priorityInt :: a -> F.CInt
|
|
priorityProc :: a -> F.Session -> F.Ptr F.CInt -> IO F.ReturnCode
|
|
|
|
data CertificateType = X509 | OpenPGP
|
|
deriving (Show)
|
|
|
|
instance Prioritised CertificateType where
|
|
priorityProc = const F.gnutls_certificate_type_set_priority
|
|
priorityInt x = case x of
|
|
X509 -> 1
|
|
OpenPGP -> 2
|
|
|
|
setPriority :: Prioritised a => [a] -> TLS ()
|
|
setPriority xs = do
|
|
let fake = head $ [undefined] ++ xs
|
|
rc <- withSession $ F.withArray0 0 (map priorityInt xs) . priorityProc fake
|
|
checkRC rc
|
|
|
|
withSession :: (F.Session -> IO a) -> TLS a
|
|
withSession io = do
|
|
s <- getSession
|
|
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
|
|
|
checkRC :: F.ReturnCode -> TLS ()
|
|
checkRC (F.ReturnCode x) = when (x < 0) $ E.throwError $ mapError x
|
|
|
|
mapError :: F.CInt -> Error
|
|
mapError = Error . toInteger
|