我有一个维度(l,n,m)的大型3D numpy数组,其中的元素分别对应于维度为l,n和m的x,y和z的1D数组。我想通过在x和y的每个组合的z值之间插值来找到给定a值(长度为b)的元素。这将给出维度(l,n,b)的输出3D阵列。我希望完全使用numpy数组,而不是使用for循环。
例如,如果我的3D阵列有尺寸(2,3,4):
x = 1 | z = 1 | 2 | 3 | 4
- - - - - - - - - - - - - -
y = 1 |[[[ 0, 1, 2, 3],
y = 2 | [ 4, 5, 6, 7],
y = 3 | [ 8, 9, 10, 11]],
x = 2 | z = 1 | 2 | 3 | 4
- - - - - - - - - - - - -
y = 1 | [[ 12, 13, 14, 15],
y = 2 | [ 16, 17, 18, 19],
y = 3 | [ 20, 21, 22, 23]]]
我想在每行{(x=1,y=1),(x=1),y=2),(y=1,y=3),(x_2,y=1
[[[ 0.3, 0.8, 1.34, 1.9, 2.45],
[ 4.3, 4.8, 5.34, 5.9, 6.45],
[ 8.3, 8.8, 9.34, 9.9, 10.45]],
[[ 12.3, 12.8, 13.34, 13.9, 14.45],
[ 16.3, 16.8, 17.34, 17.9, 18.45],
[ 20.3, 20.8, 21.34, 21.9, 22.45]]]
目前,我使用for循环来迭代x和y的每个组合,并将3D数组的行输入numpy.iterpolate函数,并将输出保存到另一个数组中;然而,对于大型阵列,这是非常缓慢的。
# array is the 3D array with dimensions (l, n, m)
# x, y and z have length l, n and m respectively
# a is the values at which I wish to interpolate at with length b
# new_array is set up with dimensions (l, n, b)
new_array = N.zeros(len(x)*len(y)*len(a)).reshape(len(x), len(y), len(a))
for i in range(len(x)):
for j in range(len(y)):
new_array[i,j,:] = numpy.interpolate(a, z, array[i,j,:])
任何帮助都将不胜感激。
您不需要for循环来通过scipy.interpolate.griddata
:运行数据
>>> from itertools import product
>>>from scipy.interpolate import griddata
>>> data = np.arange(24).reshape(2, 3, 4)
>>> x = np.arange(1, 3)
>>> y = np.arange(1, 4)
>>> z = np.arange(1, 5)
>>> points = np.array(list(product(x, y, z)))
# This is needed if your x, y and z are not consecutive ints
>>> _, x_idx = np.unique(x, return_inverse=True)
>>> _, y_idx = np.unique(y, return_inverse=True)
>>> _, z_idx = np.unique(z, return_inverse=True)
>>> point_idx = np.array(list(product(x_idx, y_idx, z_idx)))
>>> values = data[point_idx[:, 0], point_idx[:, 1], point_idx[:, 2]]
>>> new_z = np.array( [1.3, 1.8, 2.34, 2.9, 3.45])
>>> new_points = np.array(list(product(x, y, new_z)))
>>> new_values = griddata(points, values, new_points)
>>> new_values.reshape(2, 3, -1)
array([[[ 0.3 , 0.8 , 1.34, 1.9 , 2.45],
[ 4.3 , 4.8 , 5.34, 5.9 , 6.45],
[ 8.3 , 8.8 , 9.34, 9.9 , 10.45]],
[[ 12.3 , 12.8 , 13.34, 13.9 , 14.45],
[ 16.3 , 16.8 , 17.34, 17.9 , 18.45],
[ 20.3 , 20.8 , 21.34, 21.9 , 22.45]]])