如何更改从树中提取的规则的形状?

how to change the shape of rules that extracted from tree?

我正在使用 sklearn.tree.export_text 从决策树中提取规则:

|--- sepal length (cm) <= 5.75
|   |--- petal width (cm) <= 0.70
|   |   |--- class: 0.0
|   |--- petal width (cm) > 0.70
|   |   |--- class: 1.0
|--- sepal length (cm) > 5.75
|   |--- petal length (cm) <= 4.75
|   |   |--- class: 1.0

我怎样才能把每条规则都写在一行中

sepal length (cm) <= 5.75, petal width (cm) <= 0.70 -----> class: 0.0

代码:

import re

tree = """
|--- sepal length (cm) <= 5.75
|   |--- petal width (cm) <= 0.70
|   |   |--- class: 0.0
|   |--- petal width (cm) > 0.70
|   |   |--- class: 1.0
|--- sepal length (cm) > 5.75
|   |--- petal length (cm) <= 4.75
|   |   |--- class: 1.0
"""

def format_rule(stack):
    rule = ", ".join(stack[:-1])
    clazz = stack[-1].replace(": ", " = ")
    return rule + " -> " + clazz

stack = []
result = []

for line in tree.split("\n"):
    if not line:
        continue

    match = re.fullmatch(r"((?:\|   )*)\|--- (.*)", line)
    depth = len(match.group(1)) // len(r"|   ")
    label = match.group(2)

    if len(stack) > depth:
        result.append(format_rule(stack))

    stack = stack[:depth]
    stack.append(label)

result.append(format_rule(stack))
print("\n".join(result))

with open("output.txt", "w") as f:
    f.write("\n".join(result))

输出:

sepal length (cm) <= 5.75, petal width (cm) <= 0.70 -> class = 0.0
sepal length (cm) <= 5.75, petal width (cm) > 0.70 -> class = 1.0
sepal length (cm) > 5.75, petal length (cm) <= 4.75 -> class = 1.0