如何在pytorch中进行分组argmax



Pytorch中有没有任何方法可以根据组中子向量的范数来实现maxpooling?具体来说,这就是我想要实现的:

输入

x:二维浮动张量,形状#Nodes*dim

:一维长张量,形状#节点

输出

y,一个二维浮动张量,和:

y[i]=x[k],其中k=argmax_{cluster[k]=i}(torc.norm(x[k]p=2((.

我尝试了torch.scatterreduce="max",但这只适用于dim=1 and x[i]>0

有人能帮我解决这个问题吗?

我不认为有任何内置函数可以实现您想要的功能。基本上,这将是对x的范数的某种形式的scatter_reduce,但不是选择最大范数,而是选择与最大范数对应的行。

一个简单的实现可能看起来像这个

"""
input
x: float tensor of size [NODES, DIMS]
cluster: long tensor of size [NODES]
output
float tensor of size [cluster.max()+1, DIMS]
"""
num_clusters = cluster.max().item() + 1
y = torch.zeros((num_clusters, DIMS), dtype=x.dtype, device=x.device)
for cluster_id in torch.unique(cluster):
x_cluster = x[cluster == cluster_id]
y[cluster_id] = x_cluster[torch.argmax(torch.norm(x_cluster, dim=1), dim=0)]

如果clusters.max()相对较小,它应该可以正常工作。如果有很多集群,那么这种方法必须在cluster上为每个唯一的集群id创建不必要的掩码。为了避免这种情况,可以使用argsort。在纯python中,我能想到的最好的是以下内容。

num_clusters = cluster.max().item() + 1
x_norm = torch.norm(x, dim=1)
cluster_sortidx = torch.argsort(cluster)
cluster_ids, cluster_counts = torch.unique_consecutive(cluster[cluster_sortidx], return_counts=True)
end_indices = torch.cumsum(cluster_counts, dim=0).cpu().tolist()
start_indices = [0] + end_indices[:-1]
y = torch.zeros((num_clusters, DIMS), dtype=x.dtype, device=x.device)
for cluster_id, a, b in zip(cluster_ids, start_indices, end_indices):
indices = cluster_sortidx[a:b]
y[cluster_id] = x[indices[torch.argmax(x_norm[indices], dim=0)]]

例如,在NODES = 60000DIMS = 512cluster.max()=6000的随机测试中,第一个版本大约需要620ms,而第二个版本大约花费78ms。

最新更新