Tensorflow 和 Keras:不同的输出形状取决于训练或推断
Tensorflow & Keras: Different output shape depends on training or inferring
我正在对 tensorflow.keras.Model
进行子类化以实现某个模型。预期行为:
- 训练(拟合)时间:returns包含最终输出和辅助输出的张量列表;
- 推断(预测)时间:returns 单个输出张量。
代码是:
class SomeModel(tensorflow.keras.Model):
# ......
def call(self, x, training=True):
# ......
return [aux1, aux2, net] if training else net
我是这样使用它的:
model=SomeModel(...)
model.compile(...,
loss=keras.losses.SparseCategoricalCrossentropy(),
loss_weights=[0.4, 0.4, 1],...)
# ......
model.fit(data, [labels, labels, labels])
并得到:
AssertionError: in converted code:
ipython-input-33-862e679ab098:140 call *
`return [aux1, aux2, net] if training else net`
...\tensorflow_core\python\autograph\operators\control_flow.py:918 if_stmt
那么问题是if
语句转换成计算图当然会出问题。我发现整个堆栈跟踪很长而且没有用,所以它没有包含在这里。
那么,有没有办法让TensorFlow根据training
生成不同的图呢?
您使用的是哪个 tensorflow 版本?您可以覆盖 Tensorflow 2.2 中 .fit、.predict 和 .evaluate 方法中的行为,这将为这些方法生成不同的图形(我假设)并且可能适用于您的用例。
早期版本的问题是子类模型是通过跟踪 call
方法创建的。这意味着 Python 条件语句变成了 Tensorflow 条件语句,并且在图形创建和执行过程中面临一些限制。
首先,必须定义两个分支 (if-else),并且对于 python 集合(例如列表),分支必须具有相同的结构(例如元素数量)。您可以阅读有关 Autograph here and here 的限制和效果的信息。
(此外,如果条件基于 Python 变量而不是张量,则条件可能不会在每个 运行 都被评估。)
我正在对 tensorflow.keras.Model
进行子类化以实现某个模型。预期行为:
- 训练(拟合)时间:returns包含最终输出和辅助输出的张量列表;
- 推断(预测)时间:returns 单个输出张量。
代码是:
class SomeModel(tensorflow.keras.Model):
# ......
def call(self, x, training=True):
# ......
return [aux1, aux2, net] if training else net
我是这样使用它的:
model=SomeModel(...)
model.compile(...,
loss=keras.losses.SparseCategoricalCrossentropy(),
loss_weights=[0.4, 0.4, 1],...)
# ......
model.fit(data, [labels, labels, labels])
并得到:
AssertionError: in converted code:
ipython-input-33-862e679ab098:140 call *
`return [aux1, aux2, net] if training else net`
...\tensorflow_core\python\autograph\operators\control_flow.py:918 if_stmt
那么问题是if
语句转换成计算图当然会出问题。我发现整个堆栈跟踪很长而且没有用,所以它没有包含在这里。
那么,有没有办法让TensorFlow根据training
生成不同的图呢?
您使用的是哪个 tensorflow 版本?您可以覆盖 Tensorflow 2.2 中 .fit、.predict 和 .evaluate 方法中的行为,这将为这些方法生成不同的图形(我假设)并且可能适用于您的用例。
早期版本的问题是子类模型是通过跟踪 call
方法创建的。这意味着 Python 条件语句变成了 Tensorflow 条件语句,并且在图形创建和执行过程中面临一些限制。
首先,必须定义两个分支 (if-else),并且对于 python 集合(例如列表),分支必须具有相同的结构(例如元素数量)。您可以阅读有关 Autograph here and here 的限制和效果的信息。
(此外,如果条件基于 Python 变量而不是张量,则条件可能不会在每个 运行 都被评估。)