MATLAB 中的高效树实现

Efficient tree implementation in MATLAB

MATLAB 中的树 class

我正在 MATLAB 中实现树数据结构。向树中添加新的 child 节点,分配和更新与节点相关的数据值是我期望执行的典型操作。每个节点都有相同类型的 data 与之关联。删除节点对我来说不是必需的。到目前为止,我已经决定从 handle class 继承的 class 实现能够将对周围节点的引用传递给将修改树的函数。

编辑:12 月 2 日

首先,感谢迄今为止评论和回答中的所有建议。他们已经帮助我改进了我的树 class.

有人建议尝试 digraph 在 R2015b 中引入。我还没有探索这个,但看到它不能像从 handle 继承的 class 那样作为参考参数工作,我有点怀疑它在我的应用程序中的工作方式。在这一点上,我还不清楚使用自定义 data 节点和边来使用它会有多容易。

编辑:(12 月 3 日)有关主要应用程序的更多信息:MCTS

最初,我认为主要应用程序的细节只是无关紧要的,但自从阅读了@FirefoxMetzger 的评论和 后,我意识到它具有重要意义。

我正在实施一种 Monte Carlo tree search 算法。以迭代方式探索和扩展搜索树。维基百科提供了一个很好的过程图形概述:

在我的应用程序中,我执行了大量的搜索迭代。在每次搜索迭代中,我从根开始遍历当前树直到叶节点,然后通过添加新节点扩展树,并重复。由于该方法基于 运行dom 采样,在每次迭代开始时我 不知道每次迭代我将在哪个叶节点完成 。相反,这由树中当前的 data 个节点和 运行dom 样本的结果共同确定。我在单次迭代期间访问的任何节点都会更新 data

例子:我在节点n有几个children。我需要访问每个 children 中的数据并绘制一个 运行dom 样本,以确定我在搜索中移动到下一个 children 中的哪个。重复此操作直到到达叶节点。实际上,我是通过在根节点上调用一个 search 函数来决定接下来要扩展哪个 child,在该节点上递归调用 search,依此类推,最后返回一个值一次到达叶节点。从递归函数返回时使用此值来更新搜索迭代期间访问的节点的 data

树可能非常不平衡,因此一些 b运行 节点链很长,而其他节点在根级别后很快终止并且不会进一步扩展。

当前实施

下面是我当前实现的一个示例,其中包含一些用于添加节点、查询树中节点的深度或数量等的成员函数示例。

classdef stree < handle
    %   A class for a tree object that acts like a reference
    %   parameter.
    %   The tree can be traversed in both directions by using the parent
    %   and children information.
    %   New nodes can be added to the tree. The object will automatically
    %   keep track of the number of nodes in the tree and increment the
    %   storage space as necessary.

    properties (SetAccess = private)
        % Hold the data at each node
        Node = { [] };
        % Index of the parent node. The root of the tree as a parent index
        % equal to 0.
        Parent = 0;
        num_nodes = 0;
        size_increment = 1;
        maxSize = 1;
    end

    methods
        function [obj, root_ID] = stree(data, init_siz)
            % New object with only root content, with specified initial
            % size
            obj.Node = repmat({ data },init_siz,1);
            obj.Parent = zeros(init_siz,1);
            root_ID = 1;
            obj.num_nodes = 1;
            obj.size_increment = init_siz;
            obj.maxSize = numel(obj.Parent);
        end

        function ID = addnode(obj, parent, data)
            % Add child node to specified parent
            if obj.num_nodes < obj.maxSize
                % still have room for data
                idx = obj.num_nodes + 1;
                obj.Node{idx} = data;
                obj.Parent(idx) = parent;
                obj.num_nodes = idx;
            else
                % all preallocated elements are in use, reserve more memory
                obj.Node = [
                    obj.Node
                    repmat({data},obj.size_increment,1)
                    ];

                obj.Parent = [
                    obj.Parent
                    parent
                    zeros(obj.size_increment-1,1)];
                obj.num_nodes = obj.num_nodes + 1;

                obj.maxSize = numel(obj.Parent);

            end
            ID = obj.num_nodes;
        end

        function content = get(obj, ID)
            %% GET  Return the contents of the given node IDs.
            content = [obj.Node{ID}];
        end

        function obj = set(obj, ID, content)
            %% SET  Set the content of given node ID and return the modifed tree.
            obj.Node{ID} = content;
        end

        function IDs = getchildren(obj, ID)
            % GETCHILDREN  Return the list of ID of the children of the given node ID.
            % The list is returned as a line vector.
            IDs = find( obj.Parent(1:obj.num_nodes) == ID );
            IDs = IDs';
        end
        function n = nnodes(obj)
            % NNODES  Return the number of nodes in the tree.
            % Equal to root + those whose parent is not root.
            n = 1 + sum(obj.Parent(1:obj.num_nodes) ~= 0);
            assert( obj.num_nodes == n);
        end

        function flag = isleaf(obj, ID)
            % ISLEAF  Return true if given ID matches a leaf node.
            % A leaf node is a node that has no children.
            flag = ~any( obj.Parent(1:obj.num_nodes) == ID );
        end

        function depth = depth(obj,ID)
            % DEPTH return depth of tree under ID. If ID is not given, use
            % root.
            if nargin == 1
                ID = 0;
            end
            if obj.isleaf(ID)
                depth = 0;
            else
                children = obj.getchildren(ID);
                NC = numel(children);
                d = 0; % Depth from here on out
                for k = 1:NC
                    d = max(d, obj.depth(children(k)));
                end
                depth = 1 + d;
            end
        end
    end
