如何在一个图中只评估所有可能操作的随机子集?


import tensorflow as tf
operations = [
tf.keras.layers.Dense(64, activation='sigmoid'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(64, activation=None),
]

def call(x):
# Sample 2 of the 3 operations.
sampled_ids = tf.random.categorical(
tf.zeros((1, len(operations))), num_samples=2, dtype=tf.int32,
)[0]
# Compute their output given x.
op_results_eager = tf.stack([operations[op_id](x) for op_id in sampled_ids])
# Try to replicate op_results_eager in graph_mode without evaluating all operations!
op_result_functs = [lambda: op(x) for op in operations]
op_results_graph = tf.stack([tf.switch_case(branch_index=op_id, branch_fns=op_result_functs) for op_id in sampled_ids])
tf.print(tf.reduce_all(tf.equal(op_results_eager, op_results_graph)))
return op_results_graph

for _ in range(1000):
call(tf.ones(shape=(1, 5)))

我需要与op_results_eager相同的结果,但使用允许将call()包装为tf的语句。函数,同时只计算抽样操作.

正如你所看到的,我试图建立索引作为一个交换情况,但这甚至不能给出正确的结果在紧急执行

可能是这样的(尽管每个Dense层每次都接收数据):

import tensorflow as tf
dense1 = tf.keras.layers.Dense(64, activation='sigmoid')
dense2 = tf.keras.layers.Dense(64, activation='relu')
dense3 = tf.keras.layers.Dense(64, activation=None)
operations = [dense1, dense2, dense3]
def call(x):
# Sample 2 of the 3 operations.
sampled_ids = tf.random.categorical(
tf.zeros((1, len(operations))), num_samples=2, dtype=tf.int32,
)[0]
op_results_eager = tf.stack([operations[op_id](x) for op_id in sampled_ids])

op_results_graph = tf.gather(tf.stack([dense1(x), dense2(x), dense3(x)]), sampled_ids)
tf.print(tf.reduce_all(tf.equal(op_results_eager, op_results_graph)))
return op_results_graph

for _ in range(10):
call(tf.ones(shape=(1, 5)))
1
1
1
1
1
1
1
1
1
1

或者不调用每一层的另一种选择:

dense1 = tf.keras.layers.Dense(64, activation='sigmoid')
dense2 = tf.keras.layers.Dense(64, activation='relu')
dense3 = tf.keras.layers.Dense(64, activation=None)
operations = [dense1, dense2, dense3]
def call(x):
# Sample 2 of the 3 operations.
sampled_ids = tf.random.categorical(
tf.zeros((1, len(operations))), num_samples=2, dtype=tf.int32,
)[0]
op_results_eager = tf.stack([operations[op_id](x) for op_id in sampled_ids])

id1, id2 = tf.split(sampled_ids, 2)

output1 = tf.stack([tf.cond(tf.equal(id1, 0), lambda: dense1(x), lambda: tf.zeros((1, 64))),tf.cond(tf.equal(id2, 0), lambda: dense1(x), lambda: tf.zeros((1, 64)))])
output2 = tf.stack([tf.cond(tf.equal(id1, 1), lambda: dense2(x), lambda: tf.zeros((1, 64))),tf.cond(tf.equal(id2, 1), lambda: dense2(x), lambda: tf.zeros((1, 64)))])
output3 = tf.stack([tf.cond(tf.equal(id1, 2), lambda: dense3(x), lambda: tf.zeros((1, 64))),tf.cond(tf.equal(id2, 2), lambda: dense3(x), lambda: tf.zeros((1, 64)))])
outputs = tf.stack([output1, output2, output3], axis=0)
op_results_graph = tf.expand_dims(tf.gather_nd(outputs, tf.where(tf.reduce_any(tf.not_equal(outputs, tf.zeros((64,))), axis=-1))), axis=1)
return op_results_graph

相关内容

最新更新