Periodic Broadcast in Apache Spark Streaming - apache-spark

I am implementing a stream learner for text classification. There are some single-valued parameters in my implementation that needs to be updated as new stream items arrive. For example, I want to change learning rate as the new predictions are made. However, I doubt that there is a way to broadcast variables after the initial broadcast. So what happens if I need to broadcast a variable every time I update it. If there is a way to do it or a workaround for what I want to accomplish in Spark Streaming, I'd be happy to hear about it.
Thanks in advance.

I got this working by creating a wrapper class over the broadcast variable. The updateAndGet method of wrapper class returns the refreshed broadcast variable. I am calling this function inside dStream.transform -> as per the Spark Documentation
http://spark.apache.org/docs/latest/streaming-programming-guide.html#transform-operation
Transform Operation states:
"the supplied function gets called in every batch interval. This allows you to do time-varying RDD operations, that is, RDD operations, number of partitions, broadcast variables, etc. can be changed between batches."
BroadcastWrapper class will look like :
public class BroadcastWrapper {
private Broadcast<ReferenceData> broadcastVar;
private Date lastUpdatedAt = Calendar.getInstance().getTime();
private static BroadcastWrapper obj = new BroadcastWrapper();
private BroadcastWrapper(){}
public static BroadcastWrapper getInstance() {
return obj;
}
public JavaSparkContext getSparkContext(SparkContext sc) {
JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc);
return jsc;
}
public Broadcast<ReferenceData> updateAndGet(SparkContext sparkContext){
Date currentDate = Calendar.getInstance().getTime();
long diff = currentDate.getTime()-lastUpdatedAt.getTime();
if (var == null || diff > 60000) { //Lets say we want to refresh every 1 min = 60000 ms
if (var != null)
var.unpersist();
lastUpdatedAt = new Date(System.currentTimeMillis());
//Your logic to refresh
ReferenceData data = getRefData();
var = getSparkContext(sparkContext).broadcast(data);
}
return var;
}
}
You can use this broadcast variable updateAndGet function in stream.transform method that allows RDD-RDD transformations
objectStream.transform(stream -> {
Broadcast<Object> var = BroadcastWrapper.getInstance().updateAndGet(stream.context());
/**Your code to manipulate stream **/
});
Refer to my full answer from this pos :https://stackoverflow.com/a/41259333/3166245
Hope it helps

My understanding is once a broadcast variable is initially sent out, it is 'read only'. I believe you can update the broadcast variable on the local nodes, but not on remote nodes.
May be you need to consider doing this 'outside Spark'. How about using a noSQL store (Cassandra ..etc) or even Memcache? You can then update the variable from one task and periodically check this store from other tasks?

I got an ugly play, but it worked!
We can find how to get a broadcast value from a broadcast object. https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala#L114
just by broadcast id.
so i periodically rebroadcast through the same broadcast id.
val broadcastFactory = new TorrentBroadcastFactory()
broadcastFactory.unbroadcast(BroadcastId, true, true)
// append some ids to initIds
val broadcastcontent = broadcastFactory.newBroadcast[.Set[String]](initIds, false, BroadcastId)
and i can get BroadcastId from the first broadcast value.
val ids = ssc.sparkContext.broadcast(initIds)
// broadcast id
val BroadcastId = broadcastIds.id
then worker use ids as a Broadcast Type as normal.
def func(record: Array[Byte], bc: Broadcast[Set[String]]) = ???

bkc.unpersist(true)
bkc.destroy()
bkc = sc.broadcast(tableResultMap)
bkv = bkc.value
You may try this,I not guarantee whether effective

It is best that you collect the data to the driver and then broadcast them to all nodes.
Use Dstream # foreachRDD to collect the computed RDDs at the driver and once you know when you need to change learning rate, then use SparkContext#broadcast(value) to send the new value to all nodes.
I would expect the code to look something like the following:
dStreamContainingBroadcastValue.foreachRDD{ rdd =>
val valueToBroadcast = rdd.collect()
sc.broadcast(valueToBroadcast)
}
You may also find this thread useful, from the spark user mailing list. Let me know if that works.

