Trying to apply GBT on a set of data getting ClassCastException - apache-spark

I am getting "Exception in thread "main" java.lang.ClassCastException: org.apache.spark.ml.attribute.UnresolvedAttribute$ cannot be cast to org.apache.spark.ml.attribute.NominalAttribute".
Source code
package com.spark.lograthmicregression;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Calendar;
import java.util.Date;
import java.util.HashSet;
import java.util.Set;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.GBTClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import com.google.common.collect.ImmutableMap;
import scala.collection.mutable.Seq;
public class ClickThroughRateAnalytics {
private static SimpleDateFormat sdf = new SimpleDateFormat("yyMMddHH");
public static void main(String[] args) {
final SparkConf sparkConf = new SparkConf().setAppName("Click Analysis").setMaster("local");
try (JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf)) {
SQLContext sqlContext = new SQLContext(javaSparkContext);
DataFrame dataFrame = sqlContext.read().format("com.databricks.spark.csv").option("inferSchema", "true").option("header", "true")
.load("/splits/sub-suaa");
// This will keep data in memory
dataFrame.cache();
// This will describe the column
// dataFrame.describe("hour").show();
System.out.println("Rows before removing missing data : " + dataFrame.count());
// This will describe column details
// dataFrame.describe("click", "hour", "site_domain").show();
// This will calculate variance between columns +ve one increases
// second increases and -ve means one increases other decreases
// double cov = dataFrame.stat().cov("click", "hour");
// System.out.println("cov : " + cov);
// It provides quantitative measurements of the statistical
// dependence between two random variables
// double corr = dataFrame.stat().corr("click", "hour");
// System.out.println("corr : " + corr);
// Cross Tabulation provides a table of the frequency distribution
// for a set of variables
// dataFrame.stat().crosstab("site_id", "site_domain").show();
// For frequent items
// System.out.println("Frequest Items : " +
// dataFrame.stat().freqItems(new String[] { "site_id",
// "site_domain" }, 0.3).collectAsList());
// TODO we can also set maximum occurring item to categorical
// values.
// This will replace null values with average for numeric columns
dataFrame = modifiyDatFrame(dataFrame);
// Removing rows which have some missing values
dataFrame = dataFrame.na().replace(dataFrame.columns(), ImmutableMap.of("", "NA"));
dataFrame.na().fill(0.0);
dataFrame = dataFrame.na().drop();
System.out.println("Rows after removing missing data : " + dataFrame.count());
// TODO Binning and bucketing
// normalizer will take the column created by the VectorAssembler,
// normalize it and produce a new column
// Normalizer normalizer = new
// Normalizer().setInputCol("features_index").setOutputCol("features");
dataFrame = dataFrame.drop("app_category_index").drop("app_domain_index").drop("hour_index").drop("C20_index")
.drop("device_connection_type_index").drop("C1_index").drop("id").drop("device_ip_index").drop("banner_pos_index");
DataFrame[] splits = dataFrame.randomSplit(new double[] { 0.7, 0.3 });
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
StringIndexerModel labelIndexer = new StringIndexer().setInputCol("click").setOutputCol("indexedclick").fit(dataFrame);
// Here we will be sending all columns which will participate in
// prediction
VectorAssembler vectorAssembler = new VectorAssembler().setInputCols(findPredictionColumns("click", dataFrame))
.setOutputCol("features_index");
GBTClassifier gbt = new GBTClassifier().setLabelCol("indexedclick").setFeaturesCol("features_index").setMaxIter(10).setMaxBins(69000);
IndexToString labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel");
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { labelIndexer, vectorAssembler, gbt, labelConverter });
trainingData.show(1);
PipelineModel model = pipeline.fit(trainingData);
DataFrame predictions = model.transform(testData);
predictions.select("predictedLabel", "label").show(5);
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel")
.setPredictionCol("prediction").setMetricName("precision");
double accuracy = evaluator.evaluate(predictions);
System.out.println("Test Error = " + (1.0 - accuracy));
GBTClassificationModel gbtModel = (GBTClassificationModel) (model.stages()[2]);
System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString());
}
}
private static String[] findPredictionColumns(String outputCol, DataFrame dataFrame) {
String columns[] = dataFrame.columns();
String inputColumns[] = new String[columns.length - 1];
int count = 0;
for (String column : dataFrame.columns()) {
if (!column.equalsIgnoreCase(outputCol)) {
inputColumns[count++] = column;
}
}
return inputColumns;
}
/**
* This will replace empty values with mean.
*
* #param columnName
* #param dataFrame
* #return
*/
private static DataFrame modifiyDatFrame(DataFrame dataFrame) {
Set<String> numericColumns = new HashSet<String>();
if (dataFrame.numericColumns() != null && dataFrame.numericColumns().length() > 0) {
scala.collection.Iterator<Expression> iterator = ((Seq<Expression>) dataFrame.numericColumns()).toIterator();
while (iterator.hasNext()) {
Expression expression = iterator.next();
Double avgAge = dataFrame.na().drop().groupBy(((AttributeReference) expression).name()).avg(((AttributeReference) expression).name())
.first().getDouble(1);
dataFrame = dataFrame.na().fill(avgAge, new String[] { ((AttributeReference) expression).name() });
numericColumns.add(((AttributeReference) expression).name());
DataType dataType = ((AttributeReference) expression).dataType();
if (!"double".equalsIgnoreCase(dataType.simpleString())) {
dataFrame = dataFrame.withColumn("temp", dataFrame.col(((AttributeReference) expression).name()).cast(DataTypes.DoubleType))
.drop(((AttributeReference) expression).name()).withColumnRenamed("temp", ((AttributeReference) expression).name());
}
}
}
// Fit method of StringIndexer converts the column to StringType(if
// it is not of StringType) and then counts the occurrence of each
// word. It then sorts these words in descending order of their
// frequency and assigns an index to each word. StringIndexer.fit()
// method returns a StringIndexerModel which is a Transformer
StringIndexer stringIndexer = new StringIndexer();
String allCoumns[] = dataFrame.columns();
for (String column : allCoumns) {
if (!numericColumns.contains(column)) {
dataFrame = stringIndexer.setInputCol(column).setOutputCol(column + "_index").fit(dataFrame).transform(dataFrame);
dataFrame = dataFrame.drop(column);
}
}
dataFrame.printSchema();
return dataFrame;
}
#SuppressWarnings("unused")
private static void copyFile(DataFrame dataFrame) {
dataFrame
.select("id", "click", "hour", "C1", "banner_pos", "site_id", "site_domain", "site_category", "app_id", "app_domain", "app_category",
"device_id", "device_ip", "device_model", "device_type", "device_conn_type", "C14", "C15", "C16", "C17", "C18", "C19", "C20",
"C21")
.write().format("com.databricks.spark.csv").option("header", "true").option("codec", "org.apache.hadoop.io.compress.GzipCodec")
.save("/splits/sub-splitaa-optmized");
}
#SuppressWarnings("unused")
private static Integer parse(String sDate, int field) {
try {
if (sDate != null && !sDate.toString().equalsIgnoreCase("hour")) {
Date date = sdf.parse(sDate.toString());
Calendar cal = Calendar.getInstance();
cal.setTime(date);
return cal.get(field);
}
} catch (ParseException e) {
e.printStackTrace();
}
return 0;
}
}
I am using spark java. Sample file will be :
id,click,hour,C1,banner_pos,site_id,site_domain,site_category,app_id,app_domain,app_category,device_id,device_ip,device_model,device_type,device_conn_type,C14,C15,C16,C17,C18,C19,C20,C21
1000009418151094273,0,14102100,1005,0,1fbe01fe,f3845767,28905ebd,ecad2386,7801e8d9,07d7df22,a99f214a,ddd2926e,44956a24,1,2,15706,320,50,1722,0,35,-1,79
10000169349117863715,0,14102100,1005,0,1fbe01fe,f3845767,28905ebd,ecad2386,7801e8d9,07d7df22,a99f214a,96809ac8,711ee120,1,0,15704,320,50,1722,0,35,100084,79
10000371904215119486,0,14102100,1005,0,1fbe01fe,f3845767,28905ebd,ecad2386,7801e8d9,07d7df22,a99f214a,b3cf8def,8a4875bd,1,0,15704,320,50,1722,0,35,100084,79
10000640724480838376,0,14102100,1005,0,1fbe01fe,f3845767,28905ebd,ecad2386,7801e8d9,07d7df22,a99f214a,e8275b8f,6332421a,1,0,15706,320,50,1722,0,35,100084,79
10000679056417042096,0,14102100,1005,1,fe8cc448,9166c161,0569f928,ecad2386,7801e8d9,07d7df22,a99f214a,9644d0bf,779d90c2,1,0,18993,320,50,2161,0,35,-1,157
10000720757801103869,0,14102100,1005,0,d6137915,bb1ef334,f028772b,ecad2386,7801e8d9,07d7df22,a99f214a,05241af0,8a4875bd,1,0,16920,320,50,1899,0,431,100077,117
10000724729988544911,0,14102100,1005,0,8fda644b,25d4cfcd,f028772b,ecad2386,7801e8d9,07d7df22,a99f214a,b264c159,be6db1d7,1,0,20362,320,50,2333,0,39,-1,157

