关于 std::conditional_variable 。为什么这个代码片段停滞不前?

About `std::conditional_variable`. Why this code snippet stall?

为什么这个 code snippet 失速?

程序打算输出 firstsecondthird,而程序在打印 firstsecond 后停止。

#include <condition_variable>
#include <mutex>
#include <thread>
#include <functional>
#include <iostream>
#include <vector>

class Foo {
public:
    Foo() {
        
    }

    void first(std::function<void()> printFirst) 
    {
        {
            std::unique_lock<std::mutex> lk(mutex);
                         
            cv1.wait(lk, [this](){return 1==state;});

            doing = 1;
            // printFirst() outputs "first". Do not change or remove this line.
            printFirst();
                
            state = 2;
        }

        cv2.notify_one();
    }

    void second(std::function<void()> printSecond) 
    {
        {
            std::unique_lock<std::mutex> lk(mutex);
            if(state !=2 )
            {
                if((1 == state)&&(1 != doing))
                {
                    lk.unlock();
                    cv1.notify_one();
                }
            }
                        
            cv2.wait(lk, [this](){return 2==state;});

            doing = 2;
            // printSecond() outputs "second". Do not change or remove this line.
            printSecond();
            
            state = 3;
        }

        cv3.notify_one();
    }

    void third(std::function<void()> printThird) 
    {
        {
            std::unique_lock<std::mutex> lk(mutex);
            if(state !=3 )
            {
                if((1 == state)&&(1 != doing))
                {
                    lk.unlock();
                    cv1.notify_one();
                }
                else if((2 == state)&&(2 != doing))
                {
                    lk.unlock();
                    cv2.notify_one();
                }
            }
                        
            cv3.wait(lk, [this](){return 3==state;});

            // printThird() outputs "third". Do not change or remove this line.
            printThird();
            
            state = 3;
        }
    }

private:
    std::condition_variable cv1;
    std::condition_variable cv2;
    std::condition_variable cv3;
    std::mutex mutex;
    int state{1};
    int doing{0};
};

int main()
{
    Foo foo;

    std::vector<std::thread> threads;

    std::this_thread::sleep_for(std::chrono::milliseconds(300));
    threads.push_back(std::thread([&](){foo.second([]()->void{std::cout <<"second" <<std::endl;});}));

    std::this_thread::sleep_for(std::chrono::milliseconds(300));
    threads.push_back(std::thread([&](){foo.first([]()->void{std::cout <<"first" <<std::endl;});}));

    std::this_thread::sleep_for(std::chrono::milliseconds(300));
    threads.push_back(std::thread([&](){foo.third([]()->void{std::cout <<"third" <<std::endl;});}));

    std::this_thread::sleep_for(std::chrono::seconds(2));

    for(auto itr=threads.begin(); itr!=threads.end(); itr++)
    {
        itr->join();
    }
}

引自@Igor Tandetnik 的评论。

根据 document

template void wait (unique_lock& lck, Predicate pred);

lck

A unique_lock object whose mutex object is currently locked by this thread. All concurrent calls to wait member functions of this object shall use the same underlying mutex object (as returned by lck.mutex()).

因此,cv2.wait(lk, ...) 要求 lk 实际上持有 mutex

如果删除 lk.unlock();,此 code snippet 可以按预期工作。

#include <condition_variable>
#include <mutex>
#include <thread>
#include <functional>
#include <iostream>
#include <vector>

class Foo {
public:
    Foo() {
        
    }

    void first(std::function<void()> printFirst) 
    {
        {
            std::unique_lock<std::mutex> lk(mutex);
                         
            cv1.wait(lk, [this](){return 1==state;});

            doing = 1;
            // printFirst() outputs "first". Do not change or remove this line.
            printFirst();
                
            state = 2;
        }

        cv2.notify_one();
    }

