Python 动态规划性能差异

Python dynamic programming performance difference

我正在通过做 Leetcode 问题来学习动态编程,即使我正在缓存我的结果,我也经常遇到超出时间限制的错误。谁能解释一下为什么我的版本比官方版本慢这么多 this 问题?

代码上有明显的区别,比如我用了class函数递归,而官方的回答没有。我的递归函数 returns 数值,官方的没有,等等。None 虽然这些看起来是有意义的差异,但性能差异仍然是巨大的。

我的版本。这需要 0.177669 秒到 运行,并收到超出时间限制的错误。

import datetime as dt
from typing import List
from functools import lru_cache


class Solution:
    def canPartition(self, nums: List[int]) -> bool:
        self.nums = nums
        total = sum(self.nums)
        if total % 2 == 1:
            return False
        half_total = total // 2
        return self.traverse(half_total, 0) == 0

    @lru_cache(maxsize=None)
    def traverse(self, subset_sum, index):
        if subset_sum < 0:
            return float('inf')
        elif index == len(self.nums):
            return subset_sum
        else:
            include = self.traverse(subset_sum - self.nums[index], index + 1)
            exclude = self.traverse(subset_sum, index + 1)
            best = min(include, exclude)
            return best


test_case = [20,68,68,11,48,18,50,5,3,51,52,11,13,11,38,100,30,87,1,56,85,63,14,96,7,17,54,11,32,61,94,13,85,10,78,57,69,92,66,28,70,20,3,29,10,73,89,86,28,48,69,54,87,11,91,32,59,4,88,20,81,100,29,75,79,82,6,74,66,30,9,6,83,54,54,53,80,94,64,77,22,7,22,26,12,31,23,26,65,65,35,36,34,1,12,44,22,73,59,99]
solution = Solution()
start = dt.datetime.now()
print(solution.canPartition(test_case))
end = dt.datetime.now()
print((end-start).total_seconds())

这是官方回答。只需0.000165秒!

import datetime as dt
from typing import List, Tuple
from functools import lru_cache


class Solution:
    def canPartition(self, nums: List[int]) -> bool:
        @lru_cache(maxsize=None)
        def dfs(nums: Tuple[int], n: int, subset_sum: int) -> bool:
            # Base cases
            if subset_sum == 0:
                return True
            if n == 0 or subset_sum < 0:
                return False
            result = (dfs(nums, n - 1, subset_sum - nums[n - 1])
                    or dfs(nums, n - 1, subset_sum))
            return result

        # find sum of array elements
        total_sum = sum(nums)

        # if total_sum is odd, it cannot be partitioned into equal sum subsets
        if total_sum % 2 != 0:
            return False

        subset_sum = total_sum // 2
        n = len(nums)
        return dfs(tuple(nums), n - 1, subset_sum)


test_case = [20,68,68,11,48,18,50,5,3,51,52,11,13,11,38,100,30,87,1,56,85,63,14,96,7,17,54,11,32,61,94,13,85,10,78,57,69,92,66,28,70,20,3,29,10,73,89,86,28,48,69,54,87,11,91,32,59,4,88,20,81,100,29,75,79,82,6,74,66,30,9,6,83,54,54,53,80,94,64,77,22,7,22,26,12,31,23,26,65,65,35,36,34,1,12,44,22,73,59,99]
solution = Solution()
start = dt.datetime.now()
print(solution.canPartition(test_case))
end = dt.datetime.now()
print((end-start).total_seconds())

如果您想了解性能,您需要剖析您的代码。分析可以让您了解您的代码将时间花在哪里。

CPython 带有名为 cProfile 的内置分析模块。 但是您可能想看看例如line_profiler.

在旧版本中,搜索所有可能的情况。而在后者中,算法在找到可行解时停止。

在第一个版本中:

include = self.traverse(subset_sum - self.nums[index], index + 1)
# Suppose {include} is zero, the answer is already obtained, 
# but the algorithm still try to compute {exclude}, which is not neccessary.
exclude = self.traverse(subset_sum, index + 1)

在第二个版本中:

result = (dfs(nums, n - 1, subset_sum - nums[n - 1])
                    or dfs(nums, n - 1, subset_sum))
# Because of the short-circuit behavior of logical operator,
# if the first branch has already obtained the solution, 
# the second branch will not be executed.

只需添加 if-check 即可提高性能:

include = self.traverse(subset_sum - self.nums[index], index + 1)
# Check whether we are already done:
if include == 0:
    return include
exclude = self.traverse(subset_sum, index + 1)
  • 您的函数 return 是一个在执行期间在 float 和 int 之间变化的数字。 python 必须在整个执行过程中应对它。而您需要 return 对“可以分区吗?”的问题回答“是或否”。和简单的布尔值 Ture/False 就足够了,建议使用。
  • 出于与上述相同的原因,您正在对两个递归获得的结果使用比较函数 min,您必须 运行 两次递归到它们的最深层次。通过使用布尔值,您可以简化此过程,而此其他程序使用该快捷方式。