Lock Free stack 实现思路 - 目前已损坏

Lock Free stack implementation idea - currently broken

我想出了一个想法,我正在尝试实现一个不依赖引用计数来解决 ABA 问题的无锁堆栈,并且还可以正确处理内存回收。它在概念上类似于 RCU,并依赖于两个功能:将列表条目标记为已删除,以及跟踪遍历列表的读者。前者很简单,它只是使用指针的LSB。后者是我 "clever" 尝试实现无限制无锁堆栈的方法。

基本上,当任何线程尝试遍历列表时,一个原子计数器 (list.entries) 会递增。遍历完成后,第二个计数器 (list.exits) 递增。

节点分配由 push 处理,释放由 pop 处理。

push 和 pop 操作与朴素的无锁堆栈实现非常相似,但必须遍历标记为删除的节点才能到达未标记的条目。因此,推送基本上很像链表插入。

pop操作同样是遍历链表,只是在遍历时使用atomic_fetch_or将节点标记为移除,直到到达未标记的节点。

遍历0个或多个标记节点的列表后,正在弹出的线程将尝试CAS堆栈的头部。至少有一个线程并发出栈成功,此时所有入栈的读者将不再看到之前标记的节点。

成功更新列表的线程然后加载原子 list.entries,并且基本上自旋加载 atomic.exits 直到该计数器最终超过 list.entries。这应该意味着列表 "old" 版本的所有读者都已完成。该线程然后简单地释放它从列表顶部换出的标记节点列表。

所以 pop 操作的含义应该是(我认为)不存在 ABA 问题,因为释放的节点在使用它们的所有并发读取器完成之前不会返回到可用的指针池,并且显然,出于同样的原因,内存回收问题也得到了处理。

总之,这只是理论,但我仍在摸索实现,因为它目前无法正常工作(在多线程情况下)。似乎我得到了一些免费问题等,但我无法发现问题,或者我的假设有缺陷,它根本行不通。

如能提供有关概念和代码调试方法的任何见解,我们将不胜感激。

这是我当前的(损坏的)代码(用 gcc -D_GNU_SOURCE -std=c11 -Wall -O0 -g -pthread -o list list.c 编译):

#include <pthread.h>
#include <stdatomic.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>

#include <sys/resource.h>

#include <stdio.h>
#include <unistd.h>

#define NUM_THREADS 8
#define NUM_OPS (1024 * 1024)

typedef uint64_t list_data_t;

typedef struct list_node_t {
    struct list_node_t * _Atomic next;
    list_data_t data;
} list_node_t;

typedef struct {
    list_node_t * _Atomic head;
    int64_t _Atomic size;
    uint64_t _Atomic entries;
    uint64_t _Atomic exits;
} list_t;

enum {
    NODE_IDLE    = (0x0),
    NODE_REMOVED = (0x1 << 0),
    NODE_FREED   = (0x1 << 1),
    NODE_FLAGS    = (0x3),
};

static __thread struct {
    uint64_t add_count;
    uint64_t remove_count;
    uint64_t added;
    uint64_t removed;
    uint64_t mallocd;
    uint64_t freed;
} stats;

#define NODE_IS_SET(p, f) (((uintptr_t)p & f) == f)
#define NODE_SET_FLAG(p, f) ((void *)((uintptr_t)p | f))
#define NODE_CLR_FLAG(p, f) ((void *)((uintptr_t)p & ~f))
#define NODE_POINTER(p) ((void *)((uintptr_t)p & ~NODE_FLAGS))

list_node_t * list_node_new(list_data_t data)
{
    list_node_t * new = malloc(sizeof(*new));
    new->data = data;
    stats.mallocd++;

    return new;
}

void list_node_free(list_node_t * node)
{
    free(node);
    stats.freed++;
}

