如何记忆长度为 n 的递归路径搜索

How to memoize recursive path of length n search

第一次发帖时我想试试这个社区。

我已经研究了几个小时,但我似乎找不到足够接近的例子来获取灵感。我不关心答案是什么语言,但更喜欢 java、c/c++ 或伪代码。

我正在寻找网格中长度为 n 的连续路径。

我找到了一个递归解决方案,我认为它很干净并且始终有效,但是如果路径数量太多,运行时间很短。我意识到我可以迭代地实现它,但我想先找到一个递归解决方案。

我不在乎答案是什么语言,但我更喜欢 java、c/c++。

问题是—— 对于 String[] 和 int pathLength,该长度有多少条路径。

{ "ABC", "CBZ", "CZC", "BZZ", "ZAA" } 长度为 3

This is the 3rd and 7th path from below.

A B C    A . C    A B .    A . .    A . .    A . .    . . .
. . .    . B .    C . .    C B .    . B .    . B .    . . .
. . .    . . .    . . .    . . .    C . .    . . C    C . .
. . .    . . .    . . .    . . .    . . .    . . .    B . .
. . .    . . .    . . .    . . .    . . .    . . .    . A .
(spaces are for clarity only) 

return 7 条长度为 3 (A-B-C) 的可能路径

这是最初的递归解决方案

public class SimpleRecursive {

    private int ofLength;
    private int paths = 0;
    private String[] grid;

    public int count(String[] grid, int ofLength) {
        this.grid = grid;
        this.ofLength = ofLength;
        paths = 0;

        long startTime = System.currentTimeMillis();
        for (int j = 0; j < grid.length; j++) {
            for (int index = grid[j].indexOf('A'); index >= 0; index = grid[j].indexOf('A', index + 1)) {

                recursiveFind(1, index, j);

            }

        }
        System.out.println(System.currentTimeMillis() - startTime);
        return paths;
    }

    private void recursiveFind(int layer, int x, int y) {

        if (paths >= 1_000_000_000) {

        }

        else if (layer == ofLength) {

            paths++;

        }

        else {

            int xBound = grid[0].length();
            int yBound = grid.length;

            for (int dx = -1; dx <= 1; ++dx) {
                for (int dy = -1; dy <= 1; ++dy) {
                    if (dx != 0 || dy != 0) {
                        if ((x + dx < xBound && y + dy < yBound) && (x + dx >= 0 && y + dy >= 0)) {
                            if (grid[y].charAt(x) + 1 == grid[y + dy].charAt(x + dx)) {

                                recursiveFind(layer + 1, x + dx, y + dy);

                            }

                        }

                    }
                }
            }

        }
    }
}

这非常慢,因为每个新字母都可以衍生出 8 次递归,因此复杂性飙升。

我决定使用记忆来提高性能。

这是我想到的。

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

public class AlphabetCount {

    private int ofLength;
    private int paths = 0;
    private String[] grid;
//  This was an optimization that helped a little.  It would store possible next paths  
//  private HashMap<Integer, ArrayList<int[]>> memoStack = new HashMap<Integer, ArrayList<int[]>>();
    //hashmap of indices that are part of a complete path(memoization saves)
    private HashMap<Integer, int[]> completedPath = new HashMap<Integer, int[]>();
    //entry point 
    public int count(String[] grid, int ofLength) {
        this.grid = grid;
        //Since i find the starting point ('A') by brute force then i just need the next n-1 letters
        this.ofLength = ofLength - 1;
        //variable to hold number of completed runs
        paths = 0;

        //holds the path that was taken to get to current place.  determined that i dont really need to memoize 'Z' hence ofLength -1 again
        List<int[]> fullPath = new ArrayList<int[]>(ofLength - 1);

        //just a timer to compare optimizations
        long startTime = System.currentTimeMillis();

        //this just loops around finding the next 'A'
        for (int j = 0; j < grid.length; j++) {
            for (int index = grid[j].indexOf('A'); index >= 0; index = grid[j].indexOf('A', index + 1)) {

                //into recursive function.  fullPath needs to be kept in this call so that it maintains state relevant to call stack?  also the 0 here is technically 'B' because we already found 'A'
                recursiveFind(fullPath, 0, index, j);

            }

        }
        System.out.println(System.currentTimeMillis() - startTime);
        return paths;
    }

