卡拉苏巴算法太多递归



我正在尝试在c ++中实现Karatsuba乘法算法,但现在我只是想让它在python中工作。

这是我的代码:

def mult(x, y, b, m):
    if max(x, y) < b:
        return x * y
    bm = pow(b, m)
    x0 = x / bm
    x1 = x % bm
    y0 = y / bm
    y1 = y % bm
    z2 = mult(x1, y1, b, m)
    z0 = mult(x0, y0, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0

我不明白的是:z2z1z0应该如何创建?递归使用 mult 函数是否正确?如果是这样,我在某处搞砸了,因为递归没有停止。

有人可以指出错误在哪里吗?

注意:下面的回答直接解决了OP关于 过度递归,但它不会尝试提供正确的 卡拉苏巴算法。 其他回答的信息量要大得多 这方面。

试试这个版本:

def mult(x, y, b, m):
    bm = pow(b, m)
    if min(x, y) <= bm:
        return x * y
    # NOTE the following 4 lines
    x0 = x % bm
    x1 = x / bm
    y0 = y % bm
    y1 = y / bm
    z0 = mult(x0, y0, b, m)
    z2 = mult(x1, y1, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
    retval = mult(mult(z2, bm, b, m) + z1, bm, b, m) + z0
    assert retval == x * y, "%d * %d == %d != %d" % (x, y, x * y, retval)
    return retval

您的版本最严重的问题是 x0 和 x1 以及 y0 和 y1 的计算被翻转。 此外,如果x1y1为 0,则算法的推导不成立,因为在这种情况下,因式分解步骤将变得无效。 因此,必须通过确保 x 和 y 都大于 b**m 来避免这种可能性。

编辑:修复了代码中的拼写错误;添加了澄清

编辑2:

为了更清楚起见,直接评论您的原始版本:

def mult(x, y, b, m):
    # The termination condition will never be true when the recursive 
    # call is either
    #    mult(z2, bm ** 2, b, m)
    # or mult(z1, bm, b, m)
    #
    # Since every recursive call leads to one of the above, you have an
    # infinite recursion condition.
    if max(x, y) < b:
        return x * y
    bm = pow(b, m)
    # Even without the recursion problem, the next four lines are wrong
    x0 = x / bm  # RHS should be x % bm
    x1 = x % bm  # RHS should be x / bm
    y0 = y / bm  # RHS should be y % bm
    y1 = y % bm  # RHS should be y / bm
    z2 = mult(x1, y1, b, m)
    z0 = mult(x0, y0, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0

通常大数字存储为整数数组。每个整数代表一个数字。这种方法允许将任何数字乘以基数的幂,只需将数组左移。

这是我基于列表的实现(可能包含错误(:

def normalize(l,b):
    over = 0
    for i,x in enumerate(l):
        over,l[i] = divmod(x+over,b)
    if over: l.append(over)
    return l
def sum_lists(x,y,b):
    l = min(len(x),len(y))
    res = map(operator.add,x[:l],y[:l])
    if len(x) > l: res.extend(x[l:])
    else: res.extend(y[l:])
    return normalize(res,b)
def sub_lists(x,y,b):
    res = map(operator.sub,x[:len(y)],y)
    res.extend(x[len(y):])
    return normalize(res,b)
def lshift(x,n):
    if len(x) > 1 or len(x) == 1 and x[0] != 0:
        return [0 for i in range(n)] + x
    else: return x
def mult_lists(x,y,b):
    if min(len(x),len(y)) == 0: return [0]
    m = max(len(x),len(y))
    if (m == 1): return normalize([x[0]*y[0]],b)
    else: m >>= 1
    x0,x1 = x[:m],x[m:]
    y0,y1 = y[:m],y[m:]
    z0 = mult_lists(x0,y0,b)
    z1 = mult_lists(x1,y1,b)
    z2 = mult_lists(sum_lists(x0,x1,b),sum_lists(y0,y1,b),b)
    t1 = lshift(sub_lists(z2,sum_lists(z1,z0,b),b),m)
    t2 = lshift(z1,m*2)
    return sum_lists(sum_lists(z0,t1,b),t2,b)

sum_listssub_lists 返回非规范化结果 - 个位数可以大于基值。 normalize函数解决了这个问题。

所有函数都希望以相反的顺序获取数字列表。例如,以 10 为底的 12 应写为 [2,1]。让我们拿一个9987654321的平方。

» a = [1,2,3,4,5,6,7,8,9]
» res = mult_lists(a,a,10)
» res.reverse()
» res
[9, 7, 5, 4, 6, 1, 0, 5, 7, 7, 8, 9, 9, 7, 1, 0, 4, 1]

Karatsuba 乘法的目标是通过进行 3 次递归调用而不是 4 次来改进除法和征服乘法算法。因此,脚本中唯一应该包含对乘法的递归调用的行是那些分配z0z1z2的行。 其他任何事情都会给你带来更糟糕的复杂性。当你还没有定义乘法(更何况是幂(时,也不能使用 pow 来计算 bm

为此,该算法关键地使用了它使用位置符号系统的事实。如果你有一个以 b 为底的数字的表示 x,那么 x*bm 只需将该表示的数字向左移动 m 次即可获得。这种移位操作对于任何位置符号系统基本上都是"自由的"。这也意味着,如果你想实现这一点,你必须重现这个位置符号,以及"自由"的移位。要么你选择以 b=2 为基数计算并使用 python 的位运算符(或给定小数、十六进制、...如果你的测试平台有它们,则为基础(,或者你决定出于教育目的实现一些适用于任意 B 的东西,并且你用字符串、数组或列表等东西重现这个位置算术

您已经有一个带有列表的解决方案。我喜欢在python中使用字符串,因为int(s, base)会给你与字符串相对应的整数s在基数base中被视为数字表示:它使测试变得容易。我在这里发布了一个大量注释的基于字符串的实现作为要点,包括字符串到数字和数字到字符串原语。

您可以通过提供带有基数及其(相等(长度的填充字符串作为参数来测试它mult

In [169]: mult("987654321","987654321",10,9)
Out[169]: '966551847789971041'

如果您不想弄清楚填充或计算字符串长度,填充函数可以为您完成:

In [170]: padding("987654321","2")
Out[170]: ('987654321', '000000002', 9)

当然,它适用于b>10

In [171]: mult('987654321', '000000002', 16, 9)
Out[171]: '130eca8642'

(与 wolfram alpha 核对(

我相信

该技术背后的想法是使用递归算法计算zi项,但结果不是以这种方式统一在一起的。 由于您想要的最终结果是

z0 B^2m + z1 B^m + z2

假设你选择一个合适的值 B(比如 2(,你可以计算 B^m 而不做任何乘法。 例如,当使用 B = 2 时,您可以使用位移而不是乘法来计算 B^m。 这意味着最后一步可以在不做任何乘法的情况下完成。

还有一件事 - 我注意到你为整个算法选择了固定值 m。 通常,您将通过使 m 始终是一个值来实现此算法,这样当 x 和 y 以 B 为基数写入时,B^m 是 x 和 y 中位数的一半。 如果您使用的是 2 的幂,这将通过选择 m = ceil((log x(/2( 来完成。

希望这有帮助!

在 Python 2.7 中:将此文件另存为 Karatsuba.py

   def karatsuba(x,y):
        """Karatsuba multiplication algorithm.
        Return the product of two numbers in an efficient manner
        @author Shashank
        date: 23-09-2018
        Parameters
        ----------
        x : int
            First Number 
        y : int
            Second Number   
        Returns
        -------
        prod : int
               The product of two numbers 
        Examples
        --------
        >>> import Karatsuba.karatsuba
        >>> a = 1234567899876543211234567899876543211234567899876543211234567890
        >>> b = 9876543211234567899876543211234567899876543211234567899876543210
        >>> Karatsuba.karatsuba(a,b)
        12193263210333790590595945731931108068998628253528425547401310676055479323014784354458161844612101832860844366209419311263526900
        """
        if len(str(x)) == 1 or len(str(y)) == 1:
            return x*y
        else:
            n = max(len(str(x)), len(str(y)))
            m = n/2
            a = x/10**m
            b = x%10**m
            c = y/10**m
            d = y%10**m
            ac = karatsuba(a,c)                             #step 1
            bd = karatsuba(b,d)                             #step 2
            ad_plus_bc = karatsuba(a+b, c+d) - ac - bd      #step 3
            prod = ac*10**(2*m) + bd + ad_plus_bc*10**m     #step 4
            return prod

最新更新