My Task:
我试图计算两个大张量中每两个样本之间的成对距离(对于k-最近邻),即给定形状为(b1,c,h,w)
的张量test
和形状为(b2,c,h,w)
的张量train
,我需要|| test[i]-train[j] ||
用于每个i
,j
。(其中test[i]
和train[j]
的形状都是(c,h,w)
,因为它们是批样品)。
问题
train
和test
都非常大,所以我无法将它们放入RAM
我当前的解决方案
首先,我没有一次构造这些张量——当我构建它们时,我拆分数据张量并将它们单独保存到内存中,所以我最终得到文件{Testtest_1,...,Testtest_n}
和{Traintrain_1,...,Traintrain_m}
。然后,我在每个Testtest_i
和Traintrain_j
中加载一个嵌套的for
循环,计算当前距离,并保存它。
这个半伪代码可以解释
test_files = [f'Testtest_{i}' for i in range(n)]
train_files = [f'Traintrain_{j}' for j in range(m)]
dist = lambda t1,t2: torch.cdist(t1.flatten(1), t2.flatten(1))
all_distances = []
for test_i in test_files:
test_i = torch.load(test_i) # shape (c,h,w)
dist_of_i_from_all_j = torch.Tensor([])
for train_j in train_files:
train_j = torch.load(train_j) # shape (c,h,w)
dist_of_i_from_all_j = torch.cat((dist_of_i_from_all_j, dist(test_i,train_j))
all_distances.append(dist_of_i_from_all_j)
# and now I can take the k-smallest from all_distances
我认为可能可行
我遇到了FAISS存储库,他们在其中解释说这个过程可以使用他们的解决方案来加速(也许?),尽管我不太确定如何加速。无论如何,任何方法都会有所帮助!
您是否检查了FAISS文档?
如果你需要的是L2范数(torch.cidst
使用p=2
作为默认参数),那么它是相当简单的。下面的代码是FAISS文档对您的示例的改编:
import faiss
import numpy as np
d = 64 # dimension
nb = 100000 # database size
nq = 10000 # nb of queries
np.random.seed(1234) # make reproducible
x_test = np.random.random((nb, d)).astype('float32')
x_test[:, 0] += np.arange(nb) / 1000.
x_train = np.random.random((nq, d)).astype('float32')
x_train[:, 0] += np.arange(nq) / 1000.
index = faiss.IndexFlatL2(d) # build the index
print(index.is_trained)
index.add(x_test) # add vectors to the index
print(index.ntotal)
k= 100 # take the 100 closest neighbors
D, I = index.search(x_train, k) # actual search
print(I[:5]) # neighbors of the 100 first queries
print(I[-5:]) # neighbors of the 100 last queries
因此,我选择了实现一些版本的Earth-Movers-Distance,建议在以下ai.StackExchange
职位。让我总结一下方法:
给定">我的任务&"中描述的任务;上面,我定义了
def cumsum_3d(test, train):
for i in [-1, -2, -3]:
test = torch.cumsum(test, i)
train = torch.cumsum(train, i)
return test, train
那么,给定张量test
和train
:
test,train = cumsum_3d(test,train)
dist = torch.cdist(test.flatten(1),train.flatten(1))
对于未来的观众-请记住:
- 我没有使用
FAISS
,因为它目前不支持windows,但最重要的是它不支持(据我所知)这个版本的EMD或任何其他版本的多维(=形状(c,h,w)
像我的例子)张量距离。为了解决RAM问题,我使用Google Colab
并将数据切片到更多文件 - 这个实现只与我处理浅激活层相关。如果我使用最后一层(
avgpool
)作为我的激活,不使用EMD就可以了,因为avgpool
之后的输出形状为(512,)