    private void recursiveFind(List<int[]> fullPath, int layer, int x, int y) {
        //hashing key. mimics strings tohash.  should not have any duplicates to my knowledge
        int key = 31 * (x) + 62 * (y) + 93 * layer;

        //if there is more than 1000000000 paths then just stop counting and tell me its over 1000000000
        if (paths >= 1_000_000_000) {

        //this if statement never returns true unfortunately.. this is the optimization that would actually help me.
        } else if (completedPath.containsKey(key)) {
            paths++;
            for (int i = 0; i < fullPath.size() - 1; i++) {
                int mkey = 31 * fullPath.get(i)[0] + 62 * fullPath.get(i)[1] + 93 * (i);
                if (!completedPath.containsKey(mkey)) {
                    completedPath.put(mkey, fullPath.get(i));
                }
            }

        }
        //if we have a full run then save the path we took into the memoization hashmap and then increase paths 
        else if (layer == ofLength) {

            for (int i = 0; i < fullPath.size() - 1; i++) {
                int mkey = 31 * fullPath.get(i)[0] + 62 * fullPath.get(i)[1] + 93 * (i);
                if (!completedPath.containsKey(mkey)) {
                    completedPath.put(mkey, fullPath.get(i));
                }
            }

            paths++;

        }


//everything with memoStack is an optimization that i used that increased performance marginally.
//      else if (memoStack.containsKey(key)) {
//          for (int[] path : memoStack.get(key)) {
//              recursiveFind(fullPath,layer + 1, path[0], path[1]);
//          }
//      } 

        else {

            int xBound = grid[0].length();
            int yBound = grid.length;

            // ArrayList<int[]> newPaths = new ArrayList<int[]>();
            int[] pair = new int[2];

            //this loop checks indices adjacent in all 8 directions ignoring index you are in then checks to see if you are out of bounds then checks to see if one of those directions has the next character
            for (int dx = -1; dx <= 1; ++dx) {
                for (int dy = -1; dy <= 1; ++dy) {
                    if (dx != 0 || dy != 0) {
                        if ((x + dx < xBound && y + dy < yBound) && (x + dx >= 0 && y + dy >= 0)) {
                            if (grid[y].charAt(x) + 1 == grid[y + dy].charAt(x + dx)) {

                                pair[0] = x + dx;
                                pair[1] = y + dy;
                                // newPaths.add(pair.clone());
                                //not sure about this... i wanted to save space by not allocating everything but i needed fullPath to only have the path up to the current call
                                fullPath.subList(layer, fullPath.size()).clear();
                                //i reuse the int[] pair so it needs to be cloned
                                fullPath.add(pair.clone());
                                //recursive call
                                recursiveFind(fullPath, layer + 1, x + dx, y + dy);

                            }

                        }

                    }
                }
            }
            // memoStack.putIfAbsent(key, newPaths);

            // memo thought! if layer, x and y are the same as a successful runs then you can use a
            // previous run

        }
    }

}

问题是我的记忆从未真正被使用过。递归调用有点模仿深度优先搜索。前

     1
   / | \
  2  5  8
 /\  |\  |\
3 4  6 7 9 10

因此保存 运行 不会与另一个 运行 以任何节省性能的方式重叠,因为它在返回调用堆栈之前在树的底部进行搜索。所以问题是......我如何记住这个?或者一旦我得到一个完整的 运行 我如何递归回到树的开头以便我写的记忆有效。

