Pyspark - Schema validation using Metadata - apache-spark

I am trying to validate the data using spark schema. I would like to know what are the different metadata tags that I can pass to the struct field
col1 - accepts "val1", "val2", ....(A column accepts only few legal values. If any other value appears in that column that should be a bad record)
Is there any tag that I can use in the StructField to validate the data while reading the source file?

As of Spark 3.3.0 does not provide column value validation during read.
StructField metadata parameters serves more like tagging feature.
You will need to write your own validation rules. One way would be to use filter like below.
def validate_val(negated: bool = False):
negate = "not " if negated else ""
return expr(f" {negate} startswith(c1, 'val') ")
all_records = spark.createDataFrame([["val1", "val2"], ["not_val1", "val2"]], ["c1", "c2"])
good_records = all_records.filter(validate_val())
bad_records = all_records.filter(validate_val(negated=True))
Also you would probably need to
all_records.cache()
to avoid data reload for each action i.e. bad_records.count() etc.
Full code:
from pyspark.sql.session import SparkSession
from pyspark.sql.functions import expr
if __name__ == '__main__':
spark = SparkSession.builder.appName("test-app").master("local[1]").getOrCreate()
def validate_val(negated: bool = False):
negate = "not " if negated else ""
return expr(f" {negate} startswith(c1, 'val') ")
all_records = spark.createDataFrame([["val1", "val2"], ["not_val1", "val2"]], ["c1", "c2"])
all_records = all_records.cache()
good_records = all_records.filter(validate_val())
bad_records = all_records.filter(validate_val(negated=True))
good_records.show()
bad_records.show()
all_records.show()
+----+----+
| c1| c2|
+----+----+
|val1|val2|
+----+----+
+--------+----+
| c1| c2|
+--------+----+
|not_val1|val2|
+--------+----+
+--------+----+
| c1| c2|
+--------+----+
| val1|val2|
|not_val1|val2|
+--------+----+

Related

Is there a way to use a map/dict in Pyspark to avoid CASE WHEN condition equals pairs?

I have a problem in Pyspark creating a column based on values in another column for a new dataframe.
It's boring and seems to me not a good practice to use a lot of
CASE
WHEN column_a = 'value_1' THEN 'value_x'
WHEN column_a = 'value_2' THEN 'value_y'
...
WHEN column_a = 'value_289' THEN 'value_xwerwz'
END
In cases like this, in python, I get used to using a dict or, even better, a configparser file and avoid the if else condition. I just pass the key and python returns the desired value. Also, we have a 'fallback' option for ELSE clause.
The problem seems to me that we are not treating a single row but all of them in one command, so using dict/map/configparser is an unavailable option. I thought about using a loop with dict, but it seems too slow and a waste of computation as we repeat all the conditions.
I'm still looking for this practice, if I find it, I'll post it here. But, you know, probably a lot of people already use it and I don't know yet. But if there is no other way, ok. Use many WHEN THEN conditions won't be a choice.
Thank you
I tried to use a dict and searched for solutions like this
You could create a function which converts a dict into a Spark F.when, e.g.:
import pyspark.sql.functions as F
def create_spark_when(column, conditions, default):
when = None
for key, value in conditions.items():
current_when = F.when(F.col(column) == key, value)
if when is None:
when = current_when.otherwise(default)
else:
when = current_when.otherwise(when)
return when
df = spark.createDataFrame([(0,), (1,), (2,)])
df.show()
my_conditions = {1: "a", 2: "b"}
my_default = "c"
df.withColumn(
"my_column",
create_spark_when("_1", my_conditions, my_default),
).show()
Output:
+---+
| _1|
+---+
| 0|
| 1|
| 2|
+---+
+---+---------+
| _1|my_column|
+---+---------+
| 0| c|
| 1| a|
| 2| b|
+---+---------+
One choice is to use create a dataframe out of dictionary and perform join
This would work:
Creating a Dataframe:
dict={"value_1": "value_x", "value_2": "value_y"}
dict_df=spark.createDataFrame([(k,v) for k,v in dict.items()], ["key","value"])
Performing the join:
df.alias("df1")\
.join(F.broadcast(dict_df.alias("df2")), F.col("column_a")==F.col("key"))\
.selectExpr("df1.*","df2.value as newColumn")\
.show()
We can broadcast the dict_df as it is small.
Input:
Dict_df:
Output:
Alternatively, you can use a UDF - but that is not recommended.

PySpark UDF issues when referencing outside of function

