可能优化或使用并行计算
Optimization possible or use parallel computing
我遇到这个问题,我需要找到等于一个数的幂和的数量。例如:
一个输入
100 2
会产生 3
的输出,因为 100 = 10^2 = 6^2 + 8^2 = 1^2 + 3^2 + 4^2 + 5^2 + 7^2
而 100 3
的输入会产生 1 的输出,因为 100 = 1^3 + 2^3 + 3^3 + 4^3
所以我解决这个问题的函数是:
findNums :: Int -> Int -> Int
findNums a b = length [xs | xs <- (drop 1 .) subsequences [pow x b | x <- [1..c]], foldr (+) (head xs) (tail xs) == a] where c = root a b 0
root :: Int -> Int -> Int -> Int
root n a i
| pow i a <= n && pow (i+1) a > n = i
| otherwise = root n a (i+1)
pow :: Int -> Int -> Int
pow _ 0 = 1
pow n a = n * pow n (a - 1)
我找到了我的数字集中所有可能的值,这些值加起来等于所需的数字。然后我在该列表中找到所有可能的子列表,并查看其中有多少加起来达到所需的数量。这是可行的,但由于它是一个详尽的搜索,因此需要很长时间才能输入 800 2
。是否可以优化序列以便仅返回 "plausible" 子序列?还是这种问题最好看一下并行计算?
正如您所建议的,这里有一些优化空间而无需诉诸并行化(请记住,如果您要从一个到四个并行线程,并行化最多可以提供 4 倍的加速)。
subsequences
函数所做的基本上是遍历列表,并且它为每个元素创建两个执行分支:一个包含该元素的分支,一个不包含该元素的分支。即,subsequences [1,2,3]
会:
start
/-------/ \-------\ (take 1 or not)
[1,..] [..]
/ \ / \ (take 2 or not)
[1,2,..] [1,..] [2,..] [..]
/ \ / \ / \ / \ (take 3 or not)
[1,2,3] [1,2] [1,3] [1] [2,3] [2] [3] []
subsequences [1,2,3]
的结果是一个包含底部叶节点的列表。
现在,在每个中间节点(即[1,2,..]
),我们可以检查将值函数(即平方和或立方之和等)应用于已取数字的结果.如果我们已经超过了目标,那么继续该分支就没有意义了。如果我们自己写这个子序列生成逻辑,我们可以这样做:
findNums :: Int -> Int -> Int
findNums a b = findNums' a b 1 0
findNums' :: Int -> Int -> Int -> Int -> Int
findNums' a b c s
| s + c^b > a = 0
| s + c^b == a = 1
| otherwise = findNums' a b (c+1) (s + c^b) +
findNums' a b (c+1) s
这里c
是我们的计数器,s
是我们选择的数字的幂和。 findNums'
中有三种情况:
在第一种情况下,我们检查包括这个数字是否会使我们超出目标。在这种情况下,该分支不会给出任何有效结果,因此我们终止它并通过 returning 0.
指示它不包含任何解决方案
在第二种情况下,我们检查包括这个数字是否会让我们正确。在那种情况下,我们 return 1,基本上注意到我们已经找到了解决方案。
如果其中 none 为真,我们将进一步递归两个不同的分支:一个是我们将 c^b
添加到我们的总和,另一个是我们不这样做。我们把结果加在一起,也就是说这里的结果会是这个点以下找到有效解的分支数。
让我们浏览一下一些东西。
基准测试
首先:让我们确保我们确实在进行改进!为此,我们需要一些基准。 criterion 包非常适合这个。我们还将确保通过优化进行编译(因此 -O2
所有对 GHC 的调用)。设置基准可以如此简单:
import Criterion.Main
-- your code goes here
main = defaultMain
[ bench "findNums 100 2" (nf (uncurry findNums) (100, 2))
, bench "findNums 800 2" (nf (uncurry findNums) (800, 2))
]
也可以将基准实现为 nf (findNums 100) 2
,但我选择这种方式,这样我们就不能 "cheat" 通过为 100
预先计算查找 table,从而将所有工作推入基准设置而不是基准实际 运行 的部分。这是原始实施的结果:
benchmarking 100 2
time 762.7 ns (757.4 ns .. 768.5 ns)
1.000 R² (1.000 R² .. 1.000 R²)
mean 762.5 ns (760.4 ns .. 765.3 ns)
std dev 7.706 ns (6.378 ns .. 10.59 ns)
benchmarking 800 2
time 29.17 s (28.28 s .. 29.87 s)
1.000 R² (1.000 R² .. 1.000 R²)
mean 29.26 s (29.08 s .. 29.35 s)
std dev 159.2 ms (0.0 s .. 165.2 ms)
variance introduced by outliers: 19% (moderately inflated)
使用库
现在,唾手可得的成果是使用事物的现有实现,并希望他们的作者比我们做得更好。为此,我们将使用标准函数 (^)
而不是 pow
,并使用 arithmoi 包中的 integerRoot
而不是 root
。此外,我们将把惰性 foldr
换成严格的 foldl
。为了我自己的理智,我还将很长的行重新格式化为较短的行。完整结果现在如下所示:
import Criterion.Main
import Data.List
import Math.NumberTheory.Powers
sum' :: Num a => [a] -> a
sum' = foldl' (+) 0
findNums :: Int -> Int -> Int
findNums a b = length
[ xs
| xs <- drop 1 . subsequences $ [x ^ b | x <- [1..c]]
, sum' xs == a
] where c = integerRoot b a
main = defaultMain
[ bench "100 2" (nf (uncurry findNums) (100, 2))
, bench "800 2" (nf (uncurry findNums) (800, 2))
]
基准测试结果现在如下所示:
benchmarking 100 2
time 722.8 ns (721.3 ns .. 724.3 ns)
1.000 R² (1.000 R² .. 1.000 R²)
mean 722.6 ns (721.4 ns .. 724.1 ns)
std dev 4.440 ns (3.738 ns .. 5.674 ns)
benchmarking 800 2
time 17.16 s (16.93 s .. 17.64 s)
1.000 R² (1.000 R² .. 1.000 R²)
mean 17.05 s (16.99 s .. 17.15 s)
std dev 88.10 ms (0.0 s .. 94.58 ms)
不费吹灰之力就快了两倍。不错!
更好的算法
subsequences
的一个重要问题是,即使我们计算 sum' [x,y,z] > a
,我们仍然会查看所有以 [x,y,z]
开头的较长子序列。鉴于 subsequences
' return 类型的结构,我们对此无能为力;所以让我们设计一个能给我们更多结构的实现。我们将构建一棵树,其中从根到任何内部节点的路径都会给我们一个子序列。
import Data.Tree
subsequences :: [a] -> Forest a
subsequences [] = []
subsequences (x:xs) = Node x rest : rest where
rest = subsequences xs
(只是为了好玩,这会产生指数级大的语义树,但 space 使用率非常低——与原始列表大致相同 space——由于积极的子树共享。)关于这种表示,如果我们中断搜索,我们就会切断大量无趣的结果。这可以通过为列表实现类似 takeWhile
的东西来实现:
takeWhileTree :: Monoid m => (m -> Bool) -> Forest m -> Forest m
takeWhileTree predicate = goForest mempty where
goForest m forest = forest >>= goTree m
goTree m (Node m' children) =
[Node m (goForest (m <> m') children) | predicate m']
让我们试一试。完整代码现在是:
import Criterion.Main
import Data.Foldable
import Data.Monoid
import Data.Tree
import Math.NumberTheory.Powers
subsequencesTree :: [a] -> Forest a
subsequencesTree [] = []
subsequencesTree (x:xs) = Node x rest : rest where
rest = subsequencesTree xs
takeWhileTree :: Monoid m => (m -> Bool) -> Forest m -> Forest m
takeWhileTree predicate = goForest mempty where
goForest m forest = forest >>= goTree m
goTree m (Node m' children) = let m'' = m <> m' in
[Node m' (goForest m'' children) | predicate m'']
leaves :: Forest a -> [[a]]
leaves [] = [[]]
leaves forest = do
Node x children <- forest
xs <- leaves children
return (x:xs)
findNums :: Int -> Int -> Int
findNums a b = length
[ xs
| xs <- leaves
. takeWhileTree (<= Sum a)
. subsequencesTree
$ [Sum (x ^ b) | x <- [1..c]]
, fold xs == Sum a
] where c = integerRoot b a
main = defaultMain
[ bench "100 2" (nf (uncurry findNums) (100, 2))
, bench "800 2" (nf (uncurry findNums) (800, 2))
]
这看起来工作量很大,但从时间安排来看,它确实得到了回报:
benchmarking 100 2
time 16.67 μs (16.53 μs .. 16.77 μs)
0.999 R² (0.999 R² .. 1.000 R²)
mean 16.60 μs (16.52 μs .. 16.72 μs)
std dev 325.4 ns (270.5 ns .. 444.1 ns)
variance introduced by outliers: 17% (moderately inflated)
benchmarking 800 2
time 22.59 ms (22.26 ms .. 22.89 ms)
0.999 R² (0.999 R² .. 1.000 R²)
mean 22.44 ms (22.34 ms .. 22.57 ms)
std dev 260.3 μs (191.6 μs .. 332.2 μs)
findNums 800 2
上的加速因子约为 1000。
并行化
我尝试通过在 takeWhileTree
中使用 concat
和 parMap
而不是 (>>=)
来并行化它,以便在中探索树的单独分支平行线。在每种情况下,并行化的开销都远远超过拥有多个线程的好处。幸好我们在一开始就设定了基准!
在这种情况下,编写一个 returns 的函数很有用
实际序列,因为可以编写该函数
根据自身递归。
为了简化事情,我们只考虑平方和。
此外,我们将首先考虑有序序列(与
允许重复值);稍后我们将看看如何修改
只产生无序序列而没有任何重复的算法
数字。
这是我们的第一次尝试。该算法基于这样的思想:
Idea 1:
To obtain a sequence whose sum of squares is n, first
pick a value c and a sequence xs whose sum of squares
is n-c*c and put the two together.
-- an integer sqrt function
isqrt n = floor $ (sqrt (fromIntegral n) :: Double)
pows2a :: Int -> [ [Int] ]
pows2a n
| n < 0 = []
| n == 0 = [ [] ]
| otherwise = [ (c:xs) | c <- [start,start-1..1], xs <- pows2a (n-c*c) ]
where start = isqrt n
这有效,但是 returns 解决方案的排列以及解决方案
重复元素 - 例如pos2a 6
returns [2,1,1]
、[1,2,1]
、[1,1,2]
和 [1,1,1,1,1,1]
.
只有return个无序序列(没有重复)我们
使用这个想法:
Idea 2:
To obtain a sequence whose sum of squares is n, first
pick a value c and a sequence xs whose sum of squares
is n-c*c and all of whose elements are < c and
put the two together.
这只是我们第一个算法的轻微修改:
pows2b :: Int -> [[Int]]
pows2b n
| n < 0 = []
| n == 0 = [ [] ]
| otherwise = [ (c:xs) | c <- [start, start-1..1], xs <- pows2b (n-c*c), all (< c) xs ]
where
start = isqrt n
这有效,但是像 pows2b 100
这样的调用需要很长时间才能完成,因为我们多次使用相同的参数调用 pows2b
。
我们可以通过记忆结果来解决这个问题,这就是 pows2c
所做的:
powslist = map pows2c [0..]
pows2c n
| n == 0 = [ [] ]
| otherwise = [ (c:xs) | c <- [s,s-1..1], xs <- powslist !! (n-c*c), all (< c) xs ]
where s = isqrt n
此处使用参数 n-c*c
的递归调用被替换为对列表的查找,这是缓存答案的一种方式。
我遇到这个问题,我需要找到等于一个数的幂和的数量。例如:
一个输入
100 2
会产生 3
的输出,因为 100 = 10^2 = 6^2 + 8^2 = 1^2 + 3^2 + 4^2 + 5^2 + 7^2
而 100 3
的输入会产生 1 的输出,因为 100 = 1^3 + 2^3 + 3^3 + 4^3
所以我解决这个问题的函数是:
findNums :: Int -> Int -> Int
findNums a b = length [xs | xs <- (drop 1 .) subsequences [pow x b | x <- [1..c]], foldr (+) (head xs) (tail xs) == a] where c = root a b 0
root :: Int -> Int -> Int -> Int
root n a i
| pow i a <= n && pow (i+1) a > n = i
| otherwise = root n a (i+1)
pow :: Int -> Int -> Int
pow _ 0 = 1
pow n a = n * pow n (a - 1)
我找到了我的数字集中所有可能的值,这些值加起来等于所需的数字。然后我在该列表中找到所有可能的子列表,并查看其中有多少加起来达到所需的数量。这是可行的,但由于它是一个详尽的搜索,因此需要很长时间才能输入 800 2
。是否可以优化序列以便仅返回 "plausible" 子序列?还是这种问题最好看一下并行计算?
正如您所建议的,这里有一些优化空间而无需诉诸并行化(请记住,如果您要从一个到四个并行线程,并行化最多可以提供 4 倍的加速)。
subsequences
函数所做的基本上是遍历列表,并且它为每个元素创建两个执行分支:一个包含该元素的分支,一个不包含该元素的分支。即,subsequences [1,2,3]
会:
start
/-------/ \-------\ (take 1 or not)
[1,..] [..]
/ \ / \ (take 2 or not)
[1,2,..] [1,..] [2,..] [..]
/ \ / \ / \ / \ (take 3 or not)
[1,2,3] [1,2] [1,3] [1] [2,3] [2] [3] []
subsequences [1,2,3]
的结果是一个包含底部叶节点的列表。
现在,在每个中间节点(即[1,2,..]
),我们可以检查将值函数(即平方和或立方之和等)应用于已取数字的结果.如果我们已经超过了目标,那么继续该分支就没有意义了。如果我们自己写这个子序列生成逻辑,我们可以这样做:
findNums :: Int -> Int -> Int
findNums a b = findNums' a b 1 0
findNums' :: Int -> Int -> Int -> Int -> Int
findNums' a b c s
| s + c^b > a = 0
| s + c^b == a = 1
| otherwise = findNums' a b (c+1) (s + c^b) +
findNums' a b (c+1) s
这里c
是我们的计数器,s
是我们选择的数字的幂和。 findNums'
中有三种情况:
在第一种情况下,我们检查包括这个数字是否会使我们超出目标。在这种情况下,该分支不会给出任何有效结果,因此我们终止它并通过 returning 0.
指示它不包含任何解决方案在第二种情况下,我们检查包括这个数字是否会让我们正确。在那种情况下,我们 return 1,基本上注意到我们已经找到了解决方案。
如果其中 none 为真,我们将进一步递归两个不同的分支:一个是我们将 c^b
添加到我们的总和,另一个是我们不这样做。我们把结果加在一起,也就是说这里的结果会是这个点以下找到有效解的分支数。
让我们浏览一下一些东西。
基准测试
首先:让我们确保我们确实在进行改进!为此,我们需要一些基准。 criterion 包非常适合这个。我们还将确保通过优化进行编译(因此 -O2
所有对 GHC 的调用)。设置基准可以如此简单:
import Criterion.Main
-- your code goes here
main = defaultMain
[ bench "findNums 100 2" (nf (uncurry findNums) (100, 2))
, bench "findNums 800 2" (nf (uncurry findNums) (800, 2))
]
也可以将基准实现为 nf (findNums 100) 2
,但我选择这种方式,这样我们就不能 "cheat" 通过为 100
预先计算查找 table,从而将所有工作推入基准设置而不是基准实际 运行 的部分。这是原始实施的结果:
benchmarking 100 2
time 762.7 ns (757.4 ns .. 768.5 ns)
1.000 R² (1.000 R² .. 1.000 R²)
mean 762.5 ns (760.4 ns .. 765.3 ns)
std dev 7.706 ns (6.378 ns .. 10.59 ns)
benchmarking 800 2
time 29.17 s (28.28 s .. 29.87 s)
1.000 R² (1.000 R² .. 1.000 R²)
mean 29.26 s (29.08 s .. 29.35 s)
std dev 159.2 ms (0.0 s .. 165.2 ms)
variance introduced by outliers: 19% (moderately inflated)
使用库
现在,唾手可得的成果是使用事物的现有实现,并希望他们的作者比我们做得更好。为此,我们将使用标准函数 (^)
而不是 pow
,并使用 arithmoi 包中的 integerRoot
而不是 root
。此外,我们将把惰性 foldr
换成严格的 foldl
。为了我自己的理智,我还将很长的行重新格式化为较短的行。完整结果现在如下所示:
import Criterion.Main
import Data.List
import Math.NumberTheory.Powers
sum' :: Num a => [a] -> a
sum' = foldl' (+) 0
findNums :: Int -> Int -> Int
findNums a b = length
[ xs
| xs <- drop 1 . subsequences $ [x ^ b | x <- [1..c]]
, sum' xs == a
] where c = integerRoot b a
main = defaultMain
[ bench "100 2" (nf (uncurry findNums) (100, 2))
, bench "800 2" (nf (uncurry findNums) (800, 2))
]
基准测试结果现在如下所示:
benchmarking 100 2
time 722.8 ns (721.3 ns .. 724.3 ns)
1.000 R² (1.000 R² .. 1.000 R²)
mean 722.6 ns (721.4 ns .. 724.1 ns)
std dev 4.440 ns (3.738 ns .. 5.674 ns)
benchmarking 800 2
time 17.16 s (16.93 s .. 17.64 s)
1.000 R² (1.000 R² .. 1.000 R²)
mean 17.05 s (16.99 s .. 17.15 s)
std dev 88.10 ms (0.0 s .. 94.58 ms)
不费吹灰之力就快了两倍。不错!
更好的算法
subsequences
的一个重要问题是,即使我们计算 sum' [x,y,z] > a
,我们仍然会查看所有以 [x,y,z]
开头的较长子序列。鉴于 subsequences
' return 类型的结构,我们对此无能为力;所以让我们设计一个能给我们更多结构的实现。我们将构建一棵树,其中从根到任何内部节点的路径都会给我们一个子序列。
import Data.Tree
subsequences :: [a] -> Forest a
subsequences [] = []
subsequences (x:xs) = Node x rest : rest where
rest = subsequences xs
(只是为了好玩,这会产生指数级大的语义树,但 space 使用率非常低——与原始列表大致相同 space——由于积极的子树共享。)关于这种表示,如果我们中断搜索,我们就会切断大量无趣的结果。这可以通过为列表实现类似 takeWhile
的东西来实现:
takeWhileTree :: Monoid m => (m -> Bool) -> Forest m -> Forest m
takeWhileTree predicate = goForest mempty where
goForest m forest = forest >>= goTree m
goTree m (Node m' children) =
[Node m (goForest (m <> m') children) | predicate m']
让我们试一试。完整代码现在是:
import Criterion.Main
import Data.Foldable
import Data.Monoid
import Data.Tree
import Math.NumberTheory.Powers
subsequencesTree :: [a] -> Forest a
subsequencesTree [] = []
subsequencesTree (x:xs) = Node x rest : rest where
rest = subsequencesTree xs
takeWhileTree :: Monoid m => (m -> Bool) -> Forest m -> Forest m
takeWhileTree predicate = goForest mempty where
goForest m forest = forest >>= goTree m
goTree m (Node m' children) = let m'' = m <> m' in
[Node m' (goForest m'' children) | predicate m'']
leaves :: Forest a -> [[a]]
leaves [] = [[]]
leaves forest = do
Node x children <- forest
xs <- leaves children
return (x:xs)
findNums :: Int -> Int -> Int
findNums a b = length
[ xs
| xs <- leaves
. takeWhileTree (<= Sum a)
. subsequencesTree
$ [Sum (x ^ b) | x <- [1..c]]
, fold xs == Sum a
] where c = integerRoot b a
main = defaultMain
[ bench "100 2" (nf (uncurry findNums) (100, 2))
, bench "800 2" (nf (uncurry findNums) (800, 2))
]
这看起来工作量很大,但从时间安排来看,它确实得到了回报:
benchmarking 100 2
time 16.67 μs (16.53 μs .. 16.77 μs)
0.999 R² (0.999 R² .. 1.000 R²)
mean 16.60 μs (16.52 μs .. 16.72 μs)
std dev 325.4 ns (270.5 ns .. 444.1 ns)
variance introduced by outliers: 17% (moderately inflated)
benchmarking 800 2
time 22.59 ms (22.26 ms .. 22.89 ms)
0.999 R² (0.999 R² .. 1.000 R²)
mean 22.44 ms (22.34 ms .. 22.57 ms)
std dev 260.3 μs (191.6 μs .. 332.2 μs)
findNums 800 2
上的加速因子约为 1000。
并行化
我尝试通过在 takeWhileTree
中使用 concat
和 parMap
而不是 (>>=)
来并行化它,以便在中探索树的单独分支平行线。在每种情况下,并行化的开销都远远超过拥有多个线程的好处。幸好我们在一开始就设定了基准!
在这种情况下,编写一个 returns 的函数很有用 实际序列,因为可以编写该函数 根据自身递归。
为了简化事情,我们只考虑平方和。 此外,我们将首先考虑有序序列(与 允许重复值);稍后我们将看看如何修改 只产生无序序列而没有任何重复的算法 数字。
这是我们的第一次尝试。该算法基于这样的思想:
Idea 1:
To obtain a sequence whose sum of squares is n, first pick a value c and a sequence xs whose sum of squares is n-c*c and put the two together.
-- an integer sqrt function
isqrt n = floor $ (sqrt (fromIntegral n) :: Double)
pows2a :: Int -> [ [Int] ]
pows2a n
| n < 0 = []
| n == 0 = [ [] ]
| otherwise = [ (c:xs) | c <- [start,start-1..1], xs <- pows2a (n-c*c) ]
where start = isqrt n
这有效,但是 returns 解决方案的排列以及解决方案
重复元素 - 例如pos2a 6
returns [2,1,1]
、[1,2,1]
、[1,1,2]
和 [1,1,1,1,1,1]
.
只有return个无序序列(没有重复)我们 使用这个想法:
Idea 2:
To obtain a sequence whose sum of squares is n, first pick a value c and a sequence xs whose sum of squares is n-c*c and all of whose elements are < c and put the two together.
这只是我们第一个算法的轻微修改:
pows2b :: Int -> [[Int]]
pows2b n
| n < 0 = []
| n == 0 = [ [] ]
| otherwise = [ (c:xs) | c <- [start, start-1..1], xs <- pows2b (n-c*c), all (< c) xs ]
where
start = isqrt n
这有效,但是像 pows2b 100
这样的调用需要很长时间才能完成,因为我们多次使用相同的参数调用 pows2b
。
我们可以通过记忆结果来解决这个问题,这就是 pows2c
所做的:
powslist = map pows2c [0..]
pows2c n
| n == 0 = [ [] ]
| otherwise = [ (c:xs) | c <- [s,s-1..1], xs <- powslist !! (n-c*c), all (< c) xs ]
where s = isqrt n
此处使用参数 n-c*c
的递归调用被替换为对列表的查找,这是缓存答案的一种方式。