My Question: Why does Spark calculate sum and count from each partition, do an unnecessary (IMHO) shuffle (Exchange hashpartitioning), and then calculate the mean in HashAggregate?
What could've been done: Calculate the mean for each partition and then combine (union) the results.
Details:
I am reading data from Hive table defined below, which is partitioned by date.
spark.sql("""Create External Table If Not Exists daily_temp.daily_temp_2014
(
state_name string,
...
) Partitioned By (
date_local string
)
Location "./daily_temp/"
Stored As ORC""")
It consists of daily measurement of temperature for various points in the US downloaded from EPA Website.
Using the code below, data is loaded from Hive table into PySpark DataFrame:
spark = (
SparkSession.builder
.master("local")
.appName("Hive Partition Test")
.enableHiveSupport()
.config("hive.exec.dynamic.partition", "true")
.config("hive.exec.dynamic.partition.mode", "nonstrict")
.getOrCreate()
)
my_df = spark.sql("select * from daily_temp.daily_temp_2014")
I would like to calculate daily mean temperature per state.
daily_state_mean = (
my_df
.groupBy(
my_df.date_local,
my_df.state_name
)
.agg({"arithmetic_mean":"mean"})
)
And this is part of the physical (execution) plan:
+- *(2) HashAggregate(keys=[date_local#3003, state_name#2998], functions=[avg(cast(arithmetic_mean#2990 as double))], output=[date_local#3003, state_name#2998, avg(CAST(arithmetic_mean AS DOUBLE))#3014])
+- Exchange hashpartitioning(date_local#3003, state_name#2998, 365)
+- *(1) HashAggregate(keys=[date_local#3003, state_name#2998], functions=[partial_avg(cast(arithmetic_mean#2990 as double))], output=[date_local#3003, state_name#2998, sum#3021, count#3022L])
+- HiveTableScan [arithmetic_mean#2990, state_name#2998, date_local#3003], HiveTableRelation `daily_temp`.`daily_temp_2014`, org.apache.hadoop.hive.ql.io.orc.OrcSerde, [...], [date_local#3003]
Your advice and insights are highly appreciated.
There is nothing unexpected here. Spark SQL doesn't preserve partitioning information of the external source (yet).
If you want to optimize shuffles, you have CLUSTER BY / bucketBy your data. If you do, partition information will be used to optimize shuffles.
Reference How to define partitioning of DataFrame?
Related
We got two datasets thats been persisted as follows:
Dataset A:
datasetA.repartition(5, datasetA.col("region"))
.write().mode(saveMode)
.format("parquet")
.partitionBy("region")
.bucketBy(5,"studentId")
.sortBy("studentId")
.option("path", parquetFilesDirectory)
.saveAsTable( database.tableA));
Dataset B:
datasetB.repartition(5, datasetB.col("region"))
.write().mode(saveMode)
.format("parquet")
.partitionBy("region")
.bucketBy(5,"studentId")
.sortBy("studentId")
.option("path", parquetFilesDirectory)
.saveAsTable( database.tableB));
Join on region and studentId causes data shuffle. Below is the join query :
spark.sql("Select count(*) from database.tableA a, database.tableB b where a.studentId = b.studentId and a.region = b.region").show()
What could be the reason for the shuffle when we include the partition key
and how can we mitigate it ?
Yes you can mitigate Shuffling using pre-sort and group tables
I am new to pySpark.
I am trying get the latest partition (date partition) of a hive table using PySpark-dataframes and done like below.
But I am sure there is a better way to do it using dataframe functions (not by writing SQL). Could you please share inputs on better ways.
This solution is scanning through entire data on Hive table to get it.
df_1 = sqlContext.table("dbname.tablename");
df_1_dates = df_1.select('partitioned_date_column').distinct().orderBy(df_1['partitioned_date_column'].desc())
lat_date_dict=df_1_dates.first().asDict()
lat_dt=lat_date_dict['partitioned_date_column']
I agree with #philantrovert what has mentioned in the comment. You can use below approach for partition pruning to filter to limit the number of partitions scanned for your hive table.
>>> spark.sql("""show partitions test_dev_db.newpartitiontable""").show();
+--------------------+
| partition|
+--------------------+
|tran_date=2009-01-01|
|tran_date=2009-02-01|
|tran_date=2009-03-01|
|tran_date=2009-04-01|
|tran_date=2009-05-01|
|tran_date=2009-06-01|
|tran_date=2009-07-01|
|tran_date=2009-08-01|
|tran_date=2009-09-01|
|tran_date=2009-10-01|
|tran_date=2009-11-01|
|tran_date=2009-12-01|
+--------------------+
>>> max_date=spark.sql("""show partitions test_dev_db.newpartitiontable""").rdd.flatMap(lambda x:x).map(lambda x : x.replace("tran_date=","")).max()
>>> print max_date
2009-12-01
>>> query = "select city,state,country from test_dev_db.newpartitiontable where tran_date ='{}'".format(max_date)
>>> spark.sql(query).show();
+--------------------+----------------+--------------+
| city| state| country|
+--------------------+----------------+--------------+
| Southampton| England|United Kingdom|
|W Lebanon ...| NH| United States|
| Comox|British Columbia| Canada|
| Gasperich| Luxembourg| Luxembourg|
+--------------------+----------------+--------------+
>>> spark.sql(query).explain(True)
== Parsed Logical Plan ==
'Project ['city, 'state, 'country]
+- 'Filter ('tran_date = 2009-12-01)
+- 'UnresolvedRelation `test_dev_db`.`newpartitiontable`
== Analyzed Logical Plan ==
city: string, state: string, country: string
Project [city#9, state#10, country#11]
+- Filter (tran_date#12 = 2009-12-01)
+- SubqueryAlias newpartitiontable
+- Relation[city#9,state#10,country#11,tran_date#12] orc
== Optimized Logical Plan ==
Project [city#9, state#10, country#11]
+- Filter (isnotnull(tran_date#12) && (tran_date#12 = 2009-12-01))
+- Relation[city#9,state#10,country#11,tran_date#12] orc
== Physical Plan ==
*(1) Project [city#9, state#10, country#11]
+- *(1) FileScan orc test_dev_db.newpartitiontable[city#9,state#10,country#11,tran_date#12] Batched: true, Format: ORC, Location: PrunedInMemoryFileIndex[hdfs://xxx.host.com:8020/user/xxx/dev/hadoop/database/test_dev..., PartitionCount: 1, PartitionFilters: [isnotnull(tran_date#12), (tran_date#12 = 2009-12-01)], PushedFilters: [], ReadSchema: struct<city:string,state:string,country:string>
you can see in above plan that PartitionCount: 1 it has scanned only one partition from 12 available partitions.
Building on Vikrant's answer, here is a more general way of extracting partition column values directly from the table metadata, which avoids Spark scanning through all the files in the table.
First, if your data isn't already registered in a catalog, you'll want to do that so Spark can see the partition details. Here, I'm registering a new table named data.
spark.catalog.createTable(
'data',
path='/path/to/the/data',
source='parquet',
)
spark.catalog.recoverPartitions('data')
partitions = spark.sql('show partitions data')
To show a self-contained answer, however, I'll manually create the partitions DataFrame so you can see what it would look like, along with the solution for extracting a specific column value from it.
from pyspark.sql.functions import (
col,
regexp_extract,
)
partitions = (
spark.createDataFrame(
[
('/country=usa/region=ri/',),
('/country=usa/region=ma/',),
('/country=russia/region=siberia/',),
],
schema=['partition'],
)
)
partition_name = 'country'
(
partitions
.select(
'partition',
regexp_extract(
col('partition'),
pattern=r'(\/|^){}=(\S+?)(\/|$)'.format(partition_name),
idx=2,
).alias(partition_name),
)
.show(truncate=False)
)
The output of this query is:
+-------------------------------+-------+
|partition |country|
+-------------------------------+-------+
|/country=usa/region=ri/ |usa |
|/country=usa/region=ma/ |usa |
|/country=russia/region=siberia/|russia |
+-------------------------------+-------+
The solution in Scala will look very similar to this, except the call to regexp_extract() will look slightly different:
.select(
regexp_extract(
col("partition"),
exp=s"(\\/|^)${partitionName}=(\\S+?)(\\/|$$)",
groupIdx=2
).alias(partitionName).as[String]
)
Again, the benefit of querying partition values in this way is that Spark will not scan all the files in the table to get you the answer. If you have a table with tens or hundreds of thousands of files in it, your time savings will be significant.
I have posted this question on spark user forum but received no response so asking it here again.
We have a use case where we need to do a Cartesian join and for some reason we are not able to get it work with Dataset API's.
We have two dataset:
one data set with 2 string columns say c1, c2. It is a small data set with ~1 million records. The two columns are both strings of 32 characters so should be less than 500 mb.
We broadcast this dataset
the other data set is little bigger with ~10 million records
val ds1 = spark.read.format("csv").option("header", "true").load(<s3-location>).select("c1", "c2")
ds1.count
val ds2 = spark.read.format("csv").load(<s3-location>).toDF("c11", "c12", "c13", "c14", "c15", "ts")
ds2.count
ds2.crossJoin(broadcast(ds1)).filter($"c1" <= $"c11" && $"c11" <= $"c2").count
If I implement it using RDD api where I broadcast data in ds1 and then filter data in ds2 it works fine.
I have confirmed the broadcast is successful.
2019-02-14 23:11:55 INFO CodeGenerator:54 - Code generated in 10.469136 ms
2019-02-14 23:11:55 INFO TorrentBroadcast:54 - Started reading broadcast variable 29
2019-02-14 23:11:55 INFO TorrentBroadcast:54 - Reading broadcast variable 29 took 6 ms
2019-02-14 23:11:56 INFO CodeGenerator:54 - Code generated in 11.280087 ms
Query Plan:
== Physical Plan ==
BroadcastNestedLoopJoin BuildRight, Cross, ((c1#68 <= c11#13) && (c11#13 <= c2#69))
:- *Project []
: +- *Filter isnotnull(_c0#0)
: +- *FileScan csv [_c0#0,_c1#1,_c2#2,_c3#3,_c4#4,_c5#5] Batched: false, Format: CSV, Location: InMemoryFileIndex[], PartitionFilters: [], PushedFilters: [IsNotNull(_c0)], ReadSchema: struct<_c0:string,_c1:string,_c2:string,_c3:string,_c4:string,_c5:string>
+- BroadcastExchange IdentityBroadcastMode
+- *Project [c1#68, c2#69]
+- *Filter (isnotnull(c1#68) && isnotnull(c2#69))
+- *FileScan csv [c1#68,c2#69] Batched: false, Format: CSV, Location: InMemoryFileIndex[], PartitionFilters: [], PushedFilters: [IsNotNull(c1), IsNotNull(c2)], ReadSchema: struct
then the stage do not progress.
I updated the code to use broadcast ds1 and then did the join in the mapPartitions for ds2.
val ranges = spark.read.format("csv").option("header", "true").load(<s3-location>).select("c1", "c2").collect
val rangesBC = sc.broadcast(ranges)
then used this rangesBC in the mapPartitions method to identify the range each row in ds2 belongs and this job completes in 3 hrs, while the other job does not complete even after 24 hrs. This kind of implies that the query optimizer is not doing what I want it to do.
What am I doing wrong? Any pointers will be helpful. Thank you!
I have run into this issue recently and found that Spark has a strange partitioning behavior when cross joining large dataframes. If your input dataframe contain few million records, then the cross joined dataframe has partitions equal to the multiplication of the input dataframes partition, that is
Partitions of crossJoinDF = (Partitions of ds1) * (Partitions of ds2).
If ds1 or ds2 contain about few hundred partitions then the cross join dataframe would have partitions in the range of ~ 10,000. These are way too many partitions, which result in excessive overhead in managing many small tasks, making any computation (in your case - filter) on the cross joined data frame very slow to run.
So how do you make the computation faster? First check if this is indeed the issue for your problem:
scala> val crossJoinDF = ds2.crossJoin(ds1)
# This should return immediately because of spark lazy evaluation
scala> val crossJoinDFPartitions = crossJoinDF.rdd.partitions.size
Check the number of the partitions on the cross joined dataframe. If crossJoinDFPartitions > 10,000, then you do indeed have the same issue i.e cross joined dataframe has way too many partitions.
To make your operations on cross joined dataframe faster, reduce the number of partitions on the input DataFrames. For example:
scala> val ds1 = ds1.repartition(40)
scala> ds1.rdd.partitions.size
res80: Int = 40
scala> val ds2 = ds2.repartition(40)
scala> ds2.rdd.partitions.size
res81: Int = 40
scala> val crossJoinDF = ds1.crossJoin(ds2)
scala> crossJoinDF.rdd.partitions.size
res82: Int = 1600
scala> crossJoinDF.count()
The count() action should result in execution of the cross join. The count should now return in a reasonable amount of time. The number of exact partitions you choose would depend on number of cores available in your cluster.
The key takeaway here is to make sure that your cross joined dataframe has reasonable number of partitions (<< 10,000). You might also find this post useful which explains this issue in more detail.
I do not know if you are on bare metal or AWS with spot or on-demand or dedicated, or VMs with AZURE, et al. My take:
Appreciate that 10M x 1M is a lot of work, even if .filter applies on the resultant cross join. It will take some time. What were your expectations?
Spark is all about scaling in a linear way in general.
Data Centers with VMs do not have dedicated and hence do not have the fastest performance.
Then:
I ran on Databricks 10M x 100K in a simulated set-up with .86 core and 6GB on Driver for Community Edition. That ran in 17 mins.
I ran the 10M x 1M in your example on a 4 node AWS EMR non-dedicated Cluster (with some EMR-oddities like reserving the Driver on a valuable instance!) it took 3 hours for partial completion. See the picture below.
So, to answer your question:
- You did nothing wrong.
Just just need more resources allowing more parallelisation.
I did add some explicit partitioning as you can see.
I am using Java API for Apache Spark , and i have two Dataset A & B.
The schema for these both is same : PhoneNumber, Name, Age, Address
There is one record in both the Dataset that has PhoneNumber as common, but other columns in this record are different
I run following SQL query on these two Datasets (by registering these as temporary Table):
A.createOrReplaceTempView("A");
B.createOrReplaceTempView("B");
String query = "Select * from A UNION Select * from B";
Dataset<Row> result = sparkSession.sql(query);
result.show();
Surprisingly, the result has only one record with same PhoneNumber, and the other is removed.
I know UNION is SQL query is intended to remove duplicates, but then it also needs to know the Primary Key on the basis of which it decides what is duplicate.
How does this query infer the "Primary key" of my Dataset? (There is no concept of Primary key in Spark)
You can use either UNION ALL:
Seq((1L, "foo")).toDF.createOrReplaceTempView("a")
Seq((1L, "bar"), (1L, "foo")).toDF.createOrReplaceTempView("b")
spark.sql("SELECT * FROM a UNION ALL SELECT * FROM b").explain
== Physical Plan ==
Union
:- LocalTableScan [_1#152L, _2#153]
+- LocalTableScan [_1#170L, _2#171]
or Dataset.union method:
spark.table("a").union(spark.table("b")).explain
== Physical Plan ==
Union
:- LocalTableScan [_1#152L, _2#153]
+- LocalTableScan [_1#170L, _2#171]
How does this query infer the "Primary key" of my Dataset?
I doesn't, or at least not in the current version. It just applies HashAggregate using all available columns:
spark.sql("SELECT * FROM a UNION SELECT * FROM b").explain
== Physical Plan ==
*HashAggregate(keys=[_1#152L, _2#153], functions=[])
+- Exchange hashpartitioning(_1#152L, _2#153, 200)
+- *HashAggregate(keys=[_1#152L, _2#153], functions=[])
+- Union
:- LocalTableScan [_1#152L, _2#153]
+- LocalTableScan [_1#170L, _2#171]
I have some data partitioned this way:
/data/year=2016/month=9/version=0
/data/year=2016/month=10/version=0
/data/year=2016/month=10/version=1
/data/year=2016/month=10/version=2
/data/year=2016/month=10/version=3
/data/year=2016/month=11/version=0
/data/year=2016/month=11/version=1
When using this data, I'd like to load the last version only of each month.
A simple way to do this is to do load("/data/year=2016/month=11/version=3") instead of doing load("/data").
The drawback of this solution is the loss of partitioning information such as year and month, which means it would not be possible to apply operations based on the year or the month anymore.
Is it possible to ask Spark to load the last version only of each month? How would you go about this?
Well, Spark supports predicate push-down, so if you provide a filter following the load, it will only read in the data fulfilling the criteria in the filter. Like this:
spark.read.option("basePath", "/data").load("/data").filter('version === 3)
And you get to keep the partitioning information :)
Just an addition to previous answers for reference
I have a below ORC format table in hive which is partitioned on year,month & date column.
hive (default)> show partitions test_dev_db.partition_date_table;
OK
year=2019/month=08/day=07
year=2019/month=08/day=08
year=2019/month=08/day=09
If I set below properties, I can read the latest partition data in spark sql as shown below:
spark.sql("SET spark.sql.orc.enabled=true");
spark.sql("SET spark.sql.hive.convertMetastoreOrc=true")
spark.sql("SET spark.sql.orc.filterPushdown=true")
spark.sql("""select * from test_dev_db.partition_date_table where year ='2019' and month='08' and day='07' """).explain(True)
we can see PartitionCount: 1 in plan and it's obvious that it has filtered the latest partition.
== Physical Plan ==
*(1) FileScan orc test_dev_db.partition_date_table[emp_id#212,emp_name#213,emp_salary#214,emp_date#215,year#216,month#217,day#218] Batched: true, Format: ORC, Location: PrunedInMemoryFileIndex[hdfs://xxx.host.com:8020/user/xxxx/dev/hadoop/database/test_dev..., **PartitionCount: 1**, PartitionFilters: [isnotnull(year#216), isnotnull(month#217), isnotnull(day#218), (year#216 = 2019), (month#217 = 0..., PushedFilters: [], ReadSchema: struct<emp_id:int,emp_name:string,emp_salary:int,emp_date:date>
whereas same will not work if I use below query:
even if we create dataframe using spark.read.format("orc").load(hdfs absolute path of table) and create a temporary view and run spark sql on that. It will still scan all the partitions available for that table until and unless we use specific filter condition on top of that.
spark.sql("""select * from test_dev_db.partition_date_table where year ='2019' and month='08' and day in (select max(day) from test_dev_db.partition_date_table)""").explain(True)
It still has scanned all the three partitions, here PartitionCount: 3
== Physical Plan ==
*(2) BroadcastHashJoin [day#282], [max(day)#291], LeftSemi, BuildRight
:- *(2) FileScan orc test_dev_db.partition_date_table[emp_id#276,emp_name#277,emp_salary#278,emp_date#279,year#280,month#281,day#282] Batched: true, Format: ORC, Location: PrunedInMemoryFileIndex[hdfs://xxx.host.com:8020/user/xxx/dev/hadoop/database/test_dev..., PartitionCount: 3, PartitionFilters: [isnotnull(year#280), isnotnull(month#281), (year#280 = 2019), (month#281 = 08)], PushedFilters: [], ReadSchema: struct<emp_id:int,emp_name:string,emp_salary:int,emp_date:date>
To filter out the data based on the max partition using spark sql, we can use the below approach. we can use below technique for partition pruning to limits the number of files and partitions that Spark reads when querying the Hive ORC table data.
rdd=spark.sql("""show partitions test_dev_db.partition_date_table""").rdd.flatMap(lambda x:x)
newrdd=rdd.map(lambda x : x.replace("/","")).map(lambda x : x.replace("year=","")).map(lambda x : x.replace("month=","-")).map(lambda x : x.replace("day=","-")).map(lambda x : x.split('-'))
max_year=newrdd.map(lambda x : (x[0])).max()
max_month=newrdd.map(lambda x : x[1]).max()
max_day=newrdd.map(lambda x : x[2]).max()
prepare your query to filter Hive partition table using these max values.
query = "select * from test_dev_db.partition_date_table where year ='{0}' and month='{1}' and day ='{2}'".format(max_year,max_month,max_day)
>>> spark.sql(query).show();
+------+--------+----------+----------+----+-----+---+
|emp_id|emp_name|emp_salary| emp_date|year|month|day|
+------+--------+----------+----------+----+-----+---+
| 3| Govind| 810000|2019-08-09|2019| 08| 09|
| 4| Vikash| 5500|2019-08-09|2019| 08| 09|
+------+--------+----------+----------+----+-----+---+
spark.sql(query).explain(True)
If you see the plan of this query, you can see that it has scanned only one partition of given Hive table.
here PartitionCount is 1
== Optimized Logical Plan ==
Filter (((((isnotnull(day#397) && isnotnull(month#396)) && isnotnull(year#395)) && (year#395 = 2019)) && (month#396 = 08)) && (day#397 = 09))
+- Relation[emp_id#391,emp_name#392,emp_salary#393,emp_date#394,year#395,month#396,day#397] orc
== Physical Plan ==
*(1) FileScan orc test_dev_db.partition_date_table[emp_id#391,emp_name#392,emp_salary#393,emp_date#394,year#395,month#396,day#397] Batched: true, Format: ORC, Location: PrunedInMemoryFileIndex[hdfs://xxx.host.com:8020/user/xxx/dev/hadoop/database/test_dev..., PartitionCount: 1, PartitionFilters: [isnotnull(day#397), isnotnull(month#396), isnotnull(year#395), (year#395 = 2019), (month#396 = 0..., PushedFilters: [], ReadSchema: struct<emp_id:int,emp_name:string,emp_salary:int,emp_date:date>
I think you have to use Spark's Window Function and then find and filter out the latest version.
import org.apache.spark.sql.functions.{col, first}
import org.apache.spark.sql.expressions.Window
val windowSpec = Window.partitionBy("year","month").orderBy(col("version").desc)
spark.read.load("/data")
.withColumn("maxVersion", first("version").over(windowSpec))
.select("*")
.filter(col("maxVersion") === col("version"))
.drop("maxVersion")
Let me know if this works for you.
Here's a Scala general function:
/**
* Given a DataFrame, use keys (e.g. last modified time), to show the most up to date record
*
* #param dF DataFrame to be parsed
* #param groupByKeys These are the columns you would like to groupBy and expect to be duplicated,
* hence why you're trying to obtain records according to a latest value of keys.
* #param keys The sequence of keys used to rank the records in the table
* #return DataFrame with records that have rank 1, this means the most up to date version of those records
*/
def getLastUpdatedRecords(dF: DataFrame, groupByKeys: Seq[String], keys: Seq[String]): DataFrame = {
val part = Window.partitionBy(groupByKeys.head, groupByKeys.tail: _*).orderBy(array(keys.head, keys.tail: _*).desc)
val rowDF = dF.withColumn("rn", row_number().over(part))
val res = rowDF.filter(col("rn")===1).drop("rn")
res
}