I facing the issue that I get the error
TypeError: cannot pickle '_thread.RLock' object
when I try to apply the following code:
from pyspark.sql.types import *
from pyspark.sql.functions import *
data_1 = [('James','Smith','M',30),('Anna','Rose','F',41),
('Robert','Williams','M',62),
]
data_2 = [('Junior','Smith','M',15),('Helga','Rose','F',33),
('Mike','Williams','M',77),
]
columns = ["firstname","lastname","gender","age"]
df_1 = spark.createDataFrame(data=data_1, schema = columns)
df_2 = spark.createDataFrame(data=data_2, schema = columns)
def find_n_people_with_higher_age(x):
return df_2.filter(df_2['age']>=x).count()
find_n_people_with_higher_age_udf = udf(find_n_people_with_higher_age, IntegerType())
df_1.select(find_n_people_with_higher_age_udf(col('category_id')))
Here's a good article on python UDF's.
I use it as a reference as I suspected that you were running into a serialization issue. I'm showing the entire paragraph to add context of the sentence but really it's the serialization that's the issue.
Performance Considerations
It’s important to understand the performance implications of Apache
Spark’s UDF features. Python UDFs for example (such as our CTOF
function) result in data being serialized between the executor JVM and
the Python interpreter running the UDF logic – this significantly
reduces performance as compared to UDF implementations in Java or
Scala. Potential solutions to alleviate this serialization bottleneck
include:
If you consider what you are asking maybe you'll see why this isn't working. You are asking all data from your dataframe(data_2) to be shipped(serialized) to an executor that then serializes it and ships it to python to be interpreted. Dataframes don't serialize. So that's your issue, but if they did, you are sending an entire data frame to each executor. Your sample data here isn't an issue, but for trillions of records it would blow up the JVM.
What your asking is doable I just need to figure out how do it. Likely a window or group by would be the trick.
add additional data:
from pyspark.sql import Window
from pyspark.sql.types import *
from pyspark.sql.functions import *
data_1 = [('James','Smith','M',30),('Anna','Rose','F',41),
('Robert','Williams','M',62),
]
# add more data to make it more interesting.
data_2 = [('Junior','Smith','M',15),('Helga','Rose','F',33),('Gia','Rose','F',34),
('Mike','Williams','M',77), ('John','Williams','M',77), ('Bill','Williams','F',79),
]
columns = ["firstname","lastname","gender","age"]
df_1 = spark.createDataFrame(data=data_1, schema = columns)
df_2 = spark.createDataFrame(data=data_2, schema = columns)
# dataframe to help fill in missing ages
ref = spark.range( 1, 110, 1).toDF("numbers").withColumn("count", lit(0)).withColumn("rolling_Count", lit(0))
countAges = df_2.groupby("age").count()
#this actually give you the short list of ages
rollingCounts = countAges.withColumn("rolling_Count", sum(col("count")).over(Window.partitionBy().orderBy(col("age").desc())))
#fill in missing ages and remove duplicates
filled = rollingCounts.union(ref).groupBy("age").agg(sum("count").alias("count"))
#add a rolling count across all ages
allAgeCounts = filled.withColumn("rolling_Count", sum(col("count")).over(Window.partitionBy().orderBy(col("age").desc())))
#do inner join because we've filled in all ages.
df_1.join(allAgeCounts, df_1.age == allAgeCounts.age, "inner").show()
+---------+--------+------+---+---+-----+-------------+
|firstname|lastname|gender|age|age|count|rolling_Count|
+---------+--------+------+---+---+-----+-------------+
| Anna| Rose| F| 41| 41| 0| 3|
| Robert|Williams| M| 62| 62| 0| 3|
| James| Smith| M| 30| 30| 0| 5|
+---------+--------+------+---+---+-----+-------------+
I wouldn't normally want to use a window over an entire table, but here the data it's iterating over <= 110 so this is reasonable.

Convert string type to array type in spark sql

