在Flux中使用回调进行训练时记录损失



我正试图为Flux中的train!函数编写一个回调。我的代码是:

cb_loss = x -> push!(x, loss(x_train, y_train))
loss_vector = Vector{Float32}()
Flux.train!(loss, ps, train_data, opt, cb=cb_loss(loss_vector))

它给了我这个错误:

MethodError: objects of type Float32 are not callable
Stacktrace:
[1] call(::Float32) at C:Usersarjur.juliapackagesFluxFj3btsrcoptimisetrain.jl:36
[2] foreach at .abstractarray.jl:1920 [inlined]
[3] #10 at C:Usersarjur.juliapackagesFluxFj3btsrcoptimisetrain.jl:38 [inlined]
[4] macro expansion at C:Usersarjur.juliapackagesFluxFj3btsrcoptimisetrain.jl:93 [inlined]
[5] macro expansion at C:Usersarjur.juliapackagesJunooLB1dsrcprogress.jl:134 [inlined]
[6] #train!#12(::Array{Float32,1}, ::typeof(Flux.Optimise.train!), ::typeof(loss), ::Zygote.Params, ::DataLoader, ::Descent) at C:Usersarjur.juliapackagesFluxFj3btsrcoptimisetrain.jl:81
[7] (::Flux.Optimise.var"#kw##train!")(::NamedTuple{(:cb,),Tuple{Array{Float32,1}}}, ::typeof(Flux.Optimise.train!), ::Function, ::Zygote.Params, ::DataLoader, ::Descent) at .none:0
[8] top-level scope at In[108]:1

有趣的是,它正确地将第一个值添加到向量中,然后崩溃,所以我猜错误消息与此有关。

我检查了train!函数之外的函数,它是有效的,那么我应该如何重写这个函数来记录向量中的损失呢?

似乎需要像这样传递它:cb=callback。因此,可以使用全局变量或定义如下回调:

loss_vector = Vector{Float32}()
callback() = push!(loss_vector, loss(x_train, y_train))
Flux.train!(loss, ps, train_data, opt, cb=callback)

最新更新