计算每个 pyspark RDD 分区中的元素数

Count number of elements in each pyspark RDD partition

我正在寻找与此问题对应的 Pyspark:

具体来说,我想以编程方式计算 pyspark RDD 或数据帧的每个分区中的元素数量(我知道此信息可在 Spark Web UI 中获得)。

这次尝试:

df.foreachPartition(lambda iter: sum(1 for _ in iter))

结果:

AttributeError: 'NoneType' object has no attribute '_jvm'

我不想将迭代器的内容收集到内存中。

如果你问:我们可以在不遍历迭代器的情况下获取迭代器中元素的数量吗?答案是No.

但我们不必将其存储在内存中,如您提到的post:

def count_in_a_partition(idx, iterator):
  count = 0
  for _ in iterator:
    count += 1
  return idx, count

data = sc.parallelize([
    1, 2, 3, 4
], 4)

data.mapPartitionsWithIndex(count_in_a_partition).collect()

编辑

请注意,您的代码非常接近解决方案,只是 mapPartitions 需要 return 一个迭代器:

def count_in_a_partition(iterator):
  yield sum(1 for _ in iterator)

data.mapPartitions(count_in_a_partition).collect()