完美正方形 leetcode - 带记忆的递归解决方案

perfect squares leetcode - recursive solution with memoization

试图通过递归和记忆解决 问题,但对于输入 7168,我得到了错误的答案。

    public int numSquares(int n) {
        Map<Integer, Integer> memo = new HashMap();
        List<Integer> list = fillSquares(n, memo);
        if (list == null)
            return 1;
        return helper(list.size()-1, list, n, memo);
    }
    
    private int helper(int index, List<Integer> list, int left, Map<Integer, Integer> memo) {
        
        if (left == 0)
            return 0;
        if (left < 0 || index < 0)
            return Integer.MAX_VALUE-1;
        
        if (memo.containsKey(left)) {
            return memo.get(left);
        }
        
        int d1 = 1+helper(index, list, left-list.get(index), memo);
        int d2 = 1+helper(index-1, list, left-list.get(index),  memo);
        int d3 = helper(index-1, list, left, memo);
        
        int d = Math.min(Math.min(d1,d2), d3);
        memo.put(left, d);
        return d;
    }
    
    private List<Integer> fillSquares(int n, Map<Integer, Integer> memo) {
        int curr = 1;
        List<Integer> list = new ArrayList();
        int d = (int)Math.pow(curr, 2);
        while (d < n) {
            list.add(d);
            memo.put(d, 1);
            curr++;
            d = (int)Math.pow(curr, 2);
        }
        if (d == n)
            return null;
        return list;
    }

我是这样打电话的:

numSquares(7168)

所有测试用例都通过了(甚至是复杂的用例),但是这个失败了。我怀疑我的记忆有问题,但无法准确指出是什么。任何帮助将不胜感激。

您的记忆以要获得的价值为键,但这并没有考虑 index 的价值,这实际上限制了您可以使用哪些权力来获得该价值。这意味着如果(在极端情况下)index 为 0,则您只能减少剩下的一平方 (1²),这很少是形成该数字的最佳方式。因此,在第一个实例中,memo.set() 将注册 non-optimal 个方块,稍后将由递归树中挂起的其他递归调用更新。

如果您添加一些条件调试代码,您会看到 map.set 被多次调用以获得 left 的相同值,并且具有不同的值。这不好,因为这意味着 if (memo.has(left)) 块将在不能保证该值是最佳的情况下执行(还)。

您可以通过在记忆密钥中加入 index 来解决这个问题。这增加了用于记忆的 space,但它会起作用。我想你可以解决这个问题。

但是根据 Lagrange's four square theorem 每个自然数最多可以写成四个平方和,所以 returned 值永远不应该是 5 或更多。当你通过了那个数量的术语时,你可以缩短递归。这降低了使用记忆的好处。

最后,fillSquares有一个错误:当它是一个完美的正方形时,它也应该加上n,否则你找不到应该return的解决方案1 .

  • 不确定你的错误,这是一个简短的动态规划解决方案:

Java

public class Solution {
    public static final int numSquares(
        final int n
    ) {
        int[] dp = new int[n + 1];
        Arrays.fill(dp, Integer.MAX_VALUE);
        dp[0] = 0;

        for (int i = 1; i <= n; i++) {
            int j = 1;
            int min = Integer.MAX_VALUE;

            while (i - j * j >= 0) {
                min = Math.min(min, dp[i - j * j] + 1);
                ++j;
            }

            dp[i] = min;
        }

        return dp[n];
    }
}

C++

// Most of headers are already included;
// Can be removed;
#include <iostream>
#include <cstdint>
#include <vector>
#include <algorithm>

// The following block might slightly improve the execution time;
// Can be removed;
static const auto __optimize__ = []() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);
    return 0;
}();


#define MAX INT_MAX

using ValueType = std::uint_fast32_t;

