keras (tensorflow backend) 条件赋值与 K.switch()



我正在尝试实现类似的东西

if np.max(subgrid) == np.min(subgrid):
middle_middle = cur_subgrid + 1
else:
middle_middle = cur_subgrid

由于条件只能在运行时确定,因此我使用 Keras 语法

如下
middle_middle = K.switch(K.max(subgrid) == K.min(subgrid), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)

但是我收到此错误:

<ipython-input-112-0504ce070e71> in col_loop(j, gray_map, mask_A)
56 
57 
---> 58             middle_middle = K.switch(K.max(subgrid) == K.min(subgrid), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)
59 
60             print ('ml',middle_left.shape)
/nfs/isicvlnas01/share/anaconda3/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in switch(condition, then_expression, else_expression)    2561         The selected tensor.    2562     """
-> 2563     if condition.dtype != tf.bool:    2564         condition = tf.cast(condition, 'bool')    2565     if not callable(then_expression):
AttributeError: 'bool' object has no attribute 'dtype'

middle_middlecur_subgrid和子网格都是张量NxN。任何帮助,不胜感激。

我认为问题在于,使用K.max(subgrid) == K.min(subgrid)您正在创建一个比较两个张量对象的python布尔值,而不是一个包含两个输入张量比较值的Tensorflow 布尔张量

换句话说,您编写的内容将被评估为

K.switch(False, lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)

而不是

comparison = ... # Some tensor, that at runtime will contain True if min and max are the same, False otherwise. 
K.switch(comparison , lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)

所以你需要做的是使用 keras.backend.equal(( 而不是==

K.switch(K.equal(K.max(subgrid),K.min(subgrid)), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)

最新更新