Explode map column in Pyspark without losing null values - apache-spark

Is there any elegant way to explode map column in Pyspark 2.2 without loosing null values? Explode_outer was introduced in Pyspark 2.3
The schema of the affected column is:
|-- foo: map (nullable = true)
| |-- key: string
| |-- value: struct (valueContainsNull = true)
| | |-- first: long (nullable = true)
| | |-- last: long (nullable = true)
I would like to replace empty Map with some dummy values to be able to explode whole dataframe without loosing null values. I have tried something like this, but i get an error:
from pyspark.sql.functions import when, size, col
df = spark.read.parquet("path").select(
when(size(col("foo")) == 0, {"key": [0, 0]}).alias("bar")
And the error:
Py4JJavaError: An error occurred while calling z:org.apache.spark.sql.functions.when.
: java.lang.RuntimeException: Unsupported literal type class java.util.HashMap {key=[0, 0]}
at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:77)
at org.apache.spark.sql.catalyst.expressions.Literal$$anonfun$create$2.apply(literals.scala:163)
at org.apache.spark.sql.catalyst.expressions.Literal$$anonfun$create$2.apply(literals.scala:163)
at scala.util.Try.getOrElse(Try.scala:79)
at org.apache.spark.sql.catalyst.expressions.Literal$.create(literals.scala:162)
at org.apache.spark.sql.functions$.typedLit(functions.scala:112)
at org.apache.spark.sql.functions$.lit(functions.scala:95)
at org.apache.spark.sql.functions$.when(functions.scala:1256)
at org.apache.spark.sql.functions.when(functions.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:280)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:214)
at java.lang.Thread.run(Thread.java:748)

So I have finally made it. I have replaced empty map with some dummy values and then used explode and drop original column.
replace_empty_map = udf(lambda x: {"key": [0, 1]} if len(x) == 0 else x,
[StructField("first", LongType()), StructField("last", LongType())]
df = df.withColumn("foo_replaced",replace_empty_map(df["foo"])).drop("foo")
df = df.select('*', explode('foo_replaced').alias('foo_key', 'foo_val')).drop("foo_replaced")


AnalysisException: ambiguous reference error when trying to replace withColumn() with select()

I get a StackOverflowException error when I use multiple times withColumn() to update the values of a column in Pyspark.
My code with I got the StackOverflowException was:
df = df.withColumn("element", when(df["element"] == 1,"first").otherwise(df["element"]))
df = df.withColumn("element", when(df["element"] == 2,"second").otherwise(df["element"]))
df = df.withColumn("element", when(df["element"] == 3,"third").otherwise(df["element"]))
df = df.withColumn("element", when(df["element"] == 4,"fourth").otherwise(df["element"]))
The Spark documentation suggests to use the select() function. So I tried:
df = df.select("*", (when(df["element"] == 1,"first")).alias("element"))
df = df.select("*", (when(df["element"] == 2,"second")).alias("element"))
df = df.select("*", (when(df["element"] == 3,"third")).alias("element"))
df = df.select("*", (when(df["element"] == 4,"fourth")).alias("element"))
But I recieve an error because of the column "element" isn't updated, another column with the same name is created. The error is this:
Py4JJavaError: An error occurred while calling o3723.apply.
: org.apache.spark.sql.AnalysisException: Reference 'element' is ambiguous, could be: element, element.;
at org.apache.spark.sql.catalyst.expressions.package$AttributeSeq.resolve(package.scala:259)
at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveQuoted(LogicalPlan.scala:121)
at org.apache.spark.sql.Dataset.resolve(Dataset.scala:229)
at org.apache.spark.sql.Dataset.col(Dataset.scala:1282)
at org.apache.spark.sql.Dataset.apply(Dataset.scala:1249)
at sun.reflect.GeneratedMethodAccessor36.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:750)
How could I do it?
Thank you in advance!
I think you can use .when multiple times, then .otherwise. Also you should name the new column something different so that you don't get an ambiguous column error:
df = df.withColumn("element_new", when(df["element"] == 1,"first").when(df["element"] == 2,"second").when(df["element"] == 3,"third").when(df["element"] == 4,"fourth").otherwise(df["element"]))
Using .select:
df = df.select("*",when(df["element"] == 1,"first").when(df["element"] == 2,"second").when(df["element"] == 3,"third").when(df["element"] == 4,"fourth").otherwise(df["element"]).alias("element_new"))
Example output:
| 1| first|
| 2| second|
| 3| third|
| 4| fourth|
| 5| 5|

Modify nested property inside Struct column with PySpark

