Numba通过影响到位来破坏数据



Numba和NumPy不以相同的方式执行以下foo函数:

from numba import jit
import numpy as np
@jit
def foo(a):
a[:] = a[::-1] # reverse the array
a = np.array([0, 1, 2])
foo(a)
print(a)

使用NumPy(不带@jit)可打印[2, 1, 0],而使用Numba(带@jit)则打印[2, 1, 2]。看起来Numba在适当的位置修改了数组,这导致了数据损坏。通过复制阵列很容易解决问题:

a[:] = a[::-1].copy()

但这是想要的行为吗?Numba和NumPy不应该给出相同的结果吗?

我在Python 3.5.2中使用Numba v0.26.0。

这是一个已知的问题(https://github.com/numba/numba/issues/1960)并固定在numba 0.27。根据NumPy行为,修复程序检测重叠并制作临时副本以避免损坏数据。

您的jit具有与此Python循环相同的就地问题。

In [718]: x=list(range(3))
In [719]: for i in range(3):
...:     x[i] = x[2-i]
In [720]: x
Out[720]: [2, 1, 2]

x[:] = x[::-1]被缓冲,不是因为numpy识别出发生了一些特殊的事情,而是因为它总是在执行赋值时使用某种缓冲。

Python解释器将[]表示法转换为对__setitem____getitem__的调用。因此681和682做了相同的事情:

In [680]: x=np.arange(3)
In [681]: x[:] = x[::-1]
In [682]: x.__setitem__(slice(None), x.__getitem__(slice(None,None,-1)))
In [683]: x
Out[683]: array([0, 1, 2])

这意味着x[::-1]在被复制到x[:]之前会被完全求值到一个临时数组。现在x[::-1]是一个视图,而不是一个副本,所以setitem步骤必须执行某种缓冲复制。

另一种复制方法是使用

np.copyto(x, x[::-1])

检查x.__array_interface__,我发现数据缓冲区地址保持不变。因此,它正在进行复制,而不仅仅是更改数据缓冲区地址。但它是在低级别编译的代码中。

通常,缓冲只是一个实现问题,用户不需要担心。ufunc.at旨在处理缓冲产生问题的情况。这个话题会定期出现;搜索CCD_ 16。

===============

请注意,Python列表的行为方式相同。"get/setitem"的翻译是相同的。

In [699]: x=list(range(3))
In [700]: x[:] = x[::-1]
In [701]: x
Out[701]: [2, 1, 0]

==========================

我不完全确定这是否相关,但既然我测试了这些想法,我会把它们记录下来。https://docs.scipy.org/doc/numpy/reference/arrays.nditer.html建议使用CCD_ 17作为在CCD_ 18中实现迭代任务的垫脚石。

使用nditer的第一个尝试是:

In [769]: x=np.arange(5)
In [770]: it = np.nditer((x,x[::-1]), op_flags=[['readwrite'], ['readonly']])
In [771]: for i,j in it:
...:     print(i,j)
...:     i[...] = j
...:     
0 4
1 3
2 2
3 3
4 4
In [772]: x
Out[772]: array([4, 3, 2, 3, 4])

这将产生与numba相同类型的重叠结果。

添加副本可以实现完全反转。

it = np.nditer((x,x[::-1].copy()), op_flags=[['readwrite'], ['readonly']])

如果我加上external_loop标志,我也会得到一个干净的反转:

In [781]: x=np.arange(5)
In [782]: it = np.nditer((x,x[::-1]), op_flags=[['readwrite'], ['readonly']], fl
...: ags = ['external_loop'])
In [783]: for i,j in it:
...:     print(i,j)
...:     i[...] = j
...:     
[0 1 2 3 4] [4 3 2 1 0]
In [784]: x
Out[784]: array([4, 3, 2, 1, 0])

最新更新