将 tf.keras.layers.Input 放入 tf.strings.split 时出错
Error when fit tf.keras.layers.Input into tf.strings.split
我是新来的tf,这个问题可能很幼稚。我想了解为什么 Test2 会出错。
# Test 1
x = tf.constant([["aa,bbb,cc"], ["dd,,"]])
tf.strings.split(x, sep=',')
=> <tf.RaggedTensor [[[b'aa', b'bbb', b'cc']], [[b'dd', b'', b'']]]>
# Test2
x = tf.keras.layers.Input(shape=(1,), dtype=tf.string)
tf.strings.split(x, sep=",")
...
ValueError: Tensor conversion requested dtype int32 for Tensor with dtype int64: <tf.Tensor 'Cumsum:0' shape=(None,) dtype=int64>
IIUC,这应该是一个简单的模型:
import tensorflow as tf
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string)
outputs = tf.strings.split(inputs, sep=",")
# ... your additional layers
model = tf.keras.Model(inputs, outputs)
x = tf.constant([["aa,bbb,cc"], ["dd,,"]])
print(model(x))
<tf.RaggedTensor [[[b'aa', b'bbb', b'cc']],
[[b'dd', b'', b'']]]>
我假设您的错误实际上是由于使用 StringLookup
或 TextVectorization
层将字符串转换为整数后的类型不匹配造成的,因为这会运行:
x = tf.keras.layers.Input(shape=(1,), dtype=tf.string)
tf.strings.split(x, sep=",")
# <KerasTensor: type_spec=RaggedTensorSpec(TensorShape([1, 1, None]), tf.string, 2, tf.int64) (created by layer 'tf.strings.split_2')>
这是一个包含 TextVectorization
层和 Embedding
层的示例:
import tensorflow as tf
x = tf.constant([["aa,bbb,cc"], ["dd,,"]])
vectorizer_layer = tf.keras.layers.TextVectorization(standardize=None)
vectorizer_layer.adapt(['aa' 'bbb', 'cc', 'dd', ''])
embedding_layer = tf.keras.layers.Embedding(10, 5)
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string)
l = tf.keras.layers.Lambda(lambda x: tf.squeeze(tf.strings.split(x, sep=","), axis=1).to_tensor(), name='split')(inputs)
outputs = vectorizer_layer(tf.reshape(l, (tf.reduce_prod(tf.shape(l)), 1)))
outputs = tf.reshape(outputs, tf.shape(l))
outputs = embedding_layer(outputs)
model = tf.keras.Model(inputs, outputs)
print(model(x))
tf.Tensor(
[[[ 0.01325693 0.04536256 -0.04319841 -0.01714526 -0.03821129]
[ 0.01325693 0.04536256 -0.04319841 -0.01714526 -0.03821129]
[ 0.04633341 -0.02162067 0.04285793 -0.03961723 -0.04530612]]
[[ 0.04202544 -0.04769113 0.00436096 -0.04809079 -0.0097675 ]
[-0.00061743 -0.03051994 0.02737813 -0.04842547 0.03567551]
[-0.00061743 -0.03051994 0.02737813 -0.04842547 0.03567551]]], shape=(2, 3, 5), dtype=float32)
我是新来的tf,这个问题可能很幼稚。我想了解为什么 Test2 会出错。
# Test 1
x = tf.constant([["aa,bbb,cc"], ["dd,,"]])
tf.strings.split(x, sep=',')
=> <tf.RaggedTensor [[[b'aa', b'bbb', b'cc']], [[b'dd', b'', b'']]]>
# Test2
x = tf.keras.layers.Input(shape=(1,), dtype=tf.string)
tf.strings.split(x, sep=",")
...
ValueError: Tensor conversion requested dtype int32 for Tensor with dtype int64: <tf.Tensor 'Cumsum:0' shape=(None,) dtype=int64>
IIUC,这应该是一个简单的模型:
import tensorflow as tf
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string)
outputs = tf.strings.split(inputs, sep=",")
# ... your additional layers
model = tf.keras.Model(inputs, outputs)
x = tf.constant([["aa,bbb,cc"], ["dd,,"]])
print(model(x))
<tf.RaggedTensor [[[b'aa', b'bbb', b'cc']],
[[b'dd', b'', b'']]]>
我假设您的错误实际上是由于使用 StringLookup
或 TextVectorization
层将字符串转换为整数后的类型不匹配造成的,因为这会运行:
x = tf.keras.layers.Input(shape=(1,), dtype=tf.string)
tf.strings.split(x, sep=",")
# <KerasTensor: type_spec=RaggedTensorSpec(TensorShape([1, 1, None]), tf.string, 2, tf.int64) (created by layer 'tf.strings.split_2')>
这是一个包含 TextVectorization
层和 Embedding
层的示例:
import tensorflow as tf
x = tf.constant([["aa,bbb,cc"], ["dd,,"]])
vectorizer_layer = tf.keras.layers.TextVectorization(standardize=None)
vectorizer_layer.adapt(['aa' 'bbb', 'cc', 'dd', ''])
embedding_layer = tf.keras.layers.Embedding(10, 5)
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string)
l = tf.keras.layers.Lambda(lambda x: tf.squeeze(tf.strings.split(x, sep=","), axis=1).to_tensor(), name='split')(inputs)
outputs = vectorizer_layer(tf.reshape(l, (tf.reduce_prod(tf.shape(l)), 1)))
outputs = tf.reshape(outputs, tf.shape(l))
outputs = embedding_layer(outputs)
model = tf.keras.Model(inputs, outputs)
print(model(x))
tf.Tensor(
[[[ 0.01325693 0.04536256 -0.04319841 -0.01714526 -0.03821129]
[ 0.01325693 0.04536256 -0.04319841 -0.01714526 -0.03821129]
[ 0.04633341 -0.02162067 0.04285793 -0.03961723 -0.04530612]]
[[ 0.04202544 -0.04769113 0.00436096 -0.04809079 -0.0097675 ]
[-0.00061743 -0.03051994 0.02737813 -0.04842547 0.03567551]
[-0.00061743 -0.03051994 0.02737813 -0.04842547 0.03567551]]], shape=(2, 3, 5), dtype=float32)