在 Python 中实现 "Wave Collapse Function" 算法的问题

Issues implementing the "Wave Collapse Function" algorithm in Python

简而言之:

我在 Python 2.7 中对 Wave Collapse Function algorithm 的实现存在缺陷,但我无法确定问题出在哪里。我需要帮助找出我可能遗漏或做错了什么。

什么是Wave Collapse Function算法?

它是 Maxim Gumin 于 2016 年编写的一种算法,可以从样本图像生成程序模式。您可以看到它的实际效果 here (2D overlapping model) and here(3D 瓷砖模型)。

本次实施的目标:

将算法(2D 重叠模型)归结为本质,避免 original C# script 的冗余和笨拙(出奇地长且难以阅读)。这是试图使该算法更短、更清晰和 pythonic 版本。

这个实现的特点:

我正在使用 Processing(Python 模式),一种视觉设计软件,可以更轻松地处理图像(没有 PIL,没有 Matplotlib,...)。主要缺点是我仅限于 Python 2.7,无法导入 numpy。

与原始版本不同,此实现:

算法(据我了解)

1/ 读取输入位图,存储每个 NxN 模式并计算它们的出现次数。 (可选: 通过旋转和反射增强模式数据。)

例如当N=3时:

2/ 预先计算并存储模式之间所有可能的邻接关系。 在下面的示例中,模式 207、242、182 和 125 可以与模式 246

的右侧重叠

3/ 创建一个具有输出维度的数组(对于 wave 称为 W)。该数组的每个元素都是一个数组,其中包含每个模式的状态(True of False)。

例如,假设我们计算输入中的 326 个独特模式,并且我们希望输出的尺寸为 20 x 20(400 个单元格)。然后 "Wave" 数组将包含 400 (20x20) 个数组,每个数组包含 326 个布尔值。

开始时,所有布尔值都设置为 True,因为每个模式都允许出现在 Wave 的任何位置。

W = [[True for pattern in xrange(len(patterns))] for cell in xrange(20*20)]

4/ 使用输出的维度创建另一个数组(称为 H)。该数组的每个元素都是一个浮点数,在输出中保存其对应单元格的 "entropy" 值。

这里的熵是指Shannon Entropy,是根据Wave中特定位置的有效模式的数量计算的。一个单元格的有效模式越多(在 Wave 中设置为 True),它的熵就越高。

例如,为了计算单元格 22 的熵,我们查看其在波 (W[22]) 中的相应索引并计算设置为 True 的布尔值的数量。有了这个计数,我们现在可以用香农公式计算熵。该计算的结果将存储在 H 中的相同索引 H[22]

开始时,所有单元格都具有相同的熵值(H 中每个位置都有相同的浮点数),因为每个单元格的所有模式都设置为 True

H = [entropyValue for cell in xrange(20*20)]

这 4 个步骤是入门步骤,它们是初始化算法所必需的。现在开始算法的核心

5/观察:

找到具有 minimum nonzero 熵的单元格索引(请注意,在第一次迭代时,所有熵都相等,因此我们需要随机选择一个单元格的索引。)

然后,在 Wave 中的相应索引处查看仍然有效的模式,select 随机选择其中一个,按模式在输入图像中出现的频率加权(加权选择)。

例如,如果 H 中的最低值位于索引 22 (H[22]),我们会查看在 W[22] 处设置为 True 的所有模式并选择一个随机基于它在输入中出现的次数。 (请记住,在第 1 步,我们计算了每个模式出现的次数)。这确保模式在输出中的分布与在输入中的分布相似。

6/ 收起:

我们现在将 selected 模式的索引分配给具有最小熵的单元格。这意味着 Wave 中相应位置的每个模式都设置为 False 除了已选择的模式。

例如,如果 W[22] 中的模式 246 设置为 True 并且已经 selected,则所有其他模式都设置为 False .单元格 22 分配有模式 246输出单元格 22 将填充图案 246 的第一种颜色(左上角)。(本例中为蓝色)