I want to modify/filter on a property inside a struct.
Let's say I have a dataframe with the following column :
#| arrayCol |
#| {"a" : "some_value", "b" : [1, 2, 3]} |
struct<a:string, b:array<int>>
I want to filter out some values in 'b' property when value inside the array == 1
The result desired is the following :
#| arrayCol |
#| {"a" : "some_value", "b" : [2, 3]} |
Is it possible to do it without extracting the property, filter the values, and re-build another struct ?
For spark 3.1+, withField can be used to update the struct column without having to recreate all the struct. In your case, you can update the field b using filter function to filter the array values like this:
import pyspark.sql.functions as F
df1 = df.withColumn(
F.col('arrayCol').withField('b', F.filter(F.col("arrayCol.b"), lambda x: x != 1))
#| arrayCol|
#|{some_value, [2, 3]}|
For older versions, Spark doesn’t support adding/updating fields in nested structures. To update a struct column, you'll need to create a new struct using the existing fields and the updated ones:
import pyspark.sql.functions as F
df1 = df.withColumn(
F.expr("filter(arrayCol.b, x -> x != 1)").alias("b")
One way would be to define a UDF:
import ast
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import StringType, MapType
def remove_value(col):
col["b"] = str([x for x in ast.literal_eval(col["b"]) if x != 1])
return col
if __name__ == "__main__":
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(
"arrayCol": {
"a": "some_value",
"b": "[1, 2, 3]",
remove_value_udf = spark.udf.register(
"remove_value_udf", remove_value, MapType(StringType(), StringType())
df = df.withColumn(
|-- arrayCol: map (nullable = true)
| |-- key: string
| |-- value: string (valueContainsNull = true)
|-- result: map (nullable = true)
| |-- key: string
| |-- value: string (valueContainsNull = true)
|arrayCol |result |
|{a -> some_value, b -> [1, 2, 3]}|{a -> some_value, b -> [2, 3]}|

How to use Spark SQL SPLIT function to pass input to Spark SQL IN parameter [duplicate]

I have a dataframe with two columns(one string and one array of string):
|-- user: string (nullable = true)
|-- users: array (nullable = true)
| |-- element: string (containsNull = true)
How can I filter the dataframe so that the result dataframe only contains rows that user is in users?
Quick and simple:
import org.apache.spark.sql.functions.expr
df.where(expr("array_contains(users, user)")
Sure, It's possible and not so hard. To achieve this you may use a UDF.
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
val df = sc.parallelize(Array(
("1", Array("1", "2", "3")),
("2", Array("1", "2", "2", "3")),
("3", Array("1", "2"))
)).toDF("user", "users")
val inArray = udf((id: String, array: scala.collection.mutable.WrappedArray[String]) => array.contains(id), BooleanType)
df.where(inArray($"user", $"users")).show()
The output is:
|user| users|
| 1| [1, 2, 3]|
| 2|[1, 2, 2, 3]|

scala.collection.mutable.WrappedArray$ofRef cannot be cast to Integer

I'm fairly new to Spark and Scala. I'm trying to call a function as a Spark UDF but I run into this error that I can't seem to resolve.
I understand that in Scala, Array and Seq aren't the same. WrappedArray is a subtype of Seq and there is implicit conversions between WrappedArray and Array but I'm not sure why that doesn't happen in the case of the UDF.
Any pointers to help me understand and resolve this is much appreciated.
Here's a snippet of the code
def filterMapKeysWithSet(m: Map[Int, Int], a: Array[Int]): Map[Int, Int] = {
val seqToArray = a.toArray
val s = seqToArray.toSet
m filterKeys s
val myUDF = udf((m: Map[Int, Int], a: Array[Int]) => filterMapKeysWithSet(m, a))
case class myType(id: Int, m: Map[Int, Int])
val mapRDD = Seq(myType(1, Map(1 -> 100, 2 -> 200)), myType(2, Map(1 -> 100, 2 -> 200)), myType(3, Map(3 -> 300, 4 -> 400)))
val mapDF = mapRDD.toDF
mapDF: org.apache.spark.sql.DataFrame = [id: int, m: map<int,int>]
|-- id: integer (nullable = false)
|-- m: map (nullable = true)
| |-- key: integer
| |-- value: integer (valueContainsNull = false)
case class myType2(id: Int, a: Array[Int])
val idRDD = Seq(myType2(1, Array(1,2,100,200)), myType2(2, Array(100,200)), myType2(3, Array(1,2)) )
val idDF = idRDD.toDF
idDF: org.apache.spark.sql.DataFrame = [id: int, a: array<int>]
|-- id: integer (nullable = false)
|-- a: array (nullable = true)
| |-- element: integer (containsNull = false)
import sqlContext.implicits._
/* Hive context is exposed as sqlContext */
val j = mapDF.join(idDF, idDF("id") === mapDF("id")).drop(idDF("id"))
val k = j.withColumn("filteredMap",myUDF(j("m"), j("a")))
Looking at the Dataframe "j" & "k", the map and array columns have the right data types.
j: org.apache.spark.sql.DataFrame = [id: int, m: map<int,int>, a: array<int>]
|-- id: integer (nullable = false)
|-- m: map (nullable = true)
| |-- key: integer
| |-- value: integer (valueContainsNull = false)
|-- a: array (nullable = true)
| |-- element: integer (containsNull = false)
k: org.apache.spark.sql.DataFrame = [id: int, m: map<int,int>, a: array<int>, filteredMap: map<int,int>]
|-- id: integer (nullable = false)
|-- m: map (nullable = true)
| |-- key: integer
| |-- value: integer (valueContainsNull = false)
|-- a: array (nullable = true)
| |-- element: integer (containsNull = false)
|-- filteredMap: map (nullable = true)
| |-- key: integer
| |-- value: integer (valueContainsNull = false)
However, an action on the Dataframe "k" that calls the UDF fails with the following error -
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 1.0 failed 4 times, most recent failure: Lost task 0.3 in stage 1.0 (TID 6, ip-100-74-42-194.ec2.internal): java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to [I
at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$anonfun$1.apply(<console>:60)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
at org.apache.spark.sql.execution.Project$$anonfun$1$$anonfun$apply$1.apply(basicOperators.scala:51)
at org.apache.spark.sql.execution.Project$$anonfun$1$$anonfun$apply$1.apply(basicOperators.scala:49)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at scala.collection.Iterator$$anon$10.next(Iterator.scala:312)
at scala.collection.Iterator$class.foreach(Iterator.scala:727)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47)
at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273)
at scala.collection.AbstractIterator.to(Iterator.scala:1157)
at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265)
at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157)
at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252)
at scala.collection.AbstractIterator.toArray(Iterator.scala:1157)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$5.apply(SparkPlan.scala:212)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$5.apply(SparkPlan.scala:212)
at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1865)
at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1865)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
at org.apache.spark.scheduler.Task.run(Task.scala:89)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
Changing the datatype from Array[Int] to Seq[Int] in the function filterMapKeysWithSet seems to resolve the above issue.
def filterMapKeysWithSet(m: Map[Int, Int], a: Seq[Int]): Map[Int, Int] = {
val seqToArray = a.toArray
val s = seqToArray.toSet
m filterKeys s
val myUDF = udf((m: Map[Int, Int], a: Seq[Int]) => filterMapKeysWithSet(m, a))
k: org.apache.spark.sql.DataFrame = [id: int, m: map<int,int>, a: array<int>, filteredMap: map<int,int>]
|-- id: integer (nullable = false)
|-- m: map (nullable = true)
| |-- key: integer
| |-- value: integer (valueContainsNull = false)
|-- a: array (nullable = true)
| |-- element: integer (containsNull = false)
|-- filteredMap: map (nullable = true)
| |-- key: integer
| |-- value: integer (valueContainsNull = false)
| id| m| a| filteredMap|
| 1|Map(1 -> 100, 2 -...|[1, 2, 100, 200]|Map(1 -> 100, 2 -...|
| 2|Map(1 -> 100, 2 -...| [100, 200]| Map()|
| 3|Map(3 -> 300, 4 -...| [1, 2]| Map()|
So it looks like the ArrayType on Dataframe "idDF" is really a WrappedArray and not an Array - So the function call to "filterMapKeysWithSet" failed as it expected an Array but got a WrappedArray/ Seq instead (which doesn't implicitly convert to Array in Scala 2.8 and above).

PySpark : Filter based on resultant query without additional dataframe

Consider the below example
>>> l = [("US","City1",125),("US","City2",123),("Europe","CityX",23),("Europe","CityY",17)]
>>> print l
[('US', 'City1', 125), ('US', 'City2', 123), ('Europe', 'CityX', 23), ('Europe', 'CityY', 17)]
>>> sc = SparkContext(appName="N")
>>> sqlsc = SQLContext(sc)
>>> df = sqlsc.createDataFrame(l)
>>> df.printSchema()
|-- _1: string (nullable = true)
|-- _2: string (nullable = true)
|-- _3: long (nullable = true)
>>> df.registerTempTable("t1")
>>> rdf=sqlsc.sql("Select _1,sum(_3) from t1 group by _1").show()
| _1|_c1|
| US|248|
|Europe| 40|
>>> rdf.printSchema()
|-- _1: string (nullable = true)
|-- _c1: long (nullable = true)
>>> rdf.registerTempTable("t2")
>>> sqlsc.sql("Select * from t2 where _c1 > 200").show()
| _1|_c1|
| US|248|
So basically, I am trying to find all the _3 (which can be population subscribed to some service) which are above threshold in each country. In the above table, there is an additional dataframe is created (rdf)
Now, How do I eliminate the rdf dataframe and embed the complete query within df dataframe itself.
I tried, but pyspark throws error
>>> sqlsc.sql("Select _1,sum(_3) from t1 group by _1").show()
| _1|_c1|
| US|248|
|Europe| 40|
>>> sqlsc.sql("Select _1,sum(_3) from t1 group by _1 where _c1 > 200").show()
Traceback (most recent call last):
File "/ghostcache/kimanjun/spark-1.6.0/python/lib/py4j-0.9-src.zip/py4j/protocol.py", line 308, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o28.sql.
: java.lang.RuntimeException: [1.39] failure: ``union'' expected but `where' found
Here is a solution with no kind of temp tables:
#Do this to don't have conflict with sum in built-in spark functions
from pyspark.sql import sum as _sum
gDf = df.groupBy(df._1).agg(_sum(df._3).alias('sum'))
gDf.filter(gDf.sum > 200).show()
This solution we have a way of group and aggregate with a sum. To make sure that you don't have issues with the sum. Is better to the filter in another object.
I recommend you this link to see some useful ways much more powerful than using direct SQL in the dataframe.
