在张量列表上使用 torch.cat



我有一个torch tensors列表

list_tensor = [tensor([[1, 2, 3],
[3, 4, 5]]), 
tensor([[4, 5, 6],
[6, 4, 3]]), 
tensor([[4, 2, 1],
[3, 3, 1]]), 
tensor([[1, 4, 5],
[3, 1, 0]]), 
tensor([[1, 3, 3],
[2, 2, 2]])]

我想对这个集合进行交叉验证,所以我想考虑四个张量作为训练,并保留 1 个进行测试 - 我想这样做len(list_tensor)次。

所以我想做,

for num in range(1, len(list_tensor) + 1):
train_x = torch.cat((list_tensor[:num], list_tensor[num:]))

问题是我不能将列表用于torch.cat操作,因为list_tensor[:num]list_tensor[num:]都返回列表。例如,对于num = 1

list_tensor[:num] = [tensor([[1, 2, 3], [3, 4, 5]])]
list_tensor[num:] = [tensor([[4, 5, 6], [6, 4, 3]]), tensor([[4, 2, 1],[3, 3, 1]]), tensor([[1, 4, 5], 
[3, 1, 0]]), tensor([[1, 3, 3], [2, 2, 2]])]

我该如何对此执行 torch.cat?

我找到了不使用reduce的解决方法。

train_x = torch.cat((torch.cat(list_tensor[:num+1]),torch.cat(list_tensor[num+1:])))基本上连接单个列表中的所有张量,这将返回一个torch.tensor对象,然后在两者上使用torch.cat

您可以使用reduce

import torch as T
from functools import reduce
reduce(lambda x,y: T.cat((x,y)), list_tensor[:-1])

基本思想是将 concat 运算符应用于列表中的所有张量,除了最后一个张量,并继续聚合结果。您可以在此处阅读更多内容。

最新更新