Forcing pyspark join to occur sooner - apache-spark

PROBLEM: I have two tables that are vastly different in size. I want to join on some id by doing a left-outer join. Unfortunately, for some reason even after caching my actions after the join are being executed on all records even though I only want the ones from the left table. See below:
MY QUESTIONS:
1. How can I set this up so only the records that match the left table get processed through the costly wrangling steps?
LARGE_TABLE => ~900M records
SMALL_TABLE => 500K records
CODE:
combined = SMALL_TABLE.join(LARGE_TABLE SMALL_TABLE.id==LARGE_TABLE.id, 'left-outer')
print(combined.count())
...
...
# EXPENSIVE STUFF!
w = Window().partitionBy("id").orderBy(col("date_time"))
data = data.withColumn('diff_id_flag', when(lag('id').over(w) != col('id'), lit(1)).otherwise(lit(0)))
Unfortunately, my execution plan shows the expensive transformation operation above is being done on ~900M records. I find this odd since I ran df.count() to force the join to execute eagerly rather than lazily.
Any Ideas?
ADDITIONAL INFORMATION:
- note that the expensive transformation in my code flow occurs after the join (at least that is how I interpret it) but my DAG shows the expensive transformation occurring as a part of the join. This is exactly what I want to avoid as the transformation is expensive. I want the join to execute and THEN the result of that join to be run through the expensive transformation.
- Assume the smaller table CANNOT fit into memory.

The best way to do this is to broadcast the tiny dataframe. Caching is good for multiple actions, which doesnt seem to be applicable ro your particular use case.

df.count has no effect on the execution plan at all. It is just expensive operation executed without any good reason.
Window function application in this requires the same logic as join. Because you join by id and partitionBy idboth stages will require the same hash partitioning and full data scan for both sides. There is no acceptable reason to separate these two.
In practice join logic should be applied before window, serving as a filter for the the downstream transformations in the same stage.

Related

Spark Dataset join performance

I receive a Dataset and I am required to join it with another Table. Hence the most simple solution that came to my mind was to create a second Dataset for the other table and perform the joinWith.
def joinFunction(dogs: Dataset[Dog]): Dataset[(Dog, Cat)] = {
val cats: Dataset[Cat] = spark.table("dev_db.cat").as[Cat]
dogs.joinWith(cats, ...)
}
Here my main concern is with spark.table("dev_db.cat"), as it feels like we are referring to all of the cat data as
SELECT * FROM dev_db.cat
and then doing a join at a later stage. Or will the query optimizer directly perform the join with out referring to the whole table? Is there a better solution?
Here are some suggestions for your case:
a. If you have where, filter, limit, take etc operations try to apply them before joining the two datasets. Spark can't push down these kind of filters therefore you have to do by your own reducing as much as possible the amount of target records. Here an excellent source of information over the Spark optimizer.
b. Try to co-locate the datasets and minimize the shuffled data by using repartition function. The repartition should be based on the keys that participate in join i.e:
dogs.repartition(1024, "key_col1", "key_col2")
dogs.join(cats, Seq("key_col1", "key_col2"), "inner")
c. Try to use broadcast for the smaller dataset if you are sure that it can fit in memory (or increase the value of spark.broadcast.blockSize). This consists a certain boost for the performance of your Spark program since it will ensure the co-existense of two datasets within the same node.
If you can't apply any of the above then Spark doesn't have a way to know which records should be excluded and therefore will scan all the available rows from both datasets.
You need to do an explain and see if predicate push down is used. Then you can judge your concern to be correct or not.
However, in general now, if no complex datatypes are used and/or datatype mismatches are not evident, then push down takes place. You can see that with simple createOrReplaceTempView as well. See https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/3741049972324885/4201913720573284/4413065072037724/latest.html

Spark reuse broadcast DF