    void second(std::function<void()> printSecond) 
    {
        {
            std::unique_lock<std::mutex> lk(mutex);
            if(state !=2 )
            {
                if((1 == state)&&(1 != doing))
                {
                    //lk.unlock();  //removed
                    cv1.notify_one();
                }
            }
                        
            cv2.wait(lk, [this](){return 2==state;});

            doing = 2;
            // printSecond() outputs "second". Do not change or remove this line.
            printSecond();
            
            state = 3;
        }

        cv3.notify_one();
    }

    void third(std::function<void()> printThird) 
    {
        {
            std::unique_lock<std::mutex> lk(mutex);
            if(state !=3 )
            {
                if((1 == state)&&(1 != doing))
                {
                    //lk.unlock();  //removed
                    cv1.notify_one();
                }
                else if((2 == state)&&(2 != doing))
                {
                    //lk.unlock();  //removed
                    cv2.notify_one();
                }
            }
                        
            cv3.wait(lk, [this](){return 3==state;});

            // printThird() outputs "third". Do not change or remove this line.
            printThird();
            
            state = 3;
        }
    }

private:
    std::condition_variable cv1;
    std::condition_variable cv2;
    std::condition_variable cv3;
    std::mutex mutex;
    int state{1};
    int doing{0};
};

int main()
{
    Foo foo;

    std::vector<std::thread> threads;

    std::this_thread::sleep_for(std::chrono::milliseconds(300));
    threads.push_back(std::thread([&](){foo.second([]()->void{std::cout <<"second" <<std::endl;});}));

    std::this_thread::sleep_for(std::chrono::milliseconds(300));
    threads.push_back(std::thread([&](){foo.first([]()->void{std::cout <<"first" <<std::endl;});}));

    std::this_thread::sleep_for(std::chrono::milliseconds(300));
    threads.push_back(std::thread([&](){foo.third([]()->void{std::cout <<"third" <<std::endl;});}));

    std::this_thread::sleep_for(std::chrono::seconds(2));

    for(auto itr=threads.begin(); itr!=threads.end(); itr++)
    {
        itr->join();
    }
}

代码片段可以改进为:

#include <condition_variable>
#include <mutex>
#include <thread>
#include <functional>
#include <iostream>
#include <vector>
   
// @lc code=start
class Foo {
public:
    Foo() {
        
    }

    void first(std::function<void()> printFirst) 
    {
        {
            std::unique_lock<std::mutex> lk(mutex);

            // printFirst() outputs "first". Do not change or remove this line.
            printFirst();
                
            state = 2;
        }

        cv2.notify_one();
    }

    void second(std::function<void()> printSecond) 
    {
        {
            std::unique_lock<std::mutex> lk(mutex);
                 
            cv2.wait(lk, [this](){return 2==state;});

            // printSecond() outputs "second". Do not change or remove this line.
            printSecond();
            
            state = 3;
        }

        cv3.notify_one();
    }

    void third(std::function<void()> printThird) 
    {
        {
            std::unique_lock<std::mutex> lk(mutex);
                        
            cv3.wait(lk, [this](){return 3==state;});

            // printThird() outputs "third". Do not change or remove this line.
            printThird();
            
            state = 3;
        }
    }

private:
    std::condition_variable cv2;
    std::condition_variable cv3;
    std::mutex mutex;
    int state{1};
};
// @lc code=end

int main()
{
    Foo foo;

    std::vector<std::thread> threads;

    std::this_thread::sleep_for(std::chrono::milliseconds(300));
    threads.push_back(std::thread([&](){foo.second([]()->void{std::cout <<"second" <<std::endl;});}));

    std::this_thread::sleep_for(std::chrono::milliseconds(300));
    threads.push_back(std::thread([&](){foo.first([]()->void{std::cout <<"first" <<std::endl;});}));

    std::this_thread::sleep_for(std::chrono::milliseconds(300));
    threads.push_back(std::thread([&](){foo.third([]()->void{std::cout <<"third" <<std::endl;});}));

    std::this_thread::sleep_for(std::chrono::seconds(2));

    for(auto itr=threads.begin(); itr!=threads.end(); itr++)
    {
        itr->join();
    }
}