Java 带节点过滤的通用树遍历

Java generic tree traversal with node filtering

我有一个通用的树结构。 我需要一种算法来遍历它,如果给定列表中不包含一些叶子,则将其删除。如果一个子树的所有叶子都被移除,那么整个子树也被移除。

示例树:

         0
       / |  \
      1  2   3
    /  \ |  / \
   4   5 6  7  8

要保留的叶子:{4, 6}

结果树:

         0
       / | 
      1  2 
    /    | 
   4     6  

输入数据结构包含在一个 HashMap 中,其中键是节点的父节点 ID,值是父节点下的节点列表(但不是递归所有子节点)。根节点的父id为空字符串。

Map<String, List<Node>> nodes;

class Node {
    private String id;
    //other members
    //getters, setters
}

我想,某种递归 DFS 遍历算法应该可以工作,但我找不到它是如何工作的。

I suppose, some kind of recursive DFS traversal algorithm should work

完全正确。以下是构建此算法的方法:

  • 观察到该任务具有递归结构:将其应用于树的任何分支对分支所做的事情与您想对整棵树所做的事情相同
  • 可以修剪或完全删除分支
  • 你的递归实现会return一个修剪过的分支;它将通过 returning null
  • 发出分支移除信号
  • 递归函数会检查传入的Node
  • 如果节点代表一片叶子,它的内容将根据我们希望保留的项目列表进行检查
  • 如果叶子不在"keep list",return null;否则 return 叶子
  • 对于非叶子分支调用递归方法,并检查其结果
  • 如果结果是null,从地图中移除相应的分支;否则,将分支替换为从调用 return 中删除的分支
  • 如果在检查所有分支时子节点的映射为空,return null

请注意,如果叶节点的 none 与 "keep" 列表匹配,则算法可以 return null。如果这是不可取的,请在递归实现的顶部添加一个额外的级别,并将顶层的 null return 替换为 return 空树。

我建议您尝试以下方法:

方法boolean removeRecursively(String id, Set<String> leavesToKeep)将从具有给定id的节点向下遍历到该分支叶子。

首先我们检查当前节点是否是叶子节点。如果叶子不在 leavesToKeep 集合中,我们将其删除并 return true,否则 return false。这是我们递归的基本情况。

如果节点不是叶节点,那么我们做这样的事情:

children.removeIf(n -> removeRecursively(n.id, leavesToKeep));

removeIf 是一种方便的 Java8 方法,用于删除满足给定谓词的所有元素。这意味着只有当它的所有子项也被删除时,该子项才会从列表中删除。因此,如果在 children.removeIf 调用后 children 列表为空,我们应该使 removeRecursively return 为真:

if (children.isEmpty()) {
    tree.remove(id);
    return true;
} else return false;

完整方法可能如下所示:

public static boolean removeRecursively(Map<String, List<Node>> tree, String id, Set<String> leavesToKeep) {
    List<Node> children = tree.get(id);
    if (children == null || children.isEmpty()) {
        if (!leavesToKeep.contains(id)) {
            tree.remove(id);
            return true;
        } else return false;
    }
    children.removeIf(n -> removeRecursively(tree, n.id, leavesToKeep));
    if (children.isEmpty()) {
        tree.remove(id);
        return true;
    } else return false;
}

其中 tree 是您描述的地图,id 是起始节点 ID,leavesToKeep 是要保留的一组 ID。

有界面树:

public static interface Tree<T> {
    public T getValue();

    public List<Tree<T>> children();

    public default boolean isLeaf() {
        return children().isEmpty();
    }

    public default boolean removeDeadBranches(Predicate<T> testLiveLeaf) {
        if (isLeaf()) {
            return testLiveLeaf.test(getValue());
        }
        boolean remainLife = false;
        for (Iterator<Tree<T>> it = children().iterator(); it.hasNext();) {
            if (it.next().removeDeadBranches(testLiveLeaf)) {
                remainLife = true;
            } else {
                it.remove();
            }
        }
        return remainLife;
    }
}
import com.google.common.collect.Lists;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;

public class FilterTreeNode {
    private Logger logger = LoggerFactory.getLogger(FilterTreeNode.class);
    private TreeNode node0;
    private List<String> targetNode = Lists.newArrayList("B1","D1");

    @Before
    public void init(){
        node0 = TreeNode.builder().nodeCode("0").nodeName("A").build();
        TreeNode node1 = TreeNode.builder().nodeCode("1").nodeName("B").build();
        TreeNode node2 = TreeNode.builder().nodeCode("2").nodeName("C").build();
        TreeNode node3 = TreeNode.builder().nodeCode("3").nodeName("D").build();

        TreeNode node4 = TreeNode.builder().nodeCode("4").nodeName("B1").build();
        TreeNode node5 = TreeNode.builder().nodeCode("5").nodeName("B2").build();
        TreeNode node6 = TreeNode.builder().nodeCode("6").nodeName("C1").build();
        TreeNode node7 = TreeNode.builder().nodeCode("7").nodeName("D1").build();
        TreeNode node8 = TreeNode.builder().nodeCode("8").nodeName("D2").build();

        node1.setChildren(Lists.newArrayList(node4,node5));
        node2.setChildren(Lists.newArrayList(node6));
        node3.setChildren(Lists.newArrayList(node7,node8));

        node0.setChildren(Lists.newArrayList(node1,node2,node3));
    }

    @Test
    public void filterTest(){
        logger.info("before filter node0: {}",node0);
        List<TreeNode> retNodes = filterNode(node0);
        if (retNodes.size() >0){
            node0.setChildren(retNodes);
        }
        logger.info("after filter node0: {}",node0);
    }

    private List<TreeNode> filterNode(TreeNode treeNode){

        List<TreeNode> nodes = treeNode.getChildren();
        List<TreeNode> newNodes = Lists.newArrayList();
        List<TreeNode> tagNodes = Lists.newArrayList();

        for(TreeNode node : nodes){
            if (targetNode.contains(node.getNodeName())){
                newNodes.add(node);
            }
            if (node.getChildren() != null && node.getChildren().size() >0){
                List<TreeNode> retNodes = filterNode(node);
                if (retNodes.size()>0){
                    node.setChildren(retNodes);
                }else {
                    node.setChildren(null);
                    tagNodes.add(node);
                }
            }
        }
        nodes.removeAll(tagNodes);
        return newNodes;
    }
}

对于任何寻找与 @Kiril 的树而不是地图类似答案的人:

tree/node class:

@Getter
@Setter
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class Tree {

    private TreeNode node;

    @Getter
    @Setter
    @AllArgsConstructor
    @NoArgsConstructor
    @Builder
    public static class TreeNode {
        private int id;
        private boolean isMatch;
        private List<TreeNode> nodes;
    }
}

和过滤方法:

public Tree filter(Tree tree, List<Integer> ids) {
    if (tree.getNode() == null) {
        return tree;
    }
    filterNode(tree.getNode(), ids);

    return tree;
}

private boolean filterNode(Tree.TreeNode node, List<Integer> idsToFilter) {
    boolean isMatch = idsToFilter.contains(node.getId());
    node.setMatch(isMatch);
    if (CollectionUtils.isEmpty(node.getNodes()) && !isMatch) {
        return true;
    }

    node.getNodes().removeIf(treeNode -> filterNode(treeNode, idsToFilter));

    return node.getNodes().size() == 0 && !isMatch;
}