Spark arbitrary stateful stream aggregation, flatMapGroupsWithState API - apache-spark

10 days old spark developer, trying to understand the flatMapGroupsWithState API of spark.
As I understand:
We pass 2 options to it which are timeout configuration. A possible value is GroupStateTimeout.ProcessingTimeTimeout i.e. kind of an instruction to spark to consider processing time and not event time. Other is the output mode.
We pass in a function, lets say myFunction, that is responsible for setting the state for each key. And we also set a timeout duration with groupState.setTimeoutDuration(TimeUnit.HOURS.toMillis(4)), assuming groupState is the instance of my groupState for a key.
As I understand, as micro batches of stream data keep coming in, spark maintain an intermediate state as we define in user defined function. Lets say the intermediate state after processing n micro batches of data is as follows:
State for Key1:
{
key1: [v1, v2, v3, v4, v5]
}
State for key2:
{
key2: [v11, v12, v13, v14, v15]
}
For any new data that come in, myFunction is called with state for the particular key. Eg. for key1, myFunction is called with key1, new key1 values, [v1,v2,v3,v4,v5] and it updates the key1 state as per the logic.
I read about the timeout and I found Timeout dictates how long we should wait before timing out some intermediate state.
Questions:
If this process run indefinitely, my intermediate states will keep on piling and hit the memory limits on nodes. So when are these intermediate states cleared. I found that in case of event time aggregation, watermarks dictates when the intermediate states will be cleared.
What does timing out the intermediate state mean in the context of Processing time.

If this process run indefinitely, my intermediate states will keep on piling and hit the memory limits on nodes. So when are these intermediate states cleared. I found that in case of event time aggregation, watermarks dictates when the intermediate states will be cleared.
Apache Spark will mark them as expired after the expiration time, so in your example after 4 hours of inactivity (real time + 4 hours, inactivity = no new event updating the state).
What does timing out the intermediate state mean in the context of Processing time.
It means that it will time out accordingly to the real clock (processing time, org.apache.spark.util.SystemClock class). You can check what clock is currently used by analyzing org.apache.spark.sql.streaming.StreamingQueryManager#startQuery triggerClock parameter.
You will find more details in FlatMapGroupsWithStateExec class, and more particularly here:
// Generate a iterator that returns the rows grouped by the grouping function
// Note that this code ensures that the filtering for timeout occurs only after
// all the data has been processed. This is to ensure that the timeout information of all
// the keys with data is updated before they are processed for timeouts.
val outputIterator =
processor.processNewData(filteredIter) ++ processor.processTimedOutState()
And if you analyze these 2 methods, you will see that:
processNewData applies mapping function to all active keys (present in the micro-batch)
/**
* For every group, get the key, values and corresponding state and call the function,
* and return an iterator of rows
*/
def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output)
groupedIter.flatMap { case (keyRow, valueRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),
valueRowIter,
hasTimedOut = false)
}
}
processTimedOutState calls the mapping function on all expired states
def processTimedOutState(): Iterator[InternalRow] = {
if (isTimeoutEnabled) {
val timeoutThreshold = timeoutConf match {
case ProcessingTimeTimeout => batchTimestampMs.get
case EventTimeTimeout => eventTimeWatermark.get
case _ =>
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
}
val timingOutPairs = stateManager.getAllState(store).filter { state =>
state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
}
timingOutPairs.flatMap { stateData =>
callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true)
}
} else Iterator.empty
}
An important point to notice here is that Apache Spark will keep expired state in the state store if you don't invoke GroupState#remove method. The expired states won't be returned for processing though because they're flagged with NO_TIMESTAMP field. However, they will be stored in the state store delta files which may slow down the reprocessing if you need to reload the most recent state. If you analyze FlatMapGroupsWithStateExec again, you will see that the state is removed only when the state removed flag is set to true:
def callFunctionAndUpdateState(...)
// ...
// When the iterator is consumed, then write changes to state
def onIteratorCompletion: Unit = {
if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) {
stateManager.removeState(store, stateData.keyRow)
numUpdatedStateRows += 1
} else {
val currentTimeoutTimestamp = groupState.getTimeoutTimestamp
val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp
val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged
if (shouldWriteState) {
val updatedStateObj = if (groupState.exists) groupState.get else null
stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp)
numUpdatedStateRows += 1
}
}
}

Related

How to introduce reset logic when aggregating/joining streaming dataframe with static dataframe for spark streaming

