{-# LANGUAGE TypeFamilies #-} -- Copyright (C) 2010 John Millikin -- -- This program is free software: you can redistribute it and/or modify -- it under the terms of the GNU General Public License as published by -- the Free Software Foundation, either version 3 of the License, or -- any later version. -- -- This program is distributed in the hope that it will be useful, -- but WITHOUT ANY WARRANTY; without even the implied warranty of -- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -- GNU General Public License for more details. -- -- You should have received a copy of the GNU General Public License -- along with this program. If not, see . module Network.Protocol.TLS.GNU ( TLS , Session , Error (..) , throwE , fromExceptT , runTLS , runTLS' , runClient , getSession , handshake , rehandshake , putBytes , getBytes , checkPending -- * Settings , Transport (..) , handleTransport , Credentials , setCredentials , certificateCredentials ) where import qualified Control.Concurrent.MVar as M import Control.Monad (when, foldM, foldM_) import Control.Monad.Trans.Class (lift) import qualified Control.Monad.Trans.Except as E import qualified Control.Monad.Trans.Reader as R import Control.Monad.IO.Class (liftIO) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Unsafe as B import Data.IORef import qualified Foreign as F import qualified Foreign.C as F import Foreign.Concurrent as FC import qualified System.IO as IO import System.IO.Unsafe (unsafePerformIO) import UnexceptionalIO.Trans (UIO, Unexceptional) import qualified UnexceptionalIO.Trans as UIO import qualified Network.Protocol.TLS.GNU.Foreign as F data Error = Error Integer deriving (Show) globalInitMVar :: M.MVar () {-# NOINLINE globalInitMVar #-} globalInitMVar = unsafePerformIO $ M.newMVar () globalInit :: E.ExceptT Error IO () globalInit = do let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init F.ReturnCode rc <- liftIO init_ when (rc < 0) $ E.throwE $ mapError rc globalDeinit :: IO () globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit data Session = Session { sessionPtr :: F.ForeignPtr F.Session -- TLS credentials are not copied into the gnutls session struct, -- so pointers to them must be kept alive until the credentials -- are no longer needed. -- -- TODO: Have some way to mark credentials as no longer needed. -- The current code just keeps them alive for the duration -- of the session, which may be excessive. , sessionCredentials :: IORef [F.ForeignPtr F.Credentials] } type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a throwE :: Error -> TLS a throwE = fromExceptT . E.throwE fromExceptT :: E.ExceptT Error UIO a -> TLS a fromExceptT = E.mapExceptT lift runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a) runTLS s = E.runExceptT . runTLS' s runTLS' :: (Unexceptional m) => Session -> TLS a -> E.ExceptT Error m a runTLS' s = E.mapExceptT (UIO.lift . flip R.runReaderT s) runClient :: Transport -> TLS a -> IO (Either Error a) runClient transport tls = do eitherSession <- newSession transport (F.ConnectionEnd 2) case eitherSession of Left err -> return (Left err) Right session -> runTLS session tls newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session) newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do globalInit F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end when (rc < 0) $ E.throwE $ mapError rc liftIO $ do ptr <- F.peek sPtr let session = F.Session ptr push <- F.wrapTransportFunc (pushImpl transport) pull <- F.wrapTransportFunc (pullImpl transport) F.gnutls_transport_set_push_function session push F.gnutls_transport_set_pull_function session pull _ <- F.gnutls_set_default_priority session creds <- newIORef [] fp <- FC.newForeignPtr ptr $ do F.gnutls_deinit session globalDeinit F.freeHaskellFunPtr push F.freeHaskellFunPtr pull return (Session fp creds) getSession :: TLS Session getSession = lift R.ask handshake :: TLS () handshake = unsafeWithSession F.gnutls_handshake >>= checkRC rehandshake :: TLS () rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC putBytes :: BL.ByteString -> TLS () putBytes = putChunks . BL.toChunks where putChunks chunks = do maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks case maybeErr of Nothing -> return () Just err -> throwE $ mapError $ fromIntegral err putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where loop ptr len = do let len' = fromIntegral len sent <- F.gnutls_record_send s ptr len' let sent' = fromIntegral sent case len - sent' of 0 -> return Nothing x | x > 0 -> loop (F.plusPtr ptr sent') x | otherwise -> return $ Just x putChunk _ err _ = return err getBytes :: Integer -> TLS BL.ByteString getBytes count = do (mbytes, len) <- unsafeWithSession $ \s -> F.allocaBytes (fromInteger count) $ \ptr -> do len <- F.gnutls_record_recv s ptr (fromInteger count) bytes <- if len >= 0 then do chunk <- B.packCStringLen (ptr, fromIntegral len) return $ Just $ BL.fromChunks [chunk] else return Nothing return (bytes, len) case mbytes of Just bytes -> return bytes Nothing -> throwE $ mapError $ fromIntegral len checkPending :: TLS Integer checkPending = unsafeWithSession $ \s -> do pending <- F.gnutls_record_check_pending s return $ toInteger pending data Transport = Transport { transportPush :: BL.ByteString -> IO () , transportPull :: Integer -> IO BL.ByteString } pullImpl :: Transport -> F.TransportFunc pullImpl t _ buf bufSize = do bytes <- transportPull t $ toInteger bufSize let loop ptr chunk = B.unsafeUseAsCStringLen chunk $ \(cstr, len) -> do F.copyArray (F.castPtr ptr) cstr len return $ F.plusPtr ptr len foldM_ loop buf $ BL.toChunks bytes return $ fromIntegral $ BL.length bytes pushImpl :: Transport -> F.TransportFunc pushImpl t _ buf bufSize = do let buf' = F.castPtr buf bytes <- B.unsafePackCStringLen (buf', fromIntegral bufSize) transportPush t $ BL.fromChunks [bytes] return bufSize handleTransport :: IO.Handle -> Transport handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger) data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials) setCredentials :: Credentials -> TLS () setCredentials (Credentials ctype fp) = do rc <- unsafeWithSession $ \s -> F.withForeignPtr fp $ \ptr -> do F.gnutls_credentials_set s ctype ptr s <- getSession if F.unRC rc == 0 then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ()))) else checkRC rc certificateCredentials :: TLS Credentials certificateCredentials = do (rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do rc <- F.gnutls_certificate_allocate_credentials ptr ptr' <- if F.unRC rc < 0 then return F.nullPtr else F.peek ptr return (rc, ptr') checkRC rc fp <- UIO.unsafeFromIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr return $ Credentials (F.CredentialsType 1) fp -- | This must only be called with IO actions that do not throw NonPseudoException unsafeWithSession :: (F.Session -> IO a) -> TLS a unsafeWithSession io = do s <- getSession UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session checkRC :: F.ReturnCode -> TLS () checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x mapError :: F.CInt -> Error mapError = Error . toInteger