scikit-learn管道_transform()采用'x'位置参数,但给出了'y'



问题:

我正在使用scikit-learn的管道设计一个自定义转换器,但位置参数不匹配。我定义的类是:

class DataSubsetGenerator(BaseEstimator, TransformerMixin):
def __init__(self, sub_percentage, random_state = 42):
self.sub_percentage = sub_percentage
self.random_state = random_state
def fit(self):
return self
def transform(self, X_train, X_test, y_train, y_test):
# Do data processing stuff here, removed to simplify example here...
return X_train_sub, X_test_sub, y_train_sub, y_test_sub

然后,我将其放入一个1步自定义管道中进行测试:

reduce_pipeline = Pipeline([
('Prototype dataset', DataSubsetGenerator(0.5, random_state = random_state))
])
X_train, X_test, y_train, y_test = reduce_pipeline.transform(X_train, X_test, y_train, y_test)

我收到错误:

TypeError                                 Traceback (most recent call last)
<ipython-input-42-4b2a20eb8b63> in <module>()
3 ])
4 
----> 5 X_train, X_test, y_train, y_test = reduce_pipeline.transform(X_train, X_test, y_train, y_test)
TypeError: _transform() takes 2 positional arguments but 5 were given

这毫无意义,因为我已经将DataSubGenerator类的transform()函数定义为接受4个参数。

我的测试:

我在不使用sklearn的管道的情况下通过实例化DataSubGenerator和调用transform()进行了测试,它的功能符合设计:

dsg = DataSubsetGenerator(0.5, random_state = random_state)
X_train, X_test, y_train, y_test = dsg.transform(X_train, X_test, y_train, y_test)

我的问题是:为什么transform()函数在sklearn管道中使用时不能识别这4个参数

相关问答;A:

我试过研究,最接近的问答;线程是这样的:_transform((接受2个位置参数,但给出了3个。然而,我无法理解该解决方案以及它如何应用于我的场景。

由于这一行而出现错误。这里的期望是,当管道的最后一步有transform方法时,您将只提供X,这意味着它是从regression Mixin或classifierMixin继承的。

首先,我们需要理解sklearn的估计器遵循(X, y)的API设计。这也是API管道设计的原因。

因此,在将数据输入管道之前,需要进行数据拆分或采样。

我已经通过修改transformer类来解决这个问题,并返回单个列表(包含多个数据帧(:

class DataSubsetGenerator(BaseEstimator, TransformerMixin):
def __init__(self, sub_percentage, random_state = 42):
self.sub_percentage = sub_percentage
self.random_state = random_state
def fit(self):
return self
def transform(self, dataframes):
X_train, X_test, y_train, y_test = dataframes
# Do data processing stuff here, removed to simplify example here...
return [X_train_sub, X_test_sub, y_train_sub, y_test_sub]

如果有更好的解决方案或公认的模式,请随时告诉我。

最新更新