如何将 (B x C x H x W) 个张量块无缝地混合在一起以隐藏块边界?

How to seemlessly blend (B x C x H x W) tensor tiles together to hide tile boundaries?

为了完整起见,这是我正在尝试做的事情的文本摘要:

  1. 将图像拆分为图块。
  2. 运行 每个图块通过模型的新副本进行一定数量的迭代。
  3. 羽化瓷砖并将它们排成行。
  4. 羽化行并将它们重新组合成原来的image/tensor。
  5. 也许保存输出,然后再次将输出拆分为图块。
  6. 重复步骤 2 和 3 进行一定次数的迭代。

我只需要步骤 1、3 和 4 的帮助。样式转换过程会导致处理后的图块中形成一些细微的差异,因此我需要将它们重新混合在一起。通过羽化,我基本上是指将一个图块淡入另一个图块以模糊边界(例如在 ImageMagick、Photoshop 等中)。我正在尝试通过使用 Torch.linspace() 创建蒙版来完成这种混合,但我不确定是否有更好的方法。

我要完成的工作基于 on/inspired https://github.com/VaKonS/neural-style/blob/Multi-resolution/neural_style.lua, though I'm working with PyTorch. The code I am trying to implement tiling with can be found here: https://gist.github.com/ProGamerGov/e64fcb309274c2946f5a9a679ed45669,但您不需要查看它,因为您需要的一切都可以在下面找到。

本质上这就是我想要做的(红色区域与另一个图块重叠):

.

这就是我到目前为止的代码。羽化和添加行还没有实现,因为我还不能让单独的图块羽化工作。

import torch
from PIL import Image
import torchvision.transforms as transforms

def tile_calc(tile_size, v, d):
    max_val = max(min(tile_size*v+tile_size, d), 0)
    min_val = tile_size*v
    if abs(min_val - max_val) < tile_size:
        min_val = max_val-tile_size
    return min_val, max_val

def split_tensor(tensor, tile_size=256):
    tiles, tile_idx = [], []
    tile_size_y, tile_size_x = tile_size+8, tile_size +5 # Make H and W different for testing
    h, w = tensor.size(2), tensor.size(3)
    h_range, w_range = int(-(h // -tile_size_y)), int(-(w // -tile_size_x))

    for y in range(h_range):       
        for x in range(w_range):        
            ty, y_val = tile_calc(tile_size_y, y, h)
            tx, x_val = tile_calc(tile_size_x, x, w)

            tiles.append(tensor[:, :, ty:y_val, tx:x_val])
            tile_idx.append([ty, y_val, tx, x_val])

    w_overlap = tile_idx[0][3] - tile_idx[1][2]
    h_overlap = tile_idx[0][1] - tile_idx[w_range][0]

    if tensor.is_cuda:
        base_tensor = torch.zeros(tensor.squeeze(0).size(), device=tensor.get_device())
    else: 
        base_tensor = torch.zeros(tensor.squeeze(0).size())
    return tiles, base_tensor.unsqueeze(0), (h_range, w_range), (h_overlap, w_overlap) 

 # Feather vertically          
def feather_tiles(tensor_list, hxw, w_overlap):
    print(len(tensor_list))
    mask_list = []
    if w_overlap > 0:
        for i, tile in enumerate(tensor_list):
            if i % hxw[1] != 0:
                lin_mask = torch.linspace(0,1,w_overlap).repeat(tile.size(2),1)
                mask_part = torch.ones(tile.size(2), tile.size(3)-w_overlap)
                mask = torch.cat([lin_mask, mask_part], 1)
                mask = mask.repeat(3,1,1).unsqueeze(0)
                mask_list.append(mask)
            else:
                mask = torch.ones(tile.squeeze().size()).unsqueeze(0)
                mask_list.append(mask)
    return mask_list


def build_row(tensor_tiles, tile_masks, hxw, w_overlap, bt, tile_size):
    print(len(tensor_tiles), len(tile_masks))
    if bt.is_cuda:
        row_base = torch.ones(bt.size(1),tensor_tiles[0].size(2),bt.size(3), device=bt.get_device()).unsqueeze(0)
    else: 
        row_base = torch.ones(bt.size(1),tensor_tiles[0].size(2),bt.size(3)).unsqueeze(0)
    row_list = []
    for v in range(hxw[1]):
      row_list.append(row_base.clone())  

    num_tiles = 0
    row_val = 0
    tile_size_y, tile_size_x = tile_size+8, tile_size +5
    h, w = bt.size(2), bt.size(3)
    h_range, w_range = hxw[0], hxw[1]
    for y in range(h_range):       
        for x in range(w_range):        
            ty, y_val = tile_calc(tile_size_y, y, h)
            tx, x_val = tile_calc(tile_size_x, x, w)

            if num_tiles % hxw[1] != 0: 
                new_mean = (row_list[row_val][:, :, :, tx:x_val].mean() + tensor_tiles[num_tiles])/2
                row_list[row_val][:, :, :, tx:x_val] = row_list[row_val][:, :, :, tx:x_val] - row_list[row_val][:, :, :, tx:x_val].mean()
                tensor_tiles[num_tiles] = tensor_tiles[num_tiles] - tensor_tiles[num_tiles].mean()  

                row_list[row_val][:, :, :, tx:x_val] = (row_list[row_val][:, :, :, tx:x_val] + ( tensor_tiles[num_tiles] * tile_masks[num_tiles])) + new_mean

            else:
                row_list[row_val][:, :, :, tx:x_val] = tensor_tiles[num_tiles]          
            num_tiles+=1 
        row_val+=1          
    return row_list


def preprocess(image_name, image_size):
    image = Image.open(image_name).convert('RGB')
    if type(image_size) is not tuple:
        image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    tensor = (Loader(image) * 256).unsqueeze(0)
    return tensor

def deprocess(output_tensor):
    output_tensor = output_tensor.squeeze(0).cpu() / 256
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.cpu())
    return image


input_tensor = preprocess('test.jpg', 256)

tile_tensors, base_t, hxw, ovlp = split_tensor(input_tensor, 128)
tile_masks = feather_tiles(tile_tensors, hxw, ovlp[1])
row_tensors = build_row(tile_tensors, tile_masks, hxw, ovlp[1], base_t, 128)

ft = deprocess(row_tensors[0]) # save tensor to view it 
ft.save('ft_row_0.png')

您正在寻找 torch.nn.functional.unfold and torch.nn.functional.fold。 这些函数允许您对具有任意 window 大小和跨度的图像应用 "sliding window" 操作。
gives more information about these functions, and 举例说明如何使用 fold.
"blend" 重叠 windows 这些参考资料应为您提供实施混合方案所需的信息。

我能够在这里创建适用于任何图块尺寸、图像尺寸和图案的解决方案:https://github.com/ProGamerGov/neural-dream/blob/master/neural_dream/dream_tile.py

我使用蒙版将瓷砖重新混合在一起。