在 pycosat 中将 dnf 慢到 cnf

Slow dnf to cnf in pycosat

简答题

要为 pycosat 提供正确的输入,有没有办法加快从 dnf 到 cnf 的计算,或者完全绕过它?

问题详细

我一直在看this video from Raymond Hettinger about modern solvers. I downloaded the code, and implemented a solver for the game Towers里面的。下面我分享了这样做的代码。

塔谜题示例(已解决):

    3 3 2 1    
---------------
3 | 2 1 3 4 | 1
3 | 1 3 4 2 | 2
2 | 3 4 2 1 | 3
1 | 4 2 1 3 | 2
---------------
    1 2 3 2    

我遇到的问题是从dnf到cnf的转换需要很长时间。假设您知道从特定视线可以看到 3 座塔。这导致该行有 35 种可能的排列 1-5。

[('AA 1', 'AB 2', 'AC 5', 'AD 3', 'AE 4'),
 ('AA 1', 'AB 2', 'AC 5', 'AD 4', 'AE 3'),
 ...
 ('AA 3', 'AB 4', 'AC 5', 'AD 1', 'AE 2'),
 ('AA 3', 'AB 4', 'AC 5', 'AD 2', 'AE 1')]

这是一个析取范式:几个 AND 语句的 OR。这需要转换为合取范式:几个 OR 语句的 AND。然而,这非常慢。在我的 Macbook Pro 上,单行 5 分钟后它没有完成此 cnf 的计算。对于整个拼图,最多应完成 20 次(对于 5x5 网格)。

优化此代码以使计算机能够解决此 Towers 难题的最佳方法是什么?

此代码也可从 this Github repository 获得。

import string

import itertools
from sys import intern
from typing import Collection, Dict, List

from sat_utils import basic_fact, from_dnf, one_of, solve_one

Point = str


def comb(point: Point, value: int) -> str:
    """
    Format a fact (a value assigned to a given point), and store it into the interned strings table

    :param point: Point on the grid, characterized by two letters, e.g. AB
    :param value: Value of the cell on that point, e.g. 2
    :return: Fact string 'AB 2'
    """

    return intern(f'{point} {value}')


def visible_from_line(line: Collection[int], reverse: bool = False) -> int:
    """
    Return how many towers are visible from the given line

    >>> visible_from_line([1, 2, 3, 4])
    4
    >>> visible_from_line([1, 4, 3, 2])
    2
    """

    visible = 0
    highest_seen = 0
    for number in reversed(line) if reverse else line:
        if number > highest_seen:
            visible += 1
            highest_seen = number
    return visible


