在 GPU 上使用 DiffEqFlux 反向传播神经常微分方程期间的标量操作



我正在使用Julia中的DiffEqFlux包来实现神经ODE。我在让它在 GPU 上工作时遇到问题。

最简单的例子在这里:

using DifferentialEquations
using Flux, DiffEqFlux
using CuArrays
CuArrays.allowscalar(false)
x = Float32[2.; 0.]|>gpu
tspan = Float32.((0.0f0,25.0f0))
dudt = Chain(Dense(2,50,tanh),Dense(50,2))|>gpu
loss() = sum(neural_ode(dudt,x,tspan,Tsit5(),save_everystep=false,save_start=false))
@show(loss())
Flux.back!(loss())

和堆栈跟踪:

loss() = -24.529072f0 (tracked)
ERROR: LoadError: scalar getindex is disallowed
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] assertscalar(::String) at /home/vshanka2/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:14
[3] getindex at /home/vshanka2/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:54 [inlined]
[4] getindex at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/adjtrans.jl:129 [inlined]
[5] _unsafe_getindex_rs at ./reshapedarray.jl:245 [inlined]
[6] _unsafe_getindex at ./reshapedarray.jl:242 [inlined]
[7] getindex at ./reshapedarray.jl:231 [inlined]
[8] macro expansion at ./multidimensional.jl:671 [inlined]
[9] macro expansion at ./cartesian.jl:64 [inlined]
[10] macro expansion at ./multidimensional.jl:666 [inlined]
[11] _unsafe_getindex! at ./multidimensional.jl:662 [inlined]
[12] _unsafe_getindex(::IndexLinear, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}, ::UnitRange{Int64}) at ./multidimensional.jl:656
[13] _getindex at ./multidimensional.jl:642 [inlined]
[14] getindex(::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}, ::UnitRange{Int64}) at ./abstractarray.jl:927
[15] (::getfield(Tracker, Symbol("##429#432")){Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}})(::TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}}) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/lib/array.jl:196
[16] iterate at ./generator.jl:47 [inlined]
[17] collect(::Base.Generator{Tuple{TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}},TrackedArray{…,CuArray{Float32,1,Nothing}},TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}},TrackedArray{…,CuArray{Float32,1,Nothing}}},getfield(Tracker, Symbol("##429#432")){Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}}}) at ./array.jl:606
[18] #428 at /home/vshanka2/.julia/packages/Tracker/cpxco/src/lib/array.jl:193 [inlined]
[19] back_(::Tracker.Call{getfield(Tracker, Symbol("##428#431")){Tuple{TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}},TrackedArray{…,CuArray{Float32,1,Nothing}},TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}},TrackedArray{…,CuArray{Float32,1,Nothing}}}},Tuple{Tracker.Tracked{CuArray{Float32,1,CuArray{Float32,2,Nothing}}},Tracker.Tracked{CuArray{Float32,1,Nothing}},Tracker.Tracked{CuArray{Float32,1,CuArray{Float32,2,Nothing}}},Tracker.Tracked{CuArray{Float32,1,Nothing}}}}, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:35
[20] back(::Tracker.Tracked{CuArray{Float32,1,Nothing}}, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:58
[21] (::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{CuArray{Float32,1,Nothing}}, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:38
[22] foreach(::Function, ::Tuple{Tracker.Tracked{CuArray{Float32,1,Nothing}},Nothing,Nothing,Nothing}, ::Tuple{Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}},CuArray{Float32,1,Nothing},Nothing,Nothing}) at ./abstractarray.jl:1867
[23] back_(::Tracker.Call{getfield(DiffEqFlux, Symbol("##25#28")){DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},Base.Iterators.Pairs{Symbol,Bool,Tuple{Symbol},NamedTuple{(:save_everystep,),Tuple{Bool}}},TrackedArray{…,CuArray{Float32,1,Nothing}},CuArray{Float32,1,Nothing},Tuple{Tsit5},ODESolution{Float32,2,Array{CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArray{Float32,1,Nothing},Tuple{Float32,Float32},false,CuArray{Float32,1,Nothing},ODEFunction{false,getfield(DiffEqFlux, Symbol("#dudt_#32")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2,Nothing}},TrackedArray{…,CuArray{Float32,1,Nothing}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2,Nothing}},TrackedArray{…,CuArray{Float32,1,Nothing}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,getfield(DiffEqFlux, Symbol("#dudt_#32")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2,Nothing}},TrackedArray{…,CuArray{Float32,1,Nothing}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2,Nothing}},TrackedArray{…,CuArray{Float32,1,Nothing}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}},Tuple{Tracker.Tracked{CuArray{Float32,1,Nothing}},Nothing,Nothing,Nothing}}, ::CuArray{Float32,1,Nothing}, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:38
[24] back(::Tracker.Tracked{CuArray{Float32,1,Nothing}}, ::CuArray{Float32,1,Nothing}, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:58
[25] #13 at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:38 [inlined]
[26] foreach at ./abstractarray.jl:1867 [inlined]
[27] back_(::Tracker.Call{getfield(Tracker, Symbol("##484#485")){TrackedArray{…,CuArray{Float32,1,Nothing}}},Tuple{Tracker.Tracked{CuArray{Float32,1,Nothing}}}}, ::Float32, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:38
[28] back(::Tracker.Tracked{Float32}, ::Int64, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:58
[29] #back!#15 at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:77 [inlined]
[30] #back! at ./none:0 [inlined]
[31] #back!#32 at /home/vshanka2/.julia/packages/Tracker/cpxco/src/lib/real.jl:16 [inlined]
[32] back!(::Tracker.TrackedReal{Float32}) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/lib/real.jl:14
[33] top-level scope at none:0
[34] include at ./boot.jl:326 [inlined]
[35] include_relative(::Module, ::String) at ./loading.jl:1038
[36] include(::Module, ::String) at ./sysimg.jl:29
[37] include(::String) at ./client.jl:403
[38] top-level scope at none:0

向前传递工作正常,但向后调用会导致标量 getindex。如果我允许标量,当然没有问题,但它比 cpu 慢得多(即使对于我正在处理的更大问题(。

我不确定它是否与任何包依赖项有关,但这是我目前安装的内容。

[fbb218c0] BSON v0.2.4
[c5f51814] CUDAdrv v5.0.1
[be33ccc6] CUDAnative v2.7.0
[3a865a2d] CuArrays v1.6.0
[aae7a2af] DiffEqFlux v0.7.0
[9fdde737] DiffEqOperators v4.6.1
[0c46a032] DifferentialEquations v6.9.0
[587475ba] Flux v0.8.3
[a98d9a8b] Interpolations v0.12.5
[15e1cf62] NPZ v0.4.0
[91a5bcdd] Plots v0.28.4

有什么想法吗?

这在DiffEqFlux 0.10.1上得到了修复。]up升级,还是]add DiffEqFlux@0.10.1专门要求此版本。

最新更新