7/ 传播:

由于邻接约束,该模式 selection 对 Wave 中的相邻单元格有影响。对应于最近折叠的单元格的左侧和右侧单元格的布尔数组需要相应地更新。

例如,如果单元格 22 已折叠并分配了模式 246,则 W[21](左)、W[23](右)、W[2](向上)和 W[42](向下)必须修改,因为它们只保留 True 与模式 246.

相邻的模式

例如,回看步骤2的图片,我们可以看到只有207、242、182和125的图案可以放在图案246的右边。这意味着 W[23](单元格右侧 22)需要将模式 207、242、182 和 125 保留为 True,并将数组中的所有其他模式设置为 False。如果这些模式不再有效(由于先前的约束已设置为 False),则算法将面临 矛盾 .

8/ 更新熵

因为一个单元格已经折叠(一个模式selected,设置为True)并且其周围的单元格相应地更新(将非相邻模式设置为False)熵所有这些单元格都发生了变化,需要重新计算。 (请记住,细胞的熵与其在 Wave 中持有的有效模式的数量相关。)

在示例中,单元格 22 的熵现在为 0,(H[22] = 0,因为只有模式 246W[22] 处设置为 True)并且其相邻单元格的熵已降低(与模式 246 不相邻的模式已设置为 False)。

现在算法到达第一次迭代的末尾,并将循环执行步骤 5(找到具有最小非零熵的单元格)到 8(更新熵)直到所有单元格都折叠。

我的脚本

您需要 Processing with Python mode 安装到 运行 这个脚本。 它包含大约 80 行代码(与原始脚本的约 1000 行相比较短),这些代码带有完整注释,因此可以快速理解。您还需要下载 input image 并相应地更改第 16 行的路径。

from collections import Counter
from itertools import chain, izip
import math

d = 20  # dimensions of output (array of dxd cells)
N = 3 # dimensions of a pattern (NxN matrix)

Output = [120 for i in xrange(d*d)] # array holding the color value for each cell in the output (at start each cell is grey = 120)

