做一个简单的 class 支持并发

Make an easy class support concurrency

我有一个简单的银行账户class需要通过以下测试

import sys
import threading
import time
import unittest

from bank_account import BankAccount


class BankAccountTest(unittest.TestCase):

    def test_can_handle_concurrent_transactions(self):
        account = BankAccount()
        account.open()
        account.deposit(1000)

        self.adjust_balance_concurrently(account)

        self.assertEqual(account.get_balance(), 1000)

    def adjust_balance_concurrently(self, account):
        def transact():
            account.deposit(5)
            time.sleep(0.001)
            account.withdraw(5)

        # Greatly improve the chance of an operation being interrupted
        # by thread switch, thus testing synchronization effectively
        try:
            sys.setswitchinterval(1e-12)
        except AttributeError:
            # For Python 2 compatibility
            sys.setcheckinterval(1)

        threads = [threading.Thread(target=transact) for _ in range(1000)]
        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()

我试图阅读线程文档,但我发现尝试将其应用于我的情况有点令人困惑。我试过的是这样的:

class BankAccount(Thread):
    def __init__(self):
        Thread.__init__(self)
        self.state = False
        self.balance = 0
        Thread().start()

    def get_balance(self):
        if self.state:
            return self.balance
        else:
            raise ValueError

    def open(self):
        self.state = True

    def deposit(self, amount):
        self.balance += amount

    def withdraw(self, amount):
        self.balance -= amount

显然是错误的。我的目的只是了解如何让 class 处理线程切换。如果我没有包含重要信息,请告诉我。

我们需要保证多个线程不能同时修改余额

deposit()函数,虽然看起来是单步操作,其实是多步操作。

old_balance = self.balance
new_balance = old_balance + deposit
self.balance = new_balance

如果线程切换发生在存款的中间,它可能会破坏数据。

例如假设线程 1 调用 deposit(10),线程 2 调用 deposit(20),初始余额为 100

# Inside thread 1
old_balance1 = self.balance      
new_balance1 = old_balance1 + 10 
# Thread switches to thread 2    
old_balance2 = self.balance      
new_balance2 = old_balance2 + 20 
self.balance = new_balance2      # balance = 120
# Thread switches back to thread 1
self.balance = new_balance1      # balance = 110

此处最终余额为 110,而本应为 130

解决方法是防止两个线程同时写入balance变量。我们可以利用 Locks 来完成这个。

import threading

class BankAccount:

    def open(self):
        self.balance = 0
        # initialize lock
        self.lock = threading.Lock()

    def deposit(self, val):
        # if another thread has acquired lock, block till it releases
        self.lock.acquire()
        self.balance += val
        self.lock.release()

    def withdraw(self, val):
        self.lock.acquire()
        self.balance -= val
        self.lock.release()

    def get_balance(self):
        return self.balance