如何使用pytest测试迭代器函数?



我有一个函数,它使用批处理大小参数在可迭代对象上提供批处理迭代器:

def batch_iterator(iterable: Iterable[Row], batch_size: int) -> Iterator:
"""
Slices off a batch of values from an iterable, and returns it as an iterator.
"""
return iter(lambda: list(itertools.islice(iterable, batch_size)), [])

我想用pytest测试函数。下面是我尝试过的错误:

def test_batch_iterator():
words_tuple = ('jimi', 'bertrand', 'alain', 'buck', 'apple', 'banana', 'cherry', 'oak', 'maple', 'python')
result = batch_iterator(3, words_tuple)
assert result == iter(words_tuple[:2])
#----------------------------------------------------------
the pytest result at the console:
...
Expected :<tuple_iterator object at 0x00000216F2F23130>
Actual   :<callable_iterator object at 0x00000216F2F23070>
<Click to see difference>
def test_batch_iterator():

words_tuple = ('jimi', 'bertrand', 'alain', 'buck', 'apple', 'banana', 'cherry', 'oak', 'maple', 'python')
result = batch_iterator(3, words_tuple)
>       assert result == iter(words_tuple[:2])
E       assert <callable_iterator object at 0x00000216F2F23070> == <tuple_iterator object at 0x00000216F2F23130>

batch_iterator(words_tuple, 3)返回具有['jimi', 'bertrand', 'alain']的迭代器。要从中获取列表,可以使用next。您还需要在调用batch_iterator()时切换参数,并将words_tuple更改为列表,或者在断言

中强制转换它
words_tuple = ['jimi', 'bertrand', 'alain', 'buck', 'apple', 'banana', 'cherry', 'oak', 'maple', 'python']
result = batch_iterator(words_tuple, 3)
assert next(result) == words_tuple[:3]

更新当前batch_iterator创建的iterator只有前3个值。如果你想要words_tuple中的所有条目你可以使用

def batch_iterator(iterable: Iterable[Row], batch_size: int) -> Iterator:
"""
Slices off a batch of values from an iterable, and returns it as an iterator.
"""
return iter(iterable[i:i+batch_size] for i in range(0, len(iterable), batch_size))
words_tuple = ('jimi', 'bertrand', 'alain', 'buck', 'apple', 'banana', 'cherry', 'oak', 'maple', 'python')
result = batch_iterator(words_tuple, 3)
assert next(result) == words_tuple[:3]

最新更新