tf2.1.0-keras中的Using条件



我试图在我自己的tf2.1.0-keras模型中使用bool条件,下面是一个简单的例子:

import tensorflow as tf
class TestKeras:
def __init__(self):
pass
def build_graph(self):
x = tf.keras.Input(shape=(2),batch_size=1)
x_value = x[0,0]
y = tf.cond(x_value > 0, lambda :tf.add(x_value,0), lambda :tf.add(x_value,0))
return tf.keras.models.Model(inputs=[x], outputs=[y])
if __name__ == "__main__":
tk = TestKeras()
model = tk.build_graph()
model.summary(line_length=100)

,但似乎不工作,抛出异常:

using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

我已经尝试用tf.keras.backend.switch代替tf.cond,但它仍然得到相同的错误。

我还尝试将代码y = tf.cond(xxx)拆分为单个函数并添加@tf.funcion装饰器:

@tf.function
def compute_y(self,x):
return tf.cond(x > 0, lambda :tf.add(x,0), lambda :tf.add(x,0))

,但它得到了另一个错误:

Inputs to eager execution function cannot be Keras symbolic tensors, but found [<tf.Tensor 'strided_slice:0' shape=() dtype=float32>]

有人知道tf2.1.0-keras中的条件是如何工作的吗?

tf.keras.Input是一个符号张量,用于定义keras模型的输入。当你想在keras模型中应用自定义逻辑时,你应该创建Layer类的子类,或者使用Lambda层。

例如,对于Lambda层:

class TestKeras:
def __init__(self):
pass

def build_graph(self):
x = tf.keras.Input(shape=(2),batch_size=1)
def custom_fct(x):
x_value = x[0,0]
return tf.cond(x_value > 0, lambda :tf.add(x_value,0), lambda :tf.add(x_value,0))
y = tf.keras.layers.Lambda(custom_fct)(x)
return tf.keras.models.Model(inputs=[x], outputs=[y])

相关内容

  • 没有找到相关文章

最新更新