Scala:使用foldRight实现flatMap



我很难理解函数式编程练习的解决方案:

仅使用foldRightNil::实现flatMap(cons)。

解决方案如下:

def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = 
xs.foldRight(List[B]())((outCurr, outAcc) =>
f(outCurr).foldRight(outAcc)((inCurr, inAcc) => inCurr :: inAcc))

我曾尝试将匿名函数纳入函数定义中,以重写解决方案,但没有成功。我无法理解正在发生的事情,也无法想办法将其分解,这样就不那么复杂了。因此,如有任何关于解决方案的帮助或解释,我们将不胜感激。

谢谢!

首先,忽略约束,在这种情况下考虑flatMap函数。您有一个List[A]和一个函数f: A => List[B]。通常,如果你只在列表上做一个map并应用f函数,你会得到一个List[List[B]],对吧?那么,要获得List[B],你会怎么做呢?您可以在List[List[B]]上添加foldRight,只需将List[List[B]]中的所有元素追加即可返回List[B]。所以代码看起来有点像这样:

def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
val tmp = xs.map(f) // List[List[B]]
tmp.foldRight(List[B]())((outCurr, outAcc) => outCurr ++ outAcc)
}

为了验证我们目前所拥有的,在REPL中运行代码,并根据内置的flatMap方法验证结果:

scala> def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
|     val tmp = xs.map(f) // List[List[B]]
|     tmp.foldRight(List[B]())((outCurr, outAcc) => outCurr ++ outAcc)
| }
flatMap: [A, B](xs: List[A])(f: A => List[B])List[B]
scala> flatMap(List(1, 2, 3))(i => List(i, 2*i, 3*i))
res0: List[Int] = List(1, 2, 3, 2, 4, 6, 3, 6, 9)
scala> List(1,2,3).flatMap(i => List(i, 2*i, 3*i))
res1: List[Int] = List(1, 2, 3, 2, 4, 6, 3, 6, 9)

好的,现在,看看我们的约束,我们不允许在这里使用map。但我们并不真的需要,因为这里的map只是用于迭代列表xs。然后,我们可以将foldRight用于相同的目的。因此,让我们使用foldRight:重写map部分

def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
val tmp = xs.foldRight(List[List[B]]())((curr, acc) => f(curr) :: acc) // List[List[B]]
tmp.foldRight(List[B]())((outCurr, outAcc) => outCurr ++ outAcc)
}

好的,让我们验证一下新代码:

scala> def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
|         val tmp = xs.foldRight(List[List[B]]())((curr, acc) => f(curr) :: acc) // List[List[B]]
|         tmp.foldRight(List[B]())((outCurr, outAcc) => outCurr ++ outAcc)
|     }
flatMap: [A, B](xs: List[A])(f: A => List[B])List[B]
scala> flatMap(List(1, 2, 3))(i => List(i, 2*i, 3*i))
res3: List[Int] = List(1, 2, 3, 2, 4, 6, 3, 6, 9)

好的,到目前为止还不错。因此,让我们对代码进行一点优化,这就是我们将把它们组合成一个foldRight,而不是按顺序有两个foldRight。这不应该太难:

def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
xs.foldRight(List[B]()) { (curr, acc) => // Note: acc is List[B]
val tmp2 = f(curr) // List[B]
tmp2 ++ acc
}
}

再次验证:

scala> def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
|     xs.foldRight(List[B]()) { (curr, acc) => // Note: acc is List[B]
|         val tmp2 = f(curr) // List[B]
|         tmp2 ++ acc
|     }
| }
flatMap: [A, B](xs: List[A])(f: A => List[B])List[B]
scala> flatMap(List(1, 2, 3))(i => List(i, 2*i, 3*i))
res4: List[Int] = List(1, 2, 3, 2, 4, 6, 3, 6, 9)

好的,让我们看看我们的约束,看起来我们不能使用++运算。好吧,++只是将两个List[B]附加在一起的一种方法,所以我们当然可以使用foldRight方法来实现同样的事情,比如:

def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = {
xs.foldRight(List[B]()) { (curr, acc) => // Note: acc is List[B]
val tmp2 = f(curr) // List[B]
tmp2.foldRight(acc)((inCurr, inAcc) => inCurr :: inAcc)
}
}

然后,我们可以通过以下方式将它们合并为一行:

def flatMap[A, B](xs: List[A])(f: A => List[B]): List[B] = 
xs.foldRight(List[B]())((curr, acc) =>
f(curr).foldRight(acc)((inCurr, inAcc) => inCurr :: inAcc))

这不是给出的答案吗:)

有时用一个简单的例子更容易理解:假设我们有一个val xs = List[Int](1,2,3)和一个函数f: Int => List[Int], f(x) = List(x,x) (lambda x => List(x,x))将f应用于xsList(f(1),f(2), f(3))的每个元素将产生List(List(1,1),List(2,2),List(3,3)),因此我们需要将此List[List[Int]]展平。最终结果应该是List(1,1,2,2,3,3)。如果Cons(:)是非空列表的构造函数,那么它应该是Cons(1, Cons(1, Cons(2, Cons(2, Cons(3, Cons(3, Nil))))))。观察结果中的foldRight操作,该操作将构造函数Cons(::)应用于应用于列表xs的每个元素的f的结果。因此flatMap的第一个实现将是

def flatMap[A,B](xs: List[A])(f: A => List[B]): List[B] = xs match {
case Cons(head, tail) => foldRight(f(head), Nil)((a,b) => Cons(a,b))
}

在这种形式中,flatMap(List(1,2,3))将返回List(1,1)或Cons(1,1,Nil)(通过替换)。因此,我们需要继续在尾部递归调用flatMap(将问题从3(元素)减少1到2(元素)),并将空列表的基本情况添加为Nil(Nil是Cons操作的"零"元素)

def flatMap[A,B](xs: List[A])(f: A => List[B]): List[B] = xs match {
case Nil => Nil
case Cons(head, tail) => foldRight(f(head), flatMap(tail)(f))((a,b) => Cons(a,b))
}

这是最终实现。

最新更新