在 Rust 中获取切片或 Vec 中最大或最小浮点值的索引的惯用方法是什么

  • 本文关键字:索引 是什么 方法 获取 Rust 切片 Vec rust
  • 更新时间 :
  • 英文 :


假设 -- Vec<f32>没有任何NaN值或表现出任何NaN行为。

取以下示例集:

0.28  
0.3102
0.9856
0.3679
0.3697
0.46  
0.4311
0.9781
0.9891
0.5052
0.9173
0.932 
0.8365
0.5822
0.9981
0.9977

获取上述列表中最高值的索引(值可以是负数)的最整洁和最稳定的方法是什么?

我最初的尝试如下:

let _tmp = *nets.iter().max_by(|i, j| i.partial_cmp(j).unwrap()).unwrap();    
let _i = nets.iter().position(|&element| element == _tmp).unwrap();

nets&Vec<f32>.在我看来,这似乎是公然不正确的。

Python等效的工作(考虑到上述假设):

_i = nets.index(max(nets))

这有什么原因不起作用吗?

>= 生锈 1.62.0 (2022-06-30)

use std::cmp::Ordering;
   
fn example(nets: &Vec<f32>) {
    let index_of_max: Option<usize> = nets
        .iter()
        .enumerate()
        .max_by(|(_, a), (_, b)| a.total_cmp(b))
        .map(|(index, _)| index);
}

use std::cmp::Ordering;
   
fn example(nets: &Vec<f32>) {
    let index_of_max: Option<usize> = nets
        .iter()
        .enumerate()
        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
        .map(|(index, _)| index);
}

很棘手的原因是因为f32没有实现Ord。这是因为NaN值阻止浮点数形成总订单,这违反了Ord的合约。

有第三方板条箱通过定义不允许包含NaN的数字类型包装器来解决此问题。一个例子是有序浮点数。如果你使用这个 crate 首先准备集合以包含NotNan值,那么你可以编写非常接近你最初想法的代码:

use ordered_float::NotNan;
let non_nan_floats: Vec<_> = nets.iter()
    .cloned()
    .map(NotNan::new)       // Attempt to convert each f32 to a NotNan
    .filter_map(Result::ok) // Unwrap the `NotNan`s and filter out the `NaN` values 
    .collect();
let max = non_nan_floats.iter().max().unwrap();
let index = non_nan_floats.iter().position(|element| element == max).unwrap();

将此添加到Cargo.toml

[dependencies]
ordered-float = "1.0.1"

奖励材料:通过利用NotNan具有透明表示的事实,可以使类型转换真正实现零成本(假设您真的确定没有NaN值!):

let non_nan_floats: Vec<NotNan<f32>> = unsafe { mem::transmute(nets) };

我可能会做这样的事情:

fn main() -> Result<(), Box<std::error::Error>> {
    let samples = vec![
        0.28, 0.3102, 0.9856, 0.3679, 0.3697, 0.46, 0.4311, 0.9781, 0.9891, 0.5052, 0.9173, 0.932,
        0.8365, 0.5822, 0.9981, 0.9977,
    ];
    // Use enumerate to get the index
    let mut iter = samples.iter().enumerate();
    // we get the first entry
    let init = iter.next().ok_or("Need at least one input")?;
    // we process the rest
    let result = iter.try_fold(init, |acc, x| {
        // return None if x is NaN
        let cmp = x.1.partial_cmp(acc.1)?;
        // if x is greater the acc
        let max = if let std::cmp::Ordering::Greater = cmp {
            x
        } else {
            acc
        };
        Some(max)
    });
    println!("{:?}", result);
    Ok(())
}

这可以通过在迭代器上添加一个特征来实现,例如函数 try_max_by .

您可以使用以下内容找到最大值:

let mut max_value = my_vec.iter().fold(0.0f32, |max, &val| if val > max{ val } else{ max });

找到max_value后,您可以跟踪其在向量本身中的位置:

let index = my_vec.iter().position(|&r| r == max_value).unwrap();

要获得此结果,您需要对同一向量进行两次迭代。为了提高性能,您可以在fold迭代中返回索引值,最大值作为元组。

操场

我从 Alkan @Akiner那里得到了答案并对其进行了一点调整,这是一个没有任何解开包装的简单单行代码,可以完成这项工作:

let maxi = my_vec.iter().enumerate().fold((0, 0.0), |max, (ind, &val)| if val > max.1 {(ind, val)} else {max});

(PS:Rust 新手和 StackOverflow 中的第一篇文章,如果我做错了,请不要评判我:D)

最新更新