Spark SQL .withColumn() vs Column expressions - apache-spark

I would like to know if there is any performance/scalability difference when using intermediate steps/columns in pyspark between:
Using .withColumn() for example:
df = df.withColumn('bar', df.foo + 1)
df = df.withColumn('baz', df.bar + 2)
then calling df.select('baz').collect()
versus
Declaring a Spark column as a Python variable:
bar = df.foo + 1
baz = bar + 2
then calling
df.select(baz.alias('baz')).collect()
Question: If many intermediate steps/columns such as bar are required, would the two options differ in space/time complexity?

I saw my original post was deleted. In hindsight it may well be that that was correct barring the lack of communication. That example was using a foldLeft which is not your use case which is fusing of the data pipeline.
To answer your question, the fusing of data pipeline operations by Catalyst means there is no performance issue either ways as the Physical Plans show:
df = spark.createDataFrame([(x,x) for x in range(7)], ['foo', 'bar',])
df = df.withColumn('bar', df.foo + 1)
df = df.withColumn('baz', df.bar + 2)
df.select('baz').explain(extended=True)
== Physical Plan ==
*(1) Project [(foo#276L + 3) AS baz#283L]
+- *(1) Scan ExistingRDD[foo#276L,bar#277L]
and likewise:
df = spark.createDataFrame([(x,x) for x in range(7)], ['foo', 'bar',])
bar = df.foo + 1
baz = bar + 2
df.select(baz.alias('baz')).explain(extended=True)
== Physical Plan ==
*(1) Project [(foo#288L + 3) AS baz#292L]
+- *(1) Scan ExistingRDD[foo#288L,bar#289L]
They look pretty similar to me ... Notice the optimize of +3.
In addition I draw your attention to using foldLeft with .withColumn https://manuzhang.github.io/2018/07/11/spark-catalyst-cost.html

Related

PySpark data skewness with Window Functions

I have a huge PySpark dataframe and I'm doing a series of Window functions over partitions defined by my key.
The issue with the key is, my partitions gets skewed by this and results in Event Timeline that looks something like this,
I know that I can use salting technique to solve this issue when I'm doing a join. But how can I solve this issue when I'm using Window functions?
I'm using functions like lag, lead etc in the Window functions. I can't do the process with salted key, because I'll get wrong results.
How to solve skewness in this case?
I'm looking for a dynamic way of repartitioning my dataframe without skewness.
Updates based on answer from #jxc
I tried creating a sample df and tried running code over that,
df = pd.DataFrame()
df['id'] = np.random.randint(1, 1000, size=150000)
df['id'] = df['id'].map(lambda x: 100 if x % 2 == 0 else x)
df['timestamp'] = pd.date_range(start=pd.Timestamp('2020-01-01'), periods=len(df), freq='60s')
sdf = sc.createDataFrame(df)
sdf = sdf.withColumn("amt", F.rand()*100)
w = Window.partitionBy("id").orderBy("timestamp")
sdf = sdf.withColumn("new_col", F.lag("amt").over(w) + F.lead("amt").over(w))
x = sdf.toPandas()
This gave me a event timeline like this,
I tried the code from #jxc's answer,
sdf = sc.createDataFrame(df)
sdf = sdf.withColumn("amt", F.rand()*100)
N = 24*3600*365*2
sdf_1 = sdf.withColumn('pid', F.ceil(F.unix_timestamp('timestamp')/N))
w1 = Window.partitionBy('id', 'pid').orderBy('timestamp')
w2 = Window.partitionBy('id', 'pid')
sdf_2 = sdf_1.select(
'*',
F.count('*').over(w2).alias('cnt'),
F.row_number().over(w1).alias('rn'),
(F.lag('amt',1).over(w1) + F.lead('amt',1).over(w1)).alias('new_val')
)
sdf_3 = sdf_2.filter('rn in (1, 2, cnt-1, cnt)') \
.withColumn('new_val', F.lag('amt',1).over(w) + F.lead('amt',1).over(w)) \
.filter('rn in (1,cnt)')
df_new = sdf_2.filter('rn not in (1,cnt)').union(sdf_3)
x = df_new.toPandas()
I ended up one additional stage and the event timeline looked more skewed,
Also the run time is increased by a bit with new code
To process a large partition, you can try split it based on the orderBy column(most likely a numeric column or date/timestamp column which can be converted into numeric) so that all new sub-partitions maintain the correct order of rows. process rows with the new partitioner and for calculation using lag and lead functions, only rows around the boundary between sub-partitions need to be post-processed. (Below also discussed how to merge small partitions in task-2)
Use your example sdf and assume we have the following WinSpec and a simple aggregate function:
w = Window.partitionBy('id').orderBy('timestamp')
df.withColumn('new_amt', F.lag('amt',1).over(w) + F.lead('amt',1).over(w))
Task-1: split large partitions:
Try the following:
select a N to split timestamp and set up an additional partitionBy column pid (using ceil, int, floor etc.):
# N to cover 35-days' intervals
N = 24*3600*35
df1 = sdf.withColumn('pid', F.ceil(F.unix_timestamp('timestamp')/N))
add pid into partitionBy(see w1), then calaulte row_number(), lag() and lead() over w1. find also number of rows (cnt) in each new partition to help identify the end of partitions (rn == cnt). the resulting new_val will be fine for majority of rows except those on the boundaries of each partition.
w1 = Window.partitionBy('id', 'pid').orderBy('timestamp')
w2 = Window.partitionBy('id', 'pid')
df2 = df1.select(
'*',
F.count('*').over(w2).alias('cnt'),
F.row_number().over(w1).alias('rn'),
(F.lag('amt',1).over(w1) + F.lead('amt',1).over(w1)).alias('new_amt')
)
Below is an example df2 showing the boundary rows.
process the boundary: select rows which are on the boundaries rn in (1, cnt) plus those which have values used in the calculation rn in (2, cnt-1), do the same calculation of new_val over w and save result for boundary rows only.
df3 = df2.filter('rn in (1, 2, cnt-1, cnt)') \
.withColumn('new_amt', F.lag('amt',1).over(w) + F.lead('amt',1).over(w)) \
.filter('rn in (1,cnt)')
Below shows the resulting df3 from the above df2
merge df3 back to df2 to update boundary rows rn in (1,cnt)
df_new = df2.filter('rn not in (1,cnt)').union(df3)
Below screenshot shows the final df_new around the boundary rows:
# drop columns which are used to implement logic only
df_new = df_new.drop('cnt', 'rn')
Some Notes:
the following 3 WindowSpec are defined:
w = Window.partitionBy('id').orderBy('timestamp') <-- fix boundary rows
w1 = Window.partitionBy('id', 'pid').orderBy('timestamp') <-- calculate internal rows
w2 = Window.partitionBy('id', 'pid') <-- find #rows in a partition
note: strictly, we'd better use the following w to fix boundary rows to avoid issues with tied timestamp around the boundaries.
w = Window.partitionBy('id').orderBy('pid', 'rn') <-- fix boundary rows
if you know which partitions are skewed, just divide them and skip others. the existing method might split a small partition into 2 or even more if they are sparsely distributed
df1 = df.withColumn('pid', F.when(F.col('id').isin('a','b'), F.ceil(F.unix_timestamp('timestamp')/N)).otherwise(1))
If for each partition, you can retrieve count(number of rows) and min_ts=min(timestamp), then try something more dynamically for pid(below M is the threshold number of rows to split):
F.expr(f"IF(count>{M}, ceil((unix_timestamp(timestamp)-unix_timestamp(min_ts))/{N}), 1)")
note: for skewness inside a partition, will requires more complex functions to generate pid.
if only lag(1) function is used, just post-process left boundaries, filter by rn in (1, cnt) and update only rn == 1
df3 = df1.filter('rn in (1, cnt)') \
.withColumn('new_amt', F.lag('amt',1).over(w)) \
.filter('rn = 1')
similar to lead function when we need only to fix right boundaries and update rn == cnt
if only lag(2) is used, then filter and update more rows with df3:
df3 = df1.filter('rn in (1, 2, cnt-1, cnt)') \
.withColumn('new_amt', F.lag('amt',2).over(w)) \
.filter('rn in (1,2)')
You can extend the same method to mixed cases with both lag and lead having different offset.
Task-2: merge small partitions:
Based on the number of records in a partition count, we can set up an threshold M so that if count>M, the id holds its own partition, otherwise we merge partitions so that #of total records is less than M (below method has a edging case of 2*M-2).
M = 20000
# create pandas df with columns `id`, `count` and `f`, sort rows so that rows with count>=M are located on top
d2 = pd.DataFrame([ e.asDict() for e in sdf.groupby('id').count().collect() ]) \
.assign(f=lambda x: x['count'].lt(M)) \
.sort_values('f')
# add pid column to merge smaller partitions but the total row-count in partition should be less than or around M
# potentially there could be at most `2*M-2` records for the same pid, to make sure strictly count<M, use a for-loop to iterate d1 and set pid:
d2['pid'] = (d2.mask(d2['count'].gt(M),M)['count'].shift(fill_value=0).cumsum()/M).astype(int)
# add pid to sdf. In case join is too heavy, try using Map
sdf_1 = sdf.join(spark.createDataFrame(d2).alias('d2'), ["id"]) \
.select(sdf["*"], F.col("d2.pid"))
# check pid: # of records and # of distinct ids
sdf_1.groupby('pid').agg(F.count('*').alias('count'), F.countDistinct('id').alias('cnt_ids')).orderBy('pid').show()
+---+-----+-------+
|pid|count|cnt_ids|
+---+-----+-------+
| 0|74837| 1|
| 1|20036| 133|
| 2|20052| 134|
| 3|20010| 133|
| 4|15065| 100|
+---+-----+-------+
Now, the new Window should be partitioned by pid alone and move id to orderBy, see below:
w3 = Window.partitionBy('pid').orderBy('id','timestamp')
customize lag/lead functions based on the above w3 WinSpec, and then calculate new_val:
lag_w3 = lambda col,n=1: F.when(F.lag('id',n).over(w3) == F.col('id'), F.lag(col,n).over(w3))
lead_w3 = lambda col,n=1: F.when(F.lead('id',n).over(w3) == F.col('id'), F.lead(col,n).over(w3))
sdf_new = sdf_1.withColumn('new_val', lag_w3('amt',1) + lead_w3('amt',1))
To handle such skewed data, there are a couple of things you can try out.
If you are using Databricks to run your jobs and you know which column will have the skew then you can try out an option called skew hint
I recommend moving to Spark 3.0 since you will have the option to use Adaptive Query Execution (AQE) which can handle most of the issues improving your job health and potentially running them faster.
Usually, I suggest making your data more even-sized partitions before any wide operation, and Increasing the cluster size does help but I am not sure if this will work for you.

Error: Resolved attributes missing in join

I'm using pyspark to perform a join of two tables with a relatively complex join condition (using greater than/smaller than in the join conditions). This works fine, but breaks down as soon as I add a fillna command before the join.
The code looks something like this:
join_cond = [
df_a.col1 == df_b.colx,
df_a.col2 == df_b.coly,
df_a.col3 >= df_b.colz
]
df = (
df_a
.fillna('NA', subset=['col1'])
.join(df_b, join_cond, 'left')
)
This results in an error like this:
org.apache.spark.sql.AnalysisException: Resolved attribute(s) col1#4765 missing from col1#6488,col2#4766,col3#4768,colx#4823,coly#4830,colz#4764 in operator !Join LeftOuter, (((col1#4765 = colx#4823) && (col2#4766 = coly#4830)) && (col3#4768 >= colz#4764)). Attribute(s) with the same name appear in the operation: col1. Please check if the right attribute(s) are used.
It looks like spark no longer recognizes col1 after performing the fillna. (The error does not come up if I comment that out.) The problem is that I do need that statement. (And in general I've simplified this example a lot.)
I've looked at this question, but these answers do not work for me. Specifically, using .alias('a') after the fillna doesn't work because then spark does not recognize the a in the join condition.
Could someone:
Explain exactly why this is happening and how I can avoid it in the future?
Advise me on a way to solve it?
Thanks in advance for your help.
What is happening?
In order to "replace" empty values, a new dataframe is created that contains new columns. These new columns have the same names like the old ones but are effectively completely new Spark objects. In the Scala code you can see that the "changed" columns are newly created ones while the original columns are dropped.
A way to see this effect is to call explain on the dataframe before and after replacing the empty values:
df_a.explain()
prints
== Physical Plan ==
*(1) Project [_1#0L AS col1#6L, _2#1L AS col2#7L, _3#2L AS col3#8L]
+- *(1) Scan ExistingRDD[_1#0L,_2#1L,_3#2L]
while
df_a.fillna(42, subset=['col1']).explain()
prints
== Physical Plan ==
*(1) Project [coalesce(_1#0L, 42) AS col1#27L, _2#1L AS col2#7L, _3#2L AS col3#8L]
+- *(1) Scan ExistingRDD[_1#0L,_2#1L,_3#2L]
Both plans contain a column called col1, but in the first case the internal representation is called col1#6L while the second one is called col1#27L.
When the join condition df_a.col1 == df_b.colx now is associated with the column col1#6L the join will fail if only the column col1#27L is part of the left table.
How can the problem be solved?
The obvious way would be to move the `fillna` operation before the definition of the join condition:
df_a = df_a.fillna('NA', subset=['col1'])
join_cond = [
df_a.col1 == df_b.colx,
[...]
If this is not possible or wanted you can change the join condition. Instead of using a column from the dataframe (df_a.col1) you can use a column that is not associated with any dataframe by using the col function. This column works only based on its name and therefore ignores when the column is replaced in the dataframe:
from pyspark.sql import functions as F
join_cond = [
F.col("col1") == df_b.colx,
df_a.col2 == df_b.coly,
df_a.col3 >= df_b.colz
]
The downside of this second approach is that the column names in both tables must be unique.

Modifying query using Spark Catalyst Logical Plan

Is it possible to add/replace existing column expression in
DataFrame API/SQL using extension point.
Ex: assume we inject resolution rule which could check the project
node from the plan and on checking for column "name", replace it
with upper(name) for instance.
Is such a thing possible using Extension Points. The examples which i have
found are mostly simple, which do not manipulate the input expressions in the manner i need.
Kindly let me know if this is possible.
Yes this is possible.
Lets take an example. Suppose we want to write a rule which checks for Project operator and if the project is for some particular column (say 'column2'), then it multiply it by 2.
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.Column
import org.apache.spark.sql.types._
object DoubleColumn2OptimizationRule extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Project =>
if (p.projectList.filter(_.name == "column2").size >= 1) {
val newList = p.projectList.map { case x =>
if (x.name == "column2") {
Alias(Multiply(Literal(2, IntegerType), x), "column2_doubled")()
} else {
x
}
}
p.copy(projectList = newList)
} else {
p
}
}
}
say we have a table "table1" which has two columns - column1, column2.
Without this rule -
> spark.sql("select column2 from table1 limit 10").collect()
Array([1], [2], [3], [4], [5], [6], [7], [8], [9], [10])
with this rule -
> spark.experimental.extraOptimizations = Seq(DoubleColumn2OptimizationRule)
> spark.sql("select column2 from table1 limit 10").collect()
Array([2], [4], [6], [8], [10], [12], [14], [16], [18], [20])
Also you can call explain on DataFrame to check the plan -
> spark.sql("select column2 from table1 limit 10").explain
== Physical Plan ==
CollectLimit 10
+- *(1) LocalLimit 10
+- *(1) Project [(2 * column2#213) AS column2_doubled#214]
+- HiveTableScan [column2#213], HiveTableRelation `default`.`table1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [column1#212, column2#213]

How to check if spark dataframe is empty in pyspark [duplicate]

Right now, I have to use df.count > 0 to check if the DataFrame is empty or not. But it is kind of inefficient. Is there any better way to do that?
PS: I want to check if it's empty so that I only save the DataFrame if it's not empty
For Spark 2.1.0, my suggestion would be to use head(n: Int) or take(n: Int) with isEmpty, whichever one has the clearest intent to you.
df.head(1).isEmpty
df.take(1).isEmpty
with Python equivalent:
len(df.head(1)) == 0 # or bool(df.head(1))
len(df.take(1)) == 0 # or bool(df.take(1))
Using df.first() and df.head() will both return the java.util.NoSuchElementException if the DataFrame is empty. first() calls head() directly, which calls head(1).head.
def first(): T = head()
def head(): T = head(1).head
head(1) returns an Array, so taking head on that Array causes the java.util.NoSuchElementException when the DataFrame is empty.
def head(n: Int): Array[T] = withAction("head", limit(n).queryExecution)(collectFromPlan)
So instead of calling head(), use head(1) directly to get the array and then you can use isEmpty.
take(n) is also equivalent to head(n)...
def take(n: Int): Array[T] = head(n)
And limit(1).collect() is equivalent to head(1) (notice limit(n).queryExecution in the head(n: Int) method), so the following are all equivalent, at least from what I can tell, and you won't have to catch a java.util.NoSuchElementException exception when the DataFrame is empty.
df.head(1).isEmpty
df.take(1).isEmpty
df.limit(1).collect().isEmpty
I know this is an older question so hopefully it will help someone using a newer version of Spark.
I would say to just grab the underlying RDD. In Scala:
df.rdd.isEmpty
in Python:
df.rdd.isEmpty()
That being said, all this does is call take(1).length, so it'll do the same thing as Rohan answered...just maybe slightly more explicit?
I had the same question, and I tested 3 main solution :
(df != null) && (df.count > 0)
df.head(1).isEmpty() as #hulin003 suggest
df.rdd.isEmpty() as #Justin Pihony suggest
and of course the 3 works, however in term of perfermance, here is what I found, when executing the these methods on the same DF in my machine, in terme of execution time :
it takes ~9366ms
it takes ~5607ms
it takes ~1921ms
therefore I think that the best solution is df.rdd.isEmpty() as #Justin Pihony suggest
Since Spark 2.4.0 there is Dataset.isEmpty.
It's implementation is :
def isEmpty: Boolean =
withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan =>
plan.executeCollect().head.getLong(0) == 0
}
Note that a DataFrame is no longer a class in Scala, it's just a type alias (probably changed with Spark 2.0):
type DataFrame = Dataset[Row]
You can take advantage of the head() (or first()) functions to see if the DataFrame has a single row. If so, it is not empty.
If you do df.count > 0. It takes the counts of all partitions across all executors and add them up at Driver. This take a while when you are dealing with millions of rows.
The best way to do this is to perform df.take(1) and check if its null. This will return java.util.NoSuchElementException so better to put a try around df.take(1).
The dataframe return an error when take(1) is done instead of an empty row. I have highlighted the specific code lines where it throws the error.
If you are using Pyspark, you could also do:
len(df.head(1)) > 0
For Java users you can use this on a dataset :
public boolean isDatasetEmpty(Dataset<Row> ds) {
boolean isEmpty;
try {
isEmpty = ((Row[]) ds.head(1)).length == 0;
} catch (Exception e) {
return true;
}
return isEmpty;
}
This check all possible scenarios ( empty, null ).
PySpark 3.3.0+ / Scala 2.4.0+
df.isEmpty()
On PySpark, you can also use this bool(df.head(1)) to obtain a True of False value
It returns False if the dataframe contains no rows
In Scala you can use implicits to add the methods isEmpty() and nonEmpty() to the DataFrame API, which will make the code a bit nicer to read.
object DataFrameExtensions {
implicit def extendedDataFrame(dataFrame: DataFrame): ExtendedDataFrame =
new ExtendedDataFrame(dataFrame: DataFrame)
class ExtendedDataFrame(dataFrame: DataFrame) {
def isEmpty(): Boolean = dataFrame.head(1).isEmpty // Any implementation can be used
def nonEmpty(): Boolean = !isEmpty
}
}
Here, other methods can be added as well. To use the implicit conversion, use import DataFrameExtensions._ in the file you want to use the extended functionality. Afterwards, the methods can be used directly as so:
val df: DataFrame = ...
if (df.isEmpty) {
// Do something
}
I found that on some cases:
>>>print(type(df))
<class 'pyspark.sql.dataframe.DataFrame'>
>>>df.take(1).isEmpty
'list' object has no attribute 'isEmpty'
this is same for "length" or replace take() by head()
[Solution] for the issue we can use.
>>>df.limit(2).count() > 1
False
If you want only to find out whether the DataFrame is empty, then df.isEmpty, df.head(1).isEmpty() or df.rdd.isEmpty() should work, these are taking a limit(1) if you examine them:
== Physical Plan ==
*(2) HashAggregate(keys=[], functions=[count(1)], output=[count#52L])
+- *(2) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#60L])
+- *(2) GlobalLimit 1
+- Exchange SinglePartition
+- *(1) LocalLimit 1
+- *(1) InMemoryTableScan
+- InMemoryRelation [value#32L], StorageLevel(disk, memory, deserialized, 1 replicas)
... // the rest of the plan related to your computation
But if you are doing some other computation that requires a lot of memory and you don't want to cache your DataFrame just to check whether it is empty, then you can use an accumulator:
def accumulateRows(acc: LongAccumulator)(df: DataFrame): DataFrame =
df.map { row => // we map to the same row, count during this map
acc.add(1)
row
}(RowEncoder(df.schema))
val rowAccumulator = spark.sparkContext.longAccumulator("Row Accumulator")
val countedDF = df.transform(accumulateRows(rowAccumulator))
countedDF.write.saveAsTable(...) // main action
val isEmpty = rowAccumulator.isZero
Note that to see the row count, you should first perform the action. If we change the order of the last 2 lines, isEmpty will be true regardless of the computation.
df1.take(1).length>0
The take method returns the array of rows, so if the array size is equal to zero, there are no records in df.
Let's suppose we have the following empty dataframe:
df = spark.sql("show tables").limit(0)
If you are using Spark 2.1, for pyspark, to check if this dataframe is empty, you can use:
df.count() > 0
Or
bool(df.head(1))
You can do it like:
val df = sqlContext.emptyDataFrame
if( df.eq(sqlContext.emptyDataFrame) )
println("empty df ")
else
println("normal df")
dataframe.limit(1).count > 0
This also triggers a job but since we are selecting single record, even in case of billion scale records the time consumption could be much lower.
From:
https://medium.com/checking-emptiness-in-distributed-objects/count-vs-isempty-surprised-to-see-the-impact-fa70c0246ee0

Transforming Spark SQL AST with extraOptimizations

I'm wanting to take a SQL string as a user input, then transform it before execution. In particular, I want to modify the top-level projection (select clause), injecting additional columns to be retrieved by the query.
I was hoping to achieve this by hooking into Catalyst using sparkSession.experimental.extraOptimizations. I know that what I'm attempting isn't strictly speaking an optimisation (the transformation changes the semantics of the SQL statement), but the API still seems suitable. However, my transformation seems to be ignored by the query executor.
Here is a minimal example to illustrate the issue I'm having. First define a row case class:
case class TestRow(a: Int, b: Int, c: Int)
Then define an optimisation rule which simply discards any projection:
object RemoveProjectOptimisationRule extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case x: Project => x.child
}
}
Now create a dataset, register the optimisation, and run a SQL query:
// Create a dataset and register table.
val dataset = List(TestRow(1, 2, 3)).toDS()
val tableName: String = "testtable"
dataset.createOrReplaceTempView(tableName)
// Register "optimisation".
sparkSession.experimental.extraOptimizations =
Seq(RemoveProjectOptimisationRule)
// Run query.
val projected = sqlContext.sql("SELECT a FROM " + tableName + " WHERE a = 1")
// Print query result and the queryExecution object.
println("Query result:")
projected.collect.foreach(println)
println(projected.queryExecution)
Here is the output:
Query result:
[1]
== Parsed Logical Plan ==
'Project ['a]
+- 'Filter ('a = 1)
+- 'UnresolvedRelation `testtable`
== Analyzed Logical Plan ==
a: int
Project [a#3]
+- Filter (a#3 = 1)
+- SubqueryAlias testtable
+- LocalRelation [a#3, b#4, c#5]
== Optimized Logical Plan ==
Filter (a#3 = 1)
+- LocalRelation [a#3, b#4, c#5]
== Physical Plan ==
*Filter (a#3 = 1)
+- LocalTableScan [a#3, b#4, c#5]
We see that the result is identical to that of the original SQL statement, without the transformation applied. Yet, when printing the logical and physical plans, the projection has indeed been removed. I've also confirmed (through debug log output) that the transformation is indeed being invoked.
Any suggestions as to what's going on here? Maybe the optimiser simply ignores "optimisations" that change semantics?
If using the optimisations isn't the way to go, can anybody suggest an alternative? All I really want to do is parse the input SQL statement, transform it, and pass the transformed AST to Spark for execution. But as far as I can see, the APIs for doing this are private to the Spark sql package. It may be possible to use reflection, but I'd like to avoid that.
Any pointers would be much appreciated.
As you guessed, this is failing to work because we make assumptions that the optimizer will not change the results of the query.
Specifically, we cache the schema that comes out of the analyzer (and assume the optimizer does not change it). When translating rows to the external format, we use this schema and thus are truncating the columns in the result. If you did more than truncate (i.e. changed datatypes) this might even crash.
As you can see in this notebook, it is in fact producing the result you would expect under the covers. We are planning to open up more hooks at some point in the near future that would let you modify the plan at other phases of query execution. See SPARK-18127 for more details.
Michael Armbrust's answer confirmed that this kind of transformation shouldn't be done via optimisations.
I've instead used internal APIs in Spark to achieve the transformation I wanted for now. It requires methods that are package-private in Spark. So we can access them without reflection by putting the relevant logic in the appropriate package. In outline:
// Must be in the spark.sql package.
package org.apache.spark.sql
object SQLTransformer {
def apply(sparkSession: SparkSession, ...) = {
// Get the AST.
val ast = sparkSession.sessionState.sqlParser.parsePlan(sql)
// Transform the AST.
val transformedAST = ast match {
case node: Project => // Modify any top-level projection
...
}
// Create a dataset directly from the AST.
Dataset.ofRows(sparkSession, transformedAST)
}
}
Note that this of course may break with future versions of Spark.

Resources