Regex function in a loop runs slowly - python-3.x

I need to apply 15 regular expressions to a Spark DataFrame.
I will add version with small df and 3 regexps here:
df = spark.createDataFrame([
Row(a=1, val1="aaa_wwwwwww"),
Row(a=2, val1="bwq_323"),
Row(a=3, val1="haha_kdjk_ska")
])
reg_exps = [
{"reg_val": "^aaa_[a-z]{5,12}$", "replace_with": "a"},
{"reg_val": "^bwq_[0-9]{2,4}$", "replace_with": "b"},
{"reg_val": "^haha_[0-9a-z_]{5,12}$", "replace_with": "c"},
]
for reg_exp in reg_exps:
df = df.withColumn(
"val1",
when(
col("val1").rlike(reg_exp["reg_val"]),
lit(reg_exp["replace_with"])
).otherwise(col("val1"))
)
df.show(truncate=False)
It should return following dataframe:
+---+----+
|a |val1|
+---+----+
|1 |a |
|2 |b |
|3 |c |
+---+----+
The code works as expected but it's really slow. Is there any ways of speeding it up?

Attempt 1
From what can be seen, you can create just one regexp_extract, without a loop.
For a. b. c:
df = df.withColumn("val1", regexp_extract("val1", r"^([a-c])_[\da-z]{5,12}$", 1))
For any letter that is in that position:
df = df.withColumn("val1", regexp_extract("val1", r"^([a-z])_[\da-z]{5,12}$", 1))
Attempt 2
Since you said, in your real case, you cannot merge your regexes, there's one thing you can simplify without it. Instead of several .withColumn, you can do just one. You would need to combine your .when() conditions into one: F.when().when().when().w....otherwise(). This can be done using reduce. With such form, I think, values which already got a regex match, would not experience several additional regex checks.
from pyspark.sql import functions as F
from functools import reduce
whens = reduce(
lambda acc, x: acc.when(F.col("val1").rlike(x["reg_val"]), x["replace_with"]),
reg_exps,
F
).otherwise(F.col("val1"))
df = df.withColumn("val1", whens)

Related

Does pyspark hash guarantee unique result for different input? [duplicate]

I am working with spark 2.2.0 and pyspark2.
I have created a DataFrame df and now trying to add a new column "rowhash" that is the sha2 hash of specific columns in the DataFrame.
For example, say that df has the columns: (column1, column2, ..., column10)
I require sha2((column2||column3||column4||...... column8), 256) in a new column "rowhash".
For now, I tried using below methods:
1) Used hash() function but since it gives an integer output it is of not much use
2) Tried using sha2() function but it is failing.
Say columnarray has array of columns I need.
def concat(columnarray):
concat_str = ''
for val in columnarray:
concat_str = concat_str + '||' + str(val)
concat_str = concat_str[2:]
return concat_str
and then
df1 = df1.withColumn("row_sha2", sha2(concat(columnarray),256))
This is failing with "cannot resolve" error.
Thanks gaw for your answer. Since I have to hash only specific columns, I created a list of those column names (in hash_col) and changed your function as :
def sha_concat(row, columnarray):
row_dict = row.asDict() #transform row to a dict
concat_str = ''
for v in columnarray:
concat_str = concat_str + '||' + str(row_dict.get(v))
concat_str = concat_str[2:]
#preserve concatenated value for testing (this can be removed later)
row_dict["sha_values"] = concat_str
row_dict["sha_hash"] = hashlib.sha256(concat_str).hexdigest()
return Row(**row_dict)
Then passed as :
df1.rdd.map(lambda row: sha_concat(row,hash_col)).toDF().show(truncate=False)
It is now however failing with error:
UnicodeEncodeError: 'ascii' codec can't encode character u'\ufffd' in position 8: ordinal not in range(128)
I can see value of \ufffd in one of the column so I am unsure if there is a way to handle this ?
You can use pyspark.sql.functions.concat_ws() to concatenate your columns and pyspark.sql.functions.sha2() to get the SHA256 hash.
Using the data from #gaw:
from pyspark.sql.functions import sha2, concat_ws
df = spark.createDataFrame(
[(1,"2",5,1),(3,"4",7,8)],
("col1","col2","col3","col4")
)
df.withColumn("row_sha2", sha2(concat_ws("||", *df.columns), 256)).show(truncate=False)
#+----+----+----+----+----------------------------------------------------------------+
#|col1|col2|col3|col4|row_sha2 |
#+----+----+----+----+----------------------------------------------------------------+
#|1 |2 |5 |1 |1b0ae4beb8ce031cf585e9bb79df7d32c3b93c8c73c27d8f2c2ddc2de9c8edcd|
#|3 |4 |7 |8 |57f057bdc4178b69b1b6ab9d78eabee47133790cba8cf503ac1658fa7a496db1|
#+----+----+----+----+----------------------------------------------------------------+
You can pass in either 0 or 256 as the second argument to sha2(), as per the docs:
Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, and SHA-512). The numBits indicates the desired bit length of the result, which must have a value of 224, 256, 384, 512, or 0 (which is equivalent to 256).
The function concat_ws takes in a separator, and a list of columns to join. I am passing in || as the separator and df.columns as the list of columns.
I am using all of the columns here, but you can specify whatever subset of columns you'd like- in your case that would be columnarray. (You need to use the * to unpack the list.)
If you want to have the hash for each value in the different columns of your dataset you can apply a self-designed function via map to the rdd of your dataframe.
import hashlib
test_df = spark.createDataFrame([
(1,"2",5,1),(3,"4",7,8),
], ("col1","col2","col3","col4"))
def sha_concat(row):
row_dict = row.asDict() #transform row to a dict
columnarray = row_dict.keys() #get the column names
concat_str = ''
for v in row_dict.values():
concat_str = concat_str + '||' + str(v) #concatenate values
concat_str = concat_str[2:]
row_dict["sha_values"] = concat_str #preserve concatenated value for testing (this can be removed later)
row_dict["sha_hash"] = hashlib.sha256(concat_str).hexdigest() #calculate sha256
return Row(**row_dict)
test_df.rdd.map(sha_concat).toDF().show(truncate=False)
The Results would look like:
+----+----+----+----+----------------------------------------------------------------+----------+
|col1|col2|col3|col4|sha_hash |sha_values|
+----+----+----+----+----------------------------------------------------------------+----------+
|1 |2 |5 |1 |1b0ae4beb8ce031cf585e9bb79df7d32c3b93c8c73c27d8f2c2ddc2de9c8edcd|1||2||5||1|
|3 |4 |7 |8 |cb8f8c5d9fd7165cf3c0f019e0fb10fa0e8f147960c715b7f6a60e149d3923a5|8||4||7||3|
+----+----+----+----+----------------------------------------------------------------+----------+
New in version 2.0 is the hash function.
from pyspark.sql.functions import hash
(
spark
.createDataFrame([(1,'Abe'),(2,'Ben'),(3,'Cas')], ('id','name'))
.withColumn('hashed_name', hash('name'))
).show()
wich results in:
+---+----+-----------+
| id|name|hashed_name|
+---+----+-----------+
| 1| Abe| 1567000248|
| 2| Ben| 1604243918|
| 3| Cas| -586163893|
+---+----+-----------+
https://spark.apache.org/docs/latest/api/python/_modules/pyspark/sql/functions.html#hash
if you want to control how the IDs should look like then we can use this code below.
import pyspark.sql.functions as F
from pyspark.sql import Window
SRIDAbbrev = "SOD" # could be any abbreviation that identifys the table or object on the table name
max_ID = 00000000 # control how long you want your numbering to be, i chose 8.
if max_ID == None:
max_ID = 0 # helps identify where you start numbering from.
dataframe_new = dataframe.orderBy(
F.lit('name')
).withColumn(
'hashed_name',
F.concat(
F.lit(SRIDAbbrev),
F.lpad(
(
F.dense_rank().over(
Window.orderBy(name)
)
+ F.lit(max_ID)
),
8,
"0"
)
)
)
which results to
+---+----+-----------+
| id|name|hashed_name|
+---+----+-----------+
| 1| Abe| SOD0000001|
| 2| Ben| SOD0000002|
| 3| Cas| SOD0000003|
| 3| Cas| SOD0000003|
+---+----+-----------+
Let me know if this helps :)

efficiently expand array of Row to separate columns

