PySpark DataFrame 中一行与其前导 3 行之间的区别

Difference between a Row and its lead by 3 Rows in a PySpark DataFrame

我有一个 CSV 文件,它已通过以下代码作为数据框导入:

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df = spark.read.csv("name of file.csv", inferSchema = True, header = True)
df.show()

输出

    +-----+------+-----+
    |col1 | col2 | col3|
    +-----+------+-----+    
    |  A  |  2   |  4  |
    +-----+------+-----+    
    |  A  |  4   |  5  | 
    +-----+------+-----+    
    |  A  |  7   |  7  | 
    +-----+------+-----+    
    |  A  |  3   |  8  | 
    +-----+------+-----+    
    |  A  |  7   |  3  | 
    +-----+------+-----+    
    |  B  |  8   |  9  |
    +-----+------+-----+    
    |  B  |  10  |  10 | 
    +-----+------+-----+    
    |  B  |  8   |  9  |
    +-----+------+-----+    
    |  B  |  20  |  15 |
    +-----+------+-----+

我想创建另一个 col4,其中包含 col1 中每个组的 col2[n+3]/col2-1

输出应该是

   +-----+------+-----+-----+
   |col1 | col2 | col3| col4|
   +-----+------+-----+-----+    
   | A   |    2 |   4 |  0.5|  #(3/2-1)
   +-----+------+-----+-----+    
   | A   |    4 |   5 | 0.75| #(7/4-1)
   +-----+------+-----+-----+    
   | A   |    7 |   7 |  NA |
   +-----+------+-----+-----+    
   | A   |    3 |   8 |  NA |
   +-----+------+-----+-----+    
   | A   |    7 |   3 |  NA |
   +-----+------+-----+-----+    
   | B   |    8 |   9 | 1.5 |
   +-----+------+-----+-----+    
   | B   |   10 |  10 |  NA |
   +-----+------+-----+-----+    
   | B   |    8 |  9  |  NA |
   +-----+------+-----+-----+    
   | B   |   20 |  15 |  NA |
   +-----+------+-----+-----+

我知道如何在 pandas 中执行此操作,但我不确定如何在 PySpark 中对分组列进行一些计算。

目前,我的 PySpark 版本是 2.4

我的 Spark 版本是 2.2lead() and Window() have been used. For reference.

from pyspark.sql.window import Window
from pyspark.sql.functions import lead, col    
my_window = Window.partitionBy('col1').orderBy('col1')
df = df.withColumn('col2_lead_3', lead(col('col2'),3).over(my_window))\
       .withColumn('col4',(col('col2_lead_3')/col('col2'))-1).drop('col2_lead_3')
df.show()
+----+----+----+----+
|col1|col2|col3|col4|
+----+----+----+----+
|   B|   8|   9| 1.5|
|   B|  10|  10|null|
|   B|   8|   9|null|
|   B|  20|  15|null|
|   A|   2|   4| 0.5|
|   A|   4|   5|0.75|
|   A|   7|   7|null|
|   A|   3|   8|null|
|   A|   7|   3|null|
+----+----+----+----+