哪个Rust RNG应该用于多线程采样?



我试图在Rust中创建一个函数,该函数将从M正态分布N次采样。我有下面的顺序版本,运行得很好。我正在尝试使用Rayon并行化它,但遇到错误

Rc<UnsafeCell<ReseedingRng<rand_chacha::chacha::ChaCha12Core, OsRng>>> cannot be sent between threads safely

我的rand::thread_rng似乎没有实现SendSync的特征。我尝试使用StdRngOsRng,两者都这样做,无济于事,因为然后我得到错误,变量predrng不能被借用为可变的,因为它们是在Fn闭包中捕获的。

这是下面的工作代码。当我将第一个into_iter()更改为into_par_iter()时,它会出错。

use rand_distr::{Normal, Distribution};
use std::time::Instant;
use rayon::prelude::*;
fn rprednorm(n: i32, means: Vec<f64>, sds: Vec<f64>) -> Vec<Vec<f64>> {
let mut rng = rand::thread_rng();
let mut preds = vec![vec![0.0; n as usize]; means.len()];
(0..means.len()).into_iter().for_each(|i| {
(0..n).into_iter().for_each(|j| {
let normal = Normal::new(means[i], sds[i]).unwrap();
preds[i][j as usize] = normal.sample(&mut rng);
})
});
preds
}
fn main() {
let means = vec![0.0; 67000];
let sds = vec![1.0; 67000];
let start = Instant::now();
let preds = rprednorm(100, means, sds);
let duration = start.elapsed();

println!("{:?}", duration);
}

关于如何使这两个迭代器并行有什么建议吗?

谢谢。

我的rand::thread_rng似乎没有实现SendSync的特征。

你为什么要发送一个thread_rng?thread_rng的重点在于它是一个单线程RNG。

然后我得到错误,变量pred和rng不能作为可变的借用,因为它们是在Fn闭包中捕获的。

嗯,是的,你需要克隆StdRng(或复制OsRng)到每个闭包。至于pred,由于类似的原因,这不能工作:一旦你并行循环编译器不知道每个i是不同的,所以就它所关心的i的写访问可能重叠(你可以并行运行两个迭代,试图同时写到同一个地方),这是非法的。

解决方案是使用rayon在目标向量上并行迭代:
fn rprednorm(n: i32, means: Vec<f64>, sds: Vec<f64>) -> Vec<Vec<f64>> {
let mut preds = vec![vec![0.0; n as usize]; means.len()];
preds.par_iter_mut().enumerate().for_each(|(i, e)| {
let mut rng = rand::thread_rng();
(0..n).into_iter().for_each(|j| {
let normal = Normal::new(means[i], sds[i]).unwrap();
e[j as usize] = normal.sample(&mut rng);
})
});
preds
}

对于OsRng,它只是一个标记ZST,因此您可以将其作为值引用:

fn rprednorm(n: i32, means: Vec<f64>, sds: Vec<f64>) -> Vec<Vec<f64>> {
let mut preds = vec![vec![0.0; n as usize]; means.len()];
preds.par_iter_mut().enumerate().for_each(|(i, e)| {
(0..n).into_iter().for_each(|j| {
let normal = Normal::new(means[i], sds[i]).unwrap();
e[j as usize] = normal.sample(&mut rand::rngs::OsRng);
})
});
preds
}

StdRng似乎不太适合这个用例,因为你要么必须创建一个每个顶层迭代获得不同的采样,或者你必须初始化一个基本的rng,然后克隆它一次每个火花,他们都有相同的序列(因为他们将共享一个种子)。

最新更新