class TowersPuzzle:
    def __init__(self):
        self.visible_from_top = [3, 3, 2, 1]
        self.visible_from_bottom = [1, 2, 3, 2]
        self.visible_from_left = [3, 3, 2, 1]
        self.visible_from_right = [1, 2, 3, 2]
        self.given_numbers = {'AC': 3}

        # self.visible_from_top = [3, 2, 1, 4, 2]
        # self.visible_from_bottom = [2, 2, 4, 1, 2]
        # self.visible_from_left = [3, 2, 3, 1, 3]
        # self.visible_from_right = [2, 2, 1, 3, 2]

        self._cnf = None
        self._solution = None

    def display_puzzle(self):
        print('*** Puzzle ***')
        self._display(self.given_numbers)

    def display_solution(self):
        print('*** Solution ***')
        point_to_value = {point: value for point, value in [fact.split() for fact in self.solution]}
        self._display(point_to_value)

    @property
    def n(self) -> int:
        """
        :return: Size of the grid
        """

        return len(self.visible_from_top)

    @property
    def points(self) -> List[Point]:
        return [''.join(letters) for letters in itertools.product(string.ascii_uppercase[:self.n], repeat=2)]

    @property
    def rows(self) -> List[List[Point]]:
        """
        :return: Points, grouped per row
        """

        return [self.points[i:i + self.n] for i in range(0, self.n * self.n, self.n)]

    @property
    def cols(self) -> List[List[Point]]:
        """
        :return: Points, grouped per column
        """

        return [self.points[i::self.n] for i in range(self.n)]

    @property
    def values(self) -> List[int]:
        return list(range(1, self.n + 1))

    @property
    def cnf(self):
        if self._cnf is None:
            cnf = []

            # Each point assigned exactly one value
            for point in self.points:
                cnf += one_of(comb(point, value) for value in self.values)

            # Each value gets assigned to exactly one point in each row
            for row in self.rows:
                for value in self.values:
                    cnf += one_of(comb(point, value) for point in row)

            # Each value gets assigned to exactly one point in each col
            for col in self.cols:
                for value in self.values:
                    cnf += one_of(comb(point, value) for point in col)

            # Set visible from left
            if self.visible_from_left:
                for index, row in enumerate(self.rows):
                    target_visible = self.visible_from_left[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(row, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from right
            if self.visible_from_right:
                for index, row in enumerate(self.rows):
                    target_visible = self.visible_from_right[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm, reverse=True) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(row, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from top
            if self.visible_from_top:
                for index, col in enumerate(self.cols):
                    target_visible = self.visible_from_top[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(col, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from bottom
            if self.visible_from_bottom:
                for index, col in enumerate(self.cols):
                    target_visible = self.visible_from_bottom[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm, reverse=True) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(col, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set given numbers
            for point, value in self.given_numbers.items():
                cnf += basic_fact(comb(point, value))

            self._cnf = cnf

        return self._cnf

    @property
    def solution(self):
        if self._solution is None:
            self._solution = solve_one(self.cnf)
        return self._solution

    def _display(self, facts: Dict[Point, int]):
        top_line = '    ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_top]) + '    '
        print(top_line)
        print('-' * len(top_line))
        for index, row in enumerate(self.rows):
            elems = [str(self.visible_from_left[index]) or ' ', '|'] + \
                    [str(facts.get(point, ' ')) for point in row] + \
                    ['|', str(self.visible_from_right[index]) or ' ']
            print(' '.join(elems))
        print('-' * len(top_line))
        bottom_line = '    ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_bottom]) + '    '
        print(bottom_line)
        print()


if __name__ == '__main__':
    puzzle = TowersPuzzle()
    puzzle.display_puzzle()
    puzzle.display_solution()

实际花费在这个辅助函数上的时间来自视频附带的辅助代码。

def from_dnf(groups) -> 'cnf':
    'Convert from or-of-ands to and-of-ors'
    cnf = {frozenset()}
    for group_index, group in enumerate(groups, start=1):
        print(f'Group {group_index}/{len(groups)}')
        nl = {frozenset([literal]): neg(literal) for literal in group}
        # The "clause | literal" prevents dup lits: {x, x, y} -> {x, y}
        # The nl check skips over identities: {x, ~x, y} -> True
        cnf = {clause | literal for literal in nl for clause in cnf
               if nl[literal] not in clause}
        # The sc check removes clauses with superfluous terms:
        #     {{x}, {x, z}, {y, z}} -> {{x}, {y, z}}
        # Should this be left until the end?
        sc = min(cnf, key=len)  # XXX not deterministic
        cnf -= {clause for clause in cnf if clause > sc}
    return list(map(tuple, cnf))

使用 4x4 网格时 pyinstrument 的输出显示这里的行 cnf = { ... } 是罪魁祸首:

  _     ._   __/__   _ _  _  _ _/_   Recorded: 21:05:58  Samples:  146
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.515     CPU time: 0.506
/   _/                      v3.4.2

Program: ./src/towers.py

0.515 <module>  ../<string>:1
   [7 frames hidden]  .., runpy
      0.513 _run_code  runpy.py:62
      └─ 0.513 <module>  towers.py:1
         ├─ 0.501 display_solution  towers.py:64
         │  └─ 0.501 solution  towers.py:188
         │     ├─ 0.408 cnf  towers.py:101
         │     │  ├─ 0.397 from_dnf  sat_utils.py:65
         │     │  │  ├─ 0.329 <setcomp>  sat_utils.py:73
         │     │  │  ├─ 0.029 [self]
         │     │  │  ├─ 0.021 min  ../<built-in>:0
         │     │  │  │     [2 frames hidden]  ..
         │     │  │  └─ 0.016 <setcomp>  sat_utils.py:79
         │     │  └─ 0.009 [self]
         │     └─ 0.093 solve_one  sat_utils.py:53
         │        └─ 0.091 itersolve  sat_utils.py:43
         │           ├─ 0.064 translate  sat_utils.py:32
         │           │  ├─ 0.049 <listcomp>  sat_utils.py:39
         │           │  │  ├─ 0.028 [self]
         │           │  │  └─ 0.021 <listcomp>  sat_utils.py:39
         │           │  └─ 0.015 make_translate  sat_utils.py:12
         │           └─ 0.024 itersolve  ../<built-in>:0
         │                 [2 frames hidden]  ..
         └─ 0.009 <module>  typing.py:1
               [26 frames hidden]  typing, abc, ..

首先,最好注意等价性和等可满足性之间的区别。通常,将任意布尔公式(例如 DNF 中的某些内容)转换为 CNF 会导致大小呈指数 blow-up。

这个 blow-up 是您的 from_dnf 方法的问题:每当您处理另一个产品术语时,该产品中的每个 文字都需要一个新的当前 cnf 子句集的副本(它将在每个子句中添加自己)。如果您有 n 个大小为 k 的乘积项,则增长为 O(k^n).

在你的例子中 n 实际上是 k! 的函数。保留为乘积项的内容被过滤为满足视图约束的那些,但总体而言,程序的运行时间大致在 O(k^f(k!)) 范围内。即使 f 呈对数增长,这仍然是 O(k^(k lg k)) 而不是很理想!

因为你问的是“这是可满足的吗?”,你不需要一个等价的公式,而只需要一个可满足的公式.这是一些新公式,当且仅当原始公式是可满足的,但相同的赋值可能不满足。

例如,(a ∨ b)(a ∨ c) ∧ (¬b)都是明显可满足的,所以它们是等可满足的。但是设置 b true 满足第一个并伪造第二个,所以它们不等价。此外,第一个甚至没有 c 作为变量,再次使其不等同于第二个。

这种松弛足以用 线性 大小的平移代替这个指数 blow-up。


关键思想是使用扩展变量。这些是允许我们缩写表达式的新变量(即,尚未出现在公式中),因此我们最终不会在翻译中制作它们的多个副本。由于原始变量中不存在新变量,因此我们将不再有等效公式;但是因为当且仅当表达式为真时,变量才为真,它将是可满足的。

如果我们想使用 x 作为 y 的缩写,我们会声明 x ≡ y。这与x → yy → x相同,与(¬x ∨ y) ∧ (¬y ∨ x)相同,后者已经在CNF中。

考虑产品术语的缩写:x ≡ (a ∧ b)。这是 x → (a ∧ b)(a ∧ b) → x,结果是三个子句:(¬x ∨ a) ∧ (¬x ∨ b) ∧ (¬a ∨ ¬b ∨ x)。通常,用 x 缩写 k 个文字的乘积项将产生 k 个二进制子句,表示 x 表示它们中的每一个,以及一个 (k+1) 子句表示它们一起表示 x。这是线性的 k.

要真正了解为什么这有帮助,请尝试将 (a ∧ b ∧ c) ∨ (d ∧ e ∧ f) ∨ (g ∧ h ∧ i) 转换为等效的 CNF,有和没有第一个乘积项的扩展变量。当然,我们不会只停留在一个术语上:如果我们缩写每个术语,那么结果恰好是单个 CNF 子句:(x ∨ y ∨ z),其中每个都缩写一个产品术语。这个小多了!

这种方法可用于将任何 电路 转化为一个等可满足的公式,在大小和 CNF 中是线性的。这称为 Tseitin transformation。你的 DNF 公式只是一个由一堆任意 fan-in 与门组成的电路,所有这些都馈入一个任意 fan-in 或门。

最重要的是,虽然这个公式由于附加变量而不等价,但我们可以通过简单地删除扩展变量来恢复原始公式的赋值。它是一种 'best case' 可满足公式,是原始公式的严格超集。


为了将其修补到您的代码中,我添加了:

# Uses pseudo-namespacing to avoid collisions.
_EXT_SUFFIX = "___"
_NEXT_EXT_INDEX = 0


def is_ext_var(element) -> bool:
    return element.endswith(_EXT_SUFFIX)


def ext_var() -> str:
    global _NEXT_EXT_INDEX
    ext_index = _NEXT_EXT_INDEX
    _NEXT_EXT_INDEX += 1

    return intern(f"{ext_index}{_EXT_SUFFIX}")

这让我们可以凭空提取一个新的命名变量。由于这些扩展变量名称对您的解决方案显示功能没有有意义的语义,因此我更改了:

point_to_value = {
    point: value for point, value in [fact.split() for fact in self.solution]
}

进入:

point_to_value = {
    point: value
    for point, value in [
        fact.split() for fact in self.solution if not is_ext_var(fact)
    ]
}

肯定有更好的方法来做到这一点,这只是一个补丁。 :)

用上述想法重新实现你的from_dnf,我们得到:

def from_dnf(groups) -> "cnf":
    "Convert from or-of-ands to and-of-ors, equisatisfiably"
    cnf = []

    extension_vars = []
    for group in groups:
        extension_var = ext_var()
        neg_extension_var = neg(extension_var)

        imply_ext_clause = []
        for literal in group:
            imply_ext_clause.append(neg(literal))
            cnf.append((neg_extension_var, literal))

        imply_ext_clause.append(extension_var)
        cnf.append(tuple(imply_ext_clause))

        extension_vars.append(extension_var)

    cnf.append(tuple(extension_vars))
    return cnf

每组获得一个扩展变量。组中的每个文字都将其否定添加到 (k+1) 大小的蕴含子句中,并由扩展隐含。处理文字后,扩展变量完成剩余的蕴涵并将其自身添加到新扩展变量列表中。最后,这些扩展变量中至少有一个必须为真。

仅此一项更改就可以让我立即解决这个 5x5 难题:

self.visible_from_top = [3, 2, 1, 4, 2]
self.visible_from_bottom = [2, 2, 4, 1, 2]
self.visible_from_left = [3, 2, 3, 1, 3]
self.visible_from_right = [2, 2, 1, 3, 2]
self.given_numbers = {}

我也添加了一些定时输出:

@property
def solution(self):
    if self._solution is None:
        start_time = time.perf_counter()

        cnf = self.cnf
        cnf_time = time.perf_counter()
        print(f"CNF: {cnf_time - start_time}s")

        self._solution = solve_one(cnf)
        end_time = time.perf_counter()
        print(f"Solve: {end_time - cnf_time}s")
    return self._solution

5x5 拼图给我:

CNF: 0.00565183162689209s
Solve: 0.005589433014392853s

但是,在枚举可行的塔高排列时,我们仍然遇到令人讨厌的 k! 增长。

我生成了a 9x9 puzzle(站点允许的最大),对应于:

self.visible_from_top = [3, 3, 3, 3, 1, 4, 2, 4, 2]
self.visible_from_bottom = [3, 1, 4, 2, 5, 3, 3, 2, 3]
self.visible_from_left = [3, 3, 1, 2, 4, 5, 2, 3, 2]
self.visible_from_right = [3, 1, 7, 4, 3, 3, 2, 2, 4]
self.given_numbers = {
    "AB": 5,
    "AD": 4,
    "BD": 3,
    "BE": 2,
    "CD": 7,
    "CF": 5,
    "CG": 1,
    "DB": 1,
    "DH": 7,
    "EA": 4,
    "EI": 2,
    "FA": 2,
    "FE": 8,
    "GG": 7,
    "GI": 6,
    "HA": 3,
    "HF": 2,
    "HH": 1,
    "IG": 6,
}

这给了我:

CNF: 28.505195066332817s
Solve: 40.48229135945439s

我们应该花更多的时间解决问题,减少生成的时间,但将近一半的时间用于生成。

在我看来,在 CNF-SAT 翻译中使用 DNF离子通常是错误方法的标志。求解器比我们更善于探索和了解解决方案 space — 花费阶乘的时间 pre-exploring 实际上比求解器的指数更差情况更糟。

对于DNF'fall back'是可以理解的,因为程序员天生就想着“写一个能给出解的算法”。但是,当您 将其编码到问题 中时,求解器的真正好处就会显现出来。让求解器推理解决方案变得不可行的条件。为此,我们要从电路的角度来思考。我们很幸运,我们也知道如何快速将电路变成 CNF。

†我说的是“经常”;如果您的 DNF 很小并且可以快速生成(如单个电路门),或者如果将其编码到电路中非常复杂,那么它可能有助于 pre-compute 某些解决方案 space。


您实际上已经完成了其中的一些工作!例如,我们需要一个电路来计算某个数字在一个跨度(行或列)中出现的次数,并断言这个数字正好是一个。然后对于每个跨度和每个数字,我们将发出此电路。这样,如果一个大小的塔,例如3 连续出现两次,该行的 3 计数器将发出“2”,我们断言它是“1”的说法将不成立。

您的 one_of 约束是此的 one possible 实现。你的使用 'obvious' 成对编码:对于跨度中的每个位置,如果 N 存在于该位置,则它不存在于任何其他位置。这实际上是一种非常好的编码,因为它几乎完全由二进制子句组成,而 SAT 求解器喜欢二进制子句(它们使用的内存少得多并且经常传播)。但是对于要计算的大量事物,这种 O(n^2) 缩放可能会成为一个问题。

你可以想象另一种方法,你直接编码一个 adder circuit:每个位置都是电路的输入位,电路产生 n 位输出告诉你最终的总和(上面的论文是一个很好的阅读!)。然后您使用强制特定输出位的单元子句断言这个总和正好是一个。

对电路进行编码只是为了强制其某些输出为恒定值似乎是多余的。然而,这更容易推理,现代求解器知道编码可以做到这一点并针对它进行优化。它们执行的 in-processing 比初始编码过程合理执行的要复杂得多。使用求解器的 'art' 在于了解和测试这些替代编码何时比其他编码更有效。

请注意 exactly_k_ofat_least_k_ofat_most_k_of。您已在 Q class == 实施中注意到这一点。实现 at_least_1_of 是微不足道的,是一个子句; at_most_1_of 很常见,所以通常简称为 AMO。我鼓励您尝试以本文中讨论的其他一些方式实现 <>(甚至可能根据输入大小选择使用哪个)以感受它。


将我们的注意力转回到 k! 可见性约束,我们需要的是一个电路,它告诉我们从某个方向可以看到多少座塔,然后我们可以断言它是一个特定的值。

停下来想一想这是怎么做到的,不容易啊!

类似于各种 one_of 方法,我们可以使用 'pure' 电路进行计数或使用更简单但 worse-scaling 成对方法。我在这个答案的最底部 (‡) 附上了纯电路方法的草图。现在我们将使用成对方法。

要进行的主要观察是在 non-visible 塔中,我们不关心它们的排列。考虑:

3 -> 1 5 _ _ _ 9 _ _ _
     A B C D E F G H I

只要 CDE 组包含 234,我们就会从左边看到 3 个塔,同样如果 GHI 组包含 678。但是它们在组中出现的顺序对可见塔没有影响。

而不是计算哪些塔是可见的,我们将声明哪些塔是可见的并遵循它们的含义。我们将填写此函数:

def visible_in_span(points: Collection[str], desired: int) -> "cnf":
    """Assert desired visible towers in span. Wlog, visibility is from index 0."""
    points = list(points)
    n = len(points)
    assert desired <= n

    cnf = []

    # ...

    return cnf

假设一个固定的跨度和观察方向:每个位置都有 k 个相关变量,Av1Avk 说明“这是第 k 个可见的塔”。我们还将有 Av ≡ (Av1 ∨ Av2 ∨ ⋯ ∨ Avk) 意思是“A 有一座可见的塔”。

在上面的例子中,Av1Bv2Fv3都为真。有一些明显的含义要发出。在一个位置,其中最多一个是真的(你不能同时是第一座和第二座可见的塔)——但不是恰好一个,因为有一个non-visible塔。另一个是,如果一个位置是第k个可见塔,那么没有其他位置也是第k个可见塔。

到目前为止我们可以添加:

is_kth_visible_tower_at = {}
is_kth_visible_tower_vars = collections.defaultdict(list)
is_visible_tower_at = {}
for point in points:
    is_visible_tower_vars = []
    for k in range(1, n + 1):
        # Xvk
        is_kth_visible_tower_var = ext_var()

        is_kth_visible_tower_at[(point, k)] = is_kth_visible_tower_var
        is_kth_visible_tower_vars[k].append(is_kth_visible_tower_var)
        is_visible_tower_vars.append(is_kth_visible_tower_var)

    # Xv
    is_visible_tower_at_var = ext_var()
    # Xv → (Xv1 ∨ Xv2 ∨ ⋯)
    cnf.append(tuple([neg(is_visible_tower_at_var)] + is_visible_tower_vars))
    # (Xv1 ∨ Xv2 ∨ ⋯) → Xv
    for is_visible_tower_var in is_visible_tower_vars:
        cnf.append((neg(is_visible_tower_var), is_visible_tower_at_var))

    is_visible_tower_at[point] = is_visible_tower_at_var

    # At most one visible tower here.
    cnf += Q(is_visible_tower_vars) <= 1

# At most one kth visible tower anywhere.
for k in range(1, n + 1):
    cnf += Q(is_kth_visible_tower_vars[k]) <= 1

接下来我们需要订购ig 在可见塔中,因此第 k + 1 个可见塔位于第 k 个可见塔之后。这是通过第 k+1 个可见塔迫使至少一个先前位置成为第 k 个可见塔来实现的。例如,Dv3 → (Av2 ∨ Bv2 ∨ Cv2)Cv2 → (Av1 ∨ Bv1)。我们知道 Av1 始终为真,这提供了基本情况。 (如果我们进入需要 B 是第三个可见塔的情况,这将要求 A 是第二个可见塔,这与 Av1 相矛盾。)

# Towers are ordered.
for index, point in enumerate(points):
    if index == 0:
        cnf += basic_fact(is_kth_visible_tower_at[(point, 1)])
        continue

    for k in range(1, n + 1):
        # Xvk → ⋯
        implication = [neg(is_kth_visible_tower_at[(point, k)])]

        j = k - 1
        if j > 0:
            for index_j, point_j in enumerate(points):
                if index_j == index:
                    break

                # ⋯ ∨ Wxj ∨ ⋯
                implication.append(is_kth_visible_tower_at[(point_j, j)])

        cnf.append(tuple(implication))

目前一切顺利,但我们尚未将塔高与能见度相关联。以上将允许 9 8 7 作为解决方案,调用 9 第一个可见塔, 8 第二个, 7 第三个。为了解决这个问题,我们需要一个塔放置来禁止较小的塔也可见。

每个位置将再次收到一组缩写,指示它是否在特定高度以下被遮挡,称为 Ao1Ao2,等等。这会给我们带来 'grid' 的启示,让事情变得更简单。第一个是较高的塔被遮挡意味着同一位置的下一个最高塔也被遮挡,因此 Ao3 → Ao2Ao2 → Ao1。第二个是,如果一座塔在一个位置被遮挡,那么它在以后的所有位置也会被遮挡。这是 Ao3 → Bo3Bo3 → Co3 等等。

is_height_obscured_at = {}
is_height_obscured_previous = [None] * n
for point in points:
    is_obscured_previous = None
    for k in range(1, n + 1):
        # Xok
        is_height_obscured_var = ext_var()

        # Wok → Xok
        is_k_obscured_previous = is_height_obscured_previous[k - 1]
        if is_k_obscured_previous is not None:
            cnf.append((neg(is_k_obscured_previous), is_height_obscured_var))

        # Xok → Xo(k-1)
        if is_obscured_previous is not None:
            cnf.append((neg(is_height_obscured_var), is_obscured_previous))

        is_height_obscured_at[(point, k)] = is_height_obscured_var
        is_height_obscured_previous[k - 1] = is_height_obscured_var
        is_obscured_previous = is_height_obscured_var

从这里很容易看出,例如Bo4 表示其余高度等于或小于 4 的塔全部被遮挡。我们现在可以很容易地将塔的放置与默默无闻联系起来:A5 → Bo4.

# A placed tower obscures smaller later towers.
for index, point in enumerate(points):
    if index + 1 == len(points):
        break

    next_point = points[index + 1]
    for k in range(2, n + 1):
        j = k - 1

        # Xk → Yo(k-1)
        cnf.append((neg(comb(point, k)), is_height_obscured_at[(next_point, j)]))

最后,我们需要将模糊性与可见性联系起来。我们需要最后一组缩写,说明在某个位置可以看到 特定的 塔高。冒着容易出现拼写错误的风险,我们将某些高度 h 称为 Ahv,因此 Ahv ≡ (Ah ∧ Av)。一个具体的例子是 C3v ≡ (C3 ∧ Cv):当且仅当在 C 处有一座塔可见,并且该塔是高度为 3 的塔时,在 C 处可见高度为 3 的塔。

is_height_visible_at = {}
for point in points:
    for k in range(1, n + 1):
        # Xhv
        height_visible_at_var = ext_var()

        # Xhv ≡ (Xh ∧ Xv)
        cnf.append((neg(height_visible_at_var), comb(point, k)))
        cnf.append((neg(height_visible_at_var), is_visible_tower_at[point]))
        cnf.append(
            (
                neg(comb(point, k)),
                neg(is_visible_tower_at[point]),
                height_visible_at_var,
            )
        )

        is_height_visible_at[(point, k)] = height_visible_at_var

这使我们能够发出对塔放置的最终影响。如果高度为 h 的塔被遮挡,则它不可见:Bo4 → ¬B4v。这是不是等价,我们不能把Bo4 ≡ ¬B4v;可能 ¬B4v 成立是因为 B4 根本没有放在那里(但它是可见的!)。

for point in points:
    for k in range(1, n + 1):
        # Xok → ¬Xkv
        cnf.append(
            (
                neg(is_height_obscured_at[(point, k)]),
                neg(is_height_visible_at[(point, k)]),
            )
        )

为了将此与 puzzle-specific 可见性值联系起来,我们只需要禁止太多可见的塔并确保所需的数量至少可见一个(因此恰好一次):

# At least one of the towers is the desired kth visible.
cnf.append(tuple(is_kth_visible_tower_vars[desired]))

# None of the towers can be visible above the desired kth.
if desired < n:
    for is_kth_visible_tower_var in is_kth_visible_tower_vars[desired + 1]:
        cnf += basic_fact(neg(is_kth_visible_tower_var))

return cnf

我们只需要挡住第一层不需要的第k个可见塔。由于第 k+1 层将暗示存在第 k 层可见塔,因此它也被排除在外。 (等等。)

最后,我们将其挂接到 CNF 构建器中:

# Set visible from left
if self.visible_from_left:
    for index, row in enumerate(self.rows):
        target_visible = self.visible_from_left[index]
        if not target_visible:
            continue

        cnf += visible_in_span(row, target_visible)

# Set visible from right
if self.visible_from_right:
    for index, row in enumerate(self.rows):
        target_visible = self.visible_from_right[index]
        if not target_visible:
            continue

        cnf += visible_in_span(reversed(row), target_visible)

# Set visible from top
if self.visible_from_top:
    for index, col in enumerate(self.cols):
        target_visible = self.visible_from_top[index]
        if not target_visible:
            continue

        cnf += visible_in_span(col, target_visible)

# Set visible from bottom
if self.visible_from_bottom:
    for index, col in enumerate(self.cols):
        target_visible = self.visible_from_bottom[index]
        if not target_visible:
            continue

        cnf += visible_in_span(reversed(col), target_visible)

上面更快地给出了 9x9 的解决方案:

CNF: 0.028973951935768127s
Solve: 0.07169117406010628s

大约快了 685 倍,求解器完成了更多的整体工作。又快又脏还不错!

有很多方法可以清理它。例如,为了可读性,我们看到的每个地方 cnf.append((neg(a), b)) 都可以是 cnf += implies(a, b)。我们可以避免分配毫无意义的大的第 k 个可见变量,等等。

这不是well-tested;我可能错过了一些含义或规则。希望此时修复起来很容易。


最后要说的是SAT的适用性。也许现在痛苦地清楚了,SAT 求解器并不擅长计数和算术。你必须降低到一个电路,从求解过程中隐藏 higher-level 语义。

其他方法可以让您自然地表达算术、区间、集合等。答案集编程 (ASP) 就是一个例子,SMT 求解器是另一个例子。对于小问题 SAT 很好,但是对于困难的问题,这些 higher-level 方法可以大大简化问题。

这些人中的每一个实际上可能在内部决定通过 SAT-solving(特别是 SMT)进行推理,但他们将在对问题有一定 higher-level 理解的情况下这样做。


‡ 这是计算塔的纯电路方法。

是否优于pairwise将取决于被计算的塔数;也许常数因子太高以至于永远没有用,或者即使在小尺寸时它也很有用。老实说,我不知道——我以前编写过巨大的电路并且让它们工作得很好。需要实验才知道。

我将调用 Ah 位置 A 中塔的整数高度。也就是说,而不是 A1 或 [=114= 的 one-hot 编码] 或 … 或 A9 我们将 Ah0Ah1、… 和 Ahn 作为 n-bit 整数的低位到高位(统称为 Ah).对于 9x9 的限制,4 位就足够了。我们还将有 BhCh 等。

您可以使用A1 ≡ (¬Ah3 ∧ ¬Ah2 ∧ ¬Ah1 ∧ Ah0)A2 ≡ (¬Ah3 ∧ ¬Ah2 ∧ Ah1 ∧ ¬Ah0)A3 ≡ (¬Ah3 ∧ ¬Ah2 ∧ Ah1 ∧ Ah0)等连接两个表示。当且仅当设置了 A3 时,我们才有 Ah = 3。 (我们不需要添加只有一个值的约束Ah 一次是可能的,因为与每个关联的 one-hot 变量已经这样做了。)

有了整数,可能更容易了解如何计算可见性。我们可以将每个位置与最大可见塔高度相关联,命名为 AmBm,依此类推;显然第一座塔总是可见的,而且是最高的,所以 Am ≡ Ah。同样,这实际上是一个 n-bit 值 Am0Amn.

当且仅当塔的值大于先前看到的最高值时,塔才可见。我们将使用 AvBv 等来跟踪可见性。这可以通过 digital comparator;所以Bv ≡ Bh > Am。 (Av 是一个基本案例,并且总是正确的。)

这让我们也可以填写其余的最大值。 Bm ≡ Bv ? Bh : Am,等等。 conditional/if-then-else/ite是一个数字多路复用器。对于简单的 2 比 1,这很简单:Bv ? Bh : Am(Bv ∧ Bh) ∨ (¬Bv ∧ Am),每个 i ∈ 0..n.

实际上是 (Bv ∧ Bhi) ∨ (¬Bv ∧ Ami)

然后,我们将有一堆输入 AvIv 进入加法器电路,告诉我们这些输入中有多少是真的(即有多少塔是可见的)。这将是另一个 n-bit 值;然后我们使用单元子句来断言它正是例如3,如果特定的谜题需要 3 个可见的塔。

我们为每个方向的每个跨度生成相同的电路。这将是规则的一些 polynomial-sized 编码,添加许多扩展变量和许多子句。求解器可以了解某个塔放置不可行,不是因为我们这么说,而是因为它暗示了一些不可接受的中间可见性。 “应该有 4 个可见,而 2 个已经可见,所以我只剩下...”。