计算最多 18 位的素数优化

Counting primes up to 18 digits optimization

我在学校有一个任务,要在 2 分钟内数出最多 10^18 个素数,而且要使用的内存不超过 2 GB。对于第一次尝试,我实现了一个分段筛,并进行了以下优化:

问题是在我的电脑上(它有相当不错的规格)计算素数最多 10^9 需要 13 秒,因此 10^18 需要几天时间。

我的问题是,我是否遗漏了一些优化,或者是否有更好更快的方法来计算素数的数量?代码:

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <string.h>

typedef   signed          char  int8_t;
typedef   signed     short int int16_t;
typedef   signed           int int32_t;
typedef   signed long long int int64_t;

typedef unsigned          char  uint8_t;
typedef unsigned     short int uint16_t;
typedef unsigned           int uint32_t;
typedef unsigned long long int uint64_t;

#define  SIZE 32 
#define DEBUG

#define  KRED "\x1B[31m"
#define  KGRN "\x1B[32m"
#define  KYEL "\x1B[33m"
#define  KBLU "\x1B[34m"
#define  KMAG "\x1B[35m"
#define  KCYN "\x1B[36m"
#define  KWHT "\x1B[37m"
#define RESET "3[0m"

struct node {
    uint64_t     data;
    struct node* next;
};

struct queue {
    struct node* first;
    struct node* last;
    uint32_t     size; 
};

typedef struct node  Node;
typedef struct queue Queue;

/* Queue model */
uint8_t enqueue(Queue* queue, int64_t value) {
    Node* node = (Node*)malloc(sizeof(Node*));

    if (node == NULL)
        return 0;

    node->data = value;
    if (queue->last)
        queue->last->next = node;

    queue->last = node;
    if (queue->first == NULL)
        queue->first = queue->last;

    queue->size++;
    return 1;
}

uint64_t dequeue(Queue* queue) {
    Node*         node = queue->first;
    uint64_t save_data = node->data;

    if (queue->size == 0)
        return 0;

    queue->first = queue->first->next;
    queue->size--;
    free(node);

    return save_data;
}

Node* queue_peek(Queue* queue) {
    return queue->first;
}

uint32_t queue_size(Queue* queue) {
    return queue->size;
}

Queue* init_queue() {
    Queue* queue = (Queue*)malloc(sizeof(Queue*));

    queue->first = queue->last = NULL;
    queue->size  = 0;

    return queue;
}

/* Working with bit arrays functions */
uint8_t count_set_bits(uint64_t nbr) {
    uint8_t count = 0;

    while (nbr) {
        count++;
        nbr &= (nbr - 1);
    }

    return count;
}

uint8_t get_bit(uint32_t array[], uint32_t position) {
    const uint64_t mask = 1U << (position % SIZE);

    return array[position / SIZE] & mask ? 1 : 0;
}

void clear_bit(uint32_t array[], uint32_t position) {
    const uint64_t mask = ~(1U << (position % SIZE));

    array[position / SIZE] &= mask;
}

void set_bit(uint32_t array[], uint32_t position) {
    array[position / SIZE] |= (1U << (position % SIZE));
}

/* Solve the problem */
Queue* initial_sieve(uint64_t limit) {
    Queue*   queue   = init_queue();
    uint64_t _sqrt   = (uint64_t)sqrt(limit);
    uint32_t *primes = (uint32_t*)calloc(_sqrt / SIZE + 1, sizeof(uint32_t));

    set_bit(primes, 0);
    // working with reversed logic, otherwise primes should all me initialized to max uiint64_t

    enqueue(queue, 2);
    for (uint64_t number = 3; number <= _sqrt; number += 2) {
        if (get_bit(primes, number / 2) == 0) {
            enqueue(queue, number);

            for (uint64_t position = number * number; position <= _sqrt; position += (number * 2)) {    
                set_bit(primes, position / 2);
            }
        }
        else
            set_bit(primes, number / 2);
    }

    return queue;
}

uint64_t count_primes(uint64_t limit) {
    uint64_t start, end, delta;
    uint64_t non_primes_counter, initial_size;
    uint32_t *current_sieve;
    Queue* queue;

    queue = initial_sieve(limit);
    initial_size = queue->size;
    start = delta = (uint64_t)sqrt(limit);
    end   = 2 * start;
    non_primes_counter = 0;

    printf("Limits: %llu -> %llu\n", start, end);
    while (start < limit) {
        Node*    prime = queue->first->next; // pass 2 since only odd maps are represented in the sieve
        uint64_t count = 0;

        current_sieve = (uint32_t*)calloc(delta / SIZE + 1, sizeof(uint32_t));
        // memset(current_sieve, 0, sizeof(uint32_t) * delta);

        while (prime != NULL) {
            uint64_t first_composite = start / prime->data * prime->data;

            // calculate the first multiple of the given prime in the interval
            if (first_composite < start)
                first_composite += prime->data;
            if ((first_composite & 1) == 0)
                first_composite += prime->data;

            // set all the composites of the current prime in the given interval
            for (uint64_t number = first_composite; number <= end; number += (prime->data) * 2) {
                const uint64_t position = (number - start) / 2;

                if (get_bit(current_sieve, position) == 0) {
                    set_bit(current_sieve, position);
                    count++;
                }
            }

            // free(current_sieve);
            prime = prime->next;
        }

        non_primes_counter += count;
        start += delta;
        end   += delta;

        if (end > limit)
            end = limit;
    }

    uint64_t total = (limit - delta + 1) / 2 - non_primes_counter;

    printf("%sTotal composites and initial size: %llu %llu %s\n", KCYN, non_primes_counter, initial_size, RESET);
    printf("%sTotal primes: %llu %s\n", KCYN, total, RESET);
    return queue->size + (limit  - delta + 1) / 2 - non_primes_counter;
}

/* Main */
int main(int argc, char **argv) {
    clock_t begin, end;
    double  time;

    if (argc < 2) {
        printf("Invalid number of parameters\n");
        printf("Program will exit now.\n");
        return 0;
    }

    begin = clock();
    printf("%sNumber of primes found up to %s%s: %s%llu.\n%s", KWHT, KCYN, argv[1], KYEL, count_primes(atoll(argv[1])), RESET);
    end     = clock();
    time    = (double)(end - begin) / CLOCKS_PER_SEC;

    printf("%sTotal time elapsed since the starting of the program: %s%lf seconds.\n%s", KWHT, KYEL, time, RESET);  
    return 0;
}

谢谢,马库斯

你需要数出素数的个数,而不是把它们都找出来(太多了)。这叫做Prime-counting function

In mathematics, the prime-counting function is the function counting the number of prime numbers less than or equal to some real number x. It is denoted by π(x).

有很多方法可以计算这个函数。用方法比较看这个Wolfram page。两分钟之内完成这件事,似乎是一项艰巨的任务。

正如评论中提到的,还有一个很棒的answer at math.stackexchange,我认为它会有所帮助。