如何解压缩元组以便在不特定于维度的情况下进行循环



我想这样做:

    if dim==2:
        a,b=grid_shape
        for i in range(a):
            for j in range(b):
                A[i,j] = ...things...

其中dim是元组grid_shape中元素的个数。A是一个维度为dim的numpy数组。有没有一种方法可以在不特定于维度的情况下做到这一点?不需要编写像

这样难看的代码
    if dim==2:
        a,b=grid_shape
        for i in range(a):
            for j in range(b):
                A[i,j] = ...things...
    if dim==3:
        a,b,c=grid_shape
        for i in range(a):
            for j in range(b):
                for k in range(c):
                    A[i,j,k] = ...things...

使用itertools,您可以这样做:

for index in itertools.product(*(range(x) for x in grid_shape)):
    A[index] = ...things...

这依赖于几个技巧。首先,itertools.product()是一个从可迭代对象生成元组的函数。

for i in range(a):
    for j in range(b):
        index = i,j
        do_something_with(index)

可简化为

for index in itertools.product(range(a),range(b)):
    do_something_with(index)

这适用于itertools.product()的任意数量的参数,因此您可以有效地创建任意深度的嵌套循环。

另一个技巧是将网格形状转换为itertools.product的参数:
(range(x) for x in grid_shape)

等价于

(range(grid_shape[0]),range(grid_shape[1]),...)

也就是说,它是每个grid_shape维度的范围元组。然后使用*将其扩展到参数中。

itertools.product(*(range(x1),range(x2),...))

等价于

itertools.product(range(x1),range(x2),...)

同样,由于A[i,j,k]等价于A[(i,j,k)],我们可以直接使用A[index]

正如DSM指出的那样,由于您正在使用numpy,您可以减少

itertools.product(*(for range(x) for x in grid_shape))

numpy.ndindex(grid_shape)

最后一个循环变成

for index in numpy.ndindex(grid_shape):
    A[index] = ...things...

您可以通过在最后一个变量前面加上一个星号来捕获元组的其余部分,并通过在其周围加上括号来创建一个数组。

>>> tupl = ((1, 2), 3, 4, 5, 6)
>>> a, *b = tupl
>>> a
(1, 2)
>>> b
[3, 4, 5, 6]
>>> 

然后你可以循环b,所以它看起来像

a,*b=grid_shape
for i in a:
    for j in range(i):
        for k in b:
            for l in range(k):
                A[j, l] = ...things...

相关内容

  • 没有找到相关文章

最新更新