How to Test Spark RDD - apache-spark

I am not sure whether we can Test RDD's in Spark.
I came across an article where it says Mocking a RDD is not a good idea. Is there any other way or any best practice for Testing RDD's

Thank you for putting this outstanding question out there. For some reason, when it comes to Spark, everyone gets so caught up in the analytics that they forget about the great software engineering practices that emerged the last 15 years or so. This is why we make it a point to discuss testing and continuous integration (among other things like DevOps) in our course.
A Quick Aside on Terminology
Before I go on, I have to express a minor disagreement with the KnolX presentation #himanshuIIITian cites. A true unit test means you have complete control over every component in the test. There can be no interaction with databases, REST calls, file systems, or even the system clock; everything has to be "doubled" (e.g. mocked, stubbed, etc) as Gerard Mezaros puts it in xUnit Test Patterns. I know this seems like semantics, but it really matters. Failing to understand this is one major reason why you see intermittent test failures in continuous integration.
We Can Still Unit Test
So given this understanding, unit testing an RDD is impossible. However, there is still a place for unit testing when developing analytics.
(Note: I will be using Scala for the examples, but the concepts transcend languages and frameworks.)
Consider a simple operation:
Here foo and bar are simple functions. Those can be unit tested in the normal way, and they should be with as many corner cases as you can muster. After all, why do they care where they are getting their inputs from whether it is a test fixture or an RDD?
Don't Forget the Spark Shell
This isn't testing per se, but in these early stages you also should be experimenting in the Spark shell to figure out your transformations and especially the consequences of your approach. For example, you can examine physical and logical query plans, partitioning strategy and preservation, and the state of your data with many different functions like toDebugString, explain, glom, show, printSchema, and so on. I will let you explore those.
You can also set your master to local[2] in the Spark shell and in your tests to identify any problems that may only arise once you start to distribute work.
Integration Testing with Spark
Now for the fun stuff.
In order to integration test Spark after you feel confident in the quality of your helper functions and RDD/DataFrame transformation logic, it is critical to do a few things (regardless of build tool and test framework):
Increase JVM memory.
Enable forking but disable parallel execution.
Use your test framework to accumulate your Spark integration tests into suites, and initialize the SparkContext before all tests and stop it after all tests.
There are several ways to do this last one. One is available from the spark-testing-base cited by both #Pushkr and the KnolX presentation linked by #himanshuIIITian.
The Loan Pattern
Another approach is to use the Loan Pattern.
For example (using ScalaTest):
class MySpec extends WordSpec with Matchers with SparkContextSetup {
"My analytics" should {
"calculate the right thing" in withSparkContext { (sparkContext) =>
val data = Seq(...)
val rdd = sparkContext.parallelize(data)
val total = + _)
total shouldBe 1000
trait SparkContextSetup {
def withSparkContext(testMethod: (SparkContext) => Any) {
val conf = new SparkConf()
.setAppName("Spark test")
val sparkContext = new SparkContext(conf)
try {
finally sparkContext.stop()
As you can see, the Loan Pattern makes use of higher-order functions to "loan" the SparkContext to the test and then to dispose of it after it's done.
Suffering-Oriented Programming (Thanks, Nathan)
It is totally a matter of preference, but I prefer to use the Loan Pattern and wire things up myself as long as I can before bringing in another framework. Aside from just trying to stay lightweight, frameworks sometimes add a lot of "magic" that makes debugging test failures hard to reason about. So I take a Suffering-Oriented Programming approach--where I avoid adding a new framework until the pain of not having it is too much to bear. But again, this is up to you.
Now one place where spark-testing-base really shines is with the Hadoop-based helpers like HDFSClusterLike and YARNClusterLike. Mixing those traits in can really save you a lot of setup pain. Another place where it shines is with the Scalacheck-like properties and generators. But again, I would personally hold off on using it until my analytics and my tests reach that level of sophistication.
Integration Testing with Spark Streaming
Finally, I would just like to present a snippet of what a SparkStreaming integration test setup with in-memory values would look like:
val sparkContext: SparkContext = ...
val data: Seq[(String, String)] = Seq(("a", "1"), ("b", "2"), ("c", "3"))
val rdd: RDD[(String, String)] = sparkContext.parallelize(data)
val strings: mutable.Queue[RDD[(String, String)]] = mutable.Queue.empty[RDD[(String, String)]]
val streamingContext = new StreamingContext(sparkContext, Seconds(1))
val dStream: InputDStream = streamingContext.queueStream(strings)
strings += rdd
This is simpler than it looks. It really just turns a sequence of data into a queue to feed to the DStream. Most of it is really just boilerplate setup that works with the Spark APIs.
This might be my longest post ever, so I will leave it here. I hope others chime in with other ideas to help improve the quality of our analytics with the same agile software engineering practices that have improved all other application development.
And with apologies for the shameless plug, you can check out our course Analytics with Apache Spark, where we address a lot of these ideas and more. We hope to have an online version soon.

There are 2 methods of testing Spark RDD/Applications. They are as follows:
For example:
Unit to Test:
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
class WordCount {
def get(url: String, sc: SparkContext): RDD[(String, Int)] = {
val lines = sc.textFile(url) lines.flatMap(_.split(" ")).map((_, 1)).reduceByKey(_ + _)
Now Method 1 to Test it is as follows:
import org.scalatest.{ BeforeAndAfterAll, FunSuite }
import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
class WordCountTest extends FunSuite with BeforeAndAfterAll {
private var sparkConf: SparkConf = _
private var sc: SparkContext = _
override def beforeAll() {
sparkConf = new SparkConf().setAppName("unit-testing").setMaster("local")
sc = new SparkContext(sparkConf)
private val wordCount = new WordCount
test("get word count rdd") {
val result = wordCount.get("file.txt", sc)
assert(result.take(10).length === 10)
override def afterAll() {
In Method 1 we are not mocking RDD. We are just checking the behavior of our WordCount class. But here we have to manage the creation and destruction of SparkContext on our own. So, if you do not want to write extra code for that, then you can use spark-testing-base, like this:
Method 2:
import org.scalatest.FunSuite
import com.holdenkarau.spark.testing.SharedSparkContext
class WordCountTest extends FunSuite with SharedSparkContext {
private val wordCount = new WordCount
test("get word count rdd") {
val result = wordCount.get("file.txt", sc)
assert(result.take(10).length === 10)
import org.scalatest.FunSuite
import com.holdenkarau.spark.testing.SharedSparkContext
import com.holdenkarau.spark.testing.RDDComparisons
class WordCountTest extends FunSuite with SharedSparkContext with RDDComparisons {
private val wordCount = new WordCount
test("get word count rdd with comparison") {
val expected = sc.textFile("file.txt")
.flatMap(_.split(" "))
.map((_, 1))
.reduceByKey(_ + _)
val result = wordCount.get("file.txt", sc)
assert(compareRDD(expected, result).isEmpty)
For more details on Spark RDD Testing refer this - KnolX: Unit Testing of Spark Applications


Method to get number of cores for a executor on a task node?

E.g. I need to get a list of all available executors and their respective multithreading capacity (NOT the total multithreading capacity, sc.defaultParallelism already handle that).
Since this parameter is implementation-dependent (YARN and spark-standalone have different strategy for allocating cores) and situational (it may fluctuate because of dynamic allocation and long-term job running). I cannot use other method to estimate this. Is there a way to retrieve this information using Spark API in a distributed transformation? (E.g. TaskContext, SparkEnv)
UPDATE As for Spark 1.6, I have tried the following methods:
1) run a 1-stage job with many partitions ( >> defaultParallelism ) and count the number of distinctive threadIDs for each executorID:
val n = sc.defaultParallelism * 16
sc.parallelize(n, n).map(v => SparkEnv.get.executorID -> Thread.currentThread().getID)
This however leads to an estimation higher than actual multithreading capacity because each Spark executor uses an overprovisioned thread pool.
2) Similar to 1, except that n = defaultParallesim, and in every task I add a delay to prevent resource negotiator from imbalanced sharding (a fast node complete it's task and asks for more before slow nodes can start running):
val n = sc.defaultParallelism
sc.parallelize(n, n).map{
v =>
SparkEnv.get.executorID -> Thread.currentThread().getID
it works most of the time, but is much slower than necessary and may be broken by very imbalanced cluster or task speculation.
3) I haven't try this: use java reflection to read BlockManager.numUsableCores, this is obviously not a stable solution, the internal implementation may change at any time.
Please tell me if you have found something better.
It is pretty easy with Spark rest API. You have to get application id:
val applicationId = spark.sparkContext.applicationId
ui URL:
val baseUrl = spark.sparkContext.uiWebUrl
and query:
val url = { url =>
With Apache HTTP library (already in Spark dependencies, adapted from
import org.apache.http.impl.client.DefaultHttpClient
import org.apache.http.client.methods.HttpGet
import scala.util.Try
val client = new DefaultHttpClient()
val response = url
.flatMap(url => Try{client.execute(new HttpGet(url))}.toOption)
.flatMap(response => Try{
val s = response.getEntity().getContent()
val json =
and json4s:
import org.json4s._
import org.json4s.jackson.JsonMethods._
implicit val formats = DefaultFormats
case class ExecutorInfo(hostPort: String, totalCores: Int)
val executors: Option[List[ExecutorInfo]] = response.flatMap(json => Try {
As long as you keep application id and ui URL at hand and open ui port to external connections you can do the same thing from any task.
I would try to implement SparkListener in a way similar to web UI does. This code might be helpful as an example.

How to eval model without DataFrames/SparkContext?

With Spark MLLib, I'd build a model (like RandomForest), and then it was possible to eval it outside of Spark by loading the model and using predict on it passing a vector of features.
It seems like with Spark ML, predict is now called transform and only acts on a DataFrame.
Is there any way to build a DataFrame outside of Spark since it seems like one needs a SparkContext to build a DataFrame?
Am I missing something?
Re: Is there any way to build a DataFrame outside of Spark?
It is not possible. DataFrames live inside SQLContext with it living in SparkContext. Perhaps you could work it around somehow, but the whole story is that the connection between DataFrames and SparkContext is by design.
Here is my solution to use spark models outside of spark context (using PMML):
You create model with a pipeline like this:
SparkConf sparkConf = new SparkConf();
SparkSession session = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate();
String tableName = "schema.table";
Properties dbProperties = new Properties();
String tableName = "schema.table";
String simpleUrl = "jdbc:impala://host:21050/schema"
Dataset<Row> data = ,tableName,dbProperties);
String[] inputCols = {"column1"};
StringIndexer indexer = new StringIndexer().setInputCol("column1").setOutputCol("indexed_column1");
StringIndexerModel alphabet =;
data = alphabet.transform(data);
VectorAssembler assembler = new VectorAssembler().setInputCols(inputCols).setOutputCol("features");
Predictor p = new GBTRegressor();
PipelineStage[] stages = {indexer,assembler, p};
Pipeline pipeline = new Pipeline();
PipelineModel pmodel =;
PMML pmml = ConverterUtil.toPMML(data.schema(),pmodel);
FileOutputStream fos = new FileOutputStream("model.pmml");
JAXBUtil.marshalPMML(pmml,new StreamResult(fos));
Using PPML for predictions (locally, without spark context, which can be applied to a Map of arguments and not on a DataFrame):
PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(new FileInputStream(pmmlFile));
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
MiningModelEvaluator evaluator = (MiningModelEvaluator) modelEvaluatorFactory.newModelEvaluator(pmml);
inputFieldMap = new HashMap<String, Field>();
Map<FieldName,String> args = new HashMap<FieldName, String>();
Field curField = evaluator.getInputFields().get(0);
args.put(curField.getName(), "1.0");
Map<FieldName, ?> result = evaluator.evaluate(args);
Spent days on this problem too. It's not straightforward. My third suggestion involves code I have written specifically for this purpose.
Option 1
As other commenters have said, predict(Vector) is now available. However, you need to know how to construct a vector. If you don't, see Option 3.
Option 2
If the goal is to avoid setting up a Spark server (standalone or cluster modes), then its possible to start Spark in local mode. The whole thing will run inside a single JVM.
val spark = SparkSession.builder().config("spark.master", "local[*]").getOrCreate()
// create dataframe from file, or make it up from some data in memory
// use model.transform() to get predictions
But this brings unnecessary dependencies to your prediction module, and it consumes resources in your JVM at runtime. Also, if prediction latency is critical, for example making a prediction within a millisecond as soon as a request comes in, then this option is too slow.
Option 3
MLlib FeatureHasher's output can be used as an input to your learner. The class is good for one hot encoding and also for fixing the size of your feature dimension. You can use it even when all your features are numerical. If you use that in your training, then all you need at prediction time is the hashing logic there. Its implemented as a spark transformer so it's not easy to re-use outside of a spark environment. So I have done the work of pulling out the hashing function to a lib. You apply FeatureHasher and your learner during training as normal. Then here's how you use the slimmed down hasher at prediction time:
// Schema and hash size must stay consistent across training and prediction
val hasher = new FeatureHasherLite(mySchema, myHashSize)
// create sample data-point and hash it
val feature = Map("feature1" -> "value1", "feature2" -> 2.0, "feature3" -> 3, "feature4" -> false)
val featureVector = hasher.hash(feature)
// Make prediction
val prediction = model.predict(featureVector)
You can see details in my github at tilayealemu/sparkmllite. If you'd rather copy my code, take a look at FeatureHasherLite.scala.There are sample codes and unit tests too. Feel free to create an issue if you need help.


I know this question has been answered many times, but I tried everything and I do not come to a solution. I have the following code which raises a NotSerializableException
val ids : Seq[Long] = ...
ids.foreach{ id =>
sc.sequenceFile("file", classOf[LongWritable], classOf[MyWritable]).lookup(new LongWritable(id))
With the following exception
Caused by:
Serialization stack:
at org.apache.spark.serializer.JavaSerializerInstance.serialize(JavaSerializer.scala:84)
at org.apache.spark.util.ClosureCleaner$.ensureSerializable(ClosureCleaner.scala:301)
When creating the SparkContext, I do
val sparkConfig = new SparkConf().setAppName("...").setMaster("...")
sparkConfig.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConfig.registerKryoClasses(Array(classOf[BitString[_]], classOf[MinimalBitString], classOf[]))
sparkConfig.set("spark.kryoserializer.classesToRegister", ",,")
and looking at the environment tab, I can see these entries. However, I do not understand why
the Kryo serializer does not seem to be used (the stack does not mention Kryo)
LongWritable is not serialized.
I'm using Apache Spark v. 1.5.1
Loading repeatedly the same data inside a loop is extremely inefficient. If you perform actions against the same data load it once and cache:
val rdd = sc
.sequenceFile("file", classOf[LongWritable], classOf[MyWritable])
Spark doesn't consider Hadoop Writables to be serializable. There is an open JIRA (SPARK-2421) for this. To handle LongWritables simple get should be enough:{case (k, v) => k.get()}
Regarding your custom class it is your responsibility to deal with this problem.
Effective lookup requires a partitoned RDD. Otherwise it has to search every partition in your RDD.
import org.apache.spark.HashPartitioner
val numPartitions: Int = ???
val partitioned = rdd.partitionBy(new HashPartitioner(numPartitions))
Generally speaking RDDs are not designed for random access. Even with defined partitioner lookup has to linearly search candidate partition. With 5000 uniformly distributed keys and 10M objects in an RDD it most likely means a repeated search over a whole RDD. You have few options to avoid that:
val idsSet = sc.broadcast(ids.toSet)
rdd.filter{case (k, v) => idsSet.value.contains(k)}
val idsRdd = sc.parallelize(ids).map((_, null))
idsRdd.join(rdd).map{case (k, (_, v)) => (k, v)}
IndexedRDD - it doesn't like a particularly active project though
With 10M entries you'll probably be better with searching locally in memory than using Spark. For a larger data you should consider using a proper key-value store.
I'm new to apache spark but tried to solve your problem, please evaluate it, if it can help you out with the problem of serialization, it's occurring because for spark - hadoop LongWritable and other writables are not serialized.
val temp_rdd = sc.parallelize( =>
sc.sequenceFile("file", classOf[LongWritable], classOf[LongWritable]).toArray.toSeq
ids.foreach(id =>temp_rdd.lookup(new LongWritable(id)))
Try this solution. It worked fine for me.
SparkConf conf = new SparkConf().setMaster("local[*]").setAppName("SparkMapReduceApp");
conf.registerKryoClasses(new Class<?>[]{

Need some inputs in feature extraction in Apache Spark

I am new to Apache Spark and we are trying to use the MLIB utility to do some analysis. I collated some code to convert my data into features and then apply a linear regression algorithm to that. I am facing some issues . Please help and excuse if its a silly question
My person data looks like
Just a simple example to get the code working
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
case class Person(rating: String, income: Double, age: Int)
val persondata = sc.textFile("D:/spark/mydata/persondata.txt").map(_.split(",")).map(p => Person(p(0), p(1).toDouble, p(2).toInt))
def prepareFeatures(people: Seq[Person]): Seq[org.apache.spark.mllib.linalg.Vector] = {
val maxIncome = income) max
val maxAge = age) max (p =>
if (p.rating == "A") 0.7 else if (p.rating == "B") 0.5 else 0.3,
p.income / maxIncome,
p.age.toDouble / maxAge))
def prepareFeaturesWithLabels(features: Seq[org.apache.spark.mllib.linalg.Vector]): Seq[LabeledPoint] =
(0d to 1 by (1d / features.length)) zip(features) map(l => LabeledPoint(l._1, l._2))
---Its working till here.
---It breaks in the below code
val data = sc.parallelize(prepareFeaturesWithLabels(prepareFeatures(people))
scala> val data = sc.parallelize(prepareFeaturesWithLabels(prepareFeatures(people)))
<console>:36: error: not found: value people
Error occurred in an application involving default arguments.
val data = sc.parallelize(prepareFeaturesWithLabels(prepareFeatures(people)))
Please advise
You seem to be going in roughly the right direction but there are a few minor problems. First off you are trying to reference a value (people) that you haven't defined. More generally you seem to be writing your code to work with sequences, and instead you should modify your code to work with RDDs (or DataFrames). Also you seem to be using parallelize to try and parallelize your operation, but parallelize is a helper method to take a local collection and make it available as a distributed RDD. I'd probably recommend looking at the programming guides or some additional documentation to get a better understanding of the Spark APIs. Best of luck with your adventures with Spark.

reading from several topics

I'm trying to develop an application that takes four different topics from a kafka server and takes take specific actions with each topic.
I have created a class that receives a DStream and has a method that should transform the DStream.
For example, the handler class:
class StreamHandler(stream:DStream[String]) {
val stream:DStream[String] = stream
def doActions():DStream[String] = {
//Do smth. to DStream
And now, imagine I call doActions() from the main class for each handler class I want, would it be repeated with each arriving DStream or only once?
val topicHandler1 = new StreamHandler(KafkaUtils.createStream(ssc, zkQuorum, "myGroup", Map("topic1"->1)).map(_._2)
val topicHandler2 = new OtherStreamHandler(KafkaUtils.createStream(ssc, zkQuorum, "myGroup", Map("topic2"->1)).map(_._2)
topicHandler2 .doActions()
Is there a better approach?
The transformations declared on the StreamHandler will be applied to each batch of the DStream. The current code is quite incomplete to give you a certain answer. In the DStream transformation pipeline you will need an action that materializes the DStream, otherwise nothing will happen.
Regarding the approach, a function that takes a DStream and applies transformations to it would be sufficient and easy to test:
val pipeline:DStream[Data] => () = dstream =>
As it stands, it doesn't look like the class construction is buying much.
