我可以从动态大小的TensorArray中删除特定的元素吗?



我正在尝试使用TensorFlow创建一个自定义模型。我使用tf.TensorArray()函数与dynamic_size=True

当我想删除一个基于if条件的特定元素时,我遇到了问题。

例如,

t_a = tf.TensorArray(dtype=tf.float32,size=0,dynamic_size=True)
t_a.write(0,[0,23])
t_a.write(1,[1,67])
t_a.write(2,[3,0])
t_a.write(3,[4,9])

输出:

t_a.read(2)
<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[3., 0.]], dtype=float32)>

现在,我的应用程序将要求我读取每行第一个元素的值,如果为真,我需要从TensorArray中删除该元素。

访问单个元素,我使用t_a.gather([indices])

任何建议或解决这个问题是非常感谢的。

可以将"clear_after_read"设置为"True"。读取元素后,清除该值。

工作代码示例

import tensorflow as tf
ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=True)
ta = ta.write(0, 10)
ta = ta.write(1, 20)
ta = ta.write(2, 30)
indices = 1
ta.read(indices)
ta.gather([indices])

<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>

相关内容

  • 没有找到相关文章

最新更新