使用Spark_sklearn GridSearchCV进行嵌套交叉验证会导致Spark-5063错误



使用Spark_sklearn GridSearchCV作为内部cv,使用sklearn cross_validate/cross_val_score作为外部cv执行嵌套交叉验证会导致"似乎您正试图从广播变量、操作或转换引用SparkContext"错误。

inner_cv = StratifiedKFold(n_splits=2, shuffle=True, random_state=42)
outer_cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
scoring_metric = ['roc_auc', 'average_precision', 'precision']
gs = GridSearchCV(sparkcontext, estimator=RandomForestClassifier(
class_weight='balanced_subsample', n_jobs=-1),
param_grid=[{"max_depth": [5], "max_features": [.5, .8], 
"min_samples_split": [2], "min_samples_leaf": [1, 2, 5, 10], 
"bootstrap": [True, False], "criterion": ["gini", "entropy"], 
"n_estimators": [300]}], 
scoring=scoring_metric, cv=inner_cv, verbose=verbose, n_jobs=-1, 
refit='roc_auc', return_train_score=False)
scores = cross_validate(gs, X, y, cv=outer_cv, scoring=scoring_metric, n_jobs=-1, 
return_train_score=False)

我已经尝试将n_jobs=-1设置为n_jobs=1以删除基于joblib的并行性,然后重试,但它仍然会产生相同的异常。

异常:似乎您正试图从广播变量、操作或转换引用SparkContext。SparkContext只能在驱动程序上使用,不能在它在工作程序上运行的代码中使用。有关更多信息,请参阅SPARK-5063。

Complete Traceback (most recent call last):
File "model_evaluation.py", line 350, in <module>
main()
File "model_evaluation.py", line 269, in main
scores = cross_validate(gs, X, y, cv=outer_cv, scoring=scoring_metric, n_jobs=-1, return_train_score=False)
File "../python27/lib/python2.7/site-packages/sklearn/model_selection/_validation.py", line 195, in cross_validate
for train, test in cv.split(X, y, groups))
File "../python27/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 779, in __call__
while self.dispatch_one_batch(iterator):
File "../python27/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 620, in dispatch_one_batch
tasks = BatchedCalls(itertools.islice(iterator, batch_size))
File "../python27/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 127, in __init__
self.items = list(iterator_slice)
File "../python27/lib/python2.7/site-packages/sklearn/model_selection/_validation.py", line 195, in <genexpr>
for train, test in cv.split(X, y, groups))
File "../python27/lib/python2.7/site-packages/sklearn/base.py", line 61, in clone
new_object_params[name] = clone(param, safe=False)
File "../python27/lib/python2.7/site-packages/sklearn/base.py", line 52, in clone
return copy.deepcopy(estimator)
File "/usr/local/lib/python2.7/copy.py", line 182, in deepcopy
rv = reductor(2)
File "/usr/local/lib/spark/python/pyspark/context.py", line 279, in __getnewargs__
"It appears that you are attempting to reference SparkContext from a broadcast "
Exception: It appears that you are attempting to reference SparkContext from a broadcast 
variable, action, or transformation. SparkContext can only be used on the driver, not 
in code that it run on workers. For more information, see SPARK-5063.

编辑:问题似乎是sklearn cross_validate()以类似于pickle估计器对象的方式为每个拟合克隆估计器,这对于PySpark GridsearchCV估计器是不允许的,因为SparkContext()对象不能/不应该pickle。那么,我们如何正确地克隆估计器呢?

我终于找到了一个解决方案。当scikit learn clone()函数尝试深度复制SparkContext对象时,就会出现问题。我使用的解决方案有点粗糙,如果有更好的解决方案,但它有效,我肯定会走另一条路。导入copy类并重写deepcopy()函数,以便在SparkContext对象看到它时忽略它。

# Mock the deep-copy function to ignore copying sparkcontext objects
# Helps avoid pickling error or broadcast variable errors
import copy
_deepcopy = copy.deepcopy
def mock_deepcopy(*args, **kwargs):
if isinstance(args[0], SparkContext):
return args[0]
return _deepcopy(*args, **kwargs)
copy.deepcopy = mock_deepcopy

因此,现在它将不会尝试复制SparkContext对象,并且一切似乎都正常工作。

相关内容

  • 没有找到相关文章

最新更新