Apache Spark: pass Column as Transformer parameter - apache-spark

I defined a pipeline Transformer like this:
class MyTransformer(condition: Column) extends SparkTransformer {
override def transform(dataset: Dataset[_]): DataFrame = {...}
which is then used in a pipeline:
val pipeline = new Pipeline()
pipeline.setStages(Array(new MyTransformer(col("test).equals(lit("value"))))
In my transformer, I want to apply a transformation only on rows that verify the condition.
It results in a serialization issue:
Serialization stack:
- object not serializable (class: org.apache.spark.sql.Column, value: (test = value))
- field (class: my.project.MyTransformer, name: condition, type: class org.apache.spark.sql.Column)
- ...
In my understanding, the Transformer are serialized to be dispatched to executors, so every parameter should be serializable.
How can I bypass it? Is there a workaround?

This question seems a bit old...
I don't know if my (untested) idea match your needs.
A solution could be to use the SQL expression (a String instance)
val pipeline = new Pipeline()
pipeline.setStages(Array(new MyTransformer("test = 'value'")))
and to use functions.expr() to convert the expression String to Column instance in Transformer.transform method.
This way, the condition is Serializable and the non-serializable objects are created when needed in executors.


How to create an update statement where a UDT value need to be updated using QueryBuilder

I have the following udt type
CREATE TYPE tag_partitions(
year bigint,
month bigint);
and the following table
CREATE TABLE ${tableName} (
tag text,
partition_info set<FROZEN<tag_partitions>>,
The table schema is mapped using the following model
case class TagPartitionsInfo(year:Long, month:Long)
case class TagPartitions(tag:String, partition_info:Set[TagPartitionsInfo])
I have written a function which should create an Update.IfExists query: But I don't know how I should update the udt value. I tried to use set but it isn't working.
def updateValues(tableName:String, model:TagPartitions, id:TagPartitionKeys):Update.IfExists = {
val partitionInfoType:UserType = session.getCluster().getMetadata
//create value
//the logic below assumes that there is only one element in the set
val partitionsInfoSet:Set[UDTValue] = model.partition_info.map((partitionInfo:TagPartitionsInfo) =>{
println("partition info converted to UDTValue: "+partitionsInfoSet)
.where(QueryBuilder.eq("tag", id.tag)).ifExists()
The mistake was I was adding partitionsInfoSet in the table but it is a Set of Scala. I needed to convert into Set of Java using setAsJavaSet
.where(QueryBuilder.eq("tag", id.tag))
Although, it didn't answer your exact question, wouldn't it be easier to use Object Mapper for this? Something like this (I didn't modify it heavily to match your code):
#UDT(name = "scala_udt")
case class UdtCaseClass(id: Integer, #(Field #field)(name = "t") text: String) {
def this() {
this(0, "")
#Table(name = "scala_test_udt")
case class TableObjectCaseClassWithUDT(#(PartitionKey #field) id: Integer,
udts: java.util.Set[UdtCaseClass]) {
def this() {
this(0, new java.util.HashSet[UdtCaseClass]())
and then just create case class and use mapper.save on it. (Also note that you need to use Java collections, until you're imported Scala codecs).
The primary reason for using Object Mapper could be ease of use, and also better performance, because it's using prepared statements under the hood, instead of built statements that are much less efficient.
You can find more information about Object Mapper + Scala in article that I wrote recently.

Need help in filtering records according to set of rules with Apache Spark

I need help in one of the usecases that I have encountered of filtering records against a set of rules with Apache Spark.
As the actual data has too many fields, for example, you can think of data like below (for simplicity giving data in JSON format),
records : [{
"recordId": 1,
"messages": [{"name": "Tom","city": "Mumbai"},
{"name": "Jhon","address": "Chicago"}, .....]
rules : [{
ruleId: 1,
ruleName: "rule1",
criterias: {
name: "xyz",
address: "Chicago, Boston"
}, ....]
I want to match all records against all rules. Here is the pseudocode:
var matchedRecords = []
for(record <- records)
for(rule <- rules)
for(message <- record.message)
if(!isMatch(message, rule.criterias))
if(allMessagesMatched) // If loop completed without break
matchedRecords.put((record.id, ruleId))
def isMatch(message, criteria) =
for(each field in crieteria)
if(field.value contains comma)
if(! message.field containsAny field.value)
return false
else if(!message.field equals field.value) // value doesnt contain comma
return false
return true // if loop completed that means all criterias are matched
There are thousands of records containing thousands of messages and there are hundreads of such rules.
What are the approaches to solve such kind of problem ? Any specific module would be helpful like (SparkSQL, Spark Mlib, Spark GraphX)? Do I need to use any third party lib ?
Approach 1 :
Have List[Rules] & RDD[Records]
Broadcast List[Rules] as they are less in number.
Match each record with all the rules.
Still in this case there is no parallize computation happening for matching each message with the criteria.
I think your suggested approach is good direction. If I have to solve this task I would start from implementing generic trait with method responsible for matching:
trait FilterRule extends Serializable {
def match(record: Record): Boolean
Then I would implement specific filters e.g.:
class EqualsRule extends FilterRule
class RegexRule extends FilterRule
Then I would implement composite filters e.g.:
class AndRule extends FilterRule
class OrRule extends FilterRule
Then you can filter your rdd or DataSet with:
// constructing rule - in reality reading json from configuration, parsing json and creating FilterRule object
val rule = AndRule(EqualsRule(...), EqualsRule(...), ...)
// applying rule
rdd.filter(record => rule.match(r))
Second option is to try using existing Spark SQL functions and DataFrame for filtering, where you can build pretty complex expressions using and, or, multiple columns. Drawback of this approach it is not type safe and unit testing would be more complex.

How to customize column mappings with Spark Cassandra Connector in Java?

I wanted to change a column mapping to be append. Is there a better way to customize the column mappings with Spark Cassandra Connector in Java than the following?
ColumnName song_id = new ColumnName("song_id", Option.empty());
CollectionColumnName key_codes = new ColumnName("key_codes", Option.empty()).append();
List<ColumnRef> collectionColumnNames = Arrays.asList(song_id, key_codes);
scala.collection.Seq<ColumnRef> columnRefSeq = JavaApiHelper.toScalaSeq(collectionColumnNames);
.writerBuilder("demo", "song", mapToRow(PianoSong.class))
.withColumnSelector(new SomeColumns(columnRefSeq))
This is taken from this Spark Streaming code sample.
Just make your column ref's using the
Which has a constructor
case class CollectionColumnName(
columnName: String,
alias: Option[String] = None,
collectionBehavior: CollectionBehavior = CollectionOverwrite) extends ColumnRef
You can rename by setting alias and you can change the insert behavior with collectionBehavior which takes the following classes.
Api Link
/** Insert behaviors for Collections. */
sealed trait CollectionBehavior
case object CollectionOverwrite extends CollectionBehavior
case object CollectionAppend extends CollectionBehavior
case object CollectionPrepend extends CollectionBehavior
case object CollectionRemove extends CollectionBehavior
Which means you can just do
CollectionColumnName appendColumn =
new CollectionColumnName("ColumnName", Option.empty(), CollectionPrepend$.MODULE$);
Which looks a bit more Java-y and is a bit more explicit. Did you have any other goals for this code?

Spark custom estimator including persistence

I want to develop a custom estimator for spark which handles persistence of the great pipeline API as well. But as How to Roll a Custom Estimator in PySpark mllib put it there is not a lot of documentation out there (yet).
I have some data cleansing code written in spark and would like to wrap it in a custom estimator. Some na-substitutions, column deletions, filtering and basic feature generation are included (e.g. birthdate to age).
transformSchema will use the case class of the dataset ScalaReflection.schemaFor[MyClass].dataType.asInstanceOf[StructType]
fit will only fit e.g. mean age as na. substitutes
What is still pretty unclear to me:
transform in the custom pipeline model will be used to transform the "fitted" Estimator on new data. Is this correct? If yes how should I transfer the fitted values e.g. the mean age from above into the model?
how to handle persistence? I found some generic loadImpl method within private spark components but am unsure how to transfer my own parameters e.g. the mean age into the MLReader / MLWriter which are used for serialization.
It would be great if you could help me with a custom estimator - especially with the persistence part.
First of all I believe you're mixing a bit two different things:
Estimators - which represent stages that can be fit-ted. Estimator fit method takes Dataset and returns Transformer (model).
Transformers - which represent stages that can transform data.
When you fit Pipeline it fits all Estimators and returns PipelineModel. PipelineModel can transform data sequentially calling transform on all Transformers in the the model.
how should I transfer the fitted values
There is no single answer to this question. In general you have two options:
Pass parameters of the fitted model as the arguments of the Transformer.
Make parameters of the fitted model Params of the Transformer.
The first approach is typically used by the built-in Transformer, but the second one should work in some simple cases.
how to handle persistence
If Transformer is defined only by its Params you can extend DefaultParamsReadable.
If you use more complex arguments you should extend MLWritable and implement MLWriter that makes sense for your data. There are multiple examples in Spark source which show how to implement data and metadata reading / writing.
If you're looking for an easy to comprehend example take a look a the CountVectorizer(Model) where:
Estimator and Transformer share common Params.
Model vocabulary is a constructor argument, model parameters are inherited from the parent.
Metadata (parameters) is written an read using DefaultParamsWriter / DefaultParamsReader.
Custom implementation handles data (vocabulary) writing and reading.
The following uses the Scala API but you can easily refactor it to Python if you really want to...
First things first:
Estimator: implements .fit() that returns a Transformer
Transformer: implements .transform() and manipulates the DataFrame
Serialization/Deserialization: Do your best to use built-in Params and leverage simple DefaultParamsWritable trait + companion object extending DefaultParamsReadable[T]. a.k.a Stay away from MLReader / MLWriter and keep your code simple.
Parameters passing: Use a common trait extending the Params and share it between your Estimator and Model (a.k.a. Transformer)
Skeleton code:
// Common Parameters
trait MyCommonParams extends Params {
final val inputCols: StringArrayParam = // usage: new MyMeanValueStuff().setInputCols(...)
new StringArrayParam(this, "inputCols", "doc...")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
def getInputCols: Array[String] = $(inputCols)
final val meanValues: DoubleArrayParam =
new DoubleArrayParam(this, "meanValues", "doc...")
// more setters and getters
// Estimator
class MyMeanValueStuff(override val uid: String) extends Estimator[MyMeanValueStuffModel]
with DefaultParamsWritable // Enables Serialization of MyCommonParams
with MyCommonParams {
override def copy(extra: ParamMap): Estimator[MeanValueFillerModel] = defaultCopy(extra) // deafult
override def transformSchema(schema: StructType): StructType = schema // no changes
override def fit(dataset: Dataset[_]): MyMeanValueStuffModel = {
// your logic here. I can't do all the work for you! ;)
copyValues(new MyMeanValueStuffModel(uid + "_model").setParent(this))
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuff extends DefaultParamsReadable[MyMeanValueStuff]
// Model (Transformer)
class MyMeanValueStuffModel(override val uid: String) extends Model[MyMeanValueStuffModel]
with DefaultParamsWritable // Enables Serialization of MyCommonParams
with MyCommonParams {
override def copy(extra: ParamMap): MyMeanValueStuffModel = defaultCopy(extra) // default
override def transformSchema(schema: StructType): StructType = schema // no changes
override def transform(dataset: Dataset[_]): DataFrame = {
// your logic here: zip inputCols and meanValues, toMap, replace nulls with NA functions
// you have access to both inputCols and meanValues here!
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuffModel extends DefaultParamsReadable[MyMeanValueStuffModel]
With the code above you can Serialize/Deserialize a Pipeline containing a MyMeanValueStuff stage.
Want to look at some real simple implementation of an Estimator? MinMaxScaler! (My example is actually simpler though...)

How to create Spark broadcast variable from Java String array?

I have Java String array which contains 45 string which is basically column names
String[] fieldNames = {"colname1","colname2",...};
Currently I am storing above array of String in a Spark driver in a static field. My job is running slow so trying to refactor code. I am using above String array while creating a DataFrame
DataFrame dfWithColNames = sourceFrame.toDF(fieldNames);
I want to do the above using broadcast variable to that it don't ship huge string array to every executor. I believe we can do something like the following to create broadcast
String[] brArray = sc.broadcast(fieldNames,String[].class);//gives compilation error
DataFrame df = sourceFrame.toDF(???);//how do I use above broadcast can I use it as is by passing brArray
I am new to Spark.
This is a bit old question, however, I hope my solution would help somebody.
In order to broadcast any object (could be a single POJO or a collection) with Spark 2+ you first need to have the following method that creates a classTag for you:
private static <T> ClassTag<T> classTag(Class<T> clazz) {
return scala.reflect.ClassManifestFactory.fromClass(clazz);
next you use a JavaSparkContext from a SparkSession to broadcast your object as previously:
In case of a collection, say, java.util.List, you use the following:
The return variable of sc.broadcast is of type Broadcast<String[]> and not String[]. When you want to access the value, you simply call value() on the variable. From your example it would be like:
Broadcast<String[]> broadcastedFieldNames = sc.broadcast(fieldNames)
DataFrame df = sourceFrame.toDF(broadcastedFieldNames.value())
Note, that if you are writing this in Java, you probably want to wrap the SparkContext within the JavaSparkContext. It makes everything easier and you can then avoid having to pass a ClassTag to the broadcast function.
You can read more on broadcasting variables on http://spark.apache.org/docs/latest/programming-guide.html#broadcast-variables
ArrayList<String> dataToBroadcast = new ArrayList();
dataToBroadcast .add("string1");
dataToBroadcast .add("stringn");
//Creating the broadcast variable
//No need to write classTag code by hand use akka.japi.Util which is available
Broadcast<ArrayList<String>> strngBrdCast = spark.sparkContext().broadcast(
//Here is the catch.When you are iterating over a Dataset,
//Spark will actally run it in distributed mode. So if you try to accees
//Your object directly (e.g. dataToBroadcast) it would be null .
//Cause you didn't ask spark to explicitly send tha outside variable to each
//machine where you are running this for each parallelly.
//So you need to use Broadcast variable.(Most common use of Broadcast)
someSparkDataSetWhere.foreach((row) -> {
ArrayList<String> stringlist = strngBrdCast.value();
