我可以验证我的函数接收到正确类型的输入:
def foo(x: np.ndarray, y: float):
return x * y
确保如果我尝试使用此函数与x
不是np.ndarray
,我会得到一个错误,甚至在运行代码之前。
我不知道的是,如何验证数组类型。例如:
def return_valid_points_only(points: np.ndarray, valid: np.ndarray):
assert points.shape == valid.shape
return points[valid]
我想检查valid
不仅是np.ndarray
,而且是valid.dtype == bool
。
对于这个例子,如果valid
被提供0和1来表示有效性,程序不会失败,我将得到糟糕的结果。
Python就是请求原谅,而不是允许。这意味着,即使在你的第一个定义中,def foo(x: np.ndarray, y: float):
实际上也依赖于用户对提示的尊重,除非你使用像mypy这样的东西。
这里有几种方法可以使用,通常是串联使用的。一种方法是以一种与传入的输入一起工作的方式编写函数,这可能意味着失败或强制无效输入。另一种方法是仔细记录代码,这样用户就可以做出明智的决定。第二种方法尤其重要,但我将着重于第一种方法。
Numpy为你做了大部分的检查。例如,与其期望一个数组,习惯做法是强制一个:
x = np.asanyarray(x)
np.asanyarray
通常是array(a, dtype, copy=False, order=order, subok=True)
的别名。您可以为y
做类似的事情:
y = np.asanyarray(y).item()
这将允许任何类数组,只要它有一个元素,无论是否是标量。另一种方法是尊重numpy将数组一起广播的能力,因此,如果用户将y
作为x.shape[-1]
元素的列表传入。
对于第二个函数,有几个选项。一种选择是允许花哨的索引。因此,如果用户传入一个索引列表和一个布尔掩码,你可以两者都使用。另一方面,如果您坚持使用布尔掩码,则可以检查或强制使用dtype。
如果您检查,请记住,如果数组大小不匹配,numpy索引操作将为您引发错误。您只需要检查类型本身:
points = np.asanyarray(points)
valid = np.asanyarray(valid)
if valid.dtype != bool:
raise ValueError('valid argument must be a boolean mask')
如果你选择强制,用户将被允许使用0和1,但有效的输入将不会被不必要地复制:
valid = np.asanyarray(valid, bool)