如何在Tensorflow Java/scala中禁用学习阶段,就像我们在python中做的那样tf.keras.bac



我正在使用lstm模型并加载保存的模型,如下所示

val modelFiles: List[String] = List(
"saved_model.pb",
"keras_metadata.pb",
"variables/variables.data-00000-of-00001",
"variables/variables.index"
)
val modelDir = "/Models/lstm_model_fullname_with_watchlist"
val tempDirectory = Files.createTempDirectory("lstm").toFile
val savedModel = SavedModelBundle.load(tempDirectory.getAbsolutePath, "serve")
new LSTMModel(
session = savedModel.session()
)

和调用下面的预测

val input1: Tensor[_] = Tensor.create(embed(name1))
val input2: Tensor[_] = Tensor.create(embed(name2))
val result: Tensor[_] = session.runner()
.fetch("StatefulPartitionedCall:0")
.feed("serving_default_input_1:0", input1)
.feed("serving_default_input_2:0", input2)
.run().get(0)
  1. 我需要检查learning_phase是否启用或禁用?
  2. 如何禁用Scala或Java中的学习阶段。我正在使用scala。但是Java参考也很好,我希望两者相似。

我偶然发现了这个链接,这个选项本身是不可用的?

我当前的tensorflow版本是1.15.0。

Thanks in advance ....

根据这个对话,这是由模型控制的。https://github.com/tensorflow/java/issues/465

相关内容

  • 没有找到相关文章