我想这样做:
if dim==2:
a,b=grid_shape
for i in range(a):
for j in range(b):
A[i,j] = ...things...
其中dim
是元组grid_shape
中元素的个数。A
是一个维度为dim
的numpy数组。有没有一种方法可以在不特定于维度的情况下做到这一点?不需要编写像
if dim==2:
a,b=grid_shape
for i in range(a):
for j in range(b):
A[i,j] = ...things...
if dim==3:
a,b,c=grid_shape
for i in range(a):
for j in range(b):
for k in range(c):
A[i,j,k] = ...things...
使用itertools,您可以这样做:
for index in itertools.product(*(range(x) for x in grid_shape)):
A[index] = ...things...
这依赖于几个技巧。首先,itertools.product()
是一个从可迭代对象生成元组的函数。
for i in range(a):
for j in range(b):
index = i,j
do_something_with(index)
可简化为
for index in itertools.product(range(a),range(b)):
do_something_with(index)
这适用于itertools.product()
的任意数量的参数,因此您可以有效地创建任意深度的嵌套循环。
(range(x) for x in grid_shape)
等价于
(range(grid_shape[0]),range(grid_shape[1]),...)
也就是说,它是每个grid_shape维度的范围元组。然后使用*将其扩展到参数中。
itertools.product(*(range(x1),range(x2),...))
等价于
itertools.product(range(x1),range(x2),...)
同样,由于A[i,j,k]
等价于A[(i,j,k)]
,我们可以直接使用A[index]
。
正如DSM指出的那样,由于您正在使用numpy,您可以减少
itertools.product(*(for range(x) for x in grid_shape))
numpy.ndindex(grid_shape)
最后一个循环变成
for index in numpy.ndindex(grid_shape):
A[index] = ...things...
您可以通过在最后一个变量前面加上一个星号来捕获元组的其余部分,并通过在其周围加上括号来创建一个数组。
>>> tupl = ((1, 2), 3, 4, 5, 6)
>>> a, *b = tupl
>>> a
(1, 2)
>>> b
[3, 4, 5, 6]
>>>
然后你可以循环b,所以它看起来像
a,*b=grid_shape
for i in a:
for j in range(i):
for k in b:
for l in range(k):
A[j, l] = ...things...