Spark SQL: How to call UDF from DataFrame operation using JAVA - apache-spark

I would like to know how to call UDF function from function of domain-specific language(DSL) in Spark SQL using JAVA.
I have UDF function (just for example):
UDF2 equals = new UDF2<String, String, Boolean>() {
#Override
public Boolean call(String first, String second) throws Exception {
return first.equals(second);
}
};
I've registered it to sqlContext
sqlContext.udf().register("equals", equals, DataTypes.BooleanType);
When I run following query, my UDF is called and I get a result.
sqlContext.sql("SELECT p0.value FROM values p0 WHERE equals(p0.value, 'someString')");
I would transfrom this query using functions of domain specific language in Spark SQL, and I am not sure how to do it.
valuesDF.select("value").where(???);
I found that there exists callUDF() function where one of its parameters is Function2 fnctn but not UDF2.
How can I use UDF and functions from DSL?

I found a solution with which I am half-satisfied.
It is possible to call UDF as a Column Condition such as:
valuesDF.filter("equals(columnName, 'someString')").select("columnName");
But I still wonder if it is possible to call UDF directly.
Edit:
Btw, it is possible to call udf directly e.g:
df.where(callUdf("equals", scala.collection.JavaConversions.asScalaBuffer(
Arrays.asList(col("columnName"), col("otherColumnName"))
).seq())).select("columnName");
import of org.​apache.​spark.​sql.​functions is required.

When querying a dataframe, you should just be able to execute the UDF using something like this:
sourceDf.filter(equals(col("columnName"), "someString")).select("columnName")
where col("columnName") is the column you want to compare.

Here is working code example. It works with Spark 1.5.x and 1.6.x. The trick to calling UDF's from within a pipeline transformer is to use the sqlContext() on the DataFrame to register your UDF
#Test
public void test() {
// https://issues.apache.org/jira/browse/SPARK-12484
logger.info("BEGIN");
DataFrame df = createData();
final String tableName = "myTable";
sqlContext.registerDataFrameAsTable(df, tableName);
logger.info("print schema");
df.printSchema();
logger.info("original data before we applied UDF");
df.show();
MyUDF udf = new MyUDF();
final String udfName = "myUDF";
sqlContext.udf().register(udfName, udf, DataTypes.StringType);
String fmt = "SELECT *, %s(%s) as transformedByUDF FROM %s";
String stmt = String.format(fmt, udfName, tableName+".labelStr", tableName);
logger.info("AEDWIP stmt:{}", stmt);
DataFrame udfDF = sqlContext.sql(stmt);
Row[] results = udfDF.head(3);
for (Row row : results) {
logger.info("row returned by applying UDF {}", row);
}
logger.info("AEDWIP udfDF schema");
udfDF.printSchema();
logger.info("AEDWIP udfDF data");
udfDF.show();
logger.info("END");
}
DataFrame createData() {
Features f1 = new Features(1, category1);
Features f2 = new Features(2, category2);
ArrayList<Features> data = new ArrayList<Features>(2);
data.add(f1);
data.add(f2);
//JavaRDD<Features> rdd = javaSparkContext.parallelize(Arrays.asList(f1, f2));
JavaRDD<Features> rdd = javaSparkContext.parallelize(data);
DataFrame df = sqlContext.createDataFrame(rdd, Features.class);
return df;
}
class MyUDF implements UDF1<String, String> {
private static final long serialVersionUID = 1L;
#Override
public String call(String s) throws Exception {
logger.info("AEDWIP s:{}", s);
String ret = s.equalsIgnoreCase(category1) ? category1 : category3;
return ret;
}
}
public class Features implements Serializable{
private static final long serialVersionUID = 1L;
int id;
String labelStr;
Features(int id, String l) {
this.id = id;
this.labelStr = l;
}
public int getId() {
return id;
}
public void setId(int id) {
this.id = id;
}
public String getLabelStr() {
return labelStr;
}
public void setLabelStr(String labelStr) {
this.labelStr = labelStr;
}
}
this is the output
+---+--------+
| id|labelStr|
+---+--------+
| 1| noise|
| 2| ack|
+---+--------+
root
|-- id: integer (nullable = false)
|-- labelStr: string (nullable = true)
|-- transformedByUDF: string (nullable = true)
+---+--------+----------------+
| id|labelStr|transformedByUDF|
+---+--------+----------------+
| 1| noise| noise|
| 2| ack| signal|
+---+--------+----------------+

