我如何断言 Tensor 是 0 到 N 个“True”,后跟 0 到 N 个“False”的序列?

How can I assert that a Tensor is a sequence of 0 to N `True`s, followed by 0 to N `False`s?

如何断言 Tensor 如下所示(例如):

[ True, True ]
[ True, True, True, False, False, False, False ]
[ True, True, True, False, False ]
[ True, False, False, False, False ]
[ False, False ]

但拒绝这样的输入:

[ True, False, True, False, False, True, False ]
[ False, False, False, False, True ]

或者更笼统地说:我想测试张量是否仅由一系列 0 到 N 值的 True 组成,然后是 0 到 N 值的 False。我如何使用 Tensorflow 2 做到这一点?

这是您可以做到的一种方法:

import tensorflow as tf

def is_valid(a):
    # a is assumed to be a 1D boolean array
    a = tf.convert_to_tensor(a)
    # Convert to integer
    a_int = tf.dtypes.cast(a, tf.int32)
    # Take pairwise differences
    diff = a_int[1:] - a_int[:-1]
    # Check all differences are zero or negative (no transitions from False to True)
    return tf.reduce_all(diff <= 0)

# Valid examples
tf.print(is_valid([ True, True ]))
# 1
tf.print(is_valid([ True, True, True, False, False, False, False ]))
# 1
tf.print(is_valid([ True, True, True, False, False ]))
# 1
tf.print(is_valid([ True, False, False, False, False ]))
# 1
tf.print(is_valid([ False, False ]))
# 1

# Invalid examples
tf.print(is_valid([ True, False, True, False, False, True, False ]))
# 0
tf.print(is_valid([ False, False, False, False, True ]))
# 0

注意:is_valid returns 标量布尔张量,即使 tf.print 将其打印为整数。

另一种方法,研究元素的索引:

import tensorflow as tf

def is_valid(t):
  where_false = tf.where(~t)
  return len(where_false) == 0 or all( idx_true < min(where_false) for idx_true in tf.where(t))

assert is_valid(tf.constant([ True, True ]))
assert is_valid(tf.constant([ True, True, True, False, False, False, False ]))
assert is_valid(tf.constant([ True, True, True, False, False ]))
assert is_valid(tf.constant([ True, False, False, False, False ]))
assert is_valid(tf.constant([ False, False ]))
assert not is_valid(tf.constant([ True, False, True, False, False, True, False ]))
assert not is_valid(tf.constant([ False, False, False, False, True ]))

想法是,所有 True 值都应出现在第一个 False 之前(如果存在的话)。