How do group by and window functions interact in Spark SQL? - apache-spark

From this question, I learned that window functions are evaluated after the group by function in PostgresSQL.
I'd like to know what happens when you use a group by and window function in the same query in Spark. I have the same questions as the poster from the previous question:
Are the selected rows grouped first, then considered by the window function ?
Or does the window function execute first, then the resulting values are grouped by the group by?
Something else?

If you have window and group by in same query then
Group by performed first then window function will be applied on the groupby dataset.
You can check query explain plan for more details.
Example:
//sample data
spark.sql("select * from tmp").show()
//+-------+----+
//|trip_id|name|
//+-------+----+
//| 1| a|
//| 2| b|
//+-------+----+
spark.sql("select row_number() over(order by trip_id),trip_id,count(*) cnt from tmp group by trip_id").explain()
//== Physical Plan ==
//*(4) Project [row_number() OVER (ORDER BY trip_id ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#150, trip_id#10, cnt#140L]
//+- Window [row_number() windowspecdefinition(trip_id#10 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number() OVER (ORDER BY //trip_id ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#150], [trip_id#10 ASC NULLS FIRST]
// +- *(3) Sort [trip_id#10 ASC NULLS FIRST], false, 0
// +- Exchange SinglePartition
// +- *(2) HashAggregate(keys=[trip_id#10], functions=[count(1)])
// +- Exchange hashpartitioning(trip_id#10, 200)
// +- *(1) HashAggregate(keys=[trip_id#10], functions=[partial_count(1)])
// +- LocalTableScan [trip_id#10]
*(2) groupby executed first
*(4) window function applied on the result of grouped dataset.
In case if you have window clause subquery and outer query have group by then subquery(window) executed first then outer query(groupBy) executed next.
Ex:
spark.sql("select trip_id,count(*) from(select *,row_number() over(order by trip_id)rn from tmp)e group by trip_id ").explain()

Related

Spark SQL view and partition column usage

