如何在配置单元 UDF 中获取 taskID 或 mapperID(类似于 Spark 中的 partitionID)?

How to get the taskID or mapperID(something like partitionID in Spark) in a hive UDF?

作为问题,如何在配置单元 UDF 中获取 taskID 或 mapperID(类似于 Spark 中的 partitionID)?

您可以使用 TaskContext 访问任务信息:

import org.apache.spark.TaskContext
sc.parallelize(Seq[Int](), 4).mapPartitions(_ => {
  val ctx = TaskContext.get
  val stageId = ctx.stageId
  val partId = ctx.partitionId
  val hostname = java.net.InetAddress.getLocalHost().getHostName()
  Iterator(s"Stage: $stageId, Partition: $partId, Host: $hostname")}).collect.foreach(println)

Spark 2.2.0 (SPARK-18576) 中的 PySpark 添加了类似的功能:

from pyspark import TaskContext
import socket

def task_info(*_):
    ctx = TaskContext()
    return ["Stage: {0}, Partition: {1}, Host: {2}".format 
    (ctx.stageId(), ctx.partitionId(), socket.gethostname())]

for x in sc.parallelize([], 4).mapPartitions(task_info).collect():
    print(x)

我认为它会为您提供有关任务的信息,包括您要查找的地图 ID。

我自己找到了正确答案,我们可以通过以下方式获取hive UDF中的taskID :

public class TestUDF extends GenericUDF  {
    private Text result = new Text();
    private String tmpStr = "";

    @Override
    public void configure(MapredContext context) {
        //get the number of tasks 获取task总数量
        int numTasks = context.getJobConf().getNumMapTasks();
        //get the current taskID 获取当前taskID
        String taskID =  context.getJobConf().get("mapred.task.id");
        this.tmpStr = numTasks + "_h_xXx_h_" + taskID;
    }

    @Override
    public ObjectInspector initialize(ObjectInspector[] arguments)
            throws UDFArgumentException {
        return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
    }
    
    @Override
    public Object evaluate(DeferredObject[] arguments) {
        result.set(this.tmpStr);
        return this.result;
    }

    @Override
    public String getDisplayString(String[] children) {
        return "RowSeq-func()";
    }
}

但这仅在 MapReduce 执行引擎中有效,在 SparkSQL 引擎中无效。
测试代码如下:

add jar hdfs:///home/dp/tmp/shaw/my_udf.jar;
create temporary function seqx AS 'com.udf.TestUDF';

with core as (
select 
    device_id
from
    test_table
where
    p_date = '20210309' 
    and product = 'google'
distribute by
    device_id
)
select
    seqx() as seqs,
    count(1) as cc
from
    core
group by
    seqx()
order by
    seqs asc 

MR引擎结果如下,看到我们已经成功获取到任务号和taskID:

Spark 引擎中的结果与上述 sql 相同,UDF 无效,我们对 taskID 一无所知:

如果你在Spark引擎中运行你的HQL同时调用了Hive UDF,并且确实需要在Spark中获取partitionId,请看下面的代码:

import org.apache.spark.TaskContext;

public class TestUDF extends GenericUDF  {
    private Text result = new Text();
    private String tmpStr = "";

    @Override
    public ObjectInspector initialize(ObjectInspector[] arguments)
            throws UDFArgumentException {
        //get spark partitionId
        this.tmpStr = TaskContext.getPartitionId() + "-initial-pid";
        return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
    }

    public Object evaluate(DeferredObject[] arguments) {
        //get spark partitionId
        this.tmpStr = TaskContext.getPartitionId() + "-evaluate-pid";
        result.set(this.tmpStr);
        return this.result;
    }
}

同上,在UDFclass.
的override方法initializeevalute中调用TaskContext.getPartitionId()获取Spark partitionId 注意:你的UDF必须有参数,比如select my_udf(param),这样会导致你的UDF在多个任务中初始化,如果你的UDF没有参数,它会在Driver处初始化,而Driver没有taskContext和 partitionId,所以你什么也得不到。

下图是上述UDF在Spark引擎中执行的结果,看,我们成功获取到了partitionIds: