scipy stats zmap 函数的替代方法
Alternative to scipy stats zmap function
zmap 函数的 scipy stats 模块是否有替代方案?我目前正在使用它来获取两个非常大的数组的 zmap 分数,这需要相当长的时间。
是否有任何库或替代品可以提高其性能?或者甚至是另一个获取 zmap 函数的功能?
我们将不胜感激您的想法和意见!
下面是我的最小可重现代码:
from scipy import stats
import numpy as np
FeatureData = np.random.rand(483, 1)
goodData = np.random.rand(4640, 483)
FeatureNorm= stats.zmap(FeatureData, goodData)
下面是 scipy stats.zmap 在幕后所做的事情:
def zmap(scores, compare, axis=0, ddof=0):
scores, compare = map(np.asanyarray, [scores, compare])
mns = compare.mean(axis=axis, keepdims=True)
sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
return (scores - mns) / sstd
关于如何针对我的用例优化它的任何想法?我可以使用像 numba 或 JAX 这样的库来进一步提升它吗?
幸运的是,zmap
代码非常简单。然而,numpy 的开销来自于它必须实例化中间数组这一事实。如果您使用 numba
或 jax
中可用的数值编译器,它可以融合这些操作并以更少的开销进行计算。
不幸的是,numba 不支持 mean
和 std
的可选参数,所以让我们看一下 JAX。作为参考,这里是 scipy 和函数的原始 numpy 版本的基准,在 Google Colab CPU 运行时计算:
import numpy as np
from scipy import stats
FeatureData = np.random.rand(483, 1)
goodData = np.random.rand(4640, 483)
%timeit stats.zmap(FeatureData, goodData)
# 100 loops, best of 3: 13.9 ms per loop
def np_zmap(scores, compare, axis=0, ddof=0):
scores, compare = map(np.asanyarray, [scores, compare])
mns = compare.mean(axis=axis, keepdims=True)
sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
return (scores - mns) / sstd
%timeit np_zmap(FeatureData, goodData)
# 100 loops, best of 3: 13.8 ms per loop
下面是在 JAX 中执行的等效代码,包括急切模式和 JIT 编译:
import jax.numpy as jnp
from jax import jit
def jnp_zmap(scores, compare, axis=0, ddof=0):
scores, compare = map(jnp.asarray, [scores, compare])
mns = compare.mean(axis=axis, keepdims=True)
sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
return (scores - mns) / sstd
jit_jnp_zmap = jit(jnp_zmap)
FeatureData = jnp.array(FeatureData)
goodData = jnp.array(goodData)
%timeit jnp_zmap(FeatureData, goodData).block_until_ready()
# 100 loops, best of 3: 8.59 ms per loop
jit_jnp_zmap(FeatureData, goodData) # trigger compilation
%timeit jit_jnp_zmap(FeatureData, goodData).block_until_ready()
# 100 loops, best of 3: 2.78 ms per loop
JIT-compiled 版本比 scipy 或 numpy 代码快大约 5 倍。在 Colab T4 GPU 运行时上,编译后的版本又增加了 10 倍:
%timeit jit_jnp_zmap(FeatureData, goodData).block_until_ready()
1000 loops, best of 3: 286 µs per loop
如果这种操作是您分析中的瓶颈,那么像 JAX 这样的编译器可能是一个不错的选择。
zmap 函数的 scipy stats 模块是否有替代方案?我目前正在使用它来获取两个非常大的数组的 zmap 分数,这需要相当长的时间。
是否有任何库或替代品可以提高其性能?或者甚至是另一个获取 zmap 函数的功能?
我们将不胜感激您的想法和意见!
下面是我的最小可重现代码:
from scipy import stats
import numpy as np
FeatureData = np.random.rand(483, 1)
goodData = np.random.rand(4640, 483)
FeatureNorm= stats.zmap(FeatureData, goodData)
下面是 scipy stats.zmap 在幕后所做的事情:
def zmap(scores, compare, axis=0, ddof=0):
scores, compare = map(np.asanyarray, [scores, compare])
mns = compare.mean(axis=axis, keepdims=True)
sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
return (scores - mns) / sstd
关于如何针对我的用例优化它的任何想法?我可以使用像 numba 或 JAX 这样的库来进一步提升它吗?
幸运的是,zmap
代码非常简单。然而,numpy 的开销来自于它必须实例化中间数组这一事实。如果您使用 numba
或 jax
中可用的数值编译器,它可以融合这些操作并以更少的开销进行计算。
不幸的是,numba 不支持 mean
和 std
的可选参数,所以让我们看一下 JAX。作为参考,这里是 scipy 和函数的原始 numpy 版本的基准,在 Google Colab CPU 运行时计算:
import numpy as np
from scipy import stats
FeatureData = np.random.rand(483, 1)
goodData = np.random.rand(4640, 483)
%timeit stats.zmap(FeatureData, goodData)
# 100 loops, best of 3: 13.9 ms per loop
def np_zmap(scores, compare, axis=0, ddof=0):
scores, compare = map(np.asanyarray, [scores, compare])
mns = compare.mean(axis=axis, keepdims=True)
sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
return (scores - mns) / sstd
%timeit np_zmap(FeatureData, goodData)
# 100 loops, best of 3: 13.8 ms per loop
下面是在 JAX 中执行的等效代码,包括急切模式和 JIT 编译:
import jax.numpy as jnp
from jax import jit
def jnp_zmap(scores, compare, axis=0, ddof=0):
scores, compare = map(jnp.asarray, [scores, compare])
mns = compare.mean(axis=axis, keepdims=True)
sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
return (scores - mns) / sstd
jit_jnp_zmap = jit(jnp_zmap)
FeatureData = jnp.array(FeatureData)
goodData = jnp.array(goodData)
%timeit jnp_zmap(FeatureData, goodData).block_until_ready()
# 100 loops, best of 3: 8.59 ms per loop
jit_jnp_zmap(FeatureData, goodData) # trigger compilation
%timeit jit_jnp_zmap(FeatureData, goodData).block_until_ready()
# 100 loops, best of 3: 2.78 ms per loop
JIT-compiled 版本比 scipy 或 numpy 代码快大约 5 倍。在 Colab T4 GPU 运行时上,编译后的版本又增加了 10 倍:
%timeit jit_jnp_zmap(FeatureData, goodData).block_until_ready()
1000 loops, best of 3: 286 µs per loop
如果这种操作是您分析中的瓶颈,那么像 JAX 这样的编译器可能是一个不错的选择。