使用Numpy对除第一个维度以外的所有维度进行平面索引



是否有一些方法可以使用NumPy对其余维度使用平面索引?我正在尝试将以下MATLAB函数转换为Python

function [indices, weights] = locate(values, gridpoints)
indices = ones(size(values));
weights = zeros([2, size(values)]);
for ix = 1:numel(values)
if values(ix) <= gridpoints(1)
indices(ix) = 1;
weights(:, ix) = [1; 0];
elseif values(ix) >= gridpoints(end)
indices(ix) = length(gridpoints) - 1;
weights(:, ix) = [0; 1];
else
indices(ix) = find(gridpoints <= values(ix), 1, 'last');    
weights(:, ix) = ...
[gridpoints(indices(ix) + 1) - values(ix); ...
values(ix) - gridpoints(indices(ix))] ...
/ (gridpoints(indices(ix) + 1) - gridpoints(indices(ix)));
end
end
end

但我无法理解MATLAB的weights(:, ix)的NumPy等价物是什么——也就是说,只在剩下的维度上进行线性索引。

我希望语法可以直接翻译,但假设values是一个3乘4的数组,那么weights就变成了一个2乘3乘4数组。在MATLAB中,weights(:, ix)是一个2乘1的数组,而在Python中weights[:, ix]是一个2-乘3的数组。

我想我已经处理了下面函数中的所有其他内容。

import numpy as np

def locate(values, gridpoints):
indices = np.zeros(np.shape(values), dtype=int)
weights = np.zeros((2,) + np.shape(values))
for ix in range(values.size):
if values.flat[ix] <= gridpoints[0]:
indices.flat[ix] = 0
# weights[:, ix] = [1, 0]
elif values.flat[ix] >= gridpoints[-1]:
indices.flat[ix] = gridpoints.size - 2
# weights[:, ix] = [0, 1]
else:
indices.flat[ix] = (
np.argwhere(gridpoints <= values.flat[ix]).flatten()[-1]
)
# weights[:, ix] = (
#         np.array([gridpoints[indices.flat[ix] + 1] - values.flat[ix],
#                   values.flat[ix] - gridpoints[indices.flat[ix]]])
#         / (gridpoints[indices.flat[ix] + 1] - gridpoints[indices.flat[ix]])
# )
return indices, weights

你有什么建议吗?也许我只是把这个问题想错了。我也试着尽可能简单地编写代码,因为我打算稍后使用Numba来加快速度。

根据hpaulj的评论,似乎没有直接的NumPy等价物。如果没有,我能想到的最好的办法就是按照下面的代码和NumPy给Matlab用户的建议来重塑weights数组。

import numpy as np

def locate(values, gridpoints):
indices = np.zeros(values.shape, dtype=int)
weights = np.zeros((2, values.size))  # Temporarily make weights 2-by-N
for ix in range(values.size):
if values.flat[ix] <= gridpoints[0]:
indices.flat[ix] = 0
weights[:, ix] = [1, 0]
elif values.flat[ix] >= gridpoints[-1]:
indices.flat[ix] = gridpoints.size - 2
weights[:, ix] = [0, 1]
else:
indices.flat[ix] = (
np.argwhere(gridpoints <= values.flat[ix]).flatten()[-1]
)
weights[:, ix] = (
np.array([gridpoints[indices.flat[ix] + 1] - values.flat[ix],
values.flat[ix] - gridpoints[indices.flat[ix]]])
/ (gridpoints[indices.flat[ix] + 1] - gridpoints[indices.flat[ix]])
)

# Give weights correct dimensions
weights.shape = (2,) + values.shape

return indices, weights

相关内容

  • 没有找到相关文章

最新更新