static void list_add(list_t * list, list_data_t data)
{
    atomic_fetch_add_explicit(&list->entries, 1, memory_order_seq_cst);

    list_node_t * new = list_node_new(data);
    list_node_t * _Atomic * next = &list->head;
    list_node_t * current = atomic_load_explicit(next,  memory_order_seq_cst);
    do
    {
        stats.add_count++;
        while ((NODE_POINTER(current) != NULL) &&
                NODE_IS_SET(current, NODE_REMOVED))
        {
                stats.add_count++;
                current = NODE_POINTER(current);
                next = &current->next;
                current = atomic_load_explicit(next, memory_order_seq_cst);
        }
        atomic_store_explicit(&new->next, current, memory_order_seq_cst);
    }
    while(!atomic_compare_exchange_weak_explicit(
            next, &current, new,
            memory_order_seq_cst, memory_order_seq_cst));

    atomic_fetch_add_explicit(&list->exits, 1, memory_order_seq_cst);
    atomic_fetch_add_explicit(&list->size, 1, memory_order_seq_cst);
    stats.added++;
}

static bool list_remove(list_t * list, list_data_t * pData)
{
    uint64_t entries = atomic_fetch_add_explicit(
            &list->entries, 1, memory_order_seq_cst);

    list_node_t * start = atomic_fetch_or_explicit(
            &list->head, NODE_REMOVED, memory_order_seq_cst);
    list_node_t * current = start;

    stats.remove_count++;
    while ((NODE_POINTER(current) != NULL) &&
            NODE_IS_SET(current, NODE_REMOVED))
    {
        stats.remove_count++;
        current = NODE_POINTER(current);
        current = atomic_fetch_or_explicit(&current->next,
                NODE_REMOVED, memory_order_seq_cst);
    }

    uint64_t exits = atomic_fetch_add_explicit(
            &list->exits, 1, memory_order_seq_cst) + 1;

    bool result = false;
    current = NODE_POINTER(current);
    if (current != NULL)
    {
        result = true;
        *pData = current->data;

        current = atomic_load_explicit(
                &current->next, memory_order_seq_cst);

        atomic_fetch_add_explicit(&list->size,
                -1, memory_order_seq_cst);

        stats.removed++;
    }

    start = NODE_SET_FLAG(start, NODE_REMOVED);
    if (atomic_compare_exchange_strong_explicit(
            &list->head, &start, current,
            memory_order_seq_cst, memory_order_seq_cst))
    {
        entries = atomic_load_explicit(&list->entries, memory_order_seq_cst);
        while ((int64_t)(entries - exits) > 0)
        {
            pthread_yield();
            exits = atomic_load_explicit(&list->exits, memory_order_seq_cst);
        }

        list_node_t * end = NODE_POINTER(current);
        list_node_t * current = NODE_POINTER(start);
        while (current != end)
        {
            list_node_t * tmp = current;
            current = atomic_load_explicit(&current->next, memory_order_seq_cst);
            list_node_free(tmp);
            current = NODE_POINTER(current);
        }
    }

    return result;
}

static list_t list;

pthread_mutex_t ioLock = PTHREAD_MUTEX_INITIALIZER;

void * thread_entry(void * arg)
{
    sleep(2);
    int id = *(int *)arg;

    for (int i = 0; i < NUM_OPS; i++)
    {
        bool insert = random() % 2;

        if (insert)
        {
            list_add(&list, i);
        }
        else
        {
            list_data_t data;
            list_remove(&list, &data);
        }
    }

    struct rusage u;
    getrusage(RUSAGE_THREAD, &u);

    pthread_mutex_lock(&ioLock);
    printf("Thread %d stats:\n", id);
    printf("\tadded = %lu\n", stats.added);
    printf("\tremoved = %lu\n", stats.removed);
    printf("\ttotal added = %ld\n", (int64_t)(stats.added - stats.removed));
    printf("\tadded count = %lu\n", stats.add_count);
    printf("\tremoved count = %lu\n", stats.remove_count);
    printf("\tadd average = %f\n", (float)stats.add_count / stats.added);
    printf("\tremove average = %f\n", (float)stats.remove_count / stats.removed);
    printf("\tmallocd = %lu\n", stats.mallocd);
    printf("\tfreed = %lu\n", stats.freed);
    printf("\ttotal mallocd = %ld\n", (int64_t)(stats.mallocd - stats.freed));
    printf("\tutime = %f\n", u.ru_utime.tv_sec
            + u.ru_utime.tv_usec / 1000000.0f);
    printf("\tstime = %f\n", u.ru_stime.tv_sec
                    + u.ru_stime.tv_usec / 1000000.0f);
    pthread_mutex_unlock(&ioLock);

    return NULL;
}

