JAX - 功能区分问题
JAX - Problem in differentiating of function
我正在尝试对调用执行蒙特卡洛模拟,然后在 Python 中计算其关于标的资产的一阶导数,但它仍然不起作用
from jax import random
from jax import jit, grad, vmap
import jax.numpy as jnp
xi = jnp.linspace(1,1.2,5)
def Simulation(xi):
K,T,number_sim,sigma,r,q = 1.,1.,100,0.4,0,0
S = jnp.broadcast_to(xi,(number_sim,len(xi))).T
mean = -.5 * sigma * sigma * T
volatility = sigma*jnp.sqrt(T)
r_numb = random.PRNGKey(10)
BM = mean + volatility * random.normal(r_numb, shape=(number_sim,))
product = S*jnp.exp(BM)
payoff = jnp.maximum(product-K,0)
result = jnp.average(payoff, axis=1)*jnp.exp(-q*T)
return result
first_derivative = vmap(grad(Simulation))(xi)
我不知道实现算法的方式是否是使用“AD 方法”计算导数的最佳方式;该算法以这种方式工作:
S = 模拟包含所有底层证券的矩阵;对于每一行,我都使用“xi = jnp.linspace”生成每个底层证券,并且在矩阵的每一行内,我有相同的值多次等于“number_sim”
product = 生成 BM(包含正态数的向量)后,我需要将 BM 的每个元素相乘(带 exp)与 S
的每一行的每个元素
所以这是对算法的简短解释,我非常感谢任何解决这个问题的建议或技巧,并用 AD 方法计算导数!
提前致谢
您的函数似乎映射了一个向量 Rᴺ→Rᴺ。在这种情况下,有两个有意义的导数概念:逐元素导数(在 JAX 中,您可以通过组合 jax.vmap
和 jax.grad
来计算)。这将 return 一个长度为 N 的导数向量,其中元素 i 包含 i[=第 36=] 个输出相对于第 i 个输入。
或者,您可以计算雅可比矩阵(使用 jax.jacobian
),它将 return 形状 [N, N]
矩阵,其中元素 i,j 包含第 i 个输出相对于第 j 个输入的导数。
您遇到的问题是,您的函数是在假定矢量输入的情况下编写的(您要求 xi
的长度),这意味着您对雅可比矩阵感兴趣,但您要求的是逐元素导数,需要标量值函数。
所以你有两种可能的方法来解决这个问题,这取决于你对什么导数感兴趣。如果你对雅可比矩阵感兴趣,你可以使用所写的函数并使用 jax.jacobian
转换:
from jax import jacobian
print(jacobian(Simulation)(xi))
# [[0.6528027 0. 0. 0. 0. ]
# [0. 0.6819291 0. 0. 0. ]
# [0. 0. 0.7003516 0. 0. ]
# [0. 0. 0. 0.7181915 0. ]
# [0. 0. 0. 0. 0.7608434]]
或者,如果您对逐元素梯度感兴趣,您可以重写您的函数以与标量输入兼容,并像示例中那样使用 grad 的 vmap。只需更改两行:
def Simulation_scalar(xi):
K,T,number_sim,sigma,r,q = 1.,1.,100,0.4,0,0
# S = jnp.broadcast_to(xi,(number_sim,len(xi))).T
S = jnp.broadcast_to(xi,(number_sim,) + xi.shape).T
mean = -.5 * sigma * sigma * T
volatility = sigma*jnp.sqrt(T)
r_numb = random.PRNGKey(10)
BM = mean + volatility * random.normal(r_numb, shape=(number_sim,))
product = S*jnp.exp(BM)
payoff = jnp.maximum(product-K,0)
# result = jnp.average(payoff, axis=1)*jnp.exp(-q*T)
result = jnp.average(payoff, axis=-1)*jnp.exp(-q*T)
return result
print(vmap(grad(Simulation_scalar))(xi))
# [0.6528027 0.6819291 0.7003516 0.7181915 0.7608434]
我正在尝试对调用执行蒙特卡洛模拟,然后在 Python 中计算其关于标的资产的一阶导数,但它仍然不起作用
from jax import random
from jax import jit, grad, vmap
import jax.numpy as jnp
xi = jnp.linspace(1,1.2,5)
def Simulation(xi):
K,T,number_sim,sigma,r,q = 1.,1.,100,0.4,0,0
S = jnp.broadcast_to(xi,(number_sim,len(xi))).T
mean = -.5 * sigma * sigma * T
volatility = sigma*jnp.sqrt(T)
r_numb = random.PRNGKey(10)
BM = mean + volatility * random.normal(r_numb, shape=(number_sim,))
product = S*jnp.exp(BM)
payoff = jnp.maximum(product-K,0)
result = jnp.average(payoff, axis=1)*jnp.exp(-q*T)
return result
first_derivative = vmap(grad(Simulation))(xi)
我不知道实现算法的方式是否是使用“AD 方法”计算导数的最佳方式;该算法以这种方式工作:
S = 模拟包含所有底层证券的矩阵;对于每一行,我都使用“xi = jnp.linspace”生成每个底层证券,并且在矩阵的每一行内,我有相同的值多次等于“number_sim”
product = 生成 BM(包含正态数的向量)后,我需要将 BM 的每个元素相乘(带 exp)与 S
的每一行的每个元素
所以这是对算法的简短解释,我非常感谢任何解决这个问题的建议或技巧,并用 AD 方法计算导数! 提前致谢
您的函数似乎映射了一个向量 Rᴺ→Rᴺ。在这种情况下,有两个有意义的导数概念:逐元素导数(在 JAX 中,您可以通过组合 jax.vmap
和 jax.grad
来计算)。这将 return 一个长度为 N 的导数向量,其中元素 i 包含 i[=第 36=] 个输出相对于第 i 个输入。
或者,您可以计算雅可比矩阵(使用 jax.jacobian
),它将 return 形状 [N, N]
矩阵,其中元素 i,j 包含第 i 个输出相对于第 j 个输入的导数。
您遇到的问题是,您的函数是在假定矢量输入的情况下编写的(您要求 xi
的长度),这意味着您对雅可比矩阵感兴趣,但您要求的是逐元素导数,需要标量值函数。
所以你有两种可能的方法来解决这个问题,这取决于你对什么导数感兴趣。如果你对雅可比矩阵感兴趣,你可以使用所写的函数并使用 jax.jacobian
转换:
from jax import jacobian
print(jacobian(Simulation)(xi))
# [[0.6528027 0. 0. 0. 0. ]
# [0. 0.6819291 0. 0. 0. ]
# [0. 0. 0.7003516 0. 0. ]
# [0. 0. 0. 0.7181915 0. ]
# [0. 0. 0. 0. 0.7608434]]
或者,如果您对逐元素梯度感兴趣,您可以重写您的函数以与标量输入兼容,并像示例中那样使用 grad 的 vmap。只需更改两行:
def Simulation_scalar(xi):
K,T,number_sim,sigma,r,q = 1.,1.,100,0.4,0,0
# S = jnp.broadcast_to(xi,(number_sim,len(xi))).T
S = jnp.broadcast_to(xi,(number_sim,) + xi.shape).T
mean = -.5 * sigma * sigma * T
volatility = sigma*jnp.sqrt(T)
r_numb = random.PRNGKey(10)
BM = mean + volatility * random.normal(r_numb, shape=(number_sim,))
product = S*jnp.exp(BM)
payoff = jnp.maximum(product-K,0)
# result = jnp.average(payoff, axis=1)*jnp.exp(-q*T)
result = jnp.average(payoff, axis=-1)*jnp.exp(-q*T)
return result
print(vmap(grad(Simulation_scalar))(xi))
# [0.6528027 0.6819291 0.7003516 0.7181915 0.7608434]