struct Solution {
    static const int numSquares(
        const int n
    ) {
        if (n < 1) {
            return 0;
        }

        static std::vector<ValueType> count_perfect_squares{0};

        while (std::size(count_perfect_squares) <= n) {
            const ValueType len = std::size(count_perfect_squares);
            ValueType count_squares = MAX;

            for (ValueType index = 1; index * index <= len; ++index) {
                count_squares = std::min(count_squares, 1 + count_perfect_squares[len - index * index]);
            }

            count_perfect_squares.emplace_back(count_squares);
        }

        return count_perfect_squares[n];
    }
};

int main() {
    std::cout <<  std::to_string(Solution().numSquares(12) == 3) << "\n";

    return 0;
}

Python

  • 这里我们可以简单地使用lru_cache:
class Solution:
    dp = [0]
    @functools.lru_cache
    def numSquares(self, n):
        dp = self.dp
        while len(dp) <= n:
            dp += min(dp[-i * i] for i in range(1, int(len(dp) ** 0.5 + 1))) + 1, 
        return dp[n]

这里是 LeetCode 的官方解法和注释:

Java: DP

class Solution {

  public int numSquares(int n) {
    int dp[] = new int[n + 1];
    Arrays.fill(dp, Integer.MAX_VALUE);
    // bottom case
    dp[0] = 0;

    // pre-calculate the square numbers.
    int max_square_index = (int) Math.sqrt(n) + 1;
    int square_nums[] = new int[max_square_index];
    for (int i = 1; i < max_square_index; ++i) {
      square_nums[i] = i * i;
    }

    for (int i = 1; i <= n; ++i) {
      for (int s = 1; s < max_square_index; ++s) {
        if (i < square_nums[s])
          break;
        dp[i] = Math.min(dp[i], dp[i - square_nums[s]] + 1);
      }
    }
    return dp[n];
  }
}

Java:贪婪

class Solution {
  Set<Integer> square_nums = new HashSet<Integer>();

  protected boolean is_divided_by(int n, int count) {
    if (count == 1) {
      return square_nums.contains(n);
    }

    for (Integer square : square_nums) {
      if (is_divided_by(n - square, count - 1)) {
        return true;
      }
    }
    return false;
  }

  public int numSquares(int n) {
    this.square_nums.clear();

    for (int i = 1; i * i <= n; ++i) {
      this.square_nums.add(i * i);
    }

    int count = 1;
    for (; count <= n; ++count) {
      if (is_divided_by(n, count))
        return count;
    }
    return count;
  }
}

Java:广度优先搜索

class Solution {
  public int numSquares(int n) {

    ArrayList<Integer> square_nums = new ArrayList<Integer>();
    for (int i = 1; i * i <= n; ++i) {
      square_nums.add(i * i);
    }

    Set<Integer> queue = new HashSet<Integer>();
    queue.add(n);

    int level = 0;
    while (queue.size() > 0) {
      level += 1;
      Set<Integer> next_queue = new HashSet<Integer>();

      for (Integer remainder : queue) {
        for (Integer square : square_nums) {
          if (remainder.equals(square)) {
            return level;
          } else if (remainder < square) {
            break;
          } else {
            next_queue.add(remainder - square);
          }
        }
      }
      queue = next_queue;
    }
    return level;
  }
}

Java:使用数学的最​​有效解决方案

  • 运行时间:O(N^0.5)
  • 内存:O(1)
class Solution {

  protected boolean isSquare(int n) {
    int sq = (int) Math.sqrt(n);
    return n == sq * sq;
  }

  public int numSquares(int n) {
    // four-square and three-square theorems.
    while (n % 4 == 0)
      n /= 4;
    if (n % 8 == 7)
      return 4;

    if (this.isSquare(n))
      return 1;
    // enumeration to check if the number can be decomposed into sum of two squares.
    for (int i = 1; i * i <= n; ++i) {
      if (this.isSquare(n - i * i))
        return 2;
    }
    // bottom case of three-square theorem.
    return 3;
  }
}