焦点丢失:ValueError:batch_dims=1必须小于rank(indexes)=1



我正在尝试从https://github.com/artemmavrin/focal-loss/使用Tensorflow1.14,但我在使用示例进行测试时遇到了以下错误

ValueError:batch_dims=1必须小于rank(indexes(=1。

probs = tf.gather(probs, y_true, axis=-1, batch_dims=y_true_rank)

尝试更新TensorFlow。我在2.0.0中也出现了同样的错误,但对于2.4.1来说还可以。

最新更新