Issues defining an Aggregator with case class input - apache-spark

I'm trying to define a custom aggregation function which takes a StructType field as an input, using the Aggregator API with Dataframes. Spark version is 3.1.2.
Here's a reduced example (basic one-field case class, being passed in as a Row and reduced with a homemade first() aggregation):
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Encoder, Row, functions => f}
case class MyClass(
myField: String
)
class MyFirst extends Aggregator[MyClass, MyClass, MyClass] {
override def bufferEncoder: Encoder[MyClass] = ExpressionEncoder[MyClass]
override def outputEncoder: Encoder[MyClass] = ExpressionEncoder[MyClass]
override def zero: MyClass = null
override def reduce(b: MyClass, a: MyClass): MyClass = {
if (b == null) return a
b
}
override def merge(b1: MyClass, b2: MyClass): MyClass = {
reduce(b1, b2)
}
override def finish(reduction: MyClass): MyClass = {
reduction
}
}
class UDFsTestSuite extends SparkFunSuite with SharedSparkSession {
test("test myFirst") {
val input_df = spark.createDataFrame(
spark.sparkContext.parallelize(
Seq(
Row(
Row("a")
),
Row(
Row("a")
),
Row(
Row("b")
)
)
),
StructType(
List(
StructField(
"myClass",
StructType(
List(
StructField("myField", StringType)
)
)
)
)
)
)
val myFirst = f.udaf(new MyFirst())
input_df.select(myFirst(f.col("myClass")).as("result")).show()
}
}
This test throws the following analysis error. For some reason, the analyzer says it expects a String input, seemingly due to that being the type of the first field in MyClass (if I define MyClass with an int field, then the error message changes to expecting an Int):
cannot resolve 'MyFirst(myClass)' due to data type mismatch: argument 1 requires string type, however, '`myClass`' is of struct<myField:string> type.;
'Aggregate [myfirst(myClass#1, utilities.MyFirst#d08f85a, class[myField[0]: string], class[myField[0]: string], true, true, 0, 0) AS result#6]
+- LogicalRDD [myClass#1], false
org.apache.spark.sql.AnalysisException: cannot resolve 'MyFirst(myClass)' due to data type mismatch: argument 1 requires string type, however, '`myClass`' is of struct<myField:string> type.;
'Aggregate [myfirst(myClass#1, utilities.MyFirst#d08f85a, class[myField[0]: string], class[myField[0]: string], true, true, 0, 0) AS result#6]
+- LogicalRDD [myClass#1], false
at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:161)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:152)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$2(TreeNode.scala:342)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:74)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:342)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:339)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:408)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:244)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:406)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:359)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:339)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:339)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:408)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:244)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:406)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:359)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:339)
at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$transformExpressionsUp$1(QueryPlan.scala:104)
at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$1(QueryPlan.scala:116)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:74)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpression$1(QueryPlan.scala:116)
at org.apache.spark.sql.catalyst.plans.QueryPlan.recursiveTransform$1(QueryPlan.scala:127)
at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$3(QueryPlan.scala:132)
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at scala.collection.TraversableLike.map(TraversableLike.scala:286)
at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
at scala.collection.AbstractTraversable.map(Traversable.scala:108)
at org.apache.spark.sql.catalyst.plans.QueryPlan.recursiveTransform$1(QueryPlan.scala:132)
at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$4(QueryPlan.scala:137)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:244)
at org.apache.spark.sql.catalyst.plans.QueryPlan.mapExpressions(QueryPlan.scala:137)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsUp(QueryPlan.scala:104)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$1(CheckAnalysis.scala:152)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$1$adapted(CheckAnalysis.scala:93)
at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:184)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis(CheckAnalysis.scala:93)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis$(CheckAnalysis.scala:90)
at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:155)
at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:176)
at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:228)
at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:173)
at org.apache.spark.sql.execution.QueryExecution.$anonfun$analyzed$1(QueryExecution.scala:73)
at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:111)
at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:143)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:143)
at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:73)
at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:71)
at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:63)
at org.apache.spark.sql.Dataset$.$anonfun$ofRows$1(Dataset.scala:90)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:88)
at org.apache.spark.sql.Dataset.withPlan(Dataset.scala:3715)
at org.apache.spark.sql.Dataset.select(Dataset.scala:1462)
at utilities.UDFsTestSuite.$anonfun$new$1(UDFsTestSuite.scala:70)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at org.scalatest.OutcomeOf.outcomeOf(OutcomeOf.scala:85)
at org.scalatest.OutcomeOf.outcomeOf$(OutcomeOf.scala:83)
at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104)
at org.scalatest.Transformer.apply(Transformer.scala:22)
at org.scalatest.Transformer.apply(Transformer.scala:20)
at org.scalatest.funsuite.AnyFunSuiteLike$$anon$1.apply(AnyFunSuiteLike.scala:190)
at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:176)
at org.scalatest.funsuite.AnyFunSuiteLike.invokeWithFixture$1(AnyFunSuiteLike.scala:188)
at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTest$1(AnyFunSuiteLike.scala:200)
at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306)
at org.scalatest.funsuite.AnyFunSuiteLike.runTest(AnyFunSuiteLike.scala:200)
at org.scalatest.funsuite.AnyFunSuiteLike.runTest$(AnyFunSuiteLike.scala:182)
at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(SparkFunSuite.scala:61)
at org.scalatest.BeforeAndAfterEach.runTest(BeforeAndAfterEach.scala:234)
at org.scalatest.BeforeAndAfterEach.runTest$(BeforeAndAfterEach.scala:227)
at org.apache.spark.SparkFunSuite.runTest(SparkFunSuite.scala:61)
at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTests$1(AnyFunSuiteLike.scala:233)
at org.scalatest.SuperEngine.$anonfun$runTestsInBranch$1(Engine.scala:413)
at scala.collection.immutable.List.foreach(List.scala:431)
at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401)
at org.scalatest.SuperEngine.runTestsInBranch(Engine.scala:396)
at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:475)
at org.scalatest.funsuite.AnyFunSuiteLike.runTests(AnyFunSuiteLike.scala:233)
at org.scalatest.funsuite.AnyFunSuiteLike.runTests$(AnyFunSuiteLike.scala:232)
at org.scalatest.funsuite.AnyFunSuite.runTests(AnyFunSuite.scala:1563)
at org.scalatest.Suite.run(Suite.scala:1112)
at org.scalatest.Suite.run$(Suite.scala:1094)
at org.scalatest.funsuite.AnyFunSuite.org$scalatest$funsuite$AnyFunSuiteLike$$super$run(AnyFunSuite.scala:1563)
at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$run$1(AnyFunSuiteLike.scala:237)
at org.scalatest.SuperEngine.runImpl(Engine.scala:535)
at org.scalatest.funsuite.AnyFunSuiteLike.run(AnyFunSuiteLike.scala:237)
at org.scalatest.funsuite.AnyFunSuiteLike.run$(AnyFunSuiteLike.scala:236)
at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:61)
at org.scalatest.BeforeAndAfterAll.liftedTree1$1(BeforeAndAfterAll.scala:213)
at org.scalatest.BeforeAndAfterAll.run(BeforeAndAfterAll.scala:210)
at org.scalatest.BeforeAndAfterAll.run$(BeforeAndAfterAll.scala:208)
at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:61)
at org.scalatest.tools.SuiteRunner.run(SuiteRunner.scala:45)
at org.scalatest.tools.Runner$.$anonfun$doRunRunRunDaDoRunRun$13(Runner.scala:1320)
at org.scalatest.tools.Runner$.$anonfun$doRunRunRunDaDoRunRun$13$adapted(Runner.scala:1314)
at scala.collection.immutable.List.foreach(List.scala:431)
at org.scalatest.tools.Runner$.doRunRunRunDaDoRunRun(Runner.scala:1314)
at org.scalatest.tools.Runner$.$anonfun$runOptionallyWithPassFailReporter$24(Runner.scala:993)
at org.scalatest.tools.Runner$.$anonfun$runOptionallyWithPassFailReporter$24$adapted(Runner.scala:971)
at org.scalatest.tools.Runner$.withClassLoaderAndDispatchReporter(Runner.scala:1480)
at org.scalatest.tools.Runner$.runOptionallyWithPassFailReporter(Runner.scala:971)
at org.scalatest.tools.Runner$.run(Runner.scala:798)
at org.scalatest.tools.Runner.run(Runner.scala)
at org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.runScalaTest2or3(ScalaTestRunner.java:38)
at org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.main(ScalaTestRunner.java:25)
Oddly, if I just wrap the input objects in an array before passing them to the function, and adjust the function accordingly, it works without issue:
class MyFirst2 extends Aggregator[Seq[MyClass], MyClass, MyClass] {
override def bufferEncoder: Encoder[MyClass] = ExpressionEncoder[MyClass]
override def outputEncoder: Encoder[MyClass] = ExpressionEncoder[MyClass]
override def zero: MyClass = null
override def reduce(b: MyClass, a: Seq[MyClass]): MyClass = {
if (b == null) return a.head
b
}
override def merge(b1: MyClass, b2: MyClass): MyClass = {
reduce(b1, Seq(b2))
}
override def finish(reduction: MyClass): MyClass = {
reduction
}
}
class UDFsTestSuiteArray extends SparkFunSuite with SharedSparkSession {
test("test myFirst") {
val input_df = spark.createDataFrame(
spark.sparkContext.parallelize(
Seq(
Row(
Row("a")
),
Row(
Row("a")
),
Row(
Row("b")
)
)
),
StructType(
List(
StructField(
"myClass",
StructType(
List(
StructField("myField", StringType)
)
)
)
)
)
)
def myFirst(col: Column) = {
val func = f.udaf(new MyFirst2())
func(f.array(col))
}
input_df.select(myFirst(f.col("myClass")).as("result")).show()
}
}
Output:
+------+
|result|
+------+
| {a}| // returns first element as expected
+------+
Any ideas for why the first example is not working? Is this a possible Spark bug or am I just misunderstanding something about how the Aggregator API is meant to work?