I have a Databricks table (parquet not delta) "TableA" with a partition column "dldate", and it has ~3000 columns.
When I issue select * from TableA where dldate='2022-01-01', the query completes in seconds.
I have a view "view_tableA" which reads from "TableA" and performs some window functions on some of the columns.
When I issue select * from view_tableA where dldate='2022-01-01', the query runs forever.
Will the latter query effectively use the partition key of the table? If not, if there is any optimization I can do to make sure partition key is used?
If partitioning of all window functions is aligned with table partitioning, optimizer will be able to push down the predicate to table level and apply partition pruning.
For example:
SELECT *
FROM (SELECT *, sum(a) over (partition by dldate) FROM TableA)
WHERE dldate = '2022-01-01';
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Window [dldate#2932, a#2933, sum(a#2933) ...], [dldate#2932]
+- Sort [dldate#2932 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(dldate#2932, 200), ...
+- Project [dldate#2932, a#2933]
+- FileScan parquet tablea PartitionFilters: [isnotnull(dldate#2932), (dldate#2932 = 2022-01-01)]
Compare this with a query containing window function not partitioned by dldate:
SELECT *
FROM (SELECT *, sum(a) over (partition by a) FROM TableA)
WHERE dldate = '2022-01-01';
AdaptiveSparkPlan isFinalPlan=false
+- Filter (isnotnull(dldate#2968) AND (dldate#2968 = 2022-01-01)) << !!!
+- Window [dldate#2968, a#2969, sum(a#2969) ...], [a#2969]
+- Sort [a#2969 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(a#2969, 200), ...
+- Project [dldate#2968, a#2969]
+- FileScan parquet tablea PartitionFilters: [] << !!!

How to get the execution time of EACH operator in Spark SQL

for example, I have a table like:
name | height
mike 80
dan 11
And I have a Spark SQL like "select distinct name from table where height = 80"
In that case, I will get a Physical Plan like this:
== Physical Plan ==
CollectLimit 21
+- HashAggregate(keys=[name#0], functions=[], output=[name#0])
+- Exchange hashpartitioning(name#0, 200)
+- HashAggregate(keys=[name#0], functions=[], output=[name#0])
+- Project [name#0]
+- Filter (isnotnull(height#1L) && (height#1L = 80))
+- Scan ExistingRDD[name#0,height#1L]
Here is the problem, I'd like to get the execution time of EACH operator (such as Project, Filter...), so I check the SparkUI, it seems that Spark records execution time for each operator(figure below).
Click here to view the picture
I wish to get those results in my code.
So how can I get what I need?

Spark is sorting already sorted partitions resulting in performance loss

For a cached dataframe, partitioned and sorted within partitions, I get good performance when querying the key with a where clause but bad performance when performing a join with a small table on the same key.
See example dataset dftest below with 10Kx44K = 438M rows.
sqlContext.sql(f'set spark.sql.shuffle.partitions={32}')
sqlContext.clearCache()
sc.setCheckpointDir('/checkpoint/temp')
import datetime
from pyspark.sql.functions import *
from pyspark.sql import Row
start_date = datetime.date(1900, 1, 1)
end_date = datetime.date(2020, 1, 1)
dates = [ start_date + datetime.timedelta(n) for n in range(int ((end_date - start_date).days))]
dfdates=spark.createDataFrame(list(map(lambda x: Row(date=x), dates))) # some dates
dfrange=spark.createDataFrame(list(map(lambda x: Row(number=x), range(10000)))) # some number range
dfjoin = dfrange.crossJoin(dfdates)
dftest = dfjoin.withColumn("random1", round(rand()*(10-5)+5,0)).withColumn("random2", round(rand()*(10-5)+5,0)).withColumn("random3", round(rand()*(10-5)+5,0)).withColumn("random4", round(rand()*(10-5)+5,0)).withColumn("random5", round(rand()*(10-5)+5,0)).checkpoint()
dftest = dftest.repartition("number").sortWithinPartitions("number", "date").cache()
dftest.count() # 438,290,000 rows
The following query now takes roughly a second (on a small cluster with 2 workers):
dftest.where("number = 1000 and date = \"2001-04-04\"").count()
However, when I write a similar condition as a join, it takes 2 minutes:
dfsub = spark.createDataFrame([(10,"1900-01-02",1),
(1000,"2001-04-04",2),
(4000,"2002-05-05",3),
(5000,"1950-06-06",4),
(9875,"1980-07-07",5)],
["number","date", "dummy"]).repartition("number").sortWithinPartitions("number", "date").cache()
df_result = dftest.join(dfsub, ( dftest.number == dfsub.number ) & ( dftest.date == dfsub.date ), 'inner').cache()
df_result.count() # takes 2 minutes (result = 5)
I would have expected this to be roughly equally fast. Especially since I would hope that the larger dataframe is already clustered and cached. Looking at the plan:
== Physical Plan ==
InMemoryTableScan [number#771L, date#769, random1#775, random2#779, random3#784, random4#790, random5#797, number#945L, date#946, dummy#947L]
+- InMemoryRelation [number#771L, date#769, random1#775, random2#779, random3#784, random4#790, random5#797, number#945L, date#946, dummy#947L], StorageLevel(disk, memory, deserialized, 1 replicas)
+- *(3) SortMergeJoin [number#771L, cast(date#769 as string)], [number#945L, date#946], Inner
:- *(1) Sort [number#771L ASC NULLS FIRST, cast(date#769 as string) ASC NULLS FIRST], false, 0
: +- *(1) Filter (isnotnull(number#771L) && isnotnull(date#769))
: +- InMemoryTableScan [number#771L, date#769, random1#775, random2#779, random3#784, random4#790, random5#797], [isnotnull(number#771L), isnotnull(date#769)]
: +- InMemoryRelation [number#771L, date#769, random1#775, random2#779, random3#784, random4#790, random5#797], StorageLevel(disk, memory, deserialized, 1 replicas)
: +- Sort [number#771L ASC NULLS FIRST, date#769 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(number#771L, 32)
: +- *(1) Scan ExistingRDD[number#771L,date#769,random1#775,random2#779,random3#784,random4#790,random5#797]
+- *(2) Filter (isnotnull(number#945L) && isnotnull(date#946))
+- InMemoryTableScan [number#945L, date#946, dummy#947L], [isnotnull(number#945L), isnotnull(date#946)]
+- InMemoryRelation [number#945L, date#946, dummy#947L], StorageLevel(disk, memory, deserialized, 1 replicas)
+- Sort [number#945L ASC NULLS FIRST, date#946 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(number#945L, 32)
+- *(1) Scan ExistingRDD[number#945L,date#946,dummy#947L]
A lot of time seems to be spent sorting the larger dataframe by number and date (this line: Sort [number#771L ASC NULLS FIRST, date#769 ASC NULLS FIRST], false, 0). It leaves me with the following questions:
within the partitions, the sort order for both the left and right side is exactly the same, and optimal for the JOIN clause, why is Spark still sorting the the partitions again?
as the 5 join records match (up to) 5 partitions, why are all partitions evaluated?
It seems Catalyst is not using the info of repartition and sortWithinPartitions of the cached dataframe. Does it make sense to use sortWithinPartitions in cases like these?
Let me try to answer your three questions:
within the partitions, the sort order for both the left and right side is exactly the same, and optimal for the JOIN clause, why is Spark still sorting the the partitions again?
The sort order in both DataFrames is NOT the same, because of different datatypes in your sorting column date, in dfsub it is StringType and in dftest it is DateType, therefore during the join Spark sees that the ordering in both branches is different and thus forces the Sort.
as the 5 join records match (up to) 5 partitions, why are all partitions evaluated?
During the query plan processing Spark does not know how many partitions are non-empty in the small DataFrame and thus it has to process all of them.
It seems Catalyst is not using the info of repartition and sortWithinPartitions of the cached dataframe. Does it make sense to use sortWithinPartitions in cases like these?
Spark optimizer is using the information from repartition and sortWithinPartitions but there are some caveats about how it works. To fix up your query it is also important to repartition by the same columns (both of them) that you are using in the join (not just one column). In principle this should not be necessary and there is a related jira in progress that is trying to solve that.
So here are my proposed changes to your query:
Change the type of date column to StringType in dftest (Or similarly change to DateType in dfsub):
dftest.withColumn("date", col("date").cast('string'))
In both DataFrames change
.repartition("number")
to
.repartition("number", "date")
After these changes you should get a plan like this:
*(3) SortMergeJoin [number#1410L, date#1653], [number#1661L, date#1662], Inner
:- Sort [number#1410L ASC NULLS FIRST, date#1653 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(number#1410L, date#1653, 200)
: +- *(1) Project [number#1410L, cast(date#1408 as string) AS date#1653, random1#1540, random2#1544, random3#1549, random4#1555, random5#1562]
: +- *(1) Filter (isnotnull(number#1410L) && isnotnull(cast(date#1408 as string)))
: +- *(1) Scan ExistingRDD[number#1410L,date#1408,random1#1540,random2#1544,random3#1549,random4#1555,random5#1562]
+- Sort [number#1661L ASC NULLS FIRST, date#1662 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(number#1661L, date#1662, 200)
+- *(2) Filter (isnotnull(number#1661L) && isnotnull(date#1662))
+- *(2) Scan ExistingRDD[number#1661L,date#1662,dummy#1663L]
so there is only one Exchange and one Sort in each branch of the plan, both of the are coming from the repartition and sortWithinPartition that you call in your transformations and the join does not induce any more sorting or shuffling. Also notice that in my plan there is no InMemoryTableScan, since i did not use cache.

Spark.table() with limit seems to reads whole table

I am trying to read the first 20000 rows of a large table (10bil+ rows) from spark so I use the following lines of code.
df = spark.table("large_table").limit(20000).repartition(1)
df.explain()
And the explain plan looks like this.
== Physical Plan ==
Exchange RoundRobinPartitioning(1)
+- *(2) GlobalLimit 20000
+- Exchange SinglePartition
+- *(1) LocalLimit 20000
+- *(1) FileScan parquet large_table[...]
But when I try to write this df into a new table, it seems to kick off an insane number of tasks trying to read the whole table first and then begin writing to the final table! Why does spark not read only the first few files and get the row limit?

Does Spark SQL optimize queries with repeated expressions?

Given the following
from pyspark.sql import functions, window
f = functions.rank()
w1 = window.Window.partitionBy("column")
w2 = window.Window.partitionBy("column")
col = functions.col("column * 42")
and a dataframe df, is there any difference in performance for
df.select(f.over(w1), f.over(w2))
vs
df.select(f.over(w1), f.over(w1))
?
What about
df.select(col + 1, col + 2)
vs
df.select(functions.expr("column * 42 + 1"), functions.expr("column * 42 + 2")
?
(Feel free to imagine arbitrarily complex expressions in place of column * 42)
I.e. is there any benefit in reusing Column- and Window-instances vs constructing these expressions on the fly?
I would expect for Spark SQL to properly optimize this but couldn't find a conclusive answer on that.
Also, should I be able to answer this question myself by inspecting the result of df.explain() and if so, what should I be looking for?
Feel free to imagine arbitrarily complex expressions in place of column * 42
...or even any non-deterministic expressions like generating random numbers or current timestamp.
Whenever you ask such a question use explain operator to see what Spark SQL deals with under the covers (that in fact should be irrelevant of the programming language and function or method in use, shouldn't it?)
So, what happens under the covers of the following non-deterministic query (or fully deterministic, but non-deterministic at the first glance):
val q = spark.range(1)
.select(
current_timestamp as "now", // <-- this should be the same as the following line?
current_timestamp as "now_2",
rand as "r1", // <-- what about this and the following lines?
rand as "r2",
rand as "r3")
scala> q.show(truncate = false)
+-----------------------+-----------------------+-------------------+------------------+------------------+
|now |now_2 |r1 |r2 |r3 |
+-----------------------+-----------------------+-------------------+------------------+------------------+
|2017-12-13 15:17:46.305|2017-12-13 15:17:46.305|0.33579358107333823|0.9478025260069644|0.5846726225651472|
+-----------------------+-----------------------+-------------------+------------------+------------------+
I'm actually a bit surprised to have noticed that rands all generated different results as I had assumed the results would be the same. The answer is in...the source code of rand where you can see that it uses different seeds if not defined explicitly (learnt it today! thanks).
def rand(): Column = rand(Utils.random.nextLong)
The answer is to use the version of rand with explicit seed as that will give you the same Rand logical operator with the same seed across the query.
val seed = 1
val q = spark.range(1)
.select(
current_timestamp as "now", // <-- this should be the same as the following line?
current_timestamp as "now_2",
rand(seed) as "r1", // <-- what about this and the following lines?
rand(seed) as "r2",
rand(seed) as "r3")
scala> q.show(false)
+-----------------------+-----------------------+-------------------+-------------------+-------------------+
|now |now_2 |r1 |r2 |r3 |
+-----------------------+-----------------------+-------------------+-------------------+-------------------+
|2017-12-13 15:43:59.019|2017-12-13 15:43:59.019|0.06498948189958098|0.06498948189958098|0.06498948189958098|
+-----------------------+-----------------------+-------------------+-------------------+-------------------+
Spark SQL knows what you used in a structured query since the high level API of Spark SQL called DataFrame or Dataset is just a wrapper around logical operators that are the same across languages (Python, Scala, Java, R, SQL).
Just look at the source code of any function and you will see a Catalyst expression (e.g. rand) or a Dataset operator (e.g. select) and you will see one or a tree of logical operators.
In the end, Spark SQL uses an rule-based optimizer that uses rules to optimize your query and find repetitions.
So, let's have a look at your case (which is more deterministic than rand).
(I'm using Scala but the differences are at language not optimization level)
import org.apache.spark.sql.expressions.Window
val w1 = Window.partitionBy("column").orderBy("column")
val w2 = Window.partitionBy("column").orderBy("column")
In your case you used rank that requires a dataset to be ordered so I add orderBy clause to make the window specification complete.
scala> w1 == w2
res1: Boolean = false
They are indeed different from Scala's point of view
val df = spark.range(5).withColumnRenamed("id", "column")
scala> df.show
+------+
|column|
+------+
| 0|
| 1|
| 2|
| 3|
| 4|
+------+
With the dataset (which is pretty much irrelevant to our discussion), let's create a structured query and explain it to see the physical plan which is what Spark SQL executes.
val q = df.select(rank over w1, rank over w2)
scala> q.explain
== Physical Plan ==
*Project [RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#193, RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#194]
+- Window [rank(column#156L) windowspecdefinition(column#156L, column#156L ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#193, rank(column#156L) windowspecdefinition(column#156L, column#156L ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#194], [column#156L], [column#156L ASC NULLS FIRST]
+- *Sort [column#156L ASC NULLS FIRST, column#156L ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(column#156L, 200)
+- *Project [id#153L AS column#156L]
+- *Range (0, 5, step=1, splits=8)
Let's use the numbered output so we can reference every line in the description.
val plan = q.queryExecution.executedPlan
scala> println(plan.numberedTreeString)
00 *Project [RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#193, RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#194]
01 +- Window [rank(column#156L) windowspecdefinition(column#156L, column#156L ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#193, rank(column#156L) windowspecdefinition(column#156L, column#156L ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS RANK() OVER (PARTITION BY column ORDER BY column ASC NULLS FIRST unspecifiedframe$())#194], [column#156L], [column#156L ASC NULLS FIRST]
02 +- *Sort [column#156L ASC NULLS FIRST, column#156L ASC NULLS FIRST], false, 0
03 +- Exchange hashpartitioning(column#156L, 200)
04 +- *Project [id#153L AS column#156L]
05 +- *Range (0, 5, step=1, splits=8)
With that you can see whether the query is similar to another and what are the differences if any. That's the most definitive answer you can get and...surprise...things may (and often will) change between Spark versions.
I.e. is there any benefit in reusing Column- and Window-instances vs constructing these expressions on the fly?
I would not think much about it as I'd expect Spark to handle it internally (and as you may have noticed I was surprised to have seen that rand works differently).
Just use explain to see the physical plan and you will be able to answer the question yourself.

Resources