卷积 - 计算矢量化图像的相邻元素索引

Convolution - Calculating a Neighbour Element Index for a Vectorised Image

假设以下矩阵在矩阵卷积运算中既充当图像又充当内核:

0 1 2  
3 4 5  
6 7 8

要计算相邻像素索引,您可以使用以下公式:

neighbourColumn = imageColumn + (maskColumn - centerMaskColumn);
neighbourRow = imageRow + (maskRow - centerMaskRow);

因此卷积的输出将是:

output1 = {0,1,3,4} x {4,5,7,8} = 58
output2 = {0,1,2,3,4,5} x {3,4,5,6,7,8} = 100 
output2 = {1,2,4,5} x {3,4,6,7} = 70
output3 = {0,1,3,4,6,7} x {1,2,4,5,7,8} = 132
output4 = {0,1,2,3,4,5,6,7,8} x {0,1,2,3,4,5,6,7,8} = 204
output5 = {1,2,4,5,7,8} x {0,1,3,4,6,7} = 132
output6 = {3,4,6,7} x {1,2,4,5} = 70
output7 = {3,4,5,6,7,8} x {0,1,2,3,4,5} = 100
output8 = {4,5,7,8} x {0,1,3,4} = 58

因此输出矩阵为:

58  100 70
132 204 132
70  100 58

现在假设矩阵被展平得到以下向量:

0 1 2 3 4 5 6 7 8

此向量现在充当向量卷积运算中的图像和内核,其输出应为:

58 100 70 132 204 132 70 100 58

根据下面的代码,您如何计算向量的相邻元素索引,使其与矩阵中的相同相邻元素相对应?

public int[] convolve(int[] image, int[] kernel)
{       
  int imageValue; 
  int kernelValue;
  int outputValue;
  int[] outputImage = new int[image.length()];

  // loop through image
  for(int i = 0; i < image.length(); i++)
  {      
    outputValue = 0;

    // loop through kernel
    for(int j = 0; j < kernel.length(); j++)
    {
      neighbour = ?;

      // discard out of bound neighbours 
      if (neighbour >= 0 && neighbour < imageSize)
      {
        imageValue = image[neighbour];
        kernelValue = kernel[j];          
        outputValue += imageValue * kernelValue;
      }
    }

    outputImage[i] = outputValue;
  }        

  return output;
}

邻居索引是通过将原始像素索引偏移当前元素的索引与矩阵大小的一半之间的差异来计算的。例如,要计算列索引:

int neighbourCol = imageCol + col - (size / 2);

我放了一个工作演示 on GitHub,尽量保持整个卷积算法的可读性:

int[] dstImage = new int[srcImage.width() * srcImage.height()];

srcImage.forEachElement((image, imageCol, imageRow) -> {
  Pixel pixel = new Pixel();
  forEachElement((filter, col, row) -> {
    int neighbourCol = imageCol + col - (size / 2);
    int neighbourRow = imageRow + row - (size / 2);
    if (srcImage.hasElementAt(neighbourCol, neighbourRow)) {
      int color = srcImage.at(neighbourCol, neighbourRow);
      int weight = filter.at(col, row);
      pixel.addWeightedColor(color, weight);
    }
  });

  dstImage[(imageRow * srcImage.width() + imageCol)] = pixel.rgb();
});

当您处理 2D 图像时,除了普通的 1D 像素阵列外,您还必须保留有关图像的一些信息。特别是,您至少需要图像(和蒙版)的 width 才能找出一维数组中的哪些索引对应于原始二维图像中的哪些索引。正如已经指出的那样,在这样的像素阵列中,这些("virtual")2D坐标和1D坐标之间的转换有一般规则:

int pixelX = ...;
int pixelY = ...;
int index = pixelX + pixelY * imageSizeX;

基于此,您可以简单地在二维图像上进行卷积。可以很容易地检查您可以访问的像素限制。循环是图像和蒙版上的简单二维循环。如上所述,这一切都归结为您可以使用二维坐标访问一维数据。

这是一个例子。它将 Sobel 滤波器应用于输入图像。 (像素值可能还是有些奇怪,但是卷积本身和索引计算应该是正确的)

import java.awt.Graphics2D;
import java.awt.GridLayout;
import java.awt.image.BufferedImage;
import java.awt.image.DataBuffer;
import java.awt.image.DataBufferByte;
import java.io.File;
import java.io.IOException;

import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.SwingUtilities;

