将 tf.data.Dataset 包装成 tf.function 会提高性能吗?
Does wrapping tf.data.Dataset into tf.function improve performance?
鉴于下面的两个示例,在 tf.data.Dataset
签名时是否有性能改进?
数据集不在tf.function
中
import tensorflow as tf
class MyModel(tf.keras.Model):
def call(self, inputs):
return tf.ones([1, 1]) * inputs
model = MyModel()
model2 = MyModel()
@tf.function
def train_step(data):
output = model(data)
output = model2(output)
return output
dataset = tf.data.Dataset.from_tensors(tf.ones([1, 1]))
for data in dataset:
train_step(data)
tf.function
中的数据集
import tensorflow as tf
class MyModel(tf.keras.Model):
def call(self, inputs):
return tf.ones([1, 1]) * inputs
model = MyModel()
model2 = MyModel()
@tf.function
def train():
dataset = tf.data.Dataset.from_tensors(tf.ones([1, 1]))
def train_step(data):
output = model(data)
output = model2(output)
return output
for data in dataset:
train_step(data)
train()
添加 @tf.function
确实显着提高了速度。看看这个:
import tensorflow as tf
data = tf.random.normal((1000, 10, 10, 1))
dataset = tf.data.Dataset.from_tensors(data).batch(10)
def iterate_1(dataset):
for x in dataset:
x = x
@tf.function
def iterate_2(dataset):
for x in dataset:
x = x
%timeit -n 1000 iterate_1(dataset) # 1.46 ms ± 8.2 µs per loop
%timeit -n 1000 iterate_2(dataset) # 239 µs ± 10.2 µs per loop
如您所见,使用 @tf.function
进行迭代的速度提高了 6 倍以上。
鉴于下面的两个示例,在 tf.data.Dataset
签名时是否有性能改进?
数据集不在tf.function
中import tensorflow as tf
class MyModel(tf.keras.Model):
def call(self, inputs):
return tf.ones([1, 1]) * inputs
model = MyModel()
model2 = MyModel()
@tf.function
def train_step(data):
output = model(data)
output = model2(output)
return output
dataset = tf.data.Dataset.from_tensors(tf.ones([1, 1]))
for data in dataset:
train_step(data)
tf.function
中的数据集import tensorflow as tf
class MyModel(tf.keras.Model):
def call(self, inputs):
return tf.ones([1, 1]) * inputs
model = MyModel()
model2 = MyModel()
@tf.function
def train():
dataset = tf.data.Dataset.from_tensors(tf.ones([1, 1]))
def train_step(data):
output = model(data)
output = model2(output)
return output
for data in dataset:
train_step(data)
train()
添加 @tf.function
确实显着提高了速度。看看这个:
import tensorflow as tf
data = tf.random.normal((1000, 10, 10, 1))
dataset = tf.data.Dataset.from_tensors(data).batch(10)
def iterate_1(dataset):
for x in dataset:
x = x
@tf.function
def iterate_2(dataset):
for x in dataset:
x = x
%timeit -n 1000 iterate_1(dataset) # 1.46 ms ± 8.2 µs per loop
%timeit -n 1000 iterate_2(dataset) # 239 µs ± 10.2 µs per loop
如您所见,使用 @tf.function
进行迭代的速度提高了 6 倍以上。