计算特征向量中的属性数量 - Spark Scala

Count number of attributes in vector of features - Spark Scala

我有一个包含一列规范化特征的数据框,如下所示:

+--------------------+
|        normFeatures|
+--------------------+
|(17412,[0,1,2,5,1...|
|(17412,[0,1,2,5,9...|
|(17412,[0,1,2,5,1...|
|(17412,[0,1,2,5,9...|
|(17412,[0,1,2,5,1...|
|(17412,[0,1,2,5,1...|
+--------------------+

这些向量是在将 StringIndexerOneHotEncoderVectorAssembler 应用于原始属性列后获得的。

我想知道是否可以计算新属性的数量。我不知道新数字是否是向量的大小,或者嵌套 []() 中的数字是否也算作属性。

矢量化前的数据帧:

+---------+---------+-------------+---------+-------+--------+--------+--------+-------+--------+-------+
|DayOfWeek|  DepTime|UniqueCarrier|FlightNum|TailNum|ArrDelay|DepDelay|Distance|TaxiOut|    Date| Flight|
+---------+---------+-------------+---------+-------+--------+--------+--------+-------+--------+-------+
| Thursday|Afternoon|           WN|      588| N240WN|      16|      18|     393|      9|2008/1/3|HOU-LIT|
| Thursday|  Morning|           WN|     1343| N523SW|       2|       5|     441|      8|2008/1/3|HOU-MAF|
| Thursday|    Night|           WN|     3841| N280WN|      -4|      -6|     441|     14|2008/1/3|HOU-MAF|
| Thursday|  Morning|           WN|        3| N308SA|      -2|       8|     848|      7|2008/1/3|HOU-MCO|
| Thursday|Afternoon|           WN|       25| N462WN|      16|      23|     848|     10|2008/1/3|HOU-MCO|
| Thursday|    Night|           WN|       51| N483WN|       0|       4|     848|      7|2008/1/3|HOU-MCO|
| Thursday|  Evening|           WN|      940| N493WN|       3|       8|     848|      7|2008/1/3|HOU-MCO|
| Thursday|  Morning|           WN|     2621| N266WN|       5|       2|     848|     19|2008/1/3|HOU-MCO|
| Thursday|  Evening|           WN|      389| N266WN|      -5|      -1|     937|     15|2008/1/3|HOU-MDW|
| Thursday|Afternoon|           WN|      519| N514SW|      26|      28|     937|     13|2008/1/3|HOU-MDW|
+---------+---------+-------------+---------+-------+--------+--------+--------+-------+--------+-------+

矢量化和归一化后的数据帧:

+---------+---------+-------------+---------+-------+--------+--------+--------+-------+--------+-------+---------+------------+------------------+-----------+------------+--------------+--------------+-------------+-------------------+-------------+-------------------+---------------+----------------+-------------------+--------------------+--------------------+
|DayOfWeek|  DepTime|UniqueCarrier|FlightNum|TailNum|ArrDelay|DepDelay|Distance|TaxiOut|    Date| Flight|DateIndex|DepTimeIndex|UniqueCarrierIndex|FlightIndex|TailNumIndex|FlightNumIndex|DayOfWeekIndex| DayOfWeekVec|       FlightNumVec|   DepTimeVec|          FlightVec|        DateVec|UniqueCarrierVec|         TailNumVec|            features|        normFeatures|
+---------+---------+-------------+---------+-------+--------+--------+--------+-------+--------+-------+---------+------------+------------------+-----------+------------+--------------+--------------+-------------+-------------------+-------------+-------------------+---------------+----------------+-------------------+--------------------+--------------------+
| Thursday|Afternoon|           WN|      588| N240WN|      16|      18|     393|      9|2008/1/3|HOU-LIT|      9.0|         1.0|               0.0|     3631.0|       554.0|         399.0|           2.0|(6,[2],[1.0])| (7262,[399],[1.0])|(3,[1],[1.0])|(4974,[3631],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[554],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
| Thursday|  Morning|           WN|     1343| N523SW|       2|       5|     441|      8|2008/1/3|HOU-MAF|      9.0|         0.0|               0.0|     3060.0|      1256.0|        3961.0|           2.0|(6,[2],[1.0])|(7262,[3961],[1.0])|(3,[0],[1.0])|(4974,[3060],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])|(5025,[1256],[1.0])|(17412,[0,1,2,5,9...|(17412,[0,1,2,5,9...|
| Thursday|    Night|           WN|     3841| N280WN|      -4|      -6|     441|     14|2008/1/3|HOU-MAF|      9.0|         3.0|               0.0|     3060.0|       463.0|        1520.0|           2.0|(6,[2],[1.0])|(7262,[1520],[1.0])|    (3,[],[])|(4974,[3060],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[463],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
| Thursday|  Morning|           WN|        3| N308SA|      -2|       8|     848|      7|2008/1/3|HOU-MCO|      9.0|         0.0|               0.0|     1285.0|        93.0|          76.0|           2.0|(6,[2],[1.0])|  (7262,[76],[1.0])|(3,[0],[1.0])|(4974,[1285],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])|  (5025,[93],[1.0])|(17412,[0,1,2,5,9...|(17412,[0,1,2,5,9...|
| Thursday|Afternoon|           WN|       25| N462WN|      16|      23|     848|     10|2008/1/3|HOU-MCO|      9.0|         1.0|               0.0|     1285.0|       497.0|         213.0|           2.0|(6,[2],[1.0])| (7262,[213],[1.0])|(3,[1],[1.0])|(4974,[1285],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[497],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
| Thursday|    Night|           WN|       51| N483WN|       0|       4|     848|      7|2008/1/3|HOU-MCO|      9.0|         3.0|               0.0|     1285.0|       282.0|         204.0|           2.0|(6,[2],[1.0])| (7262,[204],[1.0])|    (3,[],[])|(4974,[1285],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[282],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
| Thursday|  Evening|           WN|      940| N493WN|       3|       8|     848|      7|2008/1/3|HOU-MCO|      9.0|         2.0|               0.0|     1285.0|       342.0|        1455.0|           2.0|(6,[2],[1.0])|(7262,[1455],[1.0])|(3,[2],[1.0])|(4974,[1285],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[342],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
| Thursday|  Morning|           WN|     2621| N266WN|       5|       2|     848|     19|2008/1/3|HOU-MCO|      9.0|         0.0|               0.0|     1285.0|       555.0|        2051.0|           2.0|(6,[2],[1.0])|(7262,[2051],[1.0])|(3,[0],[1.0])|(4974,[1285],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[555],[1.0])|(17412,[0,1,2,5,9...|(17412,[0,1,2,5,9...|
| Thursday|  Evening|           WN|      389| N266WN|      -5|      -1|     937|     15|2008/1/3|HOU-MDW|      9.0|         2.0|               0.0|     1081.0|       555.0|        1016.0|           2.0|(6,[2],[1.0])|(7262,[1016],[1.0])|(3,[2],[1.0])|(4974,[1081],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[555],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
| Thursday|Afternoon|           WN|      519| N514SW|      26|      28|     937|     13|2008/1/3|HOU-MDW|      9.0|         1.0|               0.0|     1081.0|       133.0|         309.0|           2.0|(6,[2],[1.0])| (7262,[309],[1.0])|(3,[1],[1.0])|(4974,[1081],[1.0])|(120,[9],[1.0])|  (19,[0],[1.0])| (5025,[133],[1.0])|(17412,[0,1,2,5,1...|(17412,[0,1,2,5,1...|
+---------+---------+-------------+---------+-------+--------+--------+--------+-------+--------+-------+---------+------------+------------------+-----------+------------+--------------+--------------+-------------+-------------------+-------------+-------------------+---------------+----------------+-------------------+--------------------+--------------------+

代码:

            val numeric_columns = Array("ArrDelay","DepDelay","TaxiOut","Distance")
            val string_columns = df.columns.diff(numeric_columns)
            println("Getting vector of normalized features")
            val index_columns = string_columns.map(col => col + "Index")
    
            // StringIndexer
            val indexer = new StringIndexer()
            .setInputCols(string_columns)
            .setOutputCols(index_columns)
    
            val vec_columns = string_columns.map(col => col + "Vec")
    
            // OneHotEncoder
            val encoder = new OneHotEncoder()
            .setInputCols(index_columns)
            .setOutputCols(vec_columns)
    
            // VectorAssembler
            val num_vec_columns:Array[String] = (numeric_columns.filter(!_.contains("ArrDelay"))) ++ vec_columns   
            val assembler = new VectorAssembler()
            .setInputCols(num_vec_columns)
            .setOutputCol("features")
    
            // Normalizer
            val normalizer = new Normalizer()
            .setInputCol("features")
            .setOutputCol("normFeatures")
            .setP(1.0)
    
        // All together in pipeline
        val pipeline = new Pipeline()
        .setStages(Array(indexer, encoder, assembler,normalizer))
        df = pipeline.fit(df).transform(df)
        df.printSchema()
        df.show(10)
        println("Done")
        println("-------------------")

提前致谢。

此处的一些注意事项:

你这里有一个 SparseVector 表示,用于多个稀疏向量转换。这些是在您使用 OneHotEncoder 转换时创建的(它已被弃用)。所以当你有类似的东西时:

(7262,[399],[1.0])

这是一个说明,表明您有一个包含 7262 个位置的向量,其中第 399 个位置为 1.0。这里的长度是7262,虽然是稀疏表示,但并不密集。

矢量组装器正在连接稀疏表示,最终您得到了最终的 17412 长度稀疏表示。如果您不截断地打印数据框,您将看到 normFeatures 列的位置和值。

如果你想提取这个字段的长度,你可以这样做:

val row = df2.select("normFeatures").head
val vector = row(0).asInstanceOf[SparseVector]
val size = vector.size

但是,整个数据帧的稀疏表示长度不是固定的。您可以有不同长度的行,尽管具有不应该发生的相同转换。如果您无法跟踪应用的转换,则在与另一个数据框执行合并操作时要小心。