Does Spark SQL optimize queries with repeated expressions? - apache-spark

Given the following
from pyspark.sql import functions, window
f = functions.rank()
w1 = window.Window.partitionBy("column")
w2 = window.Window.partitionBy("column")
col = functions.col("column * 42")
and a dataframe df, is there any difference in performance for
df.select(f.over(w1), f.over(w2))
vs
df.select(f.over(w1), f.over(w1))
?
What about
df.select(col + 1, col + 2)
vs
df.select(functions.expr("column * 42 + 1"), functions.expr("column * 42 + 2")
?
(Feel free to imagine arbitrarily complex expressions in place of column * 42)
I.e. is there any benefit in reusing Column- and Window-instances vs constructing these expressions on the fly?
I would expect for Spark SQL to properly optimize this but couldn't find a conclusive answer on that.
Also, should I be able to answer this question myself by inspecting the result of df.explain() and if so, what should I be looking for?

Feel free to imagine arbitrarily complex expressions in place of column * 42
...or even any non-deterministic expressions like generating random numbers or current timestamp.
Whenever you ask such a question use explain operator to see what Spark SQL deals with under the covers (that in fact should be irrelevant of the programming language and function or method in use, shouldn't it?)
So, what happens under the covers of the following non-deterministic query (or fully deterministic, but non-deterministic at the first glance):
val q = spark.range(1)
.select(
current_timestamp as "now", // <-- this should be the same as the following line?
current_timestamp as "now_2",
rand as "r1", // <-- what about this and the following lines?
rand as "r2",
rand as "r3")
scala> q.show(truncate = false)
+-----------------------+-----------------------+-------------------+------------------+------------------+
|now |now_2 |r1 |r2 |r3 |
+-----------------------+-----------------------+-------------------+------------------+------------------+
|2017-12-13 15:17:46.305|2017-12-13 15:17:46.305|0.33579358107333823|0.9478025260069644|0.5846726225651472|
+-----------------------+-----------------------+-------------------+------------------+------------------+
I'm actually a bit surprised to have noticed that rands all generated different results as I had assumed the results would be the same. The answer is in...the source code of rand where you can see that it uses different seeds if not defined explicitly (learnt it today! thanks).
def rand(): Column = rand(Utils.random.nextLong)
The answer is to use the version of rand with explicit seed as that will give you the same Rand logical operator with the same seed across the query.
val seed = 1
val q = spark.range(1)
.select(
current_timestamp as "now", // <-- this should be the same as the following line?
current_timestamp as "now_2",
rand(seed) as "r1", // <-- what about this and the following lines?
rand(seed) as "r2",
rand(seed) as "r3")
scala> q.show(false)
+-----------------------+-----------------------+-------------------+-------------------+-------------------+
|now |now_2 |r1 |r2 |r3 |
+-----------------------+-----------------------+-------------------+-------------------+-------------------+
|2017-12-13 15:43:59.019|2017-12-13 15:43:59.019|0.06498948189958098|0.06498948189958098|0.06498948189958098|
+-----------------------+-----------------------+-------------------+-------------------+-------------------+
Spark SQL knows what you used in a structured query since the high level API of Spark SQL called DataFrame or Dataset is just a wrapper around logical operators that are the same across languages (Python, Scala, Java, R, SQL).
Just look at the source code of any function and you will see a Catalyst expression (e.g. rand) or a Dataset operator (e.g. select) and you will see one or a tree of logical operators.
In the end, Spark SQL uses an rule-based optimizer that uses rules to optimize your query and find repetitions.
So, let's have a look at your case (which is more deterministic than rand).
(I'm using Scala but the differences are at language not optimization level)
import org.apache.spark.sql.expressions.Window
val w1 = Window.partitionBy("column").orderBy("column")
val w2 = Window.partitionBy("column").orderBy("column")
In your case you used rank that requires a dataset to be ordered so I add orderBy clause to make the window specification complete.
scala> w1 == w2
res1: Boolean = false
They are indeed different from Scala's point of view
val df = spark.range(5).withColumnRenamed("id", "column")
scala> df.show
+------+
|column|
+------+
| 0|
| 1|
| 2|
| 3|
| 4|
+------+
With the dataset (which is pretty much irrelevant to our discussion), let's create a structured query and explain it to see the physical plan which is what Spark SQL executes.
val q = df.select(rank over w1, rank over w2)
scala> q.explain
== Physical Plan ==
*Project [RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#193, RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#194]
+- Window [rank(column#156L) windowspecdefinition(column#156L, column#156L ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#193, rank(column#156L) windowspecdefinition(column#156L, column#156L ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#194], [column#156L], [column#156L ASC NULLS FIRST]
+- *Sort [column#156L ASC NULLS FIRST, column#156L ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(column#156L, 200)
+- *Project [id#153L AS column#156L]
+- *Range (0, 5, step=1, splits=8)
Let's use the numbered output so we can reference every line in the description.
val plan = q.queryExecution.executedPlan
scala> println(plan.numberedTreeString)
00 *Project [RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#193, RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#194]
01 +- Window [rank(column#156L) windowspecdefinition(column#156L, column#156L ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#193, rank(column#156L) windowspecdefinition(column#156L, column#156L ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#194], [column#156L], [column#156L ASC NULLS FIRST]
02 +- *Sort [column#156L ASC NULLS FIRST, column#156L ASC NULLS FIRST], false, 0
03 +- Exchange hashpartitioning(column#156L, 200)
04 +- *Project [id#153L AS column#156L]
05 +- *Range (0, 5, step=1, splits=8)
With that you can see whether the query is similar to another and what are the differences if any. That's the most definitive answer you can get and...surprise...things may (and often will) change between Spark versions.
I.e. is there any benefit in reusing Column- and Window-instances vs constructing these expressions on the fly?
I would not think much about it as I'd expect Spark to handle it internally (and as you may have noticed I was surprised to have seen that rand works differently).
Just use explain to see the physical plan and you will be able to answer the question yourself.

Related

Spark SQL view and partition column usage

I have a Databricks table (parquet not delta) "TableA" with a partition column "dldate", and it has ~3000 columns.
When I issue select * from TableA where dldate='2022-01-01', the query completes in seconds.
I have a view "view_tableA" which reads from "TableA" and performs some window functions on some of the columns.
When I issue select * from view_tableA where dldate='2022-01-01', the query runs forever.
Will the latter query effectively use the partition key of the table? If not, if there is any optimization I can do to make sure partition key is used?
If partitioning of all window functions is aligned with table partitioning, optimizer will be able to push down the predicate to table level and apply partition pruning.
For example:
SELECT *
FROM (SELECT *, sum(a) over (partition by dldate) FROM TableA)
WHERE dldate = '2022-01-01';
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Window [dldate#2932, a#2933, sum(a#2933) ...], [dldate#2932]
+- Sort [dldate#2932 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(dldate#2932, 200), ...
+- Project [dldate#2932, a#2933]
+- FileScan parquet tablea PartitionFilters: [isnotnull(dldate#2932), (dldate#2932 = 2022-01-01)]
Compare this with a query containing window function not partitioned by dldate:
SELECT *
FROM (SELECT *, sum(a) over (partition by a) FROM TableA)
WHERE dldate = '2022-01-01';
AdaptiveSparkPlan isFinalPlan=false
+- Filter (isnotnull(dldate#2968) AND (dldate#2968 = 2022-01-01)) << !!!
+- Window [dldate#2968, a#2969, sum(a#2969) ...], [a#2969]
+- Sort [a#2969 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(a#2969, 200), ...
+- Project [dldate#2968, a#2969]
+- FileScan parquet tablea PartitionFilters: [] << !!!

Apache Spark: broadcast join behaviour: filtering of joined tables and temp tables

I need to join 2 tables in spark.
But instead of joining 2 tables completely, I first filter out a part of second table:
spark.sql("select * from a join b on a.key=b.key where b.value='xxx' ")
I want to use broadcast join in this case.
Spark has a parameter which defines max table size for broadcast join: spark.sql.autoBroadcastJoinThreshold:
Configures the maximum size in bytes for a table that will be
broadcast to all worker nodes when performing a join. By setting this
value to -1 broadcasting can be disabled. Note that currently
statistics are only supported for Hive Metastore tables where the
command ANALYZE TABLE COMPUTE STATISTICS noscan has been
run. http://spark.apache.org/docs/2.4.0/sql-performance-tuning.html
I have following questions about this setup:
which table size spark will compare with autoBroadcastJoinThreshold's value: FULL size, or size AFTER applying where clause?
I am assuming that spark will apply where clause BEFORE broadcasting, correct?
the doc says I need to run Hive's Analyze Table command beforehand. How it will work in a case when I am using temp view as a table? As far as I understand I cannot run Analyze Table command against spark's temp view created via dataFrame.createorReplaceTempView("b"). Can I broadcast temp view contents?
Understanding for option 2 is correct.
You can not analyze a TEMP table in spark . Read here
In case you want to take the lead and want to specify the dataframe which you want to broadcast, instead spark decides, can use below snippet-
df = df1.join(F.broadcast(df2),df1.some_col == df2.some_col, "left")
I went ahead and did some small experiments to answer your 1st question.
Question 1 :
created a dataframe a with 3 rows [key,df_a_column]
created a dataframe b with 10 rows [key,value]
ran: spark.sql("SELECT * FROM a JOIN b ON a.key = b.key").explain()
== Physical Plan ==
*(1) BroadcastHashJoin [key#122], [key#111], Inner, BuildLeft, false
:- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#168]
: +- LocalTableScan [key#122, df_a_column#123]
+- *(1) LocalTableScan [key#111, value#112]
As expected the Smaller df a with 3 rows is broadcasted.
Ran : spark.sql("SELECT * FROM a JOIN b ON a.key = b.key where b.value=\"bat\"").explain()
== Physical Plan ==
*(1) BroadcastHashJoin [key#122], [key#111], Inner, BuildRight, false
:- *(1) LocalTableScan [key#122, df_a_column#123]
+- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#152]
+- LocalTableScan [key#111, value#112]
Here you can notice the dataframe b is Broadcasted ! meaning spark evaluates the size AFTER applying where for choosing which one to broadcast.
Question 2 :
Yes you are right. It's evident from the previous output it applies where first.
Question 3 :
No you cannot analyse but you can broadcast tempView table by hinting spark about it even in SQL. ref
Example : spark.sql("SELECT /*+ BROADCAST(b) */ * FROM a JOIN b ON a.key = b.key")
And if you see explain now :
== Physical Plan ==
*(1) BroadcastHashJoin [key#122], [key#111], Inner, BuildRight, false
:- *(1) LocalTableScan [key#122, df_a_column#123]
+- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [id=#184]
+- LocalTableScan [key#111, value#112]
Now if you see, dataframe b is broadcasted even though it has 10 rows.
In question 1, without the hint , a was broadcasted .
Note: Broadcast hint in SQL spark is available for 2.2
Tips to understand the physical plan :
Figure out the dataframe from the LocalTableScan[ list of columns ]
The dataframe present under the sub tree/list of BroadcastExchange is being broadcasted.

How do group by and window functions interact in Spark SQL?

From this question, I learned that window functions are evaluated after the group by function in PostgresSQL.
I'd like to know what happens when you use a group by and window function in the same query in Spark. I have the same questions as the poster from the previous question:
Are the selected rows grouped first, then considered by the window function ?
Or does the window function execute first, then the resulting values are grouped by the group by?
Something else?
If you have window and group by in same query then
Group by performed first then window function will be applied on the groupby dataset.
You can check query explain plan for more details.
Example:
//sample data
spark.sql("select * from tmp").show()
//+-------+----+
//|trip_id|name|
//+-------+----+
//| 1| a|
//| 2| b|
//+-------+----+
spark.sql("select row_number() over(order by trip_id),trip_id,count(*) cnt from tmp group by trip_id").explain()
//== Physical Plan ==
//*(4) Project [row_number() OVER (ORDER BY trip_id ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#150, trip_id#10, cnt#140L]
//+- Window [row_number() windowspecdefinition(trip_id#10 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number() OVER (ORDER BY //trip_id ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#150], [trip_id#10 ASC NULLS FIRST]
// +- *(3) Sort [trip_id#10 ASC NULLS FIRST], false, 0
// +- Exchange SinglePartition
// +- *(2) HashAggregate(keys=[trip_id#10], functions=[count(1)])
// +- Exchange hashpartitioning(trip_id#10, 200)
// +- *(1) HashAggregate(keys=[trip_id#10], functions=[partial_count(1)])
// +- LocalTableScan [trip_id#10]
*(2) groupby executed first
*(4) window function applied on the result of grouped dataset.
In case if you have window clause subquery and outer query have group by then subquery(window) executed first then outer query(groupBy) executed next.
Ex:
spark.sql("select trip_id,count(*) from(select *,row_number() over(order by trip_id)rn from tmp)e group by trip_id ").explain()

Spark is sorting already sorted partitions resulting in performance loss

For a cached dataframe, partitioned and sorted within partitions, I get good performance when querying the key with a where clause but bad performance when performing a join with a small table on the same key.
See example dataset dftest below with 10Kx44K = 438M rows.
sqlContext.sql(f'set spark.sql.shuffle.partitions={32}')
sqlContext.clearCache()
sc.setCheckpointDir('/checkpoint/temp')
import datetime
from pyspark.sql.functions import *
from pyspark.sql import Row
start_date = datetime.date(1900, 1, 1)
end_date = datetime.date(2020, 1, 1)
dates = [ start_date + datetime.timedelta(n) for n in range(int ((end_date - start_date).days))]
dfdates=spark.createDataFrame(list(map(lambda x: Row(date=x), dates))) # some dates
dfrange=spark.createDataFrame(list(map(lambda x: Row(number=x), range(10000)))) # some number range
dfjoin = dfrange.crossJoin(dfdates)
dftest = dfjoin.withColumn("random1", round(rand()*(10-5)+5,0)).withColumn("random2", round(rand()*(10-5)+5,0)).withColumn("random3", round(rand()*(10-5)+5,0)).withColumn("random4", round(rand()*(10-5)+5,0)).withColumn("random5", round(rand()*(10-5)+5,0)).checkpoint()
dftest = dftest.repartition("number").sortWithinPartitions("number", "date").cache()
dftest.count() # 438,290,000 rows
The following query now takes roughly a second (on a small cluster with 2 workers):
dftest.where("number = 1000 and date = \"2001-04-04\"").count()
However, when I write a similar condition as a join, it takes 2 minutes:
dfsub = spark.createDataFrame([(10,"1900-01-02",1),
(1000,"2001-04-04",2),
(4000,"2002-05-05",3),
(5000,"1950-06-06",4),
(9875,"1980-07-07",5)],
["number","date", "dummy"]).repartition("number").sortWithinPartitions("number", "date").cache()
df_result = dftest.join(dfsub, ( dftest.number == dfsub.number ) & ( dftest.date == dfsub.date ), 'inner').cache()
df_result.count() # takes 2 minutes (result = 5)
I would have expected this to be roughly equally fast. Especially since I would hope that the larger dataframe is already clustered and cached. Looking at the plan:
== Physical Plan ==
InMemoryTableScan [number#771L, date#769, random1#775, random2#779, random3#784, random4#790, random5#797, number#945L, date#946, dummy#947L]
+- InMemoryRelation [number#771L, date#769, random1#775, random2#779, random3#784, random4#790, random5#797, number#945L, date#946, dummy#947L], StorageLevel(disk, memory, deserialized, 1 replicas)
+- *(3) SortMergeJoin [number#771L, cast(date#769 as string)], [number#945L, date#946], Inner
:- *(1) Sort [number#771L ASC NULLS FIRST, cast(date#769 as string) ASC NULLS FIRST], false, 0
: +- *(1) Filter (isnotnull(number#771L) && isnotnull(date#769))
: +- InMemoryTableScan [number#771L, date#769, random1#775, random2#779, random3#784, random4#790, random5#797], [isnotnull(number#771L), isnotnull(date#769)]
: +- InMemoryRelation [number#771L, date#769, random1#775, random2#779, random3#784, random4#790, random5#797], StorageLevel(disk, memory, deserialized, 1 replicas)
: +- Sort [number#771L ASC NULLS FIRST, date#769 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(number#771L, 32)
: +- *(1) Scan ExistingRDD[number#771L,date#769,random1#775,random2#779,random3#784,random4#790,random5#797]
+- *(2) Filter (isnotnull(number#945L) && isnotnull(date#946))
+- InMemoryTableScan [number#945L, date#946, dummy#947L], [isnotnull(number#945L), isnotnull(date#946)]
+- InMemoryRelation [number#945L, date#946, dummy#947L], StorageLevel(disk, memory, deserialized, 1 replicas)
+- Sort [number#945L ASC NULLS FIRST, date#946 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(number#945L, 32)
+- *(1) Scan ExistingRDD[number#945L,date#946,dummy#947L]
A lot of time seems to be spent sorting the larger dataframe by number and date (this line: Sort [number#771L ASC NULLS FIRST, date#769 ASC NULLS FIRST], false, 0). It leaves me with the following questions:
within the partitions, the sort order for both the left and right side is exactly the same, and optimal for the JOIN clause, why is Spark still sorting the the partitions again?
as the 5 join records match (up to) 5 partitions, why are all partitions evaluated?
It seems Catalyst is not using the info of repartition and sortWithinPartitions of the cached dataframe. Does it make sense to use sortWithinPartitions in cases like these?
Let me try to answer your three questions:
within the partitions, the sort order for both the left and right side is exactly the same, and optimal for the JOIN clause, why is Spark still sorting the the partitions again?
The sort order in both DataFrames is NOT the same, because of different datatypes in your sorting column date, in dfsub it is StringType and in dftest it is DateType, therefore during the join Spark sees that the ordering in both branches is different and thus forces the Sort.
as the 5 join records match (up to) 5 partitions, why are all partitions evaluated?
During the query plan processing Spark does not know how many partitions are non-empty in the small DataFrame and thus it has to process all of them.
It seems Catalyst is not using the info of repartition and sortWithinPartitions of the cached dataframe. Does it make sense to use sortWithinPartitions in cases like these?
Spark optimizer is using the information from repartition and sortWithinPartitions but there are some caveats about how it works. To fix up your query it is also important to repartition by the same columns (both of them) that you are using in the join (not just one column). In principle this should not be necessary and there is a related jira in progress that is trying to solve that.
So here are my proposed changes to your query:
Change the type of date column to StringType in dftest (Or similarly change to DateType in dfsub):
dftest.withColumn("date", col("date").cast('string'))
In both DataFrames change
.repartition("number")
to
.repartition("number", "date")
After these changes you should get a plan like this:
*(3) SortMergeJoin [number#1410L, date#1653], [number#1661L, date#1662], Inner
:- Sort [number#1410L ASC NULLS FIRST, date#1653 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(number#1410L, date#1653, 200)
: +- *(1) Project [number#1410L, cast(date#1408 as string) AS date#1653, random1#1540, random2#1544, random3#1549, random4#1555, random5#1562]
: +- *(1) Filter (isnotnull(number#1410L) && isnotnull(cast(date#1408 as string)))
: +- *(1) Scan ExistingRDD[number#1410L,date#1408,random1#1540,random2#1544,random3#1549,random4#1555,random5#1562]
+- Sort [number#1661L ASC NULLS FIRST, date#1662 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(number#1661L, date#1662, 200)
+- *(2) Filter (isnotnull(number#1661L) && isnotnull(date#1662))
+- *(2) Scan ExistingRDD[number#1661L,date#1662,dummy#1663L]
so there is only one Exchange and one Sort in each branch of the plan, both of the are coming from the repartition and sortWithinPartition that you call in your transformations and the join does not induce any more sorting or shuffling. Also notice that in my plan there is no InMemoryTableScan, since i did not use cache.

How does count distinct work in Apache spark SQL

I am trying to count distinct number of entities at different date ranges.
I need to understand how spark performs this operation
val distinct_daily_cust_12month = sqlContext.sql(s"select distinct day_id,txn_type,customer_id from ${db_name}.fact_customer where day_id>='${start_last_12month}' and day_id<='${start_date}' and txn_type not in (6,99)")
val category_mapping = sqlContext.sql(s"select * from datalake.category_mapping");
val daily_cust_12month_ds =distinct_daily_cust_12month.join(broadcast(category_mapping),distinct_daily_cust_12month("txn_type")===category_mapping("id")).select("category","sub_category","customer_id","day_id")
daily_cust_12month_ds.createOrReplaceTempView("daily_cust_12month_ds")
val total_cust_metrics = sqlContext.sql(s"""select 'total' as category,
count(distinct(case when day_id='${start_date}' then customer_id end)) as yest,
count(distinct(case when day_id>='${start_week}' and day_id<='${end_week}' then customer_id end)) as week,
count(distinct(case when day_id>='${start_month}' and day_id<='${start_date}' then customer_id end)) as mtd,
count(distinct(case when day_id>='${start_last_month}' and day_id<='${end_last_month}' then customer_id end)) as ltd,
count(distinct(case when day_id>='${start_last_6month}' and day_id<='${start_date}' then customer_id end)) as lsm,
count(distinct(case when day_id>='${start_last_12month}' and day_id<='${start_date}' then customer_id end)) as ltm
from daily_cust_12month_ds
""")
No Errors, But this is taking a lot of time. I want to know if there is a better way to do this in Spark
Count distinct works by hash-partitioning the data and then counting distinct elements by partition and finally summing the counts. In general it is a heavy operation due to the full shuffle and there is no silver bullet to that in Spark or most likely any fully distributed system, operations with distinct are inherently difficult to solve in a distributed system.
In some cases there are faster ways to do it:
If approximate values are acceptable, approx_count_distinct will usually be much faster as it is based on HyperLogLog and the amount of data to be shuffled is much much less than with the exact implementation.
If you can design your pipeline in a way that the data source is already partitioned so that there can't be any duplicates between partitions, the slow step of hash-partitioning the data frame is not needed.
P.S. To understand how count distinct work, you can always use explain:
df.select(countDistinct("foo")).explain()
Example output:
== Physical Plan ==
*(3) HashAggregate(keys=[], functions=[count(distinct foo#3)])
+- Exchange SinglePartition
+- *(2) HashAggregate(keys=[], functions=[partial_count(distinct foo#3)])
+- *(2) HashAggregate(keys=[foo#3], functions=[])
+- Exchange hashpartitioning(foo#3, 200)
+- *(1) HashAggregate(keys=[foo#3], functions=[])
+- LocalTableScan [foo#3]

Resources