从给定的索引中取出numpy 1d数组的多个切片,将结果复制到2d数组中



Python新手。下面的代码片段中给出了一个名为randomWalk的numpy 1d数组。给定索引(可以解释为开始日期和结束日期,两者可能因项目而异(,我想从1d数组randomWalk中提取多个切片,并将结果排列在给定形状的2d数组中。

我正在尝试将其矢量化。能够使用np.r_从1d阵列中选择我想要的切片,但未能以输出所需的格式存储这些切片(2d阵列,行表示项目,列表示从min(startDates)max(endDates)的时间。

下面是有效的(丑陋的(代码。

import numpy as np
numItems = 20
numPeriods = 12
# Data
randomWalk = np.random.normal(loc = 0.0, scale = 0.05, size = (numPeriods,))
startDates = np.random.randint(low = 1, high = 5, size = numItems)
endDates = np.random.randint(low = 5, high = numPeriods + 1, size = numItems)
stochasticItems = np.random.choice([False, True], size=(numItems,), p = [0.9, 0.1])
# Result needs to be in this shape (code snippet is designed to capture that only
# a relatively small fraction of resultMatrix's elements will differ from unity) 
resultMatrix = np.ones((numItems, numPeriods))
# Desired result (obtained via brute force)
for i in range(numItems):
if stochasticItems[i]:
resultMatrix[
i, startDates[i]:endDates[i]] = np.cumprod(randomWalk[startDates[i]:endDates[i]] + 1.0)

受@mozway答案的启发,将不规则切片转换为规则掩码数组:

>>> # build all arrays with np.random.seed(0)
>>> x = np.arange(numPeriods)
>>> mask = (startDates[:, None] <= x) & (endDates[:, None] > x)
>>> result = np.where(mask & stochasticItems[:, None], np.where(mask, randomWalk + 1, 1).cumprod(-1), 1)
>>> np.allclose(result, resultMatrix)
True
>>> result
array([[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.0489369 , 1.16646468, 1.2753867 ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ],
[1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        , 1.        , 1.        , 1.        ,
1.        , 1.        ]])

如果矢量化是目标,那么它是由Pig回答的,如果这无关紧要(正如OP在评论中提到的那样-->目标是提高性能(,那么我建议使用numba库来加速代码。我们可以编写np.cumprod等效的numba代码,并使用numba no-python-jit:加速它

@nb.njit
def nb_cumprod(arr):
y = np.empty_like(arr)
y[0] = arr[0]
for i in range(1, arr.shape[0]):
y[i] = arr[i] * y[i-1]
return y

@nb.njit
def nb_(numItems, numPeriods, stochasticItems, startDates, endDates, randomWalk):
resultMatrix = np.ones((numItems, numPeriods))
for i in range(numItems):
if stochasticItems[i]:
resultMatrix[i, startDates[i]:endDates[i]] = nb_cumprod(randomWalk[startDates[i]:endDates[i]] + 1.0)
return resultMatrix

这段代码比我的一些基准测试中的OP更快地改进了代码~10 times

相关内容

  • 没有找到相关文章

最新更新