我想实例化一个sklearn估计器,因为它的名字。例如:
name = 'RandomForestClassifier'
clf = BaseEstimator(name)
print(clf) # return something like RandomForestClassifier(...)
我尝试过的是clf = eval(name)
,但我不能做clf.fit(X, y)
,因为clf是<class 'sklearn.ensemble._forest.RandomForestClassifier'>
,而不是类似RandomForestClassifier()
的东西。
由于它的名字,我找不到如何创建sklearn估计器。
据我所知,您想要实现的是根据字符串中给出的一些用户输入动态实例化现有分类器。BaseEstimator
是所有估计器的基类,所以我认为只有在您计划构建自己的估计器时才应该使用它。
现在我不知道你是如何运行上述代码的,因为BaseEstimator()
不接受任何参数:
clf = BaseEstimator('RandomForestClassifier')
Traceback (most recent call last):
File "/usr/local/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-6-59dd1a62618a>", line 1, in <module>
clf = BaseEstimator('RandomForestClassifier')
TypeError: BaseEstimator() takes no arguments
在任何情况下,您都可以更接近分类器的动态实例化,如下所示:
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
def get_clf(name):
if name not in ('RandomForestClassifier', 'DecisionTreeClassifier'):
raise ValueError(f"{name} is not a recognised option")
classifiers = {
"DecisionTreeClassifier": DecisionTreeClassifier,
"RandomForestClassifier": RandomForestClassifier
}
classifier = classifiers[name]
return classifier()
或者,您可以选择更直接的解决方案:
def get_clf(name):
if name not in ('RandomForestClassifier', 'DecisionTreeClassifier'):
raise ValueError(f"{name} is not a recognised option")
if name == 'RandomForestClassifier':
return RandomForestClassifier()
elif name == 'DecisionTreeClassifier':
return DecisionTreeClassifier()
else:
raise RuntimeError()
您可以使用importlib
模块:
import importlib
name = "tree.DecisionTreeClassifier"
module_name, model_name = name.split(".")
clf = getattr(importlib.import_module(f"sklearn.{module_name}"), model_name)