我试图通过使用tensorflow
函数的函数参数列表:tf.map_fn
。下面是我的代码:
def my_func(a,v,c,d):
print(a,v,c,d)
if __name__ == '__main__':
tf.config.set_visible_devices(tf.config.list_physical_devices('GPU')[0],'GPU')
iterable = [['a','b','c','s'],['s','e','f','c']]
tensor = tf.convert_to_tensor(iterable)
dataset = tf.data.Dataset.from_tensor_slices(tensor)
tf.map_fn(lambda x: my_func(*x),dataset)
但是我得到了这个错误,我不能真正破译:
Traceback (most recent call last):
File "/Volumes/WorkSSD/Notebooks/01 Convert Raw EDF to Raw CSV copy.py", line 134, in <module>
tf.map_fn(lambda x: my_func(*x),dataset)
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py", line 629, in new_func
return func(*args, **kwargs)
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py", line 561, in new_func
return func(*args, **kwargs)
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/ops/map_fn.py", line 640, in map_fn_v2
return map_fn(
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py", line 561, in new_func
return func(*args, **kwargs)
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/ops/map_fn.py", line 392, in map_fn
result_flat_signature = [
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/ops/map_fn.py", line 393, in <listcomp>
_most_general_compatible_type(s)._unbatch() # pylint: disable=protected-access
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4587, in _unbatch
raise ValueError("Slicing dataset elements is not supported for rank 0.")
ValueError: Slicing dataset elements is not supported for rank 0.
我做错了什么,我如何解决它?
如果您的目标是打印,则可以这样打印。
import tensorflow as tf
if __name__ == '__main__':
iterable = [['a','b','c','s'],['s','e','f','c']]
tensor = tf.convert_to_tensor(iterable)
dataset = tf.data.Dataset.from_tensor_slices(tensor)
for element in dataset:
print(element)
这个也可以。
def my_func(a):
tf.print(a,[a])
return a
if __name__ == '__main__':
iterable = [[['a','b','c','s'],['s','e','f','c']]]
tensor = tf.convert_to_tensor(iterable)
dataset = tf.data.Dataset.from_tensor_slices(tensor)
iterator = iter(dataset)
tf.map_fn(my_func, iterator.get_next())