A good feature of spark structured streaming is that it can join the static dataframe with the streaming dataframe. To cite an example as below. users is a static dataframe read from database. transactionStream is from a stream. By the joining function, we can get the spending of each country accumulated with the new arrival of batches.
val spendingByCountry = (transactionStream
.join(users, users("id") === transactionStream("userid"))
.groupBy($"country")
.agg(sum($"cost")) as "spending")
spendingByContry.writeStream
.outputMode("complete")
.format("console")
.start()
The sum of cost is aggregated with the new batches are coming as shown below.
-------------------------------
Batch: 0
-------------------------------
Country Spending
EN 90.0
FR 50.0
-------------------------------
Batch: 1
-------------------------------
Country Spending
EN 190.0
FR 150.0
If I want to introduce a notification and reset logic as the above example, what should be the correct approach? The requirement is that if the spending is larger than some threshold, the records of country and spending should be stored into a table and the spending should be reset as 0 to accumulate again.
One approach that you can achieve this is by arbitrary stateful processing. The groupBy can be enhanced with a custom function mapGroupsWithState where you maintain all the business logic needed. Here is an example taken from the Spark docs:
// A mapping function that maintains an integer state for string keys and returns a string. // Additionally, it sets a timeout to remove the state if it has not received data for an hour. def mappingFunction(key: String, value: Iterator[Int], state: GroupState[Int]): String = {
if (state.hasTimedOut) { // If called when timing out, remove the state
state.remove()
} else if (state.exists) { // If state exists, use it for processing
val existingState = state.get // Get the existing state
val shouldRemove = ... // Decide whether to remove the state
if (shouldRemove) {
state.remove() // Remove the state
} else {
val newState = ...
state.update(newState) // Set the new state
state.setTimeoutDuration("1 hour") // Set the timeout
}
} else {
val initialState = ...
state.update(initialState) // Set the initial state
state.setTimeoutDuration("1 hour") // Set the timeout } ... // return something }
dataset
.groupByKey(...)
.mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout)(mappingFunction)

Read, update and save cached value atomically