Related

How do you transform a `FixedSqlAction` into a `StreamingDBIO` in Slick?

I'm creating an akka-stream using Alpakka and the Slick module but I'm stuck in a type mismatch problem.
One branch is about getting the total number of invoices in their table:
def getTotal(implicit session: SlickSession) = {
import session.profile.api._
val query = TableQuery[Tables.Invoice].length.result
Slick.source(query)
}
But the end line doesn't compile because Alpakka is expecting a StreamingDBIO but I'm providing a FixedSqlAction[Int,slick.dbio.NoStream,slick.dbio.Effect.Read].
How can I move from the non-streaming result to the streaming one?
Taking the length of a table results in a single value, not a stream. So the simplest way to get a Source to feed a stream is
def getTotal(implicit session: SlickSession): Source[Int, NotUsed] =
Source.lazyFuture { () =>
// Don't actually run the query until the stream has materialized and
// demand has reached the source
val query = TableQuery[Tables.Invoice].length.result
session.db.run(query)
}
Alpakka's Slick connector is more oriented towards streaming (including managing pagination etc.) results of queries that have a lot of results. For a single result, converting the Future of the result that vanilla Slick gives you into a stream is sufficient.
If you want to start executing the query as soon as you call getTotal (note that this whether or not the downstream ever runs or demands data from the source), you can have
def getTotal(implicit session: SlickSession): Source[Int, NotUsed] = {
val query = TableQuery[Tables.Invoice].length.result
Source.future(session.db.run(query))
}
Would sth like this work for you?
def getTotal() = {
// Doc Expressions (Scalar values)
// https://scala-slick.org/doc/3.2.0/queries.html
val query = TableQuery[Tables.Invoice].length.result
val res = Await.result(session.db.run(query), 60.seconds)
println(s"Result: $res")
res
}

NoNodeAvailableException after some insert request to cassandra

