可能优化或使用并行计算

Optimization possible or use parallel computing

我遇到这个问题,我需要找到等于一个数的幂和的数量。例如:

一个输入 100 2 会产生 3 的输出,因为 100 = 10^2 = 6^2 + 8^2 = 1^2 + 3^2 + 4^2 + 5^2 + 7^2100 3 的输入会产生 1 的输出,因为 100 = 1^3 + 2^3 + 3^3 + 4^3

所以我解决这个问题的函数是:

findNums :: Int -> Int -> Int
findNums a b = length [xs | xs <- (drop 1 .) subsequences [pow x b | x <- [1..c]], foldr (+) (head xs) (tail xs) == a] where c = root a b 0 

root :: Int -> Int -> Int -> Int
root n a i
    | pow i a <= n && pow (i+1) a > n = i
    | otherwise = root n a (i+1)

pow :: Int -> Int -> Int
pow _ 0 = 1
pow n a = n * pow n (a - 1)

我找到了我的数字集中所有可能的值,这些值加起来等于所需的数字。然后我在该列表中找到所有可能的子列表,并查看其中有多少加起来达到所需的数量。这是可行的,但由于它是一个详尽的搜索,因此需要很长时间才能输入 800 2。是否可以优化序列以便仅返回 "plausible" 子序列?还是这种问题最好看一下并行计算?

正如您所建议的,这里有一些优化空间而无需诉诸并行化(请记住,如果您要从一个到四个并行线程,并行化最多可以提供 4 倍的加速)。

subsequences 函数所做的基本上是遍历列表,并且它为每个元素创建两个执行分支:一个包含该元素的分支,一个不包含该元素的分支。即,subsequences [1,2,3] 会:

                           start
                   /-------/   \-------\         (take 1 or not)
             [1,..]                    [..]
            /      \                  /    \     (take 2 or not)
    [1,2,..]        [1,..]       [2,..]    [..]
      /  \           /  \         /  \     /  \  (take 3 or not)
[1,2,3]  [1,2]   [1,3]  [1]   [2,3]  [2] [3]  []

subsequences [1,2,3]的结果是一个包含底部叶​​节点的列表。

现在,在每个中间节点(即[1,2,..]),我们可以检查将值函数(即平方和或立方之和等)应用于已取数字的结果.如果我们已经超过了目标,那么继续该分支就没有意义了。如果我们自己写这个子序列生成逻辑,我们可以这样做:

findNums :: Int -> Int -> Int
findNums a b = findNums' a b 1 0

findNums' :: Int -> Int -> Int -> Int -> Int
findNums' a b c s
  | s + c^b > a  = 0
  | s + c^b == a = 1
  | otherwise    = findNums' a b (c+1) (s + c^b) +
                   findNums' a b (c+1) s

这里c是我们的计数器,s是我们选择的数字的幂和。 findNums'中有三种情况:

在第一种情况下,我们检查包括这个数字是否会使我们超出目标。在这种情况下,该分支不会给出任何有效结果,因此我们终止它并通过 returning 0.

指示它不包含任何解决方案

在第二种情况下,我们检查包括这个数字是否会让我们正确。在那种情况下,我们 return 1,基本上注意到我们已经找到了解决方案。

如果其中 none 为真,我们将进一步递归两个不同的分支:一个是我们将 c^b 添加到我们的总和,另一个是我们不这样做。我们把结果加在一起,也就是说这里的结果会是这个点以下找到有效解的分支数。

让我们浏览一下一些东西。

基准测试

首先:让我们确保我们确实在进行改进!为此,我们需要一些基准。 criterion 包非常适合这个。我们还将确保通过优化进行编译(因此 -O2 所有对 GHC 的调用)。设置基准可以如此简单:

import Criterion.Main

-- your code goes here

main = defaultMain
    [ bench "findNums 100 2" (nf (uncurry findNums) (100, 2))
    , bench "findNums 800 2" (nf (uncurry findNums) (800, 2))
    ]