Related

Spark SQL doesn't call my UDT equals/hashcode methods

I want to implement my comparison operators(equals, hashcode, ordering) in a data type defined by me in Spark SQL. Although Spark SQL UDT's still remains private, I follow some examples like this, to workaround this situation.
I have a class called MyPoint:
#SQLUserDefinedType(udt = classOf[MyPointUDT])
case class MyPoint(x: Double, y: Double) extends Serializable {
override def hashCode(): Int = {
println("hash code")
31 * (31 * x.hashCode()) + y.hashCode()
}
override def equals(other: Any): Boolean = {
println("equals")
other match {
case that: MyPoint => this.x == that.x && this.y == that.y
case _ => false
}
}
Then, I have the UDT class:
private class MyPointUDT extends UserDefinedType[MyPoint] {
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
override def serialize(obj: MyPoint): ArrayData = {
obj match {
case features: MyPoint =>
new GenericArrayData2(Array(features.x, features.y))
}
}
override def deserialize(datum: Any): MyPoint = {
datum match {
case data: ArrayData if data.numElements() == 2 => {
val arr = data.toDoubleArray()
new MyPoint(arr(0), arr(1))
}
}
}
override def userClass: Class[MyPoint] = classOf[MyPoint]
override def asNullable: MyPointUDT = this
}
Then I create a simple DataFrame:
val p1 = new MyPoint(1.0, 2.0)
val p2 = new MyPoint(1.0, 2.0)
val p3 = new MyPoint(10.0, 20.0)
val p4 = new MyPoint(11.0, 22.0)
val points = Seq(
("P1", p1),
("P2", p2),
("P3", p3),
("P4", p4)
).toDF("label", "point")
points.registerTempTable("points")
spark.sql("SELECT Distinct(point) FROM points").show()
The problem is: Why the SQL query doesn't execute the equals method inside MyPoint class? How comparasions are being made? How can I implement my comparasion operators in this example?

Spark AnalysisException when "flattening" DataFrame in Spark SQL

I'm using the approach given here to flatten a DataFrame in Spark SQL. Here is my code:
package com.acme.etl.xml
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, SparkSession}
object RuntimeError { def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("FlattenSchema").getOrCreate()
val rowTag = "idocData"
val dataFrameReader =
spark.read
.option("rowTag", rowTag)
val xmlUri = "bad_011_1.xml"
val df =
dataFrameReader
.format("xml")
.load(xmlUri)
val schema: StructType = df.schema
val columns: Array[Column] = flattenSchema(schema)
val df2 = df.select(columns: _*)
}
def flattenSchema(schema: StructType, prefix: String = null) : Array[Column] = {
schema.fields.flatMap(f => {
val colName: String = if (prefix == null) f.name else prefix + "." + f.name
val dataType = f.dataType
dataType match {
case st: StructType => flattenSchema(st, colName)
case _: StringType => Array(new org.apache.spark.sql.Column(colName))
case _: LongType => Array(new org.apache.spark.sql.Column(colName))
case _: DoubleType => Array(new org.apache.spark.sql.Column(colName))
case arrayType: ArrayType => arrayType.elementType match {
case structType: StructType => flattenSchema(structType, colName)
}
case _ => Array(new org.apache.spark.sql.Column(colName))
}
})
}
}
Much of the time, this works fine. But for the XML given below:
<Receive xmlns="http://Microsoft.LobServices.Sap/2007/03/Idoc/3/ORDERS05/ZORDERS5/702/Receive">
<idocData>
<E2EDP01008GRP xmlns="http://Microsoft.LobServices.Sap/2007/03/Types/Idoc/3/ORDERS05/ZORDERS5/702">
<E2EDPT1001GRP>
<E2EDPT2001>
<DATAHEADERCOLUMN_DOCNUM>0000000141036013</DATAHEADERCOLUMN_DOCNUM>
</E2EDPT2001>
<E2EDPT2001>
<DATAHEADERCOLUMN_DOCNUM>0000000141036013</DATAHEADERCOLUMN_DOCNUM>
</E2EDPT2001>
</E2EDPT1001GRP>
</E2EDP01008GRP>
<E2EDP01008GRP xmlns="http://Microsoft.LobServices.Sap/2007/03/Types/Idoc/3/ORDERS05/ZORDERS5/702">
</E2EDP01008GRP>
</idocData>
</Receive>
this exception occurs:
Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve '`E2EDP01008GRP`.`E2EDPT1001GRP`.`E2EDPT2001`['DATAHEADERCOLUMN_DOCNUM']' due to data type mismatch: argument 2 requires integral type, however, ''DATAHEADERCOLUMN_DOCNUM'' is of string type.;;
'Project [E2EDP01008GRP#0.E2EDPT1001GRP.E2EDPT2001[DATAHEADERCOLUMN_DOCNUM] AS DATAHEADERCOLUMN_DOCNUM#3, E2EDP01008GRP#0._VALUE AS _VALUE#4, E2EDP01008GRP#0._xmlns AS _xmlns#5]
+- Relation[E2EDP01008GRP#0] XmlRelation(<function0>,Some(/Users/paulreiners/s3/cdi-events-partition-staging/content_acme_purchase_order_json_v1/bad_011_1.xml),Map(rowtag -> idocData, path -> /Users/paulreiners/s3/cdi-events-partition-staging/content_acme_purchase_order_json_v1/bad_011_1.xml),null)
What is causing this?
Your document contains a multi-valued array so you can't flatten it completely in one pass since you can't give both elements of the array the same column name.
Also, it's usually a bad idea to use a dot within a column name since it can easily confuse the Spark parser and will need to be escaped at all time.
The usual way to flatten such a dataset is to create new rows for each element of the array.
You can use the explode function to do this but you will need to recursively call your flatten operation because explode can't be nested.
The following code works as expected, using '_' instead of '.' as column name separator:
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, SparkSession}
import org.apache.spark.sql.{Dataset, Row}
object RuntimeError {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("FlattenSchema").getOrCreate()
val rowTag = "idocData"
val dataFrameReader = spark.read.option("rowTag", rowTag)
val xmlUri = "bad_011_1.xml"
val df = dataFrameReader.format("xml").load(xmlUri)
val df2 = flatten(df)
}
def flatten(df: Dataset[Row], prefixSeparator: String = "_") : Dataset[Row] = {
import org.apache.spark.sql.functions.{col,explode}
def mustFlatten(sc: StructType): Boolean =
sc.fields.exists(f => f.dataType.isInstanceOf[ArrayType] || f.dataType.isInstanceOf[StructType])
def flattenAndExplodeOne(sc: StructType, parent: Column = null, prefix: String = null, cols: Array[(DataType,Column)] = Array[(DataType,Column)]()): Array[(DataType,Column)] = {
val res = sc.fields.foldLeft(cols)( (columns, f) => {
val my_col = if (parent == null) col(f.name) else parent.getItem(f.name)
val flat_name = if (prefix == null) f.name else s"${prefix}${prefixSeparator}${f.name}"
f.dataType match {
case st: StructType => flattenAndExplodeOne(st, my_col, flat_name, columns)
case dt: ArrayType => {
if (columns.exists(_._1.isInstanceOf[ArrayType])) {
columns :+ ((dt, my_col.as(flat_name)))
} else {
columns :+ ((dt, explode(my_col).as(flat_name)))
}
}
case dt => columns :+ ((dt, my_col.as(flat_name)))
}
})
res
}
var flatDf = df
while (mustFlatten(flatDf.schema)) {
val newColumns = flattenAndExplodeOne(flatDf.schema, null, null).map(_._2)
flatDf = flatDf.select(newColumns:_*)
}
flatDf
}
}
The resulting df2 has the following schema and data:
df2.printSchema
root
|-- E2EDP01008GRP_E2EDPT1001GRP_E2EDPT2001_DATAHEADERCOLUMN_DOCNUM: long (nullable = true)
|-- E2EDP01008GRP__xmlns: string (nullable = true)
df2.show(true)
+--------------------------------------------------------------+--------------------+
|E2EDP01008GRP_E2EDPT1001GRP_E2EDPT2001_DATAHEADERCOLUMN_DOCNUM|E2EDP01008GRP__xmlns|
+--------------------------------------------------------------+--------------------+
| 141036013|http://Microsoft....|
| 141036013|http://Microsoft....|
+--------------------------------------------------------------+--------------------+