def setup():
    size(800, 800, P2D)
    textSize(11)

    global W, H, A, freqs, patterns, directions, xs, ys, npat

    img = loadImage('Flowers.png') # path to the input image
    iw, ih = img.width, img.height # dimensions of input image
    xs, ys = width//d, height//d # dimensions of cells (squares) in output
    kernel = [[i + n*iw for i in xrange(N)] for n in xrange(N)] # NxN matrix to read every patterns contained in input image
    directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # (x, y) tuples to access the 4 neighboring cells of a collapsed cell
    all = [] # array list to store all the patterns found in input



    # Stores the different patterns found in input
    for y in xrange(ih):
        for x in xrange(iw):

            ''' The one-liner below (cmat) creates a NxN matrix with (x, y) being its top left corner.
                This matrix will wrap around the edges of the input image.
                The whole snippet reads every NxN part of the input image and store the associated colors.
                Each NxN part is called a 'pattern' (of colors). Each pattern can be rotated or flipped (not mandatory). '''


            cmat = [[img.pixels[((x+n)%iw)+(((a[0]+iw*y)/iw)%ih)*iw] for n in a] for a in kernel]

            # Storing rotated patterns (90°, 180°, 270°, 360°) 
            for r in xrange(4):
                cmat = zip(*cmat[::-1]) # +90° rotation
                all.append(cmat) 

            # Storing reflected patterns (vertical/horizontal flip)
            all.append(cmat[::-1])
            all.append([a[::-1] for a in cmat])




    # Flatten pattern matrices + count occurences 

    ''' Once every pattern has been stored,
        - we flatten them (convert to 1D) for convenience
        - count the number of occurences for each one of them (one pattern can be found multiple times in input)
        - select unique patterns only
        - store them from less common to most common (needed for weighted choice)'''

    all = [tuple(chain.from_iterable(p)) for p in all] # flattern pattern matrices (NxN --> [])
    c = Counter(all)
    freqs = sorted(c.values()) # number of occurences for each unique pattern, in sorted order
    npat = len(freqs) # number of unique patterns
    total = sum(freqs) # sum of frequencies of unique patterns
    patterns = [p[0] for p in c.most_common()[:-npat-1:-1]] # list of unique patterns sorted from less common to most common



    # Computes entropy

    ''' The entropy of a cell is correlated to the number of possible patterns that cell holds.
        The more a cell has valid patterns (set to 'True'), the higher its entropy is.
        At start, every pattern is set to 'True' for each cell. So each cell holds the same high entropy value'''

    ent = math.log(total) - sum(map(lambda x: x * math.log(x), freqs)) / total



    # Initializes the 'wave' (W), entropy (H) and adjacencies (A) array lists

    W = [[True for _ in xrange(npat)] for i in xrange(d*d)] # every pattern is set to 'True' at start, for each cell
    H = [ent for i in xrange(d*d)] # same entropy for each cell at start (every pattern is valid)
    A = [[set() for dir in xrange(len(directions))] for i in xrange(npat)] #see below for explanation




    # Compute patterns compatibilities (check if some patterns are adjacent, if so -> store them based on their location)

    ''' EXAMPLE:
    If pattern index 42 can placed to the right of pattern index 120,
    we will store this adjacency rule as follow:

                     A[120][1].add(42)

    Here '1' stands for 'right' or 'East'/'E'

    0 = left or West/W
    1 = right or East/E
    2 = up or North/N
    3 = down or South/S '''

    # Comparing patterns to each other
    for i1 in xrange(npat):
        for i2 in xrange(npat):
            for dir in (0, 2):
                if compatible(patterns[i1], patterns[i2], dir):
                    A[i1][dir].add(i2)
                    A[i2][dir+1].add(i1)


def compatible(p1, p2, dir):

    '''NOTE: 
    what is refered as 'columns' and 'rows' here below is not really columns and rows 
    since we are dealing with 1D patterns. Remember here N = 3'''

    # If the first two columns of pattern 1 == the last two columns of pattern 2 
    # --> pattern 2 can be placed to the left (0) of pattern 1
    if dir == 0:
        return [n for i, n in enumerate(p1) if i%N!=2] == [n for i, n in enumerate(p2) if i%N!=0]

    # If the first two rows of pattern 1 == the last two rows of pattern 2
    # --> pattern 2 can be placed on top (2) of pattern 1
    if dir == 2:
        return p1[:6] == p2[-6:]



