Completely redo putMany and friends
This commit is contained in:
parent
786a83332b
commit
e8de684ae2
3 changed files with 120 additions and 94 deletions
|
@ -101,6 +101,7 @@ module Data.ByteArray.Builder
|
|||
, flush
|
||||
) where
|
||||
|
||||
import Control.Exception (SomeException,toException)
|
||||
import Control.Monad.ST (ST,runST)
|
||||
import Control.Monad.IO.Class (MonadIO,liftIO)
|
||||
import Data.ByteArray.Builder.Unsafe (Builder(Builder))
|
||||
|
@ -108,14 +109,14 @@ import Data.ByteArray.Builder.Unsafe (BuilderState(BuilderState),pasteIO)
|
|||
import Data.ByteArray.Builder.Unsafe (Commits(Initial,Mutable,Immutable))
|
||||
import Data.ByteArray.Builder.Unsafe (reverseCommitsOntoChunks)
|
||||
import Data.ByteArray.Builder.Unsafe (stringUtf8,cstring)
|
||||
import Data.ByteArray.Builder.Unsafe (addCommitsLength,copyReverseCommits)
|
||||
import Data.ByteString.Short.Internal (ShortByteString(SBS))
|
||||
import Data.Bytes.Chunks (Chunks(ChunksNil))
|
||||
import Data.Bytes.Types (Bytes(Bytes))
|
||||
import Data.Bytes.Types (Bytes(Bytes),MutableBytes(MutableBytes))
|
||||
import Data.Char (ord)
|
||||
import Data.Foldable (foldlM)
|
||||
import Data.Int (Int64,Int32,Int16,Int8)
|
||||
import Data.Primitive (ByteArray(..),MutableByteArray(..),PrimArray(..))
|
||||
import Data.Primitive.Unlifted.Array (MutableUnliftedArray,UnliftedArray)
|
||||
import Data.Text.Short (ShortText)
|
||||
import Data.WideWord (Word128)
|
||||
import Data.Word (Word64,Word32,Word16,Word8)
|
||||
|
@ -127,11 +128,9 @@ import GHC.ST (ST(ST))
|
|||
|
||||
import qualified Arithmetic.Nat as Nat
|
||||
import qualified Arithmetic.Types as Arithmetic
|
||||
import qualified Control.Monad.Primitive as PM
|
||||
import qualified Data.ByteArray.Builder.Bounded as Bounded
|
||||
import qualified Data.ByteArray.Builder.Bounded.Unsafe as UnsafeBounded
|
||||
import qualified Data.Primitive as PM
|
||||
import qualified Data.Primitive.Unlifted.Array as PM
|
||||
import qualified Data.Text.Short as TS
|
||||
import qualified GHC.Exts as Exts
|
||||
|
||||
|
@ -152,27 +151,49 @@ run hint@(I# hint# ) (Builder f) = runST $ do
|
|||
-- the callback escape from the callback (i.e. do not write it to an
|
||||
-- @IORef@). Also, do not @unsafeFreezeByteArray@ any of the mutable
|
||||
-- byte arrays in the callback. The intent is that the callback will
|
||||
-- write the buffers out, preferably using vectored I/O.
|
||||
-- write the buffer out.
|
||||
putMany :: Foldable f
|
||||
=> Int -- ^ Size of shared chunk (use 8176 if uncertain)
|
||||
-> (a -> Builder) -- ^ Value builder
|
||||
-> f a -- ^ Collection of values
|
||||
-> (UnliftedArray (MutableByteArray RealWorld) -> IO b) -- ^ Consume chunks.
|
||||
-> (MutableBytes RealWorld -> IO b) -- ^ Consume chunks.
|
||||
-> IO ()
|
||||
{-# inline putMany #-}
|
||||
putMany hint@(I# hint#) g xs cb = do
|
||||
putMany hint0 g xs cb = do
|
||||
MutableByteArray buf0 <- PM.newByteArray hint
|
||||
BuilderState bufZ offZ _ cmtsZ <- foldlM
|
||||
(\st0 a -> do
|
||||
st1@(BuilderState buf off _ cmts) <- pasteIO (g a) st0
|
||||
case cmts of
|
||||
Initial -> pure st1
|
||||
Initial -> if I# off < threshold
|
||||
then pure st1
|
||||
else do
|
||||
_ <- cb (MutableBytes (MutableByteArray buf) 0 (I# off))
|
||||
pure (BuilderState buf0 0# hint# Initial)
|
||||
_ -> do
|
||||
_ <- cb =<< commitsToArray buf off cmts
|
||||
pure (BuilderState buf0 0# hint# Initial)
|
||||
let total = addCommitsLength (I# off) cmts
|
||||
doff0 = total - I# off
|
||||
large <- PM.newByteArray total
|
||||
stToIO (PM.copyMutableByteArray large doff0 (MutableByteArray buf) 0 (I# off))
|
||||
r <- stToIO (copyReverseCommits large doff0 cmts)
|
||||
case r of
|
||||
0 -> do
|
||||
_ <- cb (MutableBytes large 0 total)
|
||||
pure (BuilderState buf0 0# hint# Initial)
|
||||
_ -> IO (\s0 -> Exts.raiseIO# putManyError s0)
|
||||
) (BuilderState buf0 0# hint# Initial) xs
|
||||
_ <- cb =<< commitsToArray bufZ offZ cmtsZ
|
||||
_ <- case cmtsZ of
|
||||
Initial -> cb (MutableBytes (MutableByteArray bufZ) 0 (I# offZ))
|
||||
_ -> IO (\s0 -> Exts.raiseIO# putManyError s0)
|
||||
pure ()
|
||||
where
|
||||
!hint@(I# hint#) = max hint0 8
|
||||
!threshold = div (hint * 3) 4
|
||||
|
||||
putManyError :: SomeException
|
||||
{-# noinline putManyError #-}
|
||||
putManyError = toException
|
||||
(userError "small-bytearray-builder: putMany implementation error")
|
||||
|
||||
-- | Variant of 'putMany' that prefixes each pushed array of chunks
|
||||
-- with the number of bytes that the chunks in each batch required.
|
||||
|
@ -184,75 +205,54 @@ putManyConsLength :: (Foldable f, MonadIO m)
|
|||
-> Int -- ^ Size of shared chunk (use 8176 if uncertain)
|
||||
-> (a -> Builder) -- ^ Value builder
|
||||
-> f a -- ^ Collection of values
|
||||
-> (UnliftedArray (MutableByteArray RealWorld) -> m b) -- ^ Consume chunks.
|
||||
-> (MutableBytes RealWorld -> m b) -- ^ Consume chunks.
|
||||
-> m ()
|
||||
{-# inline putManyConsLength #-}
|
||||
putManyConsLength n buildSize hint g xs cb = do
|
||||
let !(I# n# ) = Nat.demote n
|
||||
let !(I# actual# ) = max hint (I# n# )
|
||||
let !threshold = div (I# actual# * 3) 4
|
||||
MutableByteArray buf0 <- liftIO (PM.newByteArray (I# actual# ))
|
||||
BuilderState bufZ offZ _ cmtsZ <- foldlM
|
||||
(\st0 a -> do
|
||||
st1@(BuilderState buf off _ cmts) <- liftIO (pasteIO (g a) st0)
|
||||
case cmts of
|
||||
Initial -> pure st1
|
||||
Initial -> if I# off < threshold
|
||||
then pure st1
|
||||
else do
|
||||
let !dist = off -# n#
|
||||
_ <- liftIO $ stToIO $ UnsafeBounded.pasteST
|
||||
(buildSize (fromIntegral (I# dist)))
|
||||
(MutableByteArray buf0) 0
|
||||
_ <- cb (MutableBytes (MutableByteArray buf) 0 (I# off))
|
||||
pure (BuilderState buf0 n# (actual# -# n# ) Initial)
|
||||
_ -> do
|
||||
let !dist = commitDistance1 buf0 n# buf off cmts
|
||||
_ <- liftIO $ stToIO $ UnsafeBounded.pasteST
|
||||
(buildSize (fromIntegral (I# dist)))
|
||||
(MutableByteArray buf0)
|
||||
0
|
||||
_ <- cb =<< liftIO (IO (PM.internal (commitsToArray buf off cmts)))
|
||||
pure (BuilderState buf0 n# (actual# -# n# ) Initial)
|
||||
(MutableByteArray buf0) 0
|
||||
let total = addCommitsLength (I# off) cmts
|
||||
doff0 = total - I# off
|
||||
large <- liftIO (PM.newByteArray total)
|
||||
liftIO (stToIO (PM.copyMutableByteArray large doff0 (MutableByteArray buf) 0 (I# off)))
|
||||
r <- liftIO (stToIO (copyReverseCommits large doff0 cmts))
|
||||
case r of
|
||||
0 -> do
|
||||
_ <- cb (MutableBytes large 0 total)
|
||||
pure (BuilderState buf0 n# (actual# -# n# ) Initial)
|
||||
_ -> liftIO (IO (\s0 -> Exts.raiseIO# putManyError s0))
|
||||
) (BuilderState buf0 n# (actual# -# n# ) Initial) xs
|
||||
let !distZ = commitDistance1 bufZ n# bufZ offZ cmtsZ
|
||||
_ <- liftIO $ stToIO $ UnsafeBounded.pasteST
|
||||
(buildSize (fromIntegral (I# distZ)))
|
||||
(MutableByteArray buf0)
|
||||
0
|
||||
_ <- cb =<< liftIO (IO (PM.internal (commitsToArray bufZ offZ cmtsZ)))
|
||||
_ <- case cmtsZ of
|
||||
Initial -> do
|
||||
let !distZ = offZ -# n#
|
||||
_ <- liftIO $ stToIO $ UnsafeBounded.pasteST
|
||||
(buildSize (fromIntegral (I# distZ)))
|
||||
(MutableByteArray buf0)
|
||||
0
|
||||
cb (MutableBytes (MutableByteArray bufZ) 0 (I# offZ))
|
||||
_ -> liftIO (IO (\s0 -> Exts.raiseIO# putManyError s0))
|
||||
pure ()
|
||||
|
||||
commitsToArray ::
|
||||
MutableByteArray# RealWorld -- final chunk to append to commits
|
||||
-> Int# -- offset
|
||||
-> Commits RealWorld
|
||||
-> IO (UnliftedArray (MutableByteArray RealWorld))
|
||||
commitsToArray buf off cmts = do
|
||||
let ct = countCommits 1 cmts
|
||||
bufs <- PM.unsafeNewUnliftedArray ct
|
||||
-- Only shrink the last chunk. Crucially, this is never the first
|
||||
-- chunk (except on the commitsToArray call at the end of folding
|
||||
-- over the collection). We only perform this shrink in the hopes
|
||||
-- that a future GHC will allow reclaiming bytes from shrunk arrays.
|
||||
shrinkMutableByteArray (MutableByteArray buf) (I# off)
|
||||
PM.writeUnliftedArray bufs (ct - 1) (MutableByteArray buf)
|
||||
writeCommitsToArray (ct - 2) bufs cmts
|
||||
PM.unsafeFreezeUnliftedArray bufs
|
||||
|
||||
-- See the documentation for putMany.
|
||||
writeCommitsToArray ::
|
||||
Int
|
||||
-> MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
|
||||
-> Commits RealWorld
|
||||
-> IO ()
|
||||
writeCommitsToArray !ix !arrs x0 = case x0 of
|
||||
Initial -> pure ()
|
||||
Mutable buf _ x1 -> do
|
||||
PM.writeUnliftedArray arrs ix (MutableByteArray buf)
|
||||
writeCommitsToArray (ix - 1) arrs x1
|
||||
Immutable arr off len x1 -> do
|
||||
buf <- PM.newByteArray (I# len)
|
||||
PM.copyByteArray buf 0 (ByteArray arr) (I# off) (I# len)
|
||||
PM.writeUnliftedArray arrs ix buf
|
||||
writeCommitsToArray (ix - 1) arrs x1
|
||||
|
||||
countCommits :: Int -> Commits s -> Int
|
||||
countCommits !n x0 = case x0 of
|
||||
Initial -> n
|
||||
Mutable _ _ x1 -> countCommits (n + 1) x1
|
||||
Immutable _ _ _ x1 -> countCommits (n + 1) x1
|
||||
|
||||
-- | Convert a bounded builder to an unbounded one. If the size
|
||||
-- is a constant, use @Arithmetic.Nat.constant@ as the first argument
|
||||
-- to let GHC conjure up this value for you.
|
||||
|
@ -904,7 +904,3 @@ indexChar8Array (ByteArray b) (I# i) = C# (Exts.indexCharArray# b i)
|
|||
|
||||
c2w :: Char -> Word8
|
||||
c2w = fromIntegral . ord
|
||||
|
||||
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
|
||||
shrinkMutableByteArray (MutableByteArray x) (I# i) =
|
||||
IO (\s -> (# Exts.shrinkMutableByteArray# x i s, ()#))
|
||||
|
|
|
@ -18,6 +18,8 @@ module Data.ByteArray.Builder.Unsafe
|
|||
, fromEffect
|
||||
-- * Finalization
|
||||
, reverseCommitsOntoChunks
|
||||
, copyReverseCommits
|
||||
, addCommitsLength
|
||||
-- * Safe Functions
|
||||
-- | These functions are actually completely safe, but they are defined
|
||||
-- here because they are used by typeclass instances. Import them from
|
||||
|
@ -102,6 +104,13 @@ data Commits s
|
|||
!(Commits s)
|
||||
| Initial
|
||||
|
||||
-- | Add the total number of bytes in the commits to first
|
||||
-- argument.
|
||||
addCommitsLength :: Int -> Commits s -> Int
|
||||
addCommitsLength !acc Initial = acc
|
||||
addCommitsLength !acc (Immutable _ _ x cs) = addCommitsLength (acc + I# x) cs
|
||||
addCommitsLength !acc (Mutable _ x cs) = addCommitsLength (acc + I# x) cs
|
||||
|
||||
-- | Cons the chunks from a list of @Commits@ onto an initial
|
||||
-- @Chunks@ list (this argument is often @ChunksNil@). This reverses
|
||||
-- the order of the chunks, which is desirable since builders assemble
|
||||
|
@ -120,6 +129,37 @@ reverseCommitsOntoChunks !xs (Mutable buf len cs) = case len of
|
|||
arr <- PM.unsafeFreezeByteArray (MutableByteArray buf)
|
||||
reverseCommitsOntoChunks (ChunksCons (Bytes arr 0 (I# len)) xs) cs
|
||||
|
||||
-- | Copy the contents of the chunks into a mutable array, reversing
|
||||
-- the order of the chunks.
|
||||
-- Precondition: The destination must have enough space to house the
|
||||
-- contents. This is not checked.
|
||||
copyReverseCommits ::
|
||||
MutableByteArray s -- ^ Destination
|
||||
-> Int -- ^ Destination range successor
|
||||
-> Commits s -- ^ Source
|
||||
-> ST s Int
|
||||
{-# inline copyReverseCommits #-}
|
||||
copyReverseCommits (MutableByteArray dst) (I# off) cs = ST
|
||||
(\s0 -> case copyReverseCommits# dst off cs s0 of
|
||||
(# s1, nextOff #) -> (# s1, I# nextOff #)
|
||||
)
|
||||
|
||||
copyReverseCommits# ::
|
||||
MutableByteArray# s
|
||||
-> Int#
|
||||
-> Commits s
|
||||
-> State# s
|
||||
-> (# State# s, Int# #)
|
||||
copyReverseCommits# _ off Initial s0 = (# s0, off #)
|
||||
copyReverseCommits# marr prevOff (Mutable arr sz cs) s0 =
|
||||
let !off = prevOff -# sz in
|
||||
case Exts.copyMutableByteArray# arr 0# marr off sz s0 of
|
||||
s1 -> copyReverseCommits# marr off cs s1
|
||||
copyReverseCommits# marr prevOff (Immutable arr soff sz cs) s0 =
|
||||
let !off = prevOff -# sz in
|
||||
case Exts.copyByteArray# arr soff marr off sz s0 of
|
||||
s1 -> copyReverseCommits# marr off cs s1
|
||||
|
||||
-- | Create a builder from a cons-list of 'Char'. These
|
||||
-- are be UTF-8 encoded.
|
||||
stringUtf8 :: String -> Builder
|
||||
|
|
44
test/Main.hs
44
test/Main.hs
|
@ -8,7 +8,7 @@
|
|||
import Control.Applicative (liftA2)
|
||||
import Control.Monad.ST (runST)
|
||||
import Data.ByteArray.Builder
|
||||
import Data.Bytes.Types (Bytes(Bytes))
|
||||
import Data.Bytes.Types (Bytes(Bytes),MutableBytes(MutableBytes))
|
||||
import Data.Bytes.Chunks (Chunks(ChunksNil,ChunksCons))
|
||||
import Data.Primitive (PrimArray)
|
||||
import Data.Word
|
||||
|
@ -30,7 +30,6 @@ import qualified Data.ByteString.Lazy.Char8 as LB
|
|||
import qualified Data.Bytes.Chunks as Chunks
|
||||
import qualified Data.List as L
|
||||
import qualified Data.Primitive as PM
|
||||
import qualified Data.Primitive.Unlifted.Array as PM
|
||||
import qualified Data.Text as T
|
||||
import qualified Data.Text.Encoding as TE
|
||||
import qualified GHC.Exts as Exts
|
||||
|
@ -204,15 +203,13 @@ tests = testGroup "Tests"
|
|||
[ THU.testCase "A" $ do
|
||||
ref <- newIORef []
|
||||
let txt = "hello_world_are_you_listening" :: [Char]
|
||||
putMany 7 ascii txt (ontoRef ref)
|
||||
putMany 7 ascii txt (bytesOntoRef ref)
|
||||
res <- readIORef ref
|
||||
id $
|
||||
[ map c2w "hello_w"
|
||||
, map c2w "o"
|
||||
, map c2w "rld_are"
|
||||
, map c2w "_"
|
||||
, map c2w "you_lis"
|
||||
, map c2w "t"
|
||||
[ map c2w "hello_"
|
||||
, map c2w "world_"
|
||||
, map c2w "are_yo"
|
||||
, map c2w "u_list"
|
||||
, map c2w "ening"
|
||||
] @=? map Exts.toList (Exts.toList res)
|
||||
]
|
||||
|
@ -222,33 +219,26 @@ tests = testGroup "Tests"
|
|||
let txt = "hello_world_are_you_listening" :: [Char]
|
||||
putManyConsLength Nat.constant
|
||||
(\n -> Bounded.word16BE (fromIntegral n))
|
||||
13 ascii txt (ontoRef ref)
|
||||
16 ascii txt (bytesOntoRef ref)
|
||||
res <- readIORef ref
|
||||
id $
|
||||
[ 0x00 : 0x0C : map c2w "hello_world"
|
||||
, map c2w "_"
|
||||
, 0x00 : 0x0C : map c2w "are_you_lis"
|
||||
, map c2w "t"
|
||||
, 0x00 : 0x05 : map c2w "ening"
|
||||
[ 0x00 : 0x0A : map c2w "hello_worl"
|
||||
, 0x00 : 0x0A : map c2w "d_are_you_"
|
||||
, 0x00 : 0x09 : map c2w "listening"
|
||||
] @=? map Exts.toList (Exts.toList res)
|
||||
]
|
||||
]
|
||||
|
||||
ontoRef ::
|
||||
bytesOntoRef ::
|
||||
IORef [PM.ByteArray]
|
||||
-> PM.UnliftedArray (PM.MutableByteArray Exts.RealWorld)
|
||||
-> MutableBytes Exts.RealWorld
|
||||
-> IO ()
|
||||
ontoRef !ref xs = do
|
||||
bytesOntoRef !ref (MutableBytes buf off len) = do
|
||||
rs <- readIORef ref
|
||||
ps <- PM.foldlUnliftedArrayM'
|
||||
(\ys buf -> do
|
||||
len <- PM.getSizeofMutableByteArray buf
|
||||
dst <- PM.newByteArray len
|
||||
PM.copyMutableByteArray dst 0 buf 0 len
|
||||
dst' <- PM.unsafeFreezeByteArray dst
|
||||
pure (ys ++ [dst'])
|
||||
) [] xs
|
||||
writeIORef ref (rs ++ ps)
|
||||
dst <- PM.newByteArray len
|
||||
PM.copyMutableByteArray dst 0 buf off len
|
||||
dst' <- PM.unsafeFreezeByteArray dst
|
||||
writeIORef ref (rs ++ [dst'])
|
||||
|
||||
instance Arbitrary Chunks where
|
||||
arbitrary = do
|
||||
|
|
Loading…
Reference in a new issue