I have a spark dataframe and one of its fields is an array of Row structures. I need to expand it into their own columns. One of the problems is in the array, sometimes a field is missing.
The following is an example:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql import Row
from pyspark.sql import functions as udf
spark = SparkSession.builder.getOrCreate()
# data
rows = [{'status':'active','member_since':1990,'info':[Row(tag='name',value='John'),Row(tag='age',value='50'),Row(tag='phone',value='1234567')]},
{'status':'inactive','member_since':2000,'info':[Row(tag='name',value='Tom'),Row(tag='phone',value='1234567')]},
{'status':'active','member_since':2015,'info':[Row(tag='name',value='Steve'),Row(tag='age',value='28')]}]
# create dataframe
df = spark.createDataFrame(rows)
# transform info to dict
to_dict = udf.UserDefinedFunction(lambda s:dict(s),MapType(StringType(),StringType()))
df = df.withColumn("info_dict",to_dict("info"))
# extract name, NA if not exists
extract_name = udf.UserDefinedFunction(lambda s:s.get("name","NA"))
df = df.withColumn("name",extract_name("info_dict"))
# extract age, NA if not exists
extract_age = udf.UserDefinedFunction(lambda s:s.get("age","NA"))
df = df.withColumn("age",extract_age("info_dict"))
# extract phone, NA if not exists
extract_phone = udf.UserDefinedFunction(lambda s:s.get("phone","NA"))
df = df.withColumn("phone",extract_phone("info_dict"))
df.show()
You can see for 'Tom', 'age' is missing; for 'Steve', 'phone' is missing. Like the above code snippet, my current solution is to first transform the array into dict and then parse each individual field into their column. The result is like this:
+--------------------+------------+--------+--------------------+-----+---+-------+
| info|member_since| status| info_dict| name|age| phone|
+--------------------+------------+--------+--------------------+-----+---+-------+
|[[name, John], [a...| 1990| active|[name -> John, ph...| John| 50|1234567|
|[[name, Tom], [ph...| 2000|inactive|[name -> Tom, pho...| Tom| NA|1234567|
|[[name, Steve], [...| 2015| active|[name -> Steve, a...|Steve| 28| NA|
+--------------------+------------+--------+--------------------+-----+---+-------+
I really just want the columns 'status','member_since','name', 'age' and 'phone'. This solution works but rather slow because of the UDF. Is there any faster alternatives? Thanks
I can think of 2 ways to do this using DataFrame functions. I believe the first one should be faster, but the code is much less elegant. The second is more compact, but probably slower.
Method 1: Create Map Dynamically
The heart of this method is to turn your Row into a MapType(). This can be achieved using pyspark.sql.functions.create_map() and some magic using functools.reduce() and operator.add().
from operator import add
import pyspark.sql.functions as f
f.create_map(
*reduce(
add,
[[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
for k in range(3)]
)
)
The problem is that there isn't a way (AFAIK) to dynamically determine the length of the WrappedArray or iterate through it in an easy way. If a value is missing, this will cause an error because map keys can not be null. However since we know that the list can either contain 1, 2, 3 elements, we can just test for each of these cases.
df.withColumn(
'map',
f.when(f.size(f.col('info')) == 1,
f.create_map(
*reduce(
add,
[[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
for k in range(1)]
)
)
).otherwise(
f.when(f.size(f.col('info')) == 2,
f.create_map(
*reduce(
add,
[[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
for k in range(2)]
)
)
).otherwise(
f.when(f.size(f.col('info')) == 3,
f.create_map(
*reduce(
add,
[[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
for k in range(3)]
)
)
)))
).select(
['member_since', 'status'] + [f.col("map").getItem(k).alias(k) for k in keys]
).show(truncate=False)
The last step turns the 'map' keys into columns using the method described in this answer.
This produces the following output:
+------------+--------+-----+----+-------+
|member_since|status |name |age |phone |
+------------+--------+-----+----+-------+
|1990 |active |John |50 |1234567|
|2000 |inactive|Tom |null|1234567|
|2015 |active |Steve|28 |null |
+------------+--------+-----+----+-------+
Method 2: Use explode, groupBy and pivot
First use pyspark.sql.functions.explode() on the column 'info', and then use the 'tag' and 'value' columns as arguments to create_map():
df.withColumn('id', f.monotonically_increasing_id())\
.withColumn('exploded', f.explode(f.col('info')))\
.withColumn(
'map',
f.create_map(*[f.col('exploded')['tag'], f.col('exploded')['value']]).alias('map')
)\
.select('id', 'member_since', 'status', 'map')\
.show(truncate=False)
#+------------+------------+--------+---------------------+
#|id |member_since|status |map |
#+------------+------------+--------+---------------------+
#|85899345920 |1990 |active |Map(name -> John) |
#|85899345920 |1990 |active |Map(age -> 50) |
#|85899345920 |1990 |active |Map(phone -> 1234567)|
#|180388626432|2000 |inactive|Map(name -> Tom) |
#|180388626432|2000 |inactive|Map(phone -> 1234567)|
#|266287972352|2015 |active |Map(name -> Steve) |
#|266287972352|2015 |active |Map(age -> 28) |
#+------------+------------+--------+---------------------+
I also added a column 'id' using pyspark.sql.functions.monotonically_increasing_id() to make sure we can keep track of which rows belong to the same record.
Now we can explode the map column, groupBy(), and pivot(). We can use pyspark.sql.functions.first() as the aggregate function for the groupBy() because we know there will only be one 'value' in each group.
df.withColumn('id', f.monotonically_increasing_id())\
.withColumn('exploded', f.explode(f.col('info')))\
.withColumn(
'map',
f.create_map(*[f.col('exploded')['tag'], f.col('exploded')['value']]).alias('map')
)\
.select('id', 'member_since', 'status', f.explode('map'))\
.groupBy('id', 'member_since', 'status').pivot('key').agg(f.first('value'))\
.select('member_since', 'status', 'age', 'name', 'phone')\
.show()
#+------------+--------+----+-----+-------+
#|member_since| status| age| name| phone|
#+------------+--------+----+-----+-------+
#| 1990| active| 50| John|1234567|
#| 2000|inactive|null| Tom|1234567|
#| 2015| active| 28|Steve| null|
#+------------+--------+----+-----+-------+

How to change case of whole pyspark dataframe to lower or upper

I am trying to apply pyspark sql functions hash algorithm for every row in two dataframes to identify the differences. Hash algorithm is case sensitive .i.e. if column contains 'APPLE' and 'Apple' are considered as two different values, so I want to change the case for both dataframes to either upper or lower. I am able to achieve only for dataframe headers but not for dataframe values.Please help
#Code for Dataframe column headers
self.df_db1 =self.df_db1.toDF(*[c.lower() for c in self.df_db1.columns])
Assuming df is your dataframe, this should do the work:
from pyspark.sql import functions as F
for col in df.columns:
df = df.withColumn(col, F.lower(F.col(col)))
Both answers seems to be ok with one exception - if you have numeric column, it will be converted to string column. To avoid this, try:
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
val fields = df.schema.fields
val stringFields = df.schema.fields.filter(f => f.dataType == StringType)
val nonStringFields = df.schema.fields.filter(f => f.dataType != StringType).map(f => f.name).map(f => col(f))
val stringFieldsTransformed = stringFields .map (f => f.name).map(f => upper(col(f)).as(f))
val df = sourceDF.select(stringFieldsTransformed ++ nonStringFields: _*)
Now types are correct also when you have non-string fields, i.e. numeric fields).
If you know that each column is of String type, use one of the other answers - they are correct in that cases :)
Python code in PySpark:
from pyspark.sql.functions import *
from pyspark.sql.types import *
sourceDF = spark.createDataFrame([(1, "a")], ['n', 'n1'])
fields = sourceDF.schema.fields
stringFields = filter(lambda f: isinstance(f.dataType, StringType), fields)
nonStringFields = map(lambda f: col(f.name), filter(lambda f: not isinstance(f.dataType, StringType), fields))
stringFieldsTransformed = map(lambda f: upper(col(f.name)), stringFields)
allFields = [*stringFieldsTransformed, *nonStringFields]
df = sourceDF.select(allFields)
You can generate an expression using list comprehension:
from pyspark.sql import functions as psf
expression = [ psf.lower(psf.col(x)).alias(x) for x in df.columns ]
And then just call it over your existing dataframe
>>> df.show()
+---+---+---+---+
| c1| c2| c3| c4|
+---+---+---+---+
| A| B| C| D|
+---+---+---+---+
>>> df.select(*select_expression).show()
+---+---+---+---+
| c1| c2| c3| c4|
+---+---+---+---+
| a| b| c| d|
+---+---+---+---+

How to find count of Null and Nan values for each column in a PySpark dataframe efficiently?

import numpy as np
data = [
(1, 1, None),
(1, 2, float(5)),
(1, 3, np.nan),
(1, 4, None),
(1, 5, float(10)),
(1, 6, float("nan")),
(1, 6, float("nan")),
]
df = spark.createDataFrame(data, ("session", "timestamp1", "id2"))
Expected output
dataframe with count of nan/null for each column
Note:
The previous questions I found in stack overflow only checks for null & not nan.
That's why I have created a new question.
I know I can use isnull() function in Spark to find number of Null values in Spark column but how to find Nan values in Spark dataframe?
You can use method shown here and replace isNull with isnan:
from pyspark.sql.functions import isnan, when, count, col
df.select([count(when(isnan(c), c)).alias(c) for c in df.columns]).show()
+-------+----------+---+
|session|timestamp1|id2|
+-------+----------+---+
| 0| 0| 3|
+-------+----------+---+
or
df.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in df.columns]).show()
+-------+----------+---+
|session|timestamp1|id2|
+-------+----------+---+
| 0| 0| 5|
+-------+----------+---+
For null values in the dataframe of pyspark
Dict_Null = {col:df.filter(df[col].isNull()).count() for col in df.columns}
Dict_Null
# The output in dict where key is column name and value is null values in that column
{'#': 0,
'Name': 0,
'Type 1': 0,
'Type 2': 386,
'Total': 0,
'HP': 0,
'Attack': 0,
'Defense': 0,
'Sp_Atk': 0,
'Sp_Def': 0,
'Speed': 0,
'Generation': 0,
'Legendary': 0}
To make sure it does not fail for string, date and timestamp columns:
import pyspark.sql.functions as F
def count_missings(spark_df,sort=True):
"""
Counts number of nulls and nans in each column
"""
df = spark_df.select([F.count(F.when(F.isnan(c) | F.isnull(c), c)).alias(c) for (c,c_type) in spark_df.dtypes if c_type not in ('timestamp', 'string', 'date')]).toPandas()
if len(df) == 0:
print("There are no any missing values!")
return None
if sort:
return df.rename(index={0: 'count'}).T.sort_values("count",ascending=False)
return df
If you want to see the columns sorted based on the number of nans and nulls in descending:
count_missings(spark_df)
# | Col_A | 10 |
# | Col_C | 2 |
# | Col_B | 1 |
If you don't want ordering and see them as a single row:
count_missings(spark_df, False)
# | Col_A | Col_B | Col_C |
# | 10 | 1 | 2 |
An alternative to the already provided ways is to simply filter on the column like so
import pyspark.sql.functions as F
df = df.where(F.col('columnNameHere').isNull())
This has the added benefit that you don't have to add another column to do the filtering and it's quick on larger data sets.
Here is my one liner.
Here 'c' is the name of the column
from pyspark.sql.functions import isnan, when, count, col, isNull
df.select('c').withColumn('isNull_c',F.col('c').isNull()).where('isNull_c = True').count()
I prefer this solution:
df = spark.table(selected_table).filter(condition)
counter = df.count()
df = df.select([(counter - count(c)).alias(c) for c in df.columns])
Use the following code to identify the null values in every columns using pyspark.
def check_nulls(dataframe):
'''
Check null values and return the null values in pandas Dataframe
INPUT: Spark Dataframe
OUTPUT: Null values
'''
# Create pandas dataframe
nulls_check = pd.DataFrame(dataframe.select([count(when(isnull(c), c)).alias(c) for c in dataframe.columns]).collect(),
columns = dataframe.columns).transpose()
nulls_check.columns = ['Null Values']
return nulls_check
#Check null values
null_df = check_nulls(raw_df)
null_df
from pyspark.sql import DataFrame
import pyspark.sql.functions as fn
# compatiable with fn.isnan. Sourced from
# https://github.com/apache/spark/blob/13fd272cd3/python/pyspark/sql/functions.py#L4818-L4836
NUMERIC_DTYPES = (
'decimal',
'double',
'float',
'int',
'bigint',
'smallilnt',
'tinyint',
)
def count_nulls(df: DataFrame) -> DataFrame:
isnan_compat_cols = {c for (c, t) in df.dtypes if any(t.startswith(num_dtype) for num_dtype in NUMERIC_DTYPES)}
return df.select(
[fn.count(fn.when(fn.isnan(c) | fn.isnull(c), c)).alias(c) for c in isnan_compat_cols]
+ [fn.count(fn.when(fn.isnull(c), c)).alias(c) for c in set(df.columns) - isnan_compat_cols]
)
Builds off of gench and user8183279's answers, but checks via only isnull for columns where isnan is not possible, rather than just ignoring them.
The source code of pyspark.sql.functions seemed to have the only documentation I could really find enumerating these names — if others know of some public docs I'd be delighted.
if you are writing spark sql, then the following will also work to find null value and count subsequently.
spark.sql('select * from table where isNULL(column_value)')
Yet another alternative (improved upon Vamsi Krishna's solutions above):
def check_for_null_or_nan(df):
null_or_nan = lambda x: isnan(x) | isnull(x)
func = lambda x: df.filter(null_or_nan(x)).count()
print(*[f'{i} has {func(i)} nans/nulls' for i in df.columns if func(i)!=0],sep='\n')
check_for_null_or_nan(df)
id2 has 5 nans/nulls
Here is a readable solution because code is for people as much as computers ;-)
df.selectExpr('sum(int(isnull(<col_name>) or isnan(<col_name>))) as null_or_nan_count'))

Convert null values to empty array in Spark DataFrame

I have a Spark data frame where one column is an array of integers. The column is nullable because it is coming from a left outer join. I want to convert all null values to an empty array so I don't have to deal with nulls later.
I thought I could do it like so:
val myCol = df("myCol")
df.withColumn( "myCol", when(myCol.isNull, Array[Int]()).otherwise(myCol) )
However, this results in the following exception:
java.lang.RuntimeException: Unsupported literal type class [I [I#5ed25612
at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:49)
at org.apache.spark.sql.functions$.lit(functions.scala:89)
at org.apache.spark.sql.functions$.when(functions.scala:778)
Apparently array types are not supported by the when function. Is there some other easy way to convert the null values?
In case it is relevant, here is the schema for this column:
|-- myCol: array (nullable = true)
| |-- element: integer (containsNull = false)
You can use an UDF:
import org.apache.spark.sql.functions.udf
val array_ = udf(() => Array.empty[Int])
combined with WHEN or COALESCE:
df.withColumn("myCol", when(myCol.isNull, array_()).otherwise(myCol))
df.withColumn("myCol", coalesce(myCol, array_())).show
In the recent versions you can use array function:
import org.apache.spark.sql.functions.{array, lit}
df.withColumn("myCol", when(myCol.isNull, array().cast("array<integer>")).otherwise(myCol))
df.withColumn("myCol", coalesce(myCol, array().cast("array<integer>"))).show
Please note that it will work only if conversion from string to the desired type is allowed.
The same thing can be of course done in PySpark as well. For the legacy solutions you can define udf
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType
def empty_array(t):
return udf(lambda: [], ArrayType(t()))()
coalesce(myCol, empty_array(IntegerType()))
and in the recent versions just use array:
from pyspark.sql.functions import array
coalesce(myCol, array().cast("array<integer>"))
With a slight modification to zero323's approach, I was able to do this without using a udf in Spark 2.3.1.
val df = Seq("a" -> Array(1,2,3), "b" -> null, "c" -> Array(7,8,9)).toDF("id","numbers")
df.show
+---+---------+
| id| numbers|
+---+---------+
| a|[1, 2, 3]|
| b| null|
| c|[7, 8, 9]|
+---+---------+
val df2 = df.withColumn("numbers", coalesce($"numbers", array()))
df2.show
+---+---------+
| id| numbers|
+---+---------+
| a|[1, 2, 3]|
| b| []|
| c|[7, 8, 9]|
+---+---------+
An UDF-free alternative to use when the data type you want your array elements in can not be cast from StringType is the following:
import pyspark.sql.types as T
import pyspark.sql.functions as F
df.withColumn(
"myCol",
F.coalesce(
F.col("myCol"),
F.from_json(F.lit("[]"), T.ArrayType(T.IntegerType()))
)
)
You can replace IntegerType() with whichever data type, also complex ones.

Resources