加速 Zygote.jl AD
Speeding up Zygote.jl AD
我正在使用 Julia 的 Zygote.jl
包来计算期权希腊字母(即期权价格相对于参数的导数)的自动差异功能。请参阅下面使用 Black Scholes 计算看涨期权的希腊字母。在我的笔记本电脑上需要 40 秒才能达到 运行。我是不是做错了什么会导致它花费这么多时间?
我的猜测是,当 Zygote
必须通过 Distributions
进行区分时,困难的部分就来了,但我不确定。
using Distributions
using Zygote
function bs_call(theta)
s = theta[1]
k = theta[2]
r = theta[3]
t = theta[4]
sigma = theta[5]
vol = sigma * sqrt(t)
d1 = (log(s / k) + (r + 0.5 * sigma ^ 2) * t) / vol
d2 = d1 - vol
n = Normal()
price = cdf(n, d1) * s - cdf(n, d2) * k * exp(-1.0 * r * t)
price
end
function main()
theta = [100, 110, .20, 1.0, .50]
println(bs_call(theta))
println(bs_call'(theta))
end
main()
编辑:使用 SpecialFunctions
(从 erf
构建 cdf
函数)而不是 Distributions
让我减少到 25 秒。见下文:
using SpecialFunctions
using Zygote
function cdf(x)
0.5 * (1 + erf(x / sqrt(2)))
end
function bs_call(theta)
s = theta[1]
k = theta[2]
r = theta[3]
t = theta[4]
sigma = theta[5]
vol = sigma * sqrt(t)
d1 = (log(s / k) + (r + 0.5 * sigma ^ 2) * t) / vol
d2 = d1 - vol
price = cdf(d1) * s - cdf(d2) * k * exp(-1.0 * r * t)
price
end
function main()
theta = [100.0, 110.0, .20, 1.0, .50]
println(bs_call(theta))
println(bs_call'(theta))
end
main()
鉴于您的 main
功能,您可能会在脚本中执行此操作。在 Julia 中,最好启动一个会话(在 REPL、VSCode、Jupyter notebook 或其他环境中)并从同一会话中 运行 多个工作负载。正如 Antonello 在评论中建议的那样,您的第一次调用将由编译时间主导,但后来的调用(具有相同的参数类型)仅使用编译后的代码,并且可能与第一次调用完全不同。
可以在 https://docs.julialang.org/en/v1/manual/workflow-tips/ 中找到一些工作流程提示。
我正在使用 Julia 的 Zygote.jl
包来计算期权希腊字母(即期权价格相对于参数的导数)的自动差异功能。请参阅下面使用 Black Scholes 计算看涨期权的希腊字母。在我的笔记本电脑上需要 40 秒才能达到 运行。我是不是做错了什么会导致它花费这么多时间?
我的猜测是,当 Zygote
必须通过 Distributions
进行区分时,困难的部分就来了,但我不确定。
using Distributions
using Zygote
function bs_call(theta)
s = theta[1]
k = theta[2]
r = theta[3]
t = theta[4]
sigma = theta[5]
vol = sigma * sqrt(t)
d1 = (log(s / k) + (r + 0.5 * sigma ^ 2) * t) / vol
d2 = d1 - vol
n = Normal()
price = cdf(n, d1) * s - cdf(n, d2) * k * exp(-1.0 * r * t)
price
end
function main()
theta = [100, 110, .20, 1.0, .50]
println(bs_call(theta))
println(bs_call'(theta))
end
main()
编辑:使用 SpecialFunctions
(从 erf
构建 cdf
函数)而不是 Distributions
让我减少到 25 秒。见下文:
using SpecialFunctions
using Zygote
function cdf(x)
0.5 * (1 + erf(x / sqrt(2)))
end
function bs_call(theta)
s = theta[1]
k = theta[2]
r = theta[3]
t = theta[4]
sigma = theta[5]
vol = sigma * sqrt(t)
d1 = (log(s / k) + (r + 0.5 * sigma ^ 2) * t) / vol
d2 = d1 - vol
price = cdf(d1) * s - cdf(d2) * k * exp(-1.0 * r * t)
price
end
function main()
theta = [100.0, 110.0, .20, 1.0, .50]
println(bs_call(theta))
println(bs_call'(theta))
end
main()
鉴于您的 main
功能,您可能会在脚本中执行此操作。在 Julia 中,最好启动一个会话(在 REPL、VSCode、Jupyter notebook 或其他环境中)并从同一会话中 运行 多个工作负载。正如 Antonello 在评论中建议的那样,您的第一次调用将由编译时间主导,但后来的调用(具有相同的参数类型)仅使用编译后的代码,并且可能与第一次调用完全不同。
可以在 https://docs.julialang.org/en/v1/manual/workflow-tips/ 中找到一些工作流程提示。