public class ConvolutionWithArrays1D
{
    public static void main(String[] args) throws IOException
    {
        final BufferedImage image = 
            asGrayscaleImage(ImageIO.read(new File("lena512color.png")));
        SwingUtilities.invokeLater(new Runnable()
        {
            @Override
            public void run()
            {
                createAndShowGUI(image);
            }
        });

    }

    private static void createAndShowGUI(BufferedImage image0)
    {
        JFrame f = new JFrame();
        f.getContentPane().setLayout(new GridLayout(1,2));


        f.getContentPane().add(new JLabel(new ImageIcon(image0)));

        BufferedImage image1 = compute(image0);
        f.getContentPane().add(new JLabel(new ImageIcon(image1)));


        f.pack();
        f.setLocationRelativeTo(null);
        f.setVisible(true);
    }

    private static BufferedImage asGrayscaleImage(BufferedImage image)
    {
        BufferedImage gray = new BufferedImage(
            image.getWidth(), image.getHeight(), BufferedImage.TYPE_BYTE_GRAY);
        Graphics2D g = gray.createGraphics();
        g.drawImage(image, 0, 0, null);
        g.dispose();
        return gray;
    }

    private static int[] obtainGrayscaleIntArray(BufferedImage image)
    {
        BufferedImage gray = new BufferedImage(
            image.getWidth(), image.getHeight(), BufferedImage.TYPE_BYTE_GRAY);
        Graphics2D g = gray.createGraphics();
        g.drawImage(image, 0, 0, null);
        g.dispose();
        DataBuffer dataBuffer = gray.getRaster().getDataBuffer();
        DataBufferByte dataBufferByte = (DataBufferByte)dataBuffer;
        byte data[] = dataBufferByte.getData();
        int result[] = new int[data.length];
        for (int i=0; i<data.length; i++)
        {
            result[i] = data[i];
        }
        return result;
    }

    private static BufferedImage createImageFromGrayscaleIntArray(
        int array[], int imageSizeX, int imageSizeY)
    {
        BufferedImage gray = new BufferedImage(
            imageSizeX, imageSizeY, BufferedImage.TYPE_BYTE_GRAY);
        DataBuffer dataBuffer = gray.getRaster().getDataBuffer();
        DataBufferByte dataBufferByte = (DataBufferByte)dataBuffer;
        byte data[] = dataBufferByte.getData();
        for (int i=0; i<data.length; i++)
        {
            data[i] = (byte)array[i];
        }
        return gray;
    }

    private static BufferedImage compute(BufferedImage image)
    {
        int imagePixels[] = obtainGrayscaleIntArray(image);
        int mask[] = 
        {
             1,0,-1,
             2,0,-2,
             1,0,-1,
        };
        int outputPixels[] = 
            Convolution.filter(imagePixels, image.getWidth(), mask, 3);
        return createImageFromGrayscaleIntArray(
            outputPixels, image.getWidth(), image.getHeight());

    }

}


class Convolution
{
    public static final int[] filter(
        final int[] image, int imageSizeX,
        final int[] mask, int maskSizeX)
    {
        int imageSizeY = image.length / imageSizeX;
        int maskSizeY = mask.length / maskSizeX;
        int output[] = new int[image.length];
        for (int y=0; y<imageSizeY; y++)
        {
            for (int x=0; x<imageSizeX; x++)
            {
                int outputPixelValue = 0;
                for (int my=0; my< maskSizeY; my++)
                {
                    for (int mx=0; mx< maskSizeX; mx++)
                    {
                        int neighborX = x + mx -maskSizeX / 2;
                        int neighborY = y + my -maskSizeY / 2;

                        if (neighborX >= 0 && neighborX < imageSizeX &&
                            neighborY >= 0 && neighborY < imageSizeY)
                        {
                            int imageIndex = 
                                neighborX + neighborY * imageSizeX;
                            int maskIndex = mx + my * maskSizeX;

                            int imagePixelValue = image[imageIndex];
                            int maskPixelValue = mask[maskIndex];
                            outputPixelValue += 
                                imagePixelValue * maskPixelValue;
                        }
                    }
                }
                outputPixelValue = truncate(outputPixelValue);

                int outputIndex = x + y * imageSizeX;
                output[outputIndex] = outputPixelValue;
            }
        }
        return output;
    }

    private static final int truncate(final int pixelValue)
    {
        return Math.min(255, Math.max(0, pixelValue));
    }
}