Python scikit-learn:不能克隆对象.因为构造函数似乎没有设置参数



我修改了scikit-learn的BernoulliRBM类,使用softmax可见单元组。在此过程中,我添加了一个额外的Numpy数组visible_config作为类属性,它在构造函数中初始化,使用如下:

self.visible_config = np.cumsum(np.concatenate((np.asarray([0]),
                                visible_config), axis=0))

其中visible_config是作为输入传递给构造函数的Numpy数组。当我直接使用fit()函数来训练模型时,代码运行没有错误。但是,当我使用GridSearchCV结构时,我得到以下错误

Cannot clone object SoftmaxRBM(batch_size=100, learning_rate=0.01, n_components=100, n_iter=100,
  random_state=0, verbose=True, visible_config=[ 0 21 42 63]), as the constructor does not seem to set parameter visible_config

这似乎是在类的实例和sklearn.base.clone创建的副本之间的相等性检查中的问题,因为visible_config没有得到正确的复制。我不知道怎么解决这个问题。它在文档中说,sklearn.base.clone使用deepcopy(),所以visible_config不应该被复制吗?有人能解释一下我能在这里做什么吗?谢谢!

如果没有看到您的代码,很难确切地知道哪里出了问题,但是您在这里违反了scikit-learn API约定。估算器中的构造函数应该将属性设置为用户作为参数传递的值。所有的计算都应该在fit中进行,如果fit需要存储计算结果,它应该在带有下划线的属性(_)中进行计算。这个约定是clone和元估计器(如GridSearchCV)工作的原因。

(*)如果您在主代码库中看到一个违反此规则的估计器:那将是一个错误,并且欢迎补丁。

最新更新