Generalize the type of putManyConsLength

This commit is contained in:
Andrew Martin 2019-11-23 13:34:19 -05:00
parent 70a5c15e6c
commit 786a83332b

View file

@ -102,6 +102,7 @@ module Data.ByteArray.Builder
) where ) where
import Control.Monad.ST (ST,runST) import Control.Monad.ST (ST,runST)
import Control.Monad.IO.Class (MonadIO,liftIO)
import Data.ByteArray.Builder.Unsafe (Builder(Builder)) import Data.ByteArray.Builder.Unsafe (Builder(Builder))
import Data.ByteArray.Builder.Unsafe (BuilderState(BuilderState),pasteIO) import Data.ByteArray.Builder.Unsafe (BuilderState(BuilderState),pasteIO)
import Data.ByteArray.Builder.Unsafe (Commits(Initial,Mutable,Immutable)) import Data.ByteArray.Builder.Unsafe (Commits(Initial,Mutable,Immutable))
@ -126,6 +127,7 @@ import GHC.ST (ST(ST))
import qualified Arithmetic.Nat as Nat import qualified Arithmetic.Nat as Nat
import qualified Arithmetic.Types as Arithmetic import qualified Arithmetic.Types as Arithmetic
import qualified Control.Monad.Primitive as PM
import qualified Data.ByteArray.Builder.Bounded as Bounded import qualified Data.ByteArray.Builder.Bounded as Bounded
import qualified Data.ByteArray.Builder.Bounded.Unsafe as UnsafeBounded import qualified Data.ByteArray.Builder.Bounded.Unsafe as UnsafeBounded
import qualified Data.Primitive as PM import qualified Data.Primitive as PM
@ -176,39 +178,39 @@ putMany hint@(I# hint#) g xs cb = do
-- with the number of bytes that the chunks in each batch required. -- with the number of bytes that the chunks in each batch required.
-- (This excludes the bytes required to encode the length itself.) -- (This excludes the bytes required to encode the length itself.)
-- This is useful for chunked HTTP encoding. -- This is useful for chunked HTTP encoding.
putManyConsLength :: Foldable f putManyConsLength :: (Foldable f, MonadIO m)
=> Arithmetic.Nat n -- ^ Number of bytes used by the serialization of the length => Arithmetic.Nat n -- ^ Number of bytes used by the serialization of the length
-> (Int -> Bounded.Builder n) -- ^ Length serialization function -> (Int -> Bounded.Builder n) -- ^ Length serialization function
-> Int -- ^ Size of shared chunk (use 8176 if uncertain) -> Int -- ^ Size of shared chunk (use 8176 if uncertain)
-> (a -> Builder) -- ^ Value builder -> (a -> Builder) -- ^ Value builder
-> f a -- ^ Collection of values -> f a -- ^ Collection of values
-> (UnliftedArray (MutableByteArray RealWorld) -> IO b) -- ^ Consume chunks. -> (UnliftedArray (MutableByteArray RealWorld) -> m b) -- ^ Consume chunks.
-> IO () -> m ()
{-# inline putManyConsLength #-} {-# inline putManyConsLength #-}
putManyConsLength n buildSize hint g xs cb = do putManyConsLength n buildSize hint g xs cb = do
let !(I# n# ) = Nat.demote n let !(I# n# ) = Nat.demote n
let !(I# actual# ) = max hint (I# n# ) let !(I# actual# ) = max hint (I# n# )
MutableByteArray buf0 <- PM.newByteArray (I# actual# ) MutableByteArray buf0 <- liftIO (PM.newByteArray (I# actual# ))
BuilderState bufZ offZ _ cmtsZ <- foldlM BuilderState bufZ offZ _ cmtsZ <- foldlM
(\st0 a -> do (\st0 a -> do
st1@(BuilderState buf off _ cmts) <- pasteIO (g a) st0 st1@(BuilderState buf off _ cmts) <- liftIO (pasteIO (g a) st0)
case cmts of case cmts of
Initial -> pure st1 Initial -> pure st1
_ -> do _ -> do
let !dist = commitDistance1 buf0 n# buf off cmts let !dist = commitDistance1 buf0 n# buf off cmts
_ <- stToIO $ UnsafeBounded.pasteST _ <- liftIO $ stToIO $ UnsafeBounded.pasteST
(buildSize (fromIntegral (I# dist))) (buildSize (fromIntegral (I# dist)))
(MutableByteArray buf0) (MutableByteArray buf0)
0 0
_ <- cb =<< commitsToArray buf off cmts _ <- cb =<< liftIO (IO (PM.internal (commitsToArray buf off cmts)))
pure (BuilderState buf0 n# (actual# -# n# ) Initial) pure (BuilderState buf0 n# (actual# -# n# ) Initial)
) (BuilderState buf0 n# (actual# -# n# ) Initial) xs ) (BuilderState buf0 n# (actual# -# n# ) Initial) xs
let !distZ = commitDistance1 bufZ n# bufZ offZ cmtsZ let !distZ = commitDistance1 bufZ n# bufZ offZ cmtsZ
_ <- stToIO $ UnsafeBounded.pasteST _ <- liftIO $ stToIO $ UnsafeBounded.pasteST
(buildSize (fromIntegral (I# distZ))) (buildSize (fromIntegral (I# distZ)))
(MutableByteArray buf0) (MutableByteArray buf0)
0 0
_ <- cb =<< commitsToArray bufZ offZ cmtsZ _ <- cb =<< liftIO (IO (PM.internal (commitsToArray bufZ offZ cmtsZ)))
pure () pure ()
commitsToArray :: commitsToArray ::