检查二维子数组是否有序 - Python JAX

Check if 2D sub-array is ordered - Pyhthon JAX

让我们假设我们有一个数组ordered。我们要检查子数组 t t_inv 是否遵循与 order 数组中强加顺序相同的顺序。

从左到右阅读:第一个元素是[0,0],依此类推,直到[0,3]t_inv 是反转的,因为交换了第一个元素,它们不遵循 ordered.

中的顺序
# imposed order 
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])

# array with permuted order
t = jnp.array([[[0, 0],[0, 1], [0,3]]])
t_inv = jnp.array([[[0, 1],[0, 0], [0,3]]])

我期望如下:

 result: ordered(t) = 1, because "ordered"  
and ordered(t_inv) = -1, because "swapped/not ordered"

如何检查子数组确实是有序数组的一部分并输出顺序是否正确?

你可以这样做:

import jax.numpy as jnp

# imposed order 
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])

# array with permuted order
t = jnp.array([[0, 0],[0, 1], [0,3]])
t_inv = jnp.array([[0, 1],[0, 0], [0,3]])


def is_sorted(t, ordered):
  index = jnp.where((t[:, None] == ordered).all(-1))[1]
  return jnp.where((index == jnp.sort(index)).all(), 1, -1)

print(is_sorted(t, ordered))
# 1
print(is_sorted(t_inv, ordered))
# -1

Scaling-wise,使用基于searchsorted的解决方案可能会更快,但是目前JAX中jnp.searchsorted的实现相对较慢,因为XLA没有任何native二进制搜索算法,因此在实践中完整的成对比较通常可以更高效。