索引 Keras 张量
Indexing a Keras Tensor
我的 Keras 函数模型的输出层是一个维度 (None, 1344, 2)
的张量 x
。我希望从 x
的第二维中提取 n < 1344
个条目,并创建一个大小为 (None, n, 2)
的新张量 y
。
通过简单地访问 x[:, :n,:]
来提取 n
连续的条目似乎很简单,但是如果 n
索引是不连续的,(看起来)很难。 Keras 有一种干净的方法吗?
这是我目前的方法。
实验 1(切片张量,连续索引,有效):
print('My tensor shape is', K.int_shape(x)) #my tensor
(None, 1344, 2) # as printed in my code
print('Slicing first 5 entries, shape is', K.int_shape(x[:, :5, :]))
(None, 5, 2) # as printed in my code, works!
实验 2(在任意索引处索引张量,失败)
print('My tensor shape is', K.int_shape(x)) #my tensor
(None, 1344, 2) # as printed in my code
foo = np.array([1,2,4,5,8])
print('arbitrary indexing, shape is', K.int_shape(x[:,foo,:]))
Keras returns出现如下错误:
ValueError: Shapes must be equal rank, but are 1 and 0
From merging shape 1 with other shapes. for 'strided_slice_17/stack_1' (op:
'Pack') with input shapes: [], [5], [].
实验3(张量流后端函数)
我也尝试过 K.backend.gather
但它的用法不清楚,因为 1) Keras 文档指出索引应该是整数张量,如果我的目标是提取 numpy.where
中的条目,则没有 Keras 等价物=13=] 满足特定条件和 2) K.backend.gather
似乎从 axis = 0
中提取条目,而我想从 x
.
的第二个维度中提取条目
您正在寻找 tf.gather_nd,它将根据索引数组进行索引:
# From documentation
indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']
要在 Keras 模型中使用它,请确保将其包裹在像 Lambda
这样的层中。
我的 Keras 函数模型的输出层是一个维度 (None, 1344, 2)
的张量 x
。我希望从 x
的第二维中提取 n < 1344
个条目,并创建一个大小为 (None, n, 2)
的新张量 y
。
通过简单地访问 x[:, :n,:]
来提取 n
连续的条目似乎很简单,但是如果 n
索引是不连续的,(看起来)很难。 Keras 有一种干净的方法吗?
这是我目前的方法。
实验 1(切片张量,连续索引,有效):
print('My tensor shape is', K.int_shape(x)) #my tensor
(None, 1344, 2) # as printed in my code
print('Slicing first 5 entries, shape is', K.int_shape(x[:, :5, :]))
(None, 5, 2) # as printed in my code, works!
实验 2(在任意索引处索引张量,失败)
print('My tensor shape is', K.int_shape(x)) #my tensor
(None, 1344, 2) # as printed in my code
foo = np.array([1,2,4,5,8])
print('arbitrary indexing, shape is', K.int_shape(x[:,foo,:]))
Keras returns出现如下错误:
ValueError: Shapes must be equal rank, but are 1 and 0
From merging shape 1 with other shapes. for 'strided_slice_17/stack_1' (op:
'Pack') with input shapes: [], [5], [].
实验3(张量流后端函数)
我也尝试过 K.backend.gather
但它的用法不清楚,因为 1) Keras 文档指出索引应该是整数张量,如果我的目标是提取 numpy.where
中的条目,则没有 Keras 等价物=13=] 满足特定条件和 2) K.backend.gather
似乎从 axis = 0
中提取条目,而我想从 x
.
您正在寻找 tf.gather_nd,它将根据索引数组进行索引:
# From documentation
indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']
要在 Keras 模型中使用它,请确保将其包裹在像 Lambda
这样的层中。