ML Decision Tree classifier is only splitting on the same tree / asking about the same attribute

我目前正在使用 Gini 和信息增益制作决策树分类器,并根据每次增益最大的最佳属性拆分树。然而,它每次都坚持相同的属性并简单地调整其 question 的值。这导致非常低的准确度,通常在 30% 左右,因为它只考虑了第一个属性。


 # Used to find the best split for data among all attributes

def split(r):
    max_ig = 0
    max_att = 0
    max_att_val = 0
    i = 0

    curr_gini = gini_index(r)
    n_att = len(att)

    for c in range(n_att):
        if c == 3:

        c_vals = get_column(r, c)

        while i < len(c_vals):
            # Value of the current attribute that is being tested
            curr_att_val = r[i][c]
            true, false = fork(r, c, curr_att_val)
            ig = gain(true, false, curr_gini)

            if ig > max_ig:
                max_ig = ig
                max_att = c
                max_att_val = r[i][c]
            i += 1

    return max_ig, max_att, max_att_val


    # Used to compare and test if the current row is greater than or equal to the test value
# in order to split up the data

def compare(r, test_c, test_val):
    if r[test_c].isdigit():
        return r[test_c] == test_val

    elif float(r[test_c]) >= float(test_val):
        return True

        return False

# Splits the data into two lists for the true/false results of the compare test

def fork(r, c, test_val):
    true = []
    false = []

    for row in r:

        if compare(row, c, test_val):

    return true, false


def rec_tree(r):
ig, att, curr_att_val = split(r)

if ig == 0:
    return Leaf(r)

true_rows, false_rows = fork(r, att, curr_att_val)

true_branch = rec_tree(true_rows)
false_branch = rec_tree(false_rows)

return Node(att, curr_att_val, true_branch, false_branch)

我的工作解决方案是按如下方式更改拆分功能。老实说,我看不出哪里出了问题,但这可能很明显 工作函数如下

def split(r):
max_ig = 0
max_att = 0
max_att_val = 0

# calculates gini for the rows provided
curr_gini = gini_index(r)
no_att = len(r[0])

# Goes through the different attributes

for c in range(no_att):

    # Skip the label column (beer style)

    if c == 3:
    column_vals = get_column(r, c)

    i = 0
    while i < len(column_vals):
        # value we want to check
        att_val = r[i][c]

        # Use the attribute value to fork the data to true and false streams
        true, false = fork(r, c, att_val)

        # Calculate the information gain
        ig = gain(true, false, curr_gini)

        # If this gain is the highest found then mark this as the best choice
        if ig > max_ig:
            max_ig = ig
            max_att = c
            max_att_val = r[i][c]
        i += 1

return max_ig, max_att, max_att_val