精确线搜索算法

Exact Line Search Algorithm

我正在尝试在 Julia 中实现一个简单的线搜索算法。我是 Julia 编程的新手,所以我在旅途中学习它。如果可能的话,我想寻求一些帮助来纠正 运行 代码中的错误。

源代码.

using LinearAlgebra

function bracket_minimum(f, x = 0, s = 1e-2, k = 2.0)
    a, fa = x, f(x)
    b, fb = x + s, f(x + s)

    if(fb > fa)
        a, b = b, a
        fa, fb = fb, fa
        s = -s
    end

    while(true)
        c, fc = b + s, f(b + s)
        if(fb < fc)
            return a < c ? (a, c) : (c, a)
        else
            a, fa, b, fb = b, fb, c, fc
            s *= k
        end
    end
end

function bisection(f, a₀, b₀, ϵ)

    function D(f,a)
        # Approximate the first derivative using central differences
        h = 0.001
        return (f(a + h) - f(a - h))/(2 * h)
    end

    a = a₀
    b = b₀

    while((b - a) > ϵ)
        c = (a + b)/2.0

        if D(f,c) > 0
            b = c
        else
            a = c
        end
    end

    return (a,b)
end

function line_search(f::Function, x::Vector{Float64}, d::Vector{Float64})
    println("Hello")
    objective = α -> f(x + α*d)
    a, b = bracket_minimum(objective)
    α = bisection(objective, a, b, 1e-5)
    return α, x + α*d
end

f(x) = sin(x[1] * x[2]) + exp(x[2] + x[3]) - x[3]

x = [1,2,3]
d = [0, -1, -1]
α, x_min = line_search(f, x, d)

我遇到线性代数错误,所以我想我一定没有正确传递向量,或者我可能没有正确进行标量-向量乘法。但是,我很难弄清楚。如果我单步执行代码,它会在函数调用时失败 line_search(f,x,d),甚至不会进入函数体内。

错误描述

ERROR: MethodError: no method matching *(::Tuple{Float64,Float64}, ::Array{Int64,1})
Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...) at operators.jl:538
  *(::Adjoint{var"#s828",var"#s8281"} where var"#s8281"<:(AbstractArray{T,1} where T) where var"#s828"<:Number, ::AbstractArray{var"#s827",1} where var"#s827"<:Number) at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\adjtrans.jl:283
  *(::Transpose{T,var"#s828"} where var"#s828"<:(AbstractArray{T,1} where T), ::AbstractArray{T,1}) where T<:Real at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\adjtrans.jl:284

这是代码中的一个修复(我已经清理了一些风格上的东西,但关键问题是你的 bisection return 编辑了一个元组而不是一个值 - 我已经将它更改为 return包围间隔的中心):

function bracket_minimum(f, x = 0.0, s = 1e-2, k = 2.0)
    a, fa = x, f(x)
    b, fb = x + s, f(x + s)

    if fb > fa
        a, b = b, a
        fa, fb = fb, fa
        s = -s
    end

    while true
        s *= k
        c, fc = b + s, f(b + s)
        if fb < fc
            return minmax(a, c)
        else
            a, fa, b, fb = b, fb, c, fc
        end
    end
end

function bisection(f, a₀, b₀, ϵ)

    function D(f, a)
        # Approximate the first derivative using central differences
        h = 0.001
        return (f(a + h) - f(a - h)) / (2 * h)
    end

    a = a₀
    b = b₀

    while (b - a) > ϵ
        c = (a + b) / 2.0

        if D(f, c) > 0
            b = c
        else
            a = c
        end
    end

    return (a + b) / 2 # this was changed
end

function line_search(f::Function, x::Vector{Float64}, d::Vector{Float64})
    @assert length(x) == length(d)
    objective(α) = f(x .+ α .* d)
    a, b = bracket_minimum(objective)
    α = bisection(objective, a, b, 1e-5)
    return α, x .+ α .* d
end

f(x) = sin(x[1] * x[2]) + exp(x[2] + x[3]) - x[3]

x = [1.0, 2.0, 3.0]
d = [0.0, -1.0, -1.0]
α, x_min = line_search(f, x, d)

我没有对算法发表评论,因为我假设您将此作为编程练习来编写,而不是尝试编写最快和最可靠的算法。