Background
When the client submits a Spark application on Hadoop YARN, a YARN Application Master (MA) for Spark is created, which starts a JVM for creating and scheduling tasks for the application. The JVM is called a Spark driver or driver container (where a container can be understood as consisting of CPU and memory). Inside the driver, a Spark session will then be built; next a streaming job will be started in that session. Multiple streams can exist in the same session.
A graceful shutdown means completing the current data processing and stopping receving new data. The shutdown is important to Spark structured streaming since data may be saved but offsets may fail to be commited, leading to duplicate data upon stream restart.
Streaming can be in a continuous mode or in microbatch mode. We’ll consider only the latter in the article.
A streaming job can be killed in this way: Some process sends SIGTERM (signal of termination) to the Spark driver, then the driver stops the Spark session. To guarantee gracefulness for batch streaming, we have to ensure (1) the batch being proccessed at the time SIGTERM arrives continues being processed until completion, and (2) only after that does the closing of the Spark session and driver follow if there’s only one stream in the session.
SIGTERM will be sent by the “-stop” command of YARN 3.x.x. The command enables the graceful shutdown of YARN applications. That’s on the YARN side, and we have to implement more on the Spark streaming side to get the graceful shutdown.
Approach
Suppose we have the stream code:
val spark = ... // Spark session variable
def processEachBatch = (dataset: Dataset[Row], batchId: Long) => {
// processing logic here
}
val stream = spark
.readStream
.format("kafka")
.option("option1", "value1")
.load()
.writeStream
.trigger(Trigger.ProcessingTime("20 seconds"))
.foreachBatch(processEachBatch)
.option("checkpointLocation", /path/to/checkpoint)
.start()
// here: creates a seperate thread that uses stream.stop() to end streaming
stream.awaitTermination()
We then can create a thread that is triggered on the arrival of SIGTERM and uses the Spark method stop() to end the stream. However, the method cancels all active jobs in a group, which is not necessarily always graceful. Indeed, the batch is cancelled halfway with this method.
A solution is to employ StreamingQueryListener to monitor and react to changes in the stream such as startup, update and termination.
import org.apache.hadoop.util.ShutdownHookManage
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.streaming.StreamingQueryManager
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.SynchronousQueue
class GracefulShutdownListener(streams: StreamingQueryManager) extends StreamingQueryListener {
private val queryStore = new ConcurrentHashMap[UUID, (Runnable, SynchronousQueue[String])]()
override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {
val stream = streams.get(event.id)
val syncQueue = new SynchronousQueue[String]()
val shutdownHook: Runnable = () => {
if (stream.isActive) {
val syncSignal = "block"
syncQueue.put(syncSignal)
stream.stop()
stream.awaitTermination()
}
}
ShutdownHookManager.get().addShutdownHook(shutdownHook, 95)
queryStore.put(stream.id, (shutdownHook, syncQueue))
}
override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
val (runnable, syncQueue) = queryStore.get(event.progress.id)
syncQueue.poll()
}
override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
val (shutdownHook, syncQueue) = queryStore.remove(event.id)
if (!ShutdownHookManager.get().isShutdownInProgress) ShutdownHookManager.get().removeShutdownHook(shutdownHook)
}
}
Then let’s attach the stream to the listener.
val spark = ... // Spark session variable
def processEachBatch = (dataset: Dataset[Row], batchId: Long) => {
// processing logic here
}
// register the listener
spark.streams.addListener(new GracefulShutdownListener(spark.streams))
val stream = spark
.readStream
.format("kafka")
.option("option1", "value1")
.load()
.writeStream
.trigger(Trigger.ProcessingTime("20 seconds"))
.foreachBatch(processEachBatch)
.option("checkpointLocation", /path/to/checkpoint)
.start()
// here: creates a seperate thread that uses stream.stop() to end streaming
stream.awaitTermination()
Now, there exist three threads: the stream thread, the listener thread and the shutdownHook thread.
Here are notes on the listener used.
override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {
val stream = streams.get(event.id)
val syncQueue = new SynchronousQueue[String]()
val shutdownHook: Runnable = () => {
if (stream.isActive) {
val syncSignal = "block"
syncQueue.put(syncSignal)
stream.stop()
stream.awaitTermination()
}
}
ShutdownHookManager.get().addShutdownHook(shutdownHook, 95)
queryStore.put(stream.id, (shutdownHook, syncQueue))
}
onQueryStarted runs when the streaming query starts (see the start() method when creating the stream above). This function registers a shutdown hook to the ShutdownHookManager. The hook will be called later, when the Spark driver receives SIGTERM. To ensure all data processing is completed before the Spark session and driver call their own hook, Hadoop ShutdownHookManager was utilized rather than the pure JVM hook thread because it allows for specifiying the priority of thread execution (the higher the earlier). I chose 95 as SparkShutdownHookManager has priority 40 and the Spark context shutdown hook priority is 50. So a value above 50 is needed for the thread execution order.
The function also maps the query ID to a queue (syncQueue) that is responsible for syncing the listener thread and the newly registered shutdown hook thread. In the queue, of type SynchronousQueue, the thread that puts a value (e.g. “block”, or any value you might like) must wait for the thread that removes the value to complete before it can proceed. syncQueue is the means of communication between the threads. This is important since we need the shutdownHook thread to run only after the batch at the time SIGTERM arrives has been finished, that is, to run only after syncQueue.poll(). This action is exactly seen in onQueryProgress().
override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
val (runnable, syncQueue) = queryStore.get(event.progress.id)
syncQueue.poll()
}
Finally, in onQueryTerminated()
override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
val (shutdownHook, syncQueue) = queryStore.remove(event.id)
if (!ShutdownHookManager.get().isShutdownInProgress) ShutdownHookManager.get().removeShutdownHook(shutdownHook)
}
when the stream has stopped, we want to remove the stream from the list from running queries as a way to clean things up. Imagine the case where multiple streams exist in the same Spark application. shutdownHook should be taken from the list of hooks because if the streaming job stops due to an exception, the shutdown hook will be hanging around indefinitely as the batch doesn’t finish, leading to memory leak or deadlock. (Remember that the shutdown hook only starts after batch execution completion.)
The execution flow can be visualized in the following diagram.
Notice the appearance of SIGTERM in the middle of batch 3 execution. The shutdown hook has an interlude of going dormant before the execution is done.
Also note, however, that as the threads are asynchronous by nature, there may be cases where onQueryProgress() begins rerunning when the following batch is starting, rather than right at the time the previous batch execution has completed. Such cases should be rare.
Instead of SynchronousQueue as used above to synchronize threads, we can take advantage of Semaphore to lock access to resources. Here SynchronousQueue’s put() and poll() correspond with Semaphore‘s acquire() and release() respectively.