def draw():    # Equivalent of a 'while' loop in Processing (all the code below will be looped over and over until all cells are collapsed)
    global H, W, grid

    ### OBSERVATION
    # Find cell with minimum non-zero entropy (not collapsed yet)

    '''Randomly select 1 cell at the first iteration (when all entropies are equal), 
       otherwise select cell with minimum non-zero entropy'''

    emin = int(random(d*d)) if frameCount <= 1 else H.index(min(H)) 



    # Stoping mechanism

    ''' When 'H' array is full of 'collapsed' cells --> stop iteration '''

    if H[emin] == 'CONT' or H[emin] == 'collapsed': 
        print 'stopped'
        noLoop()
        return



    ### COLLAPSE
    # Weighted choice of a pattern

    ''' Among the patterns available in the selected cell (the one with min entropy), 
        select one pattern randomly, weighted by the frequency that pattern appears in the input image.
        With Python 2.7 no possibility to use random.choice(x, weight) so we have to hard code the weighted choice '''

    lfreqs = [b * freqs[i] for i, b in enumerate(W[emin])] # frequencies of the patterns available in the selected cell
    weights = [float(f) / sum(lfreqs) for f in lfreqs] # normalizing these frequencies
    cumsum = [sum(weights[:i]) for i in xrange(1, len(weights)+1)] # cumulative sums of normalized frequencies
    r = random(1)
    idP = sum([cs < r for cs in cumsum])  # index of selected pattern 

    # Set all patterns to False except for the one that has been chosen   
    W[emin] = [0 if i != idP else 1 for i, b in enumerate(W[emin])]

    # Marking selected cell as 'collapsed' in H (array of entropies)
    H[emin] = 'collapsed' 

    # Storing first color (top left corner) of the selected pattern at the location of the collapsed cell
    Output[emin] = patterns[idP][0]



    ### PROPAGATION
    # For each neighbor (left, right, up, down) of the recently collapsed cell
    for dir, t in enumerate(directions):
        x = (emin%d + t[0])%d
        y = (emin/d + t[1])%d
        idN = x + y * d #index of neighbor

        # If that neighbor hasn't been collapsed yet
        if H[idN] != 'collapsed': 

            # Check indices of all available patterns in that neighboring cell
            available = [i for i, b in enumerate(W[idN]) if b]

            # Among these indices, select indices of patterns that can be adjacent to the collapsed cell at this location
            intersection = A[idP][dir] & set(available) 

            # If the neighboring cell contains indices of patterns that can be adjacent to the collapsed cell
            if intersection:

                # Remove indices of all other patterns that cannot be adjacent to the collapsed cell
                W[idN] = [True if i in list(intersection) else False for i in xrange(npat)]


                ### Update entropy of that neighboring cell accordingly (less patterns = lower entropy)

                # If only 1 pattern available left, no need to compute entropy because entropy is necessarily 0
                if len(intersection) == 1: 
                    H[idN] = '0' # Putting a str at this location in 'H' (array of entropies) so that it doesn't return 0 (float) when looking for minimum entropy (min(H)) at next iteration


                # If more than 1 pattern available left --> compute/update entropy + add noise (to prevent cells to share the same minimum entropy value)
                else:
                    lfreqs = [b * f for b, f in izip(W[idN], freqs) if b] 
                    ent = math.log(sum(lfreqs)) - sum(map(lambda x: x * math.log(x), lfreqs)) / sum(lfreqs)
                    H[idN] = ent + random(.001)


            # If no index of adjacent pattern in the list of pattern indices of the neighboring cell
            # --> mark cell as a 'contradiction'
            else:
                H[idN] = 'CONT'



    # Draw output

    ''' dxd grid of cells (squares) filled with their corresponding color.      
        That color is the first (top-left) color of the pattern assigned to that cell '''

    for i, c in enumerate(Output):
        x, y = i%d, i/d
        fill(c)
        rect(x * xs, y * ys, xs, ys)

        # Displaying corresponding entropy value
        fill(0)
        text(H[i], x * xs + xs/2 - 12, y * ys + ys/2)

问题

尽管我尽了最大努力将上述所有步骤仔细地写入代码,但此实现 returns 非常奇怪且令人失望的结果:

20x20 输出示例

模式分布和邻接约束似乎都得到尊重(与输入中相同数量的蓝色、绿色、黄色和棕色以及相同的种类 个图案:水平地面,绿色茎)。

但是这些模式:

关于最后一点,我应该澄清的是,矛盾的状态是正常的,但应该很少发生(如 this paper and in this 文章第 6 页中间所述)

数小时的调试让我确信介绍性步骤(1 到 5)是正确的(计算和存储模式、邻接和熵计算、数组初始化)。 这让我认为 某些东西 一定与算法的核心部分(步骤 6 到 8) 脱节了。要么我错误地执行了这些步骤之一,要么我遗漏了逻辑的关键元素。

因此,我们将不胜感激有关此事的任何帮助!

此外,欢迎任何基于提供的脚本(使用或不使用处理)的答案

有用的额外资源:

详细 article from Stephen Sherratt and this explanatory paper 来自 Karth & Smith。 此外,为了比较,我建议检查其他 Python implementation(包含非强制性的回溯机制)。