I am late in replying but I was also facing the same error while using gbt for a dataset in csv file.
I added .setLabels(labelIndexer.labels()) in labelConverter and this solved the problem.
IndexToString labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels())

Related

Write the final dataset output for spark Java into s3

I am not able to find the correct way to write data to s3 from dataset spark. What should be some more configurations that I should add. Do I have to mention the AWS configurations in my code or it will pick it up from local .aws/ profile?
Please guide
import java.util.Properties;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class sparkSqlMysql {
private static final org.apache.log4j.Logger LOGGER = org.apache.log4j.Logger.getLogger(sparkSqlMysql.class);
private static final SparkSession sparkSession = SparkSession.builder().master("local[*]").appName("Spark2JdbcDs")
.getOrCreate();
public static void main(String[] args) {
// JDBC connection properties
final Properties connectionProperties = new Properties();
connectionProperties.put("user", "root");
connectionProperties.put("password", "password");
connectionProperties.put("driver", "com.mysql.jdbc.Driver");
final String dbTable = "(select * from Fielding) t";
final String dbTable1 = "(select * from Salaries) m";
final String dbTable2 = "(select * from Pitching) n";
// Load MySQL query result as Dataset
Dataset<Row> jdbcDF2 = sparkSession.read().jdbc("jdbc:mysql://localhost:3306/lahman2016", dbTable,
connectionProperties);
Dataset<Row> jdbcDF3 = sparkSession.read().jdbc("jdbc:mysql://localhost:3306/lahman2016", dbTable1,
connectionProperties);
Dataset<Row> jdbcDF4 = sparkSession.read().jdbc("jdbc:mysql://localhost:3306/lahman2016", dbTable2,
connectionProperties);
jdbcDF2.createOrReplaceTempView("Fielding");
jdbcDF3.createOrReplaceTempView("Salaries");
jdbcDF4.createOrReplaceTempView("Pitching");
Dataset<Row> sqlDF = sparkSession.sql(
"select Salaries.yearID, avg(Salaries.salary) as Fielding from Salaries inner join Fielding ON Salaries.yearID = Fielding.yearID AND Salaries.playerID = Fielding.playerID group by Salaries.yearID limit 5");
Dataset<Row> sqlDF1 = sparkSession.sql(
"select Salaries.yearID, avg(Salaries.salary) as Pitching from Salaries inner join Pitching ON Salaries.yearID = Pitching.yearID AND Salaries.playerID = Pitching.playerID group by Salaries.yearID limit 5");
// sqlDF.show();
// sqlDF1.show();
sqlDF.createOrReplaceTempView("avg_fielding");
sqlDF1.createOrReplaceTempView("avg_pitching");
Dataset<Row> final_query_1_output = sparkSession.sql(
"select avg_fielding.yearID, avg_fielding.Fielding, avg_pitching.Pitching from avg_fielding inner join avg_pitching ON avg_pitching.yearID = avg_fielding.yearID");
final_query_1_output.show();
The output of the query is :
final_query_1_output.show();
+------+------------------+------------------+
|yearID| Fielding| Pitching|
+------+------------------+------------------+
| 1990| 507978.625320787| 485947.2487437186|
| 2003|2216200.9609838845|2133800.1867612293|
| 2007|2633213.0126475547|2617533.3393665156|
| 2015|3996199.5729421354| 3955581.121535181|
| 2006| 2565803.492487479| 2534756.866972477|
+------+------------------+------------------+
I want to write this dataset to s3 : how can I do that?
final_query_1_output.write().mode("overwrite").save("s3n://druids3migration/data.csv");

how to sort string characters alphabetically in pig latin

If I have a table as below:name:chararray
example :
acab
bca
das
desac
How can I order the column 'name' alphabetically in Pig?
like this :
aacb
abc
ads
acdes
Write your own UDF.Pass the name to UDF.In the UDF convert the string to a chararray and sort it.Return the sorted string.
PIG
REGISTER ORDER_UDF.jar;
A = LOAD 'data.txt' USING PigStorage(',') AS (name: chararray);
B = FOREACH A GENERATE ORDER_UDF.ORDER(name);
DUMP B;
UDF
import java.io.IOException;
import java.util.Arrays;
import org.apache.pig.EvalFunc;
import org.apache.pig.data.Tuple;
public class ORDER extends EvalFunc<String>
{
public String exec(Tuple input) throws IOException {
if (input == null || input.size() == 0)
return null;
try{
char tempArray[] = ((String)input).toCharArray();
Arrays.sort(tempArray);
return new String(tempArray);
}catch(Exception e){
throw new IOException("Caught exception processing input row ", e);
}
}
}

How to get list of Named range,sheet name and referance formuls using XSSF and SAX (Event API) for large excel file

I'm tring to to read large excel file (size~10MB,.xlsx) .
I'm using below code
Workbook xmlworkbook =WorkbookFactory.create(OPCPackage.openOrCreate(root_path_name_file));
But it's showing Heap memory issue.
I have also seen other solution on StackOverflow some of them given to increase the JVM but i dont want to increase jvm.
Issue 1) We can't use SXSSF (Streaming Usermodel API) because this is only for writing or creating new workbook.
My sole objective to get the number of NamedRange of sheet, Total number of sheet and their sheet name for large excel file.
If the requirement is only to get the named ranges and sheet names, then only the /xl/workbook.xml from the *.xlsx ZIPPackage must be parsed since those informations are all stored there.
This is possible by getting the appropriate PackagePart and parsing the XML from this. For parsing XML my favorite is using StAX.
Example code which gets all sheet names and defined named ranges:
import org.apache.poi.openxml4j.opc.OPCPackage;
import org.apache.poi.openxml4j.opc.PackagePart;
import javax.xml.stream.XMLEventReader;
import javax.xml.stream.XMLInputFactory;
import javax.xml.stream.events.StartElement;
import javax.xml.stream.events.EndElement;
import javax.xml.stream.events.Characters;
import javax.xml.stream.events.Attribute;
import javax.xml.stream.events.XMLEvent;
import javax.xml.namespace.QName;
import java.io.File;
import java.util.regex.Pattern;
import java.util.List;
import java.util.ArrayList;
import java.util.Map;
import java.util.HashMap;
class StaxReadOPCPackageParts {
public static void main(String[] args) {
try {
File file = new File("file.xlsx");
OPCPackage opcpackage = OPCPackage.open(file);
//get the workbook package part
PackagePart workbookpart = opcpackage.getPartsByName(Pattern.compile("/xl/workbook.xml")).get(0);
//create reader for package part
XMLEventReader reader = XMLInputFactory.newInstance().createXMLEventReader(workbookpart.getInputStream());
List<String> sheetNames = new ArrayList<>();
Map<String, String> definedNames = new HashMap<>();
boolean isInDefinedName = false;
String sheetName = "";
String definedNameName = "";
StringBuffer definedNameFormula = new StringBuffer();
while(reader.hasNext()){ //loop over all XML in workbook.xml
XMLEvent event = (XMLEvent)reader.next();
if(event.isStartElement()) {
StartElement startElement = (StartElement)event;
QName startElementName = startElement.getName();
if(startElementName.getLocalPart().equalsIgnoreCase("sheet")) { //start element of sheet definition
Attribute attribute = startElement.getAttributeByName(new QName("name"));
sheetName = attribute.getValue();
sheetNames.add(sheetName);
} else if (startElementName.getLocalPart().equalsIgnoreCase("definedName")) { //start element of definedName
Attribute attribute = startElement.getAttributeByName(new QName("name"));
definedNameName = attribute.getValue();
isInDefinedName = true;
}
} else if(event.isCharacters() && isInDefinedName) { //character content of definedName == the formula
definedNameFormula.append(((Characters)event).getData());
} else if(event.isEndElement()) {
EndElement endElement = (EndElement)event;
QName endElementName = endElement.getName();
if(endElementName.getLocalPart().equalsIgnoreCase("definedName")) { //end element of definedName
definedNames.put(definedNameName, definedNameFormula.toString());
definedNameFormula = new StringBuffer();
isInDefinedName = false;
}
}
}
opcpackage.close();
System.out.println("Sheet names:");
for (String shName : sheetNames) {
System.out.println("Sheet name: " + shName);
}
System.out.println("Named ranges:");
for (String defName : definedNames.keySet()) {
System.out.println("Name: " + defName + ", Formula: " + definedNames.get(defName));
}
} catch (Exception ex) {
ex.printStackTrace();
}
}
}

Combining Two Spark Streams On Key

I have two kafka streams that contain results for two parallel operations, I need a way to combine both streams so I can process the results in a single spark transform. Is this possible? (illustration below)
Stream 1 {id:1,result1:True}
Stream 2 {id:1,result2:False}
JOIN(Stream 1, Stream 2, On "id") -> Output Stream {id:1,result1:True,result2:False}
Current code that isn't working:
kvs1 = KafkaUtils.createStream(sparkstreamingcontext, ZOOKEEPER, NAME+"_stream", {"test_join_1": 1})
kvs2 = KafkaUtils.createStream(sparkstreamingcontext, ZOOKEEPER, NAME+"_stream", {"test_join_2": 1})
messages_RDDstream1 = kvs1.map(lambda x: x[1])
messages_RDDstream2 = kvs2.map(lambda x: x[1])
messages_RDDstream_Final = messages_RDDstream1.join(messages_RDDstream2)
When I pass two sample jsons to each Kafka queue with the same ID field, nothing is returned in my final RDD stream. I imaging I am missing the stage of converting my Kafka JSON string message into a Tuple?
I have also tried the following:
kvs1.map(lambda (key, value): json.loads(value))
and
kvs1.map(lambda x: json.loads(x))
To no avail
Cheers
Adam
A simple lookup on Spark's documentation would have given you the answer..
You can use the join operation.
join(otherStream, [numTasks]) :
When called on two DStreams of (K, V) and (K, W) pairs, return a new DStream of (K, (V, W)) pairs with all pairs of elements for each key.
For example : val streamJoined = stream1.join(stream2)
What you need can be done using the join() method of key-value pair DStreams:
// Test data
val input1 = List((1, true), (2, false), (3, false), (4, true), (5, false))
val input2 = List((1, false), (2, false), (3, true), (4, true), (5, true))
val input1RDD = sc.parallelize(input1)
val input2RDD = sc.parallelize(input2)
import org.apache.spark.streaming.{Seconds, StreamingContext}
val streamingContext = new StreamingContext(sc, Seconds(3))
// Creates a DStream from the test data
import scala.collection.mutable
val input1DStream = streamingContext.queueStream[(Int, Boolean)](mutable.Queue(input1RDD))
val input2DStream = streamingContext.queueStream[(Int, Boolean)](mutable.Queue(input2RDD))
// Join the two streams together by merging them into a single dstream
val joinedDStream = input1DStream.join(input2DStream)
// Print the result
joinedDStream.print()
// Start the context, time out after one batch, and then stop it
streamingContext.start()
streamingContext.awaitTerminationOrTimeout(5000)
streamingContext.stop()
Results in:
-------------------------------------------
Time: 1468313607000 ms
-------------------------------------------
(4,(true,true))
(2,(false,false))
(1,(true,false))
(3,(false,true))
(5,(false,true))
I have joined two queueStream using Spark java. Please have a look at below code.
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Queue;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.api.java.JavaInputDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import com.google.common.collect.Queues;
import scala.Tuple2;
public class SparkQueueStreamJoin {
public static void main(String[] args) throws InterruptedException {
// Test data
List<Pair<Integer, Boolean>> input1 = Arrays.asList(Pair.of(1,true), Pair.of(2,false), Pair.of(3,false), Pair.of(4,true), Pair.of(5,false));
List<Pair<Integer, Boolean>> input2 = Arrays.asList(Pair.of(1,false), Pair.of(2,false), Pair.of(3,true), Pair.of(4,true), Pair.of(5,true));
SparkConf conf = new SparkConf().setMaster("local[*]").setAppName("SparkQueueStreamJoin ")
.set("spark.testing.memory", "2147480000");
//System.setProperty("hadoop.home.dir", "C:/H`enter code here`adoop/hadoop-2.7.1");
JavaSparkContext sc = new JavaSparkContext(conf);
JavaRDD<Pair<Integer, Boolean>> input1RDD = sc.parallelize(input1);
JavaRDD<Pair<Integer, Boolean>> input2RDD = sc.parallelize(input2);
JavaStreamingContext streamingContext = new JavaStreamingContext(sc, Durations.seconds(3));
Queue<JavaRDD<Pair<Integer, Boolean>>> queue1RDD = Queues.newLinkedBlockingQueue();
queue1RDD.add(input1RDD);
Queue<JavaRDD<Pair<Integer, Boolean>>> queue2RDD = Queues.newLinkedBlockingQueue();
queue2RDD.add(input2RDD);
// Creates a DStream from the test data
JavaInputDStream<Pair<Integer, Boolean>> input1DStream = streamingContext.queueStream(queue1RDD, false);
JavaInputDStream<Pair<Integer, Boolean>> input2DStream = streamingContext.queueStream(queue2RDD, false);
JavaPairDStream<Integer, Boolean> pair1DStream = input1DStream.mapToPair(new PairFunction<Pair<Integer, Boolean>, Integer, Boolean>() {
#Override
public Tuple2<Integer, Boolean> call(Pair<Integer, Boolean> rawEvent) throws Exception {
return new Tuple2<>(rawEvent.getKey(), rawEvent.getValue());
}
});
JavaPairDStream<Integer, Boolean> pair2DStream = input2DStream.mapToPair(new PairFunction<Pair<Integer, Boolean>, Integer, Boolean>() {
#Override
public Tuple2<Integer, Boolean> call(Pair<Integer, Boolean> rawEvent) throws Exception {
return new Tuple2<>(rawEvent.getKey(), rawEvent.getValue());
}
});
// Union two streams together by merging them into a single dstream
//JavaDStream<Pair<Integer, Boolean>> joinedDStream = input1DStream.union(input2DStream);
// Join the two streams together by merging them into a single dstream
JavaPairDStream<Integer, Tuple2<Boolean, Boolean>> joinedDStream = pair1DStream.join(pair2DStream);
// Print the result
joinedDStream.print();
// Start the context, time out after one batch, and then stop it
streamingContext.start();
streamingContext.awaitTerminationOrTimeout(5000);
streamingContext.stop();
}
}
Output:
-------------------------------------------
Time: 1511444352000 ms
-------------------------------------------
(1,(true,false))
(2,(false,false))
(3,(false,true))
(4,(true,true))
(5,(false,true))

Exception when using UDT in Spark DataFrame

I'm trying to create a user defined type in spark sql, but I receive:
com.ubs.ged.risk.stdout.spark.ExamplePointUDT cannot be cast to org.apache.spark.sql.types.StructType even when using their example. Has anyone made this work?
My code:
test("udt serialisation") {
val points = Seq(new ExamplePoint(1.3, 1.6), new ExamplePoint(1.3, 1.8))
val df = SparkContextForStdout.context.parallelize(points).toDF()
}
#SQLUserDefinedType(udt = classOf[ExamplePointUDT])
case class ExamplePoint(val x: Double, val y: Double)
/**
* User-defined type for [[ExamplePoint]].
*/
class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
override def sqlType: DataType = ArrayType(DoubleType, false)
override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
override def serialize(obj: Any): Seq[Double] = {
obj match {
case p: ExamplePoint =>
Seq(p.x, p.y)
}
}
override def deserialize(datum: Any): ExamplePoint = {
datum match {
case values: Seq[_] =>
val xy = values.asInstanceOf[Seq[Double]]
assert(xy.length == 2)
new ExamplePoint(xy(0), xy(1))
case values: util.ArrayList[_] =>
val xy = values.asInstanceOf[util.ArrayList[Double]].asScala
new ExamplePoint(xy(0), xy(1))
}
}
override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]
}
The usefull stackstrace is this:
com.ubs.ged.risk.stdout.spark.ExamplePointUDT cannot be cast to org.apache.spark.sql.types.StructType
java.lang.ClassCastException: com.ubs.ged.risk.stdout.spark.ExamplePointUDT cannot be cast to org.apache.spark.sql.types.StructType
at org.apache.spark.sql.SQLContext.createDataFrame(SQLContext.scala:316)
at org.apache.spark.sql.SQLContext$implicits$.rddToDataFrameHolder(SQLContext.scala:254)
It seems that the UDT needs to be used inside of another class to work (as the type of a field). One solution to use it directly is to wrap it into a Tuple1:
test("udt serialisation") {
val points = Seq(new Tuple1(new ExamplePoint(1.3, 1.6)), new Tuple1(new ExamplePoint(1.3, 1.8)))
val df = SparkContextForStdout.context.parallelize(points).toDF()
df.collect().foreach(println(_))
}

Resources