为什么"过滤器"适用于一个列表而不是其他给定"The truth value of an array with more than one element is ambiguous."?



我有以下助手函数:

flatten = lambda sentences : [word for sentence in sentences for word in sentence] # flattens list of lists 
get_feature = lambda f, input: np.array(list(map(f, input))) # applies f() to each element of input list and returns list of resultant elements
is_numeric = lambda words: get_feature(lambda word: word.isnumeric(), words)

然后执行以下操作:

ls = ['a','1','b','2','c','d','e']
print(type(ls))
print(list(filter(is_numeric, ls)))
print(type(train_tokens))
print(train_tokens[:10])
print(list(filter(is_numeric, train_tokens))[:10])

给出以下输出:

<class 'list'>
['1', '2']
<class 'list'>
['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.', 'Peter']
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-5-59836dbfb218> in <module>()
11 print(type(train_tokens))
12 print(train_tokens[:10])
---> 13 print(list(filter(is_numeric, train_tokens))[:10])
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

lstrain_tokens都是列表。那么我不明白为什么filterls有效,而对train_tokens无效。

我犯了什么愚蠢的错误?

PS:这就是我形成train_tokens:的方式

!pip install datasets
from datasets import load_dataset
conll2003dataset = load_dataset("conll2003")
train_tokens = flatten(conll2003dataset['train']['tokens'])

这是笔记本的链接。

filter一起使用的函数没有完成预期的操作。filter将该函数应用于其输入的可迭代元素,每次一个。对于您的函数,这意味着您要检查单词的每个字符,看看它是否是数字,并返回一个booleannumpy数组。只有当长度为1时,该数组才能与filter一起使用。也就是说,如果单词中只有一个字母。对于较长的单词,您会得到较长的数组,正如错误消息所说,您无法在布尔上下文中计算长度大于1的numpy数组。

如果您真的想使用filter,那么实际上不需要编写自己的函数(当然也不需要编写任何像您所展示的那样复杂的函数)。只需将str.isnumeric作为函数传递,它就会一次性检查整个单词(而不是逐个字母)。

print(list(filter(str.isnumeric, train_tokens)))

另一方面,如果你真的想使用numpy和你自己的数字检查代码,你可以放弃filter,在整个列表上调用is_numericlambda函数,得到一个可以用作掩码的布尔numpy数组。如果你也把输入转换成一个数组,你可以用掩码对它进行索引,只得到数字条目:

numeric_mask = is_numeric(train_tokens)   # this is calling your function is_numeric
train_tokens_array = np.asarray(train_tokens)
print(train_tokens_array[numeric_mask][:10])

相关内容