我将要实现一个使用递归大量的程序。因此,在我开始获得堆栈溢出例外之前,我认为实施蹦床并使用thunks会很高兴。
我第一次尝试是与阶乘有关的。这里的代码:
callable(f) = !isempty(methods(f))
function trampoline(f, arg1, arg2)
v = f(arg1, arg2)
while callable(v)
v = v()
end
return v
end
function factorial(n, continuation)
if n == 1
continuation(1)
else
(() -> factorial(n-1, (z -> (() -> continuation(n*z)))))
end
end
function cont(x)
x
end
另外,我实施了一个天真的阶乘,以检查事实上,我是否会防止堆栈溢出:
function factorial_overflow(n)
if n == 1
1
else
n*factorial_overflow(n-1)
end
end
结果是:
julia> factorial_overflow(140000)
ERROR: StackOverflowError:
#JITing with a small input
julia> trampoline(factorial, 10, cont)
3628800
#Testing
julia> trampoline(factorial, 140000, cont)
0
所以,是的,我避免了stacksoverflows。是的,我知道结果是胡说八道,因为我让整数溢出了,但是在这里我只是关心堆栈。当然,生产版本将具有固定的。
(也知道,我知道有一个内置的情况,我都不使用这些,我让它们测试了我的蹦床(。
第一次运行时,蹦床版本需要很多时间,然后在计算相同或较低值时会变得很快。如果我做了trampoline(factorial, 150000, cont)
,我将有一些编译时间。
在我看来(受过教育的猜测(,我在刻有许多不同的符号:一个生成的thunk。
我的问题是:我可以避免吗?
我认为问题是每个闭合都是其自己的类型,专门针对被捕获的变量。为了避免这种专业化,可以使用尚未完全专业化的函子:
struct L1
f
n::Int
z::Int
end
(o::L1)() = o.f(o.n*o.z)
struct L2
f
n::Int
end
(o::L2)(z) = L1(o.f, o.n, z)
struct Factorial
f
c
n::Int
end
(o::Factorial)() = o.f(o.n-1, L2(o.c, o.n))
callable(f) = false
callable(f::Union{Factorial, L1, L2}) = true
function myfactorial(n, continuation)
if n == 1
continuation(1)
else
Factorial(myfactorial, continuation, n)
end
end
function cont(x)
x
end
function trampoline(f, arg1, arg2)
v = f(arg1, arg2)
while callable(v)
v = v()
end
return v
end
请注意,该功能字段是未型的。现在,该功能在第一次运行时运行得更快:
julia> @time trampoline(myfactorial, 10, cont)
0.020673 seconds (4.24 k allocations: 264.427 KiB)
3628800
julia> @time trampoline(myfactorial, 10, cont)
0.000009 seconds (37 allocations: 1.094 KiB)
3628800
julia> @time trampoline(myfactorial, 14000, cont)
0.001277 seconds (55.55 k allocations: 1.489 MiB)
0
julia> @time trampoline(myfactorial, 14000, cont)
0.001197 seconds (55.55 k allocations: 1.489 MiB)
0
我刚将您的代码中的每个封闭都转换为相应的函子。这可能不需要,可能有更好的解决方案,但是它有效并希望展示了方法。
编辑:
要使放缓的原因更加清晰,可以使用:
function factorial(n, continuation)
if n == 1
continuation(1)
else
tmp = (z -> (() -> continuation(n*z)))
@show typeof(tmp)
(() -> factorial(n-1, tmp))
end
end
此输出:
julia> trampoline(factorial, 10, cont)
typeof(tmp) = ##31#34{Int64,#cont}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,#cont}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}}}}}
3628800
tmp
是封闭。它自动创建的类型##31#34
看起来类似于
struct Tmp{T,F}
n::T
continuation::F
end
continuation
字段的F
类型的专业化是漫长汇编时间的原因。
通过使用L2
,而不是专门针对相应的字段f
,continuation
参数 factorial
始终具有L2
类型,并且避免了问题。