Cached dataframe flushed after truncating table - apache-spark

Here are the steps:
scala> val df = sql("select * from table")
df: org.apache.spark.sql.DataFrame = [num: int]
scala> df.cache
res13: df.type = [num: int]
scala> df.collect
res14: Array[org.apache.spark.sql.Row] = Array([10], [10])
scala> df
res15: org.apache.spark.sql.DataFrame = [num: int]
scala> df.show
+---+
|num|
+---+
| 10|
| 10|
+---+
scala> sql("truncate table table")
res17: org.apache.spark.sql.DataFrame = []
scala> df.show
+---+
|num|
+---+
+---+
My question is why the df is flushed? My expectation is that it should be cached in the memory and truncate shouldn't erase the data.
Any idea will be much appreciated.
Thanks

The truncate table command removes the cached data then uncaches and empties the table. HERE is the source for truncate. If you follow that link to the source code for TruncateTableCommand, at the bottom of the case class you'll see the following for how the cache and table are handled when a table is truncated:
// After deleting the data, invalidate the table to make sure we don't keep around a stale
// file relation in the metastore cache.
spark.sessionState.refreshTable(tableName.unquotedString)
// Also try to drop the contents of the table from the columnar cache
try {
spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier))
} catch {
case NonFatal(e) =>
log.warn(s"Exception when attempting to uncache table $tableIdentWithDB", e)
}
if (table.stats.nonEmpty) {
// empty table after truncation
val newStats = CatalogStatistics(sizeInBytes = 0, rowCount = Some(0))
catalog.alterTableStats(tableName, Some(newStats))
}
Seq.empty[Row]

You should never depend on cache for correctness. Spark cache is performance optimization, and even with the most defensive StorageLevel (MEMORY_AND_DISK_SER_2) is not guaranteed to preserve data in case of worker failure, executor decommissioning or insufficient resources.
Code similar to the one used in your question might work in some conditions, but don't assume it is guaranteed or deterministic behavior.

Related

Why few partitions are processed twice if mapPartitions is used with toDF()

I need to process partition per partition (long story).
Using mapPartitions is working fine when using RDDs. In the example, when using rdd.mapPartitions(mapper).collect() all work as expected.
But, when transforming to DataFrame, one partition is processed twice.
Why this is happening and how to avoid it?
Following, the output of the next simple example. We can read how the function is executed 3 times, when there are only two partitions. One of the partitions [Row(id=1), Row(id=2)] is processed two times.
It is courious that one of the executions is ignored, as we can see in the DataDrame resulted.
size: 2 > values: [Row(id=1), Row(id=2)]
size: 2 > values: [Row(id=1), Row(id=2)]
size: 2 > values: [Row(id=3), Row(id=4)]
+---+
| id|
+---+
| 1|
| 2|
| 3|
| 4|
+---+
> Mapper executions: 3
Simple example used:
from typing import Iterator
from pyspark import Row
from pyspark.sql import SparkSession
def gen_random_row(id: str):
return Row(id=id)
if __name__ == '__main__':
spark = SparkSession.builder.master("local[1]").appName("looking for the error").getOrCreate()
executions_counter = spark.sparkContext.accumulator(0)
rdd = spark.sparkContext.parallelize([
gen_random_row(1),
gen_random_row(2),
gen_random_row(3),
gen_random_row(4),
], 2)
def mapper(iterator: Iterator[Row]) -> Iterator[Row]:
executions_counter.add(1)
lst = list(iterator)
print(f"size: {len(lst)} > values: {lst}")
for r in lst:
yield r
# rdd.mapPartitions(mapper).collect()
rdd.mapPartitions(mapper).toDF().show()
print(f"> Mapper executions: {executions_counter.value}")
spark.stop()
The solution is passing the schema to toDF
Looks like Spark is processing one partition to infer the schema.
To solve it:
schema = StructType([StructField("id", IntegerType(), True)])
rdd.mapPartitions(mapper).toDF(schema).show()
With this code, every partition is processed one time.

Spark-Scala Try Select Statement

