How to load only the data of the last partition - apache-spark

I have some data partitioned this way:
/data/year=2016/month=9/version=0
/data/year=2016/month=10/version=0
/data/year=2016/month=10/version=1
/data/year=2016/month=10/version=2
/data/year=2016/month=10/version=3
/data/year=2016/month=11/version=0
/data/year=2016/month=11/version=1
When using this data, I'd like to load the last version only of each month.
A simple way to do this is to do load("/data/year=2016/month=11/version=3") instead of doing load("/data").
The drawback of this solution is the loss of partitioning information such as year and month, which means it would not be possible to apply operations based on the year or the month anymore.
Is it possible to ask Spark to load the last version only of each month? How would you go about this?

Well, Spark supports predicate push-down, so if you provide a filter following the load, it will only read in the data fulfilling the criteria in the filter. Like this:
spark.read.option("basePath", "/data").load("/data").filter('version === 3)
And you get to keep the partitioning information :)

Just an addition to previous answers for reference
I have a below ORC format table in hive which is partitioned on year,month & date column.
hive (default)> show partitions test_dev_db.partition_date_table;
OK
year=2019/month=08/day=07
year=2019/month=08/day=08
year=2019/month=08/day=09
If I set below properties, I can read the latest partition data in spark sql as shown below:
spark.sql("SET spark.sql.orc.enabled=true");
spark.sql("SET spark.sql.hive.convertMetastoreOrc=true")
spark.sql("SET spark.sql.orc.filterPushdown=true")
spark.sql("""select * from test_dev_db.partition_date_table where year ='2019' and month='08' and day='07' """).explain(True)
we can see PartitionCount: 1 in plan and it's obvious that it has filtered the latest partition.
== Physical Plan ==
*(1) FileScan orc test_dev_db.partition_date_table[emp_id#212,emp_name#213,emp_salary#214,emp_date#215,year#216,month#217,day#218] Batched: true, Format: ORC, Location: PrunedInMemoryFileIndex[hdfs://xxx.host.com:8020/user/xxxx/dev/hadoop/database/test_dev..., **PartitionCount: 1**, PartitionFilters: [isnotnull(year#216), isnotnull(month#217), isnotnull(day#218), (year#216 = 2019), (month#217 = 0..., PushedFilters: [], ReadSchema: struct<emp_id:int,emp_name:string,emp_salary:int,emp_date:date>
whereas same will not work if I use below query:
even if we create dataframe using spark.read.format("orc").load(hdfs absolute path of table) and create a temporary view and run spark sql on that. It will still scan all the partitions available for that table until and unless we use specific filter condition on top of that.
spark.sql("""select * from test_dev_db.partition_date_table where year ='2019' and month='08' and day in (select max(day) from test_dev_db.partition_date_table)""").explain(True)
It still has scanned all the three partitions, here PartitionCount: 3
== Physical Plan ==
*(2) BroadcastHashJoin [day#282], [max(day)#291], LeftSemi, BuildRight
:- *(2) FileScan orc test_dev_db.partition_date_table[emp_id#276,emp_name#277,emp_salary#278,emp_date#279,year#280,month#281,day#282] Batched: true, Format: ORC, Location: PrunedInMemoryFileIndex[hdfs://xxx.host.com:8020/user/xxx/dev/hadoop/database/test_dev..., PartitionCount: 3, PartitionFilters: [isnotnull(year#280), isnotnull(month#281), (year#280 = 2019), (month#281 = 08)], PushedFilters: [], ReadSchema: struct<emp_id:int,emp_name:string,emp_salary:int,emp_date:date>
To filter out the data based on the max partition using spark sql, we can use the below approach. we can use below technique for partition pruning to limits the number of files and partitions that Spark reads when querying the Hive ORC table data.
rdd=spark.sql("""show partitions test_dev_db.partition_date_table""").rdd.flatMap(lambda x:x)
newrdd=rdd.map(lambda x : x.replace("/","")).map(lambda x : x.replace("year=","")).map(lambda x : x.replace("month=","-")).map(lambda x : x.replace("day=","-")).map(lambda x : x.split('-'))
max_year=newrdd.map(lambda x : (x[0])).max()
max_month=newrdd.map(lambda x : x[1]).max()
max_day=newrdd.map(lambda x : x[2]).max()
prepare your query to filter Hive partition table using these max values.
query = "select * from test_dev_db.partition_date_table where year ='{0}' and month='{1}' and day ='{2}'".format(max_year,max_month,max_day)
>>> spark.sql(query).show();
+------+--------+----------+----------+----+-----+---+
|emp_id|emp_name|emp_salary| emp_date|year|month|day|
+------+--------+----------+----------+----+-----+---+
| 3| Govind| 810000|2019-08-09|2019| 08| 09|
| 4| Vikash| 5500|2019-08-09|2019| 08| 09|
+------+--------+----------+----------+----+-----+---+
spark.sql(query).explain(True)
If you see the plan of this query, you can see that it has scanned only one partition of given Hive table.
here PartitionCount is 1
== Optimized Logical Plan ==
Filter (((((isnotnull(day#397) && isnotnull(month#396)) && isnotnull(year#395)) && (year#395 = 2019)) && (month#396 = 08)) && (day#397 = 09))
+- Relation[emp_id#391,emp_name#392,emp_salary#393,emp_date#394,year#395,month#396,day#397] orc
== Physical Plan ==
*(1) FileScan orc test_dev_db.partition_date_table[emp_id#391,emp_name#392,emp_salary#393,emp_date#394,year#395,month#396,day#397] Batched: true, Format: ORC, Location: PrunedInMemoryFileIndex[hdfs://xxx.host.com:8020/user/xxx/dev/hadoop/database/test_dev..., PartitionCount: 1, PartitionFilters: [isnotnull(day#397), isnotnull(month#396), isnotnull(year#395), (year#395 = 2019), (month#396 = 0..., PushedFilters: [], ReadSchema: struct<emp_id:int,emp_name:string,emp_salary:int,emp_date:date>

I think you have to use Spark's Window Function and then find and filter out the latest version.
import org.apache.spark.sql.functions.{col, first}
import org.apache.spark.sql.expressions.Window
val windowSpec = Window.partitionBy("year","month").orderBy(col("version").desc)
spark.read.load("/data")
.withColumn("maxVersion", first("version").over(windowSpec))
.select("*")
.filter(col("maxVersion") === col("version"))
.drop("maxVersion")
Let me know if this works for you.
Here's a Scala general function:
/**
* Given a DataFrame, use keys (e.g. last modified time), to show the most up to date record
*
* #param dF DataFrame to be parsed
* #param groupByKeys These are the columns you would like to groupBy and expect to be duplicated,
* hence why you're trying to obtain records according to a latest value of keys.
* #param keys The sequence of keys used to rank the records in the table
* #return DataFrame with records that have rank 1, this means the most up to date version of those records
*/
def getLastUpdatedRecords(dF: DataFrame, groupByKeys: Seq[String], keys: Seq[String]): DataFrame = {
val part = Window.partitionBy(groupByKeys.head, groupByKeys.tail: _*).orderBy(array(keys.head, keys.tail: _*).desc)
val rowDF = dF.withColumn("rn", row_number().over(part))
val res = rowDF.filter(col("rn")===1).drop("rn")
res
}

Related

Does spark bring entire hive table to memory

I am in the process of learning the working of Apache Spark and have some basic queries. Let's say I have a Spark application running which connects to a Hive table.
My hive table is as follows:
Name
Age
Marks
A
50
100
B
50
100
C
75
200
When I run the following code snippets, which rows and columns will be loaded into memory during the execution? Will the filtering of rows/columns be done after the entire table is loaded into the memory?
1. spark_session.sql("SELECT name, age from table").collect()
2. spark_session.sql("SELECT * from table WHERE age=50").collect()
3. spark_session.sql("SELECT * from table").select("name", "age").collect()
4. spark_session.sql("SELECT * from table").filter(df.age = 50).collect()
If the datasource supports predicate pushdown then spark will not load entire data to memory while filtering the data.
Let's check the spark plan for hive table with parquet as file format:
>>> df = spark.createDataFrame([('A', 25, 100),('B', 30, 100)], ['name', 'age', 'marks'])
>>> df.write.saveAsTable('table')
>>> spark.sql('select * from table where age=25').explain(True)
== Physical Plan ==
*(1) Filter (isnotnull(age#1389L) AND (age#1389L = 25))
+- *(1) ColumnarToRow
+- FileScan parquet default.table[name#1388,age#1389L,marks#1390L] Batched: true, DataFilters: [isnotnull(age#1389L), (age#1389L = 25)],
Format: Parquet, Location: InMemoryFileIndex[file:/Users/mohan/spark-warehouse/table],
PartitionFilters: [], PushedFilters: [IsNotNull(age), EqualTo(age,25)], ReadSchema: struct<name:string,age:bigint,marks:bigint>
You can verify if filter pushed to underlying storage by looking at PushedFilters: [IsNotNull(age), EqualTo(age,25)]

how to get the partitions info of hive table in Spark

I want to execute the SQL by Spark like this.
sparkSession.sql("select * from table")
But I want to have a partition check on the table before execution avoiding fullscan.
If the table is a partitioned table, my program will force users to add a partition filter. If not it's ok to run.
So my question is how to know whether a table is a partitioned table?
My thought is that reading info from metastore. but how to get metastore is another problem I encounter. Could someone help?
Assuming that your real goal is to restrict execution of unbounded queries, I think it would be easier to get query's execution plan and look under its FileScan / HiveTableScan leaf nodes to see if any partition filters are being applied. For partitioned tables, number of partitions that query is actually going to scan will also be presented, by the way. So, something like this should do:
scala> val df_unbound = spark.sql("select * from hottab")
df_unbound: org.apache.spark.sql.DataFrame = [id: int, descr: string ... 1 more field]
scala> val plan1 = df_unbound.queryExecution.executedPlan.toString
plan1: String =
"*(1) FileScan parquet default.hottab[id#0,descr#1,loaddate#2] Batched: true, Format: Parquet,
Location: CatalogFileIndex[hdfs://ns1/user/hive/warehouse/hottab],
PartitionCount: 365, PartitionFilters: [],
PushedFilters: [], ReadSchema: struct<id:int,descr:string>
"
scala> val df_filtered = spark.sql("select * from hottab where loaddate='2019-07-31'")
df_filtered: org.apache.spark.sql.DataFrame = [id: int, descr: string ... 1 more field]
scala> val plan2 = df_filtered.queryExecution.executedPlan.toString
plan2: String =
"*(1) FileScan parquet default.hottab[id#17,descr#18,loaddate#19] Batched: true, Format: Parquet,
Location: PrunedInMemoryFileIndex[hdfs://ns1/user/hive/warehouse/hottab/loaddate=2019-07-31],
PartitionCount: 1, PartitionFilters: [isnotnull(loaddate#19), (loaddate#19 = 2019-07-31)],
PushedFilters: [], ReadSchema: struct<id:int,descr:string>
"
This way, you also don't have to deal with SQL parsing to find table name(s) from queries, and to interrogate metastore yourself.
As a bonus, you'll be also able to see if "regular" filter pushdown occurs (for storage formats that support it) in addition to partition pruning.
You can use Scala's Try class and execute show partitions on the required table.
val numPartitions = Try(spark.sql("show partitions database.table").count) match {
case Success(v) => v
case Failure(e) => -1
}
Later you can check numPartitions. If the value is -1 then the table is not partitioned.
val listPartitions = spark.sessionState.catalog.listPartitionNames(TableIdentifier("table_name", Some("db name")))
listPartitions: Seq[String] = ArrayBuffer(partition1=value1, ... ) // partition table
listPartitions: Seq[String] = ArrayBuffer() // not partition table
I know this is late, but this might help someone
spark.sql("describe detail database.table").select("partitionColumns").show(false)
this is give the row with the partitioned columns in a array

pyspark - getting Latest partition from Hive partitioned column logic

I am new to pySpark.
I am trying get the latest partition (date partition) of a hive table using PySpark-dataframes and done like below.
But I am sure there is a better way to do it using dataframe functions (not by writing SQL). Could you please share inputs on better ways.
This solution is scanning through entire data on Hive table to get it.
df_1 = sqlContext.table("dbname.tablename");
df_1_dates = df_1.select('partitioned_date_column').distinct().orderBy(df_1['partitioned_date_column'].desc())
lat_date_dict=df_1_dates.first().asDict()
lat_dt=lat_date_dict['partitioned_date_column']
I agree with #philantrovert what has mentioned in the comment. You can use below approach for partition pruning to filter to limit the number of partitions scanned for your hive table.
>>> spark.sql("""show partitions test_dev_db.newpartitiontable""").show();
+--------------------+
| partition|
+--------------------+
|tran_date=2009-01-01|
|tran_date=2009-02-01|
|tran_date=2009-03-01|
|tran_date=2009-04-01|
|tran_date=2009-05-01|
|tran_date=2009-06-01|
|tran_date=2009-07-01|
|tran_date=2009-08-01|
|tran_date=2009-09-01|
|tran_date=2009-10-01|
|tran_date=2009-11-01|
|tran_date=2009-12-01|
+--------------------+
>>> max_date=spark.sql("""show partitions test_dev_db.newpartitiontable""").rdd.flatMap(lambda x:x).map(lambda x : x.replace("tran_date=","")).max()
>>> print max_date
2009-12-01
>>> query = "select city,state,country from test_dev_db.newpartitiontable where tran_date ='{}'".format(max_date)
>>> spark.sql(query).show();
+--------------------+----------------+--------------+
| city| state| country|
+--------------------+----------------+--------------+
| Southampton| England|United Kingdom|
|W Lebanon ...| NH| United States|
| Comox|British Columbia| Canada|
| Gasperich| Luxembourg| Luxembourg|
+--------------------+----------------+--------------+
>>> spark.sql(query).explain(True)
== Parsed Logical Plan ==
'Project ['city, 'state, 'country]
+- 'Filter ('tran_date = 2009-12-01)
+- 'UnresolvedRelation `test_dev_db`.`newpartitiontable`
== Analyzed Logical Plan ==
city: string, state: string, country: string
Project [city#9, state#10, country#11]
+- Filter (tran_date#12 = 2009-12-01)
+- SubqueryAlias newpartitiontable
+- Relation[city#9,state#10,country#11,tran_date#12] orc
== Optimized Logical Plan ==
Project [city#9, state#10, country#11]
+- Filter (isnotnull(tran_date#12) && (tran_date#12 = 2009-12-01))
+- Relation[city#9,state#10,country#11,tran_date#12] orc
== Physical Plan ==
*(1) Project [city#9, state#10, country#11]
+- *(1) FileScan orc test_dev_db.newpartitiontable[city#9,state#10,country#11,tran_date#12] Batched: true, Format: ORC, Location: PrunedInMemoryFileIndex[hdfs://xxx.host.com:8020/user/xxx/dev/hadoop/database/test_dev..., PartitionCount: 1, PartitionFilters: [isnotnull(tran_date#12), (tran_date#12 = 2009-12-01)], PushedFilters: [], ReadSchema: struct<city:string,state:string,country:string>
you can see in above plan that PartitionCount: 1 it has scanned only one partition from 12 available partitions.
Building on Vikrant's answer, here is a more general way of extracting partition column values directly from the table metadata, which avoids Spark scanning through all the files in the table.
First, if your data isn't already registered in a catalog, you'll want to do that so Spark can see the partition details. Here, I'm registering a new table named data.
spark.catalog.createTable(
'data',
path='/path/to/the/data',
source='parquet',
)
spark.catalog.recoverPartitions('data')
partitions = spark.sql('show partitions data')
To show a self-contained answer, however, I'll manually create the partitions DataFrame so you can see what it would look like, along with the solution for extracting a specific column value from it.
from pyspark.sql.functions import (
col,
regexp_extract,
)
partitions = (
spark.createDataFrame(
[
('/country=usa/region=ri/',),
('/country=usa/region=ma/',),
('/country=russia/region=siberia/',),
],
schema=['partition'],
)
)
partition_name = 'country'
(
partitions
.select(
'partition',
regexp_extract(
col('partition'),
pattern=r'(\/|^){}=(\S+?)(\/|$)'.format(partition_name),
idx=2,
).alias(partition_name),
)
.show(truncate=False)
)
The output of this query is:
+-------------------------------+-------+
|partition |country|
+-------------------------------+-------+
|/country=usa/region=ri/ |usa |
|/country=usa/region=ma/ |usa |
|/country=russia/region=siberia/|russia |
+-------------------------------+-------+
The solution in Scala will look very similar to this, except the call to regexp_extract() will look slightly different:
.select(
regexp_extract(
col("partition"),
exp=s"(\\/|^)${partitionName}=(\\S+?)(\\/|$$)",
groupIdx=2
).alias(partitionName).as[String]
)
Again, the benefit of querying partition values in this way is that Spark will not scan all the files in the table to get you the answer. If you have a table with tens or hundreds of thousands of files in it, your time savings will be significant.

Why is my parquet partitioned data slower than non-partitioned one?

My understanding is: If I partition my data on a column I will query by it should be faster. However, when I tried it, it seem to be slower instead why?
I have a users dataframe which I tried partitioning my yearmonth and not.
So I have 1 dataset partitioned by creation_yearmonth.
questionsCleanedDf.repartition("creation_yearmonth") \
.write.partitionBy('creation_yearmonth') \
.parquet('wasb://.../parquet/questions.parquet')
I have another not partitioned
questionsCleanedDf \
.write \
.parquet('wasb://.../parquet/questions_nopartition.parquet')
Then I tried creating a dataframe from these 2 parquet files and running the same query
questionsDf = spark.read.parquet('wasb://.../parquet/questions.parquet')
and
questionsDf = spark.read.parquet('wasb://.../parquet/questions_nopartition.parquet')
The query
spark.sql("""
SELECT * FROM questions
WHERE creation_yearmonth = 201606
""")
It seem like the no partition one is consistently faster or have similar times (~2 - 3s) while partitioned one is slighly slower (~3 - 4s).
I tried to do an explain:
For the partitioned dataset:
== Physical Plan ==
*FileScan parquet [id#6404,title#6405,tags#6406,owner_user_id#6407,accepted_answer_id#6408,view_count#6409,answer_count#6410,comment_count#6411,creation_date#6412,favorite_count#6413,creation_yearmonth#6414] Batched: false, Format: Parquet, Location: InMemoryFileIndex[wasb://data#cs4225.blob.core.windows.net/parquet/questions.parquet], PartitionCount: 1, PartitionFilters: [isnotnull(creation_yearmonth#6414), (creation_yearmonth#6414 = 201606)], PushedFilters: [], ReadSchema: struct<id:int,title:string,tags:array<string>,owner_user_id:int,accepted_answer_id:int,view_count...
PartitionCount: 1 I should since in this case, it can just go directly to the parition it should be faster?
For the non-paritioned one:
== Physical Plan ==
*Project [id#6440, title#6441, tags#6442, owner_user_id#6443, accepted_answer_id#6444, view_count#6445, answer_count#6446, comment_count#6447, creation_date#6448, favorite_count#6449, creation_yearmonth#6450]
+- *Filter (isnotnull(creation_yearmonth#6450) && (creation_yearmonth#6450 = 201606))
+- *FileScan parquet [id#6440,title#6441,tags#6442,owner_user_id#6443,accepted_answer_id#6444,view_count#6445,answer_count#6446,comment_count#6447,creation_date#6448,favorite_count#6449,creation_yearmonth#6450] Batched: false, Format: Parquet, Location: InMemoryFileIndex[wasb://data#cs4225.blob.core.windows.net/parquet/questions_nopartition.parquet], PartitionFilters: [], PushedFilters: [IsNotNull(creation_yearmonth), EqualTo(creation_yearmonth,201606)], ReadSchema: struct<id:int,title:string,tags:array<string>,owner_user_id:int,accepted_answer_id:int,view_count...
Also very surprising. At first the dataset has dates as strings, so I need to do a query like:
spark.sql("""
SELECT * FROM questions
WHERE CAST(creation_date AS date) BETWEEN '2017-06-01' AND '2017-07-01'
""").show(20, False)
I expected this to be even slower but it turns out, it performs the best ~1-2s. Why is that? I thought in this case, it needs to cast each row?
The explain output here:
== Physical Plan ==
*Project [id#6521, title#6522, tags#6523, owner_user_id#6524, accepted_answer_id#6525, view_count#6526, answer_count#6527, comment_count#6528, creation_date#6529, favorite_count#6530]
+- *Filter ((isnotnull(creation_date#6529) && (cast(cast(creation_date#6529 as date) as string) >= 2017-06-01)) && (cast(cast(creation_date#6529 as date) as string) <= 2017-07-01))
+- *FileScan parquet [id#6521,title#6522,tags#6523,owner_user_id#6524,accepted_answer_id#6525,view_count#6526,answer_count#6527,comment_count#6528,creation_date#6529,favorite_count#6530] Batched: false, Format: Parquet, Location: InMemoryFileIndex[wasb://data#cs4225.blob.core.windows.net/filtered/questions.parquet], PartitionFilters: [], PushedFilters: [IsNotNull(creation_date)], ReadSchema: struct<id:string,title:string,tags:array<string>,owner_user_id:string,accepted_answer_id:string,v...
Overpartitioning can actually reduce performance:
If a column has only a few rows matching each value, the number of
directories to process can become a limiting factor, and the data file
in each directory could be too small to take advantage of the Hadoop
mechanism for transmitting data in multi-megabyte blocks.
This excerpt was taken from the documentation of a different Hadoop component, Impala, but the presented argument should be valid to all components of the Hadoop stack.
I think that regardless of the partitioning scheme used, the advantages of partitioning will not be apparent until the table grows way beyond 900 MB-s.

Spark: does DataFrameWriter have to be a blocking step?

I have data partitioned by a column (say, id) and I have this dataset saved some place. Every now and then, I get a smaller incremental dataset with the same structure and I essentially have to upsert my existing data based on my id with a date column deciding which record is the newest. (I don't write it in the same place, I save the whole new blob some place else.)
There are two ways I've been doing this - either grouping in a window and taking the row with the highest date. Or via dropDuplicates, relying on the fact, that my data is ordered. (I'd rather use the former, but I've been trying various things.)
The one big issue is that each id group is not negligible (a few gigabytes), so I was hoping Spark (with n workers) would understand that since I'm reading id-partitioned data and writing id-partitioned data, it would process n ids at once and continually write them to my storage, taking new ids as it's finished with the previous ones.
Unfortunately, what seems to be happening, is that Spark processes all my id groups in one big job (and spills to disk, naturally) before writing anything to disk. It gets really really slow.
The question is thus: Is there a way to force Spark to process these groups and write them as soon as they're ready? Again, they are partitioned, so no other task will affect my partition.
Here's a bit of code that reproduces the problem:
# generate dummy data first
import random
from typing import List
from datetime import datetime, timedelta
from pyspark.sql.functions import desc, col, row_number
from pyspark.sql.window import Window
from pyspark.sql.dataframe import DataFrame
def gen_data(n: int) -> List[tuple]:
names = 'foo, bar, baz, bak'.split(', ')
return [(random.randint(1, 25), random.choice(names), datetime.today() - timedelta(days=random.randint(1, 100))) \
for j in range(n)]
def get_df(n: int) -> DataFrame:
return spark.createDataFrame(gen_data(n), ['id', 'name', 'date'])
n = 10_000
df = get_df(n)
dd = get_df(n*10)
df.write.mode('overwrite').partitionBy('id').parquet('outputs/first')
dd.write.mode('overwrite').partitionBy('id').parquet('outputs/second')
d1 and d2 are both partitioned by id and so is the resulting dataset, but it's not reflected in the plan:
w = Window().partitionBy('id').orderBy(desc('date'))
d1 = spark.read.parquet('outputs/first')
d2 = spark.read.parquet('outputs/second')
d1.union(d2).\
withColumn('rn', row_number().over(w)).filter(col('rn') == 1).drop('rn').\
write.mode('overwrite').partitionBy('id').parquet('outputs/window')
I also tried to explicitly state the partition key (otherwise the code is the same):
d1 = spark.read.parquet('outputs/first').repartition('id')
d2 = spark.read.parquet('outputs/second').repartition('id')
d1.union(d2).\
withColumn('rn', row_number().over(w)).filter(col('rn') == 1).drop('rn').\
write.mode('overwrite').partitionBy('id').parquet('outputs/window2')
Here's the same using dropDuplicates:
d1 = spark.read.parquet('outputs/first')
d2 = spark.read.parquet('outputs/second')
d1.union(d2).\
dropDuplicates(subset=['id']).\
write.mode('overwrite').partitionBy('id').parquet('outputs/window3')
I also tried emphasising that my union is still partitioned using something like this, but again to no avail:
df.union(d2).repartition('id').\
.withColumn...
I could list all partitions (ids), load them one by one while leveraging partition pruning, deduplicating and writing. But that seems like extra boilerplate that shouldn't be necessary. Or is it possible to do this via foreach?
Update (2018-03-27):
Turns out, the information about partitioning is indeed present in the window functionality in one way or another, because when I filter at the very end, partition pruning on the inputs does take place:
d1 = spark.read.parquet('outputs/first')
d2 = spark.read.parquet('outputs/second')
w = Window().partitionBy('id', 'name').orderBy(desc('date'))
d1.union(d2).withColumn('rn', row_number().over(w)).filter(col('rn') == 1).filter(col('id') == 12).explain(True)
Results in
== Physical Plan ==
*(4) Filter (isnotnull(rn#387) && (rn#387 = 1))
+- Window [row_number() windowspecdefinition(id#187, name#185, date#186 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#387], [id#187, name#185], [date#186 DESC NULLS LAST]
+- *(3) Sort [id#187 ASC NULLS FIRST, name#185 ASC NULLS FIRST, date#186 DESC NULLS LAST], false, 0
+- Exchange hashpartitioning(id#187, name#185, 200)
+- Union
:- *(1) FileScan parquet [name#185,date#186,id#187] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/.../spark_perf_partitions/outputs..., PartitionCount: 1, PartitionFilters: [isnotnull(id#187), (id#187 = 12)], PushedFilters: [], ReadSchema: struct<name:string,date:timestamp>
+- *(2) FileScan parquet [name#191,date#192,id#193] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/.../spark_perf_partitions/outputs..., PartitionCount: 1, PartitionFilters: [isnotnull(id#193), (id#193 = 12)], PushedFilters: [], ReadSchema: struct<name:string,date:timestamp>
So it indeed only reads two partitions, one per each file. So I could, instead of looping, just run the code with one filter at a time (the filter being between the window function and .write). Tedious and not very practical, but potentially faster than spilling everything to disk.
Yes this is exactly how spark partitioning works. So it computes the whole lineage and then write in a partitioned form on the disk. There are several advantages for that. One of the important reason is parallel write. So when the computation is done spark can write all the partitions in parallel on the disk. This significantly improves the performance.
If you want to write as an when the data is ready you might as well filter on the dataframe by different Ids and compute the process in a loop and write. However, in my experience this approach requires several iterations on the same dataframe resulting huge performance loss.

Resources