end

但是,有时性能很慢,对树的操作占用了我的大部分计算时间。有哪些具体的方式可以提高执行效率?如果有性能提升,甚至可以将实现更改为 handle 继承类型以外的其他类型。

当前实施的分析结果

由于向树添加新节点是最典型的操作(连同更新节点的 data),我对此做了一些 profiling。 我 运行 使用 Nd=6, Ns=10.

以下基准测试代码的分析器
function T = benchmark(Nd, Ns)
% Tree benchmark. Nd: tree depth, Ns: number of nodes per layer
% Initialize tree
T = stree(rand, 10000);
add_layers(1, Nd);
    function add_layers(node_id, num_layers)
        if num_layers == 0
            return;
        end
        child_id = zeros(Ns,1);
        for s = 1:Ns
            % add child to current node
            child_id(s) = T.addnode(node_id, rand);

            % recursively increase depth under child_id(s)
            add_layers(child_id(s), num_layers-1);
        end
    end
end

探查器的结果:

R2015b 性能


发现R2015b improves the performance of MATLAB's OOP features。我重做了上面的基准测试,确实观察到性能有所提高:

所以这已经是个好消息,尽管当然可以接受进一步的改进 ;)

以不同方式保留内存

评论中也有人建议使用

obj.Node = [obj.Node; data; cell(obj.size_increment - 1,1)];

保留更多内存而不是当前使用 repmat 的方法。这略微提高了性能。我应该注意到我的基准代码是针对虚拟数据的,并且由于实际的 data 更复杂,这可能会有所帮助。谢谢!探查器结果如下:

关于进一步提高性能的问题

  1. 也许有另一种更有效的方法来维护树的内存?遗憾的是,我通常无法提前知道树中会有多少个节点。
  2. 添加新节点和修改现有节点的data是我对树最典型的操作。截至目前,它们实际上占用了我 m 的大部分处理时间在申请中。欢迎对这些功能进行任何改进。

最后一点,我希望保持纯 MATLAB 的实现。但是,诸如 MEX 或使用某些集成 Java 功能之类的选项可能是可以接受的。

我知道这听起来很愚蠢...但是如何保留空闲节点数而不是节点总数?这将需要与一个常量(为零)进行比较,该常量是单个 属性 访问。

另一个巫术改进是将 .maxSize 移动到 .num_nodes 附近,并将这两个 放在 .Node 单元格之前。像这样,它们在内存中的位置不会因为 .Node 属性 的增长而相对于对象的开头发生变化(这里的巫术是我猜测 MATLAB 中对象的内部实现)。

稍后编辑 当我将 .Node 移动到 属性 列表末尾进行分析时,大部分执行时间都被扩展消耗了.Node 属性,符合预期(5.45 秒,而您提到的比较为 1.25 秒)。

您可以尝试分配与您实际填充的元素数量成正比的元素数量:这是 std::vector 在 c++

中的标准实现
obj.Node = [obj.Node; data; cell(q * obj.num_nodes,1)];

我记不太清了,但在 MSCC 中 q 是 1,而在 GCC 中是 .75。


这是使用 Java 的解决方案。我不太喜欢它,但它完成了它的工作。我实现了您从维基百科中提取的示例。

import javax.swing.tree.DefaultMutableTreeNode

