在 cython 类 上使用手动深度复制会导致内存 overflow.Why?

Using manual deepcopy on cython classes causes memory overflow.Why?

我正在使用 MCTS 算法 开发棋盘游戏的智能代理。 Monte carlo 树搜索 (MCTS) 是人工智能中的一种流行方法,主要用于游戏(如围棋、国际象棋等)。在这种方法中,代理根据状态构建树,这些状态是选择当前状态允许的移动的结果。允许代理在有限的时间内搜索树。在此期间,Agent 将树扩展到最有希望(赢得游戏)的节点。 下图为过程:

有关更多信息,您可以查看此 link:

1 - http://www.cameronius.com/research/mcts/about/index.html

在树的根节点中,会有一个变量rootstate显示游戏的当前状态。当我们深入树时,rootstate 的深层副本用于模拟树状态(未来状态)。

我将此代码用于 gamestate class 的 deepcopy,因为 deepcopy 由于 pickle 协议的问题而无法与 cython 对象一起正常工作:

cdef class 游戏状态:

# ... other functions

def __deepcopy__(self,memo_dictionary):
    res = gamestate(self.size)
    res.PLAYERS = self.PLAYERS
    res.size = int(self.size)
    res.board = np.array(self.board, dtype=np.int32)
    res.white_groups = deepcopy(self.white_groups) # a module which checks if white player has won the game
    res.black_groups = deepcopy(self.black_groups) # a module which checks if black player has won the game
    # the black_groups and white_groups are also cython objects which the same deepcopy function is implemented for them
    # .... etc
    return res

每当 MCTS 迭代开始时,状态的深层副本就会存储在内存中。 出现的问题是在游戏开始, 每 1 秒的迭代次数在 2000 到 3000 之间,这是预期的,但是 随着游戏树扩展 ,每 1 秒的迭代次数 减少 到 1。当每次迭代花费更多时间时,情况会变得更糟 完成。
当我检查 内存使用情况 时,我注意到每次调用代理时它 从 0.6% 增加到 90%搜索。
我在 pure python 中实现了相同的算法,并且没有此类问题。所以我猜 __deepcopy__ 函数 导致了问题。我曾经被建议为 中的 cython 对象制作自己的 pickle 协议,但我对 pickle 模块不是很熟悉。 谁能建议我一些用于我的 cython 对象的协议来摆脱这个障碍。

编辑 2:

我添加了一些可能有帮助的代码部分。 下面的代码属于 class unionfinddeepcopy 用于 gamestate 中的 white_groupsblack_groups:

cdef class unionfind:
    cdef public:
        dict parent
        dict rank
        dict groups
        list ignored
    cdef __init__(self):
    # initialize variables ...

   def __deepcopy__(self, memo_dictionary):
       res = unionfind()
       res.parent = self.parent
       res.rank = self.rank
       res.groups = self.groups
       res.ignored = self.ignored
       return res

这是搜索功能,在允许的时间内运行:

cdef class mctsagent:
    def search(time_budget):
        cdef int num_rollouts = 0
        while (num_rollouts < time_budget):
          state_copy = deepcopy(self.rootstate)
          node, state = self.select_node(state_copy) # expansion runs inside the select_node function
          turn = state.turn()
          outcome = self.roll_out(state)
          self.backup(node, turn, outcome)
          num_rollouts += 1

这个问题大概是行

res.white_groups = deepcopy(self.white_groups) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups) # a module which checks if black player has won the game

您应该做的是使用第二个参数 memo_dictionary 调用 deepcopy。这是 deepcopy 是否已复制对象的记录。没有它 deepcopy 最终会多次复制同一个对象(因此会占用大量内存)

res.white_groups = deepcopy(self.white_groups, memo_dictionary) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups, memo_dictionary) # a module which checks if black player has won the game

If the __deepcopy__() implementation needs to make a deep copy of a component, it should call the deepcopy() function with the component as first argument and the memo dictionary as second argument.

(编辑:刚刚看到@Blckknght 已经在评论中指出了这一点)

(edit2: unionfind 看起来主要包含 Python 对象。作为 cdef class 而不仅仅是普通的 [=32= 可能没有很大的价值]。另外,你当前的 __deepcopy__ 因为它实际上并没有复制那些字典——你应该做 res.parent = deepcopy(self.parent, memo_dictionary) 等等。如果你只是把它变成一个普通的 class 这会自动执行)