启用检查点的 Spark Streaming SQS

Spark streaming SQS with checkpoint enable

我访问过多个站点,例如 https://spark.apache.org/docs/latest/streaming-programming-guide.html

https://data-flair.training/blogs/spark-streaming-checkpoint/

https://docs.databricks.com/spark/latest/rdd-streaming/developing-streaming-applications.html

一些链接讨论了我们如何编码,但它太抽象了,我需要很多时间来弄清楚它是如何工作的

经过长时间的斗争,我可以设置带有检查点的流代码,在这里添加以帮助其他人

import java.util.concurrent.Executors

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
import com.amazonaws.regions.Regions
import com.amazonaws.services.sqs.model.Message
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.log4j.LogManager
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.SparkSession
import org.apache.spark.streaming.{Duration, Seconds, StreamingContext}

object StreamingApp extends scala.Serializable {
  @transient private final val mapper = new ObjectMapper
  @transient private final val LOG = LogManager.getLogger(getClass)
  @transient private final val executor = Executors.newFixedThreadPool(Runtime.getRuntime.availableProcessors)
  var s3 = "s3"
  private var shutdownMarker: String = _
  private var stopFlag: Boolean = false

  def main(args: Array[String]): Unit = {
    val queueName = args(0)
    val region = args(1)
    val fetchMaxMessage = args(2).toInt
    val visibilityTimeOutSeconds = args(3).toInt
    val waitTimeoutInMillis = args(4).toLong
    val isLocal = args(5).toBoolean
    val bucket = args(6)
    if (args.length >= 10)
      shutdownMarker = args(9)
    val sparkConf = initialiseSparkConf(isLocal)
    sparkConf.set(Constants.QUEUE_NAME, queueName)
    sparkConf.set(Constants.REGION, region)
    sparkConf.set(Constants.FETCH_MAX_MESSAGE, fetchMaxMessage.toString)
    sparkConf.set(Constants.VISIBILITY_TIMEOUT_SECONDS, visibilityTimeOutSeconds.toString)
    sparkConf.set(Constants.WAIT_TIMEOUT_IN_MILLIS, waitTimeoutInMillis.toString)

    shutdownMarker = s"$s3://$bucket/streaming/shutdownmarker"
    val checkpointDirectory = s"$s3://$bucket/streaming/checkpoint/"
    var context: StreamingContext = null

    try {
      context = StreamingContext.getOrCreate(checkpointDirectory, () => createContext(sparkConf, waitTimeoutInMillis, checkpointDirectory, args))
      context.start
      val checkIntervalMillis = 10000
      var isStopped = false

      while (!isStopped) {
        println("calling awaitTerminationOrTimeout")
        isStopped = context.awaitTerminationOrTimeout(checkIntervalMillis)
        if (isStopped)
          println("confirmed! The streaming context is stopped. Exiting application...")
        checkShutdownMarker(context.sparkContext)
        if (!isStopped && stopFlag) {
          println("stopping ssc right now")
          context.stop(stopSparkContext = true, stopGracefully = true)
          println("ssc is stopped!!!!!!!")
        }
      }
    }
    finally {
      LOG.info("Exiting the Application")
      if (context != null && org.apache.spark.streaming.StreamingContextState.STOPPED != context.getState) {
        context.stop(stopSparkContext = true, stopGracefully = true)
      }
      if (!executor.isShutdown)
        executor.shutdown()
    }
  }

  def checkShutdownMarker(sparkContext: SparkContext): Unit = {
    if (!stopFlag) {
      stopFlag = isFileExists(shutdownMarker, sparkContext)
    }
    println(s"Stop marker $shutdownMarker file found: $stopFlag at time ${System.currentTimeMillis()}")
  }

def isFileExists(path: String, sparkContext: SparkContext): Boolean = {
    isValidPath(isDir = false, path, getFileSystem(path,sparkContext))
  }

  def getFileSystem(path: String, sparkContext: SparkContext): FileSystem = {
    FileSystem.get(URI.create(path), sparkContext.hadoopConfiguration)
  }

def isValidPath(isDir: Boolean, path: String, fileSystem: FileSystem): Boolean = {
    LOG.info("Validating path {}", path)
    if (path.startsWith(Constants.S3) || path.startsWith(Constants.HDFS) || path.startsWith(Constants.FILE)) {
      val fsPath = new Path(path)
      if (isDir) {
        fileSystem isDirectory fsPath
      } else {
        fileSystem isFile fsPath
      }
    } else {
      Files.exists(Paths.get(path))
    }
  }

