Python类型检查numpy数组包括它们的dtype



我可以验证我的函数接收到正确类型的输入:

def foo(x: np.ndarray, y: float):
return x * y

确保如果我尝试使用此函数与x不是np.ndarray,我会得到一个错误,甚至在运行代码之前。

我不知道的是,如何验证数组类型。例如:

def return_valid_points_only(points: np.ndarray, valid: np.ndarray):
assert points.shape == valid.shape
return points[valid]

我想检查valid不仅是np.ndarray,而且是valid.dtype == bool

对于这个例子,如果valid被提供0和1来表示有效性,程序不会失败,我将得到糟糕的结果。

感谢

Python就是请求原谅,而不是允许。这意味着,即使在你的第一个定义中,def foo(x: np.ndarray, y: float):实际上也依赖于用户对提示的尊重,除非你使用像mypy这样的东西。

这里有几种方法可以使用,通常是串联使用的。一种方法是以一种与传入的输入一起工作的方式编写函数,这可能意味着失败或强制无效输入。另一种方法是仔细记录代码,这样用户就可以做出明智的决定。第二种方法尤其重要,但我将着重于第一种方法。

Numpy为你做了大部分的检查。例如,与其期望一个数组,习惯做法是强制一个:

x = np.asanyarray(x)

np.asanyarray通常是array(a, dtype, copy=False, order=order, subok=True)的别名。您可以为y做类似的事情:

y = np.asanyarray(y).item()

这将允许任何类数组,只要它有一个元素,无论是否是标量。另一种方法是尊重numpy将数组一起广播的能力,因此,如果用户将y作为x.shape[-1]元素的列表传入。

对于第二个函数,有几个选项。一种选择是允许花哨的索引。因此,如果用户传入一个索引列表和一个布尔掩码,你可以两者都使用。另一方面,如果您坚持使用布尔掩码,则可以检查或强制使用dtype。

如果您检查,请记住,如果数组大小不匹配,numpy索引操作将为您引发错误。您只需要检查类型本身:

points = np.asanyarray(points)
valid = np.asanyarray(valid)
if valid.dtype != bool:
raise ValueError('valid argument must be a boolean mask')

如果你选择强制,用户将被允许使用0和1,但有效的输入将不会被不必要地复制:

valid = np.asanyarray(valid, bool)

最新更新