如何在pytorch中得到具有特定值条件的亚二维张量



我想将一个二维火炬张量复制到一个只包含值的目标张量,直到202值第一次出现,其余项为零,如下所示:

source_t=tensor[[101,2001,2034,1045,202,3454,3453,1234,202]
,[101,1999,2808,202,17658,3454,202,0,0]
,[101,2012,3832,4027,3454,202,3454,9987,202]]
destination_t=tensor[[101,2001,2034,1045,202,0,0,0,0]
,[101,1999,2808,202,0,0,0,0,0]
,[101,2012,3832,4027,3454,202,0,0,0]]

我该怎么做?

我制作了一个有效且非常高效的解决方案。

我制作了一个更复杂的源张量,在不同的地方添加了202行:

import copy
import torch
source_t = torch.tensor([[101, 2001, 2034, 1045, 202, 3454, 3453, 1234, 202],
[101, 1999, 2808, 202, 17658, 3454, 202, 0, 0],
[101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[101, 2012, 3832, 4027, 3454, 202, 3454, 9987, 202],
[101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]
])

一开始,我们应该发现第一个202的出现。我们可以找到所有出现的情况,然后选择第一个:

index_202 = (source_t == 202).nonzero(as_tuple=False).numpy()
rows_for_replace = list()
columns_to_replace = list()
elements = source_t.shape[1]
current_ind = 0
while current_ind < len(index_202)-1:
current = index_202[current_ind]
element_ind = current[1] + 1
rows_for_replace.extend([current[0]]*(elements-element_ind))
while element_ind < elements:
columns_to_replace.append(element_ind)
element_ind += 1
if current[0] == index_202[current_ind+1][0]:
current_ind += 1
current_ind += 1

在这个操作之后,我们得到了所有的索引,我们应该用零来替换这些索引。第一行有4个元素,第二行有5个,第四行有3个,第三行和第五行没有任何元素。rows_for_replace, columns_to_replace

([0, 0, 0, 0, 1, 1, 1, 1, 1, 3, 3, 3], [5, 5, 5, 5, 4, 4, 4, 4, 4, 6, 6, 6])

然后我们只需复制我们的源张量,并在适当的位置设置新值:

destination_t = copy.deepcopy(source_t)
destination_t[rows_for_replace, columns_to_replace] = 0

摘要:source_t

tensor([[  101,  2001,  2034,  1045,   202,  3454,  3453,  1234,   202],
[  101,  1999,  2808,   202, 17658,  3454,   202,     0,     0],
[  101,  2012,  3832,  4027,  3454,  2020,  3454,  9987,  2020],
[  101,  2012,  3832,  4027,  3454,   202,  3454,  9987,   202],
[  101,  2012,  3832,  4027,  3454,  2020,  3454,  9987,   202]])

destination_t

tensor([[ 101, 2001, 2034, 1045,  202,    0,    0,    0,    0],
[ 101, 1999, 2808,  202,    0,    0,    0,    0,    0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[ 101, 2012, 3832, 4027, 3454,  202,    0,    0,    0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987,  202]])

我认为有一个更好的解决方案,它要求每一行都有一个"202〃;,否则,您将不得不删除不属于这种情况的行。

import torch
t = torch.tensor([[101,2001,2034,1045,202,3454,3453,1234,202],
[101,1999,2808,202,17658,3454,202,0,0],
[101,2012,3832,4027,3454,202,3454,9987,202]])
out = t.clone() # make copy

检查张量等于202的位置,将boolean转换为int,并为每行取argmax,这意味着我们有第一个1出现的列,它对应于第一个202。

然后迭代每行

cols = t.eq(202).int().argmax(1)
k = t.shape[1] # number of columns
for idx, c in enumerate(cols):
if c + 1 < k:
out[idx, c+1:] = 0 # make all values right of c equal to zero

最新更新