使用过滤和 UDF 优化 Spark 代码
Optimizing Spark code with filtering and UDF
我正在使用 Spark 处理包含 2000 万个 XML 文档的数据集。我本来是在处理所有这些,但实际上我只需要大约三分之一。在不同的 spark 工作流程中,我创建了一个数据框 keyfilter
,其中一列是每个 XML 的键,第二列是布尔值,如果 xml 对应于应该处理密钥,否则 False
。
XML 本身是使用 Pandas UDF 处理的,我无法共享。
我在 DataBricks 上的笔记本基本上是这样工作的:
import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>
keyfilter = spark.read.parquet('/path/to/keyfilter/os/s3.parquet')
keyfilter.cache()
def process_part(part, fraction=1, filter=True, return_df=False):
try:
df = spark.read.parquet('/path/to/parquets/on/s3/%s/part-%05d*' % (DATE, part))
# Sometimes, the file part-xxxxx doesn't exist
except AnalysisException:
return None
if fraction < 1:
df = df.sample(fraction=fraction, withReplacement=False)
if filter:
df_with_filter = df.join(keyfilter, on='key', how='left').fillna(False)
filtered_df = df_with_filter.filter(col('filter')).drop('filter')
mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
else:
mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
mod_df.write.parquet('/output/path/on/s3/part-%05d_%s_%d' % (part, DATE, time.time()))
if return_df:
return mod_df
n_cores = 6
i=0
while n_cores*i < 1024:
with ThreadPool(n_cores) as p:
p.map(process_part, range(n_cores*i, min(1024, n_cores*i+n_cores)))
i += 1
我发布这个问题的原因是,尽管 Pandas UDF 应该 是最昂贵的操作,但添加过滤实际上使得我的代码 运行 比我根本不过滤要慢得多。我是 Spark 的新手,我想知道我是否在这里做了一些愚蠢的事情导致 keyfilter
的连接非常慢,如果是这样,是否有办法让它们变快(例如,有没有办法让 keyfilter
像从键到布尔值的散列 table 一样工作,比如在 SQL 中创建索引?)。我想 keyfilter
的大尺寸在这里发挥了某种作用;它有 2000 万行,而 process_part
中的 df
只占这些行的一小部分(但是,df
的大小要大得多,因为它包含 XML 文档)。我是否应该将所有部分组合成一个巨大的数据框,而不是一次处理一个?
或者有没有办法通知 Spark 键在两个数据帧中都是唯一的?
让连接在合理的时间范围内发生的关键是在 keyfilter
上使用 broadcast
来执行广播哈希连接而不是标准连接。我还合并了一些部分并降低了并行度(由于某种原因,太多线程似乎有时会导致引擎崩溃)。我新的高性能代码如下所示:
import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col, braodcast
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>
keyfilter = spark.read.parquet('/path/to/keyfilter/on/s3.parquet')
keyfilter.cache()
def process_parts(part_pair, fraction=1, return_df=False, filter=True):
dfs = []
parts_start, parts_end = part_pair
parts = range(parts_start, parts_end)
for part in parts:
try:
df = spark.read.parquet('/input/path/on/s3/%s/part-%05d*' % (DATE, part))
dfs.append(df)
except AnalysisException:
print("There is no part %05d!" % part)
continue
if len(dfs) >= 2:
df = reduce(lambda x, y: x.union(y), dfs)
elif len(dfs) == 1:
df = dfs[0]
else:
return None
if fraction < 1:
df = df.sample(fraction=fraction, withReplacement=False)
if filter:
df_with_filter = df.join(broadcast(keyfilter), on='key', how='left').fillna(False)
filtered_df = df_with_filter.filter(col('filter')).drop('filter')
mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
else:
mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
mod_df.write.parquet('/output/path/on/s3/parts-%05d-%05d_%s_%d' % (parts_start, parts_end-1, DATE, time.time()))
if return_df:
return mod_df
start_time = time.time()
pairs = [(i*4, i*4+4) for i in range(256)]
with ThreadPool(3) as p:
batch_start_time = time.time()
for i, _ in enumerate(p.imap_unordered(process_parts, pairs, chunksize=1)):
batch_end_time = time.time()
batch_len = batch_end_time - batch_start_time
cum_len = batch_end_time - start_time
print('Processed group %d/256 %d minutes and %d seconds after previous group.' % (i+1, batch_len // 60, batch_len % 60))
print('%d hours, %d minutes, %d seconds since start.' % (cum_len // 3600, (cum_len % 3600) // 60, cum_len % 60))
batch_start_time = time.time()
我正在使用 Spark 处理包含 2000 万个 XML 文档的数据集。我本来是在处理所有这些,但实际上我只需要大约三分之一。在不同的 spark 工作流程中,我创建了一个数据框 keyfilter
,其中一列是每个 XML 的键,第二列是布尔值,如果 xml 对应于应该处理密钥,否则 False
。
XML 本身是使用 Pandas UDF 处理的,我无法共享。
我在 DataBricks 上的笔记本基本上是这样工作的:
import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>
keyfilter = spark.read.parquet('/path/to/keyfilter/os/s3.parquet')
keyfilter.cache()
def process_part(part, fraction=1, filter=True, return_df=False):
try:
df = spark.read.parquet('/path/to/parquets/on/s3/%s/part-%05d*' % (DATE, part))
# Sometimes, the file part-xxxxx doesn't exist
except AnalysisException:
return None
if fraction < 1:
df = df.sample(fraction=fraction, withReplacement=False)
if filter:
df_with_filter = df.join(keyfilter, on='key', how='left').fillna(False)
filtered_df = df_with_filter.filter(col('filter')).drop('filter')
mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
else:
mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
mod_df.write.parquet('/output/path/on/s3/part-%05d_%s_%d' % (part, DATE, time.time()))
if return_df:
return mod_df
n_cores = 6
i=0
while n_cores*i < 1024:
with ThreadPool(n_cores) as p:
p.map(process_part, range(n_cores*i, min(1024, n_cores*i+n_cores)))
i += 1
我发布这个问题的原因是,尽管 Pandas UDF 应该 是最昂贵的操作,但添加过滤实际上使得我的代码 运行 比我根本不过滤要慢得多。我是 Spark 的新手,我想知道我是否在这里做了一些愚蠢的事情导致 keyfilter
的连接非常慢,如果是这样,是否有办法让它们变快(例如,有没有办法让 keyfilter
像从键到布尔值的散列 table 一样工作,比如在 SQL 中创建索引?)。我想 keyfilter
的大尺寸在这里发挥了某种作用;它有 2000 万行,而 process_part
中的 df
只占这些行的一小部分(但是,df
的大小要大得多,因为它包含 XML 文档)。我是否应该将所有部分组合成一个巨大的数据框,而不是一次处理一个?
或者有没有办法通知 Spark 键在两个数据帧中都是唯一的?
让连接在合理的时间范围内发生的关键是在 keyfilter
上使用 broadcast
来执行广播哈希连接而不是标准连接。我还合并了一些部分并降低了并行度(由于某种原因,太多线程似乎有时会导致引擎崩溃)。我新的高性能代码如下所示:
import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col, braodcast
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>
keyfilter = spark.read.parquet('/path/to/keyfilter/on/s3.parquet')
keyfilter.cache()
def process_parts(part_pair, fraction=1, return_df=False, filter=True):
dfs = []
parts_start, parts_end = part_pair
parts = range(parts_start, parts_end)
for part in parts:
try:
df = spark.read.parquet('/input/path/on/s3/%s/part-%05d*' % (DATE, part))
dfs.append(df)
except AnalysisException:
print("There is no part %05d!" % part)
continue
if len(dfs) >= 2:
df = reduce(lambda x, y: x.union(y), dfs)
elif len(dfs) == 1:
df = dfs[0]
else:
return None
if fraction < 1:
df = df.sample(fraction=fraction, withReplacement=False)
if filter:
df_with_filter = df.join(broadcast(keyfilter), on='key', how='left').fillna(False)
filtered_df = df_with_filter.filter(col('filter')).drop('filter')
mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
else:
mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
mod_df.write.parquet('/output/path/on/s3/parts-%05d-%05d_%s_%d' % (parts_start, parts_end-1, DATE, time.time()))
if return_df:
return mod_df
start_time = time.time()
pairs = [(i*4, i*4+4) for i in range(256)]
with ThreadPool(3) as p:
batch_start_time = time.time()
for i, _ in enumerate(p.imap_unordered(process_parts, pairs, chunksize=1)):
batch_end_time = time.time()
batch_len = batch_end_time - batch_start_time
cum_len = batch_end_time - start_time
print('Processed group %d/256 %d minutes and %d seconds after previous group.' % (i+1, batch_len // 60, batch_len % 60))
print('%d hours, %d minutes, %d seconds since start.' % (cum_len // 3600, (cum_len % 3600) // 60, cum_len % 60))
batch_start_time = time.time()