Pytorch中有没有任何方法可以根据组中子向量的范数来实现maxpooling?具体来说,这就是我想要实现的:
输入:
x:二维浮动张量,形状#Nodes*dim
簇:一维长张量,形状#节点
输出:
y,一个二维浮动张量,和:
y[i]=x[k],其中k=argmax_{cluster[k]=i}(torc.norm(x[k]p=2((.
我尝试了torch.scatter
和reduce="max"
,但这只适用于dim=1 and x[i]>0
。
有人能帮我解决这个问题吗?
我不认为有任何内置函数可以实现您想要的功能。基本上,这将是对x
的范数的某种形式的scatter_reduce,但不是选择最大范数,而是选择与最大范数对应的行。
一个简单的实现可能看起来像这个
"""
input
x: float tensor of size [NODES, DIMS]
cluster: long tensor of size [NODES]
output
float tensor of size [cluster.max()+1, DIMS]
"""
num_clusters = cluster.max().item() + 1
y = torch.zeros((num_clusters, DIMS), dtype=x.dtype, device=x.device)
for cluster_id in torch.unique(cluster):
x_cluster = x[cluster == cluster_id]
y[cluster_id] = x_cluster[torch.argmax(torch.norm(x_cluster, dim=1), dim=0)]
如果clusters.max()
相对较小,它应该可以正常工作。如果有很多集群,那么这种方法必须在cluster
上为每个唯一的集群id创建不必要的掩码。为了避免这种情况,可以使用argsort
。在纯python中,我能想到的最好的是以下内容。
num_clusters = cluster.max().item() + 1
x_norm = torch.norm(x, dim=1)
cluster_sortidx = torch.argsort(cluster)
cluster_ids, cluster_counts = torch.unique_consecutive(cluster[cluster_sortidx], return_counts=True)
end_indices = torch.cumsum(cluster_counts, dim=0).cpu().tolist()
start_indices = [0] + end_indices[:-1]
y = torch.zeros((num_clusters, DIMS), dtype=x.dtype, device=x.device)
for cluster_id, a, b in zip(cluster_ids, start_indices, end_indices):
indices = cluster_sortidx[a:b]
y[cluster_id] = x[indices[torch.argmax(x_norm[indices], dim=0)]]
例如,在NODES = 60000
、DIMS = 512
、cluster.max()=6000
的随机测试中,第一个版本大约需要620ms,而第二个版本大约花费78ms。