如何在管道中正确使用FunctionTransformer ?



我正在尝试在我用通用句子编码器创建的句子嵌入上训练支持向量机。我在管道内部使用FunctionTransformer来拟合我的模型,但我得到以下错误:

TypeError: can't pickle _thread.RLock objects

%tensorflow_version 1.x
import tensorflow as tf
import tensorflow_hub as hub
import pandas as pd
import numpy as np
from sklearn.pipeline import make_pipeline
from sklearn.compose import make_column_transformer
from sklearn.preprocessing import FunctionTransformer
tos = pd.DataFrame({
"Character" : ["KIRK", "SPOCK"],
"Lines" : ["Shall we pick some flowers, Doctor?","Check the circuit."]
})
X = pd.DataFrame(tos["Lines"], columns = ["Lines"])
Y = tos["Character"]
x_train, x_test, y_train, y_test = train_test_split(X,Y)
embed = hub.Module("/content/module/")
pipe = make_pipeline(
make_column_transformer(
(FunctionTransformer(embed), "Lines")
),
SVC()
)
pipe.fit(x_train,y_train);

我注意到FunctionTransformer的文档提到

如果使用lambda作为函数,则生成的转换器不能酸洗。

但这似乎不是问题,因为我没有将此函数定义为lambda。

完整回溯

---------------------------------------------------------------------------
Empty                                     Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/joblib/parallel.py in dispatch_one_batch(self, iterator)
821             try:
--> 822                 tasks = self._ready_batches.get(block=False)
823             except queue.Empty:
21 frames
/usr/lib/python3.7/queue.py in get(self, block, timeout)
166                 if not self._qsize():
--> 167                     raise Empty
168             elif timeout is None:
Empty: 

