JAX:jit 函数的时间随着函数访问的内存而超线性增长
JAX: time to jit a function grows superlinear with memory accessed by function
这里举个简单的例子,对两个高斯pdf的乘积进行数值积分。其中一个高斯分布是固定的,均值始终为 0。另一个高斯分布的均值不同:
import time
import jax.numpy as np
from jax import jit
from jax.scipy.stats.norm import pdf
# set up evaluation points for numerical integration
integr_resolution = 6400
lower_bound = -100
upper_bound = 100
integr_grid = np.linspace(lower_bound, upper_bound, integr_resolution)
proba = pdf(integr_grid)
integration_weight = (upper_bound - lower_bound) / integr_resolution
# integrate with new mean
def integrate(mu_new):
x_new = integr_grid - mu_new
proba_new = pdf(x_new)
total_proba = sum(proba * proba_new * integration_weight)
return total_proba
print('starting jit')
start = time.perf_counter()
integrate = jit(integrate)
integrate(1)
stop = time.perf_counter()
print('took: ', stop - start)
功能看似简单,但根本无法扩展。以下列表包含成对的([=39= 的值],运行 代码花费的时间):
- 100 | 0.107s
- 200 | 0.23s
- 400 | 0.537s
- 800 | 1.52s
- 1600 | 5.2s
- 3200 | 19s
- 6400 | 134s
供参考,应用于 integr_resolution=6400
的未编译函数需要 0.02 秒。
我认为这可能与函数正在访问全局变量有关。但是移动代码以在函数内部设置积分点对时序没有 notable 影响。下面的代码需要 5.36s 到 运行。它对应于 1600 的 table 条目,之前需要 5.2s:
# integrate with new mean
def integrate(mu_new):
# set up evaluation points for numerical integration
integr_resolution = 1600
lower_bound = -100
upper_bound = 100
integr_grid = np.linspace(lower_bound, upper_bound, integr_resolution)
proba = pdf(integr_grid)
integration_weight = (upper_bound - lower_bound) / integr_resolution
x_new = integr_grid - mu_new
proba_new = pdf(x_new)
total_proba = sum(proba * proba_new * integration_weight)
return total_proba
这里发生了什么?
我也在 https://github.com/google/jax/issues/1776 上回答了这个问题,但也在这里添加了答案。
这是因为代码在应该使用np.sum
的地方使用了sum
。
sum
是一个 Python 内置函数,它提取序列中的每个元素并使用 +
运算符将它们逐一求和。这具有构建 XLA 需要很长时间编译的大型展开链的效果。
如果您使用 np.sum
,则 JAX 构建单个 XLA 缩减运算符,编译速度要快得多。
只是为了展示我是如何解决这个问题的:我使用了 jax.make_jaxpr
,它转储了函数的 JAX 内部跟踪表示。在这里,它显示:
In [3]: import jax
In [4]: jax.make_jaxpr(integrate)(1)
Out[4]:
{ lambda b c ; ; a.
let d = convert_element_type[ new_dtype=float32
old_dtype=int32 ] a
e = sub c d
f = sub e 0.0
g = pow f 2.0
h = div g 1.0
i = add 1.8378770351409912 h
j = neg i
k = div j 2.0
l = exp k
m = mul b l
n = mul m 2.0
o = slice[ start_indices=(0,)
limit_indices=(1,)
strides=(1,)
operand_shape=(100,) ] n
p = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] o
q = add p 0.0
r = slice[ start_indices=(1,)
limit_indices=(2,)
strides=(1,)
operand_shape=(100,) ] n
s = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] r
t = add q s
u = slice[ start_indices=(2,)
limit_indices=(3,)
strides=(1,)
operand_shape=(100,) ] n
v = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] u
w = add t v
x = slice[ start_indices=(3,)
limit_indices=(4,)
strides=(1,)
operand_shape=(100,) ] n
y = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] x
z = add w y
... similarly ...
然后很明显为什么这很慢:程序非常大。
对比 np.sum
版本:
In [5]: def integrate(mu_new):
...: x_new = integr_grid - mu_new
...:
...: proba_new = pdf(x_new)
...: total_proba = np.sum(proba * proba_new * integration_weight)
...:
...: return total_proba
...:
In [6]: jax.make_jaxpr(integrate)(1)
Out[6]:
{ lambda b c ; ; a.
let d = convert_element_type[ new_dtype=float32
old_dtype=int32 ] a
e = sub c d
f = sub e 0.0
g = pow f 2.0
h = div g 1.0
i = add 1.8378770351409912 h
j = neg i
k = div j 2.0
l = exp k
m = mul b l
n = mul m 2.0
o = reduce_sum[ axes=(0,)
input_shape=(100,) ] n
in [o] }
希望对您有所帮助!
这里举个简单的例子,对两个高斯pdf的乘积进行数值积分。其中一个高斯分布是固定的,均值始终为 0。另一个高斯分布的均值不同:
import time
import jax.numpy as np
from jax import jit
from jax.scipy.stats.norm import pdf
# set up evaluation points for numerical integration
integr_resolution = 6400
lower_bound = -100
upper_bound = 100
integr_grid = np.linspace(lower_bound, upper_bound, integr_resolution)
proba = pdf(integr_grid)
integration_weight = (upper_bound - lower_bound) / integr_resolution
# integrate with new mean
def integrate(mu_new):
x_new = integr_grid - mu_new
proba_new = pdf(x_new)
total_proba = sum(proba * proba_new * integration_weight)
return total_proba
print('starting jit')
start = time.perf_counter()
integrate = jit(integrate)
integrate(1)
stop = time.perf_counter()
print('took: ', stop - start)
功能看似简单,但根本无法扩展。以下列表包含成对的([=39= 的值],运行 代码花费的时间):
- 100 | 0.107s
- 200 | 0.23s
- 400 | 0.537s
- 800 | 1.52s
- 1600 | 5.2s
- 3200 | 19s
- 6400 | 134s
供参考,应用于 integr_resolution=6400
的未编译函数需要 0.02 秒。
我认为这可能与函数正在访问全局变量有关。但是移动代码以在函数内部设置积分点对时序没有 notable 影响。下面的代码需要 5.36s 到 运行。它对应于 1600 的 table 条目,之前需要 5.2s:
# integrate with new mean
def integrate(mu_new):
# set up evaluation points for numerical integration
integr_resolution = 1600
lower_bound = -100
upper_bound = 100
integr_grid = np.linspace(lower_bound, upper_bound, integr_resolution)
proba = pdf(integr_grid)
integration_weight = (upper_bound - lower_bound) / integr_resolution
x_new = integr_grid - mu_new
proba_new = pdf(x_new)
total_proba = sum(proba * proba_new * integration_weight)
return total_proba
这里发生了什么?
我也在 https://github.com/google/jax/issues/1776 上回答了这个问题,但也在这里添加了答案。
这是因为代码在应该使用np.sum
的地方使用了sum
。
sum
是一个 Python 内置函数,它提取序列中的每个元素并使用 +
运算符将它们逐一求和。这具有构建 XLA 需要很长时间编译的大型展开链的效果。
如果您使用 np.sum
,则 JAX 构建单个 XLA 缩减运算符,编译速度要快得多。
只是为了展示我是如何解决这个问题的:我使用了 jax.make_jaxpr
,它转储了函数的 JAX 内部跟踪表示。在这里,它显示:
In [3]: import jax
In [4]: jax.make_jaxpr(integrate)(1)
Out[4]:
{ lambda b c ; ; a.
let d = convert_element_type[ new_dtype=float32
old_dtype=int32 ] a
e = sub c d
f = sub e 0.0
g = pow f 2.0
h = div g 1.0
i = add 1.8378770351409912 h
j = neg i
k = div j 2.0
l = exp k
m = mul b l
n = mul m 2.0
o = slice[ start_indices=(0,)
limit_indices=(1,)
strides=(1,)
operand_shape=(100,) ] n
p = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] o
q = add p 0.0
r = slice[ start_indices=(1,)
limit_indices=(2,)
strides=(1,)
operand_shape=(100,) ] n
s = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] r
t = add q s
u = slice[ start_indices=(2,)
limit_indices=(3,)
strides=(1,)
operand_shape=(100,) ] n
v = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] u
w = add t v
x = slice[ start_indices=(3,)
limit_indices=(4,)
strides=(1,)
operand_shape=(100,) ] n
y = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] x
z = add w y
... similarly ...
然后很明显为什么这很慢:程序非常大。
对比 np.sum
版本:
In [5]: def integrate(mu_new):
...: x_new = integr_grid - mu_new
...:
...: proba_new = pdf(x_new)
...: total_proba = np.sum(proba * proba_new * integration_weight)
...:
...: return total_proba
...:
In [6]: jax.make_jaxpr(integrate)(1)
Out[6]:
{ lambda b c ; ; a.
let d = convert_element_type[ new_dtype=float32
old_dtype=int32 ] a
e = sub c d
f = sub e 0.0
g = pow f 2.0
h = div g 1.0
i = add 1.8378770351409912 h
j = neg i
k = div j 2.0
l = exp k
m = mul b l
n = mul m 2.0
o = reduce_sum[ axes=(0,)
input_shape=(100,) ] n
in [o] }
希望对您有所帮助!