Spark AccumulatorV2 with HashMap

I am trying to create a custom AccumulatorV2 with a hash map, the input would be hashmap and output would be a map of HashMap,
My intention is to have a K -> K1,V, where the value will increment. I am confused by the scala syntax for overriding AccumulatorV2 for Map, did anyone had a luck with this.
class CustomAccumulator extends AccumulatorV2[java.util.Map[String, String], java.util.Map[String,java.util.Map[String, Double]]]
I'm assuming that this is the scenario that needs to be implemented.
Input:
HashMap<String, String>
Output:
Should output a HashMap<String, HashMap<String, Double>>, where the second hashmap contains the count of values corresponding to the keys.
Example:
Inputs:(Following HashMaps are added to the accumulator)
Input HashMap1 -> {"key1", "value1"}, {"key2", "value1"}, {"key3", "value3"}
Input HashMap2 -> {"key1", "value1"}, {"key2", "value1"}
Input HashMap3 -> {"key2", "value1"}
Output:
{"key1", {"value1", 2}}, {"key2", {"value1", 3}}, {"key3", {"value3", 1}}
Code below:
import java.util
import java.util.Map.Entry
import java.util.{HashMap, Map}
import java.util.function.{BiFunction, Consumer}
import scala.collection.JavaConversions._
import org.apache.spark.util.AccumulatorV2
import org.datanucleus.store.rdbms.fieldmanager.OldValueParameterSetter
class CustomAccumulator extends AccumulatorV2[Map[String, String], Map[String, Map[String,Double]]] {
private var hashmap : Map[String, Map[String, Double]] = new HashMap[String, Map[String, Double]];
override def isZero: Boolean = {
return hashmap.size() == 0
}
override def copy(): AccumulatorV2[util.Map[String, String], util.Map[String, util.Map[String, Double]]] = {
var customAccumulatorcopy = new CustomAccumulator()
customAccumulatorcopy.merge(this)
return customAccumulatorcopy
}
override def reset(): Unit = {
this.hashmap = new HashMap[String, Map[String, Double]];
}
override def add(v: util.Map[String, String]): Unit = {
v.foreach(kv => {
val unitValueDouble : Double = 1;
if(this.hashmap.containsKey(kv._1)){
val innerMap = this.hashmap.get(kv._1)
innerMap.merge(kv._2, unitValueDouble, addFunction)
}
else {
val innerMap : Map[String, Double] = new HashMap[String, Double]()
innerMap.put(kv._2, unitValueDouble)
this.hashmap.put(kv._1, innerMap)
}
}
)
}
override def merge(otherAccumulator: AccumulatorV2[util.Map[String, String], util.Map[String, util.Map[String, Double]]]): Unit = {
otherAccumulator.value.foreach(kv => {
this.hashmap.merge(kv._1, kv._2, mergeMapsFunction)
})
}
override def value: util.Map[String, util.Map[String, Double]] = {
return this.hashmap
}
val mergeMapsFunction = new BiFunction[Map[String, Double], Map[String, Double], Map[String, Double]] {
override def apply(oldMap: Map[String, Double], newMap: Map[String, Double]): Map[String, Double] = {
newMap.foreach(kv => {
oldMap.merge(kv._1, kv._2, addFunction);
})
oldMap
}
}
val addFunction = new BiFunction[Double, Double, Double] {
override def apply(oldValue: Double, newValue: Double): Double = oldValue + newValue
}
}
Thanks!!!

