Been playing around Spark Structured Streaming and mapGroupsWithState (specifically following the StructuredSessionization example in the Spark source). I want to confirm some limitations I believe exist with mapGroupsWithState given my use case.
A session for my purposes is a group of uninterrupted activity for a user such that no two chronologically ordered (by event time, not processing time) events are separated by more than some developer-defined duration (30 minutes is common).
An example will help before jumping into code:
{"event_time": "2018-01-01T00:00:00", "user_id": "mike"}
{"event_time": "2018-01-01T00:01:00", "user_id": "mike"}
{"event_time": "2018-01-01T00:05:00", "user_id": "mike"}
{"event_time": "2018-01-01T00:45:00", "user_id": "mike"}
For the stream above, a session is defined with a 30 minute period of inactivity. In a streaming context, we should end up with one session (the second has yet to complete):
[
{
"user_id": "mike",
"startTimestamp": "2018-01-01T00:00:00",
"endTimestamp": "2018-01-01T00:05:00"
}
]
Now consider the following Spark driver program:
import java.sql.Timestamp
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
object StructuredSessionizationV2 {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.master("local[2]")
.appName("StructredSessionizationRedux")
.getOrCreate()
spark.sparkContext.setLogLevel("WARN")
import spark.implicits._
implicit val ctx = spark.sqlContext
val input = MemoryStream[String]
val EVENT_SCHEMA = new StructType()
.add($"event_time".string)
.add($"user_id".string)
val events = input.toDS()
.select(from_json($"value", EVENT_SCHEMA).alias("json"))
.select($"json.*")
.withColumn("event_time", to_timestamp($"event_time"))
.withWatermark("event_time", "1 hours")
events.printSchema()
val sessionized = events
.groupByKey(row => row.getAs[String]("user_id"))
.mapGroupsWithState[SessionState, SessionOutput](GroupStateTimeout.EventTimeTimeout) {
case (userId: String, events: Iterator[Row], state: GroupState[SessionState]) =>
println(s"state update for user ${userId} (current watermark: ${new Timestamp(state.getCurrentWatermarkMs())})")
if (state.hasTimedOut) {
println(s"User ${userId} has timed out, sending final output.")
val finalOutput = SessionOutput(
userId = userId,
startTimestampMs = state.get.startTimestampMs,
endTimestampMs = state.get.endTimestampMs,
durationMs = state.get.durationMs,
expired = true
)
// Drop this user's state
state.remove()
finalOutput
} else {
val timestamps = events.map(_.getAs[Timestamp]("event_time").getTime).toSeq
println(s"User ${userId} has new events (min: ${new Timestamp(timestamps.min)}, max: ${new Timestamp(timestamps.max)}).")
val newState = if (state.exists) {
println(s"User ${userId} has existing state.")
val oldState = state.get
SessionState(
startTimestampMs = math.min(oldState.startTimestampMs, timestamps.min),
endTimestampMs = math.max(oldState.endTimestampMs, timestamps.max)
)
} else {
println(s"User ${userId} has no existing state.")
SessionState(
startTimestampMs = timestamps.min,
endTimestampMs = timestamps.max
)
}
state.update(newState)
state.setTimeoutTimestamp(newState.endTimestampMs, "30 minutes")
println(s"User ${userId} state updated. Timeout now set to ${new Timestamp(newState.endTimestampMs + (30 * 60 * 1000))}")
SessionOutput(
userId = userId,
startTimestampMs = state.get.startTimestampMs,
endTimestampMs = state.get.endTimestampMs,
durationMs = state.get.durationMs,
expired = false
)
}
}
val eventsQuery = sessionized
.writeStream
.queryName("events")
.outputMode("update")
.format("console")
.start()
input.addData(
"""{"event_time": "2018-01-01T00:00:00", "user_id": "mike"}""",
"""{"event_time": "2018-01-01T00:01:00", "user_id": "mike"}""",
"""{"event_time": "2018-01-01T00:05:00", "user_id": "mike"}"""
)
input.addData(
"""{"event_time": "2018-01-01T00:45:00", "user_id": "mike"}"""
)
eventsQuery.processAllAvailable()
}
case class SessionState(startTimestampMs: Long, endTimestampMs: Long) {
def durationMs: Long = endTimestampMs - startTimestampMs
}
case class SessionOutput(userId: String, startTimestampMs: Long, endTimestampMs: Long, durationMs: Long, expired: Boolean)
}
Output of that program is:
root
|-- event_time: timestamp (nullable = true)
|-- user_id: string (nullable = true)
state update for user mike (current watermark: 1969-12-31 19:00:00.0)
User mike has new events (min: 2018-01-01 00:00:00.0, max: 2018-01-01 00:05:00.0).
User mike has no existing state.
User mike state updated. Timeout now set to 2018-01-01 00:35:00.0
-------------------------------------------
Batch: 0
-------------------------------------------
+------+----------------+--------------+----------+-------+
|userId|startTimestampMs|endTimestampMs|durationMs|expired|
+------+----------------+--------------+----------+-------+
| mike| 1514782800000| 1514783100000| 300000| false|
+------+----------------+--------------+----------+-------+
state update for user mike (current watermark: 2017-12-31 23:05:00.0)
User mike has new events (min: 2018-01-01 00:45:00.0, max: 2018-01-01 00:45:00.0).
User mike has existing state.
User mike state updated. Timeout now set to 2018-01-01 01:15:00.0
-------------------------------------------
Batch: 1
-------------------------------------------
+------+----------------+--------------+----------+-------+
|userId|startTimestampMs|endTimestampMs|durationMs|expired|
+------+----------------+--------------+----------+-------+
| mike| 1514782800000| 1514785500000| 2700000| false|
+------+----------------+--------------+----------+-------+
Given my session definition, the single event in the second batch should trigger an expiry of session state and thus a new session. However, since the watermark (2017-12-31 23:05:00.0) has not passed the state's timeout (2018-01-01 00:35:00.0), state isn't expired and the event is erroneously added to the existing session despite the fact that more than 30 minutes have passed since the latest timestamp in the previous batch.
I think the only way for session state expiration to work as I'm hoping is if enough events from different users were received within the batch to advance the watermark past the state timeout for mike.
I suppose one could also mess with the stream's watermark, but I can't think of how I'd do that to accomplish my use case.
Is this accurate? Am I missing anything in how to properly do event time-based sessionization in Spark?
The implementation you have provided does not seem to work if the watermark interval is greater than session gap duration.
For the logic you have shown to work, you need to set the watermark interval to < 30 mins.
If you really want the watermark interval to be independent of (or more than) the session gap duration, you need to wait until the watermark passes (watermark + gap) to expire the state. The merging logic seems to blindly merge the windows. This should take the gap duration into account before merging.
EDIT: I think I need to answer specific point of origin question instead of providing full resolution.
To add Arun’s answer, state function of map/flatMapGroupsWithState is being called with events first, and then being called with timed out states. Based on how it works, your code is going to reset the timeout while the state should be timed out in this batch.
So while you can leverage timeout feature to call state func even the events don’t contain such key, you still need to deal with current watermark manually. That’s why I set a timeout to earliest sessions’ session end timestamp, and handle all evictions once it is being called.
——
You can refer below code block to see how to achieve session window with event time & watermark via flatMapGroupsWithState.
NOTE: I didn't clean the code, and try to support both output modes, so once you decide the output mode, you can remove unrelated codes to make it simpler.
EDIT2: I had wrong assumption regarding flatMapGroupsWithState, events are not guaranteed to be sorted.
Just updated the code: https://gist.github.com/HeartSaVioR/9a3aeeef0f1d8ee97516743308b14cd6#file-eventtimesessionwindowimplementationviaflatmapgroupswithstate-scala-L32-L189
As of Spark 3.2.0, Spark supports Session window natively.
https://databricks.com/blog/2021/10/12/native-support-of-session-window-in-spark-structured-streaming.html
Related
A good feature of spark structured streaming is that it can join the static dataframe with the streaming dataframe. To cite an example as below. users is a static dataframe read from database. transactionStream is from a stream. By the joining function, we can get the spending of each country accumulated with the new arrival of batches.
val spendingByCountry = (transactionStream
.join(users, users("id") === transactionStream("userid"))
.groupBy($"country")
.agg(sum($"cost")) as "spending")
spendingByContry.writeStream
.outputMode("complete")
.format("console")
.start()
The sum of cost is aggregated with the new batches are coming as shown below.
-------------------------------
Batch: 0
-------------------------------
Country Spending
EN 90.0
FR 50.0
-------------------------------
Batch: 1
-------------------------------
Country Spending
EN 190.0
FR 150.0
If I want to introduce a notification and reset logic as the above example, what should be the correct approach? The requirement is that if the spending is larger than some threshold, the records of country and spending should be stored into a table and the spending should be reset as 0 to accumulate again.
One approach that you can achieve this is by arbitrary stateful processing. The groupBy can be enhanced with a custom function mapGroupsWithState where you maintain all the business logic needed. Here is an example taken from the Spark docs:
// A mapping function that maintains an integer state for string keys and returns a string. // Additionally, it sets a timeout to remove the state if it has not received data for an hour. def mappingFunction(key: String, value: Iterator[Int], state: GroupState[Int]): String = {
if (state.hasTimedOut) { // If called when timing out, remove the state
state.remove()
} else if (state.exists) { // If state exists, use it for processing
val existingState = state.get // Get the existing state
val shouldRemove = ... // Decide whether to remove the state
if (shouldRemove) {
state.remove() // Remove the state
} else {
val newState = ...
state.update(newState) // Set the new state
state.setTimeoutDuration("1 hour") // Set the timeout
}
} else {
val initialState = ...
state.update(initialState) // Set the initial state
state.setTimeoutDuration("1 hour") // Set the timeout } ... // return something }
dataset
.groupByKey(...)
.mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout)(mappingFunction)
10 days old spark developer, trying to understand the flatMapGroupsWithState API of spark.
As I understand:
We pass 2 options to it which are timeout configuration. A possible value is GroupStateTimeout.ProcessingTimeTimeout i.e. kind of an instruction to spark to consider processing time and not event time. Other is the output mode.
We pass in a function, lets say myFunction, that is responsible for setting the state for each key. And we also set a timeout duration with groupState.setTimeoutDuration(TimeUnit.HOURS.toMillis(4)), assuming groupState is the instance of my groupState for a key.
As I understand, as micro batches of stream data keep coming in, spark maintain an intermediate state as we define in user defined function. Lets say the intermediate state after processing n micro batches of data is as follows:
State for Key1:
{
key1: [v1, v2, v3, v4, v5]
}
State for key2:
{
key2: [v11, v12, v13, v14, v15]
}
For any new data that come in, myFunction is called with state for the particular key. Eg. for key1, myFunction is called with key1, new key1 values, [v1,v2,v3,v4,v5] and it updates the key1 state as per the logic.
I read about the timeout and I found Timeout dictates how long we should wait before timing out some intermediate state.
Questions:
If this process run indefinitely, my intermediate states will keep on piling and hit the memory limits on nodes. So when are these intermediate states cleared. I found that in case of event time aggregation, watermarks dictates when the intermediate states will be cleared.
What does timing out the intermediate state mean in the context of Processing time.
If this process run indefinitely, my intermediate states will keep on piling and hit the memory limits on nodes. So when are these intermediate states cleared. I found that in case of event time aggregation, watermarks dictates when the intermediate states will be cleared.
Apache Spark will mark them as expired after the expiration time, so in your example after 4 hours of inactivity (real time + 4 hours, inactivity = no new event updating the state).
What does timing out the intermediate state mean in the context of Processing time.
It means that it will time out accordingly to the real clock (processing time, org.apache.spark.util.SystemClock class). You can check what clock is currently used by analyzing org.apache.spark.sql.streaming.StreamingQueryManager#startQuery triggerClock parameter.
You will find more details in FlatMapGroupsWithStateExec class, and more particularly here:
// Generate a iterator that returns the rows grouped by the grouping function
// Note that this code ensures that the filtering for timeout occurs only after
// all the data has been processed. This is to ensure that the timeout information of all
// the keys with data is updated before they are processed for timeouts.
val outputIterator =
processor.processNewData(filteredIter) ++ processor.processTimedOutState()
And if you analyze these 2 methods, you will see that:
processNewData applies mapping function to all active keys (present in the micro-batch)
/**
* For every group, get the key, values and corresponding state and call the function,
* and return an iterator of rows
*/
def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output)
groupedIter.flatMap { case (keyRow, valueRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),
valueRowIter,
hasTimedOut = false)
}
}
processTimedOutState calls the mapping function on all expired states
def processTimedOutState(): Iterator[InternalRow] = {
if (isTimeoutEnabled) {
val timeoutThreshold = timeoutConf match {
case ProcessingTimeTimeout => batchTimestampMs.get
case EventTimeTimeout => eventTimeWatermark.get
case _ =>
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
}
val timingOutPairs = stateManager.getAllState(store).filter { state =>
state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
}
timingOutPairs.flatMap { stateData =>
callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true)
}
} else Iterator.empty
}
An important point to notice here is that Apache Spark will keep expired state in the state store if you don't invoke GroupState#remove method. The expired states won't be returned for processing though because they're flagged with NO_TIMESTAMP field. However, they will be stored in the state store delta files which may slow down the reprocessing if you need to reload the most recent state. If you analyze FlatMapGroupsWithStateExec again, you will see that the state is removed only when the state removed flag is set to true:
def callFunctionAndUpdateState(...)
// ...
// When the iterator is consumed, then write changes to state
def onIteratorCompletion: Unit = {
if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) {
stateManager.removeState(store, stateData.keyRow)
numUpdatedStateRows += 1
} else {
val currentTimeoutTimestamp = groupState.getTimeoutTimestamp
val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp
val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged
if (shouldWriteState) {
val updatedStateObj = if (groupState.exists) groupState.get else null
stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp)
numUpdatedStateRows += 1
}
}
}
I want to kill my spark streaming job when there is no activity (i.e. the receivers are not receiving messages) for a certain time. I tried doing this
var counter = 0
myDStream.foreachRDD {
rdd =>
if (rdd.count() == 0L)
{
counter = counter + 1
if (counter == 40) {
ssc.stop(true, true)
}
} else {
counter = 0
}
}
Is there a better way of doing this? How would I make a variable available to all receivers and update the variable by 1 whenever there is no activity?
Use a NoSQL Table like Cassandra or HBase to keep the counter. You can not handle Stream Polling inside a loop. Implement same logic using NoSQL or Maria DB and perform a Graceful Shutdown to your streaming Job if no activity is happening.
The way I did it was I maintained a Table in Maria DB for Streaming JOB having Polling interval of 5 mins. Every 5 mins it hits the data base and writes the count of records it consumed also the method returns what is the count of zero records line items during latest timestamp. This helped me a lot managing my Streaming Job Management. Also this table usually helps me o automatically trigger the Streaming job based on a logic written in a shell script
The programming guide says that structured streaming guarantees end-to-end exactly once semantics using appropriate sources/sinks.
However I'm not understanding how this works when the job crashes and we have a watermark applied.
Below is an example of how I currently imagine it working, please correct me on any points that I'm misunderstanding. Thanks in advance!
Example:
Spark Job: Count # events in each 1 hour window, with a 1 hour Watermark.
Messages:
A - timestamp 10am
B - timestamp 10:10am
C - timestamp 10:20am
X - timestamp 12pm
Y - timestamp 12:50pm
Z - timestamp 8pm
We start the job, read A, B, C from the Source and the job crashes at 10:30am before we've written them out to our Sink.
At 6pm the job comes back up and knows to re-process A, B, C using the saved checkpoint/WAL. The final count is 3 for the 10-11am window.
Next, it reads the new messages from Kafka, X, Y, Z in parallel since they belong to different partitions. Z is processed first, so the max event timestamp gets set to 8pm. When the job reads X and Y, they are now behind the watermark (8pm - 1 hour = 7pm), so they are discarded as old data. The final count is 1 for 8-9pm, and the job does not report anything for the 12-1pm window. We've lost data for X and Y.
---End example---
Is this scenario accurate?
If so, the 1 hour watermark may be sufficient to handle late/out-of-order data when flowing normally from Kafka-Sspark, but not when the spark job goes down/Kafka connection is lost for a long period of time. Would the only option to avoid data loss be to use a watermark longer than you expect the job to ever go down for?
The watermark is a fixed value during the minibatch. In your example, since X, Y and Z are processed in the same minibatch, watermark used for this records would be 9:20am. After completion of that minibatch watermark would be updated to 7pm.
Below the quote from the design doc for the feature SPARK-18124 which implements watermarking functionality:
To calculate the drop boundary in our trigger based execution, we have to do the following.
In every trigger, while aggregate the data, we also scan for the max value of event time in the trigger data
After trigger completes, compute watermark = MAX(event time before trigger, max event time in trigger) - threshold
Probably simulation would be more description:
import org.apache.hadoop.fs.Path
import java.sql.Timestamp
import org.apache.spark.sql.types._
import org.apache.spark.sql.streaming.ProcessingTime
val dir = new Path("/tmp/test-structured-streaming")
val fs = dir.getFileSystem(sc.hadoopConfiguration)
fs.mkdirs(dir)
val schema = StructType(StructField("vilue", StringType) ::
StructField("timestamp", TimestampType) ::
Nil)
val eventStream = spark
.readStream
.option("sep", ";")
.option("header", "false")
.schema(schema)
.csv(dir.toString)
// Watermarked aggregation
val eventsCount = eventStream
.withWatermark("timestamp", "1 hour")
.groupBy(window($"timestamp", "1 hour"))
.count
def writeFile(path: Path, data: String) {
val file = fs.create(path)
file.writeUTF(data)
file.close()
}
// Debug query
val query = eventsCount.writeStream
.format("console")
.outputMode("complete")
.option("truncate", "false")
.trigger(ProcessingTime("5 seconds"))
.start()
writeFile(new Path(dir, "file1"), """
|A;2017-08-09 10:00:00
|B;2017-08-09 10:10:00
|C;2017-08-09 10:20:00""".stripMargin)
query.processAllAvailable()
val lp1 = query.lastProgress
// -------------------------------------------
// Batch: 0
// -------------------------------------------
// +---------------------------------------------+-----+
// |window |count|
// +---------------------------------------------+-----+
// |[2017-08-09 10:00:00.0,2017-08-09 11:00:00.0]|3 |
// +---------------------------------------------+-----+
// lp1: org.apache.spark.sql.streaming.StreamingQueryProgress =
// {
// ...
// "numInputRows" : 3,
// "eventTime" : {
// "avg" : "2017-08-09T10:10:00.000Z",
// "max" : "2017-08-09T10:20:00.000Z",
// "min" : "2017-08-09T10:00:00.000Z",
// "watermark" : "1970-01-01T00:00:00.000Z"
// },
// ...
// }
writeFile(new Path(dir, "file2"), """
|Z;2017-08-09 20:00:00
|X;2017-08-09 12:00:00
|Y;2017-08-09 12:50:00""".stripMargin)
query.processAllAvailable()
val lp2 = query.lastProgress
// -------------------------------------------
// Batch: 1
// -------------------------------------------
// +---------------------------------------------+-----+
// |window |count|
// +---------------------------------------------+-----+
// |[2017-08-09 10:00:00.0,2017-08-09 11:00:00.0]|3 |
// |[2017-08-09 12:00:00.0,2017-08-09 13:00:00.0]|2 |
// |[2017-08-09 20:00:00.0,2017-08-09 21:00:00.0]|1 |
// +---------------------------------------------+-----+
// lp2: org.apache.spark.sql.streaming.StreamingQueryProgress =
// {
// ...
// "numInputRows" : 3,
// "eventTime" : {
// "avg" : "2017-08-09T14:56:40.000Z",
// "max" : "2017-08-09T20:00:00.000Z",
// "min" : "2017-08-09T12:00:00.000Z",
// "watermark" : "2017-08-09T09:20:00.000Z"
// },
// "stateOperators" : [ {
// "numRowsTotal" : 3,
// "numRowsUpdated" : 2
// } ],
// ...
// }
writeFile(new Path(dir, "file3"), "")
query.processAllAvailable()
val lp3 = query.lastProgress
// -------------------------------------------
// Batch: 2
// -------------------------------------------
// +---------------------------------------------+-----+
// |window |count|
// +---------------------------------------------+-----+
// |[2017-08-09 10:00:00.0,2017-08-09 11:00:00.0]|3 |
// |[2017-08-09 12:00:00.0,2017-08-09 13:00:00.0]|2 |
// |[2017-08-09 20:00:00.0,2017-08-09 21:00:00.0]|1 |
// +---------------------------------------------+-----+
// lp3: org.apache.spark.sql.streaming.StreamingQueryProgress =
// {
// ...
// "numInputRows" : 0,
// "eventTime" : {
// "watermark" : "2017-08-09T19:00:00.000Z"
// },
// "stateOperators" : [ ],
// ...
// }
query.stop()
fs.delete(dir, true)
Notice how Batch 0 started with watermark 1970-01-01 00:00:00 while Batch 1 started with watermark 2017-08-09 09:20:00 (max event time of Batch 0 minus 1 hour). Batch 2, while empty, used watermark 2017-08-09 19:00:00.
Z is processed first, so the max event timestamp gets set to 8pm.
That's correct. Even though Z may be computed first, the watermark is subtracted from the maximum timestamp in the current query iteration. This means that 08:00 PM will be set as the time in which we subtract the watermark time from, meaning 12:00 and 12:50 will be discarded.
From the documentation:
For a specific window starting at time T, the engine will maintain state and allow late data to update the state until (max event time seen by the engine - late threshold > T)
Would the only option to avoid data loss be to use a watermark longer than you expect the job to ever go down for
Not necessarily. Lets assume you set a maximum amount of data to be read per Kafka querying to 100 items. If you read small batches, and you're reading serially from each partition, each maximum timestamp for each batch may not be the maximum time of the latest message in the broker, meaning you won't lose these messages.
I am trying to run a spark job (which talks to Cassandra) to read data, do some aggregation, and then write aggregates to Cassandra
I have 2 tables (monthly_active_users (MAU) , daily_user_metric_aggregates (DUMA))
For every record in MAU, there will be one or more records in DUMA
Get every records in MAU and fetch user_id in it then find records in DUMA for that user (with server side filters applied like metric_name in ('ms', 'md')
If one or more records in DUMA for the specified where clause then i need to increment the count of appMauAggregate map (app wise MAU counts)
I tested this algorithm, works as expected but i wanted to find out
1) Is it an optimized algorithm (or) is there any better way to do it? I have a sense that something is not correct and i am not seeing speedups. Looks like Cassandra client is being created and shutdown for each spark action (collect). Takes long time to process small dataset.
2) Spark workers are not co-located with cassandra, meaning spark worker is running in different node (container) than C* node (we may move spark worker to C* node for data locality)
3) I am seeing spark job is being created/submitted for every spark action (collect) and i belive that it is an expected behavior from spark, is there anyway to cutdown reads from C* and create joins so that data retrierval is fast?
4) What is the downside of this algorithm? Can you recommend better design approach, meaning w/r/t partition strategy, loading C* partition onto Spark partition, executor's / driver's memory requirement?
5) As long as algorithm and design approach is fine then i can play around with spark tuning. I am using 5 workers (each with 16 CPU and 64GB RAM)
C* Schema :
MAU:
CREATE TABLE analytics.monthly_active_users (
month text,
app_id uuid,
user_id uuid,
PRIMARY KEY (month, app_id, user_id)
) WITH CLUSTERING ORDER BY (app_id ASC, user_id ASC)
data:
cqlsh:analytics> select * from monthly_active_users limit 2;
month | app_id | user_id
--------+--------------------------------------+--------------------------------------
2015-2 | 108eeeb3-7ff1-492c-9dcd-491b68492bf2 | 199c0a31-8e74-46d9-9b3c-04f67d58b4d1
2015-2 | 108eeeb3-7ff1-492c-9dcd-491b68492bf2 | 2c70a31a-031c-4dbf-8dbd-e2ce7bdc2bc7
DUMA:
CREATE TABLE analytics.daily_user_metric_aggregates (
metric_date timestamp,
user_id uuid,
metric_name text,
"count" counter,
PRIMARY KEY (metric_date, user_id, metric_name)
) WITH CLUSTERING ORDER BY (user_id ASC, metric_name ASC)
data:
cqlsh:analytics> select * from daily_user_metric_aggregates where metric_date='2015-02-08' and user_id=199c0a31-8e74-46d9-9b3c-04f67d58b4d1;
metric_date | user_id | metric_name | count
--------------------------+--------------------------------------+-------------------+-------
2015-02-08 | 199c0a31-8e74-46d9-9b3c-04f67d58b4d1 | md | 1
2015-02-08 | 199c0a31-8e74-46d9-9b3c-04f67d58b4d1 | ms | 1
Spark Job :
import java.net.InetAddress
import java.util.concurrent.atomic.AtomicLong
import java.util.{Date, UUID}
import com.datastax.spark.connector.util.Logging
import org.apache.spark.{SparkConf, SparkContext}
import org.joda.time.{DateTime, DateTimeZone}
import scala.collection.mutable.ListBuffer
object MonthlyActiveUserAggregate extends App with Logging {
val KeySpace: String = "analytics"
val MauTable: String = "mau"
val CassandraHostProperty = "CASSANDRA_HOST"
val CassandraDefaultHost = "127.0.0.1"
val CassandraHost = InetAddress.getByName(sys.env.getOrElse(CassandraHostProperty, CassandraDefaultHost))
val conf = new SparkConf().setAppName(getClass.getSimpleName)
.set("spark.cassandra.connection.host", CassandraHost.getHostAddress)
lazy val sc = new SparkContext(conf)
import com.datastax.spark.connector._
def now = new DateTime(DateTimeZone.UTC)
val metricMonth = now.getYear + "-" + now.getMonthOfYear
private val mauMonthSB: StringBuilder = new StringBuilder
mauMonthSB.append(now.getYear).append("-")
if (now.getMonthOfYear < 10) mauMonthSB.append("0")
mauMonthSB.append(now.getMonthOfYear).append("-")
if (now.getDayOfMonth < 10) mauMonthSB.append("0")
mauMonthSB.append(now.getDayOfMonth)
private val mauMonth: String = mauMonthSB.toString()
val dates = ListBuffer[String]()
for (day <- 1 to now.dayOfMonth().getMaximumValue) {
val metricDate: StringBuilder = new StringBuilder
metricDate.append(now.getYear).append("-")
if (now.getMonthOfYear < 10) metricDate.append("0")
metricDate.append(now.getMonthOfYear).append("-")
if (day < 10) metricDate.append("0")
metricDate.append(day)
dates += metricDate.toString()
}
private val metricName: List[String] = List("ms", "md")
val appMauAggregate = scala.collection.mutable.Map[String, scala.collection.mutable.Map[UUID, AtomicLong]]()
case class MAURecord(month: String, appId: UUID, userId: UUID) extends Serializable
case class DUMARecord(metricDate: Date, userId: UUID, metricName: String) extends Serializable
case class MAUAggregate(month: String, appId: UUID, total: Long) extends Serializable
private val mau = sc.cassandraTable[MAURecord]("analytics", "monthly_active_users")
.where("month = ?", metricMonth)
.collect()
mau.foreach { monthlyActiveUser =>
val duma = sc.cassandraTable[DUMARecord]("analytics", "daily_user_metric_aggregates")
.where("metric_date in ? and user_id = ? and metric_name in ?", dates, monthlyActiveUser.userId, metricName)
//.map(_.userId).distinct().collect()
.collect()
if (duma.length > 0) { // if user has `ms` for the given month
if (!appMauAggregate.isDefinedAt(mauMonth)) {
appMauAggregate += (mauMonth -> scala.collection.mutable.Map[UUID, AtomicLong]())
}
val monthMap: scala.collection.mutable.Map[UUID, AtomicLong] = appMauAggregate(mauMonth)
if (!monthMap.isDefinedAt(monthlyActiveUser.appId)) {
monthMap += (monthlyActiveUser.appId -> new AtomicLong(0))
}
monthMap(monthlyActiveUser.appId).incrementAndGet()
} else {
println(s"No message_sent in daily_user_metric_aggregates for user: $monthlyActiveUser")
}
}
for ((metricMonth: String, appMauCounts: scala.collection.mutable.Map[UUID, AtomicLong]) <- appMauAggregate) {
for ((appId: UUID, total: AtomicLong) <- appMauCounts) {
println(s"month: $metricMonth, app_id: $appId, total: $total");
val collection = sc.parallelize(Seq(MAUAggregate(metricMonth.substring(0, 7), appId, total.get())))
collection.saveToCassandra(KeySpace, MauTable, SomeColumns("month", "app_id", "total"))
}
}
sc.stop()
}
Thanks.
Your solution is the least efficient possible. You are performing a join by looking up each key one-by-one, avoiding any possible parallelization.
I've never used the Cassandra connector, but I understand it returns RDDs. So you could do this:
val mau: RDD[(UUID, MAURecord)] = sc
.cassandraTable[MAURecord]("analytics", "monthly_active_users")
.where("month = ?", metricMonth)
.map(u => u.userId -> u) // Key by user ID.
val duma: RDD[(UUID, DUMARecord)] = sc
.cassandraTable[DUMARecord]("analytics", "daily_user_metric_aggregates")
.where("metric_date in ? metric_name in ?", dates, metricName)
.map(a => a.userId -> a) // Key by user ID.
// Count "duma" by key.
val dumaCounts: RDD[(UUID, Long)] = duma.countByKey
// Join to "mau". This drops "mau" entries that have no count
// and "duma" entries that are not present in "mau".
val joined: RDD[(UUID, (MAURecord, Long))] = mau.join(dumaCounts)
// Get per-application counts.
val appCounts: RDD[(UUID, Long)] = joined
.map { case (u, (mau, count)) => mau.appId -> 1 }
.countByKey
There is a parameter spark.cassandra.connection.keep_alive_ms which controls for how long keep the connection opened. Take a look at the documentation page.
If you colocate Spark Workers with Cassandra nodes, connector will take advantage of this and create partitions appropriately so that the executor will always fetch data from the local node.
I can see some design improvements you can make in DUMA table: metric_date seems to be not the best choice for partition key - consider making (user_id, metric_name) a partition key because in that case you will not have to generate dates for the query - you will just need to put user_id and metrics_name to the where clause. Moreover, you can add a month identifier to the primary key - then, each partition will include only those information which are related to what you want to fetch with each query.
Anyway, the functionality of join in Spark-Cassandra-Connector are currently being implemented (see this ticket).