Python函数式迭代算法?

Python fuctional style iterative algoritm?

在Haskell中有一个简单的列表函数可用

iterate :: (a -> a) -> a -> [a]
iterate f x =  x : iterate f (f x)

在python中可以实现如下:

def iterate(f, init):
  while True:
    yield init
    init = f(init)

functools/itertools 模块中没有像这样基本的东西,这让我有点惊讶。是否可以使用这些库中提供的工具以函数式风格(即没有循环)简单地构造它? (主要是代码高尔夫,试图了解 Python 中的函数式风格。)

您可以使用 itertools 中的一些函数来完成:

from itertools import accumulate, repeat

def iterate(func, initial):
    return accumulate(repeat(None), func=lambda tot, _: func(tot), initial=initial)

虽然明显不是很干净。 Itertools 缺少一些用于构建流的基本函数,例如 unfoldr。大多数 itertools 函数都可以根据 unfoldr 定义,但函数式编程在 Python 中有点不舒服,所以这可能没有太大好处。

itertools 模块 more-iterools 有一个第 3 方“扩展”,其中包括(除其他外)一个 iterate 函数,其定义与您观察到的完全相同:

# Exact definition, minus the doc string...
def iterate(func, start):
    while True:
        yield start
        start = func(start)

Python 缺少像

这样的递归定义所必需的优化
def iterate(func, start):
    yield from chain([start], iterate(func, func(start))

可行。


如果你好奇,Coconut is a superset of Python that does do things like tail-call optimization. Try the following code at https://cs121-team-panda.github.io/coconut-interpreter/:

@recursive_iterator
def iterate(f, s) = (s,) :: iterate(f, f(s))

for x in iterate(x -> x + 1, 0)$[1000:1010]:
    print(x)

(我不完全确定 recursive_iterator 装饰器是否必要。我认为,迭代切片表明,这避免了 Python 中类似代码会产生的 recursion-depth 错误.)

您可以使用 walrus operator* in a generator expression 获得所需的输出。

from itertools import chain, repeat

def f(x):
    return x + 10

x = 0
it = chain([x], (x:=f(x) for _ in repeat(None)))

>>> next(it)
0
>>> next(it)
10
>>> next(it)
20

* Walrus operator is available from python3.8 or above

是的,我们可以用mapitertools构造它“loop-less”,而且比其他的快:

from itertools import tee, chain, islice

def iterate(f, init):
    parts = [[init]]
    values1, values2 = tee(chain.from_iterable(parts))
    parts.append(map(f, values2))
    return values1

def f(x):
    return 3 * x

print(*islice(iterate(f, 1), 10))

输出(Try it online!):

1 3 9 27 81 243 729 2187 6561 19683

第一个问题是我们需要外部用户的值和反馈给自身的值以计算更多值。我们可以使用 tee 来复制值。

接下来我们有一个 chicken-and-egg 问题:我们想使用 map(f, values2) 来获取函数值,其中 values2 来自我们的 map 迭代器即将创造!幸运的是,chain.from_iterable 接受了一个迭代器,我们可以在 之后扩展 创建 chain.

或者,我们可以让 parts 成为一个生成器,因为它只在创建后访问 values2

def iterate(f, init):
    def parts():
        yield init,
        yield map(f, values2)
    values1, values2 = tee(chain.from_iterable(parts()))
    return values1

使用 f = absinit = 0 计算 100,000 个值的基准(只是为了最小化解决方案速度差异的稀释):

CPython 3.8.0b4 on tio.run:

    mean  stdev  (from best 5 of 20 attempts)
 2.75 ms  0.03 ms  with_itertools1  (my first solution)
 2.76 ms  0.02 ms  with_itertools2  (my second solution)
 5.29 ms  0.02 ms  with_generator   (the question's solution)
 5.73 ms  0.04 ms  with_walrus
 9.00 ms  0.09 ms  with_accumulate

CPython 3.10.4 on my Windows laptop:

    mean  stdev  (from best 5 of 20 attempts)
 8.37 ms  0.02 ms  with_itertools2
 8.37 ms  0.00 ms  with_itertools1
17.86 ms  0.00 ms  with_generator
20.73 ms  0.01 ms  with_walrus
26.03 ms  0.24 ms  with_accumulate

CPython 3.10.4 on a Debian Google Compute Engine instance:

    mean  stdev  (from best 5 of 20 attempts)
 2.25 ms  0.00 ms  with_itertools1
 2.26 ms  0.00 ms  with_itertools2
 3.91 ms  0.00 ms  with_generator
 4.43 ms  0.00 ms  with_walrus
 7.14 ms  0.01 ms  with_accumulate

基准代码(Try it online!):

from itertools import accumulate, tee, chain, islice, repeat
import timeit
from bisect import insort
from random import shuffle
from statistics import mean, stdev
import sys

def with_accumulate(f, init):
    return accumulate(repeat(None), func=lambda tot, _: f(tot), initial=init)

def with_generator(f, init):
    while True:
        yield init
        init = f(init)

def with_walrus(f, init):
    return chain([x:=init], (x:=f(x) for _ in repeat(None)))

def with_itertools1(f, init):
    parts = [[init]]
    values1, values2 = tee(chain.from_iterable(parts))
    parts.append(map(f, values2))
    return values1

def with_itertools2(f, init):
    def parts():
        yield init,
        yield map(f, values2)
    values1, values2 = tee(chain.from_iterable(parts()))
    return values1

solutions = [
    with_accumulate,
    with_generator,
    with_walrus,
    with_itertools1,
    with_itertools2,
]

for solution in solutions:
    iterator = solution(lambda x: 3 * x, 1)
    print(*islice(iterator, 10), solution.__name__)

def consume(iterator, n):
    next(islice(iterator, n, n), None)

attempts, best = 20, 5
times = {solution: [] for solution in solutions}
for _ in range(attempts):
    shuffle(solutions)
    for solution in solutions:
        time = min(timeit.repeat(lambda: consume(solution(abs, 0), 10**5), number=1))
        insort(times[solution], time)
print(f'    mean  stdev  (from best {best} of {attempts} attempts)')
for solution in sorted(solutions, key=times.get):
    ts = times[solution][:best]
    print('%5.2f ms ' * 2 % (mean(ts) * 1e3, stdev(ts) * 1e3), solution.__name__)

print()
print(sys.implementation.name, sys.version)