我有一个torch tensors
列表
list_tensor = [tensor([[1, 2, 3],
[3, 4, 5]]),
tensor([[4, 5, 6],
[6, 4, 3]]),
tensor([[4, 2, 1],
[3, 3, 1]]),
tensor([[1, 4, 5],
[3, 1, 0]]),
tensor([[1, 3, 3],
[2, 2, 2]])]
我想对这个集合进行交叉验证,所以我想考虑四个张量作为训练,并保留 1 个进行测试 - 我想这样做len(list_tensor)
次。
所以我想做,
for num in range(1, len(list_tensor) + 1):
train_x = torch.cat((list_tensor[:num], list_tensor[num:]))
问题是我不能将列表用于torch.cat
操作,因为list_tensor[:num]
和list_tensor[num:]
都返回列表。例如,对于num = 1
,
list_tensor[:num] = [tensor([[1, 2, 3], [3, 4, 5]])]
list_tensor[num:] = [tensor([[4, 5, 6], [6, 4, 3]]), tensor([[4, 2, 1],[3, 3, 1]]), tensor([[1, 4, 5],
[3, 1, 0]]), tensor([[1, 3, 3], [2, 2, 2]])]
我该如何对此执行 torch.cat?
我找到了不使用reduce
的解决方法。
train_x = torch.cat((torch.cat(list_tensor[:num+1]),torch.cat(list_tensor[num+1:])))
基本上连接单个列表中的所有张量,这将返回一个torch.tensor
对象,然后在两者上使用torch.cat
。
您可以使用reduce
import torch as T
from functools import reduce
reduce(lambda x,y: T.cat((x,y)), list_tensor[:-1])
基本思想是将 concat 运算符应用于列表中的所有张量,除了最后一个张量,并继续聚合结果。您可以在此处阅读更多内容。