Rust字符串比较与Python速度相同.想要使程序并行化



我是个新手。我想写一个函数,稍后可以使用pyo3机箱将其作为模块导入Python。

下面是我想在Rust:中实现的函数的Python实现

def pcompare(a, b):
letters = []
for i, letter in enumerate(a):
if letter != b[i]:
letters.append(f'{letter}{i + 1}{b[i]}')
return letters

我写的第一个Rust实现看起来是这样的:

use pyo3::prelude::*;

#[pyfunction]
fn compare_strings_to_vec(a: &str, b: &str) -> PyResult<Vec<String>> {
if a.len() != b.len() {
panic!(
"Reads are not the same length! 
First string is length {} and second string is length {}.",
a.len(), b.len());
}
let a_vec: Vec<char> = a.chars().collect();
let b_vec: Vec<char> = b.chars().collect();
let mut mismatched_chars = Vec::new();
for (mut index,(i,j)) in a_vec.iter().zip(b_vec.iter()).enumerate() {
if i != j {
index += 1;
let mutation = format!("{i}{index}{j}");
mismatched_chars.push(mutation);
} 
}
Ok(mismatched_chars)
}

#[pymodule]
fn compare_strings(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(compare_strings_to_vec, m)?)?;
Ok(())
}

它是我在--release模式下构建的。该模块可以导入到Python中,但性能与Python实现的性能非常相似。

我的第一个问题是:为什么Python和Rust函数的速度相似?

现在我正在Rust中进行并行化实现。仅打印结果变量时,函数工作:

use rayon::prelude::*;
fn main() {

let a: Vec<char> = String::from("aaaa").chars().collect();
let b: Vec<char> = String::from("aaab").chars().collect();
let length = a.len();
let index: Vec<_> = (1..=length).collect();

let mut mismatched_chars: Vec<String> = Vec::new();

(a, index, b).into_par_iter().for_each(|(x, i, y)| {
if x != y {
let mutation = format!("{}{}{}", x, i, y).to_string();
println!("{mutation}");
//mismatched_chars.push(mutation);
}
});

}

然而,当我试图将突变变量推送到mismatched_chars载体时:

use rayon::prelude::*;
fn main() {

let a: Vec<char> = String::from("aaaa").chars().collect();
let b: Vec<char> = String::from("aaab").chars().collect();
let length = a.len();
let index: Vec<_> = (1..=length).collect();

let mut mismatched_chars: Vec<String> = Vec::new();

(a, index, b).into_par_iter().for_each(|(x, i, y)| {
if x != y {
let mutation = format!("{}{}{}", x, i, y).to_string();
//println!("{mutation}");
mismatched_chars.push(mutation);
}
});

}

我得到以下错误:

error[E0596]: cannot borrow `mismatched_chars` as mutable, as it is a captured variable in a `Fn` closure
--> src/main.rs:16:13
|
16 |             mismatched_chars.push(mutation);
|             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot borrow as mutable
For more information about this error, try `rustc --explain E0596`.
error: could not compile `testing_compare_strings` due to previous error

我尝试了很多不同的东西。当我这样做时:

use rayon::prelude::*;
fn main() {

let a: Vec<char> = String::from("aaaa").chars().collect();
let b: Vec<char> = String::from("aaab").chars().collect();
let length = a.len();
let index: Vec<_> = (1..=length).collect();

let mut mismatched_chars: Vec<&str> = Vec::new();

(a, index, b).into_par_iter().for_each(|(x, i, y)| {
if x != y {
let mutation = format!("{}{}{}", x, i, y).to_string();
mismatched_chars.push(&mutation);
}
});

}

错误变为:

error[E0596]: cannot borrow `mismatched_chars` as mutable, as it is a captured variable in a `Fn` closure
--> src/main.rs:16:13
|
16 |             mismatched_chars.push(&mutation);
|             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot borrow as mutable
error[E0597]: `mutation` does not live long enough
--> src/main.rs:16:35
|
10 |     let mut mismatched_chars: Vec<&str> = Vec::new();
|         -------------------- lifetime `'1` appears in the type of `mismatched_chars`
...
16 |             mismatched_chars.push(&mutation);
|             ----------------------^^^^^^^^^-
|             |                     |
|             |                     borrowed value does not live long enough
|             argument requires that `mutation` is borrowed for `'1`
17 |         }
|         - `mutation` dropped here while still borrowed

我怀疑解决办法很简单,但我自己看不出来。

