在python/numpy中有没有一种更简单的方法来完成这个数组赋值



我正在寻找一种更可读的方法来完成数组赋值(for循环部分(。

spike_train = np.zeros((2,2000)) # Array to be filled
spikes = np.random.randint(low=0, high = spike_train.shape[1]-1, size = (2,100)) # Indexes to fill
for i in range(spikes.shape[0]):  # I feel that this part could be more readable
spike_train[i, spikes[i,:]] = 1
#I was thinking something along these lines:  spike_train[spikes] = 1, but I know this doesn't work

您需要广播第一个维度(而不是像@Valdi_Bo和@Andre do那样广播repeat——它构造了一个大的中间数组(:

spike_train = np.zeros((2,2000)) # Array to be filled
spikes = np.random.randint(low=0, high = spike_train.shape[1]-1, size = (2,100)) # Indexes to fill
spike_train[np.arange(spikes.shape[0])[:, None], spikes] = 1

在这种情况下,您也可以使用np.put_along_axis

np.put_along_axis(spike_train, spikes, values = 1, axis = 1)

编辑:正如@DanielF所指出的,我在这里描述的方法不是最佳实践。我将把它留在这里供参考,但请参阅@DanielF的asnwer以获得更优化的解决方案。


另一种方法是得到";CCD_ 3、CCD_;并将其设置为1。

其中x指代行(即01(,y指代索引。

以下是如何实现这一点:

import numpy as np
spike_train = np.zeros((2,2000)) # Array to be filled
n_spikes_per_row = 100
x = np.repeat([0,1], n_spikes_per_row)
y = np.random.randint(low=0, high = spike_train.shape[1]-1, size = 2 * n_spikes_per_row) # Indexes to fill
spike_train[x, y] = 1

也许有点多余,但xy是平面阵列,包含随机选择的坐标,如下所示:

x = [0,  0,  0,  0,  0,  1,  1,  1,  1,  1]
y = [13, 5,  7,  4,  9, 14,  1,  1, 17,  4]

最后一句话:由于您选择的是随机整数,您可能会多次在y中获得相同的值,这会导致最终数组中的"尖峰"更少。也许你已经知道了,或者这对你的实现并不重要,但我想提到它以防万一。祝你好运

在没有显式循环的情况下写入指示元素的一种可能性是:

spike_train[(np.repeat(np.arange(spikes.shape[0]), spikes.shape[1]), spikes.flatten())] = 1

但执行时间甚至更长。

另一种选择是将最后一条指令更改为:

spike_train[i, spikes[i]] = 1

(对于第二维度,不需要明确地传递":"(。这次执行速度比原来的代码快一点。

我不确定可读性部分,但如果你真的想消除for循环,你可以试试这个(我的真正答案是在代码的倒数第二行中-代码的其余部分只是设置演示数据或打印东西(:

import numpy as np
TRAIN_LEN = 10                       # Should be 2000 in reality
# Initialize with sequential values, just to demonstrate that this answer
# works correctly, and updates with ONES in the right places
spike_train = np.arange(2*TRAIN_LEN).reshape (2,TRAIN_LEN)
print (spike_train)
SPIKES_LEN = 5                        # Should be 100 in reality
spikes = np.random.randint(low=0, high = spike_train.shape[1]-1, size = (2,SPIKES_LEN)) # Indexes to fill
print (spikes)
# Here's where we actually do the update
spike_train[np.arange(spikes.shape[0], dtype=np.intp).reshape(-1,1), spikes] = 1
print(spike_train)

此打印:

[[ 0  1  2  3  4  5  6  7  8  9]
[10 11 12 13 14 15 16 17 18 19]]
[[4 8 0 6 0]
[8 8 4 3 3]]

更新后:

[[ 1  1  2  3  1  5  1  7  1  9]
[10 11 12  1  1 15 16 17  1 19]]

请注意,2x10序列号数组现在已更新,在所有正确的位置都有1的

最新更新