也可以将基准实现为 nf (findNums 100) 2,但我选择这种方式,这样我们就不能 "cheat" 通过为 100 预先计算查找 table,从而将所有工作推入基准设置而不是基准实际 运行 的部分。这是原始实施的结果:

benchmarking 100 2
time                 762.7 ns   (757.4 ns .. 768.5 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 762.5 ns   (760.4 ns .. 765.3 ns)
std dev              7.706 ns   (6.378 ns .. 10.59 ns)

benchmarking 800 2
time                 29.17 s    (28.28 s .. 29.87 s)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 29.26 s    (29.08 s .. 29.35 s)
std dev              159.2 ms   (0.0 s .. 165.2 ms)
variance introduced by outliers: 19% (moderately inflated)

使用库

现在,唾手可得的成果是使用事物的现有实现,并希望他们的作者比我们做得更好。为此,我们将使用标准函数 (^) 而不是 pow,并使用 arithmoi 包中的 integerRoot 而不是 root。此外,我们将把惰性 foldr 换成严格的 foldl。为了我自己的理智,我还将很长的行重新格式化为较短的行。完整结果现在如下所示:

import Criterion.Main
import Data.List
import Math.NumberTheory.Powers

sum' :: Num a => [a] -> a
sum' = foldl' (+) 0

findNums :: Int -> Int -> Int
findNums a b = length
    [ xs
    | xs <- drop 1 . subsequences $ [x ^ b | x <- [1..c]]
    , sum' xs == a
    ] where c = integerRoot b a

main = defaultMain
    [ bench "100 2" (nf (uncurry findNums) (100, 2))
    , bench "800 2" (nf (uncurry findNums) (800, 2))
    ]

基准测试结果现在如下所示:

benchmarking 100 2
time                 722.8 ns   (721.3 ns .. 724.3 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 722.6 ns   (721.4 ns .. 724.1 ns)
std dev              4.440 ns   (3.738 ns .. 5.674 ns)

benchmarking 800 2
time                 17.16 s    (16.93 s .. 17.64 s)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 17.05 s    (16.99 s .. 17.15 s)
std dev              88.10 ms   (0.0 s .. 94.58 ms)

不费吹灰之力就快了两倍。不错!

更好的算法

subsequences 的一个重要问题是,即使我们计算 sum' [x,y,z] > a,我们仍然会查看所有以 [x,y,z] 开头的较长子序列。鉴于 subsequences' return 类型的结构,我们对此无能为力;所以让我们设计一个能给我们更多结构的实现。我们将构建一棵树,其中从根到任何内部节点的路径都会给我们一个子序列。

import Data.Tree

subsequences :: [a] -> Forest a
subsequences [] = []
subsequences (x:xs) = Node x rest : rest where
    rest = subsequences xs

(只是为了好玩,这会产生指数级大的语义树,但 space 使用率非常低——与原始列表大致相同 space——由于积极的子树共享。)关于这种表示,如果我们中断搜索,我们就会切断大量无趣的结果。这可以通过为列表实现类似 takeWhile 的东西来实现:

takeWhileTree :: Monoid m => (m -> Bool) -> Forest m -> Forest m
takeWhileTree predicate = goForest mempty where
    goForest m forest = forest >>= goTree m
    goTree   m (Node m' children) =
        [Node m (goForest (m <> m') children) | predicate m']

让我们试一试。完整代码现在是:

import Criterion.Main
import Data.Foldable
import Data.Monoid
import Data.Tree
import Math.NumberTheory.Powers

subsequencesTree :: [a] -> Forest a
subsequencesTree [] = []
subsequencesTree (x:xs) = Node x rest : rest where
    rest = subsequencesTree xs

takeWhileTree :: Monoid m => (m -> Bool) -> Forest m -> Forest m
takeWhileTree predicate = goForest mempty where
    goForest m forest = forest >>= goTree m
    goTree   m (Node m' children) = let m'' = m <> m' in
        [Node m' (goForest m'' children) | predicate m'']

leaves :: Forest a -> [[a]]
leaves [] = [[]]
leaves forest = do
    Node x children <- forest
    xs <- leaves children
    return (x:xs)

findNums :: Int -> Int -> Int
findNums a b = length
    [ xs
    | xs <- leaves
          . takeWhileTree (<= Sum a)
          . subsequencesTree
          $ [Sum (x ^ b) | x <- [1..c]]
    , fold xs == Sum a
    ] where c = integerRoot b a

main = defaultMain
    [ bench "100 2" (nf (uncurry findNums) (100, 2))
    , bench "800 2" (nf (uncurry findNums) (800, 2))
    ]

这看起来工作量很大,但从时间安排来看,它确实得到了回报:

benchmarking 100 2
time                 16.67 μs   (16.53 μs .. 16.77 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 16.60 μs   (16.52 μs .. 16.72 μs)
std dev              325.4 ns   (270.5 ns .. 444.1 ns)
variance introduced by outliers: 17% (moderately inflated)

benchmarking 800 2
time                 22.59 ms   (22.26 ms .. 22.89 ms)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 22.44 ms   (22.34 ms .. 22.57 ms)
std dev              260.3 μs   (191.6 μs .. 332.2 μs)

findNums 800 2 上的加速因子约为 1000。

并行化

我尝试通过在 takeWhileTree 中使用 concatparMap 而不是 (>>=) 来并行化它,以便在中探索树的单独分支平行线。在每种情况下,并行化的开销都远远超过拥有多个线程的好处。幸好我们在一开始就设定了基准!

在这种情况下,编写一个 returns 的函数很有用 实际序列,因为可以编写该函数 根据自身递归。

为了简化事情,我们只考虑平方和。 此外,我们将首先考虑有序序列(与 允许重复值);稍后我们将看看如何修改 只产生无序序列而没有任何重复的算法 数字。

这是我们的第一次尝试。该算法基于这样的思想:

Idea 1:

To obtain a sequence whose sum of squares is n, first pick a value c and a sequence xs whose sum of squares is n-c*c and put the two together.

 -- an integer sqrt function
 isqrt n = floor $ (sqrt (fromIntegral n) :: Double)

 pows2a :: Int -> [ [Int] ]
 pows2a n
   | n < 0     = []
   | n == 0    = [ [] ]
   | otherwise = [ (c:xs) | c <- [start,start-1..1], xs <- pows2a (n-c*c) ]
     where start = isqrt n

这有效,但是 returns 解决方案的排列以及解决方案 重复元素 - 例如pos2a 6 returns [2,1,1][1,2,1][1,1,2][1,1,1,1,1,1].

只有return个无序序列(没有重复)我们 使用这个想法:

Idea 2:

To obtain a sequence whose sum of squares is n, first pick a value c and a sequence xs whose sum of squares is n-c*c and all of whose elements are < c and put the two together.

这只是我们第一个算法的轻微修改:

 pows2b :: Int -> [[Int]]
 pows2b n
   | n < 0     = []
   | n == 0    = [ [] ]
   | otherwise =  [ (c:xs) | c <- [start, start-1..1], xs <- pows2b (n-c*c), all (< c) xs ]
   where
     start = isqrt n

这有效,但是像 pows2b 100 这样的调用需要很长时间才能完成,因为我们多次使用相同的参数调用 pows2b

我们可以通过记忆结果来解决这个问题,这就是 pows2c 所做的:

 powslist = map pows2c [0..]
 pows2c n
   | n == 0    = [ [] ]
   | otherwise = [ (c:xs) | c <- [s,s-1..1], xs <- powslist !! (n-c*c), all (< c) xs ]
   where s = isqrt n

此处使用参数 n-c*c 的递归调用被替换为对列表的查找,这是缓存答案的一种方式。