如何使这个 PyTorch 张量(B、C、H、W)平铺和混合代码更简单、更高效?

How can I make this PyTorch tensor (B, C, H, W) tiling & blending code simpler and more efficient?

所以,我在很多个月前编写了下面的代码,并且运行良好。尽管我正在努力研究如何简化它并使其更有效率。

下面的函数将图像张量(B、C、H、W)拆分为大小相等的图块(B、C、H、W),然后您可以对图块单独执行操作以节省内存。然后当从瓦片重建张量时,它使用遮罩来确保瓦片无缝地混合在一起。当最右列中的图块或底行中的图块不能使用与其他图块相同的重叠时,掩码函数中的 'special masks' 会处理。这意味着右边缘图块和底部图块有时可能几乎 none 的内容可见。这样做是为了确保图块始终准确指定大小,而不管原始 image/tensor 的大小(对于 visualization/DeepDream、神经样式转换等很重要)。与边缘 row/column 相邻的 row/column 也有特殊掩码,以及它们与边缘 row/column.

重叠的地方

每个图块有 8 个可能的面具,其中 4 个面具可以一次使用。 4 种可能的面具是左、右、上和下,每个面具都有一个特殊版本。

# Improved version of: https://github.com/ProGamerGov/neural-dream/blob/master/neural_dream/dream_tile.py
import torch


