使用sklearn.neighbors.BallTree的自定义度量给出错误的输入



我正在尝试使用sklearn.neighbors的自定义度量。BallTree,但是当它调用我的度量时,输入看起来不正确。如果我将scipy. space .distance.pdist与相同的自定义度量一起使用,它就会像预期的那样工作。但是,如果我尝试实例化BallTree,当我尝试重塑输入时将引发异常。如果我查看实际输入,形状和值看起来不正确。

import numpy as np
import scipy.spatial.distance as spdist
import sklearn.neighbors.ball_tree as ball_tree

# custom metric
def minimum_average_direct_flip(x, y):
    x = np.reshape(x, (-1, 3))
    y = np.reshape(y, (-1, 3))
    direct = np.mean(np.sqrt(np.sum(np.square(x - y), axis=1)))
    flipped = np.mean(np.sqrt(np.sum(np.square(np.flipud(x) - y), axis=1)))
    return min(direct, flipped)
# create an X to test
X = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19], [21, 22, 23, 24, 25, 26, 27, 28, 29]])
# works as expected
distances = spdist.pdist(X, metric=minimum_average_direct_flip)
# outputs: [ 17.32050808  34.64101615  17.32050808]
print distances
# raises exception, inputs to minimum_average_direct_flip look wrong
# Traceback (most recent call last):
#   File ".../test_script.py", line 23, in <module>
#     ball_tree.BallTree(X, metric=minimum_average_direct_flip)
#   File "sklearn/neighbors/binary_tree.pxi", line 1059, in sklearn.neighbors.ball_tree.BinaryTree.__init__ (sklearnneighborsball_tree.c:8381)
#   File "sklearn/neighbors/dist_metrics.pyx", line 262, in sklearn.neighbors.dist_metrics.DistanceMetric.get_metric (sklearnneighborsdist_metrics.c:4032)
#   File "sklearn/neighbors/dist_metrics.pyx", line 1091, in sklearn.neighbors.dist_metrics.PyFuncDistance.__init__ (sklearnneighborsdist_metrics.c:10586)
#   File "C:/Users/danrs/Documents/neuro_atlas/test_script.py", line 8, in minimum_average_direct_flip
#     x = np.reshape(x, (-1, 3))
#   File "C:Anaconda2libsite-packagesnumpycorefromnumeric.py", line 225, in reshape
#     return reshape(newshape, order=order)
# ValueError: total size of new array must be unchanged
ball_tree.BallTree(X, metric=minimum_average_direct_flip)

在BallTree代码对minimum_average_direct_flip的第一次调用中,输入是:

x = [ 0.4238394   0.55205233  0.04699435  0.19542642  0.20331665  0.44594837 0.35634537  0.8200018   0.28598294  0.34236847]
y = [ 0.4238394   0.55205233  0.04699435  0.19542642  0.20331665  0.44594837 0.35634537  0.8200018   0.28598294  0.34236847]

这些看起来完全不对。是我调用它的方式出错了还是这是sklearn的一个bug ?

这似乎是一个已知的问题:https://github.com/scikit-learn/scikit-learn/issues/6287

他们做了一些有问题的验证步骤。作为一种解决方法,我想我可以添加对输入大小的检查,但正如问题所指出的那样,这是不可取的,因为我不能自己进行实际的验证检查。

相关内容

最新更新