如何更改从树中提取的规则的形状?
how to change the shape of rules that extracted from tree?
我正在使用 sklearn.tree.export_text
从决策树中提取规则:
|--- sepal length (cm) <= 5.75
| |--- petal width (cm) <= 0.70
| | |--- class: 0.0
| |--- petal width (cm) > 0.70
| | |--- class: 1.0
|--- sepal length (cm) > 5.75
| |--- petal length (cm) <= 4.75
| | |--- class: 1.0
我怎样才能把每条规则都写在一行中
sepal length (cm) <= 5.75, petal width (cm) <= 0.70 -----> class: 0.0
代码:
import re
tree = """
|--- sepal length (cm) <= 5.75
| |--- petal width (cm) <= 0.70
| | |--- class: 0.0
| |--- petal width (cm) > 0.70
| | |--- class: 1.0
|--- sepal length (cm) > 5.75
| |--- petal length (cm) <= 4.75
| | |--- class: 1.0
"""
def format_rule(stack):
rule = ", ".join(stack[:-1])
clazz = stack[-1].replace(": ", " = ")
return rule + " -> " + clazz
stack = []
result = []
for line in tree.split("\n"):
if not line:
continue
match = re.fullmatch(r"((?:\| )*)\|--- (.*)", line)
depth = len(match.group(1)) // len(r"| ")
label = match.group(2)
if len(stack) > depth:
result.append(format_rule(stack))
stack = stack[:depth]
stack.append(label)
result.append(format_rule(stack))
print("\n".join(result))
with open("output.txt", "w") as f:
f.write("\n".join(result))
输出:
sepal length (cm) <= 5.75, petal width (cm) <= 0.70 -> class = 0.0
sepal length (cm) <= 5.75, petal width (cm) > 0.70 -> class = 1.0
sepal length (cm) > 5.75, petal length (cm) <= 4.75 -> class = 1.0
我正在使用 sklearn.tree.export_text
从决策树中提取规则:
|--- sepal length (cm) <= 5.75
| |--- petal width (cm) <= 0.70
| | |--- class: 0.0
| |--- petal width (cm) > 0.70
| | |--- class: 1.0
|--- sepal length (cm) > 5.75
| |--- petal length (cm) <= 4.75
| | |--- class: 1.0
我怎样才能把每条规则都写在一行中
sepal length (cm) <= 5.75, petal width (cm) <= 0.70 -----> class: 0.0
代码:
import re
tree = """
|--- sepal length (cm) <= 5.75
| |--- petal width (cm) <= 0.70
| | |--- class: 0.0
| |--- petal width (cm) > 0.70
| | |--- class: 1.0
|--- sepal length (cm) > 5.75
| |--- petal length (cm) <= 4.75
| | |--- class: 1.0
"""
def format_rule(stack):
rule = ", ".join(stack[:-1])
clazz = stack[-1].replace(": ", " = ")
return rule + " -> " + clazz
stack = []
result = []
for line in tree.split("\n"):
if not line:
continue
match = re.fullmatch(r"((?:\| )*)\|--- (.*)", line)
depth = len(match.group(1)) // len(r"| ")
label = match.group(2)
if len(stack) > depth:
result.append(format_rule(stack))
stack = stack[:depth]
stack.append(label)
result.append(format_rule(stack))
print("\n".join(result))
with open("output.txt", "w") as f:
f.write("\n".join(result))
输出:
sepal length (cm) <= 5.75, petal width (cm) <= 0.70 -> class = 0.0
sepal length (cm) <= 5.75, petal width (cm) > 0.70 -> class = 1.0
sepal length (cm) > 5.75, petal length (cm) <= 4.75 -> class = 1.0