# Apply blend masks to tiles
def mask_tile(tile, overlap, side='bottom'):
    c, h, w = tile.size(1), tile.size(2), tile.size(3)
    top_overlap, bottom_overlap, right_overlap, left_overlap = overlap[0], overlap[1], overlap[2], overlap[3]

    base_mask = torch.ones_like(tile)

    if 'left' in side and 'left-special' not in side:
        lin_mask_left = torch.linspace(0,1,left_overlap, device=tile.device).repeat(h,1).repeat(c,1,1).unsqueeze(0)
        base_mask[:,:,:,:left_overlap] = base_mask[:,:,:,:left_overlap] * lin_mask_left
    if 'right' in side and 'right-special' not in side:
        lin_mask_right = torch.linspace(1,0,right_overlap, device=tile.device).repeat(h,1).repeat(c,1,1).unsqueeze(0)
        base_mask[:,:,:,w-right_overlap:] = base_mask[:,:,:,w-right_overlap:] * lin_mask_right
    if 'top' in side and 'top-special' not in side:
        lin_mask_top = torch.linspace(0,1,top_overlap, device=tile.device).repeat(w,1).rot90(3).repeat(c,1,1).unsqueeze(0)
        base_mask[:,:,:top_overlap,:] = base_mask[:,:,:top_overlap,:] * lin_mask_top
    if 'bottom' in side and 'bottom-special' not in side:
        lin_mask_bottom = torch.linspace(1,0,bottom_overlap, device=tile.device).repeat(w,1).rot90(3).repeat(c,1,1).unsqueeze(0)
        base_mask[:,:,h-bottom_overlap:,:] = base_mask[:,:,h-bottom_overlap:,:] * lin_mask_bottom

    if 'left-special' in side:
        lin_mask_left = torch.linspace(0,1,left_overlap, device=tile.device)
        zeros_mask = torch.zeros(w-(left_overlap*2), device=tile.device)
        ones_mask = torch.ones(left_overlap, device=tile.device)
        lin_mask_left = torch.cat([zeros_mask, lin_mask_left, ones_mask], 0).repeat(h,1).repeat(c,1,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_left
    if 'right-special' in side:
        lin_mask_right = torch.linspace(1,0,right_overlap, device=tile.device)
        ones_mask = torch.ones(w-right_overlap, device=tile.device)
        lin_mask_right = torch.cat([ones_mask, lin_mask_right], 0).repeat(h,1).repeat(c,1,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_right
    if 'top-special' in side:
        lin_mask_top = torch.linspace(0,1,top_overlap, device=tile.device)
        zeros_mask = torch.zeros(h-(top_overlap*2), device=tile.device)
        ones_mask = torch.ones(top_overlap, device=tile.device)
        lin_mask_top = torch.cat([zeros_mask, lin_mask_top, ones_mask], 0).repeat(w,1).rot90(3).repeat(c,1,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_top
    if 'bottom-special' in side:
        lin_mask_bottom = torch.linspace(1,0,bottom_overlap, device=tile.device)
        ones_mask = torch.ones(h-bottom_overlap, device=tile.device)
        lin_mask_bottom = torch.cat([ones_mask, lin_mask_bottom], 0).repeat(w,1).rot90(3).repeat(c,1,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_bottom
        
    # Apply mask to tile and return masked tile
    return tile * base_mask


def add_tiles(tiles, base_img, tile_coords, tile_size, overlap):

    # Check for any tiles that need different overlap values
    r, c = len(tile_coords[0]), len(tile_coords[1])
    f_ovlp = (tile_coords[0][r-1] - tile_coords[0][r-2], tile_coords[1][c-1] - tile_coords[1][c-2])

    h, w = tiles[0].size(2), tiles[0].size(3)
    t=0
    column, row, = 0, 0
    for y in tile_coords[0]:
        for x in tile_coords[1]:
            mask_sides=''
            c_overlap = overlap.copy()
            if row == 0:
                if row == len(tile_coords[0]) - 2:
                    mask_sides += 'bottom-special'
                    c_overlap[1] = f_ovlp[0] # Change bottom overlap
                else:
                    mask_sides += 'bottom'
            elif row > 0 and row < len(tile_coords[0]) -2:
                mask_sides += 'bottom,top'
            elif row == len(tile_coords[0]) - 2:
                if f_ovlp[0] > 0:
                    mask_sides += 'bottom-special,top'
                    c_overlap[1] = f_ovlp[0] # Change bottom overlap
                elif f_ovlp[0] <= 0:
                    mask_sides += 'bottom,top'
            elif row == len(tile_coords[0]) -1:
                if f_ovlp[0] > 0:
                    mask_sides += 'top-special'
                    c_overlap[0] = f_ovlp[0] # Change top overlap
                elif f_ovlp[0] <= 0:
                    mask_sides += 'top'

            if column == 0:
                if column == len(tile_coords[1]) -2:
                    mask_sides += ',right-special'
                    c_overlap[2] = f_ovlp[1] # Change right overlap
                else:
                    mask_sides += ',right'
            elif column > 0 and column < len(tile_coords[1]) -2:
                mask_sides += ',right,left'
            elif column == len(tile_coords[1]) -2:
                if f_ovlp[1] > 0:
                    mask_sides += ',right-special,left'
                    c_overlap[2] = f_ovlp[1] # Change right overlap
                elif f_ovlp[1] <= 0:
                    mask_sides += ',right,left'
            elif column == len(tile_coords[1]) -1:
                if f_ovlp[1] > 0:
                    mask_sides += ',left-special'
                    c_overlap[3] = f_ovlp[1] # Change left overlap
                elif f_ovlp[1] <= 0:
                    mask_sides += ',left'

            tile = mask_tile(tiles[t], c_overlap, side=mask_sides)
            base_img[:, :, y:y+tile_size[0], x:x+tile_size[1]] = base_img[:, :, y:y+tile_size[0], x:x+tile_size[1]] + tile
            t+=1
            column+=1
        row+=1
        column=0
    return base_img


# Calculate the coordinates for tiles
def get_tile_coords(d, tile_dim, overlap=0):
    move = int(tile_dim * (1-overlap))
    c, tile_start, coords = 1, 0, [0]
    while tile_start + tile_dim < d:
        tile_start = move * c
        if tile_start + tile_dim >= d:
            coords.append(d - tile_dim)
        else:
            coords.append(tile_start)
        c += 1
    return coords


# Calculates info required for tiling
def tile_setup(tile_size, overlap_percent, base_size):
    if type(tile_size) is not tuple and type(tile_size) is not list:
        tile_size = (tile_size, tile_size)
    if type(overlap_percent) is not tuple and type(overlap_percent) is not list:
        overlap_percent = (overlap_percent, overlap_percent)
    x_coords = get_tile_coords(base_size[1], tile_size[1], overlap_percent[1])
    y_coords = get_tile_coords(base_size[0], tile_size[0], overlap_percent[0])
    y_ovlp, x_ovlp = int(tile_size[0] * overlap_percent[0]), int(tile_size[1] * overlap_percent[1])
    return (y_coords, x_coords), tile_size, [y_ovlp, y_ovlp, x_ovlp, x_ovlp]


# Split tensor into tiles
def tile_image(img, tile_size, overlap_percent, info_only=False):
    tile_coords, tile_size, _ = tile_setup(tile_size, overlap_percent, (img.size(2), img.size(3)))

    # Cut out tiles
    tile_list = []
    for y in tile_coords[0]:
        for x in tile_coords[1]:
            tile = img[:, :, y:y + tile_size[0], x:x + tile_size[1]]
            tile_list.append(tile)
    return tile_list


# Put tiles back into the original tensor
def rebuild_image(tiles, image_size, tile_size, overlap_percent):
    base_img = torch.zeros(image_size, device=tiles[0].device)
    tile_coords, tile_size, overlap = tile_setup(tile_size, overlap_percent, (base_img.size(2), base_img.size(3)))
    return add_tiles(tiles, base_img, tile_coords, tile_size, overlap)

上面的代码可以用下面的代码进行测试:

import torchvision.transforms as transforms
from PIL import Image
import random

# Load image
def preprocess_simple(image_name, image_size):
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    image = Image.open(image_name).convert('RGB')
    return Loader(image).unsqueeze(0)
    
# Save image   
def deprocess_simple(output_tensor, output_name):
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.squeeze(0))
    image.save(output_name)    

test_input = preprocess_simple('tubingen.jpg', (1024,1024))
tile_size=260
overlap_percent=0.5

img_tiles = tile_image(test_input, tile_size=tile_size, overlap_percent=overlap_percent)

random.shuffle(img_tiles) # Comment this out to not randomize tile positions

output_tensor = rebuild_image(img_tiles, test_input.size(), tile_size=tile_size, overlap_percent=overlap_percent)
deprocess_simple(output_tensor, 'tiled_image.jpg')

我在下面提供了一个示例(顶部是原始图像,底部是我随机放回瓷砖以展示混合系统):

我能够消除所有错误并简化此处的代码:https://github.com/ProGamerGov/dream-creator/blob/master/utils/tile_utils.py

实际上只有两种情况需要特殊掩码,它们是 rebuild_tensor 中的错误,我必须修复。重叠百分比应等于或小于 50%。