在 python 中使埃拉托色尼筛法的内存效率更高?

Making Sieve of Eratosthenes more memory efficient in python?

埃拉托色尼筛法内存限制问题

我目前正在尝试针对 Kattis 问题实施埃拉托色尼筛法的一个版本,但是,我 运行 遇到一些我的实施无法通过的内存限制。

这里是 link 问题 statement。简而言之,这个问题要我首先 return 小于或等于 n 的素数,然后解决一定数量的查询,如果一个数字 i 是否为素数。有 50 MB 内存使用限制以及仅使用 python 的标准库(无 numpy 等)。内存限制是我卡住的地方。

到目前为止,这是我的代码:

import sys

def sieve_of_eratosthenes(xs, n):
    count = len(xs) + 1
    p = 3 # start at three
    index = 0
    while p*p < n:
        for i in range(index + p, len(xs), p):
            if xs[i]:
                xs[i] = 0
                count -= 1

        temp_index = index
        for i in range(index + 1, len(xs)):
            if xs[i]:
                p = xs[i]
                temp_index += 1
                break
            temp_index += 1
        index = temp_index

    return count


def isPrime(xs, a):
    if a == 1:
        return False
    if a == 2:
        return True
    if not (a & 1):
        return False
    return bool(xs[(a >> 1) - 1])

def main():
    n, q = map(int, sys.stdin.readline().split(' '))
    odds = [num for num in range(2, n+1) if (num & 1)]
    print(sieve_of_eratosthenes(odds, n))

    for _ in range(q):
        query = int(input())
        if isPrime(odds, query):
            print('1')
        else:
            print('0')


if __name__ == "__main__":
    main()

到目前为止,我已经做了一些改进,比如只保留所有奇数的列表,这样可以减少一半的内存使用量。我也确信代码在计算素数时会按预期工作(没有得到错误的答案)。我现在的问题是,我怎样才能使我的代码更有效地存储内存?我应该使用其他一些数据结构吗?用布尔值替换我的整数列表?位数组?

非常感谢任何建议!

编辑

在对 python 中的代码进行一些调整后,我遇到了一个问题,我的分段筛的实现无法满足内存要求。

相反,我选择实施 Java 中的解决方案,这几乎不费吹灰之力。这是代码:

  public int sieveOfEratosthenes(int n){
    sieve = new BitSet((n+1) / 2);
    int count = (n + 1) / 2;

    for (int i=3; i*i <= n; i += 2){
      if (isComposite(i)) {
        continue;
      }

      // Increment by two, skipping all even numbers
      for (int c = i * i; c <= n; c += 2 * i){
        if(!isComposite(c)){
          setComposite(c);
          count--;
        }
      }
    }

    return count;

  }

  public boolean isComposite(int k) {
    return sieve.get((k - 3) / 2); // Since we don't keep track of even numbers
  }

  public void setComposite(int k) {
    sieve.set((k - 3) / 2); // Since we don't keep track of even numbers
  }

  public boolean isPrime(int a) {
    if (a < 3)
      return a > 1;

    if (a == 2)
      return true;

    if ((a & 1) == 1)
      return !isComposite(a);
    else
      return false;

  }

  public void run() throws Exception{
    BufferedReader scan = new BufferedReader(new InputStreamReader(System.in));
    String[] line = scan.readLine().split(" ");

    int n = Integer.parseInt(line[0]); int q = Integer.parseInt(line[1]);
    System.out.println(sieveOfEratosthenes(n));

    for (int i=0; i < q; i++){
      line = scan.readLine().split(" ");
      System.out.println( isPrime(Integer.parseInt(line[0])) ? '1' : '0');
    }
  }

我个人还没有找到在 Python 中实现此 BitSet 解决方案的方法(仅使用标准库)。

如果有人偶然发现 python 中问题的巧妙实现,使用分段筛、位数组或其他东西,我很想看看解决方案。

我认为你可以尝试使用布尔值列表来标记其索引是否为质数:

def sieve_of_erato(range_max):
    primes_count = range_max
    is_prime = [True for i in range(range_max + 1)]
    # Cross out all even numbers first.
    for i in range(4, range_max, 2):
        is_prime[i] = False
        primes_count -=1
    i = 3
    while i * i <= range_max:
        if is_prime[i]:
            # Update all multiples of this prime number
            # CAREFUL: Take note of the range args.
            # Reason for i += 2*i instead of i += i:
            # Since p and p*p, both are odd, (p*p + p) will be even,
            # which means that it would have already been marked before
            for multiple in range(i * i, range_max + 1, i * 2):
                is_prime[multiple] = False
                primes_count -= 1
        i += 1

    return primes_count


def main():
    num_primes = sieve_of_erato(100)
    print(num_primes)


if __name__ == "__main__":
    main()

稍后您可以使用 is_prime 数组来检查一个数是否为质数,只需检查 is_prime[number] == True

如果这不起作用,请尝试 segmented sieve

