利用数组索引将二维数组函数应用于三维数组



我编写了一个函数,它接受一组随机笛卡尔坐标,并返回留在某个空间域中的子集。说明:

grid = np.ones((5,5))
grid = np.lib.pad(grid, ((10,10), (10,10)), 'constant')
>> np.shape(grid)
(25, 25)
random_pts = np.random.random(size=(100, 2)) * len(grid)
def inside(input):
     idx = np.floor(input).astype(np.int)
     mask = grid[idx[:,0], idx[:,1]] == 1
     return input[mask]
>> inside(random_pts)
array([[ 10.59441506,  11.37998288],
       [ 10.39124766,  13.27615815],
       [ 12.28225713,  10.6970708 ],
       [ 13.78351949,  12.9933591 ]])

但是现在我想要同时生成n个random_pts集合并保持n个满足相同功能条件的对应子集的能力。因此,如果n=3

random_pts = np.random.random(size=(3, 100, 2)) * len(grid)

不诉诸for循环,我如何索引我的变量,使inside(random_pts)返回类似

的东西
array([[[ 17.73323523,   9.81956681],
        [ 10.97074592,   2.19671642],
        [ 21.12081044,  12.80412997]],
       [[ 11.41995519,   2.60974757]],
       [[  9.89827156,   9.74580059],
        [ 17.35840479,   7.76972241]]])

一种方法-

def inside3d(input):
    # Get idx in 3D
    idx3d = np.floor(input).astype(np.int)
    # Create a similar mask as witrh 2D case, but in 3D now
    mask3d = grid[idx3d[:,:,0], idx3d[:,:,1]]==1
    # Count of mask matches for each index in 0th dim    
    counts = np.sum(mask3d,axis=1)
    # Index into input to get masked matches across all elements in 0th dim
    out_cat_array = input.reshape(-1,2)[mask3d.ravel()]
    # Split the rows based on the counts, as the final output
    return np.split(out_cat_array,counts.cumsum()[:-1])

验证结果-

创建3D随机输入:
In [91]: random_pts3d = np.random.random(size=(3, 100, 2)) * len(grid)
与inside3d:

In [92]: inside3d(random_pts3d)
Out[92]: 
[array([[ 10.71196268,  12.9875877 ],
        [ 10.29700184,  10.00506662],
        [ 13.80111411,  14.80514828],
        [ 12.55070282,  14.63155383]]), array([[ 10.42636137,  12.45736944],
        [ 11.26682474,  13.01632751],
        [ 13.23550598,  10.99431284],
        [ 14.86871413,  14.19079225],
        [ 10.61103434,  14.95970597]]), array([[ 13.67395756,  10.17229061],
        [ 10.01518846,  14.95480515],
        [ 12.18167251,  12.62880968],
        [ 11.27861513,  14.45609646],
        [ 10.895685  ,  13.35214678],
        [ 13.42690335,  13.67224414]])]
与内部:

In [93]: inside(random_pts3d[0])
Out[93]: 
array([[ 10.71196268,  12.9875877 ],
       [ 10.29700184,  10.00506662],
       [ 13.80111411,  14.80514828],
       [ 12.55070282,  14.63155383]])
In [94]: inside(random_pts3d[1])
Out[94]: 
array([[ 10.42636137,  12.45736944],
       [ 11.26682474,  13.01632751],
       [ 13.23550598,  10.99431284],
       [ 14.86871413,  14.19079225],
       [ 10.61103434,  14.95970597]])
In [95]: inside(random_pts3d[2])
Out[95]: 
array([[ 13.67395756,  10.17229061],
       [ 10.01518846,  14.95480515],
       [ 12.18167251,  12.62880968],
       [ 11.27861513,  14.45609646],
       [ 10.895685  ,  13.35214678],
       [ 13.42690335,  13.67224414]])

最新更新