如何在pyspark dataframe中保持一定范围内的数据唯一?

How to keep data unique in a certain range in pysaprk dataframe?

公司可以select一段路。部分由开始和结束表示。

pyspark 数据框如下:

+--------------------+----------+--------+
|Road company        |start(km) |end(km) |
+--------------------+----------+--------+
|classA              |1         |3       |
|classA              |4         |7       |
|classA              |10        |15      |
|classA              |16        |20      |
|classB              |1         |3       |
|classB              |4         |7       |
|classB              |10        |15      |
+--------------------+----------+--------+

B类公司会优先选择路段。对于 classA 条目,应该与 classB 重叠。也就是说,A 类公司不能 select B 类(公司)选择的路段。结果应该如下:

    +--------------------+----------+--------+
    |Road company        |start(km) |end(km) |
    +--------------------+----------+--------+
    |classA              |16        |20      |
    |classB              |1         |3       |
    |classB              |4         |7       |
    |classB              |10        |15      |
    +--------------------+----------+--------+

distinct() 函数不支持将帧分成几个部分来应用不同的操作。我应该怎么做才能实现它?

如果您可以相信部分永远不会重叠这一事实,您可以使用以下逻辑解决此问题。您可能会优化它以依赖“开始(公里)”。但如果你说的更多 in-depth 可能会更复杂。

from pyspark.sql.functions col, when
from pyspark.sql.types import *

def emptyDF():
 schema = StructType([
   StructField("start(km)",IntegerType(),True),
   StructField("end(km)",IntegerType(),True),
   StructField("Road company",StringType(),True)
 ])
 return spark.createDataFrame(sc.emptyRDD(), schema)

def dummyData():
  return sc.parallelize([["classA",1,3],["classA",4,7],["classA",8,15],["classA",16,20],["classB",1,3],["classB",4,7],["classB",8,15]]).toDF(['Road company','start(km)','end(km)'])

df = dummyData()
df.cache()
df_ordered = df.orderBy(when(col("Road company") == "classB", 1)
           .when(col("Road company") == "classA", 2)
           .when(col("Road company") == "classC", 3)
           ).select("Road company").distinct()

whatsLeft = df.select( col("start(km)") ,col("end(km)") ).distinct()
result = emptyDF()

#Only use collect() on small countable sets of data.
for company in df_ordered.collect():
  taken = df.where(col("Road company") == lit(company[0]))\
          .join(whatsLeft, ["start(km)" ,"end(km)"])
  whatsLeft = whatsLeft.subtract( taken.drop( col("Road company") ) )
  result = result.union( taken )

result.show()
+---------+-------+------------+                                                
|start(km)|end(km)|Road company|
+---------+-------+------------+
|        1|      3|      classB|
|        4|      7|      classB|
|        8|     15|      classB|
|       16|     20|      classA|
+---------+-------+------------+

如果您可以部分分配路段,这里有一个不同(非常相似)的策略:

start="start(km)"
end="end(km)"
def emptyDFr():
 schema = StructType([
   StructField(start,IntegerType(),True),
   StructField(end,IntegerType(),True),
   StructField("Road company",StringType(),True),
   StructField("ranged",IntegerType(),True)
 ])
 return spark.createDataFrame(sc.emptyRDD(), schema)
def dummyData():
  return sc.parallelize([["classA",1,3],["classA",4,7],["classA",8,15],["classA",16,20],["classB",1,3],["classB",4,7],["classB",8,17]]).toDF(['Road company','start(km)','end(km)'])

df = dummyData()
df.cache()
df_ordered = df.orderBy(when(col("Road company") == "classB", 1)
           .when(col("Road company") == "classA", 2)
           .when(col("Road company") == "classC", 3)
           ).select("Road company").distinct()
# create the sequence of kilometers that cover the 'start' to 'end'
ranged = df.withColumn("range", explode(sequence( col(start), col(end) )) )
whatsLeft = ranged.select( col("range") ).distinct()
result = emptyDFr()

#Only use collect() on small countable sets of data.
for company in df_ordered.collect():
  taken = ranged.where(col("Road company") == lit(company[0]))\
          .join(whatsLeft, ["range"])
  whatsLeft = whatsLeft.subtract( taken.select( col("range") ) )
  result = result.union( taken.select(  col("range") ,col(start), col(end),col("Road company") ) )

#convert our result back to the 'original style' of records with starts and ends.    
result.groupBy( start, end, "Road company").agg(count("ranged").alias("count") )\
#figure out math to see if you got everything you asked for.
.withColumn("Partial", ((col(end)+lit(1)) - col(start)) != col("count"))\
.withColumn("Maths", ((col(end)+lit(1)) - col(start))).show() #helps show why this works not requried.