如何将torch.norm转换为余弦距离



我想将范数距离更改为余弦距离,请帮助我将此函数转换为余弦距离

def feat_prototype_distance(self, feat):
N, C, H, W = feat.shape
feat_proto_distance = -torch.ones((N, self.class_numbers, H, W)).to(feat.device)
for i in range(self.class_numbers):
feat_proto_distance[:, i, :, :] = torch.norm(self.objective_vectors[i].reshape(-1,1,1).expand(-1, H, W) - feat, 2, dim=1)
return feat_proto_distance

这是使用形状的范数距离的原始函数:self.objective_vectors[i].reform(-1,1,1(.expand(-1,H,W(:torch。大小([256122224](壮举:火炬。带8的大小([8256128224](为batch_Size

您可以使用torch.nn.CosineSimilarity

我不太清楚你的代码结构,但你可能会做一些类似的事情:

def feat_prototype_distance(self, feat):
distance_metric = torch.nn.CosineSimilarity(dim=1)
N, C, H, W = feat.shape
feat_proto_distance = -torch.ones((N, self.class_numbers, H, W)).to(feat.device)
for i in range(self.class_numbers):
feat_proto_distance[:, i, :, :] = distance_metric(self.objective_vectors[i].reshape(-1,1,1).expand(-1, H, W), feat)
return feat_proto_distance

最新更新