朱莉娅中函数调用的安布吉斯



我有这个错误

ERROR: MethodError: vcat(::Array{Real,2}, ::TrackedArray{…,Array{Float32,2}}) is ambiguous. Candidates:
vcat(364::AbstractArray, x::Union{TrackedArray, TrackedReal}, xs::Union{Number, AbstractArray}...) in Tracker at C:UsersHenri.juliapackagesTracker6wcYJsrclibarray.jl:167
vcat(A::Union{AbstractArray{T,2}, AbstractArray{T,1}} where T...) in Base at abstractarray.jl:1296
Possible fix, define
vcat(::Union{AbstractArray{T,2}, AbstractArray{T,1}} where T, ::Union{TrackedArray{T,1,A} where A<:AbstractArray{T,1} where T, TrackedArray{T,2,A} where A<:AbstractArray{T,2} where T}, ::Vararg{Union{AbstractArray{T,2}, AbstractArray{T,1}} where T,N} where N)

告诉我两个vcat()函数是模棱两可的。我想使用Base.vcat()函数,但使用它会明确抛出相同的错误。为什么?错误抛出提出的这种"可能修复"是什么?

此外,当我手动调用 REPL 中的每一行时,不会抛出任何错误。我不理解这种行为。仅当 vcat() 位于另一个函数中调用的函数中时,才会发生这种情况。就像我下面的例子一样。

下面是重现错误的代码:

using Flux
function loss(a, b, net, net2)
net2(vcat(net(a),a))
end
function test()    
opt = ADAM()
net = Chain(Dense(3,3))
net2 = Chain(Dense(6,1))
L(a, b) = loss(a, b, net, net2)
data = tuple(rand(3,1), rand(3,1))
xs = Flux.params(net)
gs = Tracker.gradient(() -> L(data...), xs)
Tracker.update!(opt, xs, gs)
end

正如在 Henri.D 的评论中提到的,我们已经设法通过使用a类型来修复它,该类型是Float64Array,默认类型由rand返回,而net(a)返回了Float32TrackedArray,并且无法使用avcat它。

我设法通过更改您的损失函数来修复vcatnet2(vcat(net(a),Float32.(a)))因为vcat无法连接,因为net(a)是一个Float32 Arraya是一个Float64的。然后L(data...)是 1 个元素的TrackedArray,而我认为你需要一个Float32这就是为什么我最终用net2(vcat(net(a),Float32.(a)))[1]替换loss function

最新更新