UnaryTransformer instance throwing ClassCastException

I have a requirement to create my own UnaryTransformer instance that accepts a Dataframe Column of type Array[String] and should also output the same type.In trying to do so,I encountered a ClassCastException on my Spark version 2.1.0.
I've put together a sample test that shows my case.
import org.apache.spark.SparkConf
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
class MyTransformer(override val uid:String) extends UnaryTransformer[Array[String],Array[String],MyTransformer] {
override protected def createTransformFunc: (Array[String]) => Array[String] = {
param1 => {
param1.foreach(println(_))
param1
}
}
override protected def outputDataType: DataType = ArrayType(StringType)
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == ArrayType(StringType), s"Data type mismatch between Array[String] and provided type $inputType.")
}
def this() = this( Identifiable.randomUID("tester") )
}
object Tester {
def main(args: Array[String]): Unit = {
val config = new SparkConf().setAppName("Tester")
implicit val sparkSession = SparkSession.builder().config(config).getOrCreate()
import sparkSession.implicits._
val dataframe = Seq(Array("Firstly" , "F1"),Array("Driving" , "S1" ),Array("Ran" , "T3" ),Array("Fourth" ,"F4"), Array("Running" , "F5")
,Array("Gone" , "S6")).toDF("input")
val transformer = new MyTransformer().setInputCol("input").setOutputCol("output")
val transformed = transformer.transform(dataframe)
transformed.select("output").show()
println("Complete....")
sparkSession.close()
}
}
Attaching the stack trace for reference
Exception in thread "main" org.apache.spark.SparkException: Failed to
execute user defined function($anonfun$createTransformFunc$1:
(array) => array) at
org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1072)
at
org.apache.spark.sql.catalyst.expressions.Alias.eval(namedExpressions.scala:144)
at
org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:48)
at
org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:30)
at
scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at
scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.immutable.List.foreach(List.scala:392) at
scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.immutable.List.map(List.scala:296) at
org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation$$anonfun$apply$21.applyOrElse(Optimizer.scala:1078)
at
org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation$$anonfun$apply$21.applyOrElse(Optimizer.scala:1073)
at
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:288)
at
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:288)
at
org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
at
org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:287)
at
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:331)
at
org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:188)
at
org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:329)
at
org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:293)
at
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:331)
at
org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:188)
at
org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:329)
at
org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:293)
at
org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:277)
at
org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation$.apply(Optimizer.scala:1073)
at
org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation$.apply(Optimizer.scala:1072)
at
org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:85)
at
org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:82)
at
scala.collection.IndexedSeqOptimized$class.foldl(IndexedSeqOptimized.scala:57)
at
scala.collection.IndexedSeqOptimized$class.foldLeft(IndexedSeqOptimized.scala:66)
at
scala.collection.mutable.WrappedArray.foldLeft(WrappedArray.scala:35)
at
org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:82)
at
org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:74)
at scala.collection.immutable.List.foreach(List.scala:392) at
org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:74)
at
org.apache.spark.sql.execution.QueryExecution.optimizedPlan$lzycompute(QueryExecution.scala:73)
at
org.apache.spark.sql.execution.QueryExecution.optimizedPlan(QueryExecution.scala:73)
at
org.apache.spark.sql.execution.QueryExecution.sparkPlan$lzycompute(QueryExecution.scala:79)
at
org.apache.spark.sql.execution.QueryExecution.sparkPlan(QueryExecution.scala:75)
at
org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:84)
at
org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:84)
at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2791)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2112) at
org.apache.spark.sql.Dataset.take(Dataset.scala:2327) at
org.apache.spark.sql.Dataset.showString(Dataset.scala:248) at
org.apache.spark.sql.Dataset.show(Dataset.scala:636) at
org.apache.spark.sql.Dataset.show(Dataset.scala:595) at
org.apache.spark.sql.Dataset.show(Dataset.scala:604) at
Tester$.main(Tester.scala:45) at Tester.main(Tester.scala)
Caused by: java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to
[Ljava.lang.String; at
MyTransformer$$anonfun$createTransformFunc$1.apply(Tester.scala:9)
at
org.apache.spark.sql.catalyst.expressions.ScalaUDF$$anonfun$2.apply(ScalaUDF.scala:89)
at
org.apache.spark.sql.catalyst.expressions.ScalaUDF$$anonfun$2.apply(ScalaUDF.scala:88)
at
org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1069)
... 53 more
ArrayType is represented as Seq not Array:
override protected def createTransformFunc: (Seq[String]) => Seq[String] = {
param1 => {
param1.foreach(println(_))
param1
}
}