注意:我已尽力使这个问题尽可能清楚(使用 GIF 和插图进行全面解释,带有有用链接和资源的完整注释代码)但是如果由于某些原因您决定否决它,请离开一个简短的评论来解释你为什么这样做。

在查看您的一个示例中链接的 live demo 并基于对原始算法代码的快速审查后,我认为您的错误在于 "Propagation" 步骤。

传播不仅仅是将相邻的 4 个单元格更新为折叠的单元格。您还必须递归地更新所有这些单元格的邻居,然后更新这些单元格的邻居等。好吧,具体来说,一旦你更新了一个相邻的单元格,你就会更新它的邻居(在到达第一个单元格的其他邻居之前),即 depth-first,而不是 breadth-first 更新。至少,这是我从现场演示中收集到的信息。

原算法的实际C#代码实现比较复杂,我也没有完全理解,但关键点似乎是"propagator"对象here, as well as the Propagate function itself, here的创建。

@mbrig 和@Leon 建议的假设是传播步骤遍历整个单元格堆栈(而不是限于一组 4 个直接邻居)是正确的。以下是在回答我自己的问题时尝试提供更多细节。

问题发生在第 7 步,传播时。原始算法确实更新了特定小区的 4 个直接邻居但是:

  • 该特定单元格的索引 依次被先前更新的邻居的索引替换
  • 每次折叠单元格时都会触发此级联过程
  • 最后 只要特定单元格的相邻模式在其相邻单元格的 1 个中可用

换句话说,正如评论中提到的,这是一种 递归 类型的传播,它不仅更新折叠单元格的邻居,还更新折叠单元格的邻居邻居...等等,只要邻接是可能的。

详细算法

折叠单元格后,其索引将放入堆栈中。该堆栈稍后用于临时存储相邻单元格的索引

stack = set([emin]) #emin = index of cell with minimum entropy that has been collapsed

只要堆栈中充满索引,传播就会持续:

while stack:

我们做的第一件事是 pop() 堆栈中包含的最后一个索引(目前唯一的一个)并获取其 4 个相邻单元格(E、W、N、S)的索引。我们必须让它们保持在边界内并确保它们环绕。

while stack:
    idC = stack.pop() # index of current cell
    for dir, t in enumerate(mat):
        x = (idC%w + t[0])%w
        y = (idC/w + t[1])%h
        idN = x + y * w  # index of neighboring cell

在继续之前,我们确保相邻单元格尚未折叠(我们不想更新只有 1 个可用模式的单元格):

        if H[idN] != 'c': 

然后我们检查所有可以被放置的模式。例如:如果相邻单元格在当前单元格的左侧(东侧),我们查看所有可以放置在当前单元格中包含的每个图案左侧的图案。

            possible = set([n for idP in W[idC] for n in A[idP][dir]])

我们还查看相邻单元格中可用的模式:

            available = W[idN]

现在我们确保相邻单元格真的必须更新。如果它的所有可用模式都已经在所有可能模式列表中 —> 就没有必要更新它(算法跳过这个邻居并继续下一个):

            if not available.issubset(possible):

但是,如果它不是 possible 列表的子集 —> 我们看一下 交集两组(所有可以放置在那个位置的图案,"luckily",在同一位置可用):

                intersection = possible & available

如果它们不相交(可以放置在那里但不可用的图案),则意味着我们 运行 变成 "contradiction"。我们必须停止整个 WFC 算法。

                if not intersection:
                    print 'contradiction'
                    noLoop()

相反,如果它们确实相交 --> 我们用改进后的模式索引列表更新相邻单元格:

                W[idN] = intersection

由于相邻单元格已更新,因此其熵也必须更新:

                lfreqs = [freqs[i] for i in W[idN]]
                H[idN] = (log(sum(lfreqs)) - sum(map(lambda x: x * log(x), lfreqs)) / sum(lfreqs)) - random(.001)

