加速 Zygote.jl AD

Speeding up Zygote.jl AD

我正在使用 Julia 的 Zygote.jl 包来计算期权希腊字母(即期权价格相对于参数的导数)的自动差异功能。请参阅下面使用 Black Scholes 计算看涨期权的希腊字母。在我的笔记本电脑上需要 40 秒才能达到 运行。我是不是做错了什么会导致它花费这么多时间?

我的猜测是,当 Zygote 必须通过 Distributions 进行区分时,困难的部分就来了,但我不确定。

using Distributions
using Zygote


function bs_call(theta)
    s = theta[1]
    k = theta[2]
    r = theta[3]
    t = theta[4]
    sigma = theta[5]
    vol = sigma * sqrt(t)
    d1 = (log(s / k) + (r + 0.5 * sigma ^ 2) * t) / vol
    d2 = d1 - vol
    n = Normal()
    price = cdf(n, d1) * s - cdf(n, d2) * k * exp(-1.0 * r * t)
    price
end


function main()
    theta = [100, 110, .20, 1.0, .50]
    println(bs_call(theta))
    println(bs_call'(theta))
end


main()

编辑:使用 SpecialFunctions(从 erf 构建 cdf 函数)而不是 Distributions 让我减少到 25 秒。见下文:

using SpecialFunctions
using Zygote


function cdf(x)
    0.5 * (1 + erf(x / sqrt(2)))
end


function bs_call(theta)
    s = theta[1]
    k = theta[2]
    r = theta[3]
    t = theta[4]
    sigma = theta[5]
    vol = sigma * sqrt(t)
    d1 = (log(s / k) + (r + 0.5 * sigma ^ 2) * t) / vol
    d2 = d1 - vol
    price = cdf(d1) * s - cdf(d2) * k * exp(-1.0 * r * t)
    price
end


function main()
    theta = [100.0, 110.0, .20, 1.0, .50]
    println(bs_call(theta))
    println(bs_call'(theta))
end


main()

鉴于您的 main 功能,您可能会在脚本中执行此操作。在 Julia 中,最好启动一个会话(在 REPL、VSCode、Jupyter notebook 或其他环境中)并从同一会话中 运行 多个工作负载。正如 Antonello 在评论中建议的那样,您的第一次调用将由编译时间主导,但后来的调用(具有相同的参数类型)仅使用编译后的代码,并且可能与第一次调用完全不同。

可以在 https://docs.julialang.org/en/v1/manual/workflow-tips/ 中找到一些工作流程提示。