Related
I'd plan to read data from a very large BigQuery table then output with 61,000 sequential records, I've tried code below:
TMP_BUCKET = "stg-gcs-bucket"
MAX_PARTITION_BYTES = str(512 * 1024 * 1024)
# 1k Account per file
# MAX_ROW_NUM_PER_FILE = "18300"
MAX_ROW_NUM_PER_FILE = "61000"
spark = SparkSession \
.builder \
.master('yarn') \
.appName('crs-bq-export-csv') \
.config('spark.sql.execution.arrow.pyspark.enabled', 'true') \
.config('spark.jars', 'gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.12-0.23.2.jar') \
.config("spark.sql.broadcastTimeout", "36000") \
.config("spark.sql.files.maxRecordsPerFile", MAX_ROW_NUM_PER_FILE) \
.config("spark.sql.files.maxPartitionBytes", MAX_PARTITION_BYTES) \
.config("spark.files.maxPartitionBytes", MAX_PARTITION_BYTES) \
.config("spark.driver.maxResultSize", "24g") \
.config("spark.sql.execution.arrow.pyspark.enabled", "true") \
.getOrCreate()
#Try to read full data from BQ
df = spark.read.format('bigquery') \
.option('table', TABLE_NAME) \
.load()
df.sort('colA').sort('colB').write.mode('overwrite').csv(OUTPUT_PATH, header=True)
but the final results didn't sort with the colA and colB and they are all inordinate:
Expected CSV:
colA colB
1. 1
2. 2
3. 3
....
60001 60001
But got:
colA colB
2. 1
3. 3
2. 2
1. 3
I checked the spark doc and it will shullfle all dfs in order to get better performance, but I need to get the final csv with specific order, how can I achieve this?
How can I do for this case? Any helps will be super helpful!
I create the dataframe like this:
data = [("2.", "1"),
("3.", "3"),
("2.", "2"),
("1.", "3")]
columns = ["colA", "colB"]
df = spark.createDataFrame(data, columns)
df.show()
+----+----+
|colA|colB|
+----+----+
|2. |1 |
|3. |3 |
|2. |2 |
|1. |3 |
+----+----+
If I run your code I get:
df.sort('colA').sort('colB').show()
+----+----+
|colA|colB|
+----+----+
| 2.| 1|
| 2.| 2|
| 1.| 3|
| 3.| 3|
+----+----+
Let's look at the execution plan it sorts by colB:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Sort [colB#1 ASC NULLS FIRST], true, 0
+- Exchange rangepartitioning(colB#1 ASC NULLS FIRST, 200), ENSURE_REQUIREMENTS, [plan_id=94]
+- Scan ExistingRDD[colA#0,colB#1]
And that is in line with the way the sort function is implemented - it sorts the whole dataframe based on the column values from the columns you have passed to the sort function. So, the final effect of chaining sort function calls has means that the resulting dataframe will be sorted based on the last sort function call.
Here is the correct approach for your use case:
df.sort('colA', 'colB').show()
df.sort('colA', 'colB').explain()
+----+----+
|colA|colB|
+----+----+
| 1.| 3|
| 2.| 1|
| 2.| 2|
| 3.| 3|
+----+----+
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Sort [colA#0 ASC NULLS FIRST, colB#1 ASC NULLS FIRST], true, 0
+- Exchange rangepartitioning(colA#0 ASC NULLS FIRST, colB#1 ASC NULLS FIRST, 200), ENSURE_REQUIREMENTS, [plan_id=148]
+- Scan ExistingRDD[colA#0,colB#1]
As you can see in the output dataframe and in the execution plan, it sorts by both columns because I am passing both columns to the sort function, first by colA and then by colB.
Recently I was asked in an interview about the algorithm of spark df.show() function.
how will spark decide from which executor/executors the records will be fetched?
Without undermining #thebluephantom's and #Hristo Iliev's answers (each give some insight into what's happening under the hood), I also wanted to add my answer to this list.
I came to the same conclusion(s), albeit by observing the behaviour of the underlying partitions.
Partitions have an index associated with them. This is seen in the code below.
(Taken from original spark source code here).
trait Partition extends Serializable {
def index: Int
:
So amongst partitions, there is an order.
And as already mentioned in other answers, the df.show() is the same as df.show(20) or the top 20 rows. So the underlying partition indexes determine which partition (and hence executor) those rows come from.
The partition indexes are assigned at the time of read, or (re-assigned) during a shuffle.
Here is some code to see this behaviour:
val df = Seq((5,5), (6,6), (7,7), (8,8), (1,1), (2,2), (3,3), (4,4)).toDF("col1", "col2")
// above sequence is defined out of order - to make behaviour visible
// see partition structure
df.rdd.glom().collect()
/* Array(Array([5,5]), Array([6,6]), Array([7,7]), Array([8,8]), Array([1,1]), Array([2,2]), Array([3,3]), Array([4,4])) */
df.show(4, false)
/*
+----+----+
|col1|col2|
+----+----+
|5 |5 |
|6 |6 |
|7 |7 |
|8 |8 |
+----+----+
only showing top 4 rows
*/
In the above code, we see 8 partitions (each inner-Array is a partition) - this is because spark defaults to 8 partitions when we create a dataframe.
Now let's repartition the dataframe.
// Now let's repartition df
val df2 = df.repartition(2)
// lets see the partition structure
df2.rdd.glom().collect()
/* Array(Array([5,5], [6,6], [7,7], [8,8], [1,1], [2,2], [3,3], [4,4]), Array()) */
// lets see output
df2.show(4,false)
/*
+----+----+
|col1|col2|
+----+----+
|5 |5 |
|6 |6 |
|7 |7 |
|8 |8 |
+----+----+
only showing top 4 rows
*/
In the above code, the top 4 rows came from the first partition (which actually has all elements of the original data in it). Also note the skew in partition sizes, since no partitioning column was mentioned.
Now lets try and create 3 partitions
val df3 = df.repartition(3)
// lets see partition structures
df3.rdd.glom().collect()
/*
Array(Array([8,8], [1,1], [2,2]), Array([5,5], [6,6]), Array([7,7], [3,3], [4,4]))
*/
// And lets see the top 4 rows this time
df3.show(4, false)
/*
+----+----+
|col1|col2|
+----+----+
|8 |8 |
|1 |1 |
|2 |2 |
|5 |5 |
+----+----+
only showing top 4 rows
*/
From the above code, we observe that Spark went to the first partition and tried to get 4 rows. Since only 3 were available, it grabbed those. Then moved on to the next partition, and got one more row. Thus, the order that you see from show(4, false), is due to the underlying data partitioning and the index ordering amongst partitions.
This example uses show(4), but this behaviour can be extended to show() or show(20).
It's simple.
In Spark 2+, show() calls showString() to format the data as a string and then prints it. showString() calls getRows() to get the top rows of the dataset as a collection of strings. getRows() calls take() to take the original rows and transforms them into strings. take() simply wraps head(). head() calls limit() to build a limit query and executes it. limit() adds a Limit(n) node at the front of the logical plan which is really a GlobalLimit(n, LocalLimit(n)). Both GlobalLimit and LocalLimit are subclasses of OrderPreservingUnaryNode that override its maxRows (in GlobalLimit) or maxRowsPerPartition (in LocalLimit) methods. The logical plan now looks like:
GlobalLimit n
+- LocalLimit n
+- ...
This goes through analysis and optimisation by Catalyst where limits are removed if something down the tree produces less rows than the limit and ends up as CollectLimitExec(m) (where m <= n) in the execution strategy, so the physical plan looks like:
CollectLimit m
+- ...
CollectLimitExec executes its child plan, then checks how many partitions the RDD has. If none, it returns an empty dataset. If one, it runs mapPartitionsInternal(_.take(m)) to take the first m elements. If more than one, it applies take(m) on each partition in the RDD using mapPartitionsInternal(_.take(m)), builds a shuffle RDD that collects the results in a single partition, then again applies take(m).
In other words, it depends (because optimisation phase), but in the general case it takes the top rows of the concatenation of the top rows of each partition and so involves all executors holding a part of the dataset.
OK, perhaps not so simple.
A shitty question as not what u would use in prod.
It is a smart action that looks at what you have in terms of transformations.
Show() is in fact show(20). If just show it looks at 1st and consecutive partitions to get 20 rows. Order by also optimized. A count does need complete processing.
Many google posts btw.
I have the two following dataframes:
df1:
+---+---+
| a| b|
+---+---+
| 1| 2|
| 1| 3|
| 1| 4|
| 2| 5|
| 2| 6|
| 3| 7|
| 3| 8|
+---+---+
info:
+---+---+------------+
| a| b| i|
+---+---+------------+
| 1| 2|1 --> 2 info|
| 1| 3|1 --> 3 info|
| 7| 3|3 --> 7 info|
+---+---+------------+
For each row in 'df1' I want to find a corresponding row in 'info':
select df1.*, info.i from df1
join info
on
(df1.a = info.a and df1.b = info.b)
This works and generates the following explain plan:
*(5) Project [a#0L, b#1L, i#6]
+- *(5) SortMergeJoin [a#0L, b#1L], [a#4L, b#5L], Inner
:- *(2) Sort [a#0L ASC NULLS FIRST, b#1L ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(a#0L, b#1L, 200), ENSURE_REQUIREMENTS, [id=#37]
: +- *(1) Filter (isnotnull(a#0L) AND isnotnull(b#1L))
: +- *(1) Scan ExistingRDD[a#0L,b#1L]
+- *(4) Sort [a#4L ASC NULLS FIRST, b#5L ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(a#4L, b#5L, 200), ENSURE_REQUIREMENTS, [id=#43]
+- *(3) Filter (isnotnull(a#4L) AND isnotnull(b#5L))
+- *(3) Scan ExistingRDD[a#4L,b#5L,i#6]
However, looking at the output:
+---+---+------------+
| a| b| i|
+---+---+------------+
| 1| 3|1 --> 3 info|
| 1| 2|1 --> 2 info|
+---+---+------------+
I understand this is not good enough for me. 'info' table has no meaning to order. So I want the record a=3, b=7 in df1 to be paired with the record a=7, b=3 in info.
select df1.*, info.i from df1
join info
on
(df1.a = info.a and df1.b = info.b) or
(df1.a = info.b and df1.b = info.a)
Output is exactly as I wanted:
+---+---+------------+
| a| b| i|
+---+---+------------+
| 1| 2|1 --> 2 info|
| 1| 3|1 --> 3 info|
| 3| 7|3 --> 7 info|
+---+---+------------+
However, the explain plan worries me:
== Physical Plan ==
*(3) Project [a#0L, b#1L, i#6]
+- CartesianProduct (((a#0L = a#4L) AND (b#1L = b#5L)) OR ((a#0L = b#5L) AND (b#1L = a#4L)))
:- *(1) Scan ExistingRDD[a#0L,b#1L]
+- *(2) Scan ExistingRDD[a#4L,b#5L,i#6]
Questions:
Is adding the OR clause correct? We can assume 'df1' and 'info' tables are unique in (a,b). df1 is ordered but info is not.
why did the plan change?
I am running on Spark 3.1.2
Adding an OR condition to a join clause makes it impossible to easily ensure that rows to be joined are regrouped on the same executor and can be effectively matched within the same executor. Thus in this case Spark is forced to use the naive cartesian product algorithm to join dataframes
To simplify, let's look at the first step of Spark join algorithm: regrouping the rows that will be matched together to the same executor. The second step (efficiently join rows within the same executor) is a bit more complicated but have the same issue with OR condition.
Spark regroups rows of a dataframe to executors (repartition) by using an hash function on join columns of dataframe, and send rows with similar hashes to same executor (for instance, similar hashes can be hashes starting with the same characters)
Example
Let's take two dataframe to join, df1 and df2. Here is df1:
row
k1
k2
v1
1
1
A
X1
2
1
B
X2
3
2
A
X3
4
2
B
X4
And here is df2:
row
k1
k2
v2
5
1
A
Y1
6
1
B
Y2
7
2
A
Y3
8
2
B
Y4
We will join those two dataframes on a Spark cluster with 2 executors
AND join condition
We use df1.k1 = df2.k1 AND df1.k2 = df2.k2 for join condition
If we use concatenation of k1 and k2 (very bad hash function, by the way) as an hash function and use the first character to regroup, the first two rows of df1 will be matched with the first two rows of df2 and the last two rows of df1 will be matched with the first two rows of df2.
So you can split your two dataframes, putting the first two rows on one executor and last two rows on the other executor and then perform the join independently on each executor.
OR join condition
However, if we change join condition to an OR condition, df1.k1 = df2.k1 OR df1.k2 = df2.k2
In this case, hashing doesn't work anymore, as row 1 of df1 will be matched to row 5, 6 and 7 of df2 and row 2 of df1 will be matched to row 5, 6 and 8 of df2. So the first two rows of df1 will be matched to all rows of df2. And the first two rows of df2 will be matched to all rows of df1.
So you can't split your dataframes to send on executors parts that can be treated independently anymore.
Conclusion
As we can see, for OR condition, you can't use hash to distribute rows for a join. So, when Spark see an OR condition on a join, it can't select hash based join algorithms. For a similar reason, it can't select Sort Merge Join algorithm.
So two algorithms remains available for Spark to chose: Broadcast Nested Loop Join and Cartesian Product. If one of the dataframe is small, Spark will use Broadcast Nested Loop Join. Else Spark will use Cartesian Product
That's why adding an OR condition makes Spark use a Cartesian Product plan.
Max and min can be evaluated and used for join (Scala):
val df1 = Seq(
(1, 2),
(1, 3),
(1, 4),
(2, 5),
(2, 6),
(3, 7),
(3, 8)
).toDF("a", "b")
.withColumn("maxValue", when($"a">$"b", $"a").otherwise($"b"))
.withColumn("minValue", when($"a">$"b", $"b").otherwise($"a"))
val info = Seq(
(1, 2, "1 --> 2 info"),
(1, 3, "1 --> 3 info"),
(7, 3, "3 --> 7 info")
).toDF("a", "b", "i")
.withColumn("maxValue", when($"a">$"b", $"a").otherwise($"b"))
.withColumn("minValue", when($"a">$"b", $"b").otherwise($"a"))
df1
.join(info, Seq("maxValue", "minValue"))
// drop unused columns
.drop("maxValue", "minValue")
.drop(info.col("a")).drop(info.col("b"))
Output:
+---+---+------------+
|a |b |i |
+---+---+------------+
|1 |2 |1 --> 2 info|
|1 |3 |1 --> 3 info|
|3 |7 |3 --> 7 info|
+---+---+------------+
I went through the documentation here: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html
It says:
for repartition: resulting DataFrame is hash partitioned.
for repartitionByRange: resulting DataFrame is range partitioned.
And a previous question also mentions it. However, I still don't understand how exactly they differ and what the impact will be when choosing one over the other?
More importantly, if repartition does hash partitioning, what impact does providing columns as its argument have?
I think it is best to look into the difference with some experiments.
Test Dataframes
For this experiment, I am using the following two Dataframes (I am showing the code in Scala but the concept is identical to Python APIs):
// Dataframe with one column "value" containing the values ranging from 0 to 1000000
val df = Seq(0 to 1000000: _*).toDF("value")
// Dataframe with one column "value" containing 1000000 the number 0 in addition to the numbers 5000, 10000 and 100000
val df2 = Seq((0 to 1000000).map(_ => 0) :+ 5000 :+ 10000 :+ 100000: _*).toDF("value")
Theory
repartition applies the HashPartitioner when one or more columns are provided and the RoundRobinPartitioner when no column is provided. If one or more columns are provided (HashPartitioner), those values will be hashed and used to determine the partition number by calculating something like partition = hash(columns) % numberOfPartitions. If no column is provided (RoundRobinPartitioner) the data gets evenly distributed across the specified number of partitions.
repartitionByRange will partition the data based on a range of the column values. This is usually used for continuous (not discrete) values such as any kind of numbers. Note that due to performance reasons this method uses sampling to estimate the ranges. Hence, the output may not be consistent, since sampling can return different values. The sample size can be controlled by the config spark.sql.execution.rangeExchange.sampleSizePerPartition.
It is also worth mentioning that for both methods if numPartitions is not given, by default it partitions the Dataframe data into spark.sql.shuffle.partitions configured in your Spark session, and could be coalesced by Adaptive Query Execution (available since Spark 3.x).
Test Setup
Based on the given Testdata I am always applying the same code:
val testDf = df
// here I will insert the partition logic
.withColumn("partition", spark_partition_id()) // applying SQL built-in function to determine actual partition
.groupBy(col("partition"))
.agg(
count(col("value")).as("count"),
min(col("value")).as("min_value"),
max(col("value")).as("max_value"))
.orderBy(col("partition"))
testDf.show(false)
Test Results
df.repartition(4, col("value"))
As expected, we get 4 partitions and because the values of df are ranging from 0 to 1000000 we see that their hashed values will result in a well distributed Dataframe.
+---------+------+---------+---------+
|partition|count |min_value|max_value|
+---------+------+---------+---------+
|0 |249911|12 |1000000 |
|1 |250076|6 |999994 |
|2 |250334|2 |999999 |
|3 |249680|0 |999998 |
+---------+------+---------+---------+
df.repartitionByRange(4, col("value"))
Also in this case, we get 4 partitions but this time the min and max values clearly shows the ranges of values within a partition. It is almost equally distributed with 250000 values per partition.
+---------+------+---------+---------+
|partition|count |min_value|max_value|
+---------+------+---------+---------+
|0 |244803|0 |244802 |
|1 |255376|244803 |500178 |
|2 |249777|500179 |749955 |
|3 |250045|749956 |1000000 |
+---------+------+---------+---------+
df2.repartition(4, col("value"))
Now, we are using the other Dataframe df2. Here, the hashing algorithm is hashing the values which are only 0, 5000, 10000 or 100000. Of course, the hash of the value 0 will always be the same, so all Zeros end up in the same partition (in this case partition 3). The other two partitions only contain one value.
+---------+-------+---------+---------+
|partition|count |min_value|max_value|
+---------+-------+---------+---------+
|0 |1 |100000 |100000 |
|1 |1 |10000 |10000 |
|2 |1 |5000 |5000 |
|3 |1000001|0 |0 |
+---------+-------+---------+---------+
df2.repartition(4)
Without using the content of the column "value" the repartition method will distribute the messages on a RoundRobin basis. All partitions have almost the same amount of data.
+---------+------+---------+---------+
|partition|count |min_value|max_value|
+---------+------+---------+---------+
|0 |250002|0 |5000 |
|1 |250002|0 |10000 |
|2 |249998|0 |100000 |
|3 |250002|0 |0 |
+---------+------+---------+---------+
df2.repartitionByRange(4, col("value"))
This case shows that the Dataframe df2 is not well defined for a repartitioning by range as almost all values are 0. Therefore, we end up having only two partitions whereas the partition 0 contains all Zeros.
+---------+-------+---------+---------+
|partition|count |min_value|max_value|
+---------+-------+---------+---------+
|0 |1000001|0 |0 |
|1 |3 |5000 |100000 |
+---------+-------+---------+---------+
By using df.explain you can get much information about these operations.
I'm using this DataFrame for the example :
df = spark.createDataFrame([(i, f"value {i}") for i in range(1, 22, 1)], ["id", "value"])
Repartition
Depending on whether a key expression (column) is specified or not, the partitioning method will be different. It is not always hash partitioning as you said.
df.repartition(3).explain(True)
== Parsed Logical Plan ==
Repartition 3, true
+- LogicalRDD [id#0L, value#1], false
== Analyzed Logical Plan ==
id: bigint, value: string
Repartition 3, true
+- LogicalRDD [id#0L, value#1], false
== Optimized Logical Plan ==
Repartition 3, true
+- LogicalRDD [id#0L, value#1], false
== Physical Plan ==
Exchange RoundRobinPartitioning(3)
+- Scan ExistingRDD[id#0L,value#1]
We can see in the generated physical plan that RoundRobinPartitioning is used:
Represents a partitioning where rows are distributed evenly across
output partitions by starting from a random target partition number
and distributing rows in a round-robin fashion. This partitioning is
used when implementing the DataFrame.repartition() operator.
When using repartition by column expression:
df.repartition(3, "id").explain(True)
== Parsed Logical Plan ==
'RepartitionByExpression ['id], 3
+- LogicalRDD [id#0L, value#1], false
== Analyzed Logical Plan ==
id: bigint, value: string
RepartitionByExpression [id#0L], 3
+- LogicalRDD [id#0L, value#1], false
== Optimized Logical Plan ==
RepartitionByExpression [id#0L], 3
+- LogicalRDD [id#0L, value#1], false
== Physical Plan ==
Exchange hashpartitioning(id#0L, 3)
+- Scan ExistingRDD[id#0L,value#1]
Now the picked partitioning method is hashpartitioning.
In hash partitioning method, a Java Object.hashCode is being calculated for every key expression to determine the destination partition_id by calculating a modulo: key.hashCode % numPartitions.
RepartitionByRange
This partitioning method creates numPartitions consecutive and not overlapping ranges of values based on the partitioning key. Thus, at least one key expression is required and needs to be orderable.
df.repartitionByRange(3, "id").explain(True)
== Parsed Logical Plan ==
'RepartitionByExpression ['id ASC NULLS FIRST], 3
+- LogicalRDD [id#0L, value#1], false
== Analyzed Logical Plan ==
id: bigint, value: string
RepartitionByExpression [id#0L ASC NULLS FIRST], 3
+- LogicalRDD [id#0L, value#1], false
== Optimized Logical Plan ==
RepartitionByExpression [id#0L ASC NULLS FIRST], 3
+- LogicalRDD [id#0L, value#1], false
== Physical Plan ==
Exchange rangepartitioning(id#0L ASC NULLS FIRST, 3)
+- Scan ExistingRDD[id#0L,value#1]
Looking at the generated physical plan, we can see that rangepartitioning differs from the two others described above by the presence of the ordering clause in the partitioning expression. When no explicit sort order is specified in the expression, it uses ascending order by default.
Some interesting links:
Repartition Logical Operators — Repartition and RepartitionByExpression
Range partitioning in Apache SparkSQL
hash vs range partitioning
we have a spark streaming application (spark 2.1 run over Hortonworks 2.6) and use the DataSet.repartition (on a DataSet<Row> that's read from Kafka) in order to repartition the DataSet<Row>'s partitions according to a given column (called block_id).
We start with a DataSet<Row>containing 50 partitions and end up (after the call to DataSet.repartition) with number of partitions equivalent to the number of unique block_id's.
The problem is that the DataSet.repartition behaves not as we expected - when we look at the event timeline of the spark job that runs the repartition, we see there are several tasks that handle 1 block_id and fewer tasks that handle 2 block_id's or even 3 or 4 block_id's.
It seems that DataSet.repartition ensures that all the Rows with the same block_id will be inside a single partition, but not that each task that creates a partition will handle only one block_id.
The result is that the repartition job (that runs inside the streaming application) takes as much time as its longest task (which is the task that handles the most block_id's.
We tried playing with the number of Vcores given to the streaming app - from 10 to 25 to 50 (we have 50 partitions in the original RDD that's read from Kafka) but the result was the same - there's always one or more task that handles more than one block_id.
We even tried increasing the batch time, again that didn't help us to achieve the goal of one task handling one block_id.
To give an example - here's the event timeline and the tasks table describing a run of the repartitionspark job:
event timeline - the two tasks in red are the ones that handle two block_id's:
tasks table - the two tasks in red are the same two from above - notice the duration of each of them is twice as the duration of all other tasks (that handle only one block_id)
This is a problem for us because the streaming application is delayed due to these long tasks and we need a solution that will enable us to perform repartition on a DataSet while having each task handling only one block_id.
And if that's not possible then maybe that's possible on an JavaRDD? Since in our case the DataSet<Row> we run the repartition on is created from a JavaRDD.
The 2 problems you need to consider:
Have a custom partitioner that assures data uniform distribution, 1 block_id / partition
Sizing the cluster so that you have enough executors to run all tasks (block_ids) simultaneously
As you've seen a simple repartition on the DataFrame doesn't assure you'll get an uniform distribution. When you repartition by block_id it will use the HashPartitioner, with formula:
Utils.nonNegativeMod(key.hashCode, numPartitions)
See: https://github.com/apache/spark/blob/branch-2.2/core/src/main/scala/org/apache/spark/Partitioner.scala#L80-L88
It's very possible 2+ keys are assigned to the same partition_id as the partition_id is key's hashCode modulo numPartitions.
What you need can be achieved by using the RDD with a custom partitioner. The easiest will be to extract the list of distinct block_ids before repartitioning.
Here's a simple example. Let's say you can have 5 blocks (2,3,6,8,9) and your cluster has 8 executors (can run up to 8 tasks simultaneously), we're over-provisioned by 3 executors:
scala> spark.conf.get("spark.sql.shuffle.partitions")
res0: String = 8
scala> spark.conf.get("spark.default.parallelism")
res1: String = 8
// Basic class to store dummy records
scala> case class MyRec(block_id: Int, other: String)
defined class MyRec
// Sample DS
scala> val ds = List((2,"A"), (3,"X"), (3, "B"), (9, "Y"), (6, "C"), (9, "M"), (6, "Q"), (2, "K"), (2, "O"), (6, "W"), (2, "T"), (8, "T")).toDF("block_id", "other").as[MyRec]
ds: org.apache.spark.sql.Dataset[MyRec] = [block_id: int, other: string]
scala> ds.show
+--------+-----+
|block_id|other|
+--------+-----+
| 2| A|
| 3| X|
| 3| B|
| 9| Y|
| 6| C|
| 9| M|
| 6| Q|
| 2| K|
| 2| O|
| 6| W|
| 2| T|
| 8| T|
+--------+-----+
// Default partitioning gets data distributed as uniformly as possible (record count)
scala> ds.rdd.getNumPartitions
res3: Int = 8
// Print records distribution by partition
scala> ds.rdd.mapPartitionsWithIndex((idx, it) => Iterator((idx, it.toList))).toDF("partition_id", "block_ids").show
+------------+--------------+
|partition_id| block_ids|
+------------+--------------+
| 0| [[2,A]]|
| 1|[[3,X], [3,B]]|
| 2| [[9,Y]]|
| 3|[[6,C], [9,M]]|
| 4| [[6,Q]]|
| 5|[[2,K], [2,O]]|
| 6| [[6,W]]|
| 7|[[2,T], [8,T]]|
+------------+--------------+
// repartitioning by block_id leaves 4 partitions empty and assigns 2 block_ids (6,9) to same partition (1)
scala> ds.repartition('block_id).rdd.mapPartitionsWithIndex((idx, it) => Iterator((idx, it.toList))).toDF("partition_id", "block_ids").where(size('block_ids) > 0).show(false)
+------------+-----------------------------------+
|partition_id|block_ids |
+------------+-----------------------------------+
|1 |[[9,Y], [6,C], [9,M], [6,Q], [6,W]]|
|3 |[[3,X], [3,B]] |
|6 |[[2,A], [2,K], [2,O], [2,T]] |
|7 |[[8,T]] |
+------------+-----------------------------------+
// Create a simple mapping for block_id to partition_id to be used by our custom partitioner (logic may be more elaborate or static if the list of block_ids is static):
scala> val mappings = ds.map(_.block_id).dropDuplicates.collect.zipWithIndex.toMap
mappings: scala.collection.immutable.Map[Int,Int] = Map(6 -> 1, 9 -> 0, 2 -> 3, 3 -> 2, 8 -> 4)
//custom partitioner assigns partition_id according to the mapping arg
scala> class CustomPartitioner(mappings: Map[Int,Int]) extends org.apache.spark.Partitioner {
| override def numPartitions: Int = mappings.size
| override def getPartition(rec: Any): Int = { mappings.getOrElse(rec.asInstanceOf[Int], 0) }
| }
defined class CustomPartitioner
// Repartition DS using new partitioner
scala> val newDS = ds.rdd.map(r => (r.block_id, r)).partitionBy(new CustomPartitioner(mappings)).toDS
newDS: org.apache.spark.sql.Dataset[(Int, MyRec)] = [_1: int, _2: struct<block_id: int, other: string>]
// Display evenly distributed block_ids
scala> newDS.rdd.mapPartitionsWithIndex((idx, it) => Iterator((idx, it.toList))).toDF("partition_id", "block_ids").where(size('block_ids) > 0).show(false)
+------------+--------------------------------------------+
|partition_id|block_ids |
+------------+--------------------------------------------+
|0 |[[9,[9,Y]], [9,[9,M]]] |
|1 |[[6,[6,C]], [6,[6,Q]], [6,[6,W]]] |
|2 |[[3,[3,X]], [3,[3,B]]] |
|3 |[[2,[2,A]], [2,[2,K]], [2,[2,O]], [2,[2,T]]]|
|4 |[[8,[8,T]]] |
+------------+--------------------------------------------+