在张量流中的网格中平铺图像(即环绕)

tiling images in a grid (i.e. with wrapround) in tensorflow

我想做的简短版本是获取格式为 (h, w, num_images) 的一堆图像,然后将它们平铺在网格中以生成单个图像很容易绘制,但我希望将它们放在网格中,即环绕(我想在 tensorflow 中执行此操作,即图形输出准备绘制的网格图像)。

通过输入:

a.) 列数(即单行图像的最大数量)

b.) 最大宽度(例如屏幕宽度)。它会自动计算以上内容

我有 numpy 代码可以执行此操作,但速度很慢,我认为在 GPU 上执行此操作作为图形的一部分更有意义。

我的张量流图代码是这样的(t 是卷积层的输出,所以最后一个轴包含图像堆栈):

act = tf.squeeze(t) # batch size is 1, so remove it
act = tf.unstack(act, num=num_filters, axis=-1) # split last axis (filters) into list of (h, w)
act = tf.stack(act) # re-stack on first axis

这给了我 (num_filters, h, w) 我将其输入到我编写的更通用的 numpy 代码中,该代码将其放在网格中(我的 numpy 代码很长,因为它更通用并使用可变大小的图像,所以我不包括在下面)。

这可以直接在tensorflow中做吗?

(请注意,如果我要执行 tf.concat 而不是 tf.stack,我可以将它们并排平铺,但它们没有环绕)

我的 numpy 代码不到 20 行就完成了,而且速度非常快。当我平铺 10000x10000x3 的图像时,速度非常快。如果图像不足,它会用零填充最后几个图块。

def reshape_row(arr):
    return reduce(lambda x, y: np.concatenate((x,y), axis=1), arr)

def reshape_col(arr):
    return reduce(lambda x, y: np.concatenate((x,y), axis=0), arr)

def arbitrary_rows_cols(arr, num_rows, num_cols, gray=False):
    num_images, height, width, depth, = arr.shape
    rows = []
    for i in range(num_rows):
        row_image = arr[i*num_cols:i*num_cols+num_cols]
        r_n, r_h, r_w, r_d = row_image.shape
        if row_image.shape[0] != num_cols:
            for _ in range(num_cols - row_image.shape[0]):
                row_image = np.concatenate((row_image, np.expand_dims(np.zeros((height, width, depth)), axis=0)), axis=0)
        row_image = reshape_row(row_image)
        rows.append(row_image)
    mosaic = reshape_col(rows)
    return mosaic

您可以将这段代码翻译成 TensorFlow 代码,这样可能会更快。看到性能比较会很有趣。

实际上,我只是找到了一种非常简单的方法,即输入行数(这并不理想,但现在已经足够好了)。

def make_grid(t, num_images, num_rows=2):
    '''takes stack of images as (1, w, h, num_images) and tiles them into a grid'''
    t = tf.squeeze(t) # remove single batch, TODO make more flexible to work with higher batch size
    t = tf.unstack(t, num=num_images, axis=-1) # split last axis (num_images) into list of (h, w)
    t = tf.concat(t, axis=1) # tile all images horizontally into single row
    t = tf.split(t, num_rows, axis=1) # split into desired number of rows
    t = tf.concat(t, axis=0) # tile rows vertically
    return t

对于形状为 [Batch, Width, Height, Channels] 的图像,这可以在 tensorflow 中使用

def image_grid(x, size=6):
    t = tf.unstack(x[:size * size], num=size*size, axis=0)
    rows = [tf.concat(t[i*size:(i+1)*size], axis=0) 
            for i in range(size)]
    image = tf.concat(rows, axis=1)
    return image[None]

TensorFlow 中添加了一个函数来执行此操作:tf.contrib.gan.eval.image_grid。它接受形状为 [batch, width, height, channels] 的输入张量以及图像网格的形状、每个图像的尺寸和图像通道的数量作为参数。它运行良好且易于使用。

如果你没有指定你想要的列数,我有一个函数可以生成一个方形网格,否则一个网格 n_cols

def tf_batch_to_canvas(X, cols: int = None):
"""
reshape a batch of images into a grid canvas to form a single image.

Parameters
----------
X: Tensor
    Batch of images to format. [N, H, W, C]-shaped
cols: int
    how many columns the grid should have. If None, a square grid will be created.
Returns
-------
image_grid: Tensor
    Tensor representing the image grid. [1, HH, WW, C]-shaped

Raises
------
    ValueError: The input tensor must be 4 dimensional

Examples
--------

x = np.ones((9, 100, 100, 3))
x = tf.convert_to_tensor(x)
canvas = batches.tf_batch_to_canvas(x)
assert canvas.shape == (1, 300, 300, 3)

canvas = batches.tf_batch_to_canvas(x, cols=5)
assert canvas.shape == (1, 200, 500, 3)
"""
if len(X.shape.as_list()) > 4:
    raise ValueError("input tensor has more than 4 dimensions.")
N, H, W, C = X.shape.as_list()
rc = math.sqrt(N)
if cols is None:
    rows = cols = math.ceil(rc)
else:
    cols = max(1, cols)
    rows = math.ceil(N / cols)
n_gray_tiles = cols * rows - N
if n_gray_tiles > 0:
    gray_tiles = tf.zeros((n_gray_tiles, H, W, C), X.dtype)
    X = tf.concat([X, gray_tiles], 0)
image_shape = (H, W)
n_channels = C
return image_grid(X, (rows, cols), image_shape, n_channels)

https://github.com/theRealSuperMario/edflow/blob/tf_batches/edflow/iterators/tf_batches.py#L8-L53

基本上,您现在可以使用 tf.contrib.gan.eval.image_grid,具体取决于您的 tensorflow 版本。 https://www.tensorflow.org/api_docs/python/tf/contrib/gan/eval/image_grid

https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/contrib/gan/python/eval/python/eval_utils_impl.py#L34-L80