I'm trying to incorporate a Try().getOrElse() statement in my select statement for a Spark DataFrame. The project I'm working on is going to be applied to multiple environments. However, each environment is a little different in terms of the naming of the raw data for ONLY one field. I do not want to write several different functions to handle each different field. Is there a elegant way to handle exceptions, like this below, in a DataFrame select statement?
val dfFilter = dfRaw
.select(
Try($"some.field.nameOption1).getOrElse($"some.field.nameOption2"),
$"some.field.abc",
$"some.field.def"
)
dfFilter.show(33, false)
However, I keep getting the following error, which makes sense because it does not exist in this environments raw data, but I'd expect the getOrElse statement to catch that exception.
org.apache.spark.sql.AnalysisException: No such struct field nameOption1 in...
Is there a good way to handle exceptions in Scala Spark for select statements? Or will I need to code up different functions for each case?
val selectedColumns = if (dfRaw.columns.contains("some.field.nameOption1")) $"some.field.nameOption2" else $"some.field.nameOption2"
val dfFilter = dfRaw
.select(selectedColumns, ...)
So I'm revisiting this question after a year. I believe this solution to be much more elegant to implement. Please let me know anyone else's thoughts:
// Generate a fake DataFrame
val df = Seq(
("1234", "A", "AAA"),
("1134", "B", "BBB"),
("2353", "C", "CCC")
).toDF("id", "name", "nameAlt")
// Extract the column names
val columns = df.columns
// Add a "new" column name that is NOT present in the above DataFrame
val columnsAdd = columns ++ Array("someNewColumn")
// Let's then "try" to select all of the columns
df.select(columnsAdd.flatMap(c => Try(df(c)).toOption): _*).show(false)
// Let's reduce the DF again...should yield the same results
val dfNew = df.select("id", "name")
dfNew.select(columnsAdd.flatMap(c => Try(dfNew(c)).toOption): _*).show(false)
// Results
columns: Array[String] = Array(id, name, nameAlt)
columnsAdd: Array[String] = Array(id, name, nameAlt, someNewColumn)
+----+----+-------+
|id |name|nameAlt|
+----+----+-------+
|1234|A |AAA |
|1134|B |BBB |
|2353|C |CCC |
+----+----+-------+
dfNew: org.apache.spark.sql.DataFrame = [id: string, name: string]
+----+----+
|id |name|
+----+----+
|1234|A |
|1134|B |
|2353|C |
+----+----+

Spark DataFrame losing string data in yarn-client mode

By some reason if I'm adding new column, appending string to existing data/column or creating new DataFrame from code, it misinterpreting string data, so show() doesn't work properly, filters (such as withColumn, where, when, etc.) doesn't work ether.
Here is example code:
object MissingValue {
def hex(str: String): String = str.getBytes("UTF-8").map(f => Integer.toHexString((f&0xFF)).toUpperCase).mkString("-")
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("MissingValue")
val sc = new SparkContext(conf)
sc.setLogLevel("WARN")
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val list = List((101,"ABC"),(102,"BCD"),(103,"CDE"))
val rdd = sc.parallelize(list).map(f => Row(f._1,f._2))
val schema = StructType(StructField("COL1",IntegerType,true)::StructField("COL2",StringType,true)::Nil)
val df = sqlContext.createDataFrame(rdd,schema)
df.show()
val str = df.first().getString(1)
println(s"${str} == ${hex(str)}")
sc.stop()
}
}
If I run it in local mode then everything works as expected:
+----+----+
|COL1|COL2|
+----+----+
| 101| ABC|
| 102| BCD|
| 103| CDE|
+----+----+
ABC == 41-42-43
But when I run the same code in yarn-client mode it produces:
+----+----+
|COL1|COL2|
+----+----+
| 101| ^E^#^#|
| 102| ^E^#^#|
| 103| ^E^#^#|
+----+----+
^E^#^# == 5-0-0
This problem exists only for string values, so first column (Integer) is fine.
Also if I'm creating rdd from the dataframe then everything is fine i.e. df.rdd.take(1).apply(0).getString(1)
I'm using Spark 1.5.0 from CDH 5.5.2
EDIT:
It seems that this happens when the difference between driver memory and executor memory is too high --driver-memory xxG --executor-memory yyG i.e. when I decreasing executor memory or increasing driver memory then the problem disappears.
This is a bug related to executor memory and Oops size:
https://issues.apache.org/jira/browse/SPARK-9725
https://issues.apache.org/jira/browse/SPARK-10914
https://issues.apache.org/jira/browse/SPARK-17706
It is fixed in Spark version 1.5.2

Temp table caching with spark-sql

