我正在尝试使用TensorFlow创建一个自定义模型。我使用tf.TensorArray()
函数与dynamic_size=True
。
当我想删除一个基于if条件的特定元素时,我遇到了问题。
例如,
t_a = tf.TensorArray(dtype=tf.float32,size=0,dynamic_size=True)
t_a.write(0,[0,23])
t_a.write(1,[1,67])
t_a.write(2,[3,0])
t_a.write(3,[4,9])
输出:
t_a.read(2)
<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[3., 0.]], dtype=float32)>
现在,我的应用程序将要求我读取每行第一个元素的值,如果为真,我需要从TensorArray中删除该元素。
访问单个元素,我使用t_a.gather([indices])
。
任何建议或解决这个问题是非常感谢的。
可以将"clear_after_read
"设置为"True
"。读取元素后,清除该值。
工作代码示例
import tensorflow as tf
ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=True)
ta = ta.write(0, 10)
ta = ta.write(1, 20)
ta = ta.write(2, 30)
indices = 1
ta.read(indices)
ta.gather([indices])
<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>