I would like to reuse my DataFrame (without falling back to doing this using "Map" function in RDD/Dataset) which I marking as broadcast-eable, but seems Spark keeps broadcasting it again and again.
Having a table "bank" (test table). I perform the following:
val cachedDf = spark.sql("select * from bank").cache
cachedDf.count
val dfBroadcasted = broadcast(cachedDf)
val dfNormal = spark.sql("select * from bank")
dfNormal.join(dfBroadcasted, List("age"))
.join(dfBroadcasted, List("age")).count
I'm caching before just in case it made a difference, but its the same with or without.
If I execute the above code, I see the following SQL plan:
As you can see, my broadcasted DF gets broadcasted TWICE with also different timings (if I add more actions afterwards, they broadcast again too).
I care about this, because I actually have a long-running program which has a "big" DataFrame which I can use to filter out HUGE DataFrames, and I would like that "big" DataFrame to be reused.
Is there a way to force reusability? (not only inside the same action, but between actions, I could survive with the same action tho)
Thanks!,
Ok, updating the question.
Summarising:
INSIDE the same action, left_semis will reuse broadcasts
while normal/left joins won't. Not sure related with the fact that Spark/developers already know the columns of that DF won't affect the output at all so they can reuse it or it's just an optimization spark is missing.
My problem seems mostly-solved, although it would be great if someone knew how to keep the broadcast across actions.
If I use left_semi (which is the join i'm going to use in my real app), the broadcast is only performed once.
With:
dfNormalxx.join(dfBroadcasted, Seq("age"),"left_semi")
.join(dfBroadcasted, Seq("age"),"left_semi").count
The plan becomes (I also changed the size so it matches my real one, but this made no difference):
Also the wall total time is much better than when using "left_semi" (I set 1 executor so it doesn't get parallelized, just wanted to check if the job was really being done twice):
Even though my collect takes 10 seconds, this will speedup table reads+groupBys which are taking like 6-7minutes

Does joining additional columns in Spark scale horizontally?

I have a dataset with about 2.4M rows, with a unique key for each row. I have performed some complex SQL queries on some other tables, producing a dataset with two columns, a key and the value true. This dataset is about 500 rows. Now I would like to (outer) join this dataset with my original table.
This produces a new table with a very sparse set of values (true in about 500 rows, null elsewhere).
Finally, I would like to do this about 200 times, giving me a final table of about 201 columns (the key, plus the 200 sparse columns).
When I run this, I notice that as it runs it gets considerably slower. The first join takes 2 seconds, then 4s, then 6s, then 10s, then 20s and after about 30 joins the system never recovers. Of course, the actual numbers are irrelevant as that depends on the cluster I'm running, but I'm wondering:
Is this slowdown is expected?
I am using parquet as a data storage format (columnar storage) so I was hopeful that adding more columns would scale horizontally, is that a correct assumption?
All the columns I've joined so far are not needed for the Nth join, can they be unloaded from memory?
Are there other things I can do when combining lots of columns in spark?
Calling explain on each join in the loop shows that each join is getting more complex (appears to include all previous joins and it also includes the complex sql queries, even though those have been checkpointed). Is there a way to really checkpoint so each join is just a join? I am actually calling show() after each join, so I assumed the join is actually happening at that point.
Is this slowdown is expected
Yes, to some extent it is. Joins belong to the most expensive operations in a data intensive systems (it is not a coincidence that products which claim linear scalability usually take joins out of the table). Join-like operation in a distributed system typically require data exchange between nodes hitting a bunch of high latency numbers.
In Spark SQL there is also additional cost of computing execution plan, which has larger than linear complexity.
I am using parquet as a data storage format (columnar storage) so I was hopeful that adding more columns would scale horizontally, is that a correct assumption?
No. Input format doesn't affect join logic at all.
All the columns I've joined so far are not needed for the Nth join, can they be unloaded from memory?
If truly excluded from the final output they will be pruned from the execution plan. But since you for a reason, I assume it is not the case and there are required for the final output.
Is there a way to really checkpoint so each join is just a join? I am actually calling show() after each join, so I assumed the join is actually happening at that point.
show computes only a small subset of data required for the output. It doesn't cache, although shuffle files might be reused.
(appears to include all previous joins and it also includes the complex sql queries, even though those have been checkpointed).
Checkpoints are created only if data is fully computed and don't remove stages from the execution plan. If you want to do it explicitly, write partial result to persistent storage and read it back at the beginning of each iteration (it is probably an overkill).
Are there other things I can do when combining lots of columns in spark?
The best thing you can do is to find a way to avoid joins completely. If key is always the same then single shuffle, and operation on groups / partitions (with byKey method, window functions) might be better choice.
However if you
have a dataset with about 2.4M rows
then using non-distributed system that supports in-place modification might be much better choice.
In the most naive implementation you can compute each aggregate separately, sort by key and write to disk. Then data can be merged together line by line with negligible memory footprint.

DataFrame orderBy followed by limit in Spark

I am having a program take generate a DataFrame on which it will run something like
Select Col1, Col2...
orderBy(ColX) limit(N)
However, when i collect the data in end, i find that it is causing the driver to OOM if I take a enough large top N
Also another observation is that if I just do sort and top, this problem will not happen. So this happen only when there is sort and top at the same time.
I am wondering why it could be happening? And particular, what is really going underneath this two combination of transforms? How does spark will evaluate query with both sorting and limit and what is corresponding execution plan underneath?
Also just curious does spark handle sort and top different between DataFrame and RDD?
EDIT,
Sorry i didn't mean collect,
what i original just mean that when i call any action to materialize the data, regardless of whether it is collect (or any action sending data back to driver) or not (So the problem is definitely not on the output size)
While it is not clear why this fails in this particular case there multiple issues you may encounter:
When you use limit it simply puts all data on a single partition, no matter how big n is. So while it doesn't explicitly collect it almost as bad.
On top of that orderBy requires a full shuffle with range partitioning which can result in a different issues when data distribution is skewed.
Finally when you collect results can be larger than the amount of memory available on the driver.
If you collect anyway there is not much you can improve here. At the end of the day driver memory will be a limiting factor but there still some possible improvements:
First of all don't use limit.
Replace collect with toLocalIterator.
use either orderBy |> rdd |> zipWithIndex |> filter or if exact number of values is not a hard requirement filter data directly based on approximated distribution as shown in Saving a spark dataframe in multiple parts without repartitioning (in Spark 2.0.0+ there is handy approxQuantile method).

How does the filter operation of Spark work on GraphX edges?

I'm very new to Spark and don't really know the basics, I just jumped into it to solve a problem. The solution for the problem involves making a graph (using GraphX) where edges have a string attribute. A user may wish to query this graph and I handle the queries by filtering out only those edges that have the string attribute which is equal to the user's query.
Now, my graph has more than 16 million edges; it takes more than 10 minutes to create the graph when I'm using all 8 cores of my computer. However, when I query this graph (like I mentioned above), I get the results instantaneously (to my pleasant surprise).
So, my question is, how exactly does the filter operation search for my queried edges? Does it look at them iteratively? Are the edges being searched for on multiple cores and it just seems very fast? Or is there some sort of hashing involved?
Here is an example of how I'm using filter: Mygraph.edges.filter(_.attr(0).equals("cat")) which means that I want to retrieve edges that have the attribute "cat" in them. How are the edges being searched?
How can the filter results be instantaneous?
Running your statement returns so fast because it doesn't actually perform the filtering. Spark uses lazy evaluation: it doesn't actually perform transformations until you perform an action which actually gathers the results. Calling a transformation method, like filter just creates a new RDD that represents this transformation and its result. You will have to perform an action like collect or count to actually have it executed:
def myGraph: Graph = ???
// No filtering actually happens yet here, the results aren't needed yet so Spark is lazy and doesn't do anything
val filteredEdges = myGraph.edges.filter()
// Counting how many edges are left requires the results to actually be instantiated, so this fires off the actual filtering
println(filteredEdges.count)
// Actually gathering all results also requires the filtering to be done
val collectedFilteredEdges = filteredEdges.collect
Note that in these examples the filter results are not stored in between: due to the laziness the filtering is repeated for both actions. To prevent that duplication, you should look into Spark's caching functionality, after reading up on the details on transformations and actions and what Spark actually does behind the scene: https://spark.apache.org/docs/latest/programming-guide.html#rdd-operations.
How exactly does the filter operation search for my queried edges (when I execute an action)?
in Spark GraphX the edges are stored in a an RDD of type EdgeRDD[ED] where ED is the type of your edge attribute, in your case String. This special RDD does some special optimizations in the background, but for your purposes it behaves like its superclass RDD[Edge[ED]] and filtering occurs like filtering any RDD: it will iterate through all items, applying the given predicate to each. An RDD however is split into a number of partitions and Spark will filter multiple partitions in parallel; in your case where you seem to run Spark locally it will do as many in parallel as the number of cores you have, or how much you have specified explicitly with --master local[4] for instance.
The RDD with edges is partitioned based on the PartitionStrategy that is set, for instance if you create your graph with Graph.fromEdgeTuples or by calling partitionBy on your graph. All strategies are based on the edge's vertices however, so don't have any knowledge about your attribute, and so don't affect your filtering operation, except maybe for some unbalanced network load if you'd run it on a cluster, all 'cat' edges end up in the same partition/executor and you do a collect or some shuffle operation. See the GraphX docs on Vertex and Edge RDDs for a bit more information on how graphs are represented and partitioned.

Resources