如果使用两个相同的参数调用函数,Tensorflow `tf.function` 将失败
Tensorflow `tf.function` fails if function is called with two identical arguments
在我的 TF 模型中,我的 call
函数调用一个外部能量函数,该函数依赖于一个函数,其中单个参数被传递两次(参见下面的简化版本):
import tensorflow as tf
@tf.function
def calc_sw3(gamma,gamma2, cutoff_jk):
E3 = 2.0
return E3
@tf.function
def calc_sw3_noerr( gamma0, cutoff_jk):
E3 = 2.0
return E3
@tf.function # without tf.function this works fine
def energy(coords, gamma):
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5
E3 = calc_sw3( gamma,gamma,norm_rij) # repeating gamma gives error
# E3 = calc_sw3_noerr( gamma, norm_rij) # this gives no error
return E3
class SWLayer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
self.gamma = tf.Variable(2.51412, dtype=tf.float32)
def call(self, coords_all):
total_conf_energy = energy( coords_all, self.gamma)
return total_conf_energy
# =============================================================================
SWL = SWLayer()
coords2 = tf.constant([[
1.9434, 1.0817, 1.0803,
2.6852, 2.7203, 1.0802,
1.3807, 1.3573, 1.3307]])
with tf.GradientTape() as tape:
tape.watch(coords2)
E = SWL( coords2)
此处如果 gamma 仅传递一次,或者如果我不使用 tf.function
装饰器。但是使用 tf.function
并两次传递相同的变量,我得到以下错误:
Traceback (most recent call last):
File "temp_tf.py", line 47, in <module>
E = SWL( coords2)
File "...venv/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "temp_tf.py", line 34, in call
total_conf_energy = energy( coords_all, self.gamma)
tensorflow.python.autograph.impl.api.StagingError: Exception encountered when calling layer "sw_layer" (type SWLayer).
in user code:
File "temp_tf.py", line 22, in energy *
E3 = calc_sw3( gamma,gamma,norm_rij) # repeating gamma gives error
IndexError: list index out of range
Call arguments received:
• coords_all=tf.Tensor(shape=(1, 9), dtype=float32)
这是预期的行为吗?
有趣的问题!我认为错误源于回溯,这导致 tf.function 不止一次评估 energy
中的 python 片段。看到这个 issue. Also, this could be related to a bug.
几个观察结果:
1.从 calc_sw3
中删除 tf.function 装饰器有效并且与 docs:
一致
[...] tf.function applies to a function and all other functions it calls.
因此,如果您再次将 tf.function
显式应用到 calc_sw3
,您可能会触发回溯,但您可能想知道为什么 calc_sw3_noerr
有效?也就是一定和变量gamma
.
有关
2。将输入 signatures 添加到 energy
函数上方的 tf.function,同时将其余代码保持原样,也可以工作 :
@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32), tf.TensorSpec(shape=None, dtype=tf.float32)])
def energy(coords, gamma):
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5
E3 = calc_sw3(gamma, gamma, norm_rij)
return E3
这个方法:
[...] ensures only one ConcreteFunction is created, and restricts the GenericFunction to the specified shapes and types. It is an effective way to limit retracing when Tensors have dynamic shapes.
所以也许假设 gamma
每次都以不同的形状调用,从而触发回溯(只是一个假设)。触发错误的事实实际上是有意或故意设计的 here. Also another interesting comment:
tf.functions can only handle a pre defined input shape, if the shape changes, or if different python objects get passed, tensorflow automagically rebuilds the function
最后,为什么我觉得是tracing的问题呢?因为实际错误来自这部分代码片段:
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5
您可以通过将其注释掉并将 norm_rij
替换为某个值然后调用 calc_sw3
来确认。它会起作用。
这意味着这段代码可能被执行了不止一次,可能 由于上述原因。这也有据可查 here:
In the first stage, referred to as "tracing", Function creates a new tf.Graph. Python code runs normally, but all TensorFlow operations (like adding two Tensors) are deferred: they are captured by the tf.Graph and not run.
In the second stage, a tf.Graph which contains everything that was deferred in the first stage is run. This stage is much faster than the tracing stage
在我的 TF 模型中,我的 call
函数调用一个外部能量函数,该函数依赖于一个函数,其中单个参数被传递两次(参见下面的简化版本):
import tensorflow as tf
@tf.function
def calc_sw3(gamma,gamma2, cutoff_jk):
E3 = 2.0
return E3
@tf.function
def calc_sw3_noerr( gamma0, cutoff_jk):
E3 = 2.0
return E3
@tf.function # without tf.function this works fine
def energy(coords, gamma):
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5
E3 = calc_sw3( gamma,gamma,norm_rij) # repeating gamma gives error
# E3 = calc_sw3_noerr( gamma, norm_rij) # this gives no error
return E3
class SWLayer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
self.gamma = tf.Variable(2.51412, dtype=tf.float32)
def call(self, coords_all):
total_conf_energy = energy( coords_all, self.gamma)
return total_conf_energy
# =============================================================================
SWL = SWLayer()
coords2 = tf.constant([[
1.9434, 1.0817, 1.0803,
2.6852, 2.7203, 1.0802,
1.3807, 1.3573, 1.3307]])
with tf.GradientTape() as tape:
tape.watch(coords2)
E = SWL( coords2)
此处如果 gamma 仅传递一次,或者如果我不使用 tf.function
装饰器。但是使用 tf.function
并两次传递相同的变量,我得到以下错误:
Traceback (most recent call last):
File "temp_tf.py", line 47, in <module>
E = SWL( coords2)
File "...venv/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "temp_tf.py", line 34, in call
total_conf_energy = energy( coords_all, self.gamma)
tensorflow.python.autograph.impl.api.StagingError: Exception encountered when calling layer "sw_layer" (type SWLayer).
in user code:
File "temp_tf.py", line 22, in energy *
E3 = calc_sw3( gamma,gamma,norm_rij) # repeating gamma gives error
IndexError: list index out of range
Call arguments received:
• coords_all=tf.Tensor(shape=(1, 9), dtype=float32)
这是预期的行为吗?
有趣的问题!我认为错误源于回溯,这导致 tf.function 不止一次评估 energy
中的 python 片段。看到这个 issue. Also, this could be related to a bug.
几个观察结果:
1.从 calc_sw3
中删除 tf.function 装饰器有效并且与 docs:
[...] tf.function applies to a function and all other functions it calls.
因此,如果您再次将 tf.function
显式应用到 calc_sw3
,您可能会触发回溯,但您可能想知道为什么 calc_sw3_noerr
有效?也就是一定和变量gamma
.
2。将输入 signatures 添加到 energy
函数上方的 tf.function,同时将其余代码保持原样,也可以工作 :
@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32), tf.TensorSpec(shape=None, dtype=tf.float32)])
def energy(coords, gamma):
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5
E3 = calc_sw3(gamma, gamma, norm_rij)
return E3
这个方法:
[...] ensures only one ConcreteFunction is created, and restricts the GenericFunction to the specified shapes and types. It is an effective way to limit retracing when Tensors have dynamic shapes.
所以也许假设 gamma
每次都以不同的形状调用,从而触发回溯(只是一个假设)。触发错误的事实实际上是有意或故意设计的 here. Also another interesting comment:
tf.functions can only handle a pre defined input shape, if the shape changes, or if different python objects get passed, tensorflow automagically rebuilds the function
最后,为什么我觉得是tracing的问题呢?因为实际错误来自这部分代码片段:
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5
您可以通过将其注释掉并将 norm_rij
替换为某个值然后调用 calc_sw3
来确认。它会起作用。
这意味着这段代码可能被执行了不止一次,可能 由于上述原因。这也有据可查 here:
In the first stage, referred to as "tracing", Function creates a new tf.Graph. Python code runs normally, but all TensorFlow operations (like adding two Tensors) are deferred: they are captured by the tf.Graph and not run.
In the second stage, a tf.Graph which contains everything that was deferred in the first stage is run. This stage is much faster than the tracing stage