我如何在伪多项式时间内从 table 文本字符串中找到字符串的最佳编码?

How do i find the best encoding of a string from a table of text strings in pseudopolynomial time?

问题:

Consider the following data compression technique. We have a table of m text strings, each at most k in length. We want to encode a data string D of length n using as few text strings as possible. For example, if our table contains (a,ba,abab,b) and the data string is bababbaababa, the best way to encode it is (b,abab,ba,abab,a)—a total of five code words. Give an O(nmk) algorithm to find the length of the best encoding. You may assume that every string has at least one encoding in terms of the table.

在skiena的Algorithm design manual书上发现了这个问题,并尝试解决

我最好的猜测是将字符串 (D) 的所有可能子串与 table 长度 (m) 中的所有文本进行匹配,时间复杂度:O(n*n*k*m).

有没有更好的方法在 O(nmk) 伪多项式时间内解决这个问题?

如果我误解了请纠正我,但听起来你的解决方案的时间复杂度比 O(nnkm) 更大。

要生成包含 D 的所有元素的所有可能分区的列表,需要 O(nn)。然后,对于每个分区中的每个子集(其中最多有 n 个子集),您需要验证是否存在 table 中的匹配项,每个子集需要 O(km)。在 n*n 个可能的分区之后(实际上不是 n*n,而是 n 的 Bell 数),匹配最多 n 个集合,每个集合的复杂度为 O(nnnmk)

也就是说,我认为尝试构建一棵树是一种更快的方法。

从D开头开始,可以尝试匹配table中的所有字符串。 table 中的每个字符串代表一个分支,如果 table 字符串在第一个位置不匹配 D,则分支终止。如果确实匹配,则该过程从刚刚匹配的 table 字符串的末尾开始重复。

此方法的一个技巧是您只需要保留产生唯一长度的一系列匹配项。以广度优先的方式重复此过程意味着如果已经发现一组达到特定长度的 table 个字符串,则保证再次达到该长度的任何一组 table 个字符串包含来自 table 的相同数量或更多字符串,但仍匹配 D 的相同子字符串。例如,table 字符串 {abab, ab} 都将产生长度为 'abab' 的字符串4 个,但其中一个会在第一次迭代时匹配,一个会在第二次迭代时匹配。在此示例中,包含 'ab'+'ab' 的编码总是比包含 'abab' 的编码更差,因此我们丢弃 'ab'+'ab'编码。可以使用一个简单的 O(1) 查找 table 来检查这一点,之前看到的任何长度也可以终止它们的分支。因此,最多可以发现 n 个唯一长度,每个长度只能发现一次,这意味着最多有 n 个可能的分支。

table中的字符串仍然必须每个都检查,这仍然是 O(mk)。添加分支,我相信 这会产生 O(nmk) 的复杂度,因为最多创建 n 个分支。

例如,字符串 'aaaaaa' 上的 table {a, aa} 每次迭代将包含 2 个分支,共 3 次迭代。每个分支都需要 O(mk) 来检查,所以这个例子的时间复杂度是 O(nmk)。

编辑:由于对我提议的实现存在一些怀疑,我用 C 实现了它并做了一些分析。


#include <stdlib.h>
#include <string.h>
#include <stdio.h>

const char* alphabet = "abcdef";
const int alphabetLen = strlen(alphabet);

void printTable(char** table, int k){
    printf("Table: {%s", table[0]);
    for(int i=1; i<k; i++){
        printf(", %s", table[i]);
    }
    printf("}\n");
}

char** generateTable(int k, int m){

    char** table = (char**)malloc(sizeof(char*)*k);
    for(int i=0; i<k; i++){
        int tableStringLen = (rand() % (m-1)) + 1;
        table[i] = (char*)malloc(tableStringLen+1);
        for(int j=0; j<tableStringLen; j++){
            table[i][j] = alphabet[rand() % alphabetLen];
        }
        table[i][tableStringLen] = 0;
    }

    return table;
}

// Unfortunately, the length of the string is partially determined by the size of m
// it's done this way to avoid generating a random string and ensuring the table can encode it
char* generateString(char** table, int k, int m){
    int minLength = rand()%200 + m*2;
    char* string = (char*)malloc(minLength + m);

    int j = 0;
    for(int i=0; j<minLength; i++){
        int tableStringIndex = rand() % k;
        strcpy(string+j, table[tableStringIndex]);
        j += strlen(table[tableStringIndex]);
    }
    string[j] = 0;

    return string;
}

void printSolution(char* string, int* branchLengths, int n){
    if(!n){ return; }
    printSolution(string, branchLengths, branchLengths[n]);
    printf("%.*s ", n-branchLengths[n], string+branchLengths[n]);
}

