如何在 TF2 的 call() 函数中获取 batch_size?
How to get batch_size in call() function in TF2?
我正在尝试在 TF2 模型的 call()
函数中获取 batch_size
。
但是,我无法得到它,因为我知道的所有方法 returns None
或张量而不是维度元组。
这是一个简短的例子
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
def call(self, x):
print(len(x))
print(x.shape)
print(tf.size(x))
print(np.shape(x))
print(x.get_shape())
print(x.get_shape().as_list())
print(tf.rank(x))
print(tf.shape(x))
print(tf.shape(x)[0])
print(tf.shape(x)[1])
return tf.random.uniform((2, 10))
m = MyModel()
m.compile(optimizer="Adam", loss="sparse_categorical_crossentropy", metrics=['accuracy'])
m.fit(np.array([[1,2,3,4], [5,6,7,8]]), np.array([0, 1]), epochs=1)
输出为:
Tensor("my_model_26/strided_slice:0", shape=(), dtype=int32)
(None, 4)
Tensor("my_model_26/Size:0", shape=(), dtype=int32)
(None, 4)
(None, 4)
[None, 4]
Tensor("my_model_26/Rank:0", shape=(), dtype=int32)
Tensor("my_model_26/Shape_2:0", shape=(2,), dtype=int32)
Tensor("my_model_26/strided_slice_1:0", shape=(), dtype=int32)
Tensor("my_model_26/strided_slice_2:0", shape=(), dtype=int32)
1/1 [==============================] - 0s 1ms/step - loss: 3.1796 - accuracy: 0.0000e+00
在此示例中,我将 (2,4)
numpy 数组作为输入,将 (2, )
作为目标提供给模型。
但是如您所见,我无法在 call()
函数中获取 batch_size
。
我需要它的原因是因为我必须为 batch_size
迭代张量,这在我的真实模型中是动态的。
例如,如果数据集大小为 10,批量大小为 3,则最后一批中的最后一批大小将为 1。因此,我必须动态知道批量大小。
谁能帮帮我?
- 张量流 2.3.3
- CUDA 10.2
- python 3.6.9
如果你想得到准确的数据和形状,你可以转为 eager 运行 true,但这不是一个好的解决方案,因为它会使训练变慢。
这样设置:
m.compile(optimizer="Adam", loss="sparse_categorical_crossentropy",
metrics=['accuracy'], run_eagerly=True)
那么输出将是:
(2, 4)
tf.Tensor(8, shape=(), dtype=int32)
(2, 4)
(2, 4)
[2, 4]
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor([2 4], shape=(2,), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
这是因为您正在使用 TensorFlow(这是强制性的,因为 Keras 现在在 TensorFlow 中),并且通过使用 TensorFlow,您需要了解将动态图“编译”为静态图的过程。
简而言之,您的 call
方法(在幕后)用 @tf.function
装饰器装饰。
这个装饰者:
- 跟踪 python 函数执行
- 转换 TensorFlow 操作中的 python 操作(例如
if a > b
变为 tf.cond(tf.greater(a,b), something, something_else)
)
- 创建一个
tf.Graph
(静态图)
- 执行刚刚创建的静态图。
所有 print
调用都在第一步执行(python 执行跟踪),这就是为什么即使您训练模型也只能看到输出一次。
也就是说,要获得张量的运行时间(动态形状),您必须使用 tf.shape(x)
,批量大小仅为 batch_size = tf.shape(x)[0]
请注意,如果你想看到形状(使用打印)你不能使用打印,但你必须使用tf.print
。
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
def call(self, x):
shape = tf.shape(x)
batch_size = shape[0]
tf.print(shape, batch_size)
return tf.random.uniform((2, 10))
m = MyModel()
m.compile(
optimizer="Adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
m.fit(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), np.array([0, 1]), epochs=1)
有关静态和动态形状的更多信息:https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/
有关 tf.function 行为的更多信息:https://pgaleone.eu/tensorflow/tf.function/2019/03/21/dissecting-tf-function-part-1/
注:这些文章是我写的。
我正在尝试在 TF2 模型的 call()
函数中获取 batch_size
。
但是,我无法得到它,因为我知道的所有方法 returns None
或张量而不是维度元组。
这是一个简短的例子
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
def call(self, x):
print(len(x))
print(x.shape)
print(tf.size(x))
print(np.shape(x))
print(x.get_shape())
print(x.get_shape().as_list())
print(tf.rank(x))
print(tf.shape(x))
print(tf.shape(x)[0])
print(tf.shape(x)[1])
return tf.random.uniform((2, 10))
m = MyModel()
m.compile(optimizer="Adam", loss="sparse_categorical_crossentropy", metrics=['accuracy'])
m.fit(np.array([[1,2,3,4], [5,6,7,8]]), np.array([0, 1]), epochs=1)
输出为:
Tensor("my_model_26/strided_slice:0", shape=(), dtype=int32)
(None, 4)
Tensor("my_model_26/Size:0", shape=(), dtype=int32)
(None, 4)
(None, 4)
[None, 4]
Tensor("my_model_26/Rank:0", shape=(), dtype=int32)
Tensor("my_model_26/Shape_2:0", shape=(2,), dtype=int32)
Tensor("my_model_26/strided_slice_1:0", shape=(), dtype=int32)
Tensor("my_model_26/strided_slice_2:0", shape=(), dtype=int32)
1/1 [==============================] - 0s 1ms/step - loss: 3.1796 - accuracy: 0.0000e+00
在此示例中,我将 (2,4)
numpy 数组作为输入,将 (2, )
作为目标提供给模型。
但是如您所见,我无法在 call()
函数中获取 batch_size
。
我需要它的原因是因为我必须为 batch_size
迭代张量,这在我的真实模型中是动态的。
例如,如果数据集大小为 10,批量大小为 3,则最后一批中的最后一批大小将为 1。因此,我必须动态知道批量大小。
谁能帮帮我?
- 张量流 2.3.3
- CUDA 10.2
- python 3.6.9
如果你想得到准确的数据和形状,你可以转为 eager 运行 true,但这不是一个好的解决方案,因为它会使训练变慢。
这样设置:
m.compile(optimizer="Adam", loss="sparse_categorical_crossentropy",
metrics=['accuracy'], run_eagerly=True)
那么输出将是:
(2, 4)
tf.Tensor(8, shape=(), dtype=int32)
(2, 4)
(2, 4)
[2, 4]
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor([2 4], shape=(2,), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
这是因为您正在使用 TensorFlow(这是强制性的,因为 Keras 现在在 TensorFlow 中),并且通过使用 TensorFlow,您需要了解将动态图“编译”为静态图的过程。
简而言之,您的 call
方法(在幕后)用 @tf.function
装饰器装饰。
这个装饰者:
- 跟踪 python 函数执行
- 转换 TensorFlow 操作中的 python 操作(例如
if a > b
变为tf.cond(tf.greater(a,b), something, something_else)
) - 创建一个
tf.Graph
(静态图) - 执行刚刚创建的静态图。
所有 print
调用都在第一步执行(python 执行跟踪),这就是为什么即使您训练模型也只能看到输出一次。
也就是说,要获得张量的运行时间(动态形状),您必须使用 tf.shape(x)
,批量大小仅为 batch_size = tf.shape(x)[0]
请注意,如果你想看到形状(使用打印)你不能使用打印,但你必须使用tf.print
。
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
def call(self, x):
shape = tf.shape(x)
batch_size = shape[0]
tf.print(shape, batch_size)
return tf.random.uniform((2, 10))
m = MyModel()
m.compile(
optimizer="Adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
m.fit(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), np.array([0, 1]), epochs=1)
有关静态和动态形状的更多信息:https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/
有关 tf.function 行为的更多信息:https://pgaleone.eu/tensorflow/tf.function/2019/03/21/dissecting-tf-function-part-1/
注:这些文章是我写的。