检查二维子数组是否有序 - Python JAX
Check if 2D sub-array is ordered - Pyhthon JAX
让我们假设我们有一个数组ordered
。我们要检查子数组 t
和 t_inv
是否遵循与 order
数组中强加顺序相同的顺序。
从左到右阅读:第一个元素是[0,0]
,依此类推,直到[0,3]
。
t_inv
是反转的,因为交换了第一个元素,它们不遵循 ordered
.
中的顺序
# imposed order
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])
# array with permuted order
t = jnp.array([[[0, 0],[0, 1], [0,3]]])
t_inv = jnp.array([[[0, 1],[0, 0], [0,3]]])
我期望如下:
result: ordered(t) = 1, because "ordered"
and ordered(t_inv) = -1, because "swapped/not ordered"
如何检查子数组确实是有序数组的一部分并输出顺序是否正确?
你可以这样做:
import jax.numpy as jnp
# imposed order
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])
# array with permuted order
t = jnp.array([[0, 0],[0, 1], [0,3]])
t_inv = jnp.array([[0, 1],[0, 0], [0,3]])
def is_sorted(t, ordered):
index = jnp.where((t[:, None] == ordered).all(-1))[1]
return jnp.where((index == jnp.sort(index)).all(), 1, -1)
print(is_sorted(t, ordered))
# 1
print(is_sorted(t_inv, ordered))
# -1
Scaling-wise,使用基于searchsorted
的解决方案可能会更快,但是目前JAX中jnp.searchsorted
的实现相对较慢,因为XLA没有任何native二进制搜索算法,因此在实践中完整的成对比较通常可以更高效。
让我们假设我们有一个数组ordered
。我们要检查子数组 t
和 t_inv
是否遵循与 order
数组中强加顺序相同的顺序。
从左到右阅读:第一个元素是[0,0]
,依此类推,直到[0,3]
。
t_inv
是反转的,因为交换了第一个元素,它们不遵循 ordered
.
# imposed order
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])
# array with permuted order
t = jnp.array([[[0, 0],[0, 1], [0,3]]])
t_inv = jnp.array([[[0, 1],[0, 0], [0,3]]])
我期望如下:
result: ordered(t) = 1, because "ordered"
and ordered(t_inv) = -1, because "swapped/not ordered"
如何检查子数组确实是有序数组的一部分并输出顺序是否正确?
你可以这样做:
import jax.numpy as jnp
# imposed order
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])
# array with permuted order
t = jnp.array([[0, 0],[0, 1], [0,3]])
t_inv = jnp.array([[0, 1],[0, 0], [0,3]])
def is_sorted(t, ordered):
index = jnp.where((t[:, None] == ordered).all(-1))[1]
return jnp.where((index == jnp.sort(index)).all(), 1, -1)
print(is_sorted(t, ordered))
# 1
print(is_sorted(t_inv, ordered))
# -1
Scaling-wise,使用基于searchsorted
的解决方案可能会更快,但是目前JAX中jnp.searchsorted
的实现相对较慢,因为XLA没有任何native二进制搜索算法,因此在实践中完整的成对比较通常可以更高效。