使用 scan 模拟嵌套 for 循环很慢
Emulating nested for loops with scan is slow
我正在尝试使用扫描函数模拟循环嵌套,但这很慢。有没有更好的方法用 Tensorflow 模拟循环嵌套?我不是只用 numpy 做这个计算,这样我就可以自动微分。
具体来说,我在使用 Tensorflow 控制操作的同时对带有双边滤波器的图像进行卷积。为此,我嵌套了 scan() 函数,但这让我的性能非常差——过滤一张小图像需要超过 5 分钟。
有没有比嵌套扫描函数更好的方法,我使用 Tensorflow 控制流操作有多糟糕? 我对一般答案感兴趣,而不是针对我的代码的具体答案。
这里是原始的,如果你想看更快的代码:
def bilateralFilter(image, sigma_space=1, sigma_range=None, win_size=None):
if sigma_range is None:
sigma_range = sigma_space
if win_size is None: win_size = max(5, 2 * int(np.ceil(3*sigma_space)) + 1)
win_ext = (win_size - 1) / 2
height = image.shape[0]
width = image.shape[1]
# pre-calculate spatial_gaussian
spatial_gaussian = []
for i in range(-win_ext, win_ext+1):
for j in range(-win_ext, win_ext+1):
spatial_gaussian.append(np.exp(-0.5*(i**2+j**2)/sigma_space**2))
padded = np.pad(image, win_ext, mode="edge")
out_image = np.zeros(image.shape)
weight = np.zeros(image.shape)
idx = 0
for row in xrange(-win_ext, 1+win_ext):
for col in xrange(-win_ext, 1+win_ext):
slice = padded[win_ext+row:height+win_ext+row,
win_ext+col:width+win_ext+col]
value = np.exp(-0.5*((image - slice)/sigma_range)**2) \
* spatial_gaussian[idx]
out_image += value*slice
weight += value
idx += 1
out_image /= weight
return out_image
这是 Tensorflow 版本:
sess = tf.InteractiveSession()
with sess.as_default():
def bilateralFilter(image, sigma_space, sigma_range):
win_size = max(5., 2 * np.ceil(3 * sigma_space) + 1)
win_ext = int((win_size - 1) / 2)
height = tf.shape(image)[0].eval()
width = tf.shape(image)[1].eval()
spatial_gaussian = []
for i in range(-win_ext, win_ext + 1):
for j in range(-win_ext, win_ext + 1):
spatial_gaussian.append(np.exp(-0.5 * (i ** 2 +\
j ** 2) / sigma_space ** 2))
# we use "symmetric" as it best approximates "edge" padding
padded = tf.pad(image, [[win_ext, win_ext], [win_ext, win_ext]],
mode='SYMMETRIC')
out_image = tf.zeros(tf.shape(image))
weight = tf.zeros(tf.shape(image))
spatial_index = tf.constant(0)
row = tf.constant(-win_ext)
col = tf.constant(-win_ext)
def cond(padded, row, col, weight, out_image, spatial_index):
return tf.less(row, win_ext + 1)
def body(padded, row, col, weight, out_image, spatial_index):
sub_image = tf.slice(padded, [win_ext + row, win_ext + col],
[height, width])
value = tf.exp(-0.5 *
(((image - sub_image) / sigma_range) ** 2)) *
spatial_gaussian[spatial_index.eval()]
out_image += value * sub_image
weight += value
spatial_index += 1
row, col = tf.cond(tf.not_equal(tf.mod(col,
tf.constant(2*win_ext + 1)), 0),
lambda: (row + 1, tf.constant(-win_ext)),
lambda: (row, col))
return padded, row, col, weight, out_image, spatial_index
padded, row, col, weight, out_image, spatial_index =
tf.while_loop(cond, body,
[padded, row, col, weight, out_image, spatial_index])
out_image /= weight
return out_image
cat = plt.imread("cat.png") # grayscale
cat = tf.reshape(tf.constant(cat), [276, 276])
cat_blurred = bilateralFilter(cat, 2., 0.25)
cat_blurred = cat_blurred.eval()
plt.figure()
plt.gray()
plt.imshow(cat_blurred)
plt.show()
您的代码存在一个问题。 cols() 有一堆 python 全局变量,您似乎希望它们在每次循环迭代时更新。请查看有关图形构造和执行的 TensorFlow 教程。简而言之,那些 python 全局变量及其相关代码只会在图构建时执行,甚至不在 TensorFlow 的执行图中。如果操作是 tf 运算符,则它只能包含在执行图中。
此外,tf.while_loop 似乎比扫描更适合您的代码。
我正在尝试使用扫描函数模拟循环嵌套,但这很慢。有没有更好的方法用 Tensorflow 模拟循环嵌套?我不是只用 numpy 做这个计算,这样我就可以自动微分。
具体来说,我在使用 Tensorflow 控制操作的同时对带有双边滤波器的图像进行卷积。为此,我嵌套了 scan() 函数,但这让我的性能非常差——过滤一张小图像需要超过 5 分钟。
有没有比嵌套扫描函数更好的方法,我使用 Tensorflow 控制流操作有多糟糕? 我对一般答案感兴趣,而不是针对我的代码的具体答案。
这里是原始的,如果你想看更快的代码:
def bilateralFilter(image, sigma_space=1, sigma_range=None, win_size=None):
if sigma_range is None:
sigma_range = sigma_space
if win_size is None: win_size = max(5, 2 * int(np.ceil(3*sigma_space)) + 1)
win_ext = (win_size - 1) / 2
height = image.shape[0]
width = image.shape[1]
# pre-calculate spatial_gaussian
spatial_gaussian = []
for i in range(-win_ext, win_ext+1):
for j in range(-win_ext, win_ext+1):
spatial_gaussian.append(np.exp(-0.5*(i**2+j**2)/sigma_space**2))
padded = np.pad(image, win_ext, mode="edge")
out_image = np.zeros(image.shape)
weight = np.zeros(image.shape)
idx = 0
for row in xrange(-win_ext, 1+win_ext):
for col in xrange(-win_ext, 1+win_ext):
slice = padded[win_ext+row:height+win_ext+row,
win_ext+col:width+win_ext+col]
value = np.exp(-0.5*((image - slice)/sigma_range)**2) \
* spatial_gaussian[idx]
out_image += value*slice
weight += value
idx += 1
out_image /= weight
return out_image
这是 Tensorflow 版本:
sess = tf.InteractiveSession()
with sess.as_default():
def bilateralFilter(image, sigma_space, sigma_range):
win_size = max(5., 2 * np.ceil(3 * sigma_space) + 1)
win_ext = int((win_size - 1) / 2)
height = tf.shape(image)[0].eval()
width = tf.shape(image)[1].eval()
spatial_gaussian = []
for i in range(-win_ext, win_ext + 1):
for j in range(-win_ext, win_ext + 1):
spatial_gaussian.append(np.exp(-0.5 * (i ** 2 +\
j ** 2) / sigma_space ** 2))
# we use "symmetric" as it best approximates "edge" padding
padded = tf.pad(image, [[win_ext, win_ext], [win_ext, win_ext]],
mode='SYMMETRIC')
out_image = tf.zeros(tf.shape(image))
weight = tf.zeros(tf.shape(image))
spatial_index = tf.constant(0)
row = tf.constant(-win_ext)
col = tf.constant(-win_ext)
def cond(padded, row, col, weight, out_image, spatial_index):
return tf.less(row, win_ext + 1)
def body(padded, row, col, weight, out_image, spatial_index):
sub_image = tf.slice(padded, [win_ext + row, win_ext + col],
[height, width])
value = tf.exp(-0.5 *
(((image - sub_image) / sigma_range) ** 2)) *
spatial_gaussian[spatial_index.eval()]
out_image += value * sub_image
weight += value
spatial_index += 1
row, col = tf.cond(tf.not_equal(tf.mod(col,
tf.constant(2*win_ext + 1)), 0),
lambda: (row + 1, tf.constant(-win_ext)),
lambda: (row, col))
return padded, row, col, weight, out_image, spatial_index
padded, row, col, weight, out_image, spatial_index =
tf.while_loop(cond, body,
[padded, row, col, weight, out_image, spatial_index])
out_image /= weight
return out_image
cat = plt.imread("cat.png") # grayscale
cat = tf.reshape(tf.constant(cat), [276, 276])
cat_blurred = bilateralFilter(cat, 2., 0.25)
cat_blurred = cat_blurred.eval()
plt.figure()
plt.gray()
plt.imshow(cat_blurred)
plt.show()
您的代码存在一个问题。 cols() 有一堆 python 全局变量,您似乎希望它们在每次循环迭代时更新。请查看有关图形构造和执行的 TensorFlow 教程。简而言之,那些 python 全局变量及其相关代码只会在图构建时执行,甚至不在 TensorFlow 的执行图中。如果操作是 tf 运算符,则它只能包含在执行图中。
此外,tf.while_loop 似乎比扫描更适合您的代码。