首先,我想说的是,这是一个学校的作业,我只是寻求一些指导。
我的任务是编写一个算法,使用快速选择找到序列中第k个最小的元素。这应该很容易,但在运行一些测试时,我遇到了瓶颈。由于某种原因,如果我输入(List(1, 1, 1, 1), 1)
,它就会进入无限循环。
这是我的实现:
val rand = new scala.util.Random()
def find(seq: Seq[Int], k: Int): Int = {
require(0 <= k && k < seq.length)
val a: Array[Int] = seq.toArray[Int] // Can't modify the argument sequence
val pivot = rand.nextInt(a.length)
val (low, high) = a.partition(_ < a(pivot))
if (low.length == k) a(pivot)
else if (low.length < k) find(high, k - low.length)
else find(low, k)
}
由于某种原因(或因为我累了),我不能发现我的错误。如果有人能提示我哪里做错了,我会很高兴。
基本上您依赖于这一行- val (low, high) = a.partition(_ < a(pivot))
将数组拆分为2个数组。第一个包含小于pivot-element的连续元素序列,第二个包含其余元素。
然后你说,如果第一个数组的长度为k
,这意味着你已经看到k
个元素小于你的枢轴元素。这意味着pivot-element实际上是k+1
最小的元素,并且实际上返回的是k+1
最小的元素,而不是k
。这是你的第一个错误。
也……当所有元素都相同时,会出现更大的问题,因为第一个数组总是有0个元素。
不仅如此……如果你在k
最小的元素中有重复的元素,比如- (1, 3, 4, 1, 2)
,你的代码会给你错误的答案。
解决方案在于观察序列(1,1,1,1)中4
最小的元素是4
第1
。这意味着你必须使用<=
而不是<
。
也……因为partition
函数在boolean
条件为false之前不会分割数组,所以不能使用partition来实现这个数组分割。您必须自己编写分割