从2D numpy数组中提取2D等大小补丁的索引



我有一个numpy数组,如下所示。

a = np.random.rand(5,6)
a
Out[52]: 
array([[0.08649968, 0.24360955, 0.27972609, 0.21566217, 0.00194021,
0.69750779],
[0.09327379, 0.7579194 , 0.34634515, 0.78285156, 0.50981823,
0.17256468],
[0.7386456 , 0.78608358, 0.80615647, 0.72471626, 0.14825363,
0.62044455],
[0.32171325, 0.10889609, 0.56453828, 0.41675939, 0.09400235,
0.32373844],
[0.52850344, 0.0783796 , 0.74144658, 0.2363739 , 0.24535204,
0.9930051 ]])

然后,我使用以下函数从原始阵列a中获得非重叠补丁。我使用了之前提问的代码。

def select_random_windows(arr, number_of_windows, window_size):
# Get sliding windows
w = view_as_windows(arr,window_size, 3)
# Store shape info
m,n =  w.shape[:2]
# Get random row, col indices for indexing into windows array
lidx = np.random.choice(m*n,number_of_windows,replace=False)
r,c = np.unravel_index(lidx,(m,n))
# If duplicate windows are allowed, use replace=True or np.random.randint
# Finally index into windows and return output
return w[r,c]

然后我调用select_random_windows来获得以下两个不重叠的补丁,每个补丁的大小都是3x3

select_random_windows(a, number_of_windows=2, window_size=(3,3))
Out[54]: 
array([[[0.08649968, 0.24360955, 0.27972609],
[0.09327379, 0.7579194 , 0.34634515],
[0.7386456 , 0.78608358, 0.80615647]],
[[0.21566217, 0.00194021, 0.69750779],
[0.78285156, 0.50981823, 0.17256468],
[0.72471626, 0.14825363, 0.62044455]]])

现在,我如何获得两个补丁中每个补丁相对于主数组a的索引。例如,第一个补丁应该具有(1x1)的索引,而第二个补丁应该有(1x4)的索引。有没有任何方法可以提取这些补丁相对于原始数组a的中心索引。

您只需使用np.where(X == [value1, value2])即可获得值的索引

最新更新