I am trying to insert data into Cassandra local cluster using async execution and version 4 of the driver (as same as my Cassandra instance)
I have instantiated the cql session in this way:
CqlSession cqlSession = CqlSession.builder()
.addContactEndPoint(new DefaultEndPoint(
InetSocketAddress.createUnresolved("localhost",9042))).build();
Then I create a statement in an async way:
return session.prepareAsync(
"insert into table (p1,p2,p3, p4) values (?, ?,?, ?)")
.thenComposeAsync(
(ps) -> {
CompletableFuture<AsyncResultSet>[] result = data.stream().map(
(d) -> session.executeAsync(
ps.bind(d.p1,d.p2,d.p3,d.p4)
)
).toCompletableFuture()
).toArray(CompletableFuture[]::new);
return CompletableFuture.allOf(result);
}
);
data is a dynamic list filled with user data.
When I exec the code I get the following exception:
Caused by: com.datastax.oss.driver.api.core.NoNodeAvailableException: No node was available to execute the query
at com.datastax.oss.driver.api.core.AllNodesFailedException.fromErrors(AllNodesFailedException.java:53)
at com.datastax.oss.driver.internal.core.cql.CqlPrepareHandler.sendRequest(CqlPrepareHandler.java:210)
at com.datastax.oss.driver.internal.core.cql.CqlPrepareHandler.onThrottleReady(CqlPrepareHandler.java:167)
at com.datastax.oss.driver.internal.core.session.throttling.PassThroughRequestThrottler.register(PassThroughRequestThrottler.java:52)
at com.datastax.oss.driver.internal.core.cql.CqlPrepareHandler.<init>(CqlPrepareHandler.java:153)
at com.datastax.oss.driver.internal.core.cql.CqlPrepareAsyncProcessor.process(CqlPrepareAsyncProcessor.java:66)
at com.datastax.oss.driver.internal.core.cql.CqlPrepareAsyncProcessor.process(CqlPrepareAsyncProcessor.java:33)
at com.datastax.oss.driver.internal.core.session.DefaultSession.execute(DefaultSession.java:210)
at com.datastax.oss.driver.api.core.cql.AsyncCqlSession.prepareAsync(AsyncCqlSession.java:90)
The node is active and some data are inserted before the exception rise. I have also tried to set up a data center name on the session builder without any result.
Why this exception rise if the node is up and running? Actually I have only one local node and that could be a problem?
The biggest thing that I don't see, is a way to limit the current number of active async threads.
Basically, if that (mapped) data stream gets hit hard enough, it'll basically create all of these new threads that it's awaiting. If the number of writes coming in from those threads creates enough back-pressure that node can't keep up or catch up to, the node will become overwhelmed and not accept requests.
Take a look at this post by Ryan Svihla of DataStax:
Cassandra: Batch Loading Without the Batch — The Nuanced Edition
Its code is from the 3.x version of the driver, but the concepts are the same. Basically, provide some way to throttle-down the writes, or limit the number of "in flight threads" running at any given time, and that should help greatly.
Finally, I have found a solution using BatchStatement and a little custom code to create a chucked list.
int chunks = 0;
if (data.size() % 100 == 0) {
chunks = data.size() / 100;
} else {
chunks = (data.size() / 100) + 1;
}
final int finalChunks = chunks;
return session.prepareAsync(
"insert into table (p1,p2,p3, p4) values (?, ?,?, ?)")
.thenComposeAsync(
(ps) -> {
AtomicInteger counter = new AtomicInteger();
final List<CompletionStage<AsyncResultSet>> batchInsert = data.stream()
.map(
(d) -> ps.bind(d.p1,d.p2,d.p3,d.p4)
)
.collect(Collectors.groupingBy(it -> counter.getAndIncrement() / finalChunks))
.values().stream().map(
boundedStatements -> BatchStatement.newInstance(BatchType.LOGGED, boundedStatements.toArray(new BatchableStatement[0]))
).map(
session::executeAsync
).collect(Collectors.toList());
return CompletableFutures.allSuccessful(batchInsert);
}
);

Spark arbitrary stateful stream aggregation, flatMapGroupsWithState API