% Let's create our example tree
top = DefaultMutableTreeNode([11,21])
n1 = DefaultMutableTreeNode([7,10])
top.add(n1)
n2 = DefaultMutableTreeNode([2,4])
n1.add(n2)
n2 = DefaultMutableTreeNode([5,6])
n1.add(n2)
n3 = DefaultMutableTreeNode([2,3])
n2.add(n3)
n3 = DefaultMutableTreeNode([3,3])
n2.add(n3)
n1 = DefaultMutableTreeNode([4,8])
top.add(n1)
n2 = DefaultMutableTreeNode([1,2])
n1.add(n2)
n2 = DefaultMutableTreeNode([2,3])
n1.add(n2)
n2 = DefaultMutableTreeNode([2,3])
n1.add(n2)
n1 = DefaultMutableTreeNode([0,3])
top.add(n1)

% Element to look for, your implementation will be recursive
searching = [0 1 1];
idx = 1;
node(idx) = top;
for item = searching,
    % Java transposes the matrices, remember to transpose back when you are reading
    node(idx).getUserObject()'
    node(idx+1) = node(idx).getChildAt(item);
    idx = idx + 1;
end
node(idx).getUserObject()'

% We made a new test...
newdata = [0, 1]
newnode = DefaultMutableTreeNode(newdata)
% ...so we expand our tree at the last node we searched
node(idx).add(newnode)

% The change has to be propagated (this is where your recursion returns)
for it=length(node):-1:1,
    itnode=node(it);
    val = itnode.getUserObject()'
    newitemdata = val + newdata
    itnode.setUserObject(newitemdata)
end

% Let's see if the new values are correct
searching = [0 1 1 0];
idx = 1;
node(idx) = top;
for item = searching,
    node(idx).getUserObject()'
    node(idx+1) = node(idx).getChildAt(item);
    idx = idx + 1;
end
node(idx).getUserObject()'

TL:DR 您深度复制存储在每次插入中的全部数据,初始化 parentNode 单元格,使其比您期望的要大.

您的数据确实具有树结构,但是您没有在实施中使用它。相反,实现的代码是查找 table(实际上是 2 tables)的计算饥饿版本,它存储树的数据和关系数据。

我这样说的原因如下:

  • 要插入你调用 stree.addnote(parent, data),它会将所有数据存储在树 object stree 的字段中 Node = {}Parent = []
  • 你似乎事先知道你想访问树中的哪个元素,因为没有给出搜索代码(如果你使用 stree.getchild(ID),我有一些坏消息)
  • 一旦你处理了一个节点,你就可以使用 find() 来追溯它,这是一个列表搜索

这绝不意味着数据的实现很笨拙,它甚至可能是最好的,具体取决于您在做什么。但是它确实解释了您的内存分配问题并给出了如何解决这些问题的提示。


将数据保留为查找 table

存储数据的方法之一是保持底层查找table。如果您知道要修改的第一个元素的 ID 而无需搜索它 ,我只会这样做。这种情况下,您可以通过两个步骤使您的结构更加高效。

首先初始化你的数组更大然后你期望你需要存储数据。如果超出查找table的容量,则初始化一个新的,即X字段更大,并生成旧数据的deep-copy。如果您需要一次或两次扩展容量(在所有插入期间),这可能不是问题,但在您的情况下,会为每次插入制作深拷贝!

其次,我将更改内部结构并合并两个 tables NodeParent。这样做的原因是您代码中的 back-propagation 需要 O(depth_from_root * n),其中 n 是您的 table 中的节点数。这是因为 find() 将为每个 parent.

遍历整个 table

相反,您可以实现类似于

的东西
table = cell(n,1) % n bigger then expected value
end_pointer = 1 % simple pointer to the first free value

function insert(data,parent_ID)
    if end_pointer < numel(table)
        content.data = data;
        content.parent = parent_ID;
        table{end_pointer} = content;
        end_pointer = end_pointer + 1;
    else
        % need more space, make sure its enough this time
        table = [table cell(end_pointer,1)];
        insert(data,parent_ID);
    end
end

function content = get_value(ID)
    content = table(ID);
end

这使您可以立即访问 parent 的 ID,而无需先 find(),每一步节省 n 次迭代,因此负担变为 O(depth)。如果您不知道您的初始节点,那么您必须 find() 那个,其成本为 O(n)。

请注意,此结构不需要 is_leaf()depth()nnodes()get_children()。如果您仍然需要那些,我需要更深入地了解您要对数据做什么,因为这会极大地影响正确的结构。


树结构

如果您永远不知道第一个节点的 ID 并且因此 总是必须搜索 ,那么这个结构是有意义的。