int* solve(char* string, char** table, int n, int m, int k, int* comparisons){
    int* currentBranchEnds = (int*)calloc(sizeof(int), n); // keeps track of all of the current ends
    int currentBranchCount = 1; // keeps track of how many active branch ends there are

    int* newBranchEnds = (int*)calloc(sizeof(int), n); // keeps track of the new branches on each iteration
    int newBranchCount = 0;

    int* branchLengths = (int*)calloc(sizeof(int), (n+1)); // used to reconstruct the solution in the end

    // used for O(1) length lookups
    int* tableStringLengths = (int*)malloc(sizeof(int) * k);
    for(int i=0; i<k; i++){ tableStringLengths[i] = strlen(table[i]); }

    *comparisons = 0;

    // continue searching while the entire string hasn't been encoded
    while(!branchLengths[n]){

        // for every active branch
        for(int i=0; i<currentBranchCount; i++){


            // try all table strings
            for(int j=0; j<k; j++){
                int newBranchEnd = currentBranchEnds[i] + tableStringLengths[j];
                
                // if the new length (branch end) already exists OR
                // if the new branch would be too long for the string, discard the branch
                *comparisons += 1;
                if(newBranchEnd > n || branchLengths[newBranchEnd]){ continue; }

                // check to see if the table string matches the target string at this position
                // could be done with strcmp, but is done in a loop here to be explicit about complexity
                char match = 1;
                for(int l=0; table[j][l]; l++){
                    *comparisons += 1;
                    if(string[currentBranchEnds[i] + l] != table[j][l]){
                        match = 0;
                        break;
                    }
                }

                // if it matches, we can create a new branch at this position
                *comparisons += 1;
                if(match){
                    branchLengths[newBranchEnd] = currentBranchEnds[i];
                    newBranchEnds[newBranchCount] = newBranchEnd;
                    newBranchCount += 1;
                }
            }
        }

        // swap the branch ends arrays to save copying
        int* tmp = currentBranchEnds;
        currentBranchEnds = newBranchEnds;
        newBranchEnds = tmp;

        currentBranchCount = newBranchCount;
        newBranchCount = 0;
    }

    free(currentBranchEnds);
    free(newBranchEnds);
    free(tableStringLengths);

    return branchLengths;
}

int main(){
    int k = rand() % 30 + 2;
    int m = rand() % 15 + 2;

    char** table = generateTable(k, m);
    printTable(table, k);

    char* string = generateString(table, k, m);
    int n = strlen(string);
    printf("String: %s\n", string);

    int comparisons;
    int* solution = solve(string, table, n, m, k, &comparisons);
    printf("Solution: ");
    printSolution(string, solution, n);
    printf("\n");
    printf("Comparisons: %d\n", comparisons);

    for(int i=0; i<k; i++){ free(table[i]); }
    free(table);
    free(solution);
    free(string);
}


分析是由 运行 算法对随机生成的 n、m 和 k 值生成 >500000 次。 “比较”的数量在每个 if 语句中递增,如代码所示,并且绘制了每个 n、m 和 k 值的平均比较次数。

比较次数(计算值)与 n、m 和 k 的大小之间显然存在线性关系。这表明复杂度为 O(nmk)

O(nmk) 解

from collections import defaultdict
from typing import Set

def smallest_cut(sentence: str, words: Set[str]) -> int:
    lookup = [1e10 for i in range(len(sentence)+1)]
    step_results = defaultdict(set)
    step_results[0].add(sentence)
    curr_step = 1
    while step_results[curr_step-1]: # this line and next line has in fact O(n) because of 'lookup'
        for curr_sentence in step_results[curr_step-1]: # look to comment above
            for word in words: # O(m)
                if curr_sentence.startswith(word): # O(k)
                    new_sentence = curr_sentence[len(word):]  # remove prefix
                    if new_sentence:
                        if lookup[len(new_sentence)] > curr_step:
                            lookup[len(new_sentence)] = curr_step  # cheat with lookup for O(1) complexity instead O(log(n))
                            step_results[curr_step].add(new_sentence)
                    else: # we found shotrest encoding
                        return curr_step
        curr_step += 1

smallest_cut('bababbaababa', {'ba', 'a', 'abab', 'b'})

我们可以逐步地,对于句子的每个前缀,通过尝试将一个词附加到更小的前缀来构造一个具有最少切割次数的解决方案。

words = {'ba', 'a', 'abab', 'b'}
sentence = 'bababbaababa'

def smallest_cut(sentence, words):
    max_word_length = max(len(w) for w in words)
    words_by_length = [[] for i in range(max_word_length + 1)]
    for w in words:
        words_by_length[len(w)].append(w)

    results_by_length = dict()
    results_by_length[0] = []

    for i in range(len(sentence) + 1):
        for j in range(max(0, i-max_word_length), i):
            if j in results_by_length:
                for w in words_by_length[i-j]:
                    if sentence[j:i] == w:
                        if i not in results_by_length or len(results_by_length[j]) + 1 < len(results_by_length[i]):
                            results_by_length[i] = results_by_length[j] + [w]

    return results_by_length[len(sentence)]

print(smallest_cut(sentence, words))

第一个for循环是n次迭代,第二个和第三个组合是m次迭代,比较sentence[j:i] == wk次操作。

Python 中的记忆化解决方案。 solve(i) 计算前 i 个字母的结果,所以答案是 solve(n)。对于每个前缀长度 i,尝试所有单词作为该前缀的可能后缀。

O(nmk) 因为有 n 个可能的 i 个参数,每个参数检查 m 个单词,每个单词检查最多花费 k 个字母检查。

Try it online!

from functools import lru_cache

words = 'a', 'ba', 'abab', 'b'
sentence = 'bababbaababa'

@lru_cache
def solve(i):
    return i and min((solve(i - len(word)) + 1
                      for word in words
                      if sentence.endswith(word, None, i)),
                     default=float('inf'))

print(solve(len(sentence)))