在 PySpark 中使用 Apache Spark 数据帧删除重音的最佳方法是什么?

What is the best way to remove accents with Apache Spark dataframes in PySpark?


我已经根据此 post 中提供的代码做了一个函数,该函数删除了特殊的重音符号。问题是函数很慢,因为它使用了 UDF。 我只是想知道我是否可以提高函数的性能以在更短的时间内获得结果,因为这对小数据帧有好处,但对大数据帧不利。



# Importing sql types
from pyspark.sql.types import StringType, IntegerType, StructType, StructField
from pyspark.sql.functions import udf, col
import unicodedata

# Building a simple dataframe:
schema = StructType([StructField("city", StringType(), True),
                     StructField("country", StringType(), True),
                     StructField("population", IntegerType(), True)])

countries = ['Venezuela', 'US@A', 'Brazil', 'Spain']
cities = ['Maracaibó', 'New York', '   São Paulo   ', '~Madrid']
population = [37800000,19795791,12341418,6489162]

# Dataframe:
df = sqlContext.createDataFrame(list(zip(cities, countries, population)), schema=schema)


class Test():
    def __init__(self, df):
        self.df = df

    def clearAccents(self, columns):
        """This function deletes accents in strings column dataFrames, 
        it does not eliminate main characters, but only deletes special tildes.

        :param columns  String or a list of column names.
        # Filters all string columns in dataFrame
        validCols = [c for (c, t) in filter(lambda t: t[1] == 'string', self.df.dtypes)]

        # If None or [] is provided with column parameter:
        if (columns == "*"): columns = validCols[:]

        # Receives  a string as an argument
        def remove_accents(inputStr):
            # first, normalize strings:
            nfkdStr = unicodedata.normalize('NFKD', inputStr)
            # Keep chars that has no other char combined (i.e. accents chars)
            withOutAccents = u"".join([c for c in nfkdStr if not unicodedata.combining(c)])
            return withOutAccents

        function = udf(lambda x: remove_accents(x) if x != None else x, StringType())
        exprs = [function(col(c)).alias(c) if (c in columns) and (c in validCols) else c for c in self.df.columns]
        self.df = self.df.select(*exprs)

foo = Test(df)

此解决方案仅 Python,但仅在可能的重音数量较少(例如,一种单一语言,如西班牙语)并且手动指定字符替换时才有用。

似乎没有内置方法可以在没有 UDF 的情况下直接执行您要求的操作,但是您可以链接许多 regexp_replace 调用来替换每个可能的重音字符。我测试了该解决方案的性能,结果表明,只有当您要替换的重音集非常有限时,它才会运行得更快。如果是这种情况,它可能比 UDF 更快,因为它在 Python.

from pyspark.sql.functions import col, regexp_replace

accent_replacements_spanish = [
    (u'á', 'a'), (u'Á', 'A'),
    (u'é', 'e'), (u'É', 'E'),
    (u'í', 'i'), (u'Í', 'I'),
    (u'ò', 'o'), (u'Ó', 'O'),
    (u'ú|ü', 'u'), (u'Ú|Ű', 'U'),
    (u'ñ', 'n'),
    # see  for other characters

    # this will convert other non ASCII characters to a question mark:
    ('[^\x00-\x7F]', '?') 

def remove_accents(column):
    r = col(column)
    for a, b in accent_replacements_spanish:
        r = regexp_replace(r, a, b)
    return r.alias('remove_accents(' + column + ')')

df = sqlContext.createDataFrame([['Olà'], ['Olé'], ['Núñez']], ['str'])

我没有将性能与其他响应进行比较,此功能不是一般的,但至少值得考虑,因为您不需要将 Scala 或 Java 添加到您的构建过程.

一个可能的改进是构建自定义 Transformer,它将处理 Unicode 规范化和相应的 Python 包装器。它应该减少在 JVM 和 Python 之间传递数据的总体开销,并且不需要对 Spark 本身进行任何修改或访问私有 API.

在 JVM 端,您需要一个与此类似的转换器:

package net.zero323.spark.ml.feature

import java.text.Normalizer
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{DataType, StringType}

class UnicodeNormalizer (override val uid: String)
  extends UnaryTransformer[String, String, UnicodeNormalizer] {

  def this() = this(Identifiable.randomUID("unicode_normalizer"))

  private val forms = Map(
    "NFC" -> Normalizer.Form.NFC, "NFD" -> Normalizer.Form.NFD,
    "NFKC" -> Normalizer.Form.NFKC, "NFKD" -> Normalizer.Form.NFKD

  val form: Param[String] = new Param(this, "form", "unicode form (one of NFC, NFD, NFKC, NFKD)",

  def setN(value: String): this.type = set(form, value)

  def getForm: String = $(form)

  setDefault(form -> "NFKD")

  override protected def createTransformFunc: String => String = {
    val normalizerForm = forms($(form))
    (s: String) => Normalizer.normalize(s, normalizerForm)

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType == StringType, s"Input type must be string type but got $inputType.")

  override protected def outputDataType: DataType = StringType

相应的构建定义(调整 Spark 和 Scala 版本以匹配您的 Spark 部署):

name := "unicode-normalization"

version := "1.0"

crossScalaVersions := Seq("2.11.12", "2.12.8")

organization := "net.zero323"

val sparkVersion = "2.4.0"

libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-core" % sparkVersion,
  "org.apache.spark" %% "spark-sql" % sparkVersion,
  "org.apache.spark" %% "spark-mllib" % sparkVersion

在 Python 方面,您需要一个类似于此的包装器。

from pyspark.ml.param.shared import *
# from pyspark.ml.util import keyword_only  # in Spark < 2.0
from pyspark import keyword_only 
from pyspark.ml.wrapper import JavaTransformer

class UnicodeNormalizer(JavaTransformer, HasInputCol, HasOutputCol):

    def __init__(self, form="NFKD", inputCol=None, outputCol=None):
        super(UnicodeNormalizer, self).__init__()
        self._java_obj = self._new_java_obj(
            "net.zero323.spark.ml.feature.UnicodeNormalizer", self.uid)
        self.form = Param(self, "form",
            "unicode form (one of NFC, NFD, NFKC, NFKD)")
        # kwargs = self.__init__._input_kwargs  # in Spark < 2.0
        kwargs = self._input_kwargs

    def setParams(self, form="NFKD", inputCol=None, outputCol=None):
        # kwargs = self.setParams._input_kwargs  # in Spark < 2.0
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setForm(self, value):
        return self._set(form=value)

    def getForm(self):
        return self.getOrDefault(self.form)

构建 Scala 包:

sbt +package

在开始 shell 或提交时包含它。例如,对于使用 Scala 2.11 构建的 Spark:

bin/pyspark --jars path-to/target/scala-2.11/unicode-normalization_2.11-1.0.jar \
 --driver-class-path path-to/target/scala-2.11/unicode-normalization_2.11-1.0.jar


from pyspark.sql.functions import regexp_replace

normalizer = UnicodeNormalizer(form="NFKD",
    inputCol="text", outputCol="text_normalized")

df = sc.parallelize([
    (1, "Maracaibó"), (2, "New York"),
    (3, "   São Paulo   "), (4, "~Madrid")
]).toDF(["id", "text"])

    .select(regexp_replace("text_normalized", "\p{M}", ""))

## +--------------------------------------+
## |regexp_replace(text_normalized,\p{M},)|
## +--------------------------------------+
## |                             Maracaibo|
## |                              New York|
## |                          Sao Paulo   |
## |                               ~Madrid|
## +--------------------------------------+

请注意,这遵循与内置文本转换器相同的约定,并且不安全。您可以通过检查 createTransformFunc.

中的 null 轻松纠正该错误

另一种使用 python Unicode Database 的方法:

import unicodedata
import sys

from pyspark.sql.functions import translate, regexp_replace

def make_trans():
    matching_string = ""
    replace_string = ""

    for i in range(ord(" "), sys.maxunicode):
        name = unicodedata.name(chr(i), "")
        if "WITH" in name:
                base = unicodedata.lookup(name.split(" WITH")[0])
                matching_string += chr(i)
                replace_string += base
            except KeyError:

    return matching_string, replace_string

def clean_text(c):
    matching_string, replace_string = make_trans()
    return translate(
        regexp_replace(c, "\p{M}", ""), 
        matching_string, replace_string


df = sc.parallelize([
(1, "Maracaibó"), (2, "New York"),
(3, "   São Paulo   "), (4, "~Madrid"),
(5, "São Paulo"), (6, "Maracaibó")
]).toDF(["id", "text"])

## +---------------+
## |           text|
## +---------------+
## |      Maracaibo|
## |       New York|
## |   Sao Paulo   |
## |        ~Madrid|
## |      Sao Paulo|
## |      Maracaibo|
## +---------------+


这是我的实现。 除了口音,我还删除了特殊字符。因为我需要旋转并保存 table,而您不能保存列名包含“,;{}()\n\t=\/”字符的 table。

import re

from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType, StructType, StructField
from unidecode import unidecode

spark = SparkSession.builder.getOrCreate()
data = [(1, "  \ / \ {____} aŠdá_ \t =  \n () asd ____aa 2134_ 23_"), (1, "N"), (2, "false"), (2, "1"), (3, "NULL"),
        (3, None)]
schema = StructType([StructField("id", IntegerType(), True), StructField("txt", StringType(), True)])
df = SparkSession.builder.getOrCreate().createDataFrame(data, schema)

for col_name in ["txt"]:
    tmp_dict = {}
    for col_value in [row[0] for row in df.select(col_name).distinct().toLocalIterator()
                      if row[0] is not None]:
        new_col_value = re.sub("[ ,;{}()\n\t=\\/]", "_", col_value)
        new_col_value = re.sub('_+', '_', new_col_value)
        if new_col_value.startswith("_"):
            new_col_value = new_col_value[1:]
        if new_col_value.endswith("_"):
            new_col_value = new_col_value[:-1]
        new_col_value = unidecode(new_col_value)
        tmp_dict[col_value] = new_col_value.lower()
    df = df.na.replace(to_replace=tmp_dict, subset=[col_name])

如果您无法访问外部库(像我一样),您可以将 unidecode 替换为

new_col_value = new_col_value.translate(str.maketrans(