After setting credentials, save a reference to the gnutls credentials

struct to keep them alive for the duration of the session.

Fixes a potential crash when opening connections, reported by Joey Hess.
This commit is contained in:
John Millikin 2012-10-26 19:01:52 -07:00
parent 04064950bd
commit 777d600326
No known key found for this signature in database
GPG key ID: 59A38F85F9C7C59E

View file

@ -51,6 +51,7 @@ import Control.Monad.Trans (MonadIO, liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Unsafe as B
import Data.IORef
import qualified Foreign as F
import qualified Foreign.C as F
import Foreign.Concurrent as FC
@ -81,6 +82,15 @@ globalInit = do
data Session = Session
{ sessionPtr :: F.ForeignPtr F.Session
, sessionGlobalState :: GlobalState
-- TLS credentials are not copied into the gnutls session struct,
-- so pointers to them must be kept alive until the credentials
-- are no longer needed.
--
-- TODO: Have some way to mark credentials as no longer needed.
-- The current code just keeps them alive for the duration
-- of the session, which may be excessive.
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
}
newtype TLS a = TLS { unTLS :: ErrorT Error (R.ReaderT Session IO) a }
@ -123,11 +133,12 @@ newSession transport end = F.alloca $ \sPtr -> runErrorT $ do
F.gnutls_transport_set_push_function session push
F.gnutls_transport_set_pull_function session pull
_ <- F.gnutls_set_default_priority session
creds <- newIORef []
fp <- FC.newForeignPtr ptr $ do
F.gnutls_deinit session
F.freeHaskellFunPtr push
F.freeHaskellFunPtr pull
return $ Session fp global
return (Session fp global creds)
getSession :: TLS Session
getSession = TLS R.ask
@ -211,7 +222,11 @@ setCredentials (Credentials ctype fp) = do
rc <- withSession $ \s ->
F.withForeignPtr fp $ \ptr -> do
F.gnutls_credentials_set s ctype ptr
checkRC rc
s <- getSession
if F.unRC rc == 0
then liftIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
else checkRC rc
certificateCredentials :: TLS Credentials
certificateCredentials = do