10 days old spark developer, trying to understand the flatMapGroupsWithState API of spark.
As I understand:
We pass 2 options to it which are timeout configuration. A possible value is GroupStateTimeout.ProcessingTimeTimeout i.e. kind of an instruction to spark to consider processing time and not event time. Other is the output mode.
We pass in a function, lets say myFunction, that is responsible for setting the state for each key. And we also set a timeout duration with groupState.setTimeoutDuration(TimeUnit.HOURS.toMillis(4)), assuming groupState is the instance of my groupState for a key.
As I understand, as micro batches of stream data keep coming in, spark maintain an intermediate state as we define in user defined function. Lets say the intermediate state after processing n micro batches of data is as follows:
State for Key1:
{
key1: [v1, v2, v3, v4, v5]
}
State for key2:
{
key2: [v11, v12, v13, v14, v15]
}
For any new data that come in, myFunction is called with state for the particular key. Eg. for key1, myFunction is called with key1, new key1 values, [v1,v2,v3,v4,v5] and it updates the key1 state as per the logic.
I read about the timeout and I found Timeout dictates how long we should wait before timing out some intermediate state.
Questions:
If this process run indefinitely, my intermediate states will keep on piling and hit the memory limits on nodes. So when are these intermediate states cleared. I found that in case of event time aggregation, watermarks dictates when the intermediate states will be cleared.
What does timing out the intermediate state mean in the context of Processing time.
If this process run indefinitely, my intermediate states will keep on piling and hit the memory limits on nodes. So when are these intermediate states cleared. I found that in case of event time aggregation, watermarks dictates when the intermediate states will be cleared.
Apache Spark will mark them as expired after the expiration time, so in your example after 4 hours of inactivity (real time + 4 hours, inactivity = no new event updating the state).
What does timing out the intermediate state mean in the context of Processing time.
It means that it will time out accordingly to the real clock (processing time, org.apache.spark.util.SystemClock class). You can check what clock is currently used by analyzing org.apache.spark.sql.streaming.StreamingQueryManager#startQuery triggerClock parameter.
You will find more details in FlatMapGroupsWithStateExec class, and more particularly here:
// Generate a iterator that returns the rows grouped by the grouping function
// Note that this code ensures that the filtering for timeout occurs only after
// all the data has been processed. This is to ensure that the timeout information of all
// the keys with data is updated before they are processed for timeouts.
val outputIterator =
processor.processNewData(filteredIter) ++ processor.processTimedOutState()
And if you analyze these 2 methods, you will see that:
processNewData applies mapping function to all active keys (present in the micro-batch)
/**
* For every group, get the key, values and corresponding state and call the function,
* and return an iterator of rows
*/
def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output)
groupedIter.flatMap { case (keyRow, valueRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),
valueRowIter,
hasTimedOut = false)
}
}
processTimedOutState calls the mapping function on all expired states
def processTimedOutState(): Iterator[InternalRow] = {
if (isTimeoutEnabled) {
val timeoutThreshold = timeoutConf match {
case ProcessingTimeTimeout => batchTimestampMs.get
case EventTimeTimeout => eventTimeWatermark.get
case _ =>
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
}
val timingOutPairs = stateManager.getAllState(store).filter { state =>
state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
}
timingOutPairs.flatMap { stateData =>
callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true)
}
} else Iterator.empty
}
An important point to notice here is that Apache Spark will keep expired state in the state store if you don't invoke GroupState#remove method. The expired states won't be returned for processing though because they're flagged with NO_TIMESTAMP field. However, they will be stored in the state store delta files which may slow down the reprocessing if you need to reload the most recent state. If you analyze FlatMapGroupsWithStateExec again, you will see that the state is removed only when the state removed flag is set to true:
def callFunctionAndUpdateState(...)
// ...
// When the iterator is consumed, then write changes to state
def onIteratorCompletion: Unit = {
if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) {
stateManager.removeState(store, stateData.keyRow)
numUpdatedStateRows += 1
} else {
val currentTimeoutTimestamp = groupState.getTimeoutTimestamp
val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp
val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged
if (shouldWriteState) {
val updatedStateObj = if (groupState.exists) groupState.get else null
stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp)
numUpdatedStateRows += 1
}
}
}

Read, update and save cached value atomically

