修剪 sklearn 决策树以确保单调
Prune sklearn decision tree to ensure monotony
我需要修剪 sklearn 决策树分类器,使指示的概率(图像右侧的值)单调递增。例如,如果你在 python 中编写一棵基本树,你有:
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.tree._tree import TREE_LEAF
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
iris = load_iris()
X, y = iris.data[:, 0].reshape(-1,1), np.where(iris.target==0,0,1)
tree = DecisionTreeClassifier(max_depth=3, random_state=123)
tree.fit(X,y)
percentages = tree.tree_.value[:,0,1]/np.sum(tree.tree_.value.reshape(-1,2), axis=1)
现在必须消除不符合指示的单调的叶子。
剩余如下:
虽然所示示例没有显示,但要考虑的规则是如果叶节点有不同的父节点,则保留数据量最大的叶节点。为了解决这个问题,我一直在尝试做一个蛮力算法,但它只执行第一次迭代,我需要将算法应用于更大的树。答案可能是使用递归,但是有了sklearn树结构,我真的不知道该怎么做。
执行以下操作可满足您建议的 p运行ing 要求:遍历树,识别 non-monotonic 个叶子,每次删除 [=[=] 的 non-monotonic 个叶子56=] 成员最少的节点并重复此操作,直到叶子之间的单调性得到维持。尽管这种 each-time-remove-one-node 方法增加了时间复杂度,但树的深度通常有限。会议论文 "Pruning for Monotone Classification Trees" 帮助我理解了树的单调性。然后我推导出这种方法来维持你的场景。
由于需要从左到右识别non-monotonic叶子,所以第一步是post-order traverse the tree。如果您不熟悉树遍历,这是完全正常的。我建议在了解功能之前先通过互联网资源了解它的机制。你可以 运行 遍历函数来查看它的结果。实际输出将帮助您理解。
#We will define a traversal algorithm which will scan the nodes and leaves from left to right
#The traversal is recursive, we declare global lists to collect values from each recursion
traversal=[] #List to collect traversal steps
parents=[]#List to collect the parents of the collected nodes or leaves
is_leaves=[] #List to collect if the collected traversal item are leaves or not
# A function to do postorder tree traversal
def postOrderTraversal(tree,root,parent):
if root!=-1:
#Recursion on left child
postOrderTraversal(tree,tree.tree_.children_left[root],root)
#Recursion on right child
postOrderTraversal(tree,tree.tree_.children_right[root],root)
traversal.append(root) #Collect the name of node or leaf
parents.append(parent) #Collect the parent of the collected node or leaf
is_leaves.append(is_leaf(tree,root)) #Collect if the collected object is leaf
上面我们递归调用了节点的左右children,这是通过decision tree structure提供的方法。使用的 is_leaf()
是一个辅助函数,如下所示。
def is_leaf(tree,node):
if tree.tree_.children_left[node]==-1:
return True
else:
return False
决策树节点总是有两个叶子。因此,仅检查左 child 是否存在会得出所讨论的 object 是节点还是叶子的信息。树 returns -1 如果 child 询问不存在。
由于您已经定义了 non-monotonicity 条件,因此需要叶子中 类 为 1 的比率。我称之为 positive_ratio()
(这就是你所说的“百分比”。)
def positive_ratio(tree): #The frequency of 1 values of leaves in binary classification tree:
#Number of samples with value 1 in leaves/total number of samples in nodes/leaves
return tree.tree_.value[:,0,1]/np.sum(tree.tree_.value.reshape(-1,2), axis=1)
returns 节点树索引(1、2、3 等)下的最终辅助函数具有最少的样本数。此函数需要其叶子表现出 non-monotonic 行为的节点列表。我们在这个辅助函数中调用 n_node_samples
属性 树结构。找到的节点就是要移除其叶子的节点。
def min_samples_node(tree, nodes): #Finds the node with the minimum number of samples among the provided list
#Make a dictionary of number of samples of given nodes, and their index in the nodes list
samples_dict={tree.tree_.n_node_samples[node]:i for i,node in enumerate(nodes)}
min_samples=min(samples_dict.keys()) #The minimum number of samples among the samples of nodes
i_min=samples_dict[min_samples] #Index of the node with minimum number of samples
return nodes[i_min] #The number of node with the minimum number of samples
定义辅助函数后,执行 p运行ing 的包装函数会迭代,直到维持树的单调性。返回所需的单调树。
def prune_nonmonotonic(tree): #Prune non-monotonic nodes of a binary classification tree
while True: #Repeat until monotonicity is sustained
#Clear the traversal lists for a new scan
traversal.clear()
parents.clear()
is_leaves.clear()
#Do a post-order traversal of tree so that the leaves will be returned in order from left to right
postOrderTraversal(tree,0,None)
#Filter the traversal outputs by keeping only leaves and leaving out the nodes
leaves=[traversal[i] for i,leaf in enumerate(is_leaves) if leaf == True]
leaves_parents=[parents[i] for i,leaf in enumerate(is_leaves) if leaf == True]
pos_ratio=positive_ratio(tree) #List of positive samples ratio of the nodes of binary classification tree
leaves_pos_ratio=[pos_ratio[i] for i in leaves] #List of positive samples ratio of the traversed leaves
#Detect the non-monotonic pairs by comparing the leaves side-by-side
nonmonotone_pairs=[[leaves[i],leaves[i+1]] for i,ratio in enumerate(leaves_pos_ratio[:-1]) if (ratio>=leaves_pos_ratio[i+1])]
#Make a flattened and unique list of leaves out of pairs
nonmonotone_leaves=[]
for pair in nonmonotone_pairs:
for leaf in pair:
if leaf not in nonmonotone_leaves:
nonmonotone_leaves.append(leaf)
if len(nonmonotone_leaves)==0: #If all leaves show monotonic properties, then break
break
#List the parent nodes of the non-monotonic leaves
nonmonotone_leaves_parents=[leaves_parents[i] for i in [leaves.index(leave) for leave in nonmonotone_leaves]]
node_min=min_samples_node(tree, nonmonotone_leaves_parents) #The node with minimum number of samples
#Prune the tree by removing the children of the detected non-monotonic and lowest number of samples node
tree.tree_.children_left[node_min]=-1
tree.tree_.children_right[node_min]=-1
return tree
所有包含“while”的循环一直持续到遍历的叶子不再显示 non-monotonicity 的迭代。 min_samples_node()
标识包含 non-monotonic 个叶子的节点,它是同类节点中成员最少的。当它的左右 children 被值“-1”替换时,树被 p运行ed 并且下一个“while”迭代将产生一个完全不同的树遍历来识别和删除剩余 non-monotonic 度。
下图分别显示了 unp运行ed 和 p运行ed 树。
我需要修剪 sklearn 决策树分类器,使指示的概率(图像右侧的值)单调递增。例如,如果你在 python 中编写一棵基本树,你有:
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.tree._tree import TREE_LEAF
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
iris = load_iris()
X, y = iris.data[:, 0].reshape(-1,1), np.where(iris.target==0,0,1)
tree = DecisionTreeClassifier(max_depth=3, random_state=123)
tree.fit(X,y)
percentages = tree.tree_.value[:,0,1]/np.sum(tree.tree_.value.reshape(-1,2), axis=1)
现在必须消除不符合指示的单调的叶子。
剩余如下:
虽然所示示例没有显示,但要考虑的规则是如果叶节点有不同的父节点,则保留数据量最大的叶节点。为了解决这个问题,我一直在尝试做一个蛮力算法,但它只执行第一次迭代,我需要将算法应用于更大的树。答案可能是使用递归,但是有了sklearn树结构,我真的不知道该怎么做。
执行以下操作可满足您建议的 p运行ing 要求:遍历树,识别 non-monotonic 个叶子,每次删除 [=[=] 的 non-monotonic 个叶子56=] 成员最少的节点并重复此操作,直到叶子之间的单调性得到维持。尽管这种 each-time-remove-one-node 方法增加了时间复杂度,但树的深度通常有限。会议论文 "Pruning for Monotone Classification Trees" 帮助我理解了树的单调性。然后我推导出这种方法来维持你的场景。
由于需要从左到右识别non-monotonic叶子,所以第一步是post-order traverse the tree。如果您不熟悉树遍历,这是完全正常的。我建议在了解功能之前先通过互联网资源了解它的机制。你可以 运行 遍历函数来查看它的结果。实际输出将帮助您理解。
#We will define a traversal algorithm which will scan the nodes and leaves from left to right
#The traversal is recursive, we declare global lists to collect values from each recursion
traversal=[] #List to collect traversal steps
parents=[]#List to collect the parents of the collected nodes or leaves
is_leaves=[] #List to collect if the collected traversal item are leaves or not
# A function to do postorder tree traversal
def postOrderTraversal(tree,root,parent):
if root!=-1:
#Recursion on left child
postOrderTraversal(tree,tree.tree_.children_left[root],root)
#Recursion on right child
postOrderTraversal(tree,tree.tree_.children_right[root],root)
traversal.append(root) #Collect the name of node or leaf
parents.append(parent) #Collect the parent of the collected node or leaf
is_leaves.append(is_leaf(tree,root)) #Collect if the collected object is leaf
上面我们递归调用了节点的左右children,这是通过decision tree structure提供的方法。使用的 is_leaf()
是一个辅助函数,如下所示。
def is_leaf(tree,node):
if tree.tree_.children_left[node]==-1:
return True
else:
return False
决策树节点总是有两个叶子。因此,仅检查左 child 是否存在会得出所讨论的 object 是节点还是叶子的信息。树 returns -1 如果 child 询问不存在。
由于您已经定义了 non-monotonicity 条件,因此需要叶子中 类 为 1 的比率。我称之为 positive_ratio()
(这就是你所说的“百分比”。)
def positive_ratio(tree): #The frequency of 1 values of leaves in binary classification tree:
#Number of samples with value 1 in leaves/total number of samples in nodes/leaves
return tree.tree_.value[:,0,1]/np.sum(tree.tree_.value.reshape(-1,2), axis=1)
returns 节点树索引(1、2、3 等)下的最终辅助函数具有最少的样本数。此函数需要其叶子表现出 non-monotonic 行为的节点列表。我们在这个辅助函数中调用 n_node_samples
属性 树结构。找到的节点就是要移除其叶子的节点。
def min_samples_node(tree, nodes): #Finds the node with the minimum number of samples among the provided list
#Make a dictionary of number of samples of given nodes, and their index in the nodes list
samples_dict={tree.tree_.n_node_samples[node]:i for i,node in enumerate(nodes)}
min_samples=min(samples_dict.keys()) #The minimum number of samples among the samples of nodes
i_min=samples_dict[min_samples] #Index of the node with minimum number of samples
return nodes[i_min] #The number of node with the minimum number of samples
定义辅助函数后,执行 p运行ing 的包装函数会迭代,直到维持树的单调性。返回所需的单调树。
def prune_nonmonotonic(tree): #Prune non-monotonic nodes of a binary classification tree
while True: #Repeat until monotonicity is sustained
#Clear the traversal lists for a new scan
traversal.clear()
parents.clear()
is_leaves.clear()
#Do a post-order traversal of tree so that the leaves will be returned in order from left to right
postOrderTraversal(tree,0,None)
#Filter the traversal outputs by keeping only leaves and leaving out the nodes
leaves=[traversal[i] for i,leaf in enumerate(is_leaves) if leaf == True]
leaves_parents=[parents[i] for i,leaf in enumerate(is_leaves) if leaf == True]
pos_ratio=positive_ratio(tree) #List of positive samples ratio of the nodes of binary classification tree
leaves_pos_ratio=[pos_ratio[i] for i in leaves] #List of positive samples ratio of the traversed leaves
#Detect the non-monotonic pairs by comparing the leaves side-by-side
nonmonotone_pairs=[[leaves[i],leaves[i+1]] for i,ratio in enumerate(leaves_pos_ratio[:-1]) if (ratio>=leaves_pos_ratio[i+1])]
#Make a flattened and unique list of leaves out of pairs
nonmonotone_leaves=[]
for pair in nonmonotone_pairs:
for leaf in pair:
if leaf not in nonmonotone_leaves:
nonmonotone_leaves.append(leaf)
if len(nonmonotone_leaves)==0: #If all leaves show monotonic properties, then break
break
#List the parent nodes of the non-monotonic leaves
nonmonotone_leaves_parents=[leaves_parents[i] for i in [leaves.index(leave) for leave in nonmonotone_leaves]]
node_min=min_samples_node(tree, nonmonotone_leaves_parents) #The node with minimum number of samples
#Prune the tree by removing the children of the detected non-monotonic and lowest number of samples node
tree.tree_.children_left[node_min]=-1
tree.tree_.children_right[node_min]=-1
return tree
所有包含“while”的循环一直持续到遍历的叶子不再显示 non-monotonicity 的迭代。 min_samples_node()
标识包含 non-monotonic 个叶子的节点,它是同类节点中成员最少的。当它的左右 children 被值“-1”替换时,树被 p运行ed 并且下一个“while”迭代将产生一个完全不同的树遍历来识别和删除剩余 non-monotonic 度。
下图分别显示了 unp运行ed 和 p运行ed 树。