好处是搜索任意音符的复杂度为 O(depth),因此搜索为 O(depth) 而不是 O(n),反向传播为 O(depth^2) 而不是 O(depth + n)。请注意,深度可以是任何东西,从完美平衡树的 log(n) 到退化树的 n,这可能取决于您的数据。

然而,为了提出正确的建议,我需要更多的洞察力,因为每种树结构都有自己的壁龛。从目前我所看到的情况来看,我建议使用一个不平衡的树,即 'sorted' 由想要的节点 parent 给出的简单顺序。这可能会根据

进一步优化
  • 是否可以根据您的数据定义总订单
  • 你如何处理双重值(相同的数据出现两次)
  • 你的数据规模是多少(千,百万,...)
  • 是一个始终与反向传播配对的查找/搜索
  • 'parent-child' 的链在您的数据上有多长(或者使用这个简单顺序的树的平衡度和深度)
  • 总是只有一个 parent 还是同一个元素用不同的 parents
  • 插入两次

我很乐意为上面的树提供示例代码,请给我留言。

编辑: 在您的情况下,不平衡的树(与执行 MCTS 并行构造)似乎是最佳选择。下面的代码假定数据在 statescore 中拆分,并且 state 是唯一的。如果不是,这仍然有效,但是可以进行优化以提高 MCTS 性能。

classdef node < handle
    % A node for a tree in a MCTS
    properties
        state = {}; %some state of the search space that identifies the node
        score = 0;
        childs = cell(50,1);
        num_childs = 0;
    end
    methods
        function obj = node(state)
            % for a new node simulate a score using MC
            obj.score = simulate_from(state); % TODO implement simulation state -> finish
            obj.state = state;
        end
        function value = update(obj)
            % update the this node using MC recursively
            if obj.num_childs == numel(obj.childs)
                % there are to many childs, we have to expand the table
                obj.childs = [obj.childs cell(obj.num_childs,1)];
            end
            if obj.do_exploration() || obj.num_childs == 0
                % explore a potential state
                state_to_explore = obj.explore();

                %check if state has already been visited
                terminate = false;
                idx = 1;
                while idx <= obj.num_childs && ~terminate
                    if obj.childs{idx}.state_equals(state_to_explore)
                        terminate = true;
                    end
                    idx = idx + 1;
                end

                %preform the according action based on search
                if idx > obj.num_childs
                    % state has never been visited
                    % this action terminates the update recursion 
                    % and creates a new leaf
                    obj.num_childs = obj.num_childs + 1;
                    obj.childs{obj.num_childs} = node(state_to_explore);
                    value = obj.childs{obj.num_childs}.calculate_value();
                    obj.update_score(value);
                else
                    % state has been visited at least once
                    value = obj.childs{idx}.update();
                    obj.update_score(value);
                end
            else
                % exploit what we know already
                best_idx = 1;
                for idx = 1:obj.num_childs
                    if obj.childs{idx}.score > obj.childs{best_idx}.score
                        best_idx = idx;
                    end
                end
                value = obj.childs{best_idx}.update();
                obj.update_score(value);
            end
            value = obj.calculate_value();
        end
        function state = explore(obj)
            %select a next state to explore, that may or may not be visited
            %TODO
        end
        function bool = do_exploration(obj)
            % decide if this node should be explored or exploited
            %TODO
        end
        function bool = state_equals(obj, test_state)
            % returns true if the nodes state is equal to test_state
            %TODO
        end
        function update_score(obj, value)
            % updates the score based on some value
            %TODO
        end
        function calculate_value(obj)
            % returns the value of this node to update previous nodes
            %TODO
        end
    end
end

对代码的一些评论:

  • 根据设置的不同,可能不需要 obj.calculate_value()。例如。如果它是某个值,可以通过单独评估 child 的分数来计算
  • 如果一个 state 可以有多个 parent,那么重用注释 object 并将其覆盖在结构
  • 中是有意义的
  • 因为每个 node 都知道它的所有 children 可以使用 node 作为根节点
  • 轻松生成子树
  • 搜索树(没有任何更新)是一个简单的递归贪婪搜索
  • 根据您搜索的分支因素,可能值得访问每个可能的 child 一次(在节点初始化时),然后再进行 randsample(obj.childs,1) 探索,因为这样可以避免复制 /重新分配 child 数组
  • parent 属性 被编码为递归更新树,在完成节点更新后将 value 传递给 parent
  • 我唯一一次重新分配内存是当单个节点有超过 50 childs any 我只为那个单独的节点重新分配

这应该 运行 快很多,因为它只关心选择树的任何部分,而不会触及任何其他部分。