ano张量上的重叠迭代



我试图在ano中实现一个扫描循环,给定一个张量将使用输入的"移动切片"。它不一定是一个移动的片,它可以是一个预处理张量到另一个代表移动片的张量。

基本上

:

[0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16]
 |-------|                                 (first  iteration)
   |-------|                               (second iteration)
     |-------|                             (third  iteration)
               ...
                    ...
                        ...
                               |-------|   (last   iteration)

其中|-------|为每次迭代的输入。

我正试图找出最有效的方法来做到这一点,也许使用某种形式的引用或操纵跨步,但我还没有设法得到一些工作,甚至为纯numpy。

我找到的一个可能的解决方案可以在这里找到,但是我不知道如何使用strides,我没有看到一个方法来使用ano。

您可以构建一个包含每个时间步长切片的起始索引的向量,并使用该向量作为序列调用Scan,而原始向量作为非序列。然后,在Scan内部,您可以在每次迭代中获得所需的切片。

我包含了一个示例,其中我也将切片的大小作为符号输入,以防您想从一次调用Theano函数更改到下一次调用:

import theano
import theano.tensor as T
# Input variables
x = T.vector("x")
slice_size = T.iscalar("slice_size")

def step(idx, vect, length):
    # From the idx of the start of the slice, the vector and the length of
    # the slice, obtain the desired slice.
    my_slice = vect[idx:idx + length]
    # Do something with the slice here. I don't know what you want to do
    # to I'll just return the slice itself.
    output = my_slice
    return output
# Make a vector containing the start idx of every slice
slice_start_indices = T.arange(x.shape[0] - slice_size + 1)
out, updates = theano.scan(fn=step,
                        sequences=[slice_start_indices],
                        non_sequences=[x, slice_size])
fct = theano.function([x, slice_size], out)

使用参数运行函数会产生如下输出:

print fct(range(17), 5)
[[  0.   1.   2.   3.   4.]
 [  1.   2.   3.   4.   5.]
 [  2.   3.   4.   5.   6.]
 [  3.   4.   5.   6.   7.]
 [  4.   5.   6.   7.   8.]
 [  5.   6.   7.   8.   9.]
 [  6.   7.   8.   9.  10.]
 [  7.   8.   9.  10.  11.]
 [  8.   9.  10.  11.  12.]
 [  9.  10.  11.  12.  13.]
 [ 10.  11.  12.  13.  14.]
 [ 11.  12.  13.  14.  15.]
 [ 12.  13.  14.  15.  16.]]

您可以使用以下rollling_window recipe:

import numpy as np
def rolling_window_lastaxis(arr, winshape):
    """
    Directly taken from Erik Rigtorp's post to numpy-discussion.
    http://www.mail-archive.com/numpy-discussion@scipy.org/msg29450.html
    (Erik Rigtorp, 2010-12-31)
    See also:
    http://mentat.za.net/numpy/numpy_advanced_slides/ (Stéfan van der Walt, 2008-08)
    https://stackoverflow.com/a/21059308/190597 (Warren Weckesser, 2011-01-11)
    https://stackoverflow.com/a/4924433/190597 (Joe Kington, 2011-02-07)
    https://stackoverflow.com/a/4947453/190597 (Joe Kington, 2011-02-09)
    """
    if winshape < 1:
        raise ValueError("winshape must be at least 1.")
    if winshape > arr.shape[-1]:
        print(winshape, arr.shape)
        raise ValueError("winshape is too long.")
    shape = arr.shape[:-1] + (arr.shape[-1] - winshape + 1, winshape)
    strides = arr.strides + (arr.strides[-1], )
    return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)
x = np.arange(17)
print(rolling_window_lastaxis(x, 5))

打印

[[ 0  1  2  3  4]
 [ 1  2  3  4  5]
 [ 2  3  4  5  6]
 [ 3  4  5  6  7]
 [ 4  5  6  7  8]
 [ 5  6  7  8  9]
 [ 6  7  8  9 10]
 [ 7  8  9 10 11]
 [ 8  9 10 11 12]
 [ 9 10 11 12 13]
 [10 11 12 13 14]
 [11 12 13 14 15]
 [12 13 14 15 16]]

请注意,还有一些更奇特的扩展,例如Joe Kington的rollling_window可以在多维窗口上滚动,Sebastian Berg的实现可以按步骤跳转。

最新更新