使用 sklearn 创建自定义转换器 - 缺少必需的位置参数错误
Creating custom transformer with sklearn - missing required positional argument error
我正在尝试创建一个自定义转换器,它将一列拆分为多列,我还想提供分隔符。
这是我创建转换器的代码
class StringSplitTransformer(BaseEstimator, TransformerMixin):
def __init__(self, cols = None):
self.cols = cols
def transform(self,df,delim):
X = df.copy()
for col in self.cols:
X = pd.concat([X,X[col].str.split(delim,expand = True)], axis = 1)
return X
def fit(self, *_):
return self
当我 运行 fit()
和 transform()
分开时,一切正常:
split_trans = StringSplitTransformer(cols = ['Cabin'])
split_trans.fit(df)
split_trans.transform(df, '/')
但是当我 运行 fit_transform()
它给我一个错误:
split_trans.fit_transform(X_train, '/')
TypeError: transform() missing 1 required positional argument: 'delim'
在我的 transform()
函数中,如果我没有 delim
参数,而是只提供定界符,那么 fit_transform()
可以工作。
我不明白为什么会这样。
fit
应该接受至少两个参数,位置 X
和可选 y=None
。当您调用 fit_transform
时,您的转换器分配了 y='\'
并错过了 delim
。好吧,我宁愿让 delim
成为 class 变量:
class StringSplitTransformer(BaseEstimator, TransformerMixin):
def __init__(self, delim, cols=None):
self.delim = delim
self.cols = cols
def fit(self, df, y=None):
return self
def transform(self, df):
X = df.copy()
for col in self.cols:
X = pd.concat([X, X[col].str.split(self.delim, expand=True)],
axis=1)
return X
我正在尝试创建一个自定义转换器,它将一列拆分为多列,我还想提供分隔符。
这是我创建转换器的代码
class StringSplitTransformer(BaseEstimator, TransformerMixin):
def __init__(self, cols = None):
self.cols = cols
def transform(self,df,delim):
X = df.copy()
for col in self.cols:
X = pd.concat([X,X[col].str.split(delim,expand = True)], axis = 1)
return X
def fit(self, *_):
return self
当我 运行 fit()
和 transform()
分开时,一切正常:
split_trans = StringSplitTransformer(cols = ['Cabin'])
split_trans.fit(df)
split_trans.transform(df, '/')
但是当我 运行 fit_transform()
它给我一个错误:
split_trans.fit_transform(X_train, '/')
TypeError: transform() missing 1 required positional argument: 'delim'
在我的 transform()
函数中,如果我没有 delim
参数,而是只提供定界符,那么 fit_transform()
可以工作。
我不明白为什么会这样。
fit
应该接受至少两个参数,位置 X
和可选 y=None
。当您调用 fit_transform
时,您的转换器分配了 y='\'
并错过了 delim
。好吧,我宁愿让 delim
成为 class 变量:
class StringSplitTransformer(BaseEstimator, TransformerMixin):
def __init__(self, delim, cols=None):
self.delim = delim
self.cols = cols
def fit(self, df, y=None):
return self
def transform(self, df):
X = df.copy()
for col in self.cols:
X = pd.concat([X, X[col].str.split(self.delim, expand=True)],
axis=1)
return X