在 R 中使用 Keras 顺序模型"reinforcelearn"包时出现问题



我正试图使用keras(2.2.50版(神经网络/序列模型,根据以下小插曲,使用reinforcelearn软件包(0.2.1版(在强化学习设置中创建一个简单的代理:https://cran.r-project.org/web/packages/reinforcelearn/vignettes/agents.html。这是我使用的代码:

library('reinforcelearn')
library('keras')
model = keras_model_sequential() %>% 
layer_dense(units = 10, input_shape = 4, activation = "linear") %>%
compile(optimizer = optimizer_sgd(lr = 0.1), loss = "mae")
agent = makeAgent(policy = "softmax", val.fun = "neural.network", algorithm = "qlearning",
val.fun.args = list(model= model))

然而,当我尝试运行makeAgent函数时,我会收到以下错误消息:

Error in .subset2(public_bind_env, "initialize")(...) : 
Assertion on 'model' failed: Must inherit from class 'keras.models.Sequential', but has classes 'keras.engine.sequential.Sequential','keras.engine.training.Model','keras.engine.network.Network','keras.engine.base_layer.Layer','tensorflow.python.module.module.Module','tensorflow.python.training.tracking.tracking.AutoTrackable','tensorflow.python.training.tracking.base.Trackable','python.builtin.object'.

问题似乎是模型的错误类,但我能做些什么来解决这个问题?

我能够通过从CRAN下载源代码来解决这个问题(https://cran.r-project.org/src/contrib/reinforcelearn_0.2.1.tar.gz)并注释掉ValueNetworkR6类/initialise函数定义中的相应行:

ValueNetwork = R6::R6Class("ValueNetwork",
public = list(
model = NULL,
# keras model # fixme: add support for mxnet
initialize = function(model) {
# checkmate::assertClass(model, "keras.models.Sequential")
self$model = model
},
...

然后我通过以下方式重新安装了源代码包:install.packages([file path], repos = NULL, type="source")

相关内容

  • 没有找到相关文章