作为奖励,您可能会惊讶地发现有一种方法可以在 O(n) 而不是 O(nloglogn) 中生成筛子。检查代码 here.

这确实是一个非常具有挑战性的问题。最大可能 N 为 10^8,假设没有任何开销,每个值使用一个字节将导致几乎 100 MB 的数据。即使通过仅存储奇数来将数据减半,在考虑开销后也会使您非常接近 50 MB。

这意味着解决方案必须使用以下几种策略中的一种或多种:

  1. 为我们的素数标志数组使用更高效的数据类型。 Python 列表维护一个指向每个列表项的指针数组(64 位 python 每个 4 个字节)。我们实际上需要原始二进制存储,在标准 python.
  2. 中几乎只剩下 bytearray
  3. 筛选中的每个值仅使用一位而不是整个字节(Bool 在技术上只需要一位,但通常使用完整字节)。
  4. 细分以删除偶数,可能 还有 3、5、7 等的倍数
  5. 使用 segmented sieve

我最初试图通过在筛子中每个值只存储 1 位来解决问题,虽然内存使用确实在要求之内,但 Python 的慢位操作也将执行时间推得很远长。找出复杂的索引以确保可靠地计算正确的位也相当困难。

然后我使用 bytearray 实现了奇数解决方案,虽然速度快了很多,但内存仍然是个问题。

字节数组奇数实现:

class Sieve:
    def __init__(self, n):
        self.not_prime = bytearray(n+1)
        self.not_prime[0] = self.not_prime[1] = 1
        for i in range(2, int(n**.5)+1):
            if self.not_prime[i] == 0:
                self.not_prime[i*i::i] = [1]*len(self.not_prime[i*i::i])
        self.n_prime = n + 1 - sum(self.not_prime)
        
    def is_prime(self, n):
        return int(not self.not_prime[n])
        


def main():
    n, q = map(int, input().split())
    s = Sieve(n)
    print(s.n_prime)
    for _ in range(q):
        i = int(input())
        print(s.is_prime(i))

if __name__ == "__main__":
    main()

进一步减少内存应该*使其工作。

编辑: 删除 2 和 3 的倍数似乎也不足以减少内存,尽管 guppy.hpy().heap() 似乎表明我的使用量实际上略低于 50MB。 ‍♂️

我学到了一个技巧 just yesterday - 如果将数字分成 6 组,则 6 个中只有 2 个可能是素数。其他的可以除以2或3。这意味着只需要2位来跟踪6个数字的素数;一个包含 8 位的字节可以跟踪 24 个数字的素数!这大大降低了筛子的内存需求。

在 Python 3.7.5 64 位 Windows 10 中,以下代码没有超过 36.4 MB。

remainder_bit = [0, 0x01, 0, 0, 0, 0x02,
                 0, 0x04, 0, 0, 0, 0x08,
                 0, 0x10, 0, 0, 0, 0x20,
                 0, 0x40, 0, 0, 0, 0x80]

def is_prime(xs, a):
    if a <= 3:
        return a > 1
    index, rem = divmod(a, 24)
    bit = remainder_bit[rem]
    if not bit:
        return False
    return not (xs[index] & bit)