您对正在做的事情有正确的想法,但您希望尝试使用带有filtermap的迭代器链来删除迭代器项或将其转换为不同的值。Rayon还提供了类似于常规迭代器的collect方法,以将项转换为类型T: FromIterator(例如Vec<T>(。

fn compare_strings_to_vec(a: &str, b: &str) -> Vec<String> {
// Same as with the if statement, but just a little shorter to write
// Plus, it will print out the two values it is comparing if it errors.
assert_eq!(a.len(), b.len(), "Reads are not the same length!");

// Zip the character iterators from a and b together
a.chars().zip(b.chars())
// Iterate with the index of each item
.enumerate()
// Rayon function which turns a regular iterator into a parallel one 
.par_bridge()
// Filter out values where the characters are the same
.filter(|(_, (a, b))| a != b)
// Convert the remaining values into an error string
.map(|(index, (a, b))| {
format!("{}{}{}", a, index + 1, b)
})
// Turn the items of this iterator into a Vec (Or any other FromIterator type).
.collect()
}

Rust Playground

优化速度

另一方面,如果你想要速度,我们需要从不同的方向来解决这个问题。您可能已经注意到了,但rayon版本的速度相当慢,因为生成线程和使用并发结构的成本比简单地比较原始线程中的字节要高出几个数量级。在我的基准测试中,我发现即使有更好的工作负载分布,当字符串长度至少为1-2百万字节时,额外的线程也只对我的机器(64GB RAM,16核(有用。考虑到您已经说过它们通常约为30000字节长,我认为使用rayon(或任何其他用于比较这种大小的线程(只会减慢代码的速度。

使用criterion进行基准测试,我最终实现了这个实现。它通常在具有10个不同字节的30000个字符的字符串上每次运行大约获得2.8156µs。为了进行比较,在相同的条件下,原始问题中发布的代码在我的系统上通常会达到61.156µs左右,因此这应该会加速约20倍。它可能会有一些变化,但它始终在基准测试中获得最佳结果。我猜这应该足够快,使这一步骤不再是代码中的瓶颈。

这个实现的重点是批量进行比较。我们可以利用大多数CPU上的128位寄存器来比较16字节批中的输入。一旦发现不等式,它所覆盖的16字节部分将被重新扫描,以确定差异的确切位置。这大大提高了性能。我最初认为usize会更好地工作,但事实似乎并非如此。我还尝试使用portable_simd夜间功能来编写此代码的simd版本,但我无法与此代码的速度相匹配。我怀疑这可能是由于错过了优化,或者我缺乏有效使用simd的经验。

我担心由于u128值没有强制执行块对齐而导致速度下降,但这似乎基本上不是问题。首先,通常很难找到愿意分配给不是系统字大小倍数的地址的分配器。当然,这是由于实用性,而不是任何实际要求。当我手动给它未对齐的切片(u128s未对齐(时,它没有受到显著影响。这就是为什么我不试图强制切片的起始索引与align_of::<u128>()对齐。

fn compare_strings_to_vec(a: &str, b: &str) -> Vec<String> {
let a_bytes = a.as_bytes();
let b_bytes = b.as_bytes();
let remainder = a_bytes.len() % size_of::<u128>();
// Strongly suggest to the compiler we are iterating though u128
a_bytes
.chunks_exact(size_of::<u128>())
.zip(b_bytes.chunks_exact(size_of::<u128>()))
.enumerate()
.filter(|(_, (a, b))| {
let a_block: &[u8; 16] = (*a).try_into().unwrap();
let b_block: &[u8; 16] = (*b).try_into().unwrap();
u128::from_ne_bytes(*a_block) != u128::from_ne_bytes(*b_block)
})
.flat_map(|(word_index, (a, b))| {
fast_path(a, b).map(move |x| word_index * size_of::<u128>() + x)
})
.chain(
fast_path(
&a_bytes[a_bytes.len() - remainder..],
&b_bytes[b_bytes.len() - remainder..],
)
.map(|x| a_bytes.len() - remainder + x),
)
.map(|index| {
format!(
"{}{}{}",
char::from(a_bytes[index]),
index + 1,
char::from(b_bytes[index])
)
})
.collect()
}
/// Very similar to regular route, but with nothing fancy, just get the indices of the overlays
#[inline(always)]
fn fast_path<'a>(a: &'a [u8], b: &'a [u8]) -> impl 'a + Iterator<Item = usize> {
a.iter()
.zip(b.iter())
.enumerate()
.filter_map(|(x, (a, b))| (a != b).then_some(x))
}

在多线程环境中不能直接访问字段mismatched_chars

您可以使用Arc<RwLock>来访问多线程中的字段。

use rayon::prelude::*;
use std::sync::{Arc, RwLock};
fn main() {
let a: Vec<char> = String::from("aaaa").chars().collect();
let b: Vec<char> = String::from("aaab").chars().collect();
let length = a.len();
let index: Vec<_> = (1..=length).collect();
let mismatched_chars: Arc<RwLock<Vec<String>>> = Arc::new(RwLock::new(Vec::new()));
(a, index, b).into_par_iter().for_each(|(x, i, y)| {
if x != y {
let mutation = format!("{}{}{}", x, i, y);
mismatched_chars
.write()
.expect("could not acquire write lock")
.push(mutation);
}
});
for mismatch in mismatched_chars
.read()
.expect("could not acquire read lock")
.iter()
{
eprintln!("{}", mismatch);
}
}

最新更新