真正杀死性能的测试字符串是 { "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "ABCDEFGHIJKLMNOPQRSTUVWXYZ" }; 对于所有长度为 26 的路径 (应该return1000000000)

PS。作为第一次发布者,任何关于一般代码改进或不良编码习惯的评论都将不胜感激。另外,因为我之前没有发帖让我知道这个问题是否不清楚或格式不正确或太长等。

我不确定你在记忆什么(也许你可以用文字解释一下?)但是这里似乎有重叠的子问题。如果我理解正确,除了 "A",一个字母的任何特定实例都只能从字母表中相邻的前一个字母到达。这意味着我们可以存储来自字母的每个特定实例的路径数。当在后续场合到达该特定实例时,我们可以避免递归到它。

深度优先搜索:

 d1 d2 d3 d4
   c1   c2
      b
    a1 a2

 .....f(c1) = f(d1) + f(d2) = 2
 .....f(c2) = f(d3) + f(d4) = 2
 ...f(b) = f(c1) + f(c2) = 4
 f(a1) = f(b) = 4
 f(a2) = f(b) = 4

嗯,我想通了!部分感谢@גלעד ברקן 的推荐。我最初只是试图通过说任何两条路径有任何相同的索引然后它是一个完整的路径,所以我们不必进一步递归,这是一个严重的过度简化。我最终写了一个小图形可视化工具,这样我就可以准确地看到我在看什么。 (这是上面的第一个例子({ "ABC"、"CBZ"、"CZC"、"BZZ"、"ZAA" },长度为 3))

L代表layer-每一层对应一个字母即layer 1 == 'A'

由此我确定每个节点都可以保存从它出发的完整路径的数量。在图片中,这意味着节点 L[2]X1Y1 将被赋予数字 4,因为任何时候您到达该节点时都有 4 条完整路径。

无论如何,我记忆到一个 int[][] 中,所以我想从这里做的唯一一件事就是制作一个哈希图,这样就不会浪费太多 space。

这是我想出的代码。

package practiceproblems;

import java.util.ArrayDeque;


public class AlphabetCount {

    private int ofLength;
    private int paths = 0;
    private String[] grid;

    //this is the array that we memoize.  could be hashmap
    private int[][] memoArray;// spec says it initalizes to zero so i am just going with it

    //entry point func
    public int count(String[] grid, int ofLength) {

        //initialize all the member vars
        memoArray = new int[grid[0].length()][grid.length];
        this.grid = grid;
        // this is minus 1 because we already found "A"
        this.ofLength = ofLength - 1;
        paths = 0;
        //saves the previous nodes visited.
        ArrayDeque<int[]> goodPathStack = new ArrayDeque<int[]>();


        long startTime = System.currentTimeMillis();
        for (int j = 0; j < grid.length; j++) {
            for (int index = grid[j].indexOf('A'); index >= 0; index = grid[j].indexOf('A', index + 1)) {
                //kinda wasteful to clone i would think... but easier because it stays in its stack frame
                recursiveFind(goodPathStack.clone(), 0, index, j);

            }

        }
        System.out.println(System.currentTimeMillis() - startTime);
        //if we have more than a bil then just return a bil
        return paths >= 1_000_000_000 ? 1_000_000_000 : paths;
    }

    //recursive func
    private void recursiveFind(ArrayDeque<int[]> fullPath, int layer, int x, int y) {

        //debugging
        System.out.println("Preorder " + layer + " " + (x) + " " + (y));

        //start pushing onto the path stack so that we know where we have been in a given recursion
        int[] pair = { x, y };
        fullPath.push(pair);

        if (paths >= 1_000_000_000) {
            return;

            //we found a full path 'A' thru length
        } else if (layer == this.ofLength) {

            paths++;
            //poll is just pop for deques apparently.
            // all of these paths start at 'A' which we find manually. so pop last.
            // all of these paths incluse the last letter which wouldnt help us because if
            // we find the last letter we already know we are at the end.
            fullPath.pollFirst();
            fullPath.pollLast();

            //this is where we save memoization info
            //each node on fullPath leads to a full path
            for (int[] p : fullPath) {
                memoArray[p[0]][p[1]]++;
            }
            return;

        } else if (memoArray[x][y] > 0) {

            //this is us using our memoization cache
            paths += memoArray[x][y];
            fullPath.pollLast();
            fullPath.pollFirst();
            for (int[] p : fullPath) {
                memoArray[p[0]][p[1]] += memoArray[x][y];
            }

        }

//      else if (memoStack.containsKey(key)) {
//          for (int[] path : memoStack.get(key)) {
//              recursiveFind(fullPath,layer + 1, path[0], path[1]);
//          }
//      } 

        else {

            int xBound = grid[0].length();
            int yBound = grid.length;

            //search in all 8 directions for a letter that comes after the one that you are on.
            for (int dx = -1; dx <= 1; ++dx) {
                for (int dy = -1; dy <= 1; ++dy) {
                    if (dx != 0 || dy != 0) {
                        if ((x + dx < xBound && y + dy < yBound) && (x + dx >= 0 && y + dy >= 0)) {
                            if (grid[y].charAt(x) + 1 == grid[y + dy].charAt(x + dx)) {

                                recursiveFind(fullPath.clone(), layer + 1, x + dx, y + dy);

                            }
                        }

                    }

                }
            }
        }

        // memoStack.putIfAbsent(key, newPaths);

        // memo thought! if one runs layer, x and y are the same then you can use a
        // previous run

    }

}

有效!并且完成 1_000_000_000 路径所需的时间减少了很多。像次秒。

希望这个例子可以帮助那些最终被难住了好几天的人。