  def createContext(sparkConf: SparkConf, waitTimeoutInMillis: Long, checkpointDirectory: String, args: Array[String]): StreamingContext = {

    val context = new StreamingContext(sparkConf, Duration(waitTimeoutInMillis + 1000))
    processMessage(context, args)
    context.checkpoint(checkpointDirectory) // set checkpoint directory
    context
  }

  def processMessage(context: StreamingContext, args: Array[String]): Unit = {

    val bucket = args(6)
    val wgPath = args(7)
    var stagingPath = args(8)
    val waitTimeoutInMillis = args(4).toLong
    if (context != null) {

      if (!stagingPath.endsWith("/")) {
        stagingPath = s"$stagingPath/"
      }
      val outputPrefix = s"$s3://$bucket/$stagingPath"

      LOG.info(s"Number of cores for driver: ${Runtime.getRuntime.availableProcessors}")

      val sparkContext: SparkContext = context.sparkContext

      val broadcasts = BroadCaster.getInstance(sparkContext, s"$s3://$bucket/$wgPath")

      val input = context.receiverStream(broadcasts(Constants.SQS_RECEIVER).value.asInstanceOf[SQSReceiver])
      //input.checkpoint(interval = Seconds(60))
      LOG.info(s"Scheduling mode ${sparkContext.getSchedulingMode.toString}")
      input.foreachRDD(r => {
        val sparkSession = SparkSession.builder.config(r.sparkContext.getConf).getOrCreate()

        val messages = r.collect().map(message => mapper.readValue(message, classOf[Message]))

        val broadcasts = BroadCaster.getInstance(r.sparkContext, s"$s3://$bucket/$wgPath")
        //Application logic
      })
    }
  }


  def initialiseSparkConf(local: Boolean): SparkConf = {
    val sparkConf = new SparkConf()
      .setAppName("Spark Streaming")
      .set("spark.scheduler.mode", "FAIR")
      .set("spark.sql.parquet.filterpushdown", "true")
      .set("spark.executor.hearbeatInterval", "20")
      .set("spark.streaming.driver.writeAheadLog.closeFileAfterWrite", "true")
      .set("spark.streaming.receiver.writeAheadLog.closeFileAfterWrite", "true")
      .set("spark.streaming.receiver.writeAheadLog.enable", "true")
      .set("spark.streaming.stopGracefullyOnShutdown", "true")
      .set("spark.streaming.backpressure.enabled","true")
      .set("spark.streaming.backpressure.pid.minRate","10") //SQS support batch of 10

    if (local) {
      s3 = "s3a"
      sparkConf.setMaster("local[*]")
    } else {
      sparkConf.set("hive.metastore.client.factory.class",
        "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory")
    }
  }
}

object BroadCaster {

  @volatile private var instance: Map[String, Broadcast[Any]] = _

  def getInstance(sparkContext: SparkContext, wgPath: String): Map[String, Broadcast[Any]] = {
    if (instance == null) {
      synchronized {
        if (instance == null) {
          instance = Utils.createBroadcastObjects(wgPath, sparkContext)
          instance += (Constants.SQS_RECEIVER -> sparkContext.broadcast(getSQSReceiver(sparkContext.getConf)))
        }
      }
    }
    instance
  }

  private def getSQSReceiver(conf: SparkConf): SQSReceiver = {
    val javaSQSReceiver = new SQSReceiver(conf.get(Constants.QUEUE_NAME)).withRegion(Regions.fromName(conf.get(Constants.REGION))).withCredential(new DefaultAWSCredentialsProviderChain())
      .withFetchMaxMessage(conf.getInt(Constants.FETCH_MAX_MESSAGE, 10)).withVisibilityTimeOutSeconds(conf.getInt(Constants.VISIBILITY_TIMEOUT_SECONDS, 1800)).withWaitTimeoutinMillis(conf.getLong(Constants.WAIT_TIMEOUT_IN_MILLIS, 1000))
    javaSQSReceiver
  }
}

import java.util.List;

import org.apache.log4j.Logger;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.receiver.Receiver;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.AmazonSQSClientBuilder;
import com.amazonaws.services.sqs.model.DeleteMessageBatchRequest;
import com.amazonaws.services.sqs.model.DeleteMessageRequest;
import com.amazonaws.services.sqs.model.Message;
import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;

public class SQSReceiver extends Receiver<String> {

