我想以某种方式在tf.while_loop
中维护一个常量列表,它可以支持以下函数
- 我能够读取和写入(多次(索引处的常数值
- 我可以在它上运行
tf.cond
,方法是在一个索引与某个常量处检查它的值
TensorArray
在此不起作用,因为它不支持重写。我还有什么其他选择?
您可以定义一个普通的Tensor
,并用tf.tensor_scatter_nd_update
更新它,如下所示:
%tensorflow_version 1.x
import tensorflow as tf
data = tf.constant([1, 1, 1, 0, 1, 0, 1, 1, 0, 0], dtype=tf.float32)
data_tensor = tf.zeros_like(data)
tensor_size = data_tensor.shape[0]
init_state = (0, data_tensor)
condition = lambda i, _: i < tensor_size
def custom_body(i, tensor):
special_index = 3 # index for which a value should be changed
new_value = 8
tensor = tf.where(tf.equal(i, special_index),
tf.tensor_scatter_nd_update(tensor, [[special_index]], [new_value]),
tf.tensor_scatter_nd_update(tensor, [[i]], [data[i]*2]))
return i + 1, tensor
body = lambda i, tensor: (custom_body(i, tensor))
_, final_result = tf.while_loop(condition, body, init_state)
with tf.Session() as sess:
final_result_values = final_result.eval()
print(final_result_values)
[2. 2. 2. 8. 2. 0. 2. 2. 0. 0.]