使用半值高效地索引numpy数组



我想用整数值和半数值对numpy数组进行索引。这大致就是我的想法:

>>> a = HalfIndexedArray([[0,1,2,3],[10,11,12,13],[20,21,22,23]])
>>> print(a[0,0])
0
>>> print(a[0.5,0.5])
11
>>> print(a[0,1+0.5])
3

只有一半的值会被用作索引,所以我想可以构建某种包装器,将值存储在整数索引中,这些整数索引是通过乘以给定的分数索引来访问的。

可以构造某种帮助函数,从而相应地获取和设置值;然而,如果仍然可以使用本地numpy索引功能(切片等(,那就更好了。此外,大量的开销实际上是不可接受的,因为这些数组将用于数值计算。

这基本上是我正在寻找的一些有效的语法糖:这是可以在numpy(或者更广泛地说,在Python(中实现的吗?

编辑:为了更清楚:我的目标是同时使用整数和半指数。重点是让代码更干净,因为在数值求解器中,半索引将对应于时间步进过程中的半步(其中也存在整数步(。例如,这里是一个简单的例子,在数学符号中经常使用分数指数。

编辑#2:根据@matszwecja的建议,我尝试重新实现numpy.ndarray__(get/set)item__函数,如下所示:

import numpy as np
class half_indexed_ndarray(np.ndarray):
def __getitem__(self, key):
print('Getting by key {0}'.format(key))
if isinstance(key, tuple):
tuple_double = tuple(int(i*2) for i in key)
return super(half_indexed_ndarray, self).__getitem__(tuple_double)
if isinstance(key, int) or isinstance(key, float):
return super(half_indexed_ndarray, self).__getitem__(int(key * 2))

def __setitem__(self, key, value):
print('Setting by key {0}'.format(key))
if isinstance(key, tuple):
tuple_double = tuple(int(i*2) for i in key)
return super(half_indexed_ndarray, self).__setitem__(tuple_double, value)
if isinstance(key, int) or isinstance(key, float):
return super(half_indexed_ndarray, self).__setitem__(int(key * 2), value)

对于简单的索引,这确实有效:

a = half_indexed_ndarray((3,3))
a[0,0]=1
a[0,0.5]=5
a[0.5,0.5]=505
assert a[0,0]==1
assert a[0,0.5]==5
assert a[0.5,0.5]==505

然而,索引范围还不起作用,numpy的行为有点令人费解。例如:

>>> print(a[-3,])
Getting by key 0
Getting by key (-3,)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-239-23821ec76b06> in <module>
----> 1 print(a[0])
~/anaconda3/lib/python3.8/site-packages/numpy/core/arrayprint.py in _array_str_implementation(a, max_line_width, precision, suppress_small, array2string)
1504         return _guarded_repr_or_str(np.ndarray.__getitem__(a, ()))
1505 
-> 1506     return array2string(a, max_line_width, precision, suppress_small, ' ', "")
1507 
1508 
<several more calls omitted by me>
<ipython-input-233-eb1d93dc766b> in __getitem__(self, key)
6         if isinstance(key, tuple):
7             tuple_double = tuple(int(i*2) for i in key)
----> 8             return super(half_indexed_ndarray, self).__getitem__(tuple_double)
9         if isinstance(key, int) or isinstance(key, float):
10             return super(half_indexed_ndarray, self).__getitem__(int(key * 2))
IndexError: index -6 is out of bounds for axis 0 with size 3

我的解释是,numpy出于某种原因将元组索引(0,)转换为(-3,),然后再次调用__getitem__。但是,在这两个调用中,索引都要乘以2,而索引只应该乘以一次。不确定如何规避。

为了更改[]的行为,您需要重新实现__getitem__类方法。由于您的类在其他情况下会表现为标准列表,因此您可以这样做:

class HalfIndexedList(list):
def __getitem__(self, key):
return super().__getitem__(int(key * 2))
a = HalfIndexedList([10,11,12,13,14,15])
for i in range(0, 6):
print(f"{i/2 = }, {a[i/2] = }") 

(当然,这只会影响使用[]运算符获取项目,a.index返回的值等内容将不受影响

然而,我同意@Otto的回答和这样一种说法,即为了更干净的代码,处理输入端要好得多。半索引是没有意义的,而且真的很不直观。

顺便说一句,Python中的2D数组索引通常使用a[i][j]而不是a[i, j]来完成,因为2D数组实际上是列表的列表。

我的建议是,与其创建一个单独的类来处理半整数索引,不如在输入端处理它。如果你采用一个半整数索引系统,并将你的输入乘以2,你可以简单地将其转换为一个普通的整数索引。

这可能会导致一段更干净、更易于维护的代码。

但是,如果您想继续创建一个自定义的可迭代项,这可能会有所帮助:https://thispointer.com/python-how-to-make-a-class-iterable-create-iterator-class-for-it/