tf.keras.to_datential类型在建模过程中出错



tf.keras.to_categorical似乎与模型拟合不兼容。如果我在模型中使用它,我可以在训练之前运行模型来预测一些值(并验证输出的形状是否正确(,但拟合告诉我的用户代码中有错误。我在马厩和晚上都试过了。

我已经将我的模型简化为:

# Minimal example
class TestModel1(tf.keras.Model):
def call(self, x):
return tf.keras.utils.to_categorical(x)
# Fixed width and a trainable layer
class TestModel2(tf.keras.Model):
def __init__(self):
super(TestModel2, self).__init__()
self.d1 = tf.keras.layers.Dense(2, activation='relu')
def call(self, x):
x = tf.keras.utils.to_categorical(x, num_classes=5)
return self.d1(x)

两个模型运行良好,但在拟合时会遇到相同的错误。

>>> x = [1, 2, 3, 4]
>>> y = [0, 1, 0, 1]
>>> TestModel1()(x)
array([[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]], dtype=float32)
>>> TestModel2()(x)
<tf.Tensor: shape=(4, 2), dtype=float32, numpy=
array([[0.1654321 , 0.        ],
[0.        , 0.07433152],
[0.87672186, 0.        ],
[0.25229335, 0.        ]], dtype=float32)>
>>> loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
>>> model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
>>> model.fit(x, y, epochs=1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
tmp_logs = train_function(iterator)
File "/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
result = self._call(*args, **kwds)
File "/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 696, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
File "/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3065, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:
/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:806 train_function  *
return step_function(self, iterator)
/Users/user/golorry/tf/my_model.py:30 call  *
return tf.keras.utils.to_categorical(x)
/Users/user/.virtualenvs/tf/lib/python3.8/site-packages/tensorflow/python/keras/utils/np_utils.py:69 to_categorical  **
y = np.array(y, dtype='int')
TypeError: __array__() takes 1 positional argument but 2 were given

这是故意的还是错误?

我建议使用tf.one_hot

class TestModel1(tf.keras.Model):
def __init__(self):
super(TestModel1, self).__init__()
def call(self, x):
return tf.one_hot(x, depth=5)
class TestModel2(tf.keras.Model):
def __init__(self):
super(TestModel2, self).__init__()
self.d1 = tf.keras.layers.Dense(2, activation='relu')
def call(self, x):
x = tf.one_hot(x, depth=5)
return self.d1(x)

这是一个跑步笔记本https://colab.research.google.com/drive/1udRnIdYGO0iBE3PfLBCJ3zT61a4qGWd2?usp=sharing

最新更新