Related
I have a pypsark dataframe like this:
| name|segment_list|rung_list |
+--------------------+------------+-----------+
| Campaign 1 | [1.0, 5.0]| [L2, L3]|
| Campaign 1 | [1.1]| [L1]|
| Campaign 2 | [1.2]| [L2]|
| Campaign 2 | [1.1]| [L4, L5]|
+--------------------+------------+-----------+
I have another pyspark dataframe that has segment and rung for every customer:
+-----------+---------------+---------+
|customer_id| segment |rung |
+-----------+---------------+---------+
| 124001823| 1.0| L2|
| 166001989| 5.0| L2|
| 768002266| 1.1| L1|
+-----------+---------------+---------+
What I want is a final output that figures out the customers based on the segment and rung list. The final output should be something like the following:
| name|customer_id |
+--------------------+------------+
| Campaign 1 | 124001823 |
| Campaign 1 | 166001989 |
| Campaign 1 | 768002266 |
+--------------------+------------+
I tried using udf but that approach didnt quite work. I would like to avoid using a for loop on a collect operation or going row by row. So I am primarily looking for a groupby operation on name column.
So I want a better way to do the following:
for row in x.collect():
y = eligible.filter(eligible.segment.isin(row['segment_list'])).filter(eligible.rung.isin(row['rung_list']))
you could try to use array_contains for the join conditions.
here's an example
data1_sdf. \
join(data2_sdf,
func.expr('array_contains(segment_list, segment)') & func.expr('array_contains(rung_list, rung)'),
'left'
). \
select('name', 'customer_id'). \
dropDuplicates(). \
show(truncate=False)
# +----------+-----------+
# |name |customer_id|
# +----------+-----------+
# |Campaign 1|166001989 |
# |Campaign 1|124001823 |
# |Campaign 1|768002266 |
# |Campaign 2|null |
# +----------+-----------
pasting the query plan spark produced
== Parsed Logical Plan ==
Deduplicate [name#123, customer_id#129]
+- Project [name#123, customer_id#129]
+- Join LeftOuter, (array_contains(segment_list#124, segment#130) AND array_contains(rung_list#125, rung#131))
:- LogicalRDD [name#123, segment_list#124, rung_list#125], false
+- LogicalRDD [customer_id#129, segment#130, rung#131], false
== Analyzed Logical Plan ==
name: string, customer_id: string
Deduplicate [name#123, customer_id#129]
+- Project [name#123, customer_id#129]
+- Join LeftOuter, (array_contains(segment_list#124, segment#130) AND array_contains(rung_list#125, rung#131))
:- LogicalRDD [name#123, segment_list#124, rung_list#125], false
+- LogicalRDD [customer_id#129, segment#130, rung#131], false
== Optimized Logical Plan ==
Aggregate [name#123, customer_id#129], [name#123, customer_id#129]
+- Project [name#123, customer_id#129]
+- Join LeftOuter, (array_contains(segment_list#124, segment#130) AND array_contains(rung_list#125, rung#131))
:- LogicalRDD [name#123, segment_list#124, rung_list#125], false
+- Filter (isnotnull(segment#130) AND isnotnull(rung#131))
+- LogicalRDD [customer_id#129, segment#130, rung#131], false
== Physical Plan ==
*(4) HashAggregate(keys=[name#123, customer_id#129], functions=[], output=[name#123, customer_id#129])
+- Exchange hashpartitioning(name#123, customer_id#129, 200), ENSURE_REQUIREMENTS, [id=#267]
+- *(3) HashAggregate(keys=[name#123, customer_id#129], functions=[], output=[name#123, customer_id#129])
+- *(3) Project [name#123, customer_id#129]
+- BroadcastNestedLoopJoin BuildRight, LeftOuter, (array_contains(segment_list#124, segment#130) AND array_contains(rung_list#125, rung#131))
:- *(1) Scan ExistingRDD[name#123,segment_list#124,rung_list#125]
+- BroadcastExchange IdentityBroadcastMode, [id=#261]
+- *(2) Filter (isnotnull(segment#130) AND isnotnull(rung#131))
+- *(2) Scan ExistingRDD[customer_id#129,segment#130,rung#131]
seems it is not well optimized, I'm thinking there can be other optimized methods.
can anyone explain the below behaviour in spark sql join. It does not matter whether I am using full_join/full_outer/left/left_outer, the physical plan always shows that Inner join is being used..
q1 = spark.sql("select count(*) from table_t1 t1 full join table_t1 t2 on t1.anchor_page_id = t2.anchor_page_id and t1.item_id = t2.item_id and t1.store_id = t2.store_id where t1.date_id = '20220323' and t2.date_id = '20220324'")
q1.explain()
== Physical Plan ==
*(6) HashAggregate(keys=[], functions=[count(1)])
+- Exchange SinglePartition
+- *(5) HashAggregate(keys=[], functions=[partial_count(1)])
+- *(5) Project
+- *(5) SortMergeJoin [anchor_page_id#1, item_id#2, store_id#5], [anchor_page_id#19, item_id#20, store_id#23], Inner
:- *(2) Sort [anchor_page_id#1 ASC NULLS FIRST, item_id#2 ASC NULLS FIRST, store_id#5 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(anchor_page_id#1, item_id#2, store_id#5, 200)
: +- *(1) Project [anchor_page_id#1, item_id#2, store_id#5]
: +- *(1) Filter ((isnotnull(item_id#2) && isnotnull(anchor_page_id#1)) && isnotnull(store_id#5))
: +- *(1) FileScan parquet table_t1[anchor_page_id#1,item_id#2,store_id#5,date_id#18] Batched: true, Format: Parquet, Location: PrunedInMemoryFileIndex[gs://abc..., PartitionCount: 1, PartitionFilters: [isnotnull(date_id#18), (date_id#18 = 20220323)], PushedFilters: [IsNotNull(item_id), IsNotNull(anchor_page_id), IsNotNull(store_id)], ReadSchema: struct<anchor_page_id:string,item_id:string,store_id:string>
+- *(4) Sort [anchor_page_id#19 ASC NULLS FIRST, item_id#20 ASC NULLS FIRST, store_id#23 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(anchor_page_id#19, item_id#20, store_id#23, 200)
+- *(3) Project [anchor_page_id#19, item_id#20, store_id#23]
+- *(3) Filter ((isnotnull(anchor_page_id#19) && isnotnull(item_id#20)) && isnotnull(store_id#23))
+- *(3) FileScan parquet table_t1[anchor_page_id#19,item_id#20,store_id#23,date_id#36] Batched: true, Format: Parquet, Location: PrunedInMemoryFileIndex[gs://abc..., PartitionCount: 1, PartitionFilters: [isnotnull(date_id#36), (date_id#36 = 20220324)], PushedFilters: [IsNotNull(anchor_page_id), IsNotNull(item_id), IsNotNull(store_id)], ReadSchema: struct<anchor_page_id:string,item_id:string,store_id:string>
>>>
q2 = spark.sql("select count(*) from table_t1 t1 full outer join table_t1 t2 on t1.anchor_page_id = t2.anchor_page_id and t1.item_id = t2.item_id and t1.store_id = t2.store_id where t1.date_id = '20220323' and t2.date_id = '20220324'")
q2.explain()
== Physical Plan ==
*(6) HashAggregate(keys=[], functions=[count(1)])
+- Exchange SinglePartition
+- *(5) HashAggregate(keys=[], functions=[partial_count(1)])
+- *(5) Project
+- *(5) SortMergeJoin [anchor_page_id#1, item_id#2, store_id#5], [anchor_page_id#42, item_id#43, store_id#46], Inner
:- *(2) Sort [anchor_page_id#1 ASC NULLS FIRST, item_id#2 ASC NULLS FIRST, store_id#5 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(anchor_page_id#1, item_id#2, store_id#5, 200)
: +- *(1) Project [anchor_page_id#1, item_id#2, store_id#5]
: +- *(1) Filter ((isnotnull(item_id#2) && isnotnull(anchor_page_id#1)) && isnotnull(store_id#5))
: +- *(1) FileScan parquet table_t1[anchor_page_id#1,item_id#2,store_id#5,date_id#18] Batched: true, Format: Parquet, Location: PrunedInMemoryFileIndex[gs://abc..., PartitionCount: 1, PartitionFilters: [isnotnull(date_id#18), (date_id#18 = 20220323)], PushedFilters: [IsNotNull(item_id), IsNotNull(anchor_page_id), IsNotNull(store_id)], ReadSchema: struct<anchor_page_id:string,item_id:string,store_id:string>
+- *(4) Sort [anchor_page_id#42 ASC NULLS FIRST, item_id#43 ASC NULLS FIRST, store_id#46 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(anchor_page_id#42, item_id#43, store_id#46, 200)
+- *(3) Project [anchor_page_id#42, item_id#43, store_id#46]
+- *(3) Filter ((isnotnull(store_id#46) && isnotnull(anchor_page_id#42)) && isnotnull(item_id#43))
+- *(3) FileScan parquet table_t1[anchor_page_id#42,item_id#43,store_id#46,date_id#59] Batched: true, Format: Parquet, Location: PrunedInMemoryFileIndex[gs://abc..., PartitionCount: 1, PartitionFilters: [isnotnull(date_id#59), (date_id#59 = 20220324)], PushedFilters: [IsNotNull(store_id), IsNotNull(anchor_page_id), IsNotNull(item_id)], ReadSchema: struct<anchor_page_id:string,item_id:string,store_id:string>
>>>
q3 = spark.sql("select count(*) from table_t1 t1 left join table_t1 t2 on t1.anchor_page_id = t2.anchor_page_id and t1.item_id = t2.item_id and t1.store_id = t2.store_id where t1.date_id = 20220323 and t2.date_id = 20220324")
q3.explain()
== Physical Plan ==
*(6) HashAggregate(keys=[], functions=[count(1)])
+- Exchange SinglePartition
+- *(5) HashAggregate(keys=[], functions=[partial_count(1)])
+- *(5) Project
+- *(5) SortMergeJoin [anchor_page_id#1, item_id#2, store_id#5], [anchor_page_id#65, item_id#66, store_id#69], Inner
:- *(2) Sort [anchor_page_id#1 ASC NULLS FIRST, item_id#2 ASC NULLS FIRST, store_id#5 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(anchor_page_id#1, item_id#2, store_id#5, 200)
: +- *(1) Project [anchor_page_id#1, item_id#2, store_id#5]
: +- *(1) Filter ((isnotnull(item_id#2) && isnotnull(anchor_page_id#1)) && isnotnull(store_id#5))
: +- *(1) FileScan parquet table_t1[anchor_page_id#1,item_id#2,store_id#5,date_id#18] Batched: true, Format: Parquet, Location: PrunedInMemoryFileIndex[gs://abc..., PartitionCount: 1, PartitionFilters: [isnotnull(date_id#18), (cast(date_id#18 as int) = 20220323)], PushedFilters: [IsNotNull(item_id), IsNotNull(anchor_page_id), IsNotNull(store_id)], ReadSchema: struct<anchor_page_id:string,item_id:string,store_id:string>
+- *(4) Sort [anchor_page_id#65 ASC NULLS FIRST, item_id#66 ASC NULLS FIRST, store_id#69 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(anchor_page_id#65, item_id#66, store_id#69, 200)
+- *(3) Project [anchor_page_id#65, item_id#66, store_id#69]
+- *(3) Filter ((isnotnull(item_id#66) && isnotnull(store_id#69)) && isnotnull(anchor_page_id#65))
+- *(3) FileScan parquet table_t1[anchor_page_id#65,item_id#66,store_id#69,date_id#82] Batched: true, Format: Parquet, Location: PrunedInMemoryFileIndex[gs://abc..., PartitionCount: 1, PartitionFilters: [isnotnull(date_id#82), (cast(date_id#82 as int) = 20220324)], PushedFilters: [IsNotNull(item_id), IsNotNull(store_id), IsNotNull(anchor_page_id)], ReadSchema: struct<anchor_page_id:string,item_id:string,store_id:string>
q4 = spark.sql("select count(*) from table_t1 t1 left outer join table_t1 t2 on t1.anchor_page_id = t2.anchor_page_id and t1.item_id = t2.item_id and t1.store_id = t2.store_id where t1.date_id = 20220323 and t2.date_id = 20220324")
q4.explain()
== Physical Plan ==
*(6) HashAggregate(keys=[], functions=[count(1)])
+- Exchange SinglePartition
+- *(5) HashAggregate(keys=[], functions=[partial_count(1)])
+- *(5) Project
+- *(5) SortMergeJoin [anchor_page_id#1, item_id#2, store_id#5], [anchor_page_id#88, item_id#89, store_id#92], Inner
:- *(2) Sort [anchor_page_id#1 ASC NULLS FIRST, item_id#2 ASC NULLS FIRST, store_id#5 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(anchor_page_id#1, item_id#2, store_id#5, 200)
: +- *(1) Project [anchor_page_id#1, item_id#2, store_id#5]
: +- *(1) Filter ((isnotnull(item_id#2) && isnotnull(anchor_page_id#1)) && isnotnull(store_id#5))
: +- *(1) FileScan parquet table_t1[anchor_page_id#1,item_id#2,store_id#5,date_id#18] Batched: true, Format: Parquet, Location: PrunedInMemoryFileIndex[gs://abc..., PartitionCount: 1, PartitionFilters: [isnotnull(date_id#18), (cast(date_id#18 as int) = 20220323)], PushedFilters: [IsNotNull(item_id), IsNotNull(anchor_page_id), IsNotNull(store_id)], ReadSchema: struct<anchor_page_id:string,item_id:string,store_id:string>
+- *(4) Sort [anchor_page_id#88 ASC NULLS FIRST, item_id#89 ASC NULLS FIRST, store_id#92 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(anchor_page_id#88, item_id#89, store_id#92, 200)
+- *(3) Project [anchor_page_id#88, item_id#89, store_id#92]
+- *(3) Filter ((isnotnull(store_id#92) && isnotnull(item_id#89)) && isnotnull(anchor_page_id#88))
+- *(3) FileScan parquet table_t1[anchor_page_id#88,item_id#89,store_id#92,date_id#105] Batched: true, Format: Parquet, Location: PrunedInMemoryFileIndex[gs://abc..., PartitionCount: 1, PartitionFilters: [isnotnull(date_id#105), (cast(date_id#105 as int) = 20220324)], PushedFilters: [IsNotNull(store_id), IsNotNull(item_id), IsNotNull(anchor_page_id)], ReadSchema: struct<anchor_page_id:string,item_id:string,store_id:string>
Full join is full outer join.
A where clause on a form of 'outer join' is converted by Optimizer into an 'inner join'.
A where clause on any 'outer' table will make it an 'inner' table. I.e. only rows where that predicate can be evaluated will pass the filter.
Let's say we have two partitioned datasets
val partitionedPersonDS = personDS.repartition(200, personDS("personId"))
val partitionedTransactionDS = transactionDS.repartition(200, transactionDS("personId"))
And we try to join them using joinWith on the same key over which they are partitioned
val transactionPersonDS: Dataset[(Transaction, Person)] = partitionedTransactionDS
.joinWith(
partitionedPersonDS,
partitionedTransactionDS.col("personId") === partitionedPersonDS.col("personId")
)
The Physical plan shows that the already partitioned Dataset's were repartitioned as part of the Sort Merge Join
InMemoryTableScan [_1#14, _2#15]
+- InMemoryRelation [_1#14, _2#15], StorageLevel(disk, memory, deserialized, 1 replicas)
+- *(5) SortMergeJoin [_1#14.personId], [_2#15.personId], Inner
:- *(2) Sort [_1#14.personId ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(_1#14.personId, 200)
: +- *(1) Project [named_struct(transactionId, transactionId#8, personId, personId#9, itemList, itemList#10) AS _1#14]
: +- Exchange hashpartitioning(personId#9, 200)
: +- LocalTableScan [transactionId#8, personId#9, itemList#10]
+- *(4) Sort [_2#15.personId ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(_2#15.personId, 200)
+- *(3) Project [named_struct(personId, personId#2, name, name#3) AS _2#15]
+- Exchange hashpartitioning(personId#2, 200)
+- LocalTableScan [personId#2, name#3]
But when we perform the join using join the already partitioned Dataset's were NOT repartitioned and only Sorted as part of the Sort Merge Join
val transactionPersonDS: DataFrame = partitionedTransactionDS
.join (
partitionedPersonDS,
partitionedTransactionDS("personId") === partitionedPersonDS("personId")
)
InMemoryTableScan [transactionId#8, personId#9, itemList#10, personId#2, name#3]
+- InMemoryRelation [transactionId#8, personId#9, itemList#10, personId#2, name#3], StorageLevel(disk, memory, deserialized, 1 replicas)
+- *(3) SortMergeJoin [personId#9], [personId#2], Inner
:- *(1) Sort [personId#9 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(personId#9, 200)
: +- LocalTableScan [transactionId#8, personId#9, itemList#10]
+- *(2) Sort [personId#2 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(personId#2, 200)
+- LocalTableScan [personId#2, name#3]
Why joinWith fails to honor a pre-partitioned dataset unlike join
my input data is stored in Cassandra and I use a table which primary key is by year,month,day,hour as a source for Spark aggregations.
My Spark application does
Join two tables
Take joined tables and select data by hour
Union selected chunks by hour
Do aggregations on result Dataset and save to Cassandra
Simplifying
val ds1 = spark.read.cassandraFormat(table1, keyspace).load().as[T]
val ds2 = spark.read.cassandraFormat(table2, keyspace).load().as[T]
val dsInput = ds1.join(ds2).coalesce(150)
val dsUnion = for (x <- hours) yield dsInput.select( where hour = x)
val dsResult = mySparkAggregation( dsUnion.reduce(_.union(_)).coalesce(10) )
dsResult.saveToCassadnra
`
The result diagram looks like this (for 3 hours/unions)
Everything works ok when I do only couple of unions e.g 24 (for one day) but when I started running that Spark job for 1 month (720 unions) than I started getting such an error
Total size of serialized results of 1126 tasks (1024.8 MB) is bigger than spark.driver.maxResultSize (1024.0 MB)
The other alarming thing is that the job creates ~100k tasks and one of the stages (the one which caused the error above) contains 74400 tasks and when it processes 1125 it crashes because of maxResultSize. What is more it seems like it has to shuffle data for each hour (union).
I tried to coalesce the number of tasks after union - than it says that the task is too big.
I would be very grateful for any help, suggestions ? I have feeling that I do something wrong.
I did some investigation and got some conclusion
Let's say we have two tables
cb.people
CREATE TABLE cb.people (
id text PRIMARY KEY,
name text
)
and
cb.address
CREATE TABLE cb.address (
people_id text PRIMARY KEY,
name text
)
with the following data
cassandra#cqlsh> select * from cb.people;
id | name
----+---------
3 | Mariusz
2 | Monica
1 | John
cassandra#cqlsh> select * from cb.address;
people_id | name
-----------+--------
3 | POLAND
2 | USA
1 | USA
Now I would like to get joined result for id 1 and 2. There are two possible solutions.
Union two select for id 1 and 2 from table people and then join with the address table
scala> val people = spark.read.cassandraFormat("people", "cb").load()
scala> val usPeople = people.where(col("id") === "1") union people.where(col("id") === "2")
scala> val address = spark.read.cassandraFormat("address", "cb").load()
scala> val joined = usPeople.join(address, address.col("people_id") === usPeople.col("id"))
Join two tables and then union two select for id 1 and 2
scala> val peopleAddress = address.join(usPeople, address.col("people_id") === usPeople.col("id"))
scala> val joined2 = peopleAddress.where(col("id") === "1") union peopleAddress.where(col("id") === "2")
both return the same result
+---------+----+---+------+
|people_id|name| id| name|
+---------+----+---+------+
| 1| USA| 1| John|
| 2| USA| 2|Monica|
+---------+----+---+------+
But looking at the explain I can see a big difference
scala> joined.explain
== Physical Plan ==
*SortMergeJoin [people_id#10], [id#0], Inner
:- *Sort [people_id#10 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(people_id#10, 200)
: +- *Filter (((people_id#10 = 1) || (people_id#10 = 2)) && isnotnull(people_id#10))
: +- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation#3077e4aa [people_id#10,name#11] PushedFilters: [Or(EqualTo(people_id,1),EqualTo(people_id,2)), IsNotNull(people_id)], ReadSchema: struct<people_id:string,name:string>
+- *Sort [id#0 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id#0, 200)
+- Union
:- *Filter isnotnull(id#0)
: +- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation#6846e4e8 [id#0,name#1] PushedFilters: [IsNotNull(id), *EqualTo(id,1)], ReadSchema: struct<id:string,name:string>
+- *Filter isnotnull(id#0)
+- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation#6846e4e8 [id#0,name#1] PushedFilters: [IsNotNull(id), *EqualTo(id,2)], ReadSchema: struct<id:string,name:string>
scala> joined2.explain
== Physical Plan ==
Union
:- *SortMergeJoin [people_id#10], [id#0], Inner
: :- *Sort [people_id#10 ASC NULLS FIRST], false, 0
: : +- Exchange hashpartitioning(people_id#10, 200)
: : +- *Filter isnotnull(people_id#10)
: : +- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation#3077e4aa [people_id#10,name#11] PushedFilters: [*EqualTo(people_id,1), IsNotNull(people_id)], ReadSchema: struct<people_id:string,name:string>
: +- *Sort [id#0 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(id#0, 200)
: +- Union
: :- *Filter isnotnull(id#0)
: : +- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation#6846e4e8 [id#0,name#1] PushedFilters: [IsNotNull(id), *EqualTo(id,1)], ReadSchema: struct<id:string,name:string>
: +- *Filter (isnotnull(id#0) && (id#0 = 1))
: +- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation#6846e4e8 [id#0,name#1] PushedFilters: [IsNotNull(id), *EqualTo(id,2), EqualTo(id,1)], ReadSchema: struct<id:string,name:string>
+- *SortMergeJoin [people_id#10], [id#0], Inner
:- *Sort [people_id#10 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(people_id#10, 200)
: +- *Filter isnotnull(people_id#10)
: +- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation#3077e4aa [people_id#10,name#11] PushedFilters: [IsNotNull(people_id), *EqualTo(people_id,2)], ReadSchema: struct<people_id:string,name:string>
+- *Sort [id#0 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id#0, 200)
+- Union
:- *Filter (isnotnull(id#0) && (id#0 = 2))
: +- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation#6846e4e8 [id#0,name#1] PushedFilters: [IsNotNull(id), *EqualTo(id,1), EqualTo(id,2)], ReadSchema: struct<id:string,name:string>
+- *Filter isnotnull(id#0)
+- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation#6846e4e8 [id#0,name#1] PushedFilters: [IsNotNull(id), *EqualTo(id,2)], ReadSchema: struct<id:string,name:string>
Now it's quite clear for me that what I did was this joined2 version when in the loop for each union there was called join. I though that Spark would be enough smart to reduce that to the first version...
Now current graph looks much more better
I hope that other people will not make the same mistake I made :) Unfortunately I covered spark with my abstraction level which covers that simple problem so spark-shell helped a lot to model the problem.
Using an SQLTransformers we can create new columns in a dataframe and have a Pipeline of these SQLTransformers as well. We can do the same thing using multiple calls to selectExpr methods on dataframes too.
But are the performace optimization metrics that are applied to the selectExpr calls applied to a pipeline of SQLTransformers as well?
For example consider the two snippets of code below:
#Method 1
df = spark.table("transactions")
df = df.selectExpr("*","sum(amt) over (partition by account) as acc_sum")
df = df.selectExpr("*","sum(amt) over (partition by dt) as dt_sum")
df.show(10)
#Method 2
df = spark.table("transactions")
trans1 = SQLTransformer(statement ="SELECT *,sum(amt) over (partition by account) as acc_sum from __THIS__")
trans2 = SQLTransformer(statement ="SELECT *,sum(amt) over (partition by dt) as dt_sum from __THIS__")
pipe = Pipeline(stage[trans1,trans2])
transPipe = pipe.fit(df)
transPipe.transform(df).show(10)
Will the performance for both of these ways of computing the same thing be the same?
Or will there be some extra optimizations that are applied to method 1 that are not used in method 2?
No additional optimizations. As always, when in doubt, check execution plan:
df = spark.createDataFrame([(1, 1, 1)], ("amt", "account", "dt"))
(df
.selectExpr("*","sum(amt) over (partition by account) as acc_sum")
.selectExpr("*","sum(amt) over (partition by dt) as dt_sum")
.explain(True))
generates:
== Parsed Logical Plan ==
'Project [*, 'sum('amt) windowspecdefinition('dt, unspecifiedframe$()) AS dt_sum#165]
+- AnalysisBarrier Project [amt#22L, account#23L, dt#24L, acc_sum#158L]
== Analyzed Logical Plan ==
amt: bigint, account: bigint, dt: bigint, acc_sum: bigint, dt_sum: bigint
Project [amt#22L, account#23L, dt#24L, acc_sum#158L, dt_sum#165L]
+- Project [amt#22L, account#23L, dt#24L, acc_sum#158L, dt_sum#165L, dt_sum#165L]
+- Window [sum(amt#22L) windowspecdefinition(dt#24L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS dt_sum#165L], [dt#24L]
+- Project [amt#22L, account#23L, dt#24L, acc_sum#158L]
+- Project [amt#22L, account#23L, dt#24L, acc_sum#158L]
+- Project [amt#22L, account#23L, dt#24L, acc_sum#158L, acc_sum#158L]
+- Window [sum(amt#22L) windowspecdefinition(account#23L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS acc_sum#158L], [account#23L]
+- Project [amt#22L, account#23L, dt#24L]
+- LogicalRDD [amt#22L, account#23L, dt#24L], false
== Optimized Logical Plan ==
Window [sum(amt#22L) windowspecdefinition(dt#24L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS dt_sum#165L], [dt#24L]
+- Window [sum(amt#22L) windowspecdefinition(account#23L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS acc_sum#158L], [account#23L]
+- LogicalRDD [amt#22L, account#23L, dt#24L], false
== Physical Plan ==
Window [sum(amt#22L) windowspecdefinition(dt#24L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS dt_sum#165L], [dt#24L]
+- *Sort [dt#24L ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(dt#24L, 200)
+- Window [sum(amt#22L) windowspecdefinition(account#23L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS acc_sum#158L], [account#23L]
+- *Sort [account#23L ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(account#23L, 200)
+- Scan ExistingRDD[amt#22L,account#23L,dt#24L]
while
trans2.transform(trans1.transform(df)).explain(True)
generates
== Parsed Logical Plan ==
'Project [*, 'sum('amt) windowspecdefinition('dt, unspecifiedframe$()) AS dt_sum#150]
+- 'UnresolvedRelation `SQLTransformer_4318bd7007cefbf17a97_826abb6c003c`
== Analyzed Logical Plan ==
amt: bigint, account: bigint, dt: bigint, acc_sum: bigint, dt_sum: bigint
Project [amt#22L, account#23L, dt#24L, acc_sum#120L, dt_sum#150L]
+- Project [amt#22L, account#23L, dt#24L, acc_sum#120L, dt_sum#150L, dt_sum#150L]
+- Window [sum(amt#22L) windowspecdefinition(dt#24L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS dt_sum#150L], [dt#24L]
+- Project [amt#22L, account#23L, dt#24L, acc_sum#120L]
+- SubqueryAlias sqltransformer_4318bd7007cefbf17a97_826abb6c003c
+- Project [amt#22L, account#23L, dt#24L, acc_sum#120L]
+- Project [amt#22L, account#23L, dt#24L, acc_sum#120L, acc_sum#120L]
+- Window [sum(amt#22L) windowspecdefinition(account#23L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS acc_sum#120L], [account#23L]
+- Project [amt#22L, account#23L, dt#24L]
+- SubqueryAlias sqltransformer_4688bba599a7f5a09c39_f5e9d251099e
+- LogicalRDD [amt#22L, account#23L, dt#24L], false
== Optimized Logical Plan ==
Window [sum(amt#22L) windowspecdefinition(dt#24L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS dt_sum#150L], [dt#24L]
+- Window [sum(amt#22L) windowspecdefinition(account#23L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS acc_sum#120L], [account#23L]
+- LogicalRDD [amt#22L, account#23L, dt#24L], false
== Physical Plan ==
Window [sum(amt#22L) windowspecdefinition(dt#24L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS dt_sum#150L], [dt#24L]
+- *Sort [dt#24L ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(dt#24L, 200)
+- Window [sum(amt#22L) windowspecdefinition(account#23L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS acc_sum#120L], [account#23L]
+- *Sort [account#23L ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(account#23L, 200)
+- Scan ExistingRDD[amt#22L,account#23L,dt#24L]
As you can see optimized and physical plans are identical.