I have table in Spark SQL in Databricks and I have a column as string. I converted as new columns as Array datatype but they still as one string. Datatype is array type in table schema
Column as String
Data1
[2461][2639][2639][7700][7700][3953]
Converted to Array
Data_New
["[2461][2639][2639][7700][7700][3953]"]
String to array conversion
df_new = df.withColumn("Data_New", array(df["Data1"]))
Then write as parquet and use as spark sql table in databricks
When I search for string using array_contains function I get results as false
select *
from table_name
where array_contains(Data_New,"[2461]")
When I search for all string then query turns the results as true
Please suggest if I can separate these string as array and can find any array using array_contains function.
Just remove leading and trailing brackets from the string then split by ][ to get an array of strings:
df = df.withColumn("Data_New", split(expr("rtrim(']', ltrim('[', Data1))"), "\\]\\["))
df.show(truncate=False)
+------------------------------------+------------------------------------+
|Data1 |Data_New |
+------------------------------------+------------------------------------+
|[2461][2639][2639][7700][7700][3953]|[2461, 2639, 2639, 7700, 7700, 3953]|
+------------------------------------+------------------------------------+
Now use array_contains like this:
df.createOrReplaceTempView("table_name")
sql_query = "select * from table_name where array_contains(Data_New,'2461')"
spark.sql(sql_query).show(truncate=False)
Actually this is not an array, this is a full string so you need a regex or similar
expr = "[2461]"
df_new.filter(df_new["Data_New"].rlike(expr))
import
from pyspark.sql import functions as sf, types as st
create table
a = [["[2461][2639][2639][7700][7700][3953]"], [None]]
sdf = sc.parallelize(a).toDF(["col1"])
sdf.show()
+--------------------+
| col1|
+--------------------+
|[2461][2639][2639...|
| null|
+--------------------+
convert type
def spliter(x):
if x is not None:
return x[1:-1].split("][")
else:
return None
udf = sf.udf(spliter, st.ArrayType(st.StringType()))
sdf.withColumn("array_col1", udf("col1")).withColumn("check", sf.array_contains("array_col1", "2461")).show()
+--------------------+--------------------+-----+
| col1| array_col1|check|
+--------------------+--------------------+-----+
|[2461][2639][2639...|[2461, 2639, 2639...| true|
| null| null| null|
+--------------------+--------------------+-----+

Get examples for rows that are removed by a filter from a spark dataframe

Suppose I have a spark dataframe df with some columns (id,...) and a string sqlFilter with a SQL filter, e.g. "id is not null".
I want to filter the dataframe df based on sqlFilter, i.e.
val filtered = df.filter(sqlFilter)
Now, I want to have a list of 10 ids from df that were removed by the filter.
Currently, I'm using a "leftanti" join to achieve this, i.e.
val examples = df.select("id").join(filtered.select("id"), Seq("id"), "leftanti")
.take(10)
.map(row => Option(row.get(0)) match { case None => "null" case Some(x) => x.toString})
However, this is really slow.
My guess is that this can be implemented faster, because spark only has to have a list for every partitition
and add an id to the list when filter removes a row and the list contains less than 10 elements. Once the action after
filter finishes, spark has to collect all the lists from the partitions until it has 10 ids.
I wanted to use accumulators as described here,
but I failed because I could not find out how to parse and use sqlFilter.
Has anybody an idea how I can improve the performance?
Update
Ramesh Maharjan suggested in the comments to inverse the SQL query, i.e.
df.filter(s"NOT ($filterString)")
.select(key)
.take(10)
.map(row => Option(row.get(0)) match { case None => "null" case Some(x) => x.toString})
This indeed improves the performance but it is not 100% equivalent.
If there are multiple rows with the same id, the id will end up in the examples if one row is removed due to the filter. With the leftantit join it does not end up in the examples because the id is still in filtered.
However, that is fine with me.
I'm still interested if it is possible to create the list "on the fly" with accumulators or something similar.
Update 2
Another issue with inverting the filter is the logical value UNKNOWN in SQL, because NOT UNKNWON = UNKNOWN, i.e. NOT(null <> 1) <=> UNKNOWN and hence this row shows up neither in the filtered dataframe nor in the inverted dataframe.
You can use a custom accumulator (because longAccumulator won't help you as all ids will be null); and you must formulate your filter statement as function :
Suppose you have a dataframe :
+----+--------+
| id| name|
+----+--------+
| 1|record 1|
|null|record 2|
| 3|record 3|
+----+--------+
Then you could do :
import org.apache.spark.util.AccumulatorV2
class RowAccumulator(var value: Seq[Row]) extends AccumulatorV2[Row, Seq[Row]] {
def this() = this(Seq.empty[Row])
override def isZero: Boolean = value.isEmpty
override def copy(): AccumulatorV2[Row, Seq[Row]] = new RowAccumulator(value)
override def reset(): Unit = value = Seq.empty[Row]
override def add(v: Row): Unit = value = value :+ v
override def merge(other: AccumulatorV2[Row, Seq[Row]]): Unit = value = value ++ other.value
}
val filteredAccum = new RowAccumulator()
ss.sparkContext.register(filteredAccum, "Filter Accum")
val filterIdIsNotNull = (r:Row) => {
if(r.isNullAt(r.fieldIndex("id"))) {
filteredAccum.add(r)
false
} else {
true
}}
df
.filter(filterIdIsNotNull)
.show()
println(filteredAccum.value)
gives
+---+--------+
| id| name|
+---+--------+
| 1|record 1|
| 3|record 3|
+---+--------+
List([null,record 2])
But personally I would not do this, I would rather do something like you've already suggested :
val dfWithFilter = df
.withColumn("keep",expr("id is not null"))
.cache() // check whether caching is feasibly
// show 10 records which we do not keep
dfWithFilter.filter(!$"keep").drop($"keep").show(10) // or use take(10)
+----+--------+
| id| name|
+----+--------+
|null|record 2|
+----+--------+
// rows that we keep
val filteredDf = dfWithFilter.filter($"keep").drop($"keep")

Udf not working

can you help me to optimize this code and make it work?
this is original data:
+--------------------+-------------+
| original_name|medicine_name|
+--------------------+-------------+
| Venlafaxine| Venlafaxine|
| Lacrifilm 5mg/ml| Lacrifilm|
| Lacrifilm 5mg/ml| null|
| Venlafaxine| null|
|Vitamin D10,000IU...| null|
| paracetamol| null|
| mucolite| null|
I'm expect to get data like this
+--------------------+-------------+
| original_name|medicine_name|
+--------------------+-------------+
| Venlafaxine| Venlafaxine|
| Lacrifilm 5mg/ml| Lacrifilm|
| Lacrifilm 5mg/ml| Lacrifilm|
| Venlafaxine| Venlafaxine|
|Vitamin D10,000IU...| null|
| paracetamol| null|
| mucolite| null|
This is the code:
distinct_df = spark.sql("select distinct medicine_name as medicine_name from medicine where medicine_name is not null")
distinct_df.createOrReplaceTempView("distinctDF")
def getMax(num1, num2):
pmax = (num1>=num2)*num1+(num2>num1)*num2
return pmax
def editDistance(s1, s2):
ed = (getMax(length(s1), length(s2)) - levenshtein(s1,s2))/
getMax(length(s1), length(s2))
return ed
editDistanceUdf = udf(lambda x,y: editDistance(x,y), FloatType())
def getSimilarity(str):
res = spark.sql("select medicine_name, editDistanceUdf('str', medicine_name) from distinctDf where editDistanceUdf('str', medicine_name)>=0.85 order by 2")
res['medicine_name'].take(1)
return res
getSimilarityUdf = udf(lambda x: getSimilarity(x), StringType())
res_df = df.withColumn('m_name', when((df.medicine_name.isNull)|(df.medicine_name.=="null")),getSimilarityUdf(df.original_name)
.otherwise(df.medicine_name)).show()
now i'm getting error:
command_part = REFERENCE_TYPE + parameter._get_object_id()
AttributeError: 'function' object has no attribute '_get_object_id'
There is a bunch of problems with your code:
You cannot use SparkSession or distributed objects in the udf. So getSimilarity just cannot work. If you want to compare objects like this you have to join.
If length and levenshtein come from pyspark.sql.functions there cannot be used inside UserDefinedFunctions. There are designed to generate SQL expressions, mapping from *Column to Column.
Column isNull is a method not property so should be called:
df.medicine_name.isNull()
Following
df.medicine_name.=="null"
is not a syntactically valid Python (looks like Scala calque) and would throw compiler exceptions.
If SparkSession access was allowed in an UserDefinedFunction this wouldn't be a valid substitution
spark.sql("select medicine_name, editDistanceUdf('str', medicine_name) from distinctDf where editDistanceUdf('str', medicine_name)>=0.85 order by 2")
You should use string formatting methods
spark.sql("select medicine_name, editDistanceUdf({str}, medicine_name) from distinctDf where editDistanceUdf({str}, medicine_name)>=0.85 order by 2".format(str=str))
Maybe some other problems, but since you didn't provide a MCVE, anything else would be pure guessing.
When you fix smaller mistakes you have two choices:
Use crossJoin:
combined = df.alias("left").crossJoin(spark.table("distinctDf").alias("right"))
Then apply udf, filter, and one of the methods listed in Find maximum row per group in Spark DataFrame to closest match in group.
Use built-in approximate matching tools as explained in Efficient string matching in Apache Spark

Resources