I have a multiple streams (N) which should update the same cache. So, assume, that there is at least N threads. Each thread may process values with similar keys. The problem is that if i do update as following:
1. Read old value from cache (multiple threads get the same old value)
2. Merge new value with old value (each thread update old value)
3. Save updated value back to the cache (only the last update was saved, another one is lost)
i can lost some updates if multiple threads will simultaneously try to update the same record. At first glance, there is a solution to make all updates atomic: for example, use Increment mutation in hbase or add in aerospike (currently, i'm considering these caches for my case). If value consists only of numeric primitive types, then it is ok, because both cache implementations support atomic inc/dec.
1. Inc/dec each value (cache will resolve sequence of this ops by it's self)
But what if value consists not only of primitives? Then i have to read value and update it in my code. In this case i still can lose some updates.
As i wrote, currently i'm considering hbase and aerospike, but both not fully fit for my case. In hbase, as i know, there is no way to lock row from client side (> ~0.98), so i have to use checkAndPut operation for each complex type. In aerospike i can achieve something like row-based lock using lua udfs, but i want to avoid them. Redis allow to watch record and if there is was update from another thread the transaction will fail and i can catch this error and try again.
So, my question is how to achieve something like row-based lock for such updates and is row-based lock will be a correct way? Maybe there is another approach?
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setMaster("local[2]").setAppName("sample")
val sc = new SparkContext(sparkConf)
val ssc = new StreamingContext(sc, Duration(500))
val source = Source()
val stream = source.stream(ssc)
stream.foreachRDD(rdd => {
if (!rdd.isEmpty()) {
rdd.foreachPartition(partition => {
if (partition.nonEmpty) {
val cache = Cache()
partition.foreach(entity=> {
// in this block if 2 distributed workers (in case of apache spark, for example)
//will process entities with the same keys i can lose one of this update
// worker1 and worker2 will get the same value
val value = cache.get(entity.key)
// both workers will update this value but may get different results
val updatedValue = ??? // some non-trivial update depends on entity
// for example, worker1 put new value, then worker2 put new value. In this case only updates from worker2 are visible and updates from worker1 are lost
cache.put(entity.key, updatedValue)
})
}
})
}
})
ssc.start()
ssc.awaitTermination()
}
So, in case if i use kafka as source i can workaround this if messages are partitioned by keys. In this case i can rely on the fact that only 1 worker will process particular record at any point of time. But how to handle the same situation when messages partitioned randomly (key is inside message body)?

How do I limit write operations to 1k records/sec?

Currently, I am able to write to database in the batchsize of 500. But due to the memory shortage error and delay synchronization between child aggregator and leaf node of database, sometimes I am running into Leaf Node Memory Error. The only solution for this is if I limit my write operations to 1k records per second, I can get rid of the error.
dataStream
.map(line => readJsonFromString(line))
.grouped(memsqlBatchSize)
.foreach { recordSet =>
val dbRecords = recordSet.map(m => (m, Events.transform(m)))
dbRecords.map { record =>
try {
Events.setValues(eventInsert, record._2)
eventInsert.addBatch
} catch {
case e: Exception =>
logger.error(s"error adding batch: ${e.getMessage}")
val error_event = Events.jm.writeValueAsString(mapAsJavaMap(record._1.asInstanceOf[Map[String, Object]]))
logger.error(s"event: $error_event")
}
}
// Bulk Commit Records
try {
eventInsert.executeBatch
} catch {
case e: java.sql.BatchUpdateException =>
val updates = e.getUpdateCounts
logger.error(s"failed commit: ${updates.toString}")
updates.zipWithIndex.filter { case (v, i) => v == Statement.EXECUTE_FAILED }.foreach { case (v, i) =>
val error = Events.jm.writeValueAsString(mapAsJavaMap(dbRecords(i)._1.asInstanceOf[Map[String, Object]]))
logger.error(s"insert error: $error")
logger.error(e.getMessage)
}
}
finally {
connection.commit
eventInsert.clearBatch
logger.debug(s"committed: ${dbRecords.length.toString}")
}
}
The reason for 1k records is that, some of the data that I am trying to write can contains tons of json records and if batch size if 500, that may result in 30k records per second. Is there any way so that I can make sure that only 1000 records will be written to the database in a batch irrespective of the number of records?
I don't think Thead.sleep is a good idea to handle this situation. Generally we don't recommend to do so in Scala and we don't want to block the thread in any case.
One suggestion would be using any Streaming techniques such as Akka.Stream, Monix.Observable. There are some pro and cons between those libraries I don't want to spend too much paragraph on it. But they do support back pressure to control the producing rate when consumer is slower than producer. For example, in your case your consumer is database writing and your producer maybe is reading some json files and doing some aggregations.
The following code illustrates the idea and you will need to modify as your need:
val sourceJson = Source(dataStream.map(line => readJsonFromString(line)))
val sinkDB = Sink(Events.jm.writeValueAsString) // you will need to figure out how to generate the Sink
val flowThrottle = Flow[String]
.throttle(1, 1.second, 1, ThrottleMode.shaping)
val runnable = sourceJson.via[flowThrottle].toMat(sinkDB)(Keep.right)
val result = runnable.run()
The code block is already called by a thread and there are multiple threads running in parallel. Either I can use Thread.sleep(1000) or delay(1.0) in this scala code. But if I use delay() it will use a promise which might have to call outside the function. Looks like Thread.sleep() is the best option along with batch size of 1000. After performing the testing, I could benchmark 120,000 records/thread/sec without any problem.
According to the architecture of memsql, all loads into memsql are done into a rowstore first into the local memory and from there memsql will merge into the columnstore at the end leaves. That resulted into the leaf error everytime I pushed more number of data causing bottleneck. Reducing the batchsize and introducing a Thread.sleep() helped me writing 120,000 records/sec. Performed testing with this benchmark.

How does Spark Structured Streaming flush in-memory state when state data is no longer being checked?

I am trying to build a sessionization application with Spark Structured Streaming(version 2.2.0).
In case of using mapGroupWithState with Update mode, I understand that the executor will crash with an OOM exception if the state data grows large. Hence, I have to manage the memory with GroupStateTimeout option.
(Ref. How does Spark Structured Streaming handle in-memory state when state data is growing?)
However, I can't check if the state is timed-out and ready to be removed if there is no more new streaming data for the particular keys.
For example, let's say I have the following code.
myDataset
.groupByKey(_.key)
.flatMapGroupsWithState(OutputMode.Update, GroupStateTimeout.EventTimeTimeout)(makeSession)
makeSession() function will check if the state is timed-out and remove the timed-out state.
Now, let's say the key "foo" has some stored state in memory already, and no new data with the key "foo" is streaming into the application. As a result, makeSession() does not process the data with key "foo" and the stored state is not being checked. Which means, the stored state with key "foo" persists in memory. If there are many keys like "foo", the stored states will not be flushed and JVM will raise OOM exception.
I might be misunderstanding with mapGroupWithState, but I suspect my OOM exception is caused by the above issue.
If I am correct, what would be the solution for this case?
I want to flush all the stored states that has been timedout and have no more new streaming data.
Is there any good code example?
Now, let's say the key "foo" has some stored state in memory already,
and no new data with the key "foo" is streaming into the application.
As a result, makeSession() does not process the data with key "foo"
and the stored state is not being checked.
This is incorrect. As long as you have new data for any key, Spark will make sure that each batch validates the entire key set, and invoke the timed out keys one last time.
A part of every call to flat/mapGroupsWithState, we have:
val outputIterator =
updater.updateStateForKeysWithData(filteredIter) ++
updater.updateStateForTimedOutKeys()
And this is updateStateForTimedOutKeys:
def updateStateForTimedOutKeys(): Iterator[InternalRow] = {
if (isTimeoutEnabled) {
val timeoutThreshold = timeoutConf match {
case ProcessingTimeTimeout => batchTimestampMs.get
case EventTimeTimeout => eventTimeWatermark.get
case _ =>
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
}
val timingOutKeys = store.filter { case (_, stateRow) =>
val timeoutTimestamp = getTimeoutTimestamp(stateRow)
timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold
}
timingOutKeys.flatMap { case (keyRow, stateRow) =>
callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), hasTimedOut = true)
}
} else Iterator.empty
}
Where the relevant part is flatMap over the timed out keys and invoking each function one last time with hasTimedOut = true.

