我打算对一个tflearn模型的超参数执行网格搜索。似乎tflearn.DNN
产生的模型与sklearn的GridSearchCV期望不兼容:
from sklearn.grid_search import GridSearchCV
import tflearn
import tflearn.datasets.mnist as mnist
import numpy as np
X, Y, testX, testY = mnist.load_data(one_hot=True)
encoder = tflearn.input_data(shape=[None, 784])
encoder = tflearn.fully_connected(encoder, 256)
encoder = tflearn.fully_connected(encoder, 64)
# Building the decoder
decoder = tflearn.fully_connected(encoder, 256)
decoder = tflearn.fully_connected(decoder, 784)
# Regression, with mean square error
net = tflearn.regression(decoder, optimizer='adam', learning_rate=0.01,
loss='mean_square', metric=None)
model = tflearn.DNN(net, tensorboard_verbose=0)
grid_hyperparams = {'optimizer': ['adam', 'sgd', 'rmsprop'], 'learning_rate': np.logspace(-4, -1, 4)}
grid = GridSearchCV(model, param_grid=grid_hyperparams, scoring='mean_squared_error', cv=2)
grid.fit(X, X)
我得到错误:
TypeError Traceback (most recent call last)
<ipython-input-3-fd63245cd0a3> in <module>()
22 grid_hyperparams = {'optimizer': ['adam', 'sgd', 'rmsprop'], 'learning_rate': np.logspace(-4, -1, 4)}
23 grid = GridSearchCV(model, param_grid=grid_hyperparams, scoring='mean_squared_error', cv=2)
---> 24 grid.fit(X, X)
25
26
/home/deeplearning/anaconda3/lib/python3.5/site-packages/sklearn/grid_search.py in fit(self, X, y)
802
803 """
--> 804 return self._fit(X, y, ParameterGrid(self.param_grid))
805
806
/home/deeplearning/anaconda3/lib/python3.5/site-packages/sklearn/grid_search.py in _fit(self, X, y, parameter_iterable)
539 n_candidates * len(cv)))
540
--> 541 base_estimator = clone(self.estimator)
542
543 pre_dispatch = self.pre_dispatch
/home/deeplearning/anaconda3/lib/python3.5/site-packages/sklearn/base.py in clone(estimator, safe)
45 "it does not seem to be a scikit-learn estimator "
46 "as it does not implement a 'get_params' methods."
---> 47 % (repr(estimator), type(estimator)))
48 klass = estimator.__class__
49 new_object_params = estimator.get_params(deep=False)
TypeError: Cannot clone object '<tflearn.models.dnn.DNN object at 0x7fead09948d0>' (type <class 'tflearn.models.dnn.DNN'>): it does not seem to be a scikit-learn estimator as it does not implement a 'get_params' methods.
任何想法我怎么能得到一个对象适合GridSearchCV?
我没有使用tflearn的经验,但是我有一些Python和sklearn的基本背景。从你的StackOverflow截图中的错误判断,tflearn模型没有与scikit-learn估计器相同的方法或属性。这是可以理解的,因为它们不是科学学习评估器。
Sklearn的网格搜索CV仅适用于与scikit-learn估计器具有相同方法和属性的对象(例如具有fit()和predict()方法)。如果你打算使用sklearn的网格搜索,你必须在tflearn模型周围编写自己的包装器,使其作为sklearn估计器的替代品,这意味着你必须编写自己的类,该类具有与任何其他scikit-learn估计器相同的方法,但使用tflearn库来实际实现这些方法。
要做到这一点,请理解基本scikit-learn估计器(最好是您熟悉的那个)的代码,并查看fit()、predict()、get_params()等方法对对象及其内部实际做了什么。然后使用tflearn库编写自己的类。一开始,一个快速的谷歌搜索显示这个存储库是"一个薄薄的scikit-learn风格的tensorflow框架包装器":DSLituiev/tflearn (https://github.com/DSLituiev/tflearn)。我不知道这是否会作为网格搜索的替代品,但值得一看。