在 Pyspark 数据框中将二元语法转换为 N-gram

Convert bigrams to N-grams in Pyspark dataframe

我有数据框

    import pyspark.sql.functions as F
    
sdf = spark.createDataFrame([
    ('first', 'apple edible', '23'),
    ('first', 'edible fruit', '34'),
    ('second', 'flowering plant','11'),
    ('second', 'plant green','7'),
    ('third', 'citrus fruit','16'),
    ('third', 'soft sweet','9'), ],
    ['group', 'bigram', 'count'])

+------+---------------+-----+
|group |bigram         |count|
+------+---------------+-----+
|first |apple edible   |23   |
|first |edible fruit   |34   |
|second|flowering plant|11   |
|second|plant green    |7    |
|third |citrus fruit   |16   |
|third |soft sweet     |9    |
+------+---------------+-----+

我想将二元语法集合成一个 N-gram(n=3),条件是:

  1. 二元词完全包含在 N-gram 中;
  2. 最后一个字和开头一样;

结果第一组和第二组合并成一个N-gram,但是第三组不是(我想留下count列)

+------+---------------+-----+---------------------+
|group |bigram         |count|Ngram                |
+------+---------------+-----+---------------------+
|first |apple edible   |23   |apple edible fruit   |
|first |edible fruit   |34   |apple edible fruit   |
|second|flowering plant|11   |flowering plant green|
|second|plant green    |7    |flowering plant green|
+------+---------------+-----+---------------------+

到目前为止我只写了n-gram的转换,然后我不太明白如何进一步制作条件。

sdf_collect = sdf.withColumn('collect_bigram', F.split(F.col("bigram"), " "))\
    .withColumn('collect_bigram', F.regexp_replace('collect_bigram', r'(^\[)|(\]$)', ''))\
    .groupby('group').agg(F.collect_set(F.col('collect_bigram')).alias('collect_bigram'))\

ngram_bigram = NGram(n=3)
ngram_bigram.setInputCol("collect_bigram")
ngram_bigram.setOutputCol("Ngrams")

sdf_ngram3 = ngram_bigram.transform(
    sdf_collect)  

自联接可以提供帮助,第二个条件在联接条件中实现。然后通过组合两侧的数组创建 n-grams 。组合数组时,省略两个数组中共有的元素:

sdf2 = sdf.withColumn('collect_bigram', F.split(F.col("bigram"), " "))

sdf2.alias("a").join(sdf2.alias("b"), F.expr("a.collect_bigram[1] == b.collect_bigram[0]")) \
    .selectExpr("array_join(array_union(a.collect_bigram, slice(b.collect_bigram, 2, size(b.collect_bigram))), ' ' ) as n")\
    .show(truncate=False)

结果:

+---------------------+
|n                    |
+---------------------+
|flowering plant green|
|apple edible fruit   |
+---------------------+

如果将 fruit apple 添加到原始列表中,结果为

+---------------------+
|n                    |
+---------------------+
|fruit apple edible   | --> fruit apple + apple edible
|flowering plant green| --> flowering plant + plant green
|edible fruit apple   | --> edible fruit + fruit apple
|citrus fruit apple   | --> citrus fruit + fruit apple
|apple edible fruit   | --> apple edible + edible fruit
+---------------------+