如何在 pyspark ml 管道中的列子集上使用 StandardScaler?

How to use StandardScaler on subset of columns in pyspark ml pipeline?

在我的数据框中,一些列是连续值,而其他列只有 0/1 值。在使用 Pipeline 进行逻辑回归之前,我想在连续列上使用 StandardScaler。如何实现代码?

我试试:

from pyspark.ml.feature import VectorAssembler,StandardScaler
from pyspark.ml import Pipeline,Transformer
from pyspark.sql.functions import udf,col
from pyspark.sql.types import FloatType, ArrayType
from pyspark.ml.util import DefaultParamsWritable, DefaultParamsReadable
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params, TypeConverters

class StandardScalerSubset(Transformer, DefaultParamsReadable, DefaultParamsWritable):
    """
    A custom Transformer which use StandardScaler on subset of features.
    """
    def __init__(self, to_scale_cols, remaining_cols):
        super(StandardScalerSubset, self).__init__()
        self.to_scale_cols = to_scale_cols  # continuous columns to be scaled
        self.remaining_cols = remaining_cols  # other columns

    def _transform(self, data):
        va = VectorAssembler().setInputCols(self.to_scale_cols).setOutputCol("to_scale_vector")
        data_va = va.transform(data)

        scaler = StandardScaler(inputCol="to_scale_vector", outputCol="scaled_vector", withMean=True, withStd=True)
        scaler_model = scaler.fit(data_va)
        data_scaled = scaler_model.transform(data_va)

        vector2list = udf(lambda x: x.toArray().tolist(),ArrayType(FloatType()))
        # return all columns
        data_res = data_scaled.withColumn("scaled_list", vector2list("scaled_vector")) \
            .select(self.remaining_cols
                    + [col("scaled_list").getItem(i).alias(c) for (i, c) in enumerate(self.scale_cols)])
        return data_res

输入:

# +---+---+---+---+
# |  a|  b|  c|  d|
# +---+---+---+---+
# |  1|  5| 10|  0|
# |  0| 10| 20|  1|
# |  1| 15| 25|  0|
# |  0| 30| 40|  1|
# +---+---+---+---+

输出将是:

# +---+---+--------+-----+
# |  a|  d|       b|    c|
# +---+---+--------+-----+
# |  1|  0| -0.9258| -1.1|
# |  0|  1| -0.4629| -0.3|
# |  1|  0|     0.0|  0.1|
# |  0|  1|  1.3887|  1.3|
# +---+---+--------+-----+

可以这样使用:

scalerFeatures = ['xxx']
featureAr = ['xxx']
remainingFeatures = ['xxx']
sss = StandardScalerSubset(scale_cols=scalerFeatures, remaining_cols=remainingFeatures)
vectorAssembler = VectorAssembler().setInputCols(featureArr).setOutputCol("features")
lrModel = LogisticRegression(featuresCol="features",regParam=0.1,maxIter=100,family="binomial")
pipeline = Pipeline().setStages([sss, vectorAssembler, modelObject])
pipeline.fit(trainData).write().overwrite().save(modelSavePath)

当我使用 PipelineModel.load(modelSavePath) 加载模型时,出现错误。 我认为我应该同时实现 fittransform 。但是我不知道该怎么做。谁能帮我?谢谢

评论太长,但您可以尝试以下方法:

from pyspark.ml.feature import StandardScaler
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline

scalerFeatures = ['b', 'c']
remainingFeatures = ['a', 'd']
featureArr = remainingFeatures + [('scaled_' + f) for f in scalerFeatures]

va1 = [VectorAssembler(inputCols=[f], outputCol=('vec_' + f)) for f in scalerFeatures]
ss = [StandardScaler(inputCol='vec_' + f, outputCol='scaled_' + f, withMean=True, withStd=True) for f in scalerFeatures]

va2 = VectorAssembler(inputCols=featureArr, outputCol="features")
lr = LogisticRegression()

stages = va1 + ss + [va2]
# I don't have a label column, but if you do, you can put lr stage at the end:
# stages = va1 + ss + [va2, lr]

p = Pipeline(stages=stages)
p.fit(df).transform(df).show()
+---+---+---+---+------+------+---------------------+----------------------+--------------------------------------------------+
|a  |b  |c  |d  |vec_b |vec_c |scaled_b             |scaled_c              |features                                          |
+---+---+---+---+------+------+---------------------+----------------------+--------------------------------------------------+
|1  |5  |10 |0  |[5.0] |[10.0]|[-0.9258200997725514]|[-1.0999999999999999] |[1.0,0.0,-0.9258200997725514,-1.0999999999999999] |
|0  |10 |20 |1  |[10.0]|[20.0]|[-0.4629100498862757]|[-0.29999999999999993]|[0.0,1.0,-0.4629100498862757,-0.29999999999999993]|
|1  |15 |25 |0  |[15.0]|[25.0]|[0.0]                |[0.09999999999999998] |[1.0,0.0,0.0,0.09999999999999998]                 |
|0  |30 |40 |1  |[30.0]|[40.0]|[1.3887301496588271] |[1.2999999999999998]  |[0.0,1.0,1.3887301496588271,1.2999999999999998]   |
+---+---+---+---+------+------+---------------------+----------------------+--------------------------------------------------+