keras 中的一个热编码器
onehot encoder in keras
我想索引 keras 模型中输入实例的最大值(onehot 编码器)。我知道如何预处理它,但我想在模型中执行它。
inputs = [[ 1,2,3,5],
[ 2,4,6,3],
[ 0,7,1,2]]
argmax =[3,2,1]
onehot = [[0,0,0,1],
[0,0,1,0],
[0,1,0,0]]
concated= [[ 1,2,3,5, 0,0,0,1],
[ 2,4,6,3, 0,0,1,0],
[ 0,7,1,2, 0,1,0,0]]
inputs = Input()
argmax = K.argmax(inputs, axis=-1)
onehot = #How to perform this operation with keras ??#
concated = Concat([inputs,onehot], axis=-1)
已编辑
是否有可能像 python 中的代码?
argmax = [3,2,1]
eye = np.eye(4)
eye[argmax]
Out[6]:
array([[0., 0., 0., 1.],
[0., 0., 1., 0.],
[0., 1., 0., 0.]])
您可以尝试这样的操作:
import tensorflow as tf
x = tf.constant([[
[ 1,2,3,5],
[ 2,4,6,3],
[ 0,7,1,2]],
[
[ 9,2,3,5],
[ 2,2,2,3],
[ 0,1,5,2]]
])
inputs = tf.keras.layers.Input((3, 4))
argmax = tf.argmax(inputs, axis=-1)
ij = tf.stack(tf.meshgrid(
tf.range(tf.shape(inputs)[0], dtype=tf.int64),
tf.range(tf.shape(inputs)[1], dtype=tf.int64),
indexing='ij'), axis=-1)
gather_indices = tf.concat([ij, tf.expand_dims(argmax, axis=-1)], axis=-1)
onehot = tf.tensor_scatter_nd_update(tf.zeros_like(x, dtype=tf.float32), gather_indices, tf.ones_like(argmax, dtype=tf.float32))
outputs = tf.keras.layers.Concatenate()([inputs,onehot])
model = tf.keras.Model(inputs, outputs)
print(model(x))
tf.Tensor(
[[[1. 2. 3. 5. 0. 0. 0. 1.]
[2. 4. 6. 3. 0. 0. 1. 0.]
[0. 7. 1. 2. 0. 1. 0. 0.]]
[[9. 2. 3. 5. 1. 0. 0. 0.]
[2. 2. 2. 3. 0. 0. 0. 1.]
[0. 1. 5. 2. 0. 0. 1. 0.]]], shape=(2, 3, 8), dtype=float32)
关于 tf.meshgrid
是什么以及为什么需要它的解释可以在 中找到。您可以考虑将这些操作包装在自定义 Keras
层中。
如果您使用 tf.reduce_max
而不是 tf.argmax
,那就简单多了:
import tensorflow as tf
x = tf.constant([[
[ 1,2,3,5],
[ 2,4,6,3],
[ 0,7,1,2]],
[
[ 9,2,3,5],
[ 2,2,2,3],
[ 0,1,5,2]]
])
inputs = tf.keras.layers.Input((3, 4))
onehot = tf.where(tf.not_equal(inputs, tf.reduce_max(inputs, keepdims=True, axis=-1)), tf.zeros_like(inputs), tf.ones_like(inputs))
outputs = tf.keras.layers.Concatenate()([inputs,onehot])
model = tf.keras.Model(inputs, outputs)
print(model(x))
tf.Tensor(
[[[1. 2. 3. 5. 0. 0. 0. 1.]
[2. 4. 6. 3. 0. 0. 1. 0.]
[0. 7. 1. 2. 0. 1. 0. 0.]]
[[9. 2. 3. 5. 1. 0. 0. 0.]
[2. 2. 2. 3. 0. 0. 0. 1.]
[0. 1. 5. 2. 0. 0. 1. 0.]]], shape=(2, 3, 8), dtype=float32)
但我想你想使用 argmax
是有原因的。
我想索引 keras 模型中输入实例的最大值(onehot 编码器)。我知道如何预处理它,但我想在模型中执行它。
inputs = [[ 1,2,3,5],
[ 2,4,6,3],
[ 0,7,1,2]]
argmax =[3,2,1]
onehot = [[0,0,0,1],
[0,0,1,0],
[0,1,0,0]]
concated= [[ 1,2,3,5, 0,0,0,1],
[ 2,4,6,3, 0,0,1,0],
[ 0,7,1,2, 0,1,0,0]]
inputs = Input()
argmax = K.argmax(inputs, axis=-1)
onehot = #How to perform this operation with keras ??#
concated = Concat([inputs,onehot], axis=-1)
已编辑 是否有可能像 python 中的代码?
argmax = [3,2,1]
eye = np.eye(4)
eye[argmax]
Out[6]:
array([[0., 0., 0., 1.],
[0., 0., 1., 0.],
[0., 1., 0., 0.]])
您可以尝试这样的操作:
import tensorflow as tf
x = tf.constant([[
[ 1,2,3,5],
[ 2,4,6,3],
[ 0,7,1,2]],
[
[ 9,2,3,5],
[ 2,2,2,3],
[ 0,1,5,2]]
])
inputs = tf.keras.layers.Input((3, 4))
argmax = tf.argmax(inputs, axis=-1)
ij = tf.stack(tf.meshgrid(
tf.range(tf.shape(inputs)[0], dtype=tf.int64),
tf.range(tf.shape(inputs)[1], dtype=tf.int64),
indexing='ij'), axis=-1)
gather_indices = tf.concat([ij, tf.expand_dims(argmax, axis=-1)], axis=-1)
onehot = tf.tensor_scatter_nd_update(tf.zeros_like(x, dtype=tf.float32), gather_indices, tf.ones_like(argmax, dtype=tf.float32))
outputs = tf.keras.layers.Concatenate()([inputs,onehot])
model = tf.keras.Model(inputs, outputs)
print(model(x))
tf.Tensor(
[[[1. 2. 3. 5. 0. 0. 0. 1.]
[2. 4. 6. 3. 0. 0. 1. 0.]
[0. 7. 1. 2. 0. 1. 0. 0.]]
[[9. 2. 3. 5. 1. 0. 0. 0.]
[2. 2. 2. 3. 0. 0. 0. 1.]
[0. 1. 5. 2. 0. 0. 1. 0.]]], shape=(2, 3, 8), dtype=float32)
关于 tf.meshgrid
是什么以及为什么需要它的解释可以在 Keras
层中。
如果您使用 tf.reduce_max
而不是 tf.argmax
,那就简单多了:
import tensorflow as tf
x = tf.constant([[
[ 1,2,3,5],
[ 2,4,6,3],
[ 0,7,1,2]],
[
[ 9,2,3,5],
[ 2,2,2,3],
[ 0,1,5,2]]
])
inputs = tf.keras.layers.Input((3, 4))
onehot = tf.where(tf.not_equal(inputs, tf.reduce_max(inputs, keepdims=True, axis=-1)), tf.zeros_like(inputs), tf.ones_like(inputs))
outputs = tf.keras.layers.Concatenate()([inputs,onehot])
model = tf.keras.Model(inputs, outputs)
print(model(x))
tf.Tensor(
[[[1. 2. 3. 5. 0. 0. 0. 1.]
[2. 4. 6. 3. 0. 0. 1. 0.]
[0. 7. 1. 2. 0. 1. 0. 0.]]
[[9. 2. 3. 5. 1. 0. 0. 0.]
[2. 2. 2. 3. 0. 0. 0. 1.]
[0. 1. 5. 2. 0. 0. 1. 0.]]], shape=(2, 3, 8), dtype=float32)
但我想你想使用 argmax
是有原因的。