(Conv1D) Tensorflow 和 Jax 对同一输入产生不同的输出
(Conv1D) Tensorflow and Jax Resulting Different Outputs for The Same Input
我正在尝试使用 conv1d 函数分别在 jax 和 tensorflow 上进行转置转换。我阅读了关于 con1d_transposed 操作的 jax 和 tensorflow 的文档,但它们对相同的输入产生了不同的输出。
我找不到问题所在。而且我不知道哪一个会产生正确的结果。请帮助我。
我的 Jax 实现(Jax 代码)
x = np.asarray([[[1, 2, 3, 4, -5], [1, 2, 3, 4, 5]]], dtype=np.float32).transpose((0, 2, 1))
filters = np.array([[[1, 0, -1], [-1, 0, 1]],
[[1, 1, 1], [-1, -1, -1]]],
dtype=np.float32).transpose((2, 1, 0))
kernel_rot = np.rot90(np.rot90(filters))
print(f"x strides: {x.strides}\nfilters strides: {kernel_rot.strides}\nx shape: {x.shape}\nfilters shape: {filters.shape}\nx: \n{x}\nfilters: \n{filters}\n")
dn1 = lax.conv_dimension_numbers(x.shape, filters.shape,('NWC', 'WIO', 'NWC'))
print(dn1)
res = lax.conv_general_dilated(x,kernel_rot,(1,),'SAME',(1,),(1,),dn1)
res = np.asarray(res)
print(f"result strides: {res.strides}\nresult shape: {res.shape}\nresult: \n{res}\n")
我的 TensorFlow 实现(TensorFlow 代码)
x = np.asarray([[[1, 2, 3, 4, -5], [1, 2, 3, 4, 5]]], dtype=np.float32).transpose((0, 2, 1))
filters = np.array([[[1, 0, -1], [-1, 0, 1]],
[[1, 1, 1], [-1, -1, -1]]],
dtype=np.float32).transpose((2, 1, 0))
print(f"x strides: {x.strides}\nfilters strides: {filters.strides}\nx shape: {x.shape}\nfilters shape: {filters.shape}\nx: \n{x}\nfilters: \n{filters}\n")
res = tf.nn.conv1d_transpose(x, filters, output_shape = x.shape, strides = (1, 1, 1), padding = 'SAME', data_format='NWC', dilations=1)
res = np.asarray(res)
print(f"result strides: {res.strides}\nresult shape: {res.shape}\nresult: \n{res}\n")
Jax 的输出
result strides: (40, 8, 4)
result shape: (1, 5, 2)
result:
[[[ 0. 0.]
[ 0. 0.]
[ 0. 0.]
[10. 10.]
[ 0. 10.]]]
TensorFlow 的输出
result strides: (40, 8, 4)
result shape: (1, 5, 2)
result:
[[[ 5. -5.]
[ 8. -8.]
[ 11. -11.]
[ 4. -4.]
[ 5. -5.]]]
函数 conv1d_transpose
需要形状 [filter_width, output_channels, in_channels]
的过滤器。如果上面片段中的 filters
被转置以满足这个形状,那么对于 jax 到 return 正确的结果,同时计算 dn1
参数应该是 WOI
(W idth - Output_channels - Input_channels) 而不是 WIO
(Width - Input_channels - Output_channels)。之后:
result.strides = (40, 8, 4)
result.shape = (1, 5, 2)
result:
[[[ -5., 5.],
[ -8., 8.],
[-11., 11.],
[ -4., 4.],
[ -5., 5.]]]
结果与 tensorflow 不同,但是 jax 的内核被翻转了,所以这实际上是预期的。
我正在尝试使用 conv1d 函数分别在 jax 和 tensorflow 上进行转置转换。我阅读了关于 con1d_transposed 操作的 jax 和 tensorflow 的文档,但它们对相同的输入产生了不同的输出。
我找不到问题所在。而且我不知道哪一个会产生正确的结果。请帮助我。
我的 Jax 实现(Jax 代码)
x = np.asarray([[[1, 2, 3, 4, -5], [1, 2, 3, 4, 5]]], dtype=np.float32).transpose((0, 2, 1))
filters = np.array([[[1, 0, -1], [-1, 0, 1]],
[[1, 1, 1], [-1, -1, -1]]],
dtype=np.float32).transpose((2, 1, 0))
kernel_rot = np.rot90(np.rot90(filters))
print(f"x strides: {x.strides}\nfilters strides: {kernel_rot.strides}\nx shape: {x.shape}\nfilters shape: {filters.shape}\nx: \n{x}\nfilters: \n{filters}\n")
dn1 = lax.conv_dimension_numbers(x.shape, filters.shape,('NWC', 'WIO', 'NWC'))
print(dn1)
res = lax.conv_general_dilated(x,kernel_rot,(1,),'SAME',(1,),(1,),dn1)
res = np.asarray(res)
print(f"result strides: {res.strides}\nresult shape: {res.shape}\nresult: \n{res}\n")
我的 TensorFlow 实现(TensorFlow 代码)
x = np.asarray([[[1, 2, 3, 4, -5], [1, 2, 3, 4, 5]]], dtype=np.float32).transpose((0, 2, 1))
filters = np.array([[[1, 0, -1], [-1, 0, 1]],
[[1, 1, 1], [-1, -1, -1]]],
dtype=np.float32).transpose((2, 1, 0))
print(f"x strides: {x.strides}\nfilters strides: {filters.strides}\nx shape: {x.shape}\nfilters shape: {filters.shape}\nx: \n{x}\nfilters: \n{filters}\n")
res = tf.nn.conv1d_transpose(x, filters, output_shape = x.shape, strides = (1, 1, 1), padding = 'SAME', data_format='NWC', dilations=1)
res = np.asarray(res)
print(f"result strides: {res.strides}\nresult shape: {res.shape}\nresult: \n{res}\n")
Jax 的输出
result strides: (40, 8, 4)
result shape: (1, 5, 2)
result:
[[[ 0. 0.]
[ 0. 0.]
[ 0. 0.]
[10. 10.]
[ 0. 10.]]]
TensorFlow 的输出
result strides: (40, 8, 4)
result shape: (1, 5, 2)
result:
[[[ 5. -5.]
[ 8. -8.]
[ 11. -11.]
[ 4. -4.]
[ 5. -5.]]]
函数 conv1d_transpose
需要形状 [filter_width, output_channels, in_channels]
的过滤器。如果上面片段中的 filters
被转置以满足这个形状,那么对于 jax 到 return 正确的结果,同时计算 dn1
参数应该是 WOI
(W idth - Output_channels - Input_channels) 而不是 WIO
(Width - Input_channels - Output_channels)。之后:
result.strides = (40, 8, 4)
result.shape = (1, 5, 2)
result:
[[[ -5., 5.],
[ -8., 8.],
[-11., 11.],
[ -4., 4.],
[ -5., 5.]]]
结果与 tensorflow 不同,但是 jax 的内核被翻转了,所以这实际上是预期的。