我想将一个二维火炬张量复制到一个只包含值的目标张量,直到202值第一次出现,其余项为零,如下所示:
source_t=tensor[[101,2001,2034,1045,202,3454,3453,1234,202]
,[101,1999,2808,202,17658,3454,202,0,0]
,[101,2012,3832,4027,3454,202,3454,9987,202]]
destination_t=tensor[[101,2001,2034,1045,202,0,0,0,0]
,[101,1999,2808,202,0,0,0,0,0]
,[101,2012,3832,4027,3454,202,0,0,0]]
我该怎么做?
我制作了一个有效且非常高效的解决方案。
我制作了一个更复杂的源张量,在不同的地方添加了202行:
import copy
import torch
source_t = torch.tensor([[101, 2001, 2034, 1045, 202, 3454, 3453, 1234, 202],
[101, 1999, 2808, 202, 17658, 3454, 202, 0, 0],
[101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[101, 2012, 3832, 4027, 3454, 202, 3454, 9987, 202],
[101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]
])
一开始,我们应该发现第一个202的出现。我们可以找到所有出现的情况,然后选择第一个:
index_202 = (source_t == 202).nonzero(as_tuple=False).numpy()
rows_for_replace = list()
columns_to_replace = list()
elements = source_t.shape[1]
current_ind = 0
while current_ind < len(index_202)-1:
current = index_202[current_ind]
element_ind = current[1] + 1
rows_for_replace.extend([current[0]]*(elements-element_ind))
while element_ind < elements:
columns_to_replace.append(element_ind)
element_ind += 1
if current[0] == index_202[current_ind+1][0]:
current_ind += 1
current_ind += 1
在这个操作之后,我们得到了所有的索引,我们应该用零来替换这些索引。第一行有4个元素,第二行有5个,第四行有3个,第三行和第五行没有任何元素。rows_for_replace, columns_to_replace
([0, 0, 0, 0, 1, 1, 1, 1, 1, 3, 3, 3], [5, 5, 5, 5, 4, 4, 4, 4, 4, 6, 6, 6])
然后我们只需复制我们的源张量,并在适当的位置设置新值:
destination_t = copy.deepcopy(source_t)
destination_t[rows_for_replace, columns_to_replace] = 0
摘要:source_t
tensor([[ 101, 2001, 2034, 1045, 202, 3454, 3453, 1234, 202],
[ 101, 1999, 2808, 202, 17658, 3454, 202, 0, 0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[ 101, 2012, 3832, 4027, 3454, 202, 3454, 9987, 202],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]])
destination_t
tensor([[ 101, 2001, 2034, 1045, 202, 0, 0, 0, 0],
[ 101, 1999, 2808, 202, 0, 0, 0, 0, 0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[ 101, 2012, 3832, 4027, 3454, 202, 0, 0, 0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]])
我认为有一个更好的解决方案,它要求每一行都有一个"202〃;,否则,您将不得不删除不属于这种情况的行。
import torch
t = torch.tensor([[101,2001,2034,1045,202,3454,3453,1234,202],
[101,1999,2808,202,17658,3454,202,0,0],
[101,2012,3832,4027,3454,202,3454,9987,202]])
out = t.clone() # make copy
检查张量等于202的位置,将boolean转换为int,并为每行取argmax,这意味着我们有第一个1出现的列,它对应于第一个202。
然后迭代每行
cols = t.eq(202).int().argmax(1)
k = t.shape[1] # number of columns
for idx, c in enumerate(cols):
if c + 1 < k:
out[idx, c+1:] = 0 # make all values right of c equal to zero