最后,也是最重要的,我们将该相邻单元格的索引添加到堆栈中,使其依次成为下一个 当前 单元格(其邻居将在期间更新的单元格)下一个 while 循环):

                stack.add(idN)

完整更新脚本

from collections import Counter
from itertools import chain
from random import choice

w, h = 40, 25
N = 3

def setup():
    size(w*20, h*20, P2D)
    background('#FFFFFF')
    frameRate(1000)
    noStroke()

    global W, A, H, patterns, freqs, npat, mat, xs, ys

    img = loadImage('Flowers.png') 
    iw, ih = img.width, img.height
    xs, ys = width//w, height//h
    kernel = [[i + n*iw for i in xrange(N)] for n in xrange(N)]
    mat = ((-1, 0), (1, 0), (0, -1), (0, 1))
    all = []

    for y in xrange(ih):
        for x in xrange(iw):
            cmat = [[img.pixels[((x+n)%iw)+(((a[0]+iw*y)/iw)%ih)*iw] for n in a] for a in kernel]
            for r in xrange(4):
                cmat = zip(*cmat[::-1])
                all.append(cmat)
                all.append(cmat[::-1])
                all.append([a[::-1] for a in cmat])

    all = [tuple(chain.from_iterable(p)) for p in all] 
    c = Counter(all)
    patterns = c.keys()
    freqs = c.values()
    npat = len(freqs) 

    W = [set(range(npat)) for i in xrange(w*h)] 
    A = [[set() for dir in xrange(len(mat))] for i in xrange(npat)]
    H = [100 for i in xrange(w*h)] 

    for i1 in xrange(npat):
        for i2 in xrange(npat):
            if [n for i, n in enumerate(patterns[i1]) if i%N!=(N-1)] == [n for i, n in enumerate(patterns[i2]) if i%N!=0]:
                A[i1][0].add(i2)
                A[i2][1].add(i1)
            if patterns[i1][:(N*N)-N] == patterns[i2][N:]:
                A[i1][2].add(i2)
                A[i2][3].add(i1)


def draw():    
    global H, W

    emin = int(random(w*h)) if frameCount <= 1 else H.index(min(H)) 

    if H[emin] == 'c': 
        print 'finished'
        noLoop()

    id = choice([idP for idP in W[emin] for i in xrange(freqs[idP])])
    W[emin] = [id]
    H[emin] = 'c' 

    stack = set([emin])
    while stack:
        idC = stack.pop() 
        for dir, t in enumerate(mat):
            x = (idC%w + t[0])%w
            y = (idC/w + t[1])%h
            idN = x + y * w 
            if H[idN] != 'c': 
                possible = set([n for idP in W[idC] for n in A[idP][dir]])
                if not W[idN].issubset(possible):
                    intersection = possible & W[idN] 
                    if not intersection:
                        print 'contradiction'
                        noLoop()
                        return

                    W[idN] = intersection
                    lfreqs = [freqs[i] for i in W[idN]]
                    H[idN] = (log(sum(lfreqs)) - sum(map(lambda x: x * log(x), lfreqs)) / sum(lfreqs)) - random(.001)
                    stack.add(idN)

    fill(patterns[id][0])
    rect((emin%w) * xs, (emin/w) * ys, xs, ys)

整体改进

除了这些修复之外,我还对 speed-up 观察和传播步骤进行了一些小的代码优化,并缩短了加权选择计算。

  • "Wave" 现在由 Python 指数 组成,其大小随着单元格 "collapsed"(替换大型固定大小的布尔值列表)而减少。

  • 熵存储在一个 defaultdict 中,其键被逐步删除。

  • 起始熵值被替换为 运行dom 整数(不需要第一次熵计算,因为开始时等概率的高度不确定性)

  • 单元格显示一次(避免将它们存储在数组中并在每一帧重新绘制)

  • 加权选择现在是one-liner(避免了列表理解的几行可有可无的行)