Is a table registered with registerTempTable (createOrReplaceTempView with spark 2.+) cached?
Using Zeppelin, I register a DataFrame in my scala code, after heavy computation, and then within %pyspark I want to access it, and further filter it.
Will it use a memory-cached version of the table? Or will it be rebuilt each time?
Registered tables are not cached in memory.
The registerTempTable createOrReplaceTempView method will just create or replace a view of the given DataFrame with a given query plan.
It will convert the query plan to canonicalized SQL string, and store it as view text in metastore, if we need to create a permanent view.
You'll need to cache your DataFrame explicitly. e.g :
df.createOrReplaceTempView("my_table") # df.registerTempTable("my_table") for spark <2.+
spark.cacheTable("my_table")
EDIT:
Let's illustrate this with an example :
Using cacheTable :
scala> val df = Seq(("1",2),("b",3)).toDF
// df: org.apache.spark.sql.DataFrame = [_1: string, _2: int]
scala> sc.getPersistentRDDs
// res0: scala.collection.Map[Int,org.apache.spark.rdd.RDD[_]] = Map()
scala> df.createOrReplaceTempView("my_table")
scala> sc.getPersistentRDDs
// res2: scala.collection.Map[Int,org.apache.spark.rdd.RDD[_]] = Map()
scala> spark.catalog.cacheTable("my_table") // spark.cacheTable("...") before spark 2.0
scala> sc.getPersistentRDDs
// res4: scala.collection.Map[Int,org.apache.spark.rdd.RDD[_]] =
// Map(2 -> In-memory table my_table MapPartitionsRDD[2] at
// cacheTable at <console>:26)
Now the same example using cache.registerTempTable cache.createOrReplaceTempView :
scala> sc.getPersistentRDDs
// res2: scala.collection.Map[Int,org.apache.spark.rdd.RDD[_]] = Map()
scala> val df = Seq(("1",2),("b",3)).toDF
// df: org.apache.spark.sql.DataFrame = [_1: string, _2: int]
scala> df.createOrReplaceTempView("my_table")
scala> sc.getPersistentRDDs
// res4: scala.collection.Map[Int,org.apache.spark.rdd.RDD[_]] = Map()
scala> df.cache.createOrReplaceTempView("my_table")
scala> sc.getPersistentRDDs
// res6: scala.collection.Map[Int,org.apache.spark.rdd.RDD[_]] =
// Map(2 -> ConvertToUnsafe
// +- LocalTableScan [_1#0,_2#1], [[1,2],[b,3]]
// MapPartitionsRDD[2] at cache at <console>:28)
It is not. You should cache explicitly:
sqlContext.cacheTable("someTable")

How to get the number of elements in partition? [duplicate]

This question already has answers here:
Apache Spark: Get number of records per partition
(6 answers)
Closed 2 years ago.
Is there any way to get the number of elements in a spark RDD partition, given the partition ID? Without scanning the entire partition.
Something like this:
Rdd.partitions().get(index).size()
Except I don't see such an API for spark. Any ideas? workarounds?
Thanks
The following gives you a new RDD with elements that are the sizes of each partition:
rdd.mapPartitions(iter => Array(iter.size).iterator, true)
PySpark:
num_partitions = 20000
a = sc.parallelize(range(int(1e6)), num_partitions)
l = a.glom().map(len).collect() # get length of each partition
print(min(l), max(l), sum(l)/len(l), len(l)) # check if skewed
Spark/scala:
val numPartitions = 20000
val a = sc.parallelize(0 until 1e6.toInt, numPartitions )
val l = a.glom().map(_.length).collect() # get length of each partition
print(l.min, l.max, l.sum/l.length, l.length) # check if skewed
The same is possible for a dataframe, not just for an RDD.
Just add DF.rdd.glom... into the code above.
Notice that glom() converts elements of each partition into a list, so it's memory-intensive. A less memory-intensive version (pyspark version only):
import statistics
def get_table_partition_distribution(table_name: str):
def get_partition_len (iterator):
yield sum(1 for _ in iterator)
l = spark.table(table_name).rdd.mapPartitions(get_partition_len, True).collect() # get length of each partition
num_partitions = len(l)
min_count = min(l)
max_count = max(l)
avg_count = sum(l)/num_partitions
stddev = statistics.stdev(l)
print(f"{table_name} each of {num_partitions} partition's counts: min={min_count:,} avg±stddev={avg_count:,.1f} ±{stddev:,.1f} max={max_count:,}")
get_table_partition_distribution('someTable')
outputs something like
someTable each of 1445 partition's counts:
min=1,201,201 avg±stddev=1,202,811.6 ±21,783.4 max=2,030,137
I know I'm little late here, but I have another approach to get number of elements in a partition by leveraging spark's inbuilt function. It works for spark version above 2.1.
Explanation:
We are going to create a sample dataframe (df), get the partition id, do a group by on partition id, and count each record.
Pyspark:
>>> from pyspark.sql.functions import spark_partition_id, count as _count
>>> df = spark.sql("set -v").unionAll(spark.sql("set -v")).repartition(4)
>>> df.rdd.getNumPartitions()
4
>>> df.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").agg(_count("key")).orderBy("partition_id").show()
+------------+----------+
|partition_id|count(key)|
+------------+----------+
| 0| 48|
| 1| 44|
| 2| 32|
| 3| 48|
+------------+----------+
Scala:
scala> val df = spark.sql("set -v").unionAll(spark.sql("set -v")).repartition(4)
df: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [key: string, value: string ... 1 more field]
scala> df.rdd.getNumPartitions
res0: Int = 4
scala> df.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").agg(count("key")).orderBy("partition_id").show()
+------------+----------+
|partition_id|count(key)|
+------------+----------+
| 0| 48|
| 1| 44|
| 2| 32|
| 3| 48|
+------------+----------+
pzecevic's answer works, but conceptually there's no need to construct an array and then convert it to an iterator. I would just construct the iterator directly and then get the counts with a collect call.
rdd.mapPartitions(iter => Iterator(iter.size), true).collect()
P.S. Not sure if his answer is actually doing more work since Iterator.apply will likely convert its arguments into an array.

Resources