def sieve_of_eratosthenes(xs, n):
    count = (n // 3) + 1 # subtract out 1 and 4, add 2 3 and 5
    p = 5
    while p*p <= n:
        if is_prime(xs, p):
            for i in range(5 * p, n + 1, p):
                index, rem = divmod(i, 24)
                bit = remainder_bit[rem]
                if bit and not (xs[index] & bit):
                    xs[index] |= bit
                    count -= 1
        p += 2
        if is_prime(xs, p):
            for i in range(5 * p, n + 1, p):
                index, rem = divmod(i, 24)
                bit = remainder_bit[rem]
                if bit and not (xs[index] & bit):
                    xs[index] |= bit
                    count -= 1
        p += 4

    return count


def init_sieve(n):
    return bytearray((n + 23) // 24)

n = 100000000
xs = init_sieve(n)
sieve_of_eratosthenes(xs, n)
5761455
sum(is_prime(xs, i) for i in range(n+1))
5761455

编辑:理解其工作原理的关键是筛子会产生重复图案。对于素数 2 和 3,模式每 2*3 或 6 个数字重复一次,并且在这 6 个中,4 个不可能成为素数,只剩下 2 个。在选择素数来生成模式方面没有任何限制,除了也许是递减规律returns。我决定尝试将 5 添加到组合中,使模式每 2*3*5=30 个数字重复一次。在这 30 个数字中,只有 8 个可以是质数,这意味着每个字节可以跟踪 30 个数字,而不是上面的 24 个!这使您在内存使用方面有 20% 的优势。

这是更新后的代码。我也稍微简化了一下,并在进行时去掉了素数计数。

remainder_bit30 = [0,    0x01, 0,    0,    0,    0,    0, 0x02, 0,    0,
                   0,    0x04, 0,    0x08, 0,    0,    0, 0x10, 0,    0x20,
                   0,    0,    0,    0x40, 0,    0,    0, 0,    0,    0x80]

def is_prime(xs, a):
    if a <= 5:
        return (a > 1) and (a != 4)
    index, rem = divmod(a, 30)
    bit = remainder_bit30[rem]
    return (bit != 0) and not (xs[index] & bit)

def sieve_of_eratosthenes(xs):
    n = 30 * len(xs) - 1
    p = 0
    while p*p < n:
        for offset in (1, 7, 11, 13, 17, 19, 23, 29):
            p += offset
            if is_prime(xs, p):
                for i in range(p * p, n + 1, p):
                    index, rem = divmod(i, 30)
                    if index < len(xs):
                        bit = remainder_bit30[rem]
                        xs[index] |= bit
            p -= offset
        p += 30

def init_sieve(n):
    b = bytearray((n + 30) // 30)
    return b

这里是一个分段筛选方法的例子,内存不应超过 8MB。

def primeSieve(n,X,window=10**6): 
    primes     = []       # only store minimum number of primes to shift windows
    primeCount = 0        # count primes beyond the ones stored
    flags      = list(X)  # numbers will be replaced by 0 or 1 as we progress
    base       = 1        # number corresponding to 1st element of sieve
    isPrime    = [False]+[True]*(window-1) # starting sieve
    
    def flagPrimes(): # flag x values for current sieve window
        flags[:] = [isPrime[x-base]*1 if x in range(base,base+window) else x
                    for x in flags]
    for p in (2,*range(3,n+1,2)):       # potential primes: 2 and odd numbers
        if p >= base+window:            # shift sieve window as needed
            flagPrimes()                # set X flags before shifting window
            isPrime = [True]*window     # initialize next sieve window
            base    = p                 # 1st number in window
            for k in primes:            # update sieve using known primes 
                if k>base+window:break
                i = (k-base%k)%k + k*(k==p)  
                isPrime[i::k] = (False for _ in range(i,window,k))
        if not isPrime[p-base]: continue
        primeCount += 1                 # count primes 
        if p*p<=n:primes.append(p)      # store shifting primes, update sieve
        isPrime[p*p-base::p] = (False for _ in range(p*p-base,window,p))

    flagPrimes() # update flags with last window (should cover the rest of them)
    return primeCount,flags     
        

输出:

print(*primeSieve(9973,[1,2,3,4,9972,9973]))
# 1229, [0, 1, 1, 0, 0, 1]

print(*primeSieve(10**8,[1,2,3,4,9972,9973,1000331]))
# 5761455 [0, 1, 1, 0, 0, 1, 0]

您可以使用 window 大小来在执行时间和内存消耗之间取得最佳平衡。尽管 n 的大值的执行时间(在我的笔记本电脑上)仍然相当长:

from timeit import timeit
for w in range(3,9):
    t = timeit(lambda:primeSieve(10**8,[],10**w),number=1)
    print(f"10e{w} window:",t)

10e3 window: 119.463959956
10e4 window: 33.33273301199999
10e5 window: 24.153761258999992
10e6 window: 24.649398391000005
10e7 window: 27.616014667
10e8 window: 27.919413531000004

奇怪的是,window 超过 10^6 的尺寸性能更差。最佳点似乎介于 10^5 和 10^6 之间。 window 的 10^7 无论如何都会超过您的 50MB 限制。

关于如何以内存高效的方式快速生成素数,我有了另一个想法。它基于与埃拉托色尼筛法相同的概念,但使用字典来保存每个素数将失效(即跳过)的下一个值。这只需要为每个质数存储一个字典条目,直到 n.

的平方根
def genPrimes(maxPrime):
    if maxPrime>=2: yield 2           # special processing for 2
    primeSkips = dict()               # skipValue:prime
    for n in range(3,maxPrime+1,2):
        if n not in primeSkips:       # if not in skip list, it is a new prime
            yield n
            if n*n <= maxPrime:       # first skip will be at n^2
                primeSkips[n*n] = n
            continue
        prime = primeSkips.pop(n)     # find next skip for n's prime
        skip  = n+2*prime
        while skip in primeSkips:     # must not already be skipped
            skip += 2*prime                
        if skip<=maxPrime:            # don't skip beyond maxPrime
            primeSkips[skip]=prime           

使用这个,primeSieve 函数可以简单地 运行 通过素数,计算它们,并标记 x 值:

def primeSieve(n,X):
    primeCount = 0
    nonPrimes  = set(X)
    for prime in genPrimes(n):
        primeCount += 1
        nonPrimes.discard(prime)
    return primeCount,[0 if x in nonPrimes else 1 for x in X]


print(*primeSieve(9973,[1,2,3,4,9972,9973]))
# 1229, [0, 1, 1, 0, 0, 1]

print(*primeSieve(10**8,[1,2,3,4,9972,9973,1000331]))
# 5761455 [0, 1, 1, 0, 0, 1, 0]

这个 运行 比我之前的回答稍微快一点,并且只消耗 78K 内存来生成最多 10^8 个素数(在 21 秒内)。