Related

Transform PCollection<KV> to custom class

My goal is to read a file from GCS and write it to Cassandra.
New to Apache Beam/Dataflow, I could find most of the hand on build with Python. Unfortunately CassandraIO is only Java native with Beam.
I used the word count example as a template and try to get rid of the TextIO.write() and replace it with a CassandraIO.<Words>write().
Here my java class for the Cassandra table
package org.apache.beam.examples;
import java.io.Serializable;
import com.datastax.driver.mapping.annotations.Column;
import com.datastax.driver.mapping.annotations.PartitionKey;
import com.datastax.driver.mapping.annotations.Table;
#Table(keyspace = "test", name = "words", readConsistency = "ONE", writeConsistency = "QUORUM",
caseSensitiveKeyspace = false, caseSensitiveTable = false)
public class Words implements Serializable {
// private static final long serialVersionUID = 1L;
#PartitionKey
#Column(name = "word")
public String word;
#Column(name = "count")
public long count;
public Words() {
}
public Words(String word, int count) {
this.word = word;
this.count = count;
}
#Override
public boolean equals(Object obj) {
Words other = (Words) obj;
return this.word.equals(other.word) && this.count == other.count;
}
}
And here the pipeline part of the main code.
static void runWordCount(WordCount.WordCountOptions options) {
Pipeline p = Pipeline.create(options);
// Concepts #2 and #3: Our pipeline applies the composite CountWords transform, and passes the
// static FormatAsTextFn() to the ParDo transform.
p.apply("ReadLines", TextIO.read().from(options.getInputFile()))
.apply(new WordCountToCassandra.CountWords())
// Here I'm not sure how to transform PCollection<KV> into PCollection<Words>
.apply(MapElements.into(TypeDescriptor.of(Words.class)).via(PCollection<KV<String, Long>>)
}))
.apply(CassandraIO.<Words>write()
.withHosts(Collections.singletonList("my_ip"))
.withPort(9142)
.withKeyspace("test")
.withEntity(Words.class));
p.run().waitUntilFinish();
}
My understand is to use a PTransform to pass from PCollection<T1> from PCollection<T2>. I don't know how to map that.
If it's 1:1 mapping, MapElements.into is the right choice.
You can either specify a class that implements SerializableFunction<FromType, ToType>, or simply use a lambda, for example:
.apply(MapElements.into(TypeDescriptor.of(Words.class)).via(kv -> new Words(kv.getKey(), kv.getValue()));
Please check MapElements for more information.
If the transformation is not one-to-one, there are other available options such as FlatMapElements or ParDo.

Spark CVS load - custom schema - custom object

Encoder<Transaction> encoder = Encoders.bean(Transaction.class);
Dataset<Row> transactionDS = sparkSession
.read()
.format("csv")
.option("header", true)
.option("delimiter", ",")
.option("enforceSchema", false)
.option("multiLine", false)
.schema(encoder.schema())
.load("s3a://xxx/testSchema.csv");
.as(encoder);
System.out.println("==============schema starts============");
transactionDS.printSchema();
System.out.println("==============schema ends============");
transactionDS.show(10, true); // this is the line that bombs.
My CVS is this -
transactionId,accountId
1,2
10,44
I'm printing my schema in the logs - (you see, the columns are now flipped, or sorted - Ah!)
==============schema starts============
root
|-- accountId: integer (nullable = true)
|-- transactionId: long (nullable = true)
==============schema ends============
I'm getting below error
Caused by: java.lang.IllegalArgumentException: CSV header does not conform to the schema.
Header: transactionId, accounted
Schema: accountId, transactionId
Expected: accountId but found: transactionId
This is what my Tranaction class looks like.
public class Transaction implements Serializable {
private static final long serialVersionUID = 7648268336292069686L;
private Long transactionId;
private Integer accountId;
public Long getTransactionId() {
return transactionId;
}
public void setTransactionId(Long transactionId) {
this.transactionId = transactionId;
}
public Integer getAccountId() {
return accountId;
}
public void setAccountId(Integer accountId) {
this.accountId = accountId;
}
}
Question - Why Spark is not able to match my schema? The ordering is messed up. In my CSV, I'm passing transactionid, accountId but spark takes my schema accountId, transctionId. Ah!
do not use encoder.schema to load csv file, its column order may not according to csv.
Unlike parquet csv doesn't have a schema, so it will not apply the correct order, what you can do is to read the csv without:
.schema(encoder.schema())
Then apply the schema to the dataset that you just created.
This is what I ended up doing -
Encoder<Transaction> encoder = Encoders.bean(Transaction.class);
// read data from S3
System.out.println("Going to read file......................................");
Dataset<Transaction> transactionDS = sparkSession
.read()
.format("csv")
.option("header", true)
.option("delimiter", ",")
//.option("enforceSchema", false)
.option("inferSchema", false)
.option("dateFormat", "yyyy-MM-dd")
//.option("multiLine", false)
//.schema(encoder.schema())
.schema(_createSchema())
.csv("s3a://xxx/transactions_4_with_column_names.csv")
.as(encoder);
The _createSchema() function is below -
private static StructType _createSchema() {
List<StructField> list = new ArrayList<StructField>() {
private static final long serialVersionUID = -4953991596584287923L;
{
add(DataTypes.createStructField("transactionId", DataTypes.LongType, true));
add(DataTypes.createStructField("accountId", DataTypes.IntegerType, true));
add(DataTypes.createStructField("destAccountId", DataTypes.IntegerType, true));
add(DataTypes.createStructField("destPostDate", DataTypes.DateType, true));
}
};
return new StructType(list.toArray(new StructField[0]));
}

How to get the taskID or mapperID(something like partitionID in Spark) in a hive UDF?

As question, How to get the taskID or mapperID(something like partitionID in Spark) in a hive UDF ?
You can access task information using TaskContext:
import org.apache.spark.TaskContext
sc.parallelize(Seq[Int](), 4).mapPartitions(_ => {
val ctx = TaskContext.get
val stageId = ctx.stageId
val partId = ctx.partitionId
val hostname = java.net.InetAddress.getLocalHost().getHostName()
Iterator(s"Stage: $stageId, Partition: $partId, Host: $hostname")}).collect.foreach(println)
A similar functionality has been added to PySpark in Spark 2.2.0 (SPARK-18576):
from pyspark import TaskContext
import socket
def task_info(*_):
ctx = TaskContext()
return ["Stage: {0}, Partition: {1}, Host: {2}".format
(ctx.stageId(), ctx.partitionId(), socket.gethostname())]
for x in sc.parallelize([], 4).mapPartitions(task_info).collect():
print(x)
I think it will provide you the information about the task including map id you are looking for.
I have found the correct answer on my own, we can get the taskID in a hive UDF the way as below :
public class TestUDF extends GenericUDF {
private Text result = new Text();
private String tmpStr = "";
#Override
public void configure(MapredContext context) {
//get the number of tasks 获取task总数量
int numTasks = context.getJobConf().getNumMapTasks();
//get the current taskID 获取当前taskID
String taskID = context.getJobConf().get("mapred.task.id");
this.tmpStr = numTasks + "_h_xXx_h_" + taskID;
}
#Override
public ObjectInspector initialize(ObjectInspector[] arguments)
throws UDFArgumentException {
return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
}
#Override
public Object evaluate(DeferredObject[] arguments) {
result.set(this.tmpStr);
return this.result;
}
#Override
public String getDisplayString(String[] children) {
return "RowSeq-func()";
}
}
but this would be effective only in MapReduce execution engine, it would not work in a SparkSQL engine.
Test code as below:
add jar hdfs:///home/dp/tmp/shaw/my_udf.jar;
create temporary function seqx AS 'com.udf.TestUDF';
with core as (
select
device_id
from
test_table
where
p_date = '20210309'
and product = 'google'
distribute by
device_id
)
select
seqx() as seqs,
count(1) as cc
from
core
group by
seqx()
order by
seqs asc
Result in MR engine as below, see we have got the task number and taskID successfully:
Result in Spark engine with same sql above, the UDF is not valid, we get nothing about taskID:
If you run your HQL in Spark engine and call the Hive UDF meanwhile, and really need to get the partitionId in Spark, see the code below :
import org.apache.spark.TaskContext;
public class TestUDF extends GenericUDF {
private Text result = new Text();
private String tmpStr = "";
#Override
public ObjectInspector initialize(ObjectInspector[] arguments)
throws UDFArgumentException {
//get spark partitionId
this.tmpStr = TaskContext.getPartitionId() + "-initial-pid";
return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
}
public Object evaluate(DeferredObject[] arguments) {
//get spark partitionId
this.tmpStr = TaskContext.getPartitionId() + "-evaluate-pid";
result.set(this.tmpStr);
return this.result;
}
}
As above, you can get the Spark partitionId by calling TaskContext.getPartitionId() in the override method initialize or evalute of UDF class.
Notice: your UDF must has params, suchs select my_udf(param), this would lead your UDF initialized in multiple tasks, if your UDF do not have a param, it will be initialized at the Driver, and the Driver do not have the taskContext and partitionId, so you would get nothing.
The image below is a result produced by the above UDF executed in Spark engine,see, we get the partitionIds successfully :

How to create a a spark dataframe from Integer RDD

How can I create a DataFrame from an JavaRDD contains Integers. I have done something like below but not working.
List<Integer> input = Arrays.asList(101, 103, 105);
JavaRDD<Integer> inputRDD = sc.parallelize(input);
DataFrame dataframe = sqlcontext.createDataFrame(inputRDD, Integer.class);
I got ClassCastException saying org.apache.spark.sql.types.IntegerType$ cannot be cast to org.apache.spark.sql.types.StructType
How can I achieve this?
Apparently (although not intuitively), this createDataFrame overload can only work for "Bean" types, which means types that do not correspond to any built-in Spark SQL type.
You can see that in the source code, the class you pass is matched with a Spark SQL type in JavaTypeInference.inferDataType, and the result is cast into a StructType (see dataType.asInstanceOf[StructType] in SQLContext.getSchema - but the built in "primitive" types (like IntegerType) are NOT StructTypes... Looks like a bug or undocumented behavior to me....
WORKAROUNDS:
Wrap your Integers with a "bean" class (that's ugly, I know):
public static class MyBean {
final int value;
MyBean(int value) {
this.value = value;
}
public int getValue() {
return value;
}
}
List<MyBean> input = Arrays.asList(new MyBean(101), new MyBean(103), new MyBean(105));
JavaRDD<MyBean> inputRDD = sc.parallelize(input);
DataFrame dataframe = sqlcontext.createDataFrame(inputRDD, MyBean.class);
dataframe.show(); // this works...
Convert to RDD<Row> yourself:
// convert to Rows:
JavaRDD<Row> rowRdd = inputRDD.map(new Function<Integer, Row>() {
#Override
public Row call(Integer v1) throws Exception {
return RowFactory.create(v1);
}
});
// create schema (this looks nicer in Scala...):
StructType schema = new StructType(new StructField[]{new StructField("number", IntegerType$.MODULE$, false, Metadata.empty())});
DataFrame dataframe = sqlcontext.createDataFrame(rowRdd, schema);
dataframe.show(); // this works...
Now in Spark 2.2 you can do the following to create a Dataset.
Dataset<Integer> dataSet = sqlContext().createDataset(javardd.rdd(), Encoders.INT());

How to execute ranged query in cassandra with astyanax and composite column

I am developing a blog using cassandra and astyanax. It is only an exercise of course.
I have modelled the CF_POST_INFO column family in this way:
private static class PostAttribute {
#Component(ordinal = 0)
UUID postId;
#Component(ordinal = 1)
String category;
#Component
String name;
public PostAttribute() {}
private PostAttribute(UUID postId, String category, String name) {
this.postId = postId;
this.category = category;
this.name = name;
}
public static PostAttribute of(UUID postId, String category, String name) {
return new PostAttribute(postId, category, name);
}
}
private static AnnotatedCompositeSerializer<PostAttribute> postSerializer = new AnnotatedCompositeSerializer<>(PostAttribute.class);
private static final ColumnFamily<String, PostAttribute> CF_POST_INFO =
ColumnFamily.newColumnFamily("post_info", StringSerializer.get(), postSerializer);
And a post is saved in this way:
MutationBatch m = keyspace().prepareMutationBatch();
ColumnListMutation<PostAttribute> clm = m.withRow(CF_POST_INFO, "posts")
.putColumn(PostAttribute.of(post.getId(), "author", "id"), post.getAuthor().getId().get())
.putColumn(PostAttribute.of(post.getId(), "author", "name"), post.getAuthor().getName())
.putColumn(PostAttribute.of(post.getId(), "meta", "title"), post.getTitle())
.putColumn(PostAttribute.of(post.getId(), "meta", "pubDate"), post.getPublishingDate().toDate());
for(String tag : post.getTags()) {
clm.putColumn(PostAttribute.of(post.getId(), "tags", tag), (String) null);
}
for(String category : post.getCategories()) {
clm.putColumn(PostAttribute.of(post.getId(), "categories", category), (String)null);
}
the idea is to have some row like bucket of some time (one row per month or year for example).
Now if I want to get the last 5 posts for example, how can I do a rage query for that? I can execute a rage query based on the post id (UUID) but I don't know the available post ids without doing another query to get them. What are the cassandra best practice here?
Any suggestion about the data model is welcome of course, I'm very newbie to cassandra.
If your use case works the way I think it works you could modify your PostAttribute so that the first component is a TimeUUID that way you can store it as time series data and you'd easily be able to pull the oldest 5 or newest 5 using the standard techniques. Anyway...here's a sample of what it would look like to me since you don't really need to make multiple columns if you're already using composites.
public class PostInfo {
#Component(ordinal = 0)
protected UUID timeUuid;
#Component(ordinal = 1)
protected UUID postId;
#Component(ordinal = 2)
protected String category;
#Component(ordinal = 3)
protected String name;
#Component(ordinal = 4)
protected UUID authorId;
#Component(ordinal = 5)
protected String authorName;
#Component(ordinal = 6)
protected String title;
#Component(ordinal = 7)
protected Date published;
public PostInfo() {}
private PostInfo(final UUID postId, final String category, final String name, final UUID authorId, final String authorName, final String title, final Date published) {
this.timeUuid = TimeUUIDUtils.getUniqueTimeUUIDinMillis();
this.postId = postId;
this.category = category;
this.name = name;
this.authorId = authorId;
this.authorName = authorName;
this.title = title;
this.published = published;
}
public static PostInfo of(final UUID postId, final String category, final String name, final UUID authorId, final String authorName, final String title, final Date published) {
return new PostInfo(postId, category, name, authorId, authorName, title, published);
}
}
private static AnnotatedCompositeSerializer<PostInfo> postInfoSerializer = new AnnotatedCompositeSerializer<>(PostInfo.class);
private static final ColumnFamily<String, PostInfo> CF_POSTS_TIMELINE =
ColumnFamily.newColumnFamily("post_info", StringSerializer.get(), postInfoSerializer);
You should save it like this:
MutationBatch m = keyspace().prepareMutationBatch();
ColumnListMutation<PostInfo> clm = m.withRow(CF_POSTS_TIMELINE, "all" /* or whatever makes sense for you such as year or month or whatever */)
.putColumn(PostInfo.of(post.getId(), post.getCategory(), post.getName(), post.getAuthor().getId(), post.getAuthor().getName(), post.getTitle(), post.getPublishedOn()), /* maybe just null bytes as column value */)
m.execute();
Then you could query like this:
OperationResult<ColumnList<PostInfo>> result = getKeyspace()
.prepareQuery(CF_POSTS_TIMELINE)
.getKey("all" /* or whatever makes sense like month, year, etc */)
.withColumnRange(new RangeBuilder()
.setLimit(5)
.setReversed(true)
.build())
.execute();
ColumnList<PostInfo> columns = result.getResult();
for (Column<PostInfo> column : columns) {
// do what you need here
}

Resources