How to check if spark dataframe is empty in pyspark [duplicate] - apache-spark

Right now, I have to use df.count > 0 to check if the DataFrame is empty or not. But it is kind of inefficient. Is there any better way to do that?
PS: I want to check if it's empty so that I only save the DataFrame if it's not empty

For Spark 2.1.0, my suggestion would be to use head(n: Int) or take(n: Int) with isEmpty, whichever one has the clearest intent to you.
df.head(1).isEmpty
df.take(1).isEmpty
with Python equivalent:
len(df.head(1)) == 0 # or bool(df.head(1))
len(df.take(1)) == 0 # or bool(df.take(1))
Using df.first() and df.head() will both return the java.util.NoSuchElementException if the DataFrame is empty. first() calls head() directly, which calls head(1).head.
def first(): T = head()
def head(): T = head(1).head
head(1) returns an Array, so taking head on that Array causes the java.util.NoSuchElementException when the DataFrame is empty.
def head(n: Int): Array[T] = withAction("head", limit(n).queryExecution)(collectFromPlan)
So instead of calling head(), use head(1) directly to get the array and then you can use isEmpty.
take(n) is also equivalent to head(n)...
def take(n: Int): Array[T] = head(n)
And limit(1).collect() is equivalent to head(1) (notice limit(n).queryExecution in the head(n: Int) method), so the following are all equivalent, at least from what I can tell, and you won't have to catch a java.util.NoSuchElementException exception when the DataFrame is empty.
df.head(1).isEmpty
df.take(1).isEmpty
df.limit(1).collect().isEmpty
I know this is an older question so hopefully it will help someone using a newer version of Spark.

I would say to just grab the underlying RDD. In Scala:
df.rdd.isEmpty
in Python:
df.rdd.isEmpty()
That being said, all this does is call take(1).length, so it'll do the same thing as Rohan answered...just maybe slightly more explicit?

I had the same question, and I tested 3 main solution :
(df != null) && (df.count > 0)
df.head(1).isEmpty() as #hulin003 suggest
df.rdd.isEmpty() as #Justin Pihony suggest
and of course the 3 works, however in term of perfermance, here is what I found, when executing the these methods on the same DF in my machine, in terme of execution time :
it takes ~9366ms
it takes ~5607ms
it takes ~1921ms
therefore I think that the best solution is df.rdd.isEmpty() as #Justin Pihony suggest

Since Spark 2.4.0 there is Dataset.isEmpty.
It's implementation is :
def isEmpty: Boolean =
withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan =>
plan.executeCollect().head.getLong(0) == 0
}
Note that a DataFrame is no longer a class in Scala, it's just a type alias (probably changed with Spark 2.0):
type DataFrame = Dataset[Row]

You can take advantage of the head() (or first()) functions to see if the DataFrame has a single row. If so, it is not empty.

If you do df.count > 0. It takes the counts of all partitions across all executors and add them up at Driver. This take a while when you are dealing with millions of rows.
The best way to do this is to perform df.take(1) and check if its null. This will return java.util.NoSuchElementException so better to put a try around df.take(1).
The dataframe return an error when take(1) is done instead of an empty row. I have highlighted the specific code lines where it throws the error.

If you are using Pyspark, you could also do:
len(df.head(1)) > 0

For Java users you can use this on a dataset :
public boolean isDatasetEmpty(Dataset<Row> ds) {
boolean isEmpty;
try {
isEmpty = ((Row[]) ds.head(1)).length == 0;
} catch (Exception e) {
return true;
}
return isEmpty;
}
This check all possible scenarios ( empty, null ).

PySpark 3.3.0+ / Scala 2.4.0+
df.isEmpty()

On PySpark, you can also use this bool(df.head(1)) to obtain a True of False value
It returns False if the dataframe contains no rows

In Scala you can use implicits to add the methods isEmpty() and nonEmpty() to the DataFrame API, which will make the code a bit nicer to read.
object DataFrameExtensions {
implicit def extendedDataFrame(dataFrame: DataFrame): ExtendedDataFrame =
new ExtendedDataFrame(dataFrame: DataFrame)
class ExtendedDataFrame(dataFrame: DataFrame) {
def isEmpty(): Boolean = dataFrame.head(1).isEmpty // Any implementation can be used
def nonEmpty(): Boolean = !isEmpty
}
}
Here, other methods can be added as well. To use the implicit conversion, use import DataFrameExtensions._ in the file you want to use the extended functionality. Afterwards, the methods can be used directly as so:
val df: DataFrame = ...
if (df.isEmpty) {
// Do something
}

