haskell-gnutls/lib/Network/Protocol/TLS/GNU.hs
Stephen Paul Weber 955b054ff4 Switch monad transformer stack to a type alias
Since we already allowed injecting any Session via runTLS or throwing any Error
via throwE, this does not reduce safety at all but improves ergonomics
considerably.

The only downside here is that we must say goodbye to our transitional MonadIO
instance.
2021-02-16 13:12:55 -05:00

241 lines
7.4 KiB
Haskell

{-# LANGUAGE TypeFamilies #-}
-- Copyright (C) 2010 John Millikin <jmillikin@gmail.com>
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program. If not, see <http://www.gnu.org/licenses/>.
module Network.Protocol.TLS.GNU
( TLS
, Session
, Error (..)
, throwE
, fromExceptT
, runTLS
, runClient
, getSession
, handshake
, rehandshake
, putBytes
, getBytes
, checkPending
-- * Settings
, Transport (..)
, handleTransport
, Credentials
, setCredentials
, certificateCredentials
) where
import qualified Control.Concurrent.MVar as M
import Control.Monad (when, foldM, foldM_)
import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.Except as E
import qualified Control.Monad.Trans.Reader as R
import Control.Monad.IO.Class (liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Unsafe as B
import Data.IORef
import qualified Foreign as F
import qualified Foreign.C as F
import Foreign.Concurrent as FC
import qualified System.IO as IO
import System.IO.Unsafe (unsafePerformIO)
import UnexceptionalIO.Trans (UIO, Unexceptional)
import qualified UnexceptionalIO.Trans as UIO
import qualified Network.Protocol.TLS.GNU.Foreign as F
data Error = Error Integer
deriving (Show)
globalInitMVar :: M.MVar ()
{-# NOINLINE globalInitMVar #-}
globalInitMVar = unsafePerformIO $ M.newMVar ()
globalInit :: E.ExceptT Error IO ()
globalInit = do
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
F.ReturnCode rc <- liftIO init_
when (rc < 0) $ E.throwE $ mapError rc
globalDeinit :: IO ()
globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit
data Session = Session
{ sessionPtr :: F.ForeignPtr F.Session
-- TLS credentials are not copied into the gnutls session struct,
-- so pointers to them must be kept alive until the credentials
-- are no longer needed.
--
-- TODO: Have some way to mark credentials as no longer needed.
-- The current code just keeps them alive for the duration
-- of the session, which may be excessive.
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
}
type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a
throwE :: Error -> TLS a
throwE = fromExceptT . E.throwE
fromExceptT :: E.ExceptT Error UIO a -> TLS a
fromExceptT = E.mapExceptT lift
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT tls) s
runClient :: Transport -> TLS a -> IO (Either Error a)
runClient transport tls = do
eitherSession <- newSession transport (F.ConnectionEnd 2)
case eitherSession of
Left err -> return (Left err)
Right session -> runTLS session tls
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
globalInit
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
when (rc < 0) $ E.throwE $ mapError rc
liftIO $ do
ptr <- F.peek sPtr
let session = F.Session ptr
push <- F.wrapTransportFunc (pushImpl transport)
pull <- F.wrapTransportFunc (pullImpl transport)
F.gnutls_transport_set_push_function session push
F.gnutls_transport_set_pull_function session pull
_ <- F.gnutls_set_default_priority session
creds <- newIORef []
fp <- FC.newForeignPtr ptr $ do
F.gnutls_deinit session
globalDeinit
F.freeHaskellFunPtr push
F.freeHaskellFunPtr pull
return (Session fp creds)
getSession :: TLS Session
getSession = lift R.ask
handshake :: TLS ()
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
rehandshake :: TLS ()
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
putBytes :: BL.ByteString -> TLS ()
putBytes = putChunks . BL.toChunks where
putChunks chunks = do
maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
case maybeErr of
Nothing -> return ()
Just err -> throwE $ mapError $ fromIntegral err
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
loop ptr len = do
let len' = fromIntegral len
sent <- F.gnutls_record_send s ptr len'
let sent' = fromIntegral sent
case len - sent' of
0 -> return Nothing
x | x > 0 -> loop (F.plusPtr ptr sent') x
| otherwise -> return $ Just x
putChunk _ err _ = return err
getBytes :: Integer -> TLS BL.ByteString
getBytes count = do
(mbytes, len) <- unsafeWithSession $ \s ->
F.allocaBytes (fromInteger count) $ \ptr -> do
len <- F.gnutls_record_recv s ptr (fromInteger count)
bytes <- if len >= 0
then do
chunk <- B.packCStringLen (ptr, fromIntegral len)
return $ Just $ BL.fromChunks [chunk]
else return Nothing
return (bytes, len)
case mbytes of
Just bytes -> return bytes
Nothing -> throwE $ mapError $ fromIntegral len
checkPending :: TLS Integer
checkPending = unsafeWithSession $ \s -> do
pending <- F.gnutls_record_check_pending s
return $ toInteger pending
data Transport = Transport
{ transportPush :: BL.ByteString -> IO ()
, transportPull :: Integer -> IO BL.ByteString
}
pullImpl :: Transport -> F.TransportFunc
pullImpl t _ buf bufSize = do
bytes <- transportPull t $ toInteger bufSize
let loop ptr chunk =
B.unsafeUseAsCStringLen chunk $ \(cstr, len) -> do
F.copyArray (F.castPtr ptr) cstr len
return $ F.plusPtr ptr len
foldM_ loop buf $ BL.toChunks bytes
return $ fromIntegral $ BL.length bytes
pushImpl :: Transport -> F.TransportFunc
pushImpl t _ buf bufSize = do
let buf' = F.castPtr buf
bytes <- B.unsafePackCStringLen (buf', fromIntegral bufSize)
transportPush t $ BL.fromChunks [bytes]
return bufSize
handleTransport :: IO.Handle -> Transport
handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger)
data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
setCredentials :: Credentials -> TLS ()
setCredentials (Credentials ctype fp) = do
rc <- unsafeWithSession $ \s ->
F.withForeignPtr fp $ \ptr -> do
F.gnutls_credentials_set s ctype ptr
s <- getSession
if F.unRC rc == 0
then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
else checkRC rc
certificateCredentials :: TLS Credentials
certificateCredentials = do
(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
rc <- F.gnutls_certificate_allocate_credentials ptr
ptr' <- if F.unRC rc < 0
then return F.nullPtr
else F.peek ptr
return (rc, ptr')
checkRC rc
fp <- UIO.unsafeFromIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
return $ Credentials (F.CredentialsType 1) fp
-- | This must only be called with IO actions that do not throw NonPseudoException
unsafeWithSession :: (F.Session -> IO a) -> TLS a
unsafeWithSession io = do
s <- getSession
UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
checkRC :: F.ReturnCode -> TLS ()
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
mapError :: F.CInt -> Error
mapError = Error . toInteger