2012-05-13 17:28:58 -04:00
|
|
|
{-# LANGUAGE TypeFamilies #-}
|
|
|
|
|
2010-04-26 12:59:24 -04:00
|
|
|
-- Copyright (C) 2010 John Millikin <jmillikin@gmail.com>
|
2012-05-13 17:28:58 -04:00
|
|
|
--
|
2010-04-26 12:59:24 -04:00
|
|
|
-- This program is free software: you can redistribute it and/or modify
|
|
|
|
-- it under the terms of the GNU General Public License as published by
|
|
|
|
-- the Free Software Foundation, either version 3 of the License, or
|
|
|
|
-- any later version.
|
2012-05-13 17:28:58 -04:00
|
|
|
--
|
2010-04-26 12:59:24 -04:00
|
|
|
-- This program is distributed in the hope that it will be useful,
|
|
|
|
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
|
-- GNU General Public License for more details.
|
2012-05-13 17:28:58 -04:00
|
|
|
--
|
2010-04-26 12:59:24 -04:00
|
|
|
-- You should have received a copy of the GNU General Public License
|
|
|
|
-- along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
|
|
|
|
module Network.Protocol.TLS.GNU
|
|
|
|
( TLS
|
|
|
|
, Session
|
|
|
|
, Error (..)
|
2021-02-13 20:31:35 -05:00
|
|
|
, throwE
|
|
|
|
, fromExceptT
|
2010-04-26 12:59:24 -04:00
|
|
|
|
|
|
|
, runTLS
|
2021-02-13 21:18:46 -05:00
|
|
|
, runTLS'
|
2010-04-26 12:59:24 -04:00
|
|
|
, runClient
|
|
|
|
, getSession
|
|
|
|
, handshake
|
2012-02-23 21:33:43 -05:00
|
|
|
, rehandshake
|
2010-04-26 12:59:24 -04:00
|
|
|
, putBytes
|
|
|
|
, getBytes
|
|
|
|
, checkPending
|
|
|
|
|
|
|
|
-- * Settings
|
|
|
|
, Transport (..)
|
|
|
|
, handleTransport
|
|
|
|
|
|
|
|
, Credentials
|
|
|
|
, setCredentials
|
|
|
|
, certificateCredentials
|
|
|
|
) where
|
2012-05-13 17:28:58 -04:00
|
|
|
|
|
|
|
import qualified Control.Concurrent.MVar as M
|
2021-02-13 20:59:00 -05:00
|
|
|
import Control.Monad (when, foldM, foldM_)
|
2021-02-13 20:07:24 -05:00
|
|
|
import Control.Monad.Trans.Class (lift)
|
2021-02-13 20:31:35 -05:00
|
|
|
import qualified Control.Monad.Trans.Except as E
|
2021-02-13 20:07:24 -05:00
|
|
|
import qualified Control.Monad.Trans.Reader as R
|
2021-02-13 20:59:00 -05:00
|
|
|
import Control.Monad.IO.Class (liftIO)
|
2010-04-26 12:59:24 -04:00
|
|
|
import qualified Data.ByteString as B
|
|
|
|
import qualified Data.ByteString.Lazy as BL
|
2012-05-13 17:28:58 -04:00
|
|
|
import qualified Data.ByteString.Unsafe as B
|
2012-10-26 22:01:52 -04:00
|
|
|
import Data.IORef
|
2010-04-26 12:59:24 -04:00
|
|
|
import qualified Foreign as F
|
|
|
|
import qualified Foreign.C as F
|
2012-05-13 17:28:58 -04:00
|
|
|
import Foreign.Concurrent as FC
|
|
|
|
import qualified System.IO as IO
|
|
|
|
import System.IO.Unsafe (unsafePerformIO)
|
2021-02-13 20:14:32 -05:00
|
|
|
import UnexceptionalIO.Trans (UIO, Unexceptional)
|
|
|
|
import qualified UnexceptionalIO.Trans as UIO
|
2012-05-13 17:28:58 -04:00
|
|
|
|
2010-04-26 12:59:24 -04:00
|
|
|
import qualified Network.Protocol.TLS.GNU.Foreign as F
|
|
|
|
|
2021-02-13 20:59:00 -05:00
|
|
|
data Error = Error Integer
|
2010-04-26 12:59:24 -04:00
|
|
|
deriving (Show)
|
|
|
|
|
|
|
|
globalInitMVar :: M.MVar ()
|
|
|
|
{-# NOINLINE globalInitMVar #-}
|
|
|
|
globalInitMVar = unsafePerformIO $ M.newMVar ()
|
|
|
|
|
2021-02-13 20:31:35 -05:00
|
|
|
globalInit :: E.ExceptT Error IO ()
|
2010-04-26 12:59:24 -04:00
|
|
|
globalInit = do
|
2012-02-23 21:33:26 -05:00
|
|
|
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
|
|
|
|
F.ReturnCode rc <- liftIO init_
|
2021-02-13 20:31:35 -05:00
|
|
|
when (rc < 0) $ E.throwE $ mapError rc
|
2013-09-07 15:32:39 -04:00
|
|
|
|
|
|
|
globalDeinit :: IO ()
|
|
|
|
globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit
|
2010-04-26 12:59:24 -04:00
|
|
|
|
|
|
|
data Session = Session
|
|
|
|
{ sessionPtr :: F.ForeignPtr F.Session
|
2012-10-26 22:01:52 -04:00
|
|
|
|
|
|
|
-- TLS credentials are not copied into the gnutls session struct,
|
|
|
|
-- so pointers to them must be kept alive until the credentials
|
|
|
|
-- are no longer needed.
|
|
|
|
--
|
|
|
|
-- TODO: Have some way to mark credentials as no longer needed.
|
|
|
|
-- The current code just keeps them alive for the duration
|
|
|
|
-- of the session, which may be excessive.
|
|
|
|
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
|
2010-04-26 12:59:24 -04:00
|
|
|
}
|
|
|
|
|
2021-02-13 20:59:00 -05:00
|
|
|
type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a
|
2021-02-13 20:31:35 -05:00
|
|
|
|
|
|
|
throwE :: Error -> TLS a
|
|
|
|
throwE = fromExceptT . E.throwE
|
|
|
|
|
|
|
|
fromExceptT :: E.ExceptT Error UIO a -> TLS a
|
2021-02-13 20:59:00 -05:00
|
|
|
fromExceptT = E.mapExceptT lift
|
2010-04-26 12:59:24 -04:00
|
|
|
|
2021-02-13 20:14:32 -05:00
|
|
|
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
|
2021-02-13 21:18:46 -05:00
|
|
|
runTLS s = E.runExceptT . runTLS' s
|
|
|
|
|
|
|
|
runTLS' :: (Unexceptional m) => Session -> TLS a -> E.ExceptT Error m a
|
|
|
|
runTLS' s = E.mapExceptT (UIO.lift . flip R.runReaderT s)
|
2010-04-26 12:59:24 -04:00
|
|
|
|
|
|
|
runClient :: Transport -> TLS a -> IO (Either Error a)
|
|
|
|
runClient transport tls = do
|
|
|
|
eitherSession <- newSession transport (F.ConnectionEnd 2)
|
|
|
|
case eitherSession of
|
|
|
|
Left err -> return (Left err)
|
|
|
|
Right session -> runTLS session tls
|
|
|
|
|
|
|
|
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
|
2021-02-13 20:31:35 -05:00
|
|
|
newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
|
2013-09-07 15:32:39 -04:00
|
|
|
globalInit
|
2010-04-26 12:59:24 -04:00
|
|
|
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
|
2021-02-13 20:31:35 -05:00
|
|
|
when (rc < 0) $ E.throwE $ mapError rc
|
2010-04-26 12:59:24 -04:00
|
|
|
liftIO $ do
|
|
|
|
ptr <- F.peek sPtr
|
|
|
|
let session = F.Session ptr
|
|
|
|
push <- F.wrapTransportFunc (pushImpl transport)
|
|
|
|
pull <- F.wrapTransportFunc (pullImpl transport)
|
|
|
|
F.gnutls_transport_set_push_function session push
|
|
|
|
F.gnutls_transport_set_pull_function session pull
|
2012-02-23 21:33:26 -05:00
|
|
|
_ <- F.gnutls_set_default_priority session
|
2012-10-26 22:01:52 -04:00
|
|
|
creds <- newIORef []
|
2010-04-26 12:59:24 -04:00
|
|
|
fp <- FC.newForeignPtr ptr $ do
|
|
|
|
F.gnutls_deinit session
|
2013-09-07 15:32:39 -04:00
|
|
|
globalDeinit
|
2010-04-26 12:59:24 -04:00
|
|
|
F.freeHaskellFunPtr push
|
|
|
|
F.freeHaskellFunPtr pull
|
2013-09-07 15:32:39 -04:00
|
|
|
return (Session fp creds)
|
2010-04-26 12:59:24 -04:00
|
|
|
|
|
|
|
getSession :: TLS Session
|
2021-02-13 20:59:00 -05:00
|
|
|
getSession = lift R.ask
|
2010-04-26 12:59:24 -04:00
|
|
|
|
|
|
|
handshake :: TLS ()
|
2021-02-13 20:59:00 -05:00
|
|
|
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
|
2010-04-26 12:59:24 -04:00
|
|
|
|
|
|
|
rehandshake :: TLS ()
|
2021-02-13 20:59:00 -05:00
|
|
|
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
|
2010-04-26 12:59:24 -04:00
|
|
|
|
|
|
|
putBytes :: BL.ByteString -> TLS ()
|
|
|
|
putBytes = putChunks . BL.toChunks where
|
|
|
|
putChunks chunks = do
|
2021-02-13 20:59:00 -05:00
|
|
|
maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
|
2010-04-26 12:59:24 -04:00
|
|
|
case maybeErr of
|
|
|
|
Nothing -> return ()
|
2021-02-13 20:31:35 -05:00
|
|
|
Just err -> throwE $ mapError $ fromIntegral err
|
2010-04-26 12:59:24 -04:00
|
|
|
|
|
|
|
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
|
|
|
|
loop ptr len = do
|
|
|
|
let len' = fromIntegral len
|
|
|
|
sent <- F.gnutls_record_send s ptr len'
|
|
|
|
let sent' = fromIntegral sent
|
|
|
|
case len - sent' of
|
|
|
|
0 -> return Nothing
|
|
|
|
x | x > 0 -> loop (F.plusPtr ptr sent') x
|
|
|
|
| otherwise -> return $ Just x
|
|
|
|
|
|
|
|
putChunk _ err _ = return err
|
|
|
|
|
|
|
|
getBytes :: Integer -> TLS BL.ByteString
|
|
|
|
getBytes count = do
|
2021-02-13 20:59:00 -05:00
|
|
|
(mbytes, len) <- unsafeWithSession $ \s ->
|
2010-04-26 12:59:24 -04:00
|
|
|
F.allocaBytes (fromInteger count) $ \ptr -> do
|
|
|
|
len <- F.gnutls_record_recv s ptr (fromInteger count)
|
|
|
|
bytes <- if len >= 0
|
|
|
|
then do
|
|
|
|
chunk <- B.packCStringLen (ptr, fromIntegral len)
|
|
|
|
return $ Just $ BL.fromChunks [chunk]
|
|
|
|
else return Nothing
|
|
|
|
return (bytes, len)
|
|
|
|
|
2012-02-23 21:33:26 -05:00
|
|
|
case mbytes of
|
2010-04-26 12:59:24 -04:00
|
|
|
Just bytes -> return bytes
|
2021-02-13 20:31:35 -05:00
|
|
|
Nothing -> throwE $ mapError $ fromIntegral len
|
2010-04-26 12:59:24 -04:00
|
|
|
|
|
|
|
checkPending :: TLS Integer
|
2021-02-13 20:59:00 -05:00
|
|
|
checkPending = unsafeWithSession $ \s -> do
|
2010-04-26 12:59:24 -04:00
|
|
|
pending <- F.gnutls_record_check_pending s
|
|
|
|
return $ toInteger pending
|
|
|
|
|
|
|
|
data Transport = Transport
|
|
|
|
{ transportPush :: BL.ByteString -> IO ()
|
|
|
|
, transportPull :: Integer -> IO BL.ByteString
|
|
|
|
}
|
|
|
|
|
|
|
|
pullImpl :: Transport -> F.TransportFunc
|
|
|
|
pullImpl t _ buf bufSize = do
|
|
|
|
bytes <- transportPull t $ toInteger bufSize
|
|
|
|
let loop ptr chunk =
|
|
|
|
B.unsafeUseAsCStringLen chunk $ \(cstr, len) -> do
|
|
|
|
F.copyArray (F.castPtr ptr) cstr len
|
|
|
|
return $ F.plusPtr ptr len
|
|
|
|
foldM_ loop buf $ BL.toChunks bytes
|
|
|
|
return $ fromIntegral $ BL.length bytes
|
|
|
|
|
|
|
|
pushImpl :: Transport -> F.TransportFunc
|
|
|
|
pushImpl t _ buf bufSize = do
|
|
|
|
let buf' = F.castPtr buf
|
|
|
|
bytes <- B.unsafePackCStringLen (buf', fromIntegral bufSize)
|
|
|
|
transportPush t $ BL.fromChunks [bytes]
|
|
|
|
return bufSize
|
|
|
|
|
|
|
|
handleTransport :: IO.Handle -> Transport
|
|
|
|
handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger)
|
|
|
|
|
|
|
|
data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
|
|
|
|
|
|
|
|
setCredentials :: Credentials -> TLS ()
|
|
|
|
setCredentials (Credentials ctype fp) = do
|
2021-02-13 20:59:00 -05:00
|
|
|
rc <- unsafeWithSession $ \s ->
|
2010-04-26 12:59:24 -04:00
|
|
|
F.withForeignPtr fp $ \ptr -> do
|
|
|
|
F.gnutls_credentials_set s ctype ptr
|
2012-10-26 22:01:52 -04:00
|
|
|
|
|
|
|
s <- getSession
|
|
|
|
if F.unRC rc == 0
|
2021-02-13 20:59:00 -05:00
|
|
|
then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
|
2012-10-26 22:01:52 -04:00
|
|
|
else checkRC rc
|
2010-04-26 12:59:24 -04:00
|
|
|
|
|
|
|
certificateCredentials :: TLS Credentials
|
|
|
|
certificateCredentials = do
|
2021-02-13 20:59:00 -05:00
|
|
|
(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
|
2010-04-26 12:59:24 -04:00
|
|
|
rc <- F.gnutls_certificate_allocate_credentials ptr
|
2012-02-23 21:33:26 -05:00
|
|
|
ptr' <- if F.unRC rc < 0
|
2010-04-26 12:59:24 -04:00
|
|
|
then return F.nullPtr
|
|
|
|
else F.peek ptr
|
2012-02-23 21:33:26 -05:00
|
|
|
return (rc, ptr')
|
2010-04-26 12:59:24 -04:00
|
|
|
checkRC rc
|
2021-02-13 20:59:00 -05:00
|
|
|
fp <- UIO.unsafeFromIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
|
2010-04-26 12:59:24 -04:00
|
|
|
return $ Credentials (F.CredentialsType 1) fp
|
|
|
|
|
2021-02-13 20:59:00 -05:00
|
|
|
-- | This must only be called with IO actions that do not throw NonPseudoException
|
|
|
|
unsafeWithSession :: (F.Session -> IO a) -> TLS a
|
|
|
|
unsafeWithSession io = do
|
2010-04-26 12:59:24 -04:00
|
|
|
s <- getSession
|
2021-02-13 20:59:00 -05:00
|
|
|
UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
|
2010-04-26 12:59:24 -04:00
|
|
|
|
|
|
|
checkRC :: F.ReturnCode -> TLS ()
|
2021-02-13 20:31:35 -05:00
|
|
|
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
|
2010-04-26 12:59:24 -04:00
|
|
|
|
|
|
|
mapError :: F.CInt -> Error
|
|
|
|
mapError = Error . toInteger
|