如何使用tf.assertTypeEqual测试张量的dtype



我有一个张量

t = tf.io.decode_raw(tf.constant('asdf'), tf.uint8)
print(t.dtype)
<dtype: 'uint8'>

我想使用tf.test.TestCase中的函数来验证这个数据类型。我使用的是下面的函数

self.assertDTypeEqual(t.dtype, tf.uint8)

但它给了我一个错误

TypeError: Cannot interpret 'tf.uint8' as a data type

请告诉我如何使用tf.test.TestCase验证张量数据类型?

尝试了您的案例,我运行了tensorflow 2.3,似乎它将字符串作为输入

import tensorflow as tf
class dtype_testcase(tf.test.TestCase):
def setUp(self):
super(dtype_testcase, self).setUp()

def tearDown(self):
pass

def test_dtype(self):
data = tf.constant([1,2, 3], dtype=tf.uint8)
self.assertDTypeEqual(data, 'uint8')


if __name__ == '__main__':
tf.test.main()
[       OK ] dtype_testcase.test_dtype
[ RUN      ] dtype_testcase.test_session
[  SKIPPED ] dtype_testcase.test_session
----------------------------------------------------------------------
Ran 2 tests in 0.394s

相关内容

最新更新