Exception when using UDT in Spark DataFrame

I'm trying to create a user defined type in spark sql, but I receive:
com.ubs.ged.risk.stdout.spark.ExamplePointUDT cannot be cast to org.apache.spark.sql.types.StructType even when using their example. Has anyone made this work?
My code:
test("udt serialisation") {
val points = Seq(new ExamplePoint(1.3, 1.6), new ExamplePoint(1.3, 1.8))
val df = SparkContextForStdout.context.parallelize(points).toDF()
}
#SQLUserDefinedType(udt = classOf[ExamplePointUDT])
case class ExamplePoint(val x: Double, val y: Double)
/**
* User-defined type for [[ExamplePoint]].
*/
class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
override def sqlType: DataType = ArrayType(DoubleType, false)
override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
override def serialize(obj: Any): Seq[Double] = {
obj match {
case p: ExamplePoint =>
Seq(p.x, p.y)
}
}
override def deserialize(datum: Any): ExamplePoint = {
datum match {
case values: Seq[_] =>
val xy = values.asInstanceOf[Seq[Double]]
assert(xy.length == 2)
new ExamplePoint(xy(0), xy(1))
case values: util.ArrayList[_] =>
val xy = values.asInstanceOf[util.ArrayList[Double]].asScala
new ExamplePoint(xy(0), xy(1))
}
}
override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]
}
The usefull stackstrace is this:
com.ubs.ged.risk.stdout.spark.ExamplePointUDT cannot be cast to org.apache.spark.sql.types.StructType
java.lang.ClassCastException: com.ubs.ged.risk.stdout.spark.ExamplePointUDT cannot be cast to org.apache.spark.sql.types.StructType
at org.apache.spark.sql.SQLContext.createDataFrame(SQLContext.scala:316)
at org.apache.spark.sql.SQLContext$implicits$.rddToDataFrameHolder(SQLContext.scala:254)
It seems that the UDT needs to be used inside of another class to work (as the type of a field). One solution to use it directly is to wrap it into a Tuple1:
test("udt serialisation") {
val points = Seq(new Tuple1(new ExamplePoint(1.3, 1.6)), new Tuple1(new ExamplePoint(1.3, 1.8)))
val df = SparkContextForStdout.context.parallelize(points).toDF()
df.collect().foreach(println(_))
}

Resources