Python MeanShift内存错误



我正在sklearn.cluster模块中运行一个名为MeanShift()的集群算法(这里是文档)。我正在处理的对象有310057个点分布在三维空间中。我运行它的计算机总共有128Gb的内存,所以当我遇到以下错误时,我很难相信我真的在使用所有的内存

[user@host ~]$ python meanshifttest.py
Traceback (most recent call last):
  File "meanshifttest.py", line 13, in <module>
    ms = MeanShift().fit(X)
  File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 280, in fit
    cluster_all=self.cluster_all)
  File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 99, in mean_shift
bandwidth = estimate_bandwidth(X)
  File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 45, in estimate_bandwidth
d, _ = nbrs.kneighbors(X, return_distance=True)
  File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/neighbors/base.py", line 313, in kneighbors
return_distance=return_distance)
  File "binary_tree.pxi", line 1313, in sklearn.neighbors.kd_tree.BinaryTree.query (sklearn/neighbors/kd_tree.c:10007)
  File "binary_tree.pxi", line 595, in sklearn.neighbors.kd_tree.NeighborsHeap.__init__ (sklearn/neighbors/kd_tree.c:4709)
MemoryError

我运行的代码如下:

from sklearn.cluster import MeanShift
import asciitable
import numpy as np
import time
data = asciitable.read('./multidark_MDR1_FOFID85000000000_ParticlePos.csv',delimiter=',')
x = [data[i][2] for i in range(len(data))]
y = [data[i][3] for i in range(len(data))]
z = [data[i][4] for i in range(len(data))]
X = np.array(zip(x,y,z))
t0 = time.time()
ms = MeanShift().fit(X)
t1 = time.time()
print str(t1-t0) + " seconds."
labels = ms.labels_
print set(labels)

有人知道发生了什么吗?不幸的是,我无法切换聚类算法,因为这是我发现的唯一一个除了不接受链接长度/k个聚类数/先验信息之外还做得很好的算法。

提前感谢!

**更新:我又看了一下文件,上面写着:

可扩展性:

因为此实现使用了一个平面内核和
一个球树来查找每个内核的成员,其复杂性将为
到O(T*n*log(n)),其中n为样本数
T是点数。在更高维度中,复杂性将
趋向于O(T*n^2)。

可以通过使用更少的种子来提高可扩展性,例如使用
在get_bin_seeds函数中的min_bin_freq的较高值。

请注意,estimate_bandwidth函数的可伸缩性远不如
均值偏移算法,如果使用它将成为瓶颈。

这似乎有一定的道理,因为如果你仔细观察这个错误,它会抱怨estimate_bandwidth。这是否表明我只是在算法中使用了太多粒子?

从错误消息判断,我怀疑它正在尝试计算点之间的所有成对距离,这意味着它需要310057²的浮点数或716GB的RAM。

可以通过向MeanShift构造函数提供显式bandwidth参数来禁用此行为。

这可以说是一个bug;考虑为其提交一份错误报告。(包括我在内的scikit学习团队最近一直在努力消除各种地方过于昂贵的距离计算,但显然没有人关注meanshift。)

EDIT:上面的计算结果相差了3倍,但内存使用率确实是二次型的。我刚刚在scikit-learn的开发版本中修复了这个问题。

相关内容

  • 没有找到相关文章

最新更新