在 cython 类 上使用手动深度复制会导致内存 overflow.Why?
Using manual deepcopy on cython classes causes memory overflow.Why?
我正在使用 MCTS 算法 开发棋盘游戏的智能代理。
Monte carlo 树搜索 (MCTS) 是人工智能中的一种流行方法,主要用于游戏(如围棋、国际象棋等)。在这种方法中,代理根据状态构建树,这些状态是选择当前状态允许的移动的结果。允许代理在有限的时间内搜索树。在此期间,Agent 将树扩展到最有希望(赢得游戏)的节点。
下图为过程:
有关更多信息,您可以查看此 link:
1 - http://www.cameronius.com/research/mcts/about/index.html
在树的根节点中,会有一个变量rootstate
显示游戏的当前状态。当我们深入树时,rootstate
的深层副本用于模拟树状态(未来状态)。
我将此代码用于 gamestate
class 的 deepcopy
,因为 deepcopy
由于 pickle 协议的问题而无法与 cython 对象一起正常工作:
cdef class 游戏状态:
# ... other functions
def __deepcopy__(self,memo_dictionary):
res = gamestate(self.size)
res.PLAYERS = self.PLAYERS
res.size = int(self.size)
res.board = np.array(self.board, dtype=np.int32)
res.white_groups = deepcopy(self.white_groups) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups) # a module which checks if black player has won the game
# the black_groups and white_groups are also cython objects which the same deepcopy function is implemented for them
# .... etc
return res
每当 MCTS 迭代开始时,状态的深层副本就会存储在内存中。
出现的问题是在游戏开始,
每 1 秒的迭代次数在 2000 到 3000 之间,这是预期的,但是 随着游戏树扩展 ,每 1 秒的迭代次数 减少 到 1。当每次迭代花费更多时间时,情况会变得更糟
完成。
当我检查 内存使用情况 时,我注意到每次调用代理时它 从 0.6% 增加到 90%搜索。
我在 pure python 中实现了相同的算法,并且没有此类问题。所以我猜 __deepcopy__ 函数 导致了问题。我曾经被建议为 中的 cython 对象制作自己的 pickle
协议,但我对 pickle
模块不是很熟悉。
谁能建议我一些用于我的 cython 对象的协议来摆脱这个障碍。
编辑 2:
我添加了一些可能有帮助的代码部分。
下面的代码属于 class unionfind
的 deepcopy 用于 gamestate
中的 white_groups
和 black_groups
:
cdef class unionfind:
cdef public:
dict parent
dict rank
dict groups
list ignored
cdef __init__(self):
# initialize variables ...
def __deepcopy__(self, memo_dictionary):
res = unionfind()
res.parent = self.parent
res.rank = self.rank
res.groups = self.groups
res.ignored = self.ignored
return res
这是搜索功能,在允许的时间内运行:
cdef class mctsagent:
def search(time_budget):
cdef int num_rollouts = 0
while (num_rollouts < time_budget):
state_copy = deepcopy(self.rootstate)
node, state = self.select_node(state_copy) # expansion runs inside the select_node function
turn = state.turn()
outcome = self.roll_out(state)
self.backup(node, turn, outcome)
num_rollouts += 1
这个问题大概是行
res.white_groups = deepcopy(self.white_groups) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups) # a module which checks if black player has won the game
您应该做的是使用第二个参数 memo_dictionary
调用 deepcopy
。这是 deepcopy
是否已复制对象的记录。没有它 deepcopy
最终会多次复制同一个对象(因此会占用大量内存)
res.white_groups = deepcopy(self.white_groups, memo_dictionary) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups, memo_dictionary) # a module which checks if black player has won the game
(编辑:刚刚看到@Blckknght 已经在评论中指出了这一点)
(edit2: unionfind
看起来主要包含 Python 对象。作为 cdef class
而不仅仅是普通的 [=32= 可能没有很大的价值]。另外,你当前的 __deepcopy__
因为它实际上并没有复制那些字典——你应该做 res.parent = deepcopy(self.parent, memo_dictionary)
等等。如果你只是把它变成一个普通的 class 这会自动执行)
我正在使用 MCTS 算法 开发棋盘游戏的智能代理。 Monte carlo 树搜索 (MCTS) 是人工智能中的一种流行方法,主要用于游戏(如围棋、国际象棋等)。在这种方法中,代理根据状态构建树,这些状态是选择当前状态允许的移动的结果。允许代理在有限的时间内搜索树。在此期间,Agent 将树扩展到最有希望(赢得游戏)的节点。 下图为过程:
有关更多信息,您可以查看此 link:
1 - http://www.cameronius.com/research/mcts/about/index.html
在树的根节点中,会有一个变量rootstate
显示游戏的当前状态。当我们深入树时,rootstate
的深层副本用于模拟树状态(未来状态)。
我将此代码用于 gamestate
class 的 deepcopy
,因为 deepcopy
由于 pickle 协议的问题而无法与 cython 对象一起正常工作:
cdef class 游戏状态:
# ... other functions
def __deepcopy__(self,memo_dictionary):
res = gamestate(self.size)
res.PLAYERS = self.PLAYERS
res.size = int(self.size)
res.board = np.array(self.board, dtype=np.int32)
res.white_groups = deepcopy(self.white_groups) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups) # a module which checks if black player has won the game
# the black_groups and white_groups are also cython objects which the same deepcopy function is implemented for them
# .... etc
return res
每当 MCTS 迭代开始时,状态的深层副本就会存储在内存中。
出现的问题是在游戏开始,
每 1 秒的迭代次数在 2000 到 3000 之间,这是预期的,但是 随着游戏树扩展 ,每 1 秒的迭代次数 减少 到 1。当每次迭代花费更多时间时,情况会变得更糟
完成。
当我检查 内存使用情况 时,我注意到每次调用代理时它 从 0.6% 增加到 90%搜索。
我在 pure python 中实现了相同的算法,并且没有此类问题。所以我猜 __deepcopy__ 函数 导致了问题。我曾经被建议为 pickle
协议,但我对 pickle
模块不是很熟悉。
谁能建议我一些用于我的 cython 对象的协议来摆脱这个障碍。
编辑 2:
我添加了一些可能有帮助的代码部分。
下面的代码属于 class unionfind
的 deepcopy 用于 gamestate
中的 white_groups
和 black_groups
:
cdef class unionfind:
cdef public:
dict parent
dict rank
dict groups
list ignored
cdef __init__(self):
# initialize variables ...
def __deepcopy__(self, memo_dictionary):
res = unionfind()
res.parent = self.parent
res.rank = self.rank
res.groups = self.groups
res.ignored = self.ignored
return res
这是搜索功能,在允许的时间内运行:
cdef class mctsagent:
def search(time_budget):
cdef int num_rollouts = 0
while (num_rollouts < time_budget):
state_copy = deepcopy(self.rootstate)
node, state = self.select_node(state_copy) # expansion runs inside the select_node function
turn = state.turn()
outcome = self.roll_out(state)
self.backup(node, turn, outcome)
num_rollouts += 1
这个问题大概是行
res.white_groups = deepcopy(self.white_groups) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups) # a module which checks if black player has won the game
您应该做的是使用第二个参数 memo_dictionary
调用 deepcopy
。这是 deepcopy
是否已复制对象的记录。没有它 deepcopy
最终会多次复制同一个对象(因此会占用大量内存)
res.white_groups = deepcopy(self.white_groups, memo_dictionary) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups, memo_dictionary) # a module which checks if black player has won the game
(编辑:刚刚看到@Blckknght 已经在评论中指出了这一点)
(edit2: unionfind
看起来主要包含 Python 对象。作为 cdef class
而不仅仅是普通的 [=32= 可能没有很大的价值]。另外,你当前的 __deepcopy__
因为它实际上并没有复制那些字典——你应该做 res.parent = deepcopy(self.parent, memo_dictionary)
等等。如果你只是把它变成一个普通的 class 这会自动执行)