我正在寻找一种更可读的方法来完成数组赋值(for循环部分(。
spike_train = np.zeros((2,2000)) # Array to be filled
spikes = np.random.randint(low=0, high = spike_train.shape[1]-1, size = (2,100)) # Indexes to fill
for i in range(spikes.shape[0]): # I feel that this part could be more readable
spike_train[i, spikes[i,:]] = 1
#I was thinking something along these lines: spike_train[spikes] = 1, but I know this doesn't work
您需要广播第一个维度(而不是像@Valdi_Bo和@Andre do那样广播repeat
——它构造了一个大的中间数组(:
spike_train = np.zeros((2,2000)) # Array to be filled
spikes = np.random.randint(low=0, high = spike_train.shape[1]-1, size = (2,100)) # Indexes to fill
spike_train[np.arange(spikes.shape[0])[:, None], spikes] = 1
在这种情况下,您也可以使用np.put_along_axis
:
np.put_along_axis(spike_train, spikes, values = 1, axis = 1)
编辑:正如@DanielF所指出的,我在这里描述的方法不是最佳实践。我将把它留在这里供参考,但请参阅@DanielF的asnwer以获得更优化的解决方案。
另一种方法是得到";CCD_ 3、CCD_;并将其设置为1。
其中x
指代行(即0
或1
(,y
指代索引。
以下是如何实现这一点:
import numpy as np
spike_train = np.zeros((2,2000)) # Array to be filled
n_spikes_per_row = 100
x = np.repeat([0,1], n_spikes_per_row)
y = np.random.randint(low=0, high = spike_train.shape[1]-1, size = 2 * n_spikes_per_row) # Indexes to fill
spike_train[x, y] = 1
也许有点多余,但x
和y
是平面阵列,包含随机选择的坐标,如下所示:
x = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y = [13, 5, 7, 4, 9, 14, 1, 1, 17, 4]
最后一句话:由于您选择的是随机整数,您可能会多次在y
中获得相同的值,这会导致最终数组中的"尖峰"更少。也许你已经知道了,或者这对你的实现并不重要,但我想提到它以防万一。祝你好运
在没有显式循环的情况下写入指示元素的一种可能性是:
spike_train[(np.repeat(np.arange(spikes.shape[0]), spikes.shape[1]), spikes.flatten())] = 1
但执行时间甚至更长。
另一种选择是将最后一条指令更改为:
spike_train[i, spikes[i]] = 1
(对于第二维度,不需要明确地传递":"(。这次执行速度比原来的代码快一点。
我不确定可读性部分,但如果你真的想消除for
循环,你可以试试这个(我的真正答案是在代码的倒数第二行中-代码的其余部分只是设置演示数据或打印东西(:
import numpy as np
TRAIN_LEN = 10 # Should be 2000 in reality
# Initialize with sequential values, just to demonstrate that this answer
# works correctly, and updates with ONES in the right places
spike_train = np.arange(2*TRAIN_LEN).reshape (2,TRAIN_LEN)
print (spike_train)
SPIKES_LEN = 5 # Should be 100 in reality
spikes = np.random.randint(low=0, high = spike_train.shape[1]-1, size = (2,SPIKES_LEN)) # Indexes to fill
print (spikes)
# Here's where we actually do the update
spike_train[np.arange(spikes.shape[0], dtype=np.intp).reshape(-1,1), spikes] = 1
print(spike_train)
此打印:
[[ 0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]]
[[4 8 0 6 0]
[8 8 4 3 3]]
更新后:
[[ 1 1 2 3 1 5 1 7 1 9]
[10 11 12 1 1 15 16 17 1 19]]
请注意,2x10序列号数组现在已更新,在所有正确的位置都有1的。