生成组合拆分的Python迭代器



我知道itertools.combinations将从一个iterable:中遍历特定长度的所有组合

import itertools
import numpy as np
my_array = np.array([[0,1,2],[3,4,5],[6,7,8]])
for c in itertools.combinations(my_array, 2):
    print (c)

将打印

(array([0, 1, 2]), array([3, 4, 5]))
(array([0, 1, 2]), array([6, 7, 8]))
(array([3, 4, 5]), array([6, 7, 8]))

然而,我想要像一样的东西

for c, d in combo_splits(my_array, 2):
    print (c)
    print (d)

打印

(array([0, 1, 2]), array([3, 4, 5])) 
array([4, 5, 6])
(array([0, 1, 2]), array([6, 7, 8])) 
array([3, 4, 5])
(array([3, 4, 5]), array([6, 7, 8]))
array([0, 2, 3])

我会使用组合来生成掩码,然后只使用掩码对数组进行花式索引。看起来是这样的:

import itertools
import numpy as np
my_array = np.array([[0,1,2],[3,4,5],[6,7,8]])
n = len(my_array)
for r in itertools.combinations(xrange(n), 2):
    rows = np.array(r)
    on = np.zeros(n, dtype=bool)
    on[rows] = True
    print my_array[on]
    print my_array[~on]

这不是Python自带的,但您可以很容易地在itertools.combinations之上构建它:

def combinations_with_leftovers(pool, k):
    """itertools.combinations, but also returning the parts we didn't take.
    Each element of combinations_with_leftovers(pool, k) is a tuple (a, b)
    where a is the tuple that itertools.combinations would have given
    and b is a tuple of the elements of pool not used in a.
    """
    pool = tuple(pool)
    for chosen_indices in itertools.combinations(xrange(len(pool)), k):
        chosen_indices = set(chosen_indices)
        a = tuple(x for i, x in enumerate(pool) if i in chosen_indices)
        b = tuple(x for i, x in enumerate(pool) if i not in chosen_indices)
        yield a, b

最新更新