diff --git a/src/Data/ByteArray/Builder.hs b/src/Data/ByteArray/Builder.hs index cc19f98..b42000b 100644 --- a/src/Data/ByteArray/Builder.hs +++ b/src/Data/ByteArray/Builder.hs @@ -101,6 +101,7 @@ module Data.ByteArray.Builder , flush ) where +import Control.Exception (SomeException,toException) import Control.Monad.ST (ST,runST) import Control.Monad.IO.Class (MonadIO,liftIO) import Data.ByteArray.Builder.Unsafe (Builder(Builder)) @@ -108,14 +109,14 @@ import Data.ByteArray.Builder.Unsafe (BuilderState(BuilderState),pasteIO) import Data.ByteArray.Builder.Unsafe (Commits(Initial,Mutable,Immutable)) import Data.ByteArray.Builder.Unsafe (reverseCommitsOntoChunks) import Data.ByteArray.Builder.Unsafe (stringUtf8,cstring) +import Data.ByteArray.Builder.Unsafe (addCommitsLength,copyReverseCommits) import Data.ByteString.Short.Internal (ShortByteString(SBS)) import Data.Bytes.Chunks (Chunks(ChunksNil)) -import Data.Bytes.Types (Bytes(Bytes)) +import Data.Bytes.Types (Bytes(Bytes),MutableBytes(MutableBytes)) import Data.Char (ord) import Data.Foldable (foldlM) import Data.Int (Int64,Int32,Int16,Int8) import Data.Primitive (ByteArray(..),MutableByteArray(..),PrimArray(..)) -import Data.Primitive.Unlifted.Array (MutableUnliftedArray,UnliftedArray) import Data.Text.Short (ShortText) import Data.WideWord (Word128) import Data.Word (Word64,Word32,Word16,Word8) @@ -127,11 +128,9 @@ import GHC.ST (ST(ST)) import qualified Arithmetic.Nat as Nat import qualified Arithmetic.Types as Arithmetic -import qualified Control.Monad.Primitive as PM import qualified Data.ByteArray.Builder.Bounded as Bounded import qualified Data.ByteArray.Builder.Bounded.Unsafe as UnsafeBounded import qualified Data.Primitive as PM -import qualified Data.Primitive.Unlifted.Array as PM import qualified Data.Text.Short as TS import qualified GHC.Exts as Exts @@ -152,27 +151,49 @@ run hint@(I# hint# ) (Builder f) = runST $ do -- the callback escape from the callback (i.e. do not write it to an -- @IORef@). Also, do not @unsafeFreezeByteArray@ any of the mutable -- byte arrays in the callback. The intent is that the callback will --- write the buffers out, preferably using vectored I/O. +-- write the buffer out. putMany :: Foldable f => Int -- ^ Size of shared chunk (use 8176 if uncertain) -> (a -> Builder) -- ^ Value builder -> f a -- ^ Collection of values - -> (UnliftedArray (MutableByteArray RealWorld) -> IO b) -- ^ Consume chunks. + -> (MutableBytes RealWorld -> IO b) -- ^ Consume chunks. -> IO () {-# inline putMany #-} -putMany hint@(I# hint#) g xs cb = do +putMany hint0 g xs cb = do MutableByteArray buf0 <- PM.newByteArray hint BuilderState bufZ offZ _ cmtsZ <- foldlM (\st0 a -> do st1@(BuilderState buf off _ cmts) <- pasteIO (g a) st0 case cmts of - Initial -> pure st1 + Initial -> if I# off < threshold + then pure st1 + else do + _ <- cb (MutableBytes (MutableByteArray buf) 0 (I# off)) + pure (BuilderState buf0 0# hint# Initial) _ -> do - _ <- cb =<< commitsToArray buf off cmts - pure (BuilderState buf0 0# hint# Initial) + let total = addCommitsLength (I# off) cmts + doff0 = total - I# off + large <- PM.newByteArray total + stToIO (PM.copyMutableByteArray large doff0 (MutableByteArray buf) 0 (I# off)) + r <- stToIO (copyReverseCommits large doff0 cmts) + case r of + 0 -> do + _ <- cb (MutableBytes large 0 total) + pure (BuilderState buf0 0# hint# Initial) + _ -> IO (\s0 -> Exts.raiseIO# putManyError s0) ) (BuilderState buf0 0# hint# Initial) xs - _ <- cb =<< commitsToArray bufZ offZ cmtsZ + _ <- case cmtsZ of + Initial -> cb (MutableBytes (MutableByteArray bufZ) 0 (I# offZ)) + _ -> IO (\s0 -> Exts.raiseIO# putManyError s0) pure () + where + !hint@(I# hint#) = max hint0 8 + !threshold = div (hint * 3) 4 + +putManyError :: SomeException +{-# noinline putManyError #-} +putManyError = toException + (userError "small-bytearray-builder: putMany implementation error") -- | Variant of 'putMany' that prefixes each pushed array of chunks -- with the number of bytes that the chunks in each batch required. @@ -184,75 +205,54 @@ putManyConsLength :: (Foldable f, MonadIO m) -> Int -- ^ Size of shared chunk (use 8176 if uncertain) -> (a -> Builder) -- ^ Value builder -> f a -- ^ Collection of values - -> (UnliftedArray (MutableByteArray RealWorld) -> m b) -- ^ Consume chunks. + -> (MutableBytes RealWorld -> m b) -- ^ Consume chunks. -> m () {-# inline putManyConsLength #-} putManyConsLength n buildSize hint g xs cb = do let !(I# n# ) = Nat.demote n let !(I# actual# ) = max hint (I# n# ) + let !threshold = div (I# actual# * 3) 4 MutableByteArray buf0 <- liftIO (PM.newByteArray (I# actual# )) BuilderState bufZ offZ _ cmtsZ <- foldlM (\st0 a -> do st1@(BuilderState buf off _ cmts) <- liftIO (pasteIO (g a) st0) case cmts of - Initial -> pure st1 + Initial -> if I# off < threshold + then pure st1 + else do + let !dist = off -# n# + _ <- liftIO $ stToIO $ UnsafeBounded.pasteST + (buildSize (fromIntegral (I# dist))) + (MutableByteArray buf0) 0 + _ <- cb (MutableBytes (MutableByteArray buf) 0 (I# off)) + pure (BuilderState buf0 n# (actual# -# n# ) Initial) _ -> do let !dist = commitDistance1 buf0 n# buf off cmts _ <- liftIO $ stToIO $ UnsafeBounded.pasteST (buildSize (fromIntegral (I# dist))) - (MutableByteArray buf0) - 0 - _ <- cb =<< liftIO (IO (PM.internal (commitsToArray buf off cmts))) - pure (BuilderState buf0 n# (actual# -# n# ) Initial) + (MutableByteArray buf0) 0 + let total = addCommitsLength (I# off) cmts + doff0 = total - I# off + large <- liftIO (PM.newByteArray total) + liftIO (stToIO (PM.copyMutableByteArray large doff0 (MutableByteArray buf) 0 (I# off))) + r <- liftIO (stToIO (copyReverseCommits large doff0 cmts)) + case r of + 0 -> do + _ <- cb (MutableBytes large 0 total) + pure (BuilderState buf0 n# (actual# -# n# ) Initial) + _ -> liftIO (IO (\s0 -> Exts.raiseIO# putManyError s0)) ) (BuilderState buf0 n# (actual# -# n# ) Initial) xs - let !distZ = commitDistance1 bufZ n# bufZ offZ cmtsZ - _ <- liftIO $ stToIO $ UnsafeBounded.pasteST - (buildSize (fromIntegral (I# distZ))) - (MutableByteArray buf0) - 0 - _ <- cb =<< liftIO (IO (PM.internal (commitsToArray bufZ offZ cmtsZ))) + _ <- case cmtsZ of + Initial -> do + let !distZ = offZ -# n# + _ <- liftIO $ stToIO $ UnsafeBounded.pasteST + (buildSize (fromIntegral (I# distZ))) + (MutableByteArray buf0) + 0 + cb (MutableBytes (MutableByteArray bufZ) 0 (I# offZ)) + _ -> liftIO (IO (\s0 -> Exts.raiseIO# putManyError s0)) pure () -commitsToArray :: - MutableByteArray# RealWorld -- final chunk to append to commits - -> Int# -- offset - -> Commits RealWorld - -> IO (UnliftedArray (MutableByteArray RealWorld)) -commitsToArray buf off cmts = do - let ct = countCommits 1 cmts - bufs <- PM.unsafeNewUnliftedArray ct - -- Only shrink the last chunk. Crucially, this is never the first - -- chunk (except on the commitsToArray call at the end of folding - -- over the collection). We only perform this shrink in the hopes - -- that a future GHC will allow reclaiming bytes from shrunk arrays. - shrinkMutableByteArray (MutableByteArray buf) (I# off) - PM.writeUnliftedArray bufs (ct - 1) (MutableByteArray buf) - writeCommitsToArray (ct - 2) bufs cmts - PM.unsafeFreezeUnliftedArray bufs - --- See the documentation for putMany. -writeCommitsToArray :: - Int - -> MutableUnliftedArray RealWorld (MutableByteArray RealWorld) - -> Commits RealWorld - -> IO () -writeCommitsToArray !ix !arrs x0 = case x0 of - Initial -> pure () - Mutable buf _ x1 -> do - PM.writeUnliftedArray arrs ix (MutableByteArray buf) - writeCommitsToArray (ix - 1) arrs x1 - Immutable arr off len x1 -> do - buf <- PM.newByteArray (I# len) - PM.copyByteArray buf 0 (ByteArray arr) (I# off) (I# len) - PM.writeUnliftedArray arrs ix buf - writeCommitsToArray (ix - 1) arrs x1 - -countCommits :: Int -> Commits s -> Int -countCommits !n x0 = case x0 of - Initial -> n - Mutable _ _ x1 -> countCommits (n + 1) x1 - Immutable _ _ _ x1 -> countCommits (n + 1) x1 - -- | Convert a bounded builder to an unbounded one. If the size -- is a constant, use @Arithmetic.Nat.constant@ as the first argument -- to let GHC conjure up this value for you. @@ -904,7 +904,3 @@ indexChar8Array (ByteArray b) (I# i) = C# (Exts.indexCharArray# b i) c2w :: Char -> Word8 c2w = fromIntegral . ord - -shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO () -shrinkMutableByteArray (MutableByteArray x) (I# i) = - IO (\s -> (# Exts.shrinkMutableByteArray# x i s, ()#)) diff --git a/src/Data/ByteArray/Builder/Unsafe.hs b/src/Data/ByteArray/Builder/Unsafe.hs index 2cc3716..54a480e 100644 --- a/src/Data/ByteArray/Builder/Unsafe.hs +++ b/src/Data/ByteArray/Builder/Unsafe.hs @@ -18,6 +18,8 @@ module Data.ByteArray.Builder.Unsafe , fromEffect -- * Finalization , reverseCommitsOntoChunks + , copyReverseCommits + , addCommitsLength -- * Safe Functions -- | These functions are actually completely safe, but they are defined -- here because they are used by typeclass instances. Import them from @@ -102,6 +104,13 @@ data Commits s !(Commits s) | Initial +-- | Add the total number of bytes in the commits to first +-- argument. +addCommitsLength :: Int -> Commits s -> Int +addCommitsLength !acc Initial = acc +addCommitsLength !acc (Immutable _ _ x cs) = addCommitsLength (acc + I# x) cs +addCommitsLength !acc (Mutable _ x cs) = addCommitsLength (acc + I# x) cs + -- | Cons the chunks from a list of @Commits@ onto an initial -- @Chunks@ list (this argument is often @ChunksNil@). This reverses -- the order of the chunks, which is desirable since builders assemble @@ -120,6 +129,37 @@ reverseCommitsOntoChunks !xs (Mutable buf len cs) = case len of arr <- PM.unsafeFreezeByteArray (MutableByteArray buf) reverseCommitsOntoChunks (ChunksCons (Bytes arr 0 (I# len)) xs) cs +-- | Copy the contents of the chunks into a mutable array, reversing +-- the order of the chunks. +-- Precondition: The destination must have enough space to house the +-- contents. This is not checked. +copyReverseCommits :: + MutableByteArray s -- ^ Destination + -> Int -- ^ Destination range successor + -> Commits s -- ^ Source + -> ST s Int +{-# inline copyReverseCommits #-} +copyReverseCommits (MutableByteArray dst) (I# off) cs = ST + (\s0 -> case copyReverseCommits# dst off cs s0 of + (# s1, nextOff #) -> (# s1, I# nextOff #) + ) + +copyReverseCommits# :: + MutableByteArray# s + -> Int# + -> Commits s + -> State# s + -> (# State# s, Int# #) +copyReverseCommits# _ off Initial s0 = (# s0, off #) +copyReverseCommits# marr prevOff (Mutable arr sz cs) s0 = + let !off = prevOff -# sz in + case Exts.copyMutableByteArray# arr 0# marr off sz s0 of + s1 -> copyReverseCommits# marr off cs s1 +copyReverseCommits# marr prevOff (Immutable arr soff sz cs) s0 = + let !off = prevOff -# sz in + case Exts.copyByteArray# arr soff marr off sz s0 of + s1 -> copyReverseCommits# marr off cs s1 + -- | Create a builder from a cons-list of 'Char'. These -- are be UTF-8 encoded. stringUtf8 :: String -> Builder diff --git a/test/Main.hs b/test/Main.hs index a533c65..ddd738e 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -8,7 +8,7 @@ import Control.Applicative (liftA2) import Control.Monad.ST (runST) import Data.ByteArray.Builder -import Data.Bytes.Types (Bytes(Bytes)) +import Data.Bytes.Types (Bytes(Bytes),MutableBytes(MutableBytes)) import Data.Bytes.Chunks (Chunks(ChunksNil,ChunksCons)) import Data.Primitive (PrimArray) import Data.Word @@ -30,7 +30,6 @@ import qualified Data.ByteString.Lazy.Char8 as LB import qualified Data.Bytes.Chunks as Chunks import qualified Data.List as L import qualified Data.Primitive as PM -import qualified Data.Primitive.Unlifted.Array as PM import qualified Data.Text as T import qualified Data.Text.Encoding as TE import qualified GHC.Exts as Exts @@ -204,15 +203,13 @@ tests = testGroup "Tests" [ THU.testCase "A" $ do ref <- newIORef [] let txt = "hello_world_are_you_listening" :: [Char] - putMany 7 ascii txt (ontoRef ref) + putMany 7 ascii txt (bytesOntoRef ref) res <- readIORef ref id $ - [ map c2w "hello_w" - , map c2w "o" - , map c2w "rld_are" - , map c2w "_" - , map c2w "you_lis" - , map c2w "t" + [ map c2w "hello_" + , map c2w "world_" + , map c2w "are_yo" + , map c2w "u_list" , map c2w "ening" ] @=? map Exts.toList (Exts.toList res) ] @@ -222,33 +219,26 @@ tests = testGroup "Tests" let txt = "hello_world_are_you_listening" :: [Char] putManyConsLength Nat.constant (\n -> Bounded.word16BE (fromIntegral n)) - 13 ascii txt (ontoRef ref) + 16 ascii txt (bytesOntoRef ref) res <- readIORef ref id $ - [ 0x00 : 0x0C : map c2w "hello_world" - , map c2w "_" - , 0x00 : 0x0C : map c2w "are_you_lis" - , map c2w "t" - , 0x00 : 0x05 : map c2w "ening" + [ 0x00 : 0x0A : map c2w "hello_worl" + , 0x00 : 0x0A : map c2w "d_are_you_" + , 0x00 : 0x09 : map c2w "listening" ] @=? map Exts.toList (Exts.toList res) ] ] -ontoRef :: +bytesOntoRef :: IORef [PM.ByteArray] - -> PM.UnliftedArray (PM.MutableByteArray Exts.RealWorld) + -> MutableBytes Exts.RealWorld -> IO () -ontoRef !ref xs = do +bytesOntoRef !ref (MutableBytes buf off len) = do rs <- readIORef ref - ps <- PM.foldlUnliftedArrayM' - (\ys buf -> do - len <- PM.getSizeofMutableByteArray buf - dst <- PM.newByteArray len - PM.copyMutableByteArray dst 0 buf 0 len - dst' <- PM.unsafeFreezeByteArray dst - pure (ys ++ [dst']) - ) [] xs - writeIORef ref (rs ++ ps) + dst <- PM.newByteArray len + PM.copyMutableByteArray dst 0 buf off len + dst' <- PM.unsafeFreezeByteArray dst + writeIORef ref (rs ++ [dst']) instance Arbitrary Chunks where arbitrary = do