Periodic Broadcast in Apache Spark Streaming

I am implementing a stream learner for text classification. There are some single-valued parameters in my implementation that needs to be updated as new stream items arrive. For example, I want to change learning rate as the new predictions are made. However, I doubt that there is a way to broadcast variables after the initial broadcast. So what happens if I need to broadcast a variable every time I update it. If there is a way to do it or a workaround for what I want to accomplish in Spark Streaming, I'd be happy to hear about it.
Thanks in advance.
I got this working by creating a wrapper class over the broadcast variable. The updateAndGet method of wrapper class returns the refreshed broadcast variable. I am calling this function inside dStream.transform -> as per the Spark Documentation
http://spark.apache.org/docs/latest/streaming-programming-guide.html#transform-operation
Transform Operation states:
"the supplied function gets called in every batch interval. This allows you to do time-varying RDD operations, that is, RDD operations, number of partitions, broadcast variables, etc. can be changed between batches."
BroadcastWrapper class will look like :
public class BroadcastWrapper {
private Broadcast<ReferenceData> broadcastVar;
private Date lastUpdatedAt = Calendar.getInstance().getTime();
private static BroadcastWrapper obj = new BroadcastWrapper();
private BroadcastWrapper(){}
public static BroadcastWrapper getInstance() {
return obj;
}
public JavaSparkContext getSparkContext(SparkContext sc) {
JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc);
return jsc;
}
public Broadcast<ReferenceData> updateAndGet(SparkContext sparkContext){
Date currentDate = Calendar.getInstance().getTime();
long diff = currentDate.getTime()-lastUpdatedAt.getTime();
if (var == null || diff > 60000) { //Lets say we want to refresh every 1 min = 60000 ms
if (var != null)
var.unpersist();
lastUpdatedAt = new Date(System.currentTimeMillis());
//Your logic to refresh
ReferenceData data = getRefData();
var = getSparkContext(sparkContext).broadcast(data);
}
return var;
}
}
You can use this broadcast variable updateAndGet function in stream.transform method that allows RDD-RDD transformations
objectStream.transform(stream -> {
Broadcast<Object> var = BroadcastWrapper.getInstance().updateAndGet(stream.context());
/**Your code to manipulate stream **/
});
Refer to my full answer from this pos :https://stackoverflow.com/a/41259333/3166245
Hope it helps
My understanding is once a broadcast variable is initially sent out, it is 'read only'. I believe you can update the broadcast variable on the local nodes, but not on remote nodes.
May be you need to consider doing this 'outside Spark'. How about using a noSQL store (Cassandra ..etc) or even Memcache? You can then update the variable from one task and periodically check this store from other tasks?
I got an ugly play, but it worked!
We can find how to get a broadcast value from a broadcast object. https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala#L114
just by broadcast id.
so i periodically rebroadcast through the same broadcast id.
val broadcastFactory = new TorrentBroadcastFactory()
broadcastFactory.unbroadcast(BroadcastId, true, true)
// append some ids to initIds
val broadcastcontent = broadcastFactory.newBroadcast[.Set[String]](initIds, false, BroadcastId)
and i can get BroadcastId from the first broadcast value.
val ids = ssc.sparkContext.broadcast(initIds)
// broadcast id
val BroadcastId = broadcastIds.id
then worker use ids as a Broadcast Type as normal.
def func(record: Array[Byte], bc: Broadcast[Set[String]]) = ???
bkc.unpersist(true)
bkc.destroy()
bkc = sc.broadcast(tableResultMap)
bkv = bkc.value
You may try this,I not guarantee whether effective
It is best that you collect the data to the driver and then broadcast them to all nodes.
Use Dstream # foreachRDD to collect the computed RDDs at the driver and once you know when you need to change learning rate, then use SparkContext#broadcast(value) to send the new value to all nodes.
I would expect the code to look something like the following:
dStreamContainingBroadcastValue.foreachRDD{ rdd =>
val valueToBroadcast = rdd.collect()
sc.broadcast(valueToBroadcast)
}
You may also find this thread useful, from the spark user mailing list. Let me know if that works.

Resources