Completely redo putMany and friends

This commit is contained in:
Andrew Martin 2019-11-25 10:52:00 -05:00
parent 786a83332b
commit e8de684ae2
3 changed files with 120 additions and 94 deletions

View file

@ -101,6 +101,7 @@ module Data.ByteArray.Builder
, flush , flush
) where ) where
import Control.Exception (SomeException,toException)
import Control.Monad.ST (ST,runST) import Control.Monad.ST (ST,runST)
import Control.Monad.IO.Class (MonadIO,liftIO) import Control.Monad.IO.Class (MonadIO,liftIO)
import Data.ByteArray.Builder.Unsafe (Builder(Builder)) import Data.ByteArray.Builder.Unsafe (Builder(Builder))
@ -108,14 +109,14 @@ import Data.ByteArray.Builder.Unsafe (BuilderState(BuilderState),pasteIO)
import Data.ByteArray.Builder.Unsafe (Commits(Initial,Mutable,Immutable)) import Data.ByteArray.Builder.Unsafe (Commits(Initial,Mutable,Immutable))
import Data.ByteArray.Builder.Unsafe (reverseCommitsOntoChunks) import Data.ByteArray.Builder.Unsafe (reverseCommitsOntoChunks)
import Data.ByteArray.Builder.Unsafe (stringUtf8,cstring) import Data.ByteArray.Builder.Unsafe (stringUtf8,cstring)
import Data.ByteArray.Builder.Unsafe (addCommitsLength,copyReverseCommits)
import Data.ByteString.Short.Internal (ShortByteString(SBS)) import Data.ByteString.Short.Internal (ShortByteString(SBS))
import Data.Bytes.Chunks (Chunks(ChunksNil)) import Data.Bytes.Chunks (Chunks(ChunksNil))
import Data.Bytes.Types (Bytes(Bytes)) import Data.Bytes.Types (Bytes(Bytes),MutableBytes(MutableBytes))
import Data.Char (ord) import Data.Char (ord)
import Data.Foldable (foldlM) import Data.Foldable (foldlM)
import Data.Int (Int64,Int32,Int16,Int8) import Data.Int (Int64,Int32,Int16,Int8)
import Data.Primitive (ByteArray(..),MutableByteArray(..),PrimArray(..)) import Data.Primitive (ByteArray(..),MutableByteArray(..),PrimArray(..))
import Data.Primitive.Unlifted.Array (MutableUnliftedArray,UnliftedArray)
import Data.Text.Short (ShortText) import Data.Text.Short (ShortText)
import Data.WideWord (Word128) import Data.WideWord (Word128)
import Data.Word (Word64,Word32,Word16,Word8) import Data.Word (Word64,Word32,Word16,Word8)
@ -127,11 +128,9 @@ import GHC.ST (ST(ST))
import qualified Arithmetic.Nat as Nat import qualified Arithmetic.Nat as Nat
import qualified Arithmetic.Types as Arithmetic import qualified Arithmetic.Types as Arithmetic
import qualified Control.Monad.Primitive as PM
import qualified Data.ByteArray.Builder.Bounded as Bounded import qualified Data.ByteArray.Builder.Bounded as Bounded
import qualified Data.ByteArray.Builder.Bounded.Unsafe as UnsafeBounded import qualified Data.ByteArray.Builder.Bounded.Unsafe as UnsafeBounded
import qualified Data.Primitive as PM import qualified Data.Primitive as PM
import qualified Data.Primitive.Unlifted.Array as PM
import qualified Data.Text.Short as TS import qualified Data.Text.Short as TS
import qualified GHC.Exts as Exts import qualified GHC.Exts as Exts
@ -152,27 +151,49 @@ run hint@(I# hint# ) (Builder f) = runST $ do
-- the callback escape from the callback (i.e. do not write it to an -- the callback escape from the callback (i.e. do not write it to an
-- @IORef@). Also, do not @unsafeFreezeByteArray@ any of the mutable -- @IORef@). Also, do not @unsafeFreezeByteArray@ any of the mutable
-- byte arrays in the callback. The intent is that the callback will -- byte arrays in the callback. The intent is that the callback will
-- write the buffers out, preferably using vectored I/O. -- write the buffer out.
putMany :: Foldable f putMany :: Foldable f
=> Int -- ^ Size of shared chunk (use 8176 if uncertain) => Int -- ^ Size of shared chunk (use 8176 if uncertain)
-> (a -> Builder) -- ^ Value builder -> (a -> Builder) -- ^ Value builder
-> f a -- ^ Collection of values -> f a -- ^ Collection of values
-> (UnliftedArray (MutableByteArray RealWorld) -> IO b) -- ^ Consume chunks. -> (MutableBytes RealWorld -> IO b) -- ^ Consume chunks.
-> IO () -> IO ()
{-# inline putMany #-} {-# inline putMany #-}
putMany hint@(I# hint#) g xs cb = do putMany hint0 g xs cb = do
MutableByteArray buf0 <- PM.newByteArray hint MutableByteArray buf0 <- PM.newByteArray hint
BuilderState bufZ offZ _ cmtsZ <- foldlM BuilderState bufZ offZ _ cmtsZ <- foldlM
(\st0 a -> do (\st0 a -> do
st1@(BuilderState buf off _ cmts) <- pasteIO (g a) st0 st1@(BuilderState buf off _ cmts) <- pasteIO (g a) st0
case cmts of case cmts of
Initial -> pure st1 Initial -> if I# off < threshold
_ -> do then pure st1
_ <- cb =<< commitsToArray buf off cmts else do
_ <- cb (MutableBytes (MutableByteArray buf) 0 (I# off))
pure (BuilderState buf0 0# hint# Initial) pure (BuilderState buf0 0# hint# Initial)
_ -> do
let total = addCommitsLength (I# off) cmts
doff0 = total - I# off
large <- PM.newByteArray total
stToIO (PM.copyMutableByteArray large doff0 (MutableByteArray buf) 0 (I# off))
r <- stToIO (copyReverseCommits large doff0 cmts)
case r of
0 -> do
_ <- cb (MutableBytes large 0 total)
pure (BuilderState buf0 0# hint# Initial)
_ -> IO (\s0 -> Exts.raiseIO# putManyError s0)
) (BuilderState buf0 0# hint# Initial) xs ) (BuilderState buf0 0# hint# Initial) xs
_ <- cb =<< commitsToArray bufZ offZ cmtsZ _ <- case cmtsZ of
Initial -> cb (MutableBytes (MutableByteArray bufZ) 0 (I# offZ))
_ -> IO (\s0 -> Exts.raiseIO# putManyError s0)
pure () pure ()
where
!hint@(I# hint#) = max hint0 8
!threshold = div (hint * 3) 4
putManyError :: SomeException
{-# noinline putManyError #-}
putManyError = toException
(userError "small-bytearray-builder: putMany implementation error")
-- | Variant of 'putMany' that prefixes each pushed array of chunks -- | Variant of 'putMany' that prefixes each pushed array of chunks
-- with the number of bytes that the chunks in each batch required. -- with the number of bytes that the chunks in each batch required.
@ -184,75 +205,54 @@ putManyConsLength :: (Foldable f, MonadIO m)
-> Int -- ^ Size of shared chunk (use 8176 if uncertain) -> Int -- ^ Size of shared chunk (use 8176 if uncertain)
-> (a -> Builder) -- ^ Value builder -> (a -> Builder) -- ^ Value builder
-> f a -- ^ Collection of values -> f a -- ^ Collection of values
-> (UnliftedArray (MutableByteArray RealWorld) -> m b) -- ^ Consume chunks. -> (MutableBytes RealWorld -> m b) -- ^ Consume chunks.
-> m () -> m ()
{-# inline putManyConsLength #-} {-# inline putManyConsLength #-}
putManyConsLength n buildSize hint g xs cb = do putManyConsLength n buildSize hint g xs cb = do
let !(I# n# ) = Nat.demote n let !(I# n# ) = Nat.demote n
let !(I# actual# ) = max hint (I# n# ) let !(I# actual# ) = max hint (I# n# )
let !threshold = div (I# actual# * 3) 4
MutableByteArray buf0 <- liftIO (PM.newByteArray (I# actual# )) MutableByteArray buf0 <- liftIO (PM.newByteArray (I# actual# ))
BuilderState bufZ offZ _ cmtsZ <- foldlM BuilderState bufZ offZ _ cmtsZ <- foldlM
(\st0 a -> do (\st0 a -> do
st1@(BuilderState buf off _ cmts) <- liftIO (pasteIO (g a) st0) st1@(BuilderState buf off _ cmts) <- liftIO (pasteIO (g a) st0)
case cmts of case cmts of
Initial -> pure st1 Initial -> if I# off < threshold
then pure st1
else do
let !dist = off -# n#
_ <- liftIO $ stToIO $ UnsafeBounded.pasteST
(buildSize (fromIntegral (I# dist)))
(MutableByteArray buf0) 0
_ <- cb (MutableBytes (MutableByteArray buf) 0 (I# off))
pure (BuilderState buf0 n# (actual# -# n# ) Initial)
_ -> do _ -> do
let !dist = commitDistance1 buf0 n# buf off cmts let !dist = commitDistance1 buf0 n# buf off cmts
_ <- liftIO $ stToIO $ UnsafeBounded.pasteST _ <- liftIO $ stToIO $ UnsafeBounded.pasteST
(buildSize (fromIntegral (I# dist))) (buildSize (fromIntegral (I# dist)))
(MutableByteArray buf0) (MutableByteArray buf0) 0
0 let total = addCommitsLength (I# off) cmts
_ <- cb =<< liftIO (IO (PM.internal (commitsToArray buf off cmts))) doff0 = total - I# off
large <- liftIO (PM.newByteArray total)
liftIO (stToIO (PM.copyMutableByteArray large doff0 (MutableByteArray buf) 0 (I# off)))
r <- liftIO (stToIO (copyReverseCommits large doff0 cmts))
case r of
0 -> do
_ <- cb (MutableBytes large 0 total)
pure (BuilderState buf0 n# (actual# -# n# ) Initial) pure (BuilderState buf0 n# (actual# -# n# ) Initial)
_ -> liftIO (IO (\s0 -> Exts.raiseIO# putManyError s0))
) (BuilderState buf0 n# (actual# -# n# ) Initial) xs ) (BuilderState buf0 n# (actual# -# n# ) Initial) xs
let !distZ = commitDistance1 bufZ n# bufZ offZ cmtsZ _ <- case cmtsZ of
Initial -> do
let !distZ = offZ -# n#
_ <- liftIO $ stToIO $ UnsafeBounded.pasteST _ <- liftIO $ stToIO $ UnsafeBounded.pasteST
(buildSize (fromIntegral (I# distZ))) (buildSize (fromIntegral (I# distZ)))
(MutableByteArray buf0) (MutableByteArray buf0)
0 0
_ <- cb =<< liftIO (IO (PM.internal (commitsToArray bufZ offZ cmtsZ))) cb (MutableBytes (MutableByteArray bufZ) 0 (I# offZ))
_ -> liftIO (IO (\s0 -> Exts.raiseIO# putManyError s0))
pure () pure ()
commitsToArray ::
MutableByteArray# RealWorld -- final chunk to append to commits
-> Int# -- offset
-> Commits RealWorld
-> IO (UnliftedArray (MutableByteArray RealWorld))
commitsToArray buf off cmts = do
let ct = countCommits 1 cmts
bufs <- PM.unsafeNewUnliftedArray ct
-- Only shrink the last chunk. Crucially, this is never the first
-- chunk (except on the commitsToArray call at the end of folding
-- over the collection). We only perform this shrink in the hopes
-- that a future GHC will allow reclaiming bytes from shrunk arrays.
shrinkMutableByteArray (MutableByteArray buf) (I# off)
PM.writeUnliftedArray bufs (ct - 1) (MutableByteArray buf)
writeCommitsToArray (ct - 2) bufs cmts
PM.unsafeFreezeUnliftedArray bufs
-- See the documentation for putMany.
writeCommitsToArray ::
Int
-> MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Commits RealWorld
-> IO ()
writeCommitsToArray !ix !arrs x0 = case x0 of
Initial -> pure ()
Mutable buf _ x1 -> do
PM.writeUnliftedArray arrs ix (MutableByteArray buf)
writeCommitsToArray (ix - 1) arrs x1
Immutable arr off len x1 -> do
buf <- PM.newByteArray (I# len)
PM.copyByteArray buf 0 (ByteArray arr) (I# off) (I# len)
PM.writeUnliftedArray arrs ix buf
writeCommitsToArray (ix - 1) arrs x1
countCommits :: Int -> Commits s -> Int
countCommits !n x0 = case x0 of
Initial -> n
Mutable _ _ x1 -> countCommits (n + 1) x1
Immutable _ _ _ x1 -> countCommits (n + 1) x1
-- | Convert a bounded builder to an unbounded one. If the size -- | Convert a bounded builder to an unbounded one. If the size
-- is a constant, use @Arithmetic.Nat.constant@ as the first argument -- is a constant, use @Arithmetic.Nat.constant@ as the first argument
-- to let GHC conjure up this value for you. -- to let GHC conjure up this value for you.
@ -904,7 +904,3 @@ indexChar8Array (ByteArray b) (I# i) = C# (Exts.indexCharArray# b i)
c2w :: Char -> Word8 c2w :: Char -> Word8
c2w = fromIntegral . ord c2w = fromIntegral . ord
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray (MutableByteArray x) (I# i) =
IO (\s -> (# Exts.shrinkMutableByteArray# x i s, ()#))

View file

@ -18,6 +18,8 @@ module Data.ByteArray.Builder.Unsafe
, fromEffect , fromEffect
-- * Finalization -- * Finalization
, reverseCommitsOntoChunks , reverseCommitsOntoChunks
, copyReverseCommits
, addCommitsLength
-- * Safe Functions -- * Safe Functions
-- | These functions are actually completely safe, but they are defined -- | These functions are actually completely safe, but they are defined
-- here because they are used by typeclass instances. Import them from -- here because they are used by typeclass instances. Import them from
@ -102,6 +104,13 @@ data Commits s
!(Commits s) !(Commits s)
| Initial | Initial
-- | Add the total number of bytes in the commits to first
-- argument.
addCommitsLength :: Int -> Commits s -> Int
addCommitsLength !acc Initial = acc
addCommitsLength !acc (Immutable _ _ x cs) = addCommitsLength (acc + I# x) cs
addCommitsLength !acc (Mutable _ x cs) = addCommitsLength (acc + I# x) cs
-- | Cons the chunks from a list of @Commits@ onto an initial -- | Cons the chunks from a list of @Commits@ onto an initial
-- @Chunks@ list (this argument is often @ChunksNil@). This reverses -- @Chunks@ list (this argument is often @ChunksNil@). This reverses
-- the order of the chunks, which is desirable since builders assemble -- the order of the chunks, which is desirable since builders assemble
@ -120,6 +129,37 @@ reverseCommitsOntoChunks !xs (Mutable buf len cs) = case len of
arr <- PM.unsafeFreezeByteArray (MutableByteArray buf) arr <- PM.unsafeFreezeByteArray (MutableByteArray buf)
reverseCommitsOntoChunks (ChunksCons (Bytes arr 0 (I# len)) xs) cs reverseCommitsOntoChunks (ChunksCons (Bytes arr 0 (I# len)) xs) cs
-- | Copy the contents of the chunks into a mutable array, reversing
-- the order of the chunks.
-- Precondition: The destination must have enough space to house the
-- contents. This is not checked.
copyReverseCommits ::
MutableByteArray s -- ^ Destination
-> Int -- ^ Destination range successor
-> Commits s -- ^ Source
-> ST s Int
{-# inline copyReverseCommits #-}
copyReverseCommits (MutableByteArray dst) (I# off) cs = ST
(\s0 -> case copyReverseCommits# dst off cs s0 of
(# s1, nextOff #) -> (# s1, I# nextOff #)
)
copyReverseCommits# ::
MutableByteArray# s
-> Int#
-> Commits s
-> State# s
-> (# State# s, Int# #)
copyReverseCommits# _ off Initial s0 = (# s0, off #)
copyReverseCommits# marr prevOff (Mutable arr sz cs) s0 =
let !off = prevOff -# sz in
case Exts.copyMutableByteArray# arr 0# marr off sz s0 of
s1 -> copyReverseCommits# marr off cs s1
copyReverseCommits# marr prevOff (Immutable arr soff sz cs) s0 =
let !off = prevOff -# sz in
case Exts.copyByteArray# arr soff marr off sz s0 of
s1 -> copyReverseCommits# marr off cs s1
-- | Create a builder from a cons-list of 'Char'. These -- | Create a builder from a cons-list of 'Char'. These
-- are be UTF-8 encoded. -- are be UTF-8 encoded.
stringUtf8 :: String -> Builder stringUtf8 :: String -> Builder

View file

@ -8,7 +8,7 @@
import Control.Applicative (liftA2) import Control.Applicative (liftA2)
import Control.Monad.ST (runST) import Control.Monad.ST (runST)
import Data.ByteArray.Builder import Data.ByteArray.Builder
import Data.Bytes.Types (Bytes(Bytes)) import Data.Bytes.Types (Bytes(Bytes),MutableBytes(MutableBytes))
import Data.Bytes.Chunks (Chunks(ChunksNil,ChunksCons)) import Data.Bytes.Chunks (Chunks(ChunksNil,ChunksCons))
import Data.Primitive (PrimArray) import Data.Primitive (PrimArray)
import Data.Word import Data.Word
@ -30,7 +30,6 @@ import qualified Data.ByteString.Lazy.Char8 as LB
import qualified Data.Bytes.Chunks as Chunks import qualified Data.Bytes.Chunks as Chunks
import qualified Data.List as L import qualified Data.List as L
import qualified Data.Primitive as PM import qualified Data.Primitive as PM
import qualified Data.Primitive.Unlifted.Array as PM
import qualified Data.Text as T import qualified Data.Text as T
import qualified Data.Text.Encoding as TE import qualified Data.Text.Encoding as TE
import qualified GHC.Exts as Exts import qualified GHC.Exts as Exts
@ -204,15 +203,13 @@ tests = testGroup "Tests"
[ THU.testCase "A" $ do [ THU.testCase "A" $ do
ref <- newIORef [] ref <- newIORef []
let txt = "hello_world_are_you_listening" :: [Char] let txt = "hello_world_are_you_listening" :: [Char]
putMany 7 ascii txt (ontoRef ref) putMany 7 ascii txt (bytesOntoRef ref)
res <- readIORef ref res <- readIORef ref
id $ id $
[ map c2w "hello_w" [ map c2w "hello_"
, map c2w "o" , map c2w "world_"
, map c2w "rld_are" , map c2w "are_yo"
, map c2w "_" , map c2w "u_list"
, map c2w "you_lis"
, map c2w "t"
, map c2w "ening" , map c2w "ening"
] @=? map Exts.toList (Exts.toList res) ] @=? map Exts.toList (Exts.toList res)
] ]
@ -222,33 +219,26 @@ tests = testGroup "Tests"
let txt = "hello_world_are_you_listening" :: [Char] let txt = "hello_world_are_you_listening" :: [Char]
putManyConsLength Nat.constant putManyConsLength Nat.constant
(\n -> Bounded.word16BE (fromIntegral n)) (\n -> Bounded.word16BE (fromIntegral n))
13 ascii txt (ontoRef ref) 16 ascii txt (bytesOntoRef ref)
res <- readIORef ref res <- readIORef ref
id $ id $
[ 0x00 : 0x0C : map c2w "hello_world" [ 0x00 : 0x0A : map c2w "hello_worl"
, map c2w "_" , 0x00 : 0x0A : map c2w "d_are_you_"
, 0x00 : 0x0C : map c2w "are_you_lis" , 0x00 : 0x09 : map c2w "listening"
, map c2w "t"
, 0x00 : 0x05 : map c2w "ening"
] @=? map Exts.toList (Exts.toList res) ] @=? map Exts.toList (Exts.toList res)
] ]
] ]
ontoRef :: bytesOntoRef ::
IORef [PM.ByteArray] IORef [PM.ByteArray]
-> PM.UnliftedArray (PM.MutableByteArray Exts.RealWorld) -> MutableBytes Exts.RealWorld
-> IO () -> IO ()
ontoRef !ref xs = do bytesOntoRef !ref (MutableBytes buf off len) = do
rs <- readIORef ref rs <- readIORef ref
ps <- PM.foldlUnliftedArrayM'
(\ys buf -> do
len <- PM.getSizeofMutableByteArray buf
dst <- PM.newByteArray len dst <- PM.newByteArray len
PM.copyMutableByteArray dst 0 buf 0 len PM.copyMutableByteArray dst 0 buf off len
dst' <- PM.unsafeFreezeByteArray dst dst' <- PM.unsafeFreezeByteArray dst
pure (ys ++ [dst']) writeIORef ref (rs ++ [dst'])
) [] xs
writeIORef ref (rs ++ ps)
instance Arbitrary Chunks where instance Arbitrary Chunks where
arbitrary = do arbitrary = do