我有一个函数,它使用批处理大小参数在可迭代对象上提供批处理迭代器:
def batch_iterator(iterable: Iterable[Row], batch_size: int) -> Iterator:
"""
Slices off a batch of values from an iterable, and returns it as an iterator.
"""
return iter(lambda: list(itertools.islice(iterable, batch_size)), [])
我想用pytest测试函数。下面是我尝试过的错误:
def test_batch_iterator():
words_tuple = ('jimi', 'bertrand', 'alain', 'buck', 'apple', 'banana', 'cherry', 'oak', 'maple', 'python')
result = batch_iterator(3, words_tuple)
assert result == iter(words_tuple[:2])
#----------------------------------------------------------
the pytest result at the console:
...
Expected :<tuple_iterator object at 0x00000216F2F23130>
Actual :<callable_iterator object at 0x00000216F2F23070>
<Click to see difference>
def test_batch_iterator():
words_tuple = ('jimi', 'bertrand', 'alain', 'buck', 'apple', 'banana', 'cherry', 'oak', 'maple', 'python')
result = batch_iterator(3, words_tuple)
> assert result == iter(words_tuple[:2])
E assert <callable_iterator object at 0x00000216F2F23070> == <tuple_iterator object at 0x00000216F2F23130>
batch_iterator(words_tuple, 3)
返回具有['jimi', 'bertrand', 'alain']
的迭代器。要从中获取列表,可以使用next
。您还需要在调用batch_iterator()
时切换参数,并将words_tuple
更改为列表,或者在断言
words_tuple = ['jimi', 'bertrand', 'alain', 'buck', 'apple', 'banana', 'cherry', 'oak', 'maple', 'python']
result = batch_iterator(words_tuple, 3)
assert next(result) == words_tuple[:3]
更新当前batch_iterator
创建的iterator
只有前3个值。如果你想要words_tuple
中的所有条目你可以使用
def batch_iterator(iterable: Iterable[Row], batch_size: int) -> Iterator:
"""
Slices off a batch of values from an iterable, and returns it as an iterator.
"""
return iter(iterable[i:i+batch_size] for i in range(0, len(iterable), batch_size))
words_tuple = ('jimi', 'bertrand', 'alain', 'buck', 'apple', 'banana', 'cherry', 'oak', 'maple', 'python')
result = batch_iterator(words_tuple, 3)
assert next(result) == words_tuple[:3]