我有一个张量
t = tf.io.decode_raw(tf.constant('asdf'), tf.uint8)
print(t.dtype)
<dtype: 'uint8'>
我想使用tf.test.TestCase中的函数来验证这个数据类型。我使用的是下面的函数
self.assertDTypeEqual(t.dtype, tf.uint8)
但它给了我一个错误
TypeError: Cannot interpret 'tf.uint8' as a data type
请告诉我如何使用tf.test.TestCase验证张量数据类型?
尝试了您的案例,我运行了tensorflow 2.3,似乎它将字符串作为输入
import tensorflow as tf
class dtype_testcase(tf.test.TestCase):
def setUp(self):
super(dtype_testcase, self).setUp()
def tearDown(self):
pass
def test_dtype(self):
data = tf.constant([1,2, 3], dtype=tf.uint8)
self.assertDTypeEqual(data, 'uint8')
if __name__ == '__main__':
tf.test.main()
[ OK ] dtype_testcase.test_dtype
[ RUN ] dtype_testcase.test_session
[ SKIPPED ] dtype_testcase.test_session
----------------------------------------------------------------------
Ran 2 tests in 0.394s