Scipy 的 N 维插值与 "sparse=True" 产生意外的 ValueError



我正在使用scipy.interpolate.interpn研究 n-D 插值示例。正如预期的那样,下面的玩具示例代码可以正常工作。

#!/usr/bin/env python3
from scipy.interpolate import interpn
import numpy as np
x=np.arange(4)
y=np.arange(3)
z=np.arange(2)
xx = np.linspace(0, 3, 7)
yy = np.linspace(0,2, 5)
zz = np.linspace(0,1,3)
a1=np.arange(24)
a1=a1.reshape((4,3,2))
grids=np.array(np.meshgrid(xx,yy,zz, indexing='ij'))   
grids=np.moveaxis(grids, 0, -1)
a2=interpn((x,y,z), a1, grids)

但是,如果我 改变
grids=np.array(np.meshgrid(xx,yy,zz, indexing='ij'))

grids=np.array(np.meshgrid(xx,yy,zz, sparse=True, indexing='ij'))

我得到了ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all(),这指向

/home/zell/.local/lib/python3.6/

site-packages/scipy-1.4.1-py3.6-linux-x86_64.egg/scipy/interpolate/interpolate.py(2645(interpn((

这个问题似乎是由"稀疏=真"引入的。
如何在保持"稀疏=True"的同时解决此问题(因为我的网格将占用大量内存(?

meshgrid创建 3 个数组,每个输入数组一个。 没有sparse,每个都是一个3D数组,形状相同:

In [95]: len(np.meshgrid(xx,yy,zz, indexing='ij'))                                                              
Out[95]: 3
In [96]: np.meshgrid(xx,yy,zz, indexing='ij')[0].shape                                                          
Out[96]: (7, 5, 3)

当你把它包装在np.array你会得到一个(3, 7, 5, 3(数组。

使用sparse,它制作 3 个数组,也是 3d,但不是完整的。 它们以相同的方式一起广播,但没有重复的元素

In [97]: np.meshgrid(xx,yy,zz, indexing='ij', sparse=True)[0].shape                                             
Out[97]: (7, 1, 1)
In [98]: np.meshgrid(xx,yy,zz, indexing='ij', sparse=True)[1].shape                                             
Out[98]: (1, 5, 1)
In [99]: np.meshgrid(xx,yy,zz, indexing='ij', sparse=True)[2].shape                                             
Out[99]: (1, 1, 3)

你不能像以前那样把它们变成一个 4D 阵列!

在最新的 1.19dev 中,我收到以下警告:

In [101]: np.array(np.meshgrid(xx,yy,zz, indexing='ij', sparse=True)).shape                                     
/usr/local/bin/ipython3:1: VisibleDeprecationWarning: Creating an 
ndarray from ragged nested sequences (which is a list-or-tuple of 
lists-or-tuples-or ndarrays with different lengths or shapes) is 
deprecated. If you meant to do this, you must specify 'dtype=object' 
when creating the ndarray
#!/usr/bin/python3
Out[101]: (3,)

正是这个 3 元素对象 dtype 数组带来了interpn问题。

interpn文档将xi参数指定为

xi - ndarray of shape (…, ndim)
The coordinates to sample the gridded data at

显然,它期望的是常规的numpy数组,而不是"破烂"的数组。

最新更新