根据部分和添加填充



我有四个给定的变量:

  • 集团大小
  • 分组总数
  • 部分和
  • 一维张量当一组内的和达到部分和时,我想加零。例如:
groupsize = 4
totalgroups = 3
partialsum = 15
d1tensor = torch.tensor([ 3, 12,  5,  5,  5,  4, 11])

预期结果为:[3, 12, 0, 0, 5, 5, 0, 4, 11, 0, 0]

我不知道如何在纯pytorch中实现这一点。在python中应该是这样的:

target = [0]*(groupsize*totalgroups)
cursor = 0
current_count = 0
d1tensor = [ 3, 12,  5,  5,  5,  4, 11]
for idx, ele in enumerate(target):
subgroup_start = (idx//groupsize) *groupsize
subgroup_end = subgroup_start + groupsize 
if sum(target[subgroup_start:subgroup_end]) < partialsum: 
target[idx] = d1tensor[cursor]
cursor +=1
有谁能帮我一下吗?我已经用谷歌搜索过了,但是什么也没找到。

一些逻辑、Numpy和列表推导式在这里就足够了。我会一步一步地分解它,你可以把它做得更细更漂亮:

import numpy as np
my_val = 15
block_size = 4
total_groups = 3
d1 = [3, 12,  5,  5,  5,  4, 11]
d2 = np.cumsum(d1)
d3 = d2 % my_val == 0 #find where sum of elements is 15 or multiple
split_points= [i+1 for i, x in enumerate(d3) if x] # find index where cumsum == my_val
#### Option 1
split_array = np.split(d1, split_points, axis=0)
padded_arrays = [np.pad(array, (0, block_size - len(array)), mode='constant') for array in split_array] #pad arrays
padded_d1 = np.concatenate(padded_arrays[:total_groups]) #put them together, discard extra group if present
#### Option 2
split_points = [el for el in split_points if el <len(d1)] #make sure we are not splitting on the last element of d1
split_array = np.split(d1, split_points, axis=0)
padded_arrays = [np.pad(array, (0, block_size - len(array)), mode='constant') for array in split_array] #pad arrays
padded_d1 = np.concatenate(padded_arrays)

最新更新