我有一个数组,如下所示:
[['A0' 'B0' 'C0']
['A1' 'B1' 'C1']
['A2' 'B2' 'C2']]
我想得到B1
的邻居,即B0 , C1 , B2 , A1
,以及它们的索引。
这就是我想到的:
import numpy as np
arr = np.array([
['A0','B0','C0'],
['A1','B1','C1'],
['A2','B2','C2'],
])
def get_neighbor_indices(x,y):
neighbors = []
try:
top = arr[y - 1, x]
neighbors.append((top, (y - 1, x)))
except IndexError:
pass
try:
bottom = arr[y + 1, x]
neighbors.append((bottom, (y + 1, x)))
except IndexError:
pass
try:
left = arr[y, x - 1]
neighbors.append((left, (y, x - 1)))
except IndexError:
pass
try:
right = arr[y, x + 1]
neighbors.append((right, (y, x + 1)))
except IndexError:
pass
return neighbors
这将返回元组(value, (y, x))
的列表。
有没有更好的方法可以做到这一点而不依赖try/except?
您可以在numpy中直接执行此操作,没有任何例外,因为您知道数组的大小。由给出了x, y
近邻的索引
inds = np.array([[x, y]]) + np.array([[1, 0], [-1, 0], [0, 1], [0, -1]])
你可以很容易地制作一个掩码来指示哪些索引是有效的:
valid = (inds[:, 0] >= 0) & (inds[:, 0] < arr.shape[0]) &
(inds[:, 1] >= 0) & (inds[:, 1] < arr.shape[1])
现在提取您想要的值:
inds = inds[valid, :]
vals = arr[inds[:, 0], inds[:, 1]]
最简单的返回值是inds, vals
,但如果您坚持保留原始格式,则可以将其转换为
[v, tuple(i) for v, i in zip(vals, inds)]
附录
你可以很容易地修改它来处理任意尺寸:
def neighbors(arr, *pos):
pos = np.array(pos).reshape(1, -1)
offset = np.zeros((2 * pos.size, pos.size), dtype=np.int)
offset[np.arange(0, offset.shape[0], 2), np.arange(offset.shape[1])] = 1
offset[np.arange(1, offset.shape[0], 2), np.arange(offset.shape[1])] = -1
inds = pos + offset
valid = np.all(inds >= 0, axis=1) & np.all(inds < arr.shape, axis=1)
inds = inds[valid, :]
vals = arr[tuple(inds.T)]
return vals, inds
给定一个N维数组arr
和pos
的N个元素,您可以通过将每个维度依次设置为1
或-1
来创建偏移。通过将inds
和arr.shape
一起广播,以及在每个N大小的行中调用np.all
,而不是针对每个维度手动执行,掩码valid
的计算大大简化。最后,转换tuple(inds.T)
通过将每列分配给一个单独的维度,将inds
转换为实际的花式索引。转置是必要的,因为数组在行上迭代(dim 0(。
您可以使用这个:
def get_neighbours(inds):
places = [(-1, 0), (1, 0), (0, -1), (0, 1)]
return [(arr[x, y], (y, x)) for x, y in [(inds[0] + p[0], inds[1] + p[1]) for p in places] if x >= 0 and y >= 0]
get_neighbours(1, 1)
# OUTPUT [('B0', (1, 0)), ('B2', (1, 2)), ('A1', (0, 1)), ('C1', (2, 1))]
get_neighbours(0, 0)
# OUTPUT [('A1', (0, 1)), ('B0', (1, 0))]
这个怎么样?
def get_neighbor_indices(x,y):
return ( [(arr[y-1,x], (y-1, x))] if y>0 else [] ) +
( [(arr[y+1,x], (y+1, x))] if y<arr.shape[0]-1 else [] ) +
( [(arr[y,x-1], (y, x-1))] if x>0 else [] ) +
( [(arr[y,x+1], (y, x+1))] if x<arr.shape[1]-1 else [] )