I found that on some cases:
>>>print(type(df))
<class 'pyspark.sql.dataframe.DataFrame'>
>>>df.take(1).isEmpty
'list' object has no attribute 'isEmpty'
this is same for "length" or replace take() by head()
[Solution] for the issue we can use.
>>>df.limit(2).count() > 1
False

If you want only to find out whether the DataFrame is empty, then df.isEmpty, df.head(1).isEmpty() or df.rdd.isEmpty() should work, these are taking a limit(1) if you examine them:
== Physical Plan ==
*(2) HashAggregate(keys=[], functions=[count(1)], output=[count#52L])
+- *(2) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#60L])
+- *(2) GlobalLimit 1
+- Exchange SinglePartition
+- *(1) LocalLimit 1
+- *(1) InMemoryTableScan
+- InMemoryRelation [value#32L], StorageLevel(disk, memory, deserialized, 1 replicas)
... // the rest of the plan related to your computation
But if you are doing some other computation that requires a lot of memory and you don't want to cache your DataFrame just to check whether it is empty, then you can use an accumulator:
def accumulateRows(acc: LongAccumulator)(df: DataFrame): DataFrame =
df.map { row => // we map to the same row, count during this map
acc.add(1)
row
}(RowEncoder(df.schema))
val rowAccumulator = spark.sparkContext.longAccumulator("Row Accumulator")
val countedDF = df.transform(accumulateRows(rowAccumulator))
countedDF.write.saveAsTable(...) // main action
val isEmpty = rowAccumulator.isZero
Note that to see the row count, you should first perform the action. If we change the order of the last 2 lines, isEmpty will be true regardless of the computation.

df1.take(1).length>0
The take method returns the array of rows, so if the array size is equal to zero, there are no records in df.

Let's suppose we have the following empty dataframe:
df = spark.sql("show tables").limit(0)
If you are using Spark 2.1, for pyspark, to check if this dataframe is empty, you can use:
df.count() > 0
Or
bool(df.head(1))

You can do it like:
val df = sqlContext.emptyDataFrame
if( df.eq(sqlContext.emptyDataFrame) )
println("empty df ")
else
println("normal df")

dataframe.limit(1).count > 0
This also triggers a job but since we are selecting single record, even in case of billion scale records the time consumption could be much lower.
From:
https://medium.com/checking-emptiness-in-distributed-objects/count-vs-isempty-surprised-to-see-the-impact-fa70c0246ee0

Related

Order of evaluation of predicates in Spark SQL where clause

I am trying to understand the order of predicate evaluation in Spark SQL in order to increase performance of a query.
Let's say I have the following query
"select * from tbl where pred1 and pred2"
and lets say that none of the predicates qualify as pushdown filters (for simplification).
Also lets assume that pred1 is computationally much more complex than pred2 (assume regex pattern matching vs negation).
Is there any way to verify that spark will evaluate pred2 before
pred1?
Is this deterministic?
Is this controllable?
Is there any way to see the final execution plan?
General
Good question.
Inferred answer via testing a scenario and making deductions as could not find the suitable docs. 2nd attempt due to all sorts of statements on the web not able to be backed up.
This question I think is not about AQE Spark 3.x aspects, but it is
about say, a dataframe as part of Stage N of a Spark App that has
passed the stage of acquiring data from sources at rest, which is
subject to filtering with multiple predicates being applied.
Then the central point is does it matter how the predicates are
ordered or does Spark (Catalyst) re-order the predicates to minimize
the work to be done?
The premise here is that filtering the maximum amount of data out first makes more sense than evaluating a predicate that filters very
little out.
This is a well-known RDBMS point referring to sargable predicates (subject to evolution of definition over time).
A lot of the discussion focused on indexes, Spark, Hive do not have this, but DF's are columnar.
Point 1
You can try for %sql
EXPLAIN EXTENDED select k, sum(v) from values (1, 2), (1, 3) t(k, v) group by k;
From this you can see what's going on if there is re-arranging of
predicates, but I saw no such aspects in the Physical Plan in non-AQE
mode on Databricks. Refer to
https://docs.databricks.com/sql/language-manual/sql-ref-syntax-qry-explain.html.
Catalyst can re-arrange filtering I read here and there. To what
extent, is a lot of research; I was not able to confirm this.
Also an interesting read:
https://www.waitingforcode.com/apache-spark-sql/catalyst-optimizer-in-spark-sql/read
Point 2
I ran the following pathetic contrived examples with the same
functional query but with predicates reversed, using a column that has
high cardinality and tested for a value that does not in fact exist
and then compared the count of the accumulator used in an UDF when called.
Scenario 1
import org.apache.spark.sql.functions._
def randomInt1to1000000000 = scala.util.Random.nextInt(1000000000)+1
def randomInt1to10 = scala.util.Random.nextInt(10)+1
def randomInt1to1000000 = scala.util.Random.nextInt(1000000)+1
val df = sc.parallelize(Seq.fill(1000000){(randomInt1to1000000,randomInt1to1000000000,randomInt1to10)}).toDF("nuid","hc", "lc").withColumn("text", lpad($"nuid", 3, "0")).withColumn("literal",lit(1))
val accumulator = sc.longAccumulator("udf_call_count")
spark.udf.register("myUdf", (x: String) => {accumulator.add(1)
x.length}
)
accumulator.reset()
df.where("myUdf(text) = 3 and hc = -4").select(max($"text")).show(false)
println(s"Number of UDF calls ${accumulator.value}")
returns:
+---------+
|max(text)|
+---------+
|null |
+---------+
Number of UDF calls 1000000
Scenario 2
import org.apache.spark.sql.functions._
def randomInt1to1000000000 = scala.util.Random.nextInt(1000000000)+1
def randomInt1to10 = scala.util.Random.nextInt(10)+1
def randomInt1to1000000 = scala.util.Random.nextInt(1000000)+1
val dfA = sc.parallelize(Seq.fill(1000000){(randomInt1to1000000,randomInt1to1000000000,randomInt1to10)}).toDF("nuid","hc", "lc").withColumn("text", lpad($"nuid", 3, "0")).withColumn("literal",lit(1))
val accumulator = sc.longAccumulator("udf_call_count")
spark.udf.register("myUdf", (x: String) => {accumulator.add(1)
x.length}
)
accumulator.reset()
dfA.where("hc = -4 and myUdf(text) = 3").select(max($"text")).show(false)
println(s"Number of UDF calls ${accumulator.value}")
returns:
+---------+
|max(text)|
+---------+
|null |
+---------+
Number of UDF calls 0
My conclusion here is that:
There is left to right evaluation - in this case - as there are 0 calls for the udf as the accumulator value is 0 for scenario 2, as opposed to scenario 1 with 1M calls registered.
So, the order of predicate processing as say ORACLE and DB2 may do for Stage 1 predicates does not apply.
Point 3
I note from the manual however
https://docs.databricks.com/spark/latest/spark-sql/udf-scala.html the
following:
Evaluation order and null checking
Spark SQL (including SQL and the DataFrame and Dataset APIs) does not
guarantee the order of evaluation of subexpressions. In particular,
the inputs of an operator or function are not necessarily evaluated
left-to-right or in any other fixed order. For example, logical AND
and OR expressions do not have left-to-right “short-circuiting”
semantics.
Therefore, it is dangerous to rely on the side effects or order of
evaluation of Boolean expressions, and the order of WHERE and HAVING
clauses, since such expressions and clauses can be reordered during
query optimization and planning. Specifically, if a UDF relies on
short-circuiting semantics in SQL for null checking, there’s no
guarantee that the null check will happen before invoking the UDF. For
example,
spark.udf.register("strlen", (s: String) => s.length)
spark.sql("select s from test1 where s is not null and strlen(s) > 1") // no guarantee
This WHERE clause does not guarantee the strlen UDF to be invoked
after filtering out nulls.
To perform proper null checking, we recommend that you do either of
the following:
Make the UDF itself null-aware and do null checking inside the UDF
itself Use IF or CASE WHEN expressions to do the null check and invoke
the UDF in a conditional branch.
spark.udf.register("strlen_nullsafe", (s: String) => if (s != null) s.length else -1)
spark.sql("select s from test1 where s is not null and strlen_nullsafe(s) > 1") // ok
spark.sql("select s from test1 where if(s is not null, strlen(s), null) > 1") // ok
Slight contradiction.

Error: Resolved attributes missing in join

I'm using pyspark to perform a join of two tables with a relatively complex join condition (using greater than/smaller than in the join conditions). This works fine, but breaks down as soon as I add a fillna command before the join.
The code looks something like this:
join_cond = [
df_a.col1 == df_b.colx,
df_a.col2 == df_b.coly,
df_a.col3 >= df_b.colz
]
df = (
df_a
.fillna('NA', subset=['col1'])
.join(df_b, join_cond, 'left')
)
This results in an error like this:
org.apache.spark.sql.AnalysisException: Resolved attribute(s) col1#4765 missing from col1#6488,col2#4766,col3#4768,colx#4823,coly#4830,colz#4764 in operator !Join LeftOuter, (((col1#4765 = colx#4823) && (col2#4766 = coly#4830)) && (col3#4768 >= colz#4764)). Attribute(s) with the same name appear in the operation: col1. Please check if the right attribute(s) are used.
It looks like spark no longer recognizes col1 after performing the fillna. (The error does not come up if I comment that out.) The problem is that I do need that statement. (And in general I've simplified this example a lot.)
I've looked at this question, but these answers do not work for me. Specifically, using .alias('a') after the fillna doesn't work because then spark does not recognize the a in the join condition.
Could someone:
Explain exactly why this is happening and how I can avoid it in the future?
Advise me on a way to solve it?
Thanks in advance for your help.
What is happening?
In order to "replace" empty values, a new dataframe is created that contains new columns. These new columns have the same names like the old ones but are effectively completely new Spark objects. In the Scala code you can see that the "changed" columns are newly created ones while the original columns are dropped.
A way to see this effect is to call explain on the dataframe before and after replacing the empty values:
df_a.explain()
prints
== Physical Plan ==
*(1) Project [_1#0L AS col1#6L, _2#1L AS col2#7L, _3#2L AS col3#8L]
+- *(1) Scan ExistingRDD[_1#0L,_2#1L,_3#2L]
while
df_a.fillna(42, subset=['col1']).explain()
prints
== Physical Plan ==
*(1) Project [coalesce(_1#0L, 42) AS col1#27L, _2#1L AS col2#7L, _3#2L AS col3#8L]
+- *(1) Scan ExistingRDD[_1#0L,_2#1L,_3#2L]
Both plans contain a column called col1, but in the first case the internal representation is called col1#6L while the second one is called col1#27L.
When the join condition df_a.col1 == df_b.colx now is associated with the column col1#6L the join will fail if only the column col1#27L is part of the left table.
How can the problem be solved?
The obvious way would be to move the `fillna` operation before the definition of the join condition:
df_a = df_a.fillna('NA', subset=['col1'])
join_cond = [
df_a.col1 == df_b.colx,
[...]
If this is not possible or wanted you can change the join condition. Instead of using a column from the dataframe (df_a.col1) you can use a column that is not associated with any dataframe by using the col function. This column works only based on its name and therefore ignores when the column is replaced in the dataframe:
from pyspark.sql import functions as F
join_cond = [
F.col("col1") == df_b.colx,
df_a.col2 == df_b.coly,
df_a.col3 >= df_b.colz
]
The downside of this second approach is that the column names in both tables must be unique.

How to stop Spark resolving UDF column in conditional statement

I want to perform some conditional branching to avoid calculating unnecessary nodes but I am noticing that if the source column in the condition statement is a UDF then the otherwise is resolved regardless:
#pandas_udf("double", PandasUDFType.SCALAR)
def udf_that_throws_exception(*cols):
raise Exception('Error')
#pandas_udf("int", PandasUDFType.SCALAR)
def simple_mul_udf(*cols):
result = cols[0]
for c in cols[1:]:
result *= c
return result
df = spark.range(0,5)
df = df.withColumn('A', lit(1))
df = df.withColumn('B', lit(2))
df = df.withColumn('udf', simple_mul('A','B'))
df = df.withColumn('sql', expr('A*B'))
df = df.withColumn('res', when(df.sql < 100, lit(1)).otherwise(udf_that_throws(lit(0))))
The above code works as expected, the statement in this case is always true so my UDF that throws an exception is never called.
However, if i change the condition to use df.udf instead then all of a sudden the otherwise UDF is called and i get the exception even though the condition result has not changed.
I thought i might be able to obfuscate it by removing the UDF from the condition however the same result occurs regardless:
df = df.withColumn('cond', when(df.udf < 100, lit(1)).otherwise(lit(0)))
df = df.withColumn('res', when(df.cond == lit(1), lit(1)).otherwise(udf_that_throws_exception(lit(0))))
I imagine this has something to do with the way Spark optimizes, which is fine, but am looking for any way to do this without incurring the cost. Any ideas?
Edit
Per request for more information. We are writing a processing engine that can accept an arbitrary model and the code generates the graph. Along the way there are junctures where we make decisions based on the state of values at runtime. We make heavy use of pandas UDF. So imagine a situation where we have multiple paths in the graph and, depending on some condition at runtime, we wish to follow one of those paths, leaving all others untouched.
I would like to encode this logic into the graph so there is no point where I have to collect and branch in the code.
The sample code I have provided is for demonstration purposes only. The issue I am facing is that if the column used in the IF statement is a UDF or, it seems, if it is derived from a UDF, then the OTHERWISE condition is always executed even if its never actually used. If the IF/ELSE are cheap operations such as literals I wouldnt mind, but what if the column UDF (perhaps on both sides) results in a large aggregation or some other length process which is actually just thrown away?
In PySpark the UDFs are computed beforehand and therefore you are getting this sub-optimal bahaviour. You can see it also from the query plan:
== Physical Plan ==
*(2) Project [id#753L, 1 AS A#755, 2 AS B#758, pythonUDF1#776 AS udf#763, CASE WHEN (pythonUDF1#776 < 100) THEN 1.0 ELSE pythonUDF2#777 END AS res#769]
+- ArrowEvalPython [simple_mul_udf(1, 2), simple_mul_udf(1, 2), udf_that_throws_exception(0)], [id#753L, pythonUDF0#775, pythonUDF1#776, pythonUDF2#777]
+- *(1) Range (0, 5, step=1, splits=8)
The ArrowEvalPython operator is responsible for computing the UDFs and after that the condition will be evaluated in the Project operator.
The reason why you get different behaviour when you call df.sql in your condition (the optimal behaviour) is that this is a special case in which the value in the this column is constant (both columns A and B are constant) and the Spark optimizer can evaluate it beforehand (in the driver during query plan processing, before the execution of the actual job on the cluster) and thus it knows that the otherwise branch of the condition will never have to be evaluated. If the value in this sql column is dynamic (for example like in the id column) the behaviour will be suboptimal again, because Spark does not know in advance that otherwise part should never take place.
If you want to avoid this suboptimal behaviour (calling udf in otherwise even though it is not needed), one possible solution is that you evaluate this condition inside your udf, for example as follows:
#pandas_udf("int", PandasUDFType.SCALAR)
def udf_with_cond(*cols):
result = cols[0]
for c in cols[1:]:
result *= c
if((result < 100).any()):
return result
else:
raise Exception('Error')
df = df.withColumn('res', udf_with_cond('A', 'B'))

Is .show() a Spark action? [duplicate]

I have the following code:
val df_in = sqlcontext.read.json(jsonFile) // the file resides in hdfs
//some operations in here to create df as df_in with two more columns "terms1" and "terms2"
val intersectUDF = udf( (seq1:Seq[String], seq2:Seq[String] ) => { seq1 intersect seq2 } ) //intersects two sequences
val symmDiffUDF = udf( (seq1:Seq[String], seq2:Seq[String] ) => { (seq1 diff seq2) ++ (seq2 diff seq1) } ) //compute the difference of two sequences
val df1 = (df.withColumn("termsInt", intersectUDF(df("terms1"), df1("terms2") ) )
.withColumn("termsDiff", symmDiffUDF(df("terms1"), df1("terms2") ) )
.where( size(col("termsInt")) >0 && size(col("termsDiff")) > 0 && size(col("termsDiff")) <= 2 )
.cache()
) // add the intersection and difference columns and filter the resulting DF
df1.show()
df1.count()
The app is working properly and fast until the show() but in the count() step, it creates 40000 tasks.
My understanding is that df1.show() should be triggering the full df1 creation and df1.count() should be very fast. What am I missing here? why is count() that slow?
Thank you very much in advance,
Roxana
show is indeed an action, but it is smart enough to know when it doesn't have to run everything. If you had an orderBy it would take very long too, but in this case all your operations are map operations and so there's no need to calculate the whole final table. However, count needs to physically go through the whole table in order to count it and that's why it's taking so long. You could test what I'm saying by adding an orderBy to df1's definition - then it should take long.
EDIT: Also, the 40k tasks are likely due to the amount of partitions your DF is partitioned into. Try using df1.repartition(<a sensible number here, depending on cluster and DF size>) and trying out count again.
show() by default shows only 20 rows. If the 1st partition returned more than 20 rows, then the rest partitions will not be executed.
Note show has a lot of variations. If you run show(false) which means show all results, all partitions will be executed and may take more time. So, show() equals show(20) which is a partial action.

Transforming Spark SQL AST with extraOptimizations

I'm wanting to take a SQL string as a user input, then transform it before execution. In particular, I want to modify the top-level projection (select clause), injecting additional columns to be retrieved by the query.
I was hoping to achieve this by hooking into Catalyst using sparkSession.experimental.extraOptimizations. I know that what I'm attempting isn't strictly speaking an optimisation (the transformation changes the semantics of the SQL statement), but the API still seems suitable. However, my transformation seems to be ignored by the query executor.
Here is a minimal example to illustrate the issue I'm having. First define a row case class:
case class TestRow(a: Int, b: Int, c: Int)
Then define an optimisation rule which simply discards any projection:
object RemoveProjectOptimisationRule extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case x: Project => x.child
}
}
Now create a dataset, register the optimisation, and run a SQL query:
// Create a dataset and register table.
val dataset = List(TestRow(1, 2, 3)).toDS()
val tableName: String = "testtable"
dataset.createOrReplaceTempView(tableName)
// Register "optimisation".
sparkSession.experimental.extraOptimizations =
Seq(RemoveProjectOptimisationRule)
// Run query.
val projected = sqlContext.sql("SELECT a FROM " + tableName + " WHERE a = 1")
// Print query result and the queryExecution object.
println("Query result:")
projected.collect.foreach(println)
println(projected.queryExecution)
Here is the output:
Query result:
[1]
== Parsed Logical Plan ==
'Project ['a]
+- 'Filter ('a = 1)
+- 'UnresolvedRelation `testtable`
== Analyzed Logical Plan ==
a: int
Project [a#3]
+- Filter (a#3 = 1)
+- SubqueryAlias testtable
+- LocalRelation [a#3, b#4, c#5]
== Optimized Logical Plan ==
Filter (a#3 = 1)
+- LocalRelation [a#3, b#4, c#5]
== Physical Plan ==
*Filter (a#3 = 1)
+- LocalTableScan [a#3, b#4, c#5]
We see that the result is identical to that of the original SQL statement, without the transformation applied. Yet, when printing the logical and physical plans, the projection has indeed been removed. I've also confirmed (through debug log output) that the transformation is indeed being invoked.
Any suggestions as to what's going on here? Maybe the optimiser simply ignores "optimisations" that change semantics?
If using the optimisations isn't the way to go, can anybody suggest an alternative? All I really want to do is parse the input SQL statement, transform it, and pass the transformed AST to Spark for execution. But as far as I can see, the APIs for doing this are private to the Spark sql package. It may be possible to use reflection, but I'd like to avoid that.
Any pointers would be much appreciated.
As you guessed, this is failing to work because we make assumptions that the optimizer will not change the results of the query.
Specifically, we cache the schema that comes out of the analyzer (and assume the optimizer does not change it). When translating rows to the external format, we use this schema and thus are truncating the columns in the result. If you did more than truncate (i.e. changed datatypes) this might even crash.
As you can see in this notebook, it is in fact producing the result you would expect under the covers. We are planning to open up more hooks at some point in the near future that would let you modify the plan at other phases of query execution. See SPARK-18127 for more details.
Michael Armbrust's answer confirmed that this kind of transformation shouldn't be done via optimisations.
I've instead used internal APIs in Spark to achieve the transformation I wanted for now. It requires methods that are package-private in Spark. So we can access them without reflection by putting the relevant logic in the appropriate package. In outline:
// Must be in the spark.sql package.
package org.apache.spark.sql
object SQLTransformer {
def apply(sparkSession: SparkSession, ...) = {
// Get the AST.
val ast = sparkSession.sessionState.sqlParser.parsePlan(sql)
// Transform the AST.
val transformedAST = ast match {
case node: Project => // Modify any top-level projection
...
}
// Create a dataset directly from the AST.
Dataset.ofRows(sparkSession, transformedAST)
}
}
Note that this of course may break with future versions of Spark.

Resources