为什么在数据框中删除列后进行拟合会出错?

Why do I get an error by fitting after dropping a column in a dataframe?

我尝试对头部为:

的数据集进行预测
     ID  Reason  Month  Day  ...  Season  Drinker   Age Group    TimeOff
0    28      23     10    2  ...     4.0      Yes  Middle Aged     Low
1    17      18      1    3  ...     2.0      Yes          NaN    High
2    25       1      7    3  ...     1.0      Yes        Adult    High
3    11      28     11    2  ...     4.0      Yes        Adult     Low
4    10      23      3    2  ...     2.0       No  Middle Aged     Low
..   ..     ...    ...  ...  ...     ...      ...          ...     ...
587  28      28      3    2  ...     3.0      NaN  Young Adult     Low
588  20      28     10    5  ...     4.0      NaN  Middle Aged     Low
589  14       8      3    5  ...     2.0       No  Middle Aged    High
590  28       0      5    4  ...     NaN       No        Adult     Low
591  34      25      5    6  ...     NaN       No  Middle Aged    High

然后在预处理时我尝试删除列 'Season' 但稍后会得到一个详细的错误,这是代码:

import numpy as np
from sklearn.compose import ColumnTransformer
import pandas as pd
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.naive_bayes import GaussianNB


def load_dataset(train_csv_path):
    data = pd.read_csv(train_csv_path, sep=',')
    return data


class DataPreprocessor(object):

    def __init__(self):
        self.transformer: Pipeline = None

    def fit(self, dataset_df):

        numerical_columns = ['ID', 'Transportation expense', 'Residence Distance', 'Service time', 'Weight', 'Height', 'Season', 'Pet', 'Son', 'Day', 'Month', 'Reason']

        categorical_columns = list(set(dataset_df.columns) - set(numerical_columns))

        num_pipeline = Pipeline([
            ('imputer', SimpleImputer(strategy="median"))
        ])

        categorical_transformer = OneHotEncoder(drop=None, sparse=False, handle_unknown='ignore')
        cat_pipeline = Pipeline([
            ('1hot', categorical_transformer)
        ])

        preprocessor = ColumnTransformer(
            transformers=[
                ("dropId", 'drop', 'ID'),
                ("num", num_pipeline, numerical_columns),
                ("cat", cat_pipeline, categorical_columns),
            ]
        )

        self.transformer = Pipeline(steps=[
            ("preprocessor", preprocessor)
        ])
        ### DROPPING HERE
        dataset = dataset_df.drop("Season", axis=1)

        self.transformer.fit(dataset)

    def transform(self, df):
        return self.transformer.transform(df)


def train_model(processed_X, y):
    model = GaussianNB()
    model.fit(processed_X, y)

    return model


if __name__ == '__main__':
    preprocessor = DataPreprocessor()
    train_csv_path = 'time_off_data_train.csv'
    train_dataset_df = load_dataset(train_csv_path)
    print(train_dataset_df.head)

    X_train = train_dataset_df.iloc[:, :-1]
    y_train = train_dataset_df['TimeOff']
    preprocessor.fit(X_train)
    model = train_model(preprocessor.transform(X_train), y_train)

我收到这个错误:

Traceback (most recent call last):
  File "/Users/.../PycharmProjects/final_proj/venv/lib/python3.8/site-packages/pandas/core/indexes/base.py", line 3621, in get_loc
    return self._engine.get_loc(casted_key)
  File "pandas/_libs/index.pyx", line 136, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/index.pyx", line 163, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 5198, in pandas._libs.hashtable.PyObjectHashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 5206, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 'Season'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/.../PycharmProjects/final_proj/venv/lib/python3.8/site-packages/sklearn/utils/__init__.py", line 416, in _get_column_indices
    col_idx = all_columns.get_loc(col)
  File "/Users/.../PycharmProjects/final_proj/venv/lib/python3.8/site-packages/pandas/core/indexes/base.py", line 3623, in get_loc
    raise KeyError(key) from err
KeyError: 'Season'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/.../PycharmProjects/final_proj/main.py", line 200, in <module>
    preprocessor.fit(X_train)
  File "/Users/.../PycharmProjects/final_proj/main.py", line 149, in fit
    self.transformer.fit(dataset)
  File "/Users/.../PycharmProjects/final_proj/venv/lib/python3.8/site-packages/sklearn/pipeline.py", line 382, in fit
    self._final_estimator.fit(Xt, y, **fit_params_last_step)
  File "/Users/.../PycharmProjects/final_proj/venv/lib/python3.8/site-packages/sklearn/compose/_column_transformer.py", line 640, in fit
    self.fit_transform(X, y=y)
  File "/Users/.../PycharmProjects/final_proj/venv/lib/python3.8/site-packages/sklearn/compose/_column_transformer.py", line 670, in fit_transform
    self._validate_column_callables(X)
  File "/Users/.../PycharmProjects/final_proj/venv/lib/python3.8/site-packages/sklearn/compose/_column_transformer.py", line 357, in _validate_column_callables
    transformer_to_input_indices[name] = _get_column_indices(X, columns)
  File "/Users/.../PycharmProjects/final_proj/venv/lib/python3.8/site-packages/sklearn/utils/__init__.py", line 424, in _get_column_indices
    raise ValueError("A given column is not a column of the dataframe") from e
ValueError: A given column is not a column of the dataframe

为什么在删除所述列后出现拟合错误,这与在 ColumnTransformer 函数中删除它有什么区别?

DataProcessorfit 方法中删除列与在 ColumnTransformer.

中删除列之间存在巨大差异

当放入 fit 方法中时,只有当您希望适合预处理器时(当您调用 preprocessor.fit(X_train) 时)才会丢弃该列,而不是当您想要实际转换训练数据时(preprocessor.transform(X_train))。您可以注意到 fit 方法没有 return 数据框,这意味着按照您的方式删除列是无用的(但是 transform 方法 return 预处理数据框)。

您的脚本失败,因为在 preprocessorfit 方法中调用 self.transformer.fit(dataset) 时,转换器希望您的数据框有一个“季节”列,因为您在声明你的 numerical_columns.

如果你想正确地删除“季节”列,你可以(并且必须)使用 ColumnTransformer,就像你对“ID”列所做的一样(通过声明一个 drop 转换器对于“季节”列)。然后你的预处理器会期望你的数据在拟合时有一个“季节”列,并且会知道它应该在转换时删除它。