I am trying to build a dictionary dynamically using pyspark, by reading the table structure on the oracle database. Here's a simplified version of my code
predefined dictionary (convert_dict.py)
conversions = {
"COL1": lambda c: f.col(c).cast("string"),
"COL2": lambda c: f.from_unixtime(f.unix_timestamp(c, dateFormat)).cast("date"),
"COL3": lambda c: f.from_unixtime(f.unix_timestamp(c, dateFormat)).cast("date"),
"COL4": lambda c: f.col(c).cast("float")
}
Main program
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
from pyspark.sql.types import StructType, StructField, StringType
from convert_dict import conversions
spark = SparkSession.builder.appName("file_testing").getOrCreate()
table_name = "TEST_TABLE"
input_file_path = "file:\\\c:\Desktop\foo.txt"
sql_query = "(select listagg(column_name,',') within group(order by column_id) col from user_tab_columns where " \
"table_name = '" + table_name + "' and column_name not in ('COL10', 'COL11','COL12') order by column_id) table_columns"
struct_schema = StructType([\
StructField("COL1", StringType(), True),\
StructField("COL2", StringType(), True),\
StructField("COL3", StringType(), True),\
StructField("COL4", StringType(), True),\
])
data_df = spark.read.schema(struct_schema).option("sep", ",").option("header", "true").csv(input_file_path)
validdateData = lines.withColumn(
"dataTypeValidations",
f.concat_ws(",",
*[
f.when(
v(k).isNull() & f.col(k).isNotNull(),
f.lit(k + " not valid")
).otherwise(f.lit("None"))
for k,v in conversions.items()
]
)
)
data_temp = validdateData
for k,v in conversions.items():
data_temp = data_temp.withColumn(k,v(k))
validateData.show()
spark.stop()
If I am to change the above code to dynamically generate the dictionary from database
DATEFORMAT = "yyyyMMdd"
dict_sql = """
(select column_name,case when data_type = 'VARCHAR2' then 'string' when data_type in ( 'DATE','TIMESTAMP(6)') then 'date' when data_type = 'NUMBER' and NVL(DATA_SCALE,0) <> 0 then 'float' when data_type = 'NUMBER' and NVL(DATA_SCALE,0) = 0 then 'int'
end d_type from user_tab_columns where table_name = 'TEST_TABLE' and column_name not in ('COL10', 'COL11','COL12')) dict
"""
column_df = spark.read.format("jdbc").option("url",url).option("dbtable", dict_sql)\
.option("user",user).option("password",password).option("driver",driver).load()
conversions = {}
for row in column_df.rdd.collect():
column_name = row.COLUMN_NAME
column_type = row.D_TYPE
if column_type == "date":
conversions.update({column_name: lambda c:f.col(c)})
elif column_type == "float":
conversions.update({column_name: lambda c: f.col(c).cast("float")})
elif column_type == "date":
conversions.update({column_name: lambda c: f.from_unixtime(f.unix_timestamp(c, DATEFORMAT)).cast("date")})
elif column_type == "int":
conversions.update({column_name: lambda c: f.col(c).cast("int")})
else:
conversions.update({column_name: lambda c: f.col(c)})
The conversion of data-types doesn't work when the above dynamically generated dictionary is used. For example: if "COL2" contains "20210731", the resulting data from the above code stays the same, i.e. doesn't get converted to the correct date format. Where as the predefined dictionary works in correct manner.
Am I missing something here or is there a better way to implement dynamically generated dictionaries in pyspark?
Had a rookie mistake in my code, in the if-then-else block, I had two separate statements for column_type == "date"
Related
I have a spark streaming job in which listens to kinesis stream, then it writes it to hudi table, what I want to do is say for example I added these two records to hudi table:
| user_id | name | timestamp
| -------- | -------------- |----------
| 1 | name1 |1.1.1
| 2 | name2 |1.1.1
this should be reflected in the hudi table initially, but what if I edited in the second record to have the name name2_new
what I expect is that I will have three records each with their timestamp like this:
| user_id | name | timestamp
| -------- | -------------- |----------
| 1 | name1 |1.1.1
| 2 | name2 |1.1.1
| 2 | name2_new |2.2.2
when I query this in athena I get as expected the three records, but what if I want extra analysis in which I want the last version of the records which must be something like this (because I updated in the record 2):
| user_id | name | timestamp
| -------- | -------------- |----------
| 1 | name1 |1.1.1
| 2 | name2_new |2.2.2
is there any idea how I can get the latest version only?
here's the glue job I used to create the hudi table:
import sys
from awsglue.transforms import *
from awsglue.utils import getResolvedOptions
from pyspark.sql.session import SparkSession
from pyspark.context import SparkContext
from awsglue.context import GlueContext
from awsglue.job import Job
from pyspark.sql import DataFrame, Row
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.functions import col, to_timestamp, monotonically_increasing_id, to_date, when
import datetime
from awsglue import DynamicFrame
import boto3
## #params: [JOB_NAME]
args = getResolvedOptions(sys.argv,
["JOB_NAME", "database_name", "kinesis_table_name", "starting_position_of_kinesis_iterator",
"hudi_table_name", "window_size", "s3_path_hudi", "s3_path_spark"])
spark = SparkSession.builder.config('spark.serializer', 'org.apache.spark.serializer.KryoSerializer').config(
'spark.sql.hive.convertMetastoreParquet', 'false').getOrCreate()
sc = spark.sparkContext
glueContext = GlueContext(sc)
job = Job(glueContext)
job.init(args['JOB_NAME'], args)
database_name = args["database_name"]
kinesis_table_name = args["kinesis_table_name"]
hudi_table_name = args["hudi_table_name"]
s3_path_hudi = args["s3_path_hudi"]
s3_path_spark = args["s3_path_spark"]
print("***********")
print(f"""
database_name {database_name}
kinesis_table_name = {kinesis_table_name}
hudi_table_name ={hudi_table_name}
s3_path_hudi = {s3_path_hudi}
s3_path_spark = {s3_path_spark}
""")
# can be set to "latest", "trim_horizon" or "earliest"
starting_position_of_kinesis_iterator = args["starting_position_of_kinesis_iterator"]
# The amount of time to spend processing each batch
window_size = args["window_size"]
data_frame_DataSource0 = glueContext.create_data_frame.from_catalog(
database=database_name,
table_name=kinesis_table_name,
transformation_ctx="DataSource0",
additional_options={"inferSchema": "true", "startingPosition": starting_position_of_kinesis_iterator}
)
# config
commonConfig = {
'path': s3_path_hudi
}
hudiWriteConfig = {
'className': 'org.apache.hudi',
'hoodie.table.name': hudi_table_name,
'hoodie.datasource.write.operation': 'upsert',
'hoodie.datasource.write.table.type': 'MERGE_ON_READ',
'hoodie.datasource.write.precombine.field': 'timestamp',
'hoodie.datasource.write.recordkey.field': 'user_id,timestamp',
#'hoodie.datasource.write.partitionpath.field': 'year:SIMPLE,month:SIMPLE,day:SIMPLE',
#'hoodie.datasource.write.keygenerator.class': 'org.apache.hudi.keygen.CustomKeyGenerator',
#'hoodie.deltastreamer.keygen.timebased.timestamp.type': 'DATE_STRING',
#'hoodie.deltastreamer.keygen.timebased.input.dateformat': 'yyyy-mm-dd',
#'hoodie.deltastreamer.keygen.timebased.output.dateformat': 'yyyy/MM/dd'
}
hudiGlueConfig = {
'hoodie.datasource.hive_sync.enable': 'true',
'hoodie.datasource.hive_sync.sync_as_datasource': 'false',
'hoodie.datasource.hive_sync.database': database_name,
'hoodie.datasource.hive_sync.table': hudi_table_name,
'hoodie.datasource.hive_sync.use_jdbc': 'false',
#'hoodie.datasource.write.hive_style_partitioning': 'false',
#'hoodie.datasource.hive_sync.partition_extractor_class': 'org.apache.hudi.hive.MultiPartKeysValueExtractor',
#'hoodie.datasource.hive_sync.partition_fields': 'year,month,day'
}
combinedConf = {
**commonConfig,
**hudiWriteConfig,
**hudiGlueConfig
}
# ensure the incomong record has the correct current schema, new fresh columns are fine, if a column exists in current schema but not in incoming record then manually add before inserting
def evolveSchema(kinesis_df, table, forcecast=False):
try:
# get existing table's schema
print("in evolve schema")
kinesis_df.show(truncate=False)
glue_catalog_df = spark.sql("SELECT * FROM " + table + " LIMIT 0")
# sanitize for hudi specific system columns
#columns_to_drop = ['_hoodie_commit_time', '_hoodie_commit_seqno', '_hoodie_record_key',
# '_hoodie_partition_path', '_hoodie_file_name']
columns_to_drop = ['_hoodie_commit_time', '_hoodie_commit_seqno', '_hoodie_record_key',
'_hoodie_file_name']
glue_catalog_df_sanitized = glue_catalog_df.drop(*columns_to_drop)
if (kinesis_df.schema != glue_catalog_df_sanitized.schema):
merged_df = kinesis_df.unionByName(glue_catalog_df_sanitized, allowMissingColumns=True)
return (merged_df)
except Exception as e:
print(e)
return (kinesis_df)
def processBatch(data_frame, batchId):
print("data frame is")
data_frame.show(truncate=False)
schema_main = StructType(
[
StructField('data', StringType(), True),
StructField('metadata', StringType(), True)
]
)
schema_data = StructType(
[
StructField("user_id",IntegerType(),True),
StructField("firstname",StringType(),True),
StructField("lastname",StringType(),True),
StructField("address", StringType(), True),
StructField("email", StringType(), True)
]
)
schema_metadata = StructType(
[
StructField("timestamp", StringType(), True),
StructField("record-type", StringType(), True),
StructField("operation", StringType(), True),
StructField("partition-key-type", StringType(), True),
StructField("schema-name", StringType(), True),
StructField("table-name" , StringType(), True),
StructField("transaction-id", StringType(), True)
]
)
data_frame = data_frame.withColumn("$json$data_infer_schema$_temporary$", from_json("$json$data_infer_schema$_temporary$", schema_main)).select(col("$json$data_infer_schema$_temporary$.*")).\
withColumn("data",from_json("data", schema_data)).\
withColumn("metadata",from_json("metadata", schema_metadata)).select(col("data.*"),col("metadata.timestamp"))
print("data frame is")
data_frame.show(truncate=False)
#column_headers = list(data_frame.columns)
#print("The Column Header :", column_headers)
#data_frame.printSchema()
if (data_frame.count() > 0):
kinesis_dynamic_frame = DynamicFrame.fromDF(data_frame, glueContext, "from_kinesis_data_frame")
#print("dynamoic frame is")
#kinesis_dynamic_frame.show(truncate=False)
#print('d')
kinesis_data_frame = kinesis_dynamic_frame.toDF()
print("kinesis is")
kinesis_data_frame.show(truncate=False)
kinesis_data_frame = evolveSchema(kinesis_data_frame, database_name + '.' + hudi_table_name, False)
glueContext.write_dynamic_frame.from_options(
frame=DynamicFrame.fromDF(kinesis_data_frame, glueContext, "kinesis_data_frame"),
connection_type="custom.spark",
connection_options=combinedConf
)
glueContext.forEachBatch(
frame=data_frame_DataSource0,
batch_function=processBatch,
options={
"windowSize": window_size,
"checkpointLocation": s3_path_spark
}
)
job.commit()
Usually in this kind of application, we upsert the table and we choose the record with the biggest timestamp when the keys match between the new coming records and the existing records. But since you need the history of all the updates, you have two options to achieve what you are looking for:
Create a new table containing the last version of each record by duplicating your job sinks (kinesis -> full_history_table, kinesis -> last_state_table), or create a stream (mini batches) on the first table using incremental queries (kinesis -> full_history_table -> last_state_table), in this case you will have the result for your two queries without the need to aggregate the data
Aggregate the table on each query: if you query is not frequent, you can aggregate the data on the fly using the window function. You can create a view and apply you query on it:
SELECT user_id, name, timestamp
FROM (
SELECT t.*,
ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY timestamp DESC) AS rn
FROM TABLE_NAME t
) t
WHERE rn=1;
I am trying to use a customized accumulator within Palantir Foundry to aggregate Data within
a user defined function which is applied to each row of a dataframe within a statement df.withColumn(...).
From the resulting dataframe, I see, that the incrementation of the accumulator-value happens as expected. However, the value of the accumulator variable itself in the script does not change during the execution.
I see, that the Python-ID of the accumulator variable in the script differs from the Python-ID of the accumulator within the user defined function. But that might be expected...
How do I access the accumulator value which incrementation can be watched in the resulting dataframe-colun from within the calling script after the execution, as this is the information I am looking for?
from transforms.api import transform_df, Input, Output
import numpy as np
from pyspark.accumulators import AccumulatorParam
from pyspark.sql.functions import udf, struct
global accum
#transform_df(
Output("ri.foundry.main.dataset.xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"),
)
def compute(ctx):
from pyspark.sql.types import StructType, StringType, IntegerType, StructField
data2 = [("James","","Smith","36636","M",3000),
("Michael","Rose","","40288","M",4000),
("Robert","","Williams","42114","M",4000),
("Maria","Anne","Jones","39192","F",4000),
("Jen","Mary","Brown","","F",-1)
]
schema = StructType([ \
StructField("firstname",StringType(),True), \
StructField("middlename",StringType(),True), \
StructField("lastname",StringType(),True), \
StructField("id", StringType(), True), \
StructField("gender", StringType(), True), \
StructField("salary", IntegerType(), True) \
])
df = ctx.spark_session.createDataFrame(data=data2, schema=schema)
####################################
class AccumulatorNumpyArray(AccumulatorParam):
def zero(self, zero: np.ndarray):
return zero
def addInPlace(self, v1, v2):
return v1 + v2
# from pyspark.context import SparkContext
# sc = SparkContext.getOrCreate()
sc = ctx.spark_session.sparkContext
shape = 3
global accum
accum = sc.accumulator(
np.zeros(shape, dtype=np.int64),
AccumulatorNumpyArray(),
)
def func(row):
global accum
accum += np.ones(shape)
return str(accum) + '_' + str(id(accum))
user_defined_function = udf(func, StringType())
new = df.withColumn("processed", user_defined_function(struct([df[col] for col in df.columns])))
new.show(2)
print(accum)
return df
results in
+---------+----------+--------+-----+------+------+--------------------+
|firstname|middlename|lastname| id|gender|salary| processed|
+---------+----------+--------+-----+------+------+--------------------+
| James| | Smith|36636| M| 3000|[1. 1. 1.]_140388...|
| Michael| Rose| |40288| M| 4000|[2. 2. 2.]_140388...|
+---------+----------+--------+-----+------+------+--------------------+
only showing top 2 rows
and
> accum
Accumulator<id=0, value=[0 0 0]>
> id(accum)
140574405092256
If the Foundry-Boiler-Plate is removed, resulting in
import numpy as np
from pyspark.accumulators import AccumulatorParam
from pyspark.sql.functions import udf, struct
from pyspark.sql.types import StructType, StringType, IntegerType, StructField
from pyspark.sql import SparkSession
from pyspark.context import SparkContext
spark = (
SparkSession.builder.appName("Python Spark SQL basic example")
.config("spark.some.config.option", "some-value")
.getOrCreate()
)
# ctx = spark.sparkContext.getOrCreate()
data2 = [
("James", "", "Smith", "36636", "M", 3000),
("Michael", "Rose", "", "40288", "M", 4000),
("Robert", "", "Williams", "42114", "M", 4000),
("Maria", "Anne", "Jones", "39192", "F", 4000),
("Jen", "Mary", "Brown", "", "F", -1),
]
schema = StructType(
[
StructField("firstname", StringType(), True),
StructField("middlename", StringType(), True),
StructField("lastname", StringType(), True),
StructField("id", StringType(), True),
StructField("gender", StringType(), True),
StructField("salary", IntegerType(), True),
]
)
# df = ctx.spark_session.createDataFrame(data=data2, schema=schema)
df = spark.createDataFrame(data=data2, schema=schema)
####################################
class AccumulatorNumpyArray(AccumulatorParam):
def zero(self, zero: np.ndarray):
return zero
def addInPlace(self, v1, v2):
return v1 + v2
sc = SparkContext.getOrCreate()
shape = 3
global accum
accum = sc.accumulator(
np.zeros(shape, dtype=np.int64),
AccumulatorNumpyArray(),
)
def func(row):
global accum
accum += np.ones(shape)
return str(accum) + "_" + str(id(accum))
user_defined_function = udf(func, StringType())
new = df.withColumn(
"processed", user_defined_function(struct([df[col] for col in df.columns]))
)
new.show(2, False)
print(id(accum))
print(accum)
the output obtained within a regular Python environment with pyspark version 3.3.1 on Ubuntu meets the expectations and is
+---------+----------+--------+-----+------+------+--------------------------+
|firstname|middlename|lastname|id |gender|salary|processed |
+---------+----------+--------+-----+------+------+--------------------------+
|James | |Smith |36636|M |3000 |[1. 1. 1.]_139642682452576|
|Michael |Rose | |40288|M |4000 |[1. 1. 1.]_139642682450224|
+---------+----------+--------+-----+------+------+--------------------------+
only showing top 2 rows
140166944013424
[3. 3. 3.]
The code that runs outside of the transform is ran in a different environment than the code within your transform. When you commit, you'll be running your checks which runs the code outside the transform to generate the jobspec which is technically your executable transform. You can find these within the "details" of your dataset after the checks pass.
The logic within your transform is then detached and runs in isolation each time you hit build. The global accum you define outside the transform is never ran and doesn't exist when the code inside the compute is running.
global accum <-- runs in checks
#transform_df(
Output("ri.foundry.main.dataset.c0d4fc0c-bb1d-4c7b-86ce-a13ec6666490"),
)
def compute(ctx):
bla bla some logic <-- runs during build
The prints you are doing during your second code example, happen after the df is processed, because you are asking spark to compute with the new.show(2, false). While the print you are doing in the first example happen before the df is processed, since the compute will only happen after your return df.
If you want to try to print after your df is computed, you can use #transform(... instead of #transform_df(... and do a print after writing the dataframe contents. Should be something like this:
#transform(
output=Output("ri.foundry.main.dataset.c0d4fc0c-bb1d-4c7b-86ce-a13ec6666490"),
)
def compute(ctx, output):
df = ... some logic ...
output.write_dataframe(df) # please check the function name I think it was write_dataframe, but may be wrong
print accum
I have list of string in python as follows :
['start_column=column123;to_3=2020-09-07 10:29:24;to_1=2020-09-07 10:31:08;to_0=2020-09-07 10:31:13;',
'start_column=column475;to_3=2020-09-07 10:29:34;']
I am trying to convert it into dataframe in following way :
schema = StructType([
StructField('Rows', ArrayType(StringType()), True)
])
rdd = sc.parallelize(test_list)
query_data = spark.createDataFrame(rdd,schema)
print(query_data.schema)
query_data.show()
I am getting following error:
TypeError: StructType can not accept object
You just need to pass that as a list while creating the dataframe as below ...
a_list = ['start_column=column123;to_3=2020-09-07 10:29:24;to_1=2020-09-07 10:31:08;to_0=2020-09-07 10:31:13;',
'start_column=column475;to_3=2020-09-07 10:29:34;']
sparkdf = spark.createDataFrame([a_list],["col1", "col2"])
sparkdf.show(truncate=False)
+--------------------------------------------------------------------------------------------------+------------------------------------------------+
|col1 |col2 |
+--------------------------------------------------------------------------------------------------+------------------------------------------------+
|start_column=column123;to_3=2020-09-07 10:29:24;to_1=2020-09-07 10:31:08;to_0=2020-09-07 10:31:13;|start_column=column475;to_3=2020-09-07 10:29:34;|
+--------------------------------------------------------------------------------------------------+------------------------------------------------+
You should use schema = StringType() because your rows contains strings rather than structs of strings.
I have two possible solutions for you.
SOLUTION 1: Assuming you wanted a dataframe with just one row
I was able to make it work by wrapping the values in test_list in Parentheses and using StringType.
v = [('start_column=column123;to_3=2020-09-07 10:29:24;to_1=2020-09-07 10:31:08;to_0=2020-09-07 10:31:13;',
'start_column=column475;to_3=2020-09-07 10:29:34;')]
schema = StructType([
StructField('col_1', StringType(), True),
StructField('col_2', StringType(), True),
])
rdd = sc.parallelize(v)
query_data = spark.createDataFrame(rdd,schema)
print(query_data.schema)
query_data.show(truncate = False)
SOLUTION 2: Assuming you wanted a dataframe with just one column
v = ['start_column=column123;to_3=2020-09-07 10:29:24;to_1=2020-09-07 10:31:08;to_0=2020-09-07 10:31:13;',
'start_column=column475;to_3=2020-09-07 10:29:34;']
from pyspark.sql.types import StringType
df = spark.createDataFrame(v, StringType())
df.show(truncate = False)
I have a Glue job (running on spark) that simply convert a CSV file to Parquet. I don't have control over the CSV data and as a result I want to capture any inconsistency between the data and the table schema during the conversion to Parquet. For instance if a column is defined as Integer, I want the job gives me an error if there is any string value in that column! Currently, DynamicFrame resolves this by giving choices (string and integer) in the resulted Parquet file! which is helpful for some use cases but I'm wondering if there is any way that enforce the schema and have the glue job to throw error if there is any inconsistency. Here is my code:
datasource0 = glueContext.create_dynamic_frame.from_catalog(databasem=mdbName, table_namem=mtable, transformation_ctx="datasource0")
df = datasource0.toDF()
df = df.coalesce(parquetFileCount)
df = convertColDataType(df, "timestamp", "timestamp", dbName, table)
applymapping1 = DynamicFrame.fromDF(df,glueContext,"finalDF")
datasink4 = glueContext.write_dynamic_frame.from_options(frame = applymapping1, connection_type = "s3", connection_options = {"path": path}, format = "parquet", transformation_ctx = "datasink4")
job.commit()
You can solve this by using spark native lib instead of using glue lib
Instead of reading from catalog, read from the corresponding s3 path with custom schema and failfast mode
schema = StructType ([StructField ('id', IntegerType(), True),
StructField ('name', StringType(), True)]
df = spark.read.option('mode', 'FAILFAST').csv(s3Path, schema=schema)
I had a similiar data type issue that I was able to solve by importing another df that I knew was in the correct format. Then I looped over the columns of the two dfs and compared their data types. In this example I reformated the data types where necessary:
df1 = inputfile
df2 = target
if df1.schema != df2.schema:
colnames = df2.schema.names
for colname in colnames:
df1DataType = get_dtype(df1, colname)
df2DataType = get_dtype(df2, colname)
if df1DataType != df2DataType:
if df1DataType == 'timestamp':
not_string = ''
df2 = df2.withColumn(colname, df2[colname].cast(TimestampType()))
elif df1DataType == 'double':
not_string = ''
df2 = df2.withColumn(colname, df2[colname].cast(DoubleType()))
elif df1DataType == 'int':
not_string = ''
df2 = df2.withColumn(colname, df2[colname].cast(IntegerType()))
else:
not_string = 'not '
print(not_string + 'updating: ' + colname + ' - from ' + df2DataType + ' to ' + df1DataType)
target = df2
(EDIT) changed field names (from foo, bar,... to name and city) because old naming was confusing
I need to use a single function in multiple UDFs and return different Structs depending on the input.
This simplified version of my implementation basically does what I am looking for:
from pyspark.sql.types import IntegerType, StructType, StringType
from pyspark.sql.functions import when, col
df = spark.createDataFrame([1, 2, 3], IntegerType()).toDF('id')
struct_one = StructType().add('name', StringType(), True)
struct_not_one = StructType().add('city', StringType(), True)
def select(id):
if id == 1:
return {'name': 'Alice'}
else:
return {'city': 'Seattle'}
one_udf = udf(select, struct_one)
not_one_udf = udf(select, struct_not_one)
df = df.withColumn('one', when((col('id') == 1), one_udf(col('id'))))\
.withColumn('not_one', when((col('id') != 1), not_one_udf(col('id'))))
display(df)
(EDIT) Output:
id one not_one
1 {"name":"Alice"} null
2 null {"city":"Seattle"}
3 null {"city":"Seattle"}
But, the same code returning an ArrayType of StructType unfortunatly fails:
from pyspark.sql.types import IntegerType, StructType, StringType, ArrayType
from pyspark.sql.functions import when, col
df = spark.createDataFrame([1, 2, 3], IntegerType()).toDF('id')
struct_one = StructType().add('name', StringType(), True)
struct_not_one = ArrayType(StructType().add('city', StringType(), True))
def select(id):
if id == 1:
return {'name': 'Alice'}
else:
return [{'city': 'Seattle'}, {'city': 'Milan'}]
one_udf = udf(select, struct_one)
not_one_udf = udf(select, struct_not_one)
df = df.withColumn('one', when((col('id') == 1), one_udf(col('id'))))\
.withColumn('not_one', when((col('id') != 1), not_one_udf(col('id'))))
display(df)
The error message is:
ValueError: Unexpected tuple 'name' with StructType
(EDIT) Desired Output would be:
id one not_one
1 {"name":"Alice"} null
2 null [{"city":"Seattle"},{"city":"Milan"}]
3 null [{"city":"Seattle"},{"city":"Milan"}]
Returning and ArrayType of other types (StringType, IntegerType,...) for example works, though.
Also returning an Array of StructType when not using a single function in multiple UDFs works:
from pyspark.sql.types import IntegerType, StructType, StringType, ArrayType
from pyspark.sql.functions import when, col
df = spark.createDataFrame([1, 2, 3], IntegerType()).toDF('id')
struct_not_one = ArrayType(StructType().add('city', StringType(), True))
def select(id):
return [{'city': 'Seattle'}, {'city': 'Milan'}]
not_one_udf = udf(select, struct_not_one)
df = df.withColumn('not_one', when((col('id') != 1), not_one_udf(col('id'))))
display(df)
(EDIT) Output:
id not_one
1 null
2 [{"city":"Seattle"},{"city":"Milan"}]
3 [{"city":"Seattle"},{"city":"Milan"}]
Why is returning an ArrayType of StructType and using multiple UDFs with one single function not working?
Thanks!
"Spark SQL (including SQL and the DataFrame and Dataset API) does not guarantee the order of evaluation of subexpressions...
Therefore, it is dangerous to rely on the side effects or order of evaluation of Boolean expressions, and the order of WHERE and HAVING clauses, since such expressions and clauses can be reordered during query optimization and planning. Specifically, if a UDF relies on short-circuiting semantics in SQL for null checking, there’s no guarantee that the null check will happen before invoking the UDF."
See Evaluation order and null checking
To keep your udf generic you could push the 'when filter' into your udf:
from pyspark.sql.types import IntegerType, StructType, StringType, ArrayType
from pyspark.sql.functions import when, col, lit
df = spark.createDataFrame([1, 2, 3], IntegerType()).toDF('id')
struct_one = StructType().add('name', StringType(), True)
struct_not_one = ArrayType(StructType().add('city', StringType(), True))
def select(id, test):
if eval(test.format(id)) is False:
return None
if id == 1:
return {'name': 'Alice'}
else:
return [{'city': 'Seattle'}, {'city': 'Milan'}]
one_udf = udf(select, struct_one)
not_one_udf = udf(select, struct_not_one)
df = df.withColumn('one', one_udf(col('id'), lit('{} == 1')))\
.withColumn('not_one', not_one_udf(col('id'), lit('{} != 1')))
display(df)