During handling of the above exception, another exception occurred:
TypeError                                 Traceback (most recent call last)
<ipython-input-69-a981c354b190> in <module>()
----> 1 pipe.fit(x_train,y_train)
/usr/local/lib/python3.7/dist-packages/sklearn/pipeline.py in fit(self, X, y, **fit_params)
388         """
389         fit_params_steps = self._check_fit_params(**fit_params)
--> 390         Xt = self._fit(X, y, **fit_params_steps)
391         with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
392             if self._final_estimator != "passthrough":
/usr/local/lib/python3.7/dist-packages/sklearn/pipeline.py in _fit(self, X, y, **fit_params_steps)
353                 message_clsname="Pipeline",
354                 message=self._log_message(step_idx),
--> 355                 **fit_params_steps[name],
356             )
357             # Replace the transformer of the step with the fitted
/usr/local/lib/python3.7/dist-packages/joblib/memory.py in __call__(self, *args, **kwargs)
347 
348     def __call__(self, *args, **kwargs):
--> 349         return self.func(*args, **kwargs)
350 
351     def call_and_shelve(self, *args, **kwargs):
/usr/local/lib/python3.7/dist-packages/sklearn/pipeline.py in _fit_transform_one(transformer, X, y, weight, message_clsname, message, **fit_params)
891     with _print_elapsed_time(message_clsname, message):
892         if hasattr(transformer, "fit_transform"):
--> 893             res = transformer.fit_transform(X, y, **fit_params)
894         else:
895             res = transformer.fit(X, y, **fit_params).transform(X)
/usr/local/lib/python3.7/dist-packages/sklearn/compose/_column_transformer.py in fit_transform(self, X, y)
673         self._validate_remainder(X)
674 
--> 675         result = self._fit_transform(X, y, _fit_transform_one)
676 
677         if not result:
/usr/local/lib/python3.7/dist-packages/sklearn/compose/_column_transformer.py in _fit_transform(self, X, y, func, fitted, column_as_strings)
613                     message=self._log_message(name, idx, len(transformers)),
614                 )
--> 615                 for idx, (name, trans, column, weight) in enumerate(transformers, 1)
616             )
617         except ValueError as e:
/usr/local/lib/python3.7/dist-packages/joblib/parallel.py in __call__(self, iterable)
1041             # remaining jobs.
1042             self._iterating = False
-> 1043             if self.dispatch_one_batch(iterator):
1044                 self._iterating = self._original_iterator is not None
1045 
/usr/local/lib/python3.7/dist-packages/joblib/parallel.py in dispatch_one_batch(self, iterator)
831                 big_batch_size = batch_size * n_jobs
832 
--> 833                 islice = list(itertools.islice(iterator, big_batch_size))
834                 if len(islice) == 0:
835                     return False
/usr/local/lib/python3.7/dist-packages/sklearn/compose/_column_transformer.py in <genexpr>(.0)
613                     message=self._log_message(name, idx, len(transformers)),
614                 )
--> 615                 for idx, (name, trans, column, weight) in enumerate(transformers, 1)
616             )
617         except ValueError as e:
/usr/local/lib/python3.7/dist-packages/sklearn/base.py in clone(estimator, safe)
84     new_object_params = estimator.get_params(deep=False)
85     for name, param in new_object_params.items():
---> 86         new_object_params[name] = clone(param, safe=False)
87     new_object = klass(**new_object_params)
88     params_set = new_object.get_params(deep=False)
/usr/local/lib/python3.7/dist-packages/sklearn/base.py in clone(estimator, safe)
65     elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
66         if not safe:
---> 67             return copy.deepcopy(estimator)
68         else:
69             if isinstance(estimator, type):
/usr/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
178                     y = x
179                 else:
--> 180                     y = _reconstruct(x, memo, *rv)
181 
182     # If is its own copy, don't memoize.
/usr/lib/python3.7/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
279     if state is not None:
280         if deep:
--> 281             state = deepcopy(state, memo)
282         if hasattr(y, '__setstate__'):
283             y.__setstate__(state)
/usr/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
148     copier = _deepcopy_dispatch.get(cls)
149     if copier:
--> 150         y = copier(x, memo)
151     else:
152         try:
/usr/lib/python3.7/copy.py in _deepcopy_dict(x, memo, deepcopy)
239     memo[id(x)] = y
240     for key, value in x.items():
--> 241         y[deepcopy(key, memo)] = deepcopy(value, memo)
242     return y
243 d[dict] = _deepcopy_dict
/usr/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
178                     y = x
179                 else:
--> 180                     y = _reconstruct(x, memo, *rv)
181 
182     # If is its own copy, don't memoize.
/usr/lib/python3.7/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
279     if state is not None:
280         if deep:
--> 281             state = deepcopy(state, memo)
282         if hasattr(y, '__setstate__'):
283             y.__setstate__(state)
/usr/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
148     copier = _deepcopy_dispatch.get(cls)
149     if copier:
--> 150         y = copier(x, memo)
151     else:
152         try:
/usr/lib/python3.7/copy.py in _deepcopy_dict(x, memo, deepcopy)
239     memo[id(x)] = y
240     for key, value in x.items():
--> 241         y[deepcopy(key, memo)] = deepcopy(value, memo)
242     return y
243 d[dict] = _deepcopy_dict
/usr/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
167                     reductor = getattr(x, "__reduce_ex__", None)
168                     if reductor:
--> 169                         rv = reductor(4)
170                     else:
171                         reductor = getattr(x, "__reduce__", None)
TypeError: can't pickle _thread.RLock objects

您应该直接在管道中传递FunctionTransformer的实例,而不是将其包装在ColumnTransformer中。

我没有检查代码,因为我没有在我的机器上安装tensorflow_hub。所以,如果这对你不起作用,请原谅我。

pipe = make_pipeline(        
FunctionTransformer(embed, kw_args={"kw_arg_nm":kw_arg_value}), 
# Provision someway to pass "Lines" through kw_args.       
SVC()
)
pipe.fit(x_train,y_train);

相关内容

  • 没有找到相关文章

最新更新