使用 sklearn 创建自定义转换器 - 缺少必需的位置参数错误

Creating custom transformer with sklearn - missing required positional argument error

我正在尝试创建一个自定义转换器,它将一列拆分为多列,我还想提供分隔符。

这是我创建转换器的代码

class StringSplitTransformer(BaseEstimator, TransformerMixin):
def __init__(self, cols = None):
    self.cols = cols
def transform(self,df,delim):
    X = df.copy()
    for col in self.cols:
        X = pd.concat([X,X[col].str.split(delim,expand = True)], axis = 1)
    return X
def fit(self, *_):
    return self

当我 运行 fit()transform() 分开时,一切正常:

split_trans = StringSplitTransformer(cols = ['Cabin'])
split_trans.fit(df)
split_trans.transform(df, '/')

但是当我 运行 fit_transform() 它给我一个错误:

split_trans.fit_transform(X_train, '/')

TypeError: transform() missing 1 required positional argument: 'delim'

在我的 transform() 函数中,如果我没有 delim 参数,而是只提供定界符,那么 fit_transform() 可以工作。 我不明白为什么会这样。

fit 应该接受至少两个参数,位置 X 和可选 y=None。当您调用 fit_transform 时,您的转换器分配了 y='\' 并错过了 delim。好吧,我宁愿让 delim 成为 class 变量:

class StringSplitTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, delim, cols=None):
        self.delim = delim
        self.cols = cols

    def fit(self, df, y=None):
        return self

    def transform(self, df):
        X = df.copy()
        for col in self.cols:
            X = pd.concat([X, X[col].str.split(self.delim, expand=True)],
                          axis=1)
        return X