在位置参数和关键字参数中压缩可迭代项



问题

我有一个函数func,以及一组位置和关键字参数(argskwargs)。这些参数中的一些未知数量是特定类型(SpecialType)的可迭代参数,每个参数的长度相同。我想同时并行迭代所有这些特殊的可迭代项,并将func仅应用于这些可迭代项中的一组相应值,同时每次都传递其他函数参数。在伪代码中,我想要的可能大致如下:

from itertools import repeat
args_all_as_iterables = [a.iterable if isinstance(a, SpecialType) else repeat(a) for a in args]
kwargs_all_as_iterables = {k: v.iterable if isinstance(v, SpecialType) else repeat(v) for k, v in kwargs}
for args_values, kwarg_values in zip(*args_all_as_iterables, **kwargs_all_as_iterables):
this_result = func(*args_values, **kwargs_values)

上面的代码显然是无效的(zip不会接受未打包的kwargs),但我不知道如何像这样将位置参数和关键字参数压缩在一起。(这里itertools.repeat的目的是允许我将每个参数压缩在一起,只需根据需要多次重复不可迭代的参数。)


尝试

如果我所要做的就是将函数应用于多个可迭代对象,那么我会做一些类似的事情

for values in zip(*[a.iterable for a in args]):
result = func(*[v for v in values])

但我不知道如何概括这一点来接受夸尔格。

我可以尝试将kwargs转换为(key, value)元组的迭代器列表?我可以试着制作一个单独的列表,记录argskwargs中的哪些参数属于SpecialType?理想情况下,我希望有一种方法可以快速浏览args列表和kwargs字典。


上下文

上下文是SpecialType是包含多个数据集的Tree类。我实际调用的不是a.iterable,而是tree.subtree,它在树中的所有数据节点上返回一个迭代器。zip旨在允许我同时沿着任意数量的并排树的节点行走,同时将func应用于每棵树的相应节点中的数据。然后,我将在每组树节点上存储调用func的各种结果,并使用它来构建新的树。有关详细信息,请参阅本期。

您可以将kwargs.values()视为args。

我的解决方案,在其实际的原始上下文中看起来是这样的:

# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
args_as_tree_length_iterables = [a.subtree if isinstance(a, DataTree) else repeat(a) for a in args]
n_args = len(args_as_tree_length_iterables)
kwargs_as_tree_length_iterables = {k: v.subtree if isinstance(v, DataTree) else repeat(v) for k, v in kwargs.items()}
for all_node_args in zip(*args_as_tree_length_iterables, *list(kwargs_as_tree_length_iterables.values())):
node_args_as_datasets = [a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args]]
node_kwargs_as_datasets = dict(zip([k for k in kwargs_as_tree_length_iterables.keys()],
[v.ds if isinstance(v, DataTree) else v for v in all_node_args[n_args:]]))
# Now we can call func on the data in this particular set of corresponding nodes
results = func(*node_args_as_datasets, **node_kwargs_as_datasets)

最新更新