    private String queueName;
    private transient AWSCredentialsProvider credential;
    private Regions region = Regions.US_EAST_1;
    private Long waitTimeoutinMillis = 0L;
    private ObjectMapper mapper= new ObjectMapper();
    private transient Logger logger = Logger.getLogger(SQSReceiver.class);
    private boolean deleteOnReceipt = false;
    private int fetchMaxMessage = 100;
    private int visibilityTimeOutSeconds = 60;

    private String sqsQueueUrl;
    private transient AmazonSQS amazonSQS;

    public SQSReceiver(String queueName) {
        this(queueName, false);
    }

    public SQSReceiver(String queueName, boolean deleteOnReceipt) {
        super(StorageLevel.MEMORY_AND_DISK_SER());
        this.queueName = queueName;
        this.deleteOnReceipt = deleteOnReceipt;
        setupSQS(queueName);
    }

    private void setupSQS(String queueName) {
        AmazonSQSClientBuilder amazonSQSClientBuilder = AmazonSQSClientBuilder.standard();

        if (credential != null) {
            amazonSQSClientBuilder.withCredentials(credential);
        }
        amazonSQSClientBuilder.withRegion(region);
        amazonSQS = amazonSQSClientBuilder.build();
        sqsQueueUrl = amazonSQS.getQueueUrl(queueName).getQueueUrl();
    }

    public void onStart() {
        new Thread(this::receive).start();
    }

    public void onStop() {
        // There is nothing much to do as the thread calling receive()
        // is designed to stop by itself if isStopped() returns false
    }

    private void receive() {
        try {
            setupSQS(queueName);
            ReceiveMessageRequest receiveMessageRequest = new ReceiveMessageRequest(sqsQueueUrl).withMaxNumberOfMessages(fetchMaxMessage).withVisibilityTimeout(visibilityTimeOutSeconds)
                    .withWaitTimeSeconds(20); //https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/examples-sqs-long-polling.html
            receiveMessagesFromSQS(amazonSQS, sqsQueueUrl, receiveMessageRequest);
        } catch (Throwable e) {
            stop("Error encountered while initializing", e);
        }
    }

    private void receiveMessagesFromSQS(final AmazonSQS amazonSQS, final String sqsQueueUrl,
                                        ReceiveMessageRequest receiveMessageRequest) {
        try {
            while (!isStopped()) {
                List<Message> messages = amazonSQS.receiveMessage(receiveMessageRequest).getMessages();
                if (deleteOnReceipt) {
                    String receiptHandle = messages.get(0).getReceiptHandle();
                    messages.forEach(m -> store(m.getBody()));
                    amazonSQS.deleteMessage(new DeleteMessageRequest(sqsQueueUrl, receiptHandle));
                } else {
                    messages.forEach(this::storeMessage);
                }
                if (waitTimeoutinMillis > 0L)
                    Thread.sleep(waitTimeoutinMillis);
            }
            restart("Trying to connect again");
        } catch (IllegalArgumentException | InterruptedException e) {
            restart("Could not connect", e);
        } catch (Throwable e) {
            restart("Error Receiving Data", e);
        }
    }

    private void storeMessage(Message m) {
        try {
            if (m != null)
                store(mapper.writeValueAsString(m));
        } catch (JsonProcessingException e) {
            logger.error("Unable to write message to streaming context");
        }
    }

    public SQSReceiver withVisibilityTimeOutSeconds(int visibilityTimeOutSeconds) {
        this.visibilityTimeOutSeconds = visibilityTimeOutSeconds;
        return this;
    }

    public SQSReceiver withFetchMaxMessage(int fetchMaxMessage) {
        if (fetchMaxMessage > 10) {
            throw new IllegalArgumentException("FetchMaxMessage can't be greater than 10");
        }
        this.fetchMaxMessage = fetchMaxMessage;
        return this;
    }

    public SQSReceiver withWaitTimeoutinMillis(long waitTimeoutinMillis) {
        this.waitTimeoutinMillis = waitTimeoutinMillis;
        return this;
    }

    public SQSReceiver withRegion(Regions region) {
        this.region = region;
        return this;
    }

    public SQSReceiver withCredential(AWSCredentialsProvider credential) {
        this.credential = credential;
        return this;
    }

    public void deleteMessages(DeleteMessageBatchRequest request) {
        request.withQueueUrl(sqsQueueUrl);
        amazonSQS.deleteMessageBatch(request);
    }
}