为什么将 KerasTensor 传递给 Sequential.add 不会引发 TypeError?
Why does passing a KerasTensor to Sequential.add not raise a TypeError?
在 keras.Sequential.add
的 definition 中,我们有(根据文档)
if isinstance(layer, tf.Module):
if not isinstance(layer, base_layer.Layer):
layer = functional.ModuleWrapper(layer)
else:
raise TypeError('The added layer must be an instance of class Layer. '
f'Received: layer={layer} of type {type(layer)}.')
然而,当我执行
model = keras.Sequential()
model.add(keras.Input(100))
a TypeError
未引发。 keras.Input
returns 张量,KerasTensor
的 definition 表明它继承自 Object
,而不是 Layer
。
为什么我可以将 Input
添加到 Sequential
,而不是像我期望的那样被要求添加 InputLayer
?
好问题,如果你看一下 source code,我们读到:
If we are passed a Keras tensor created by keras.Input(), we can
extract the input layer from its keras history and use that without any loss of generality.
似乎 keras.Input
正在内部转换为 InputLayer
here:
if hasattr(layer, '_keras_history'):
origin_layer = layer._keras_history[0]
if isinstance(origin_layer, input_layer.InputLayer):
layer = origin_layer
这也可以用这个片段来验证:
inputs = tf.keras.Input(100)
print(inputs._keras_history[0])
<keras.engine.input_layer.InputLayer object at 0x7f9379a0cdd0>
这就是您没有看到任何错误的原因。
在 keras.Sequential.add
的 definition 中,我们有(根据文档)
if isinstance(layer, tf.Module):
if not isinstance(layer, base_layer.Layer):
layer = functional.ModuleWrapper(layer)
else:
raise TypeError('The added layer must be an instance of class Layer. '
f'Received: layer={layer} of type {type(layer)}.')
然而,当我执行
model = keras.Sequential()
model.add(keras.Input(100))
a TypeError
未引发。 keras.Input
returns 张量,KerasTensor
的 definition 表明它继承自 Object
,而不是 Layer
。
为什么我可以将 Input
添加到 Sequential
,而不是像我期望的那样被要求添加 InputLayer
?
好问题,如果你看一下 source code,我们读到:
If we are passed a Keras tensor created by keras.Input(), we can extract the input layer from its keras history and use that without any loss of generality.
似乎 keras.Input
正在内部转换为 InputLayer
here:
if hasattr(layer, '_keras_history'):
origin_layer = layer._keras_history[0]
if isinstance(origin_layer, input_layer.InputLayer):
layer = origin_layer
这也可以用这个片段来验证:
inputs = tf.keras.Input(100)
print(inputs._keras_history[0])
<keras.engine.input_layer.InputLayer object at 0x7f9379a0cdd0>
这就是您没有看到任何错误的原因。