将手动realWorld#状态传递与任意Monad交错是否安全



考虑这个为任意Monad:生成列表的函数

generateListM :: Monad m => Int -> (Int -> m a) -> m [a]
generateListM sz f = go 0
where go i | i < sz = do x <- f i
xs <- go (i + 1)
return (x:xs)
| otherwise = pure []

实现可能并不完美,但它在这里只是为了演示所需的效果,这非常简单。例如,如果monad是一个列表,那么获取列表列表:

λ> generateListM 3 (i -> [0 :: Int64 .. fromIntegral i])
[[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,1,2]]

我想做的是实现同样的效果,但对于ByteArray而不是List。事实证明,这比我第一次偶然发现这个问题时想象的要棘手得多。最终目标是使用该生成器在massiv中实现mapM,但这不是重点。

需要最少精力的方法是使用矢量包中的函数generateM,同时进行一些手动转换。但事实证明,有一种方法可以通过手动处理状态令牌并将其与monad:交织的巧妙小技巧来实现至少x2倍的性能增益

{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UnboxedTuples       #-}
import           Data.Primitive.ByteArray
import           Data.Primitive.Types
import qualified Data.Vector.Primitive    as VP
import           GHC.Int
import           GHC.Magic
import           GHC.Prim
-- | Can't `return` unlifted types, so we need a wrapper for the state and MutableByteArray
data MutableByteArrayState s = MutableByteArrayState !(State# s) !(MutableByteArray# s)
generatePrimM :: forall m a . (Prim a, Monad m) => Int -> (Int -> m a) -> m (VP.Vector a)
generatePrimM (I# sz#) f =
runRW# $ s0# -> do
let go i# = do
case i# <# sz# of
0# ->
case newByteArray# (sz# *# sizeOf# (undefined :: a)) (noDuplicate# s0#) of
(# s1#, mba# #) -> return (MutableByteArrayState s1# mba#)
_ -> do
res <- f (I# i#)
MutableByteArrayState si# mba# <- go (i# +# 1#)
return (MutableByteArrayState (writeByteArray# mba# i# res si#) mba#)
MutableByteArrayState s# mba# <- go 0#
case unsafeFreezeByteArray# mba# s# of
(# _, ba# #) -> return (VP.Vector 0 (I# sz#) (ByteArray ba#))

我们可以像以前一样使用它,只是现在我们将获得一个由ByteArray支持的原始Vector,这正是我真正需要的:

λ> generatePrimM 3 (i -> [0 :: Int64 .. fromIntegral i])
[[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,1,2]]

这似乎很有效,在ghc 8.0和8.2版本中表现良好,只是8.4和8.6中有回归,但这个问题是正交的。

最后我谈到了实际的问题。这种方法真的安全吗?有没有一些我不知道的边缘案例可能会在以后咬我?欢迎就上述职能提出任何其他建议或意见。

PS。m不必局限于MonadApplicative也可以很好地工作,但当使用do语法时,示例会更清晰一些。

TLDR从我目前收集的信息来看,以我最初提出的方式生成原始Vector似乎是一种安全的方法。此外,noDuplicate#的使用并不是真正必要的,因为所有的运算都是幂等的,并且运算的顺序不会对结果数组产生影响。

披露:我第一次想到这个问题已经一年多了。直到上个月,我才试图回到它。我之所以这么说,是因为现在查看原始包时,我注意到了一个新的模块Data.Primitive.PrimArray。正如@chi在评论中提到的那样,实际上没有必要为了获得解决方案而下降到低级基元,因为它可能已经存在了。它正好包含generatePrimArrayA函数,这正是我想要的(源代码的一个简化副本(:

newtype STA a = STA {_runSTA :: forall s. MutableByteArray# s -> ST s (PrimArray a)}
runSTA :: forall a. Prim a => Int -> STA a -> PrimArray a
runSTA !sz =
(STA m) -> runST $ newPrimArray sz >>= (ar :: MutablePrimArray s a) -> m (unMutablePrimArray ar)
generatePrimArrayA :: (Applicative f, Prim a) => Int -> (Int -> f a) -> f (PrimArray a)
generatePrimArrayA len f =
let go !i
| i == len = pure $ STA $ mary -> unsafeFreezePrimArray (MutablePrimArray mary)
| otherwise =
liftA2
(b (STA m) -> STA $ mary -> writePrimArray (MutablePrimArray mary) i b >> m mary)
(f i)
(go (i + 1))
in runSTA len <$> go 0

作为一个有趣的练习,如果我们用通常的归约规则进行基本的简化,我们会得到一个与我最初得到的非常相似的东西:

generatePrimArrayA :: forall f a. (Applicative f, Prim a) => Int -> (Int -> f a) -> f (PrimArray a)
generatePrimArrayA !(I# n#) f =
let go i# = case i# <# n# of
0# -> pure $ mary s# ->
case unsafeFreezeByteArray# mary s# of
(# s'#, arr'# #) -> (# s'#, PrimArray arr'# #)
_ -> liftA2
(b m ->
mary s ->
case writeByteArray# mary i# b s of
s'# -> m mary s'#)
(f (I# i#))
(go (i# +# 1#))
in (m -> runRW# $ s0# ->
case newByteArray# (n# *# sizeOf# (undefined :: a)) s0# of
(# s'#, arr# #) -> case m arr# s'# of
(# _, a #) -> a)
<$> go 0#

这是我为Applicative而不是Monad:调整的版本

generatePrimM :: forall m a . (Prim a, Applicative m) => Int -> (Int -> m a) -> m (PrimArray a)
generatePrimM (I# sz#) f =
let go i# = case i# <# sz# of
0# -> runRW# $ s0# ->
case newByteArray# (sz# *# sizeOf# (undefined :: a)) s0# of
(# s1#, mba# #) -> pure (MutableByteArrayState s1# mba#)
_  -> liftA2
(b (MutableByteArrayState si# mba#) ->
MutableByteArrayState (writeByteArray# mba# i# b si#) mba#)
(f (I# i#))
(go (i# +# 1#))
in ((MutableByteArrayState s# mba#) ->
case unsafeFreezeByteArray# mba# s# of
(# _, ba# #) -> PrimArray ba#) <$>
(go 0#)

在功能和性能方面,它们非常接近,最终它们都会产生完全相同的答案。不同之处在于内部环路go最终产生了什么。后者将返回一个包含可以构建MutableByteArray#s的闭包的应用程序,该闭包稍后将被冻结。前者有一个循环,返回一个包含将创建冻结ByteArray#s的操作的应用程序,一旦向其提供了可以创建MutableByteArray#的操作。

然而,使这两种方法都安全的原因是,循环中每个生成的数组的每个元素都只被写入一次,并且创建的每个MutableByteArray#在被生成函数返回之前都会被冻结,但在它完成对它们的写入之前不会被冻结。

最新更新