int main(int argc, char ** argv)
{
    struct {
            pthread_t thread;
            int id;
    }
    threads[NUM_THREADS];
    for (int i = 0; i < NUM_THREADS; i++)
    {
        threads[i].id = i;
        pthread_create(&threads[i].thread, NULL, thread_entry, &threads[i].id);
    }

    for (int i = 0; i < NUM_THREADS; i++)
    {
        pthread_join(threads[i].thread, NULL);
    }

    printf("Size = %ld\n", atomic_load(&list.size));

    uint32_t count = 0;

    list_data_t data;
    while(list_remove(&list, &data))
    {
        count++;
    }
    printf("Removed %u\n", count);
}

您提到您正在尝试解决 ABA 问题,但描述和代码实际上是在尝试解决更难的问题:memory reclamation 问题。

这个问题通常出现在 "deletion" 无锁集合的功能中,这些功能是用没有垃圾收集的语言实现的。核心问题是从共享结构中删除节点的线程通常不知道何时可以安全地释放已删除的节点,因为其他读取可能仍然引用它。经常解决这个问题,作为副作用, 解决了 ABA 问题:特别是关于 CAS 操作成功,即使底层指针(和对象的状态)已经改变同时至少两次,以原始 value 结束,但呈现出完全不同的状态。

从某种意义上说,ABA 问题更容易解决,因为 ABA 问题有几个直接的解决方案,特别是不会导致 "memory reclamation" 问题的解决方案。从某种意义上说,可以检测到位置修改的硬件也更容易,例如,使用 LL/SC 或事务内存原语,可能根本不会出现问题。

所以说,你正在寻找内存回收问题的解决方案,它也将避免 ABA 问题。

你的问题的核心是这个陈述:

The thread that successfully updates the list then loads the atomic list.entries, and basically spin-loads atomic.exits until that counter finally exceeds list.entries. This should imply that all readers of the "old" version of the list have completed. The thread then simply frees the the list of marked nodes that it swapped off the top of the list.

这个逻辑不成立。等待 list.exits(你说 atomic.exits 但我认为这是一个错字,因为你只在其他地方谈论 list.exits)大于 list.entries 只告诉您现在 总出口 入口 多,此时变异线程捕获了入口计数。但是,这些出口可能是新读者来来去去产生的:这并不意味着所有老读者都已完成如您所说!

这是一个简单的例子。首先写线程T1和读线程T2大约同时访问链表,所以list.entries是2,list.exits是0。写线程弹出一个节点,并保存 list.entries 的当前值 (2) 并等待 lists.exits 大于 2。现在还有三个读取线程,T3T4T5 到达并快速阅读列表并离开。现在 lists.exits 为 3,满足您的条件并且 T1 释放节点。 T2 虽然没有去任何地方,但由于它正在读取一个已释放的节点而爆炸了!

你的基本想法是可行的,但你的两种反击方法肯定行不通。

这是一个经过充分研究的问题,因此您不必发明自己的算法(请参阅上面的 link),甚至不必编写自己的代码,因为 librcu and concurrencykit 之类的东西已经存在存在。

用于教育目的

如果您想要将此工作用于教育目的,一种方法是使用确保修改后进入的线程已经开始使用一组不同的 list.entry/exit 计数器。一种方法是生成计数器,当作者想要修改列表时,它会增加生成计数器,这会导致新读者切换到另一组 list.entry/exit 计数器。

现在笔者只需要等待list.entry[old] == list.exists[old],也就是说的读者都已经离开了。您也可以每代只使用一个计数器:您实际上并不是两个 entry/exit 计数器(尽管它可能有助于减少争用)。

当然,您知道有一个新问题来管理每代单独计数器的列表...这看起来有点像构建无锁列表的原始问题!不过这个问题要简单一些,因为你可能会对代数 "in flight" 设置一些合理的界限并预先分配它们,或者你可能会实现一种更容易推理的有限类型的无锁列表about 因为增删改查只发生在头部或尾部。