I have a multiple streams (N) which should update the same cache. So, assume, that there is at least N threads. Each thread may process values with similar keys. The problem is that if i do update as following:
1. Read old value from cache (multiple threads get the same old value)
2. Merge new value with old value (each thread update old value)
3. Save updated value back to the cache (only the last update was saved, another one is lost)
i can lost some updates if multiple threads will simultaneously try to update the same record. At first glance, there is a solution to make all updates atomic: for example, use Increment mutation in hbase or add in aerospike (currently, i'm considering these caches for my case). If value consists only of numeric primitive types, then it is ok, because both cache implementations support atomic inc/dec.
1. Inc/dec each value (cache will resolve sequence of this ops by it's self)
But what if value consists not only of primitives? Then i have to read value and update it in my code. In this case i still can lose some updates.
As i wrote, currently i'm considering hbase and aerospike, but both not fully fit for my case. In hbase, as i know, there is no way to lock row from client side (> ~0.98), so i have to use checkAndPut operation for each complex type. In aerospike i can achieve something like row-based lock using lua udfs, but i want to avoid them. Redis allow to watch record and if there is was update from another thread the transaction will fail and i can catch this error and try again.
So, my question is how to achieve something like row-based lock for such updates and is row-based lock will be a correct way? Maybe there is another approach?
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setMaster("local[2]").setAppName("sample")
val sc = new SparkContext(sparkConf)
val ssc = new StreamingContext(sc, Duration(500))
val source = Source()
val stream = source.stream(ssc)
stream.foreachRDD(rdd => {
if (!rdd.isEmpty()) {
rdd.foreachPartition(partition => {
if (partition.nonEmpty) {
val cache = Cache()
partition.foreach(entity=> {
// in this block if 2 distributed workers (in case of apache spark, for example)
//will process entities with the same keys i can lose one of this update
// worker1 and worker2 will get the same value
val value = cache.get(entity.key)
// both workers will update this value but may get different results
val updatedValue = ??? // some non-trivial update depends on entity
// for example, worker1 put new value, then worker2 put new value. In this case only updates from worker2 are visible and updates from worker1 are lost
cache.put(entity.key, updatedValue)
})
}
})
}
})
ssc.start()
ssc.awaitTermination()
}
So, in case if i use kafka as source i can workaround this if messages are partitioned by keys. In this case i can rely on the fact that only 1 worker will process particular record at any point of time. But how to handle the same situation when messages partitioned randomly (key is inside message body)?

How to create Spark broadcast variable from Java String array?

I have Java String array which contains 45 string which is basically column names
String[] fieldNames = {"colname1","colname2",...};
Currently I am storing above array of String in a Spark driver in a static field. My job is running slow so trying to refactor code. I am using above String array while creating a DataFrame
DataFrame dfWithColNames = sourceFrame.toDF(fieldNames);
I want to do the above using broadcast variable to that it don't ship huge string array to every executor. I believe we can do something like the following to create broadcast
String[] brArray = sc.broadcast(fieldNames,String[].class);//gives compilation error
DataFrame df = sourceFrame.toDF(???);//how do I use above broadcast can I use it as is by passing brArray
I am new to Spark.
This is a bit old question, however, I hope my solution would help somebody.
In order to broadcast any object (could be a single POJO or a collection) with Spark 2+ you first need to have the following method that creates a classTag for you:
private static <T> ClassTag<T> classTag(Class<T> clazz) {
return scala.reflect.ClassManifestFactory.fromClass(clazz);
}
next you use a JavaSparkContext from a SparkSession to broadcast your object as previously:
sparkSession.sparkContext().broadcast(
yourObject,
classTag(YourObject.class)
)
In case of a collection, say, java.util.List, you use the following:
sparkSession.sparkContext().broadcast(
yourObject,
classTag(List.class)
)
The return variable of sc.broadcast is of type Broadcast<String[]> and not String[]. When you want to access the value, you simply call value() on the variable. From your example it would be like:
Broadcast<String[]> broadcastedFieldNames = sc.broadcast(fieldNames)
DataFrame df = sourceFrame.toDF(broadcastedFieldNames.value())
Note, that if you are writing this in Java, you probably want to wrap the SparkContext within the JavaSparkContext. It makes everything easier and you can then avoid having to pass a ClassTag to the broadcast function.
You can read more on broadcasting variables on http://spark.apache.org/docs/latest/programming-guide.html#broadcast-variables
ArrayList<String> dataToBroadcast = new ArrayList();
dataToBroadcast .add("string1");
...
dataToBroadcast .add("stringn");
//Creating the broadcast variable
//No need to write classTag code by hand use akka.japi.Util which is available
Broadcast<ArrayList<String>> strngBrdCast = spark.sparkContext().broadcast(
dataToBroadcast,
akka.japi.Util.classTag(ArrayList.class));
//Here is the catch.When you are iterating over a Dataset,
//Spark will actally run it in distributed mode. So if you try to accees
//Your object directly (e.g. dataToBroadcast) it would be null .
//Cause you didn't ask spark to explicitly send tha outside variable to each
//machine where you are running this for each parallelly.
//So you need to use Broadcast variable.(Most common use of Broadcast)
someSparkDataSetWhere.foreach((row) -> {
ArrayList<String> stringlist = strngBrdCast.value();
...
...
})

Resources