在 Julia 中使用 ForwardDiff 时限制函数签名
Restricting function signatures while using ForwardDiff in Julia
我正在尝试在一个库中使用 ForwardDiff,其中几乎所有函数都被限制为只能接收浮点数。我想概括这些函数签名,以便可以使用 ForwardDiff,同时仍然具有足够的限制性,因此函数只采用数值而不是日期之类的东西。我有很多具有相同名称但类型不同的函数(即函数将 "time" 作为 float 或具有相同函数名称的 Date 接受)并且不想从头到尾删除类型限定符。
最小工作示例
using ForwardDiff
x = [1.0, 2.0, 3.0, 4.0 ,5.0]
typeof(x) # Array{Float64,1}
function G(x::Array{Real,1})
return sum(exp.(x))
end
function grad_F(x::Array)
return ForwardDiff.gradient(G, x)
end
G(x) # Method Error
grad_F(x) # Method error
function G(x::Array{Float64,1})
return sum(exp.(x))
end
G(x) # This works
grad_F(x) # This has a method error
function G(x)
return sum(exp.(x))
end
G(x) # This works
grad_F(x) # This works
# But now I cannot restrict the function G to only take numeric arrays and not for instance arrays of Dates.
有没有办法限制函数只接受数值(整数和浮点数)以及 ForwardDiff 使用但不允许符号、日期等的任何双数结构
ForwardDiff.Dual
是抽象类型 Real
的子类型。但是,您遇到的问题是 Julia 的类型参数是不变的,而不是协变的。下面的话,returns false.
# check if `Array{Float64, 1}` is a subtype of `Array{Real, 1}`
julia> Array{Float64, 1} <: Array{Real, 1}
false
这使得你的函数定义
function G(x::Array{Real,1})
return sum(exp.(x))
end
不正确(不适合您使用)。这就是您收到以下错误的原因。
julia> G(x)
ERROR: MethodError: no method matching G(::Array{Float64,1})
正确的定义应该是
function G(x::Array{<:Real,1})
return sum(exp.(x))
end
或者如果您需要以某种方式轻松访问数组的具体元素类型
function G(x::Array{T,1}) where {T<:Real}
return sum(exp.(x))
end
您的 grad_F
功能也是如此。
您可能会发现阅读 the relevant section 的 Julia 类型文档很有用。
您可能还想为 AbstractArray{<:Real,1}
类型而不是 Array{<:Real, 1}
类型注释您的函数,以便您的函数可以处理其他类型的数组,例如 StaticArrays
、OffsetArrays
等,无需重新定义。
这将接受由任何类型的数字参数化的任何类型的数组:
function foo(xs::AbstractArray{<:Number})
@show typeof(xs)
end
或:
function foo(xs::AbstractArray{T}) where T<:Number
@show typeof(xs)
end
如果你需要引用函数体中的类型参数T
。
x1 = [1.0, 2.0, 3.0, 4.0 ,5.0]
x2 = [1, 2, 3,4, 5]
x3 = 1:5
x4 = 1.0:5.0
x5 = [1//2, 1//4, 1//8]
xss = [x1, x2, x3, x4, x5]
function foo(xs::AbstractArray{T}) where T<:Number
@show xs typeof(xs) T
println()
end
for xs in xss
foo(xs)
end
输出:
xs = [1.0, 2.0, 3.0, 4.0, 5.0]
typeof(xs) = Array{Float64,1}
T = Float64
xs = [1, 2, 3, 4, 5]
typeof(xs) = Array{Int64,1}
T = Int64
xs = 1:5
typeof(xs) = UnitRange{Int64}
T = Int64
xs = 1.0:1.0:5.0
typeof(xs) = StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}
T = Float64
xs = Rational{Int64}[1//2, 1//4, 1//8]
typeof(xs) = Array{Rational{Int64},1}
T = Rational{Int64}
您可以运行此处的示例代码:https://repl.it/@SalchiPapa/Restricting-function-signatures-in-Julia
我正在尝试在一个库中使用 ForwardDiff,其中几乎所有函数都被限制为只能接收浮点数。我想概括这些函数签名,以便可以使用 ForwardDiff,同时仍然具有足够的限制性,因此函数只采用数值而不是日期之类的东西。我有很多具有相同名称但类型不同的函数(即函数将 "time" 作为 float 或具有相同函数名称的 Date 接受)并且不想从头到尾删除类型限定符。
最小工作示例
using ForwardDiff
x = [1.0, 2.0, 3.0, 4.0 ,5.0]
typeof(x) # Array{Float64,1}
function G(x::Array{Real,1})
return sum(exp.(x))
end
function grad_F(x::Array)
return ForwardDiff.gradient(G, x)
end
G(x) # Method Error
grad_F(x) # Method error
function G(x::Array{Float64,1})
return sum(exp.(x))
end
G(x) # This works
grad_F(x) # This has a method error
function G(x)
return sum(exp.(x))
end
G(x) # This works
grad_F(x) # This works
# But now I cannot restrict the function G to only take numeric arrays and not for instance arrays of Dates.
有没有办法限制函数只接受数值(整数和浮点数)以及 ForwardDiff 使用但不允许符号、日期等的任何双数结构
ForwardDiff.Dual
是抽象类型 Real
的子类型。但是,您遇到的问题是 Julia 的类型参数是不变的,而不是协变的。下面的话,returns false.
# check if `Array{Float64, 1}` is a subtype of `Array{Real, 1}`
julia> Array{Float64, 1} <: Array{Real, 1}
false
这使得你的函数定义
function G(x::Array{Real,1})
return sum(exp.(x))
end
不正确(不适合您使用)。这就是您收到以下错误的原因。
julia> G(x)
ERROR: MethodError: no method matching G(::Array{Float64,1})
正确的定义应该是
function G(x::Array{<:Real,1})
return sum(exp.(x))
end
或者如果您需要以某种方式轻松访问数组的具体元素类型
function G(x::Array{T,1}) where {T<:Real}
return sum(exp.(x))
end
您的 grad_F
功能也是如此。
您可能会发现阅读 the relevant section 的 Julia 类型文档很有用。
您可能还想为 AbstractArray{<:Real,1}
类型而不是 Array{<:Real, 1}
类型注释您的函数,以便您的函数可以处理其他类型的数组,例如 StaticArrays
、OffsetArrays
等,无需重新定义。
这将接受由任何类型的数字参数化的任何类型的数组:
function foo(xs::AbstractArray{<:Number})
@show typeof(xs)
end
或:
function foo(xs::AbstractArray{T}) where T<:Number
@show typeof(xs)
end
如果你需要引用函数体中的类型参数T
。
x1 = [1.0, 2.0, 3.0, 4.0 ,5.0]
x2 = [1, 2, 3,4, 5]
x3 = 1:5
x4 = 1.0:5.0
x5 = [1//2, 1//4, 1//8]
xss = [x1, x2, x3, x4, x5]
function foo(xs::AbstractArray{T}) where T<:Number
@show xs typeof(xs) T
println()
end
for xs in xss
foo(xs)
end
输出:
xs = [1.0, 2.0, 3.0, 4.0, 5.0]
typeof(xs) = Array{Float64,1}
T = Float64
xs = [1, 2, 3, 4, 5]
typeof(xs) = Array{Int64,1}
T = Int64
xs = 1:5
typeof(xs) = UnitRange{Int64}
T = Int64
xs = 1.0:1.0:5.0
typeof(xs) = StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}
T = Float64
xs = Rational{Int64}[1//2, 1//4, 1//8]
typeof(xs) = Array{Rational{Int64},1}
T = Rational{Int64}
您可以运行此处的示例代码:https://repl.it/@SalchiPapa/Restricting-function-signatures-in-Julia