是什么导致了我的随机森林代码中出现这种奇怪的类型错误?



我有一些非常简单的代码,在csv的数据上训练随机森林。代码、减去导入和常量可以在下面找到:

def build_estimator(model_dir):
"""Build an estimator."""
params = tensor_forest.ForestHParams(
num_classes=2, num_features=5,
num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
graph_builder_class = tensor_forest.RandomForestGraphs
if FLAGS.use_training_loss:
graph_builder_class = tensor_forest.TrainingLossForest
# Use the SKCompat wrapper, which gives us a convenient way to split
# in-memory data like MNIST into batches.
return estimator.SKCompat(random_forest.TensorForestEstimator(
params, graph_builder_class=graph_builder_class,
model_dir=model_dir))

model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir
est = build_estimator(model_dir)
COLUMNS = [ "a", "b", "c",
"d", "e", "f"]
postData = pd.read_csv("PostData2Cut.csv", names=COLUMNS, skipinitialspace=True, dtype=np.float32)
est.fit(x=postData[["a", "b", "c",
"d", "e"]], y=postData[["f"]],
batch_size=FLAGS.batch_size)

当我到达est.fit行时,尽管它崩溃了,说:

TypeError: Input 'input_data' of 'CountExtremelyRandomStats' Op has type float64 that does not match expected type of float32.

显然,这发生在以下代码行中名为op_def_library.py的张量流文件中:

apply_op
(prefix, dtypes.as_dtype(input_arg.type).name))

不太确定是什么原因造成的。我似乎在说从 csv 读取的值应该是 float32 类型。对此感到非常沮丧。关于如何解决它的任何想法?

我遇到了同样的错误,tf.cast((解决了这个问题。

training_set = tf.cast(training_set, tf.float32)

看到这个答案: 类型使用 TensorForestEstimator 训练张量流随机森林时出错

相关内容

最新更新