How to broadcast a DataFrame? - apache-spark

I am using spark-sql-2.4.1 version.
creating a broadcast variable as below
Broadcast<Map<String,Dataset>> bcVariable = javaSparkContext.broadcast(//read dataset);
Me passing the bcVariable to a function
Service.calculateFunction(sparkSession, bcVariable.getValue());
public static class Service {
public static calculateFunction(
SparkSession sparkSession,
Map<String, Dataset> dataSet ) {
System.out.println("---> size : " + dataSet.size()); //printing size 1
for( Entry<String, Dataset> aEntry : dataSet.entrySet() ) {
System.out.println( aEntry.getKey()); // printing key
aEntry.getValue().show() // throw null pointer exception
}
}
What is wrong here ? how to pass a dataset/dataframe in the function?
Try 2 :
Broadcast<Dataset> bcVariable = javaSparkContext.broadcast(//read dataset);
Me passing the bcVariable to a function
Service.calculateFunction(sparkSession, bcVariable.getValue());
public static class Service {
public static calculateFunction(
SparkSession sparkSession,
Dataset dataSet ) {
System.out.println("---> size : " + dataSet.size()); // throwing null pointer exception.
}
What is wrong here ? how to pass a dataset/dataframe in the function?
Try 3 :
Dataset metaData = //read dataset from oracle table i.e. meta-data.
Me passing the metaData to a function
Service.calculateFunction(sparkSession, metaData );
public static class Service {
public static calculateFunction(
SparkSession sparkSession,
Dataset metaData ) {
System.out.println("---> size : " + metaData.size()); // throwing null pointer exception.
}
What is wrong here ? how to pass a dataset/dataframe in the function?

The value to be broadcast has to be any Scala object but not a DataFrame.
Service.calculateFunction(sparkSession, metaData) is executed on executors and hence metaData is null (as it was not serialized and sent over the wire from the driver to executors).
broadcast[T](value: T): Broadcast[T]
Broadcast a read-only variable to the cluster, returning a org.apache.spark.broadcast.Broadcast object for reading it in distributed functions. The variable will be sent to each cluster only once.
Think of DataFrame data abstraction to represent a distributed computation that is described in a SQL-like language (Dataset API or SQL). It simply does not make any sense to have it anywhere but on the driver where computations can be submitted for execution (as tasks on executors).
You simply have to "convert" the data this computation represents (in DataFrame terms) using DataFrame.collect.
Once you collected the data you can broadcast it and reference using .value method.
The code could look as follows:
val dataset = // reading data
Broadcast<Map<String,Dataset>> bcVariable =
javaSparkContext.broadcast(dataset.collect);
Service.calculateFunction(sparkSession, bcVariable.getValue());
The only change compared to your code is collect.

Related

Why does Spark Dataset.map require all parts of the query to be serializable?

I would like to use the Dataset.map function to transform the rows of my dataset. The sample looks like this:
val result = testRepository.readTable(db, tableName)
.map(testInstance.doSomeOperation)
.count()
where testInstance is a class that extends java.io.Serializable, but testRepository does extend this. The code throws the following error:
Job aborted due to stage failure.
Caused by: NotSerializableException: TestRepository
Question
I understand why testInstance.doSomeOperation needs to be serializable, since it's inside the map and will be distributed to the Spark workers. But why does testRepository needs to be serialized? I don't see why that is necessary for the map. Changing the definition to class TestRepository extends java.io.Serializable solves the issue, but that is not desirable in the larger context of the project.
Is there a way to make this work without making TestRepository serializable, or why is it required to be serializable?
Minimal working example
Here's a full example with the code from both classes that reproduces the NotSerializableException:
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
case class MyTableSchema(id: String, key: String, value: Double)
val db = "temp_autodelete"
val tableName = "serialization_test"
class TestRepository extends java.io.Serializable {
def readTable(database: String, tableName: String): Dataset[MyTableSchema] = {
spark.table(f"$database.$tableName")
.as[MyTableSchema]
}
}
val testRepository = new TestRepository()
class TestClass() extends java.io.Serializable {
def doSomeOperation(row: MyTableSchema): MyTableSchema = {
row
}
}
val testInstance = new TestClass()
val result = testRepository.readTable(db, tableName)
.map(testInstance.doSomeOperation)
.count()
The reason why is because your map operation is reading from something that already takes place on the executors.
If you look at your pipeline:
val result = testRepository.readTable(db, tableName)
.map(testInstance.doSomeOperation)
.count()
The first thing you do is testRepository.readTable(db, tableName). If we look inside of the readTable method, we see that you are doing a spark.table operation in there. If we look at the function signature of this method from the API docs, we see the following function signature:
def table(tableName: String): DataFrame
This is not an operation that solely takes place on the driver (imagine reading in a file of >1TB while only taking place on the driver), and it creates a Dataframe (which is by itself a distributed dataset). That means that the testRepository.readTable(db, tableName) function needs to be distributed, and so your testRepository object needs to be distributed.
Hope this helps you!

How to get the taskID or mapperID(something like partitionID in Spark) in a hive UDF?

As question, How to get the taskID or mapperID(something like partitionID in Spark) in a hive UDF ?
You can access task information using TaskContext:
import org.apache.spark.TaskContext
sc.parallelize(Seq[Int](), 4).mapPartitions(_ => {
val ctx = TaskContext.get
val stageId = ctx.stageId
val partId = ctx.partitionId
val hostname = java.net.InetAddress.getLocalHost().getHostName()
Iterator(s"Stage: $stageId, Partition: $partId, Host: $hostname")}).collect.foreach(println)
A similar functionality has been added to PySpark in Spark 2.2.0 (SPARK-18576):
from pyspark import TaskContext
import socket
def task_info(*_):
ctx = TaskContext()
return ["Stage: {0}, Partition: {1}, Host: {2}".format
(ctx.stageId(), ctx.partitionId(), socket.gethostname())]
for x in sc.parallelize([], 4).mapPartitions(task_info).collect():
print(x)
I think it will provide you the information about the task including map id you are looking for.
I have found the correct answer on my own, we can get the taskID in a hive UDF the way as below :
public class TestUDF extends GenericUDF {
private Text result = new Text();
private String tmpStr = "";
#Override
public void configure(MapredContext context) {
//get the number of tasks 获取task总数量
int numTasks = context.getJobConf().getNumMapTasks();
//get the current taskID 获取当前taskID
String taskID = context.getJobConf().get("mapred.task.id");
this.tmpStr = numTasks + "_h_xXx_h_" + taskID;
}
#Override
public ObjectInspector initialize(ObjectInspector[] arguments)
throws UDFArgumentException {
return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
}
#Override
public Object evaluate(DeferredObject[] arguments) {
result.set(this.tmpStr);
return this.result;
}
#Override
public String getDisplayString(String[] children) {
return "RowSeq-func()";
}
}
but this would be effective only in MapReduce execution engine, it would not work in a SparkSQL engine.
Test code as below:
add jar hdfs:///home/dp/tmp/shaw/my_udf.jar;
create temporary function seqx AS 'com.udf.TestUDF';
with core as (
select
device_id
from
test_table
where
p_date = '20210309'
and product = 'google'
distribute by
device_id
)
select
seqx() as seqs,
count(1) as cc
from
core
group by
seqx()
order by
seqs asc
Result in MR engine as below, see we have got the task number and taskID successfully:
Result in Spark engine with same sql above, the UDF is not valid, we get nothing about taskID:
If you run your HQL in Spark engine and call the Hive UDF meanwhile, and really need to get the partitionId in Spark, see the code below :
import org.apache.spark.TaskContext;
public class TestUDF extends GenericUDF {
private Text result = new Text();
private String tmpStr = "";
#Override
public ObjectInspector initialize(ObjectInspector[] arguments)
throws UDFArgumentException {
//get spark partitionId
this.tmpStr = TaskContext.getPartitionId() + "-initial-pid";
return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
}
public Object evaluate(DeferredObject[] arguments) {
//get spark partitionId
this.tmpStr = TaskContext.getPartitionId() + "-evaluate-pid";
result.set(this.tmpStr);
return this.result;
}
}
As above, you can get the Spark partitionId by calling TaskContext.getPartitionId() in the override method initialize or evalute of UDF class.
Notice: your UDF must has params, suchs select my_udf(param), this would lead your UDF initialized in multiple tasks, if your UDF do not have a param, it will be initialized at the Driver, and the Driver do not have the taskContext and partitionId, so you would get nothing.
The image below is a result produced by the above UDF executed in Spark engine,see, we get the partitionIds successfully :

Java Spark Dataset MapFunction - Task not serializable without any reference to class

I have a following class that reads csv data into Spark's Dataset. Everything works fine if I just simply read and return the data.
However, if I apply a MapFunction to the data before returning from function, I get
Exception in thread "main" org.apache.spark.SparkException: Task not serializable
Caused by: java.io.NotSerializableException: com.Workflow.
I know Spark's working and its need to serialize objects for distributed processing, however, I'm NOT using any reference to Workflow class in my mapping logic. I'm not calling any Workflow class function in my mapping logic. So why is Spark trying to serialize Workflow class? Any help will be appreciated.
public class Workflow {
private final SparkSession spark;
public Dataset<Row> readData(){
final StructType schema = new StructType()
.add("text", "string", false)
.add("category", "string", false);
Dataset<Row> data = spark.read()
.schema(schema)
.csv(dataPath);
/*
* works fine till here if I call
* return data;
*/
Dataset<Row> cleanedData = data.map(new MapFunction<Row, Row>() {
public Row call(Row row){
/* some mapping logic */
return row;
}
}, RowEncoder.apply(schema));
cleanedData.printSchema();
/* .... ERROR .... */
cleanedData.show();
return cleanedData;
}
}
anonymous inner classes have a hidden/implicit reference to enclosing class. use Lambda expression or go with Roma Anankin's solution
you could make Workflow implement Serializeble and SparkSession as #transient

Spark closures behavior

I have a basic question on spark closures. I am not able to distinguish code behavior between scenario 2 & 3, both produces same output but based on my understanding scenario 3 should not work as expected.
The Below code is common for all scenarios:
class A implements Serializable{
String t;
A(String t){
this.t=t;
}
}
//Initiaze spark context
JavaSparkContext context=....
//create rdd
JavaRDD<String> rdd = context.parallelize(Arrays.asList("a","b","c","d","e"),3);
Scenerio 1: don't do this because A is initialize in driver and not visible on executor.
A a=new A("pqr");
rdd.map(i->i+a.t).collect();
Scenerio 2: Recommended way of sharing object
Broadcast<A> broadCast = context.broadcast(new A("pqr"));
rdd.map(i->broadCast.getValue().t+i).collect();
//output: [pqra, pqrb, pqrc, pqrd, pqre]
Scenerio 3: why this code work as expected even when I initiate A in driver?
class TestFunction implements Function<String, String>, Serializable {
private A val;
public TestFunction(){ }
public TestFunction(A a){
this.val = a;
}
#Override
public String call(String integer) throws Exception {
return val.t+integer;
}
}
TestFunction mapFunction = new TestFunction(new A("pqr"));
System.out.println(rdd.map(mapFunction).collect());
//output: [pqra, pqrb, pqrc, pqrd, pqre]
Note: I am running program in cluster mode.
The generated Java bytecodes for Scenerio 1 & 3 are almost the same. The benefit of using Broadcast (Scenerio 2) is the broadcast object will only be sent to an executor once and reuse it in other tasks on this executor. Scenerio 1 & 3 will always send the object A to executors for each task.

How Spark tasks in the same executor share variables (NumberFormatException with SimpleDateFormat)?

The spark docs says that
By default, when Spark runs a function in parallel as a set of tasks on different nodes, it ships a copy of each variable used in the function to each task.
If I create a Java SimpleDateFormat and use it in RDD operations, I got a exception NumberFormatException: multiple points.
I know SimpleDateFormat is not thread-safe. But as said by spark docs, this SimpleDateFormat object is copied to each task, so there should not be multiple threads accessing this object.
I speculate that all task in one executor shares the same SimpleDateFormate object, am I right?
This program prints the same object java.text.SimpleDateFormat#f82ede60
object NormalVariable {
// create dateFormat here doesn't change
// val dateFormat = new SimpleDateFormat("yyyy.MM.dd")
def main(args: Array[String]) {
val dateFormat = new SimpleDateFormat("yyyy.MM.dd")
val conf = new SparkConf().setAppName("Spark Test").setMaster("local[*]")
val spark = new SparkContext(conf)
val dates = Array[String]("1999.09.09", "2000.09.09", "2001.09.09", "2002.09.09", "2003.09.09")
println(dateFormat)
val resultes = spark.parallelize(dates).map { i =>
println(dateFormat)
dateFormat.parse(i)
}.collect()
println(resultes.mkString(" "))
spark.stop()
}
}
As you know, SimpleDateFormat is not thread safe.
If Spark is using a single core per executor (--executor-cores 1) then everything should work fine. But as soon as you configure more than one core per executor, your code is now running multi-threaded, the SimpleDateFormat is shared by multiple Spark tasks concurrently, and is likely to corrupt the data and throw various exceptions.
To fix this, you can use one of the same approaches as for non-Spark code, namely ThreadLocal, which ensures you get one copy of the SimpleDateFormat per thread.
In Java, this looks like:
public class DateFormatTest {
private static final ThreadLocal<DateFormat> df = new ThreadLocal<DateFormat>(){
#Override
protected DateFormat initialValue() {
return new SimpleDateFormat("yyyyMMdd");
}
};
public Date convert(String source) throws ParseException{
Date d = df.get().parse(source);
return d;
}
}
and the equivalent code in Scala works just the same - shown here as a spark-shell session:
import java.text.SimpleDateFormat
object SafeFormat extends ThreadLocal[SimpleDateFormat] {
override def initialValue = {
new SimpleDateFormat("yyyyMMdd HHmmss")
}
}
sc.parallelize(Seq("20180319 162058")).map(SafeFormat.get.parse(_)).collect
res6: Array[java.util.Date] = Array(Mon Mar 19 16:20:58 GMT 2018)
So you would define the ThreadLocal at the top level of your job class or object, then call df.get to obtain the SimpleDateFormat within your RDD operations.
See:
http://fahdshariff.blogspot.co.uk/2010/08/dateformat-with-multiple-threads.html
"Java DateFormat is not threadsafe" what does this leads to?

Resources