我试图在我自己的tf2.1.0-keras模型中使用bool条件,下面是一个简单的例子:
import tensorflow as tf
class TestKeras:
def __init__(self):
pass
def build_graph(self):
x = tf.keras.Input(shape=(2),batch_size=1)
x_value = x[0,0]
y = tf.cond(x_value > 0, lambda :tf.add(x_value,0), lambda :tf.add(x_value,0))
return tf.keras.models.Model(inputs=[x], outputs=[y])
if __name__ == "__main__":
tk = TestKeras()
model = tk.build_graph()
model.summary(line_length=100)
,但似乎不工作,抛出异常:
using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
我已经尝试用tf.keras.backend.switch
代替tf.cond
,但它仍然得到相同的错误。
我还尝试将代码y = tf.cond(xxx)
拆分为单个函数并添加@tf.funcion
装饰器:
@tf.function
def compute_y(self,x):
return tf.cond(x > 0, lambda :tf.add(x,0), lambda :tf.add(x,0))
,但它得到了另一个错误:
Inputs to eager execution function cannot be Keras symbolic tensors, but found [<tf.Tensor 'strided_slice:0' shape=() dtype=float32>]
有人知道tf2.1.0-keras中的条件是如何工作的吗?
tf.keras.Input
是一个符号张量,用于定义keras模型的输入。当你想在keras模型中应用自定义逻辑时,你应该创建Layer
类的子类,或者使用Lambda
层。
例如,对于Lambda
层:
class TestKeras:
def __init__(self):
pass
def build_graph(self):
x = tf.keras.Input(shape=(2),batch_size=1)
def custom_fct(x):
x_value = x[0,0]
return tf.cond(x_value > 0, lambda :tf.add(x_value,0), lambda :tf.add(x_value,0))
y = tf.keras.layers.Lambda(custom_fct)(x)
return tf.keras.models.Model(inputs=[x], outputs=[y])