Why does Spark run 5 jobs for a simple aggregation? - apache-spark

I use Spark in local mode from an IDE/eclipse.
I can see Spark UI creating many jobs for a simple aggregation. Why?
import org.apache.spark.sql.SparkSession
trait SparkSessionWrapper {
lazy val spark: SparkSession = {
SparkSession
.builder()
.master("local[2]")
.appName("Spark Me")
.getOrCreate()
}
spark.sparkContext.setLogLevel("WARN")
}
The Spark application is as follows:
object RowNumberCalc
extends App
with SparkSessionWrapper {
import spark.implicits._
val cityDf = Seq(
("London", "Harish",5500,"2019-10-01"),
("NYC","RAJA",11121,"2019-10-01"),
("SFO","BABU",77000,"2019-10-01"),
("London","Rick",7500,"2019-09-01"),
("NYC","Jenna",6511,"2019-09-01"),
("SFO","Richard",234567,"2019-09-01"),
("London","Harish",999999,"2019-08-01"),
("NYC","Sam",1234,"2019-08-01"),
("SFO","Dylan",45678,"2019-08-01")).toDF("city","name","money","month_id")
cityDf.createOrReplaceTempView("city_table")
val totalMoneySql =
"""
|select city, sum(money) from city_table group by 1 """.stripMargin
spark.sql(totalMoneySql).show(false)
System.in.read
spark.stop()
}
As shown a simple calculation of Sum of Money for Each City
Now SPARK-UI shows ==> 5 JOBS each with 2 Stages !!!
And SQL tab also shows 5 jobs .
But Physical Plan shows correct Stage division
== Physical Plan ==
CollectLimit 21
+- *(2) LocalLimit 21
+- *(2) HashAggregate(keys=[city#9], functions=[sum(cast(money#11 as bigint))], output=[city#9, sum(money)#24])
+- Exchange hashpartitioning(city#9, 200)
+- *(1) HashAggregate(keys=[city#9], functions=[partial_sum(cast(money#11 as bigint))], output=[city#9, sum#29L])
+- LocalTableScan [city#9, money#11]
FROM WHERE/HOW 5 JOBS are being triggered ???

tl;dr You've got a very few rows to work with (9 as the main input and 3 aggregates) across the default 200 partitions and so 5 Spark jobs to meet the requirements of Dataset.show to show 20 rows.
In other words, what you experience is Dataset.show-specific (which by the way is not for large datasets, isn't it?)
By default Dataset.show displays 20 rows. It starts with 1 partition and takes up to 20 rows. If there are not enough rows, it multiplies by 4 (if I'm not mistaken) and scans the other 4 partitions to find the missing rows. That works until 20 rows are collected.
Number of output rows of the last HashAggregate is 3 rows.
Depending on what partitions these 3 rows are in Spark could run one, two or more jobs. It strongly depends on the hash of the rows (per HashPartitioner).
If you really want to see a single Spark job for this number of rows (9 for the input) start the Spark application with spark.sql.shuffle.partitions configuration property as 1.
That will make the computation with 1 partition after the aggregation and all the result rows in one partition.

Related

Pushing down aggregation to Cassandra when querying from Spark

I've got a Cassandra table looking like this:
my_keyspace.my_table (
part_key_col_1 text,
clust_key_col_1 int,
clust_key_col_2 text,
value_col_1 text
PRIMARY KEY (part_key_col_1, clust_key_col_1, clust_key_col_2, value_col_1)
I'm looking to retrieve tha maximum value of clust_key_col_1 for each part_key_col_1, where I also want a filter on clust_key_col_1. In CQL I can achieve this using:
SELECT
part_key_col_1
max(clust_key_col_1)
FROM my_table
WHERE clust_key_col_1 < 123
GROUP BY part_key_col_1
ALLOW FILTERING
Even though I need to use ALLOW FILTERING the query is super fast, I got roughly 1 000 000 unique part_key_col_1 and for each part_key_col_1 I got less than 5000 unique clust_key_col_1 .
My problem comes when I try to get the same data in Spark using Spark Cassandra Connector. I've tried the following in Spark:
cassandra_df = (
spark.read
.format("org.apache.spark.sql.cassandra")
.options(table='my_table', keyspace='my_keyspace')
.load()
.filter(f.col('clust_key_col_1') < 123)
.groupBy(f.col('part_key_col_1'))
.agg(
f.max('clust_key_col_1')
)
)
But the Physical-plan ends up being:
== Physical Plan ==
*(2) HashAggregate(keys=[part_key_col_1#144], functions=[max(clust_key_col_1#145)])
+- Exchange hashpartitioning(part_key_col_1#144, 20)
+- *(1) HashAggregate(keys=[part_key_col_1#144], functions=[partial_max(clust_key_col_1#145)])
+- *(1) Scan org.apache.spark.sql.cassandra.CassandraSourceRelation [part_key_col_1#144,clust_key_col_1#145] PushedFilters: [*LessThan(clust_key_col_1,123)], ReadSchema: struct<part_key_col_1:string,clust_key_col_1:int>
Meaning the filter for clust_key_col_1 gets pushed down to Cassandra, but the grouping and the aggregation does not. Instead all the data (with clust_key_col_1 < 123) gets loaded into Spark and aggregated in Spark. Can I somehow "push down" the grouping/aggregation to Cassandra and only load the max(clust_key_col_1) for each part_key_col_1) to reduce the load on Spark and the network? Right now Spark will load 1 000 000 * 5000 rows instead of 1 000 000 rows.

Apache Spark: broadcast join behaviour: filtering of joined tables and temp tables

I need to join 2 tables in spark.
But instead of joining 2 tables completely, I first filter out a part of second table:
spark.sql("select * from a join b on a.key=b.key where b.value='xxx' ")
I want to use broadcast join in this case.
Spark has a parameter which defines max table size for broadcast join: spark.sql.autoBroadcastJoinThreshold:
Configures the maximum size in bytes for a table that will be
broadcast to all worker nodes when performing a join. By setting this
value to -1 broadcasting can be disabled. Note that currently
statistics are only supported for Hive Metastore tables where the
command ANALYZE TABLE COMPUTE STATISTICS noscan has been
run. http://spark.apache.org/docs/2.4.0/sql-performance-tuning.html
I have following questions about this setup:
which table size spark will compare with autoBroadcastJoinThreshold's value: FULL size, or size AFTER applying where clause?
I am assuming that spark will apply where clause BEFORE broadcasting, correct?
the doc says I need to run Hive's Analyze Table command beforehand. How it will work in a case when I am using temp view as a table? As far as I understand I cannot run Analyze Table command against spark's temp view created via dataFrame.createorReplaceTempView("b"). Can I broadcast temp view contents?
Understanding for option 2 is correct.
You can not analyze a TEMP table in spark . Read here
In case you want to take the lead and want to specify the dataframe which you want to broadcast, instead spark decides, can use below snippet-
df = df1.join(F.broadcast(df2),df1.some_col == df2.some_col, "left")
I went ahead and did some small experiments to answer your 1st question.
Question 1 :
created a dataframe a with 3 rows [key,df_a_column]
created a dataframe b with 10 rows [key,value]
ran: spark.sql("SELECT * FROM a JOIN b ON a.key = b.key").explain()
== Physical Plan ==
*(1) BroadcastHashJoin [key#122], [key#111], Inner, BuildLeft, false
:- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#168]
: +- LocalTableScan [key#122, df_a_column#123]
+- *(1) LocalTableScan [key#111, value#112]
As expected the Smaller df a with 3 rows is broadcasted.
Ran : spark.sql("SELECT * FROM a JOIN b ON a.key = b.key where b.value=\"bat\"").explain()
== Physical Plan ==
*(1) BroadcastHashJoin [key#122], [key#111], Inner, BuildRight, false
:- *(1) LocalTableScan [key#122, df_a_column#123]
+- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#152]
+- LocalTableScan [key#111, value#112]
Here you can notice the dataframe b is Broadcasted ! meaning spark evaluates the size AFTER applying where for choosing which one to broadcast.
Question 2 :
Yes you are right. It's evident from the previous output it applies where first.
Question 3 :
No you cannot analyse but you can broadcast tempView table by hinting spark about it even in SQL. ref
Example : spark.sql("SELECT /*+ BROADCAST(b) */ * FROM a JOIN b ON a.key = b.key")
And if you see explain now :
== Physical Plan ==
*(1) BroadcastHashJoin [key#122], [key#111], Inner, BuildRight, false
:- *(1) LocalTableScan [key#122, df_a_column#123]
+- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#184]
+- LocalTableScan [key#111, value#112]
Now if you see, dataframe b is broadcasted even though it has 10 rows.
In question 1, without the hint , a was broadcasted .
Note: Broadcast hint in SQL spark is available for 2.2
Tips to understand the physical plan :
Figure out the dataframe from the LocalTableScan[ list of columns ]
The dataframe present under the sub tree/list of BroadcastExchange is being broadcasted.

Pushdown filter in case of spark structured Delta streaming

I have a use case where we need to stream Open Source Delta table into multiple queries, filtered on one of the partitioned column.
Eg,.
Given Delta-table partitioned on year column.
Streaming query 1
spark.readStream.format("delta").load("/tmp/delta-table/").
where("year= 2013")
Streaming query 2
spark.readStream.format("delta").load("/tmp/delta-table/").
where("year= 2014")
The physical plan shows filter after the streaming.
> == Physical Plan == Filter (isnotnull(year#431) AND (year#431 = 2013))
> +- StreamingRelation delta, []
My question is does pushdown predicate works with Streaming queries in Delta?
Can we stream only specific partition from the Delta?
If the column is already partitioned, only the required partition will be scanned.
Let's create both partitioned and non-partitioned delta table and perform structured streaming.
Partitioned delta table streaming:
val spark = SparkSession.builder().master("local[*]").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
import spark.implicits._
//sample dataframe
val df = Seq((1,2020),(2,2021),(3,2020),(4,2020),
(5,2020),(6,2020),(7,2019),(8,2019),(9,2018),(10,2020)).toDF("id","year")
//partionBy year column and save as delta table
df.write.format("delta").partitionBy("year").save("delta-stream")
//streaming delta table
spark.readStream.format("delta").load("delta-stream")
.where('year===2020)
.writeStream.format("console").start().awaitTermination()
physical plan of above streaming query: Notice the partitionFilters
Non-partitioned delta table streaming:
df.write.format("delta").save("delta-stream")
spark.readStream.format("delta").load("delta-stream")
.where('year===2020)
.writeStream.format("console").start().awaitTermination()
physical plan of above streaming query: Notice the pushedFilters

How does count distinct work in Apache spark SQL

I am trying to count distinct number of entities at different date ranges.
I need to understand how spark performs this operation
val distinct_daily_cust_12month = sqlContext.sql(s"select distinct day_id,txn_type,customer_id from ${db_name}.fact_customer where day_id>='${start_last_12month}' and day_id<='${start_date}' and txn_type not in (6,99)")
val category_mapping = sqlContext.sql(s"select * from datalake.category_mapping");
val daily_cust_12month_ds =distinct_daily_cust_12month.join(broadcast(category_mapping),distinct_daily_cust_12month("txn_type")===category_mapping("id")).select("category","sub_category","customer_id","day_id")
daily_cust_12month_ds.createOrReplaceTempView("daily_cust_12month_ds")
val total_cust_metrics = sqlContext.sql(s"""select 'total' as category,
count(distinct(case when day_id='${start_date}' then customer_id end)) as yest,
count(distinct(case when day_id>='${start_week}' and day_id<='${end_week}' then customer_id end)) as week,
count(distinct(case when day_id>='${start_month}' and day_id<='${start_date}' then customer_id end)) as mtd,
count(distinct(case when day_id>='${start_last_month}' and day_id<='${end_last_month}' then customer_id end)) as ltd,
count(distinct(case when day_id>='${start_last_6month}' and day_id<='${start_date}' then customer_id end)) as lsm,
count(distinct(case when day_id>='${start_last_12month}' and day_id<='${start_date}' then customer_id end)) as ltm
from daily_cust_12month_ds
""")
No Errors, But this is taking a lot of time. I want to know if there is a better way to do this in Spark
Count distinct works by hash-partitioning the data and then counting distinct elements by partition and finally summing the counts. In general it is a heavy operation due to the full shuffle and there is no silver bullet to that in Spark or most likely any fully distributed system, operations with distinct are inherently difficult to solve in a distributed system.
In some cases there are faster ways to do it:
If approximate values are acceptable, approx_count_distinct will usually be much faster as it is based on HyperLogLog and the amount of data to be shuffled is much much less than with the exact implementation.
If you can design your pipeline in a way that the data source is already partitioned so that there can't be any duplicates between partitions, the slow step of hash-partitioning the data frame is not needed.
P.S. To understand how count distinct work, you can always use explain:
df.select(countDistinct("foo")).explain()
Example output:
== Physical Plan ==
*(3) HashAggregate(keys=[], functions=[count(distinct foo#3)])
+- Exchange SinglePartition
+- *(2) HashAggregate(keys=[], functions=[partial_count(distinct foo#3)])
+- *(2) HashAggregate(keys=[foo#3], functions=[])
+- Exchange hashpartitioning(foo#3, 200)
+- *(1) HashAggregate(keys=[foo#3], functions=[])
+- LocalTableScan [foo#3]

question about the count method in spark dataset?

I was reading the book 'spark definitive guide'
It has an example like below.
val myRange = spark.range(1000).toDF("number")
val divisBy2 = myRange.where("number % 2 = 0")
divisBy2.count()
Below is the description for the three lines of code.
we started a Spark job that runs our filter transformation (a narrow
transformation), then an aggregation (a wide transformation) that performs the counts on a per
partition basis, and then a collect, which brings our result to a native object in the respective
language
I know the count is an action not a transformation, since it return an actual value and I can not call 'explain' on the return value of count.
But I was wondering why the count will cause the wide transformation, how can I know the execution plan of this count in tis case since I can not invoke the 'explain' after count
Thanks.
updated:
This image is the spark ui screenshot, I take it from databricks notebook,
I said there is a shuffle write and read operation, does it mean that there is a wide transformation?
Here is the execution plan:
== Physical Plan ==
*(2) HashAggregate(keys=[], functions=[count(1)], output=[count#7L])
+- Exchange SinglePartition
+- *(1) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#10L])
+- *(1) Project
+- *(1) Filter ((id#0L % 2) = 0)
+- *(1) Range (0, 1000, step=1, splits=8)
What we can see here:
Counting made inside each partition
All partitions merged into the single one
Final counting made

Resources