Single Task Taking Long Time in PySpark - apache-spark

I am running a PySpark application where I am reading several Parquet files into Spark dataframes and created temporary views on them to use in my SQL query. So I have like 18 views where some are ~ 1TB, few in several GBs and some other smaller views. I am joining all of these and running my business logic to get the desired outcome. My code takes extremely long time to run (>3 hours) for this data. Looking at the Spark History Server, I can see there's one task that seems the culprit as the time taken, data spilled to memory and disk, shuffle read/write everything is way higher than the median. This indicates a data skew. So I even used salting on my large dataframes before creating the temp views. However there's still no difference in the execution time. I checked the number of partitions and it's already 792 (maximum I can have my current Glue config). I have also enabled adaptive query execution and adaptive skewJoin handling.
My original dataset was extremely huge largest table being ~40TB and has 2.5 years of data. I am trying to do a one time historical load and was unsuccessful on running over the entire data. With trial and error, I had to reduce this to processing 1TB of data at a time (for the largest table) which is still taking 3+ hours. This is not a scalable approach and hence I am looking for some inputs to optimize this.
Below are my app details:
Number of workers = 792
Spark config:
spark= (SparkSession
.builder
.appName("scmCaseAlertDatamartFullLoad")
.config("spark.sql.sources.partitionOverwriteMode", "dynamic")
.config("spark.sql.adaptive.enabled","true")
.config("spark.sql.broadcastTimeout","900")
.config("spark.sql.adaptive.skewJoin.enabled","true")
.getOrCreate()
)
Code (just included key relevant methods, starting point is loadSCMCseAlertData()):
def getIncomingMatchesFullData(self):
select_query_incoming_matches_full_data = """
SELECT DISTINCT alrt.caseid AS case_id,
alrt.alertid AS alert_id,
alrt.accountid AS account_id,
sc.created_time AS case_created_time,
sc.last_updated_time AS case_last_updated_time,
alrt.srccreatedtime AS alert_created_time,
aud.last_updated_by AS case_last_updated_by,
sc.closed_time AS case_last_close_time,
lcs.status AS case_status,
lcst.state AS case_state,
lcra.responsive_action,
sc.assigned_to AS case_assigned_to,
cr1.team_name AS case_assigned_to_team,
sc.resolved_by AS case_resolved_by,
cr2.team_name AS case_resolved_by_team,
aud.last_annotation AS case_last_annotation,
ca.name AS case_approver,
alrt.screeningdecision AS screening_decision,
ap.accountpool AS division,
lcd.decision AS case_current_decision,
CASE
WHEN sm.grylaclientid LIKE '%AddressService%' THEN 'Address Service'
WHEN sm.grylaclientid LIKE '%GrylaOrderProcessingService%' THEN 'Retail Checkout Service'
WHEN sm.grylaclientid = 'urn:cdo:GrylaBatchScreeningAAA:AWS:Default' THEN 'Batch Screening'
WHEN sm.grylaclientid = 'urn:cdo:OfficerJennyBindle:AWS:Default' THEN 'API'
ELSE 'Other'
END AS channel,
ap.businesstype AS business_type,
ap.businessname AS business_name,
ap.marketplaceid AS ap_marketplace_id,
ap.region AS ap_region,
ap.memberid AS ap_member_id,
ap.secondaryaccountpool AS secondary_account_pool,
sm.action AS client_action,
acl.added_by,
acl.lnb_id AS accept_list_lnb_id,
acl.created_time AS accept_list_created_time,
acl.source_case_id AS accept_list_source_case_id,
acs.status AS accept_list_status,
ap.street1 AS ap_line_1,
ap.street2 AS ap_line_2,
ap.street3 AS ap_line_3,
ap.city AS ap_city,
ap.state AS ap_state,
ap.postalcode AS ap_postal_code,
ap.country AS ap_country,
ap.fullname AS ap_full_name,
ap.email AS ap_email,
sm.screening_match_id AS dp_screening_match_id,
CASE
WHEN sm.matchtype = 'name_only_matching_details' THEN 'Name Only'
WHEN sm.matchtype = 'address_only_matching_details' THEN 'Address Only'
WHEN sm.matchtype = 'address_matching_details' THEN 'Address'
WHEN sm.matchtype = 'scr_matching_details' THEN 'SCR'
WHEN sm.matchtype = 'hotkey_matching_details' THEN 'HotKey'
END AS match_type,
sm.matchaction AS match_action,
alrt.batchfilename AS batch_file_id,
REGEXP_REPLACE(dp.name, '\\n|\\r|\\t', ' ') AS dp_matched_add_full_name,
dp.street AS dp_line1,
'' AS dp_line2,
dp.city AS dp_city,
dp.state AS dp_state,
dp.postalcode AS dp_postal_code,
dp.country AS dp_country,
dp.matchedplaces AS scr_value,
dp.hotkeyvalues AS hotkey_value,
sm.acceptlistid AS suppressed_by_accept_list_id,
sm.suppresseddedupe AS is_deduped,
sm.matchhash AS hash,
sm.matchdecision AS match_decision,
ap.addressid AS amazon_address_id,
ap.dateofbirth AS date_of_birth,
sm.grylaclientid AS gryla_client_id,
cr1.name AS case_assigned_to_role,
cr2.name AS case_resolved_by_role,
alrt.screeningengine AS screening_engine,
sm.srccreatedtime AS match_created_time,
sm.srclastupdatedtime AS match_updated_time,
to_date(sm.srclastupdatedtime,"yyyy-MM-dd") AS match_updated_date,
sm.match_updated_time_msec,
sm.suppressedby AS match_suppressed_by
FROM
cm_screening_match sm
JOIN
cm_screening_match_redshift smr ON sm.screening_match_id = smr.screening_match_id
LEFT JOIN
cm_case_alert alrt ON sm.screening_match_id = alrt.screening_match_id
LEFT JOIN
cm_amazon_party ap ON sm.screening_match_id = ap.screening_match_id
LEFT JOIN
cm_denied_party dp ON sm.screening_match_id = dp.screening_match_id
LEFT JOIN
cm_spectre_case sc ON alrt.caseid = sc.case_id
LEFT JOIN
cm_lookup_case_status lcs ON sc.status_id = lcs.status_id
LEFT JOIN
cm_lookup_case_state lcst ON sc.state_id = lcst.state_id
LEFT JOIN
cm_lookup_case_decision lcd ON sc.decision_id = lcd.decision_id
LEFT JOIN
cm_lookup_case_responsive_action lcra ON sc.responsive_action_id = lcra.responsive_action_id
LEFT JOIN
cm_user cu1 ON sc.assigned_to = cu1.alias
LEFT JOIN
cm_role cr1 ON cu1.current_role_id = cr1.role_id
LEFT JOIN
cm_user cu2 ON sc.resolved_by = cu2.alias
LEFT JOIN
cm_role cr2 ON cu2.current_role_id = cr2.role_id
LEFT JOIN
cm_accept_list acl ON acl.screening_match_id = sm.screening_match_id
LEFT JOIN
cm_lookup_accept_list_status acs ON acs.status_id = acl.status_id
LEFT JOIN
(
SELECT case_id,
last_value(username) OVER (PARTITION BY case_id ORDER BY created_time
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS last_updated_by,
last_value(description) OVER (PARTITION BY case_id ORDER BY created_time
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS last_annotation
FROM cm_spectre_case_audit
) aud ON sc.case_id = aud.case_id
LEFT JOIN
cm_approver ca ON sc.approver_id = ca.approver_id
"""
print(select_query_incoming_matches_full_data)
incomingMatchesFullDF = self.spark.sql(select_query_incoming_matches_full_data)
return incomingMatchesFullDF
def getBaseTables(self,matchtime_lower_threshold,matchtime_upper_threshold,cursor):
print('Fetching datalake data for matches created after: {}' .format(matchtime_lower_threshold))
matchDF = self.getDatalakeData(matchtime_lower_threshold,matchtime_upper_threshold,self.data_input_match)
matchDF = matchDF.select("screening_match_id","grylaclientid","action","matchtype","matchaction","acceptlistid","suppresseddedupe","matchhash","matchdecision","srccreatedtime","srclastupdatedtime","suppressedby","lastupdatedtime")
#.withColumn("screentime",to_timestamp("screentime")) \
matchDF = matchDF.withColumn("match_updated_time_msec",col("lastupdatedtime").cast(LongType())).drop("lastupdatedtime")
#matchDF = matchDF.repartition(2400,"screening_match_id")
matchDF = self.getLatestRecord(matchDF)
matchDF = matchDF.withColumn("salt", rand())
matchDF = matchDF.repartition("salt")
matchDF.createOrReplaceTempView("cm_screening_match")
print("Total from matchDF:",matchDF.count())
print("Number of paritions in matchDF: " ,matchDF.rdd.getNumPartitions())
alertDF = self.getDatalakeData(matchtime_lower_threshold,matchtime_upper_threshold,self.data_input_alert)
alertDF = alertDF.select("screening_match_id","caseid","alertid","accountid","srccreatedtime","screeningdecision","batchfilename","screeningengine","lastupdatedtime")
alertDF = alertDF.withColumn("match_updated_time_msec",col("lastupdatedtime").cast(LongType())).drop("lastupdatedtime")
#alertDF = alertDF.repartition(2400,"screening_match_id")
alertDF = self.getLatestRecord(alertDF)
alertDF = alertDF.withColumn("salt", rand())
alertDF = alertDF.repartition("salt")
alertDF.createOrReplaceTempView("cm_case_alert")
print("Total from alertDF:",alertDF.count())
print("Number of paritions in alertDF: " ,alertDF.rdd.getNumPartitions())
apDF = self.getDatalakeData(matchtime_lower_threshold,matchtime_upper_threshold,self.data_input_ap)
apDF = apDF.select("screening_match_id","accountpool","businesstype","businessname","marketplaceid","region","memberid","secondaryaccountpool","street1","street2","street3","city","state","postalcode","country","fullname","email","addressid","dateofbirth","lastupdatedtime")
apDF = apDF.withColumn("dateofbirth",to_date("dateofbirth","yyyy-MM-dd")) \
.withColumn("match_updated_time_msec",col("lastupdatedtime").cast(LongType())) \
.drop("lastupdatedtime")
#apDF = apDF.repartition(2400,"screening_match_id")
apDF = self.getLatestRecord(apDF)
apDF = apDF.withColumn("salt", rand())
apDF = apDF.repartition("salt")
apDF.createOrReplaceTempView("cm_amazon_party")
print("Total from apDF:",apDF.count())
print("Number of paritions in apDF: " ,apDF.rdd.getNumPartitions())
dpDF = self.getDatalakeData(matchtime_lower_threshold,matchtime_upper_threshold,self.data_input_dp)
dpDF = dpDF.select("screening_match_id","name","street","city","state","postalcode","country","matchedplaces","hotkeyvalues","lastupdatedtime")
dpDF = dpDF.withColumn("match_updated_time_msec",col("lastupdatedtime").cast(LongType())).drop("lastupdatedtime")
#dpDF = dpDF.repartition(2400,"screening_match_id")
dpDF = self.getLatestRecord(dpDF)
dpDF = dpDF.withColumn("salt", rand())
dpDF = dpDF.repartition("salt")
dpDF.createOrReplaceTempView("cm_denied_party")
print("Total from dpDF:",dpDF.count())
print("Number of paritions in dpDF: " ,dpDF.rdd.getNumPartitions())
print('Fetching data from Redshift Base tables...')
self.getRedshiftData(matchtime_lower_threshold,matchtime_upper_threshold,cursor)
caseAuditDF = self.spark.read.parquet(self.data_input_case_audit)
caseAuditDF.createOrReplaceTempView("cm_spectre_case_audit")
caseDF = self.spark.read.parquet(self.data_input_case)
caseDF.createOrReplaceTempView("cm_spectre_case")
caseStatusDF = self.spark.read.parquet(self.data_input_case_status)
caseStatusDF.createOrReplaceTempView("cm_lookup_case_status")
caseStateDF = self.spark.read.parquet(self.data_input_case_state)
caseStateDF.createOrReplaceTempView("cm_lookup_case_state")
caseDecisionDF = self.spark.read.parquet(self.data_input_case_decision)
caseDecisionDF.createOrReplaceTempView("cm_lookup_case_decision")
caseRespActDF = self.spark.read.parquet(self.data_input_case_responsive_action)
caseRespActDF.createOrReplaceTempView("cm_lookup_case_responsive_action")
userDF = self.spark.read.parquet(self.data_input_user)
userDF.createOrReplaceTempView("cm_user")
userSnapshotDF = self.spark.read.parquet(self.data_input_user_snapshot)
userSnapshotDF.createOrReplaceTempView("v_cm_user_snapshot")
roleDF = self.spark.read.parquet(self.data_input_role)
roleDF.createOrReplaceTempView("cm_role")
skillDF = self.spark.read.parquet(self.data_input_skill)
skillDF.createOrReplaceTempView("cm_skill")
lookupSkillDF = self.spark.read.parquet(self.data_input_lookup_skills)
lookupSkillDF.createOrReplaceTempView("cm_lookup_skills")
skillTypeDF = self.spark.read.parquet(self.data_input_skill_type)
skillTypeDF.createOrReplaceTempView("cm_skill_type")
acceptListDF = self.spark.read.parquet(self.data_input_accept_list)
acceptListDF.createOrReplaceTempView("cm_accept_list")
lookupAcceptListStatusDF = self.spark.read.parquet(self.data_input_lookup_accept_list_status)
lookupAcceptListStatusDF.createOrReplaceTempView("cm_lookup_accept_list_status")
approverDF = self.spark.read.parquet(self.data_input_approver)
approverDF.createOrReplaceTempView("cm_approver")
screeningMatchDF_temp = self.spark.read.parquet(self.data_input_screening_match_redshift)
screeningMatchLookupDF_temp = self.spark.read.parquet(self.data_input_lookup_screening_match_redshift)
screeningMatchLookupDF_temp_new = screeningMatchLookupDF_temp.withColumnRenamed("screening_match_id","lookupdf_screening_match_id")
"""
The screening_match_id in datalake table is a mix of alphanumeric match IDs (the ones in cm_lookup_screening_match_id in Redshift) and numeric (the ones in cm_screening_match in Redshift). Hence we combine the match IDs from both the Redshift tables. Also, there are matches which were created in the past but updated recently. Since updated date is only present in cm_screening_match and not in cm_lookup_screening_match_id, we will only have the numeric match Ids. When we join this to datalake table, we won't be able to find these matches as they are present in the alphanumeric form in datalake. Hence what we do is read the entire table of cm_lookup_screening_match_id and join it with cm_screening_match to enrich cm_screening_match with the alphanumeric match Id. Finally we filter cm_lookup_screening_match_id only for newly created matches and combine with the matches from enriched version of cm_screening_match.
"""
screeningMatchDF_enriched = screeningMatchDF_temp.join(screeningMatchLookupDF_temp_new,screeningMatchDF_temp.screening_match_id == screeningMatchLookupDF_temp_new.lookupdf_screening_match_id,"left")
screeningMatchDF_enriched = screeningMatchDF_enriched.withColumn("screening_match_id",col("screening_match_id").cast(StringType()))
screeningMatchDF = screeningMatchDF_enriched.select(col("screening_match_id")).union(screeningMatchDF_enriched.select(col("match_event_id")))
screeningMatchLookupDF = screeningMatchLookupDF_temp_new.filter("created_time > '{}'" .format(matchtime_lower_threshold)).select(col("match_event_id"))
screeningMatchRedshiftDF = screeningMatchDF.union(screeningMatchLookupDF)
#screeningMatchRedshiftDF = screeningMatchRedshiftDF.repartition(792,"screening_match_id")
screeningMatchRedshiftDF = screeningMatchRedshiftDF.withColumn("salt", rand())
screeningMatchRedshiftDF = screeningMatchRedshiftDF.repartition("salt")
screeningMatchRedshiftDF.createOrReplaceTempView("cm_screening_match_redshift")
print("Total from screeningMatchRedshiftDF:",screeningMatchRedshiftDF.count())
def loadSCMCaseAlertTable(self):
print('Getting the thresholds for data to be loaded')
matchtime_lower_threshold = self.getLowerThreshold('scm_case_alert_data')
print('Match time lower threshold is: {}' .format(matchtime_lower_threshold))
matchtime_upper_threshold = self.default_upper_threshold
print('Match time upper threshold is: {}' .format(matchtime_upper_threshold))
print("Getting the required base tables")
con = self.get_redshift_connection()
cursor = con.cursor()
self.getBaseTables(matchtime_lower_threshold,matchtime_upper_threshold,cursor)
print("Getting the enriched dataset for incoming matches (the ones to be inserted or updated)")
incomingMatchesFullDF = self.getIncomingMatchesFullData()
print("Total records in incomingMatchesFullDF: ", incomingMatchesFullDF.count())
print("Copying the incoming data to temp work dir")
print("Clearing work directory: {}" .format(self.work_scad_path))
self.deleteAllObjectsFromS3Prefix(self.dest_bucket,self.dest_work_prefix_scad)
print("Writing data to work dir: {}" .format(self.work_scad_path))
#.coalesce(1) \
incomingMatchesFullDF.write \
.partitionBy("match_updated_date") \
.mode("overwrite") \
.parquet(self.work_scad_path + self.work_dir_partitioned_table_scad)
print("Data copied to work dir")
print("Reading data from work dir in a temporary dataframe")
incomingMatchesFullDF_copy = self.spark.read.parquet(self.work_scad_path + "scm_case_alert_data_work.parquet/")
if self.update_mode == 'overwrite':
print("Datamart update mode is overwrite. New data will replace existing data.")
print("Publishing to Redshift")
self.publishToRedshift(con,cursor)
print("Publishing to Redshift complete")
elif self.update_mode == 'upsert':
print("Datamart update mode is upsert. New data will be loaded and existing data will be updated.")
print("Checking for cases updated between {} and {}" .format(matchtime_lower_threshold,matchtime_upper_threshold))
updatedCasesDF = self.getUpdatedCases(matchtime_lower_threshold,matchtime_upper_threshold)
updatedCasesDF.createOrReplaceTempView("updated_cases")
print("Getting updated case attributes")
updatedCaseAttributesDF = self.getUpdatedCaseAttributes()
print("Moving updated case data to temp work directory: {}".format(self.work_updated_cases_path))
print("Clearing work directory")
self.deleteAllObjectsFromS3Prefix(self.dest_bucket,self.dest_work_prefix_updated_cases)
try:
print("Writing data to work dir: {}" .format(self.work_updated_cases_path))
updatedCaseAttributesDF.coalesce(1) \
.write \
.mode("overwrite") \
.parquet(self.work_updated_cases_path + "updated_cases.parquet")
except Exception as e:
e = sys.exc_info()[0]
print("No data to write to work dir")
print("Starting the process to publish data to Redshift")
self.publishToRedshift(con,cursor)
print("Publishing to Redshift complete")
print('Updating metadata table')
matchtime_lower_threshold_new = incomingMatchesFullDF_copy.agg({'match_updated_time': 'max'}).collect()[0][0]
if matchtime_lower_threshold_new is not None:
matchtime_lower_threshold_new_formatted = matchtime_lower_threshold_new.strftime("%Y-%m-%d %H:%M:%S")
print("Latest match time lower threshold with new load: {}" .format(matchtime_lower_threshold_new_formatted))
self.updatePipelineMetadata('scm_case_alert_data','max_data_update_time',matchtime_lower_threshold_new_formatted)
else:
print("No new matches, leaving max_data_update_time for match as it is")
print("Metadata table up to date")
print("Committing the updates to Redshift and closing the connection")
con.commit() #Committing after the metadata table is updated to ensure the datamart data and threshold are aligned
cursor.close()
con.close()
Spark History Server Screenshot:

As you have correctly felt, you're having data skew issues. This is really apparent from your last screenshot. Have a look at the shuffle read/write sizes! The thing that you have to find out is: for which shuffle operation (looks like a join) are you having this issue?
Only salting the large dataframes without knowing where your skew is wont solve the issue.
So, my proposed plan of action:
You see that stage 112 from your picture is the problematic stage. Figure out which join operation this is about. In the SQL tab of the web-ui you can find that stage 112 and hover over it. That should give you enough info to figure out which shuffle/join key is skewed.
Once you know which key is skewed, understand the statistical contents of your key using spark-shell or something like that. Figure out which value is overly common. This will help in making future decisions. A simple df.groupBy("problematicKey").count will already be really interesting.
Once you know that, you can go ahead and salt that specific key.
But you're absolutely on the right track! Keeping an eye on that Tasks page and the time it takes for each task is a great approach!
Hope this helps :)

Related

PySpark data skewness with Window Functions

I have a huge PySpark dataframe and I'm doing a series of Window functions over partitions defined by my key.
The issue with the key is, my partitions gets skewed by this and results in Event Timeline that looks something like this,
I know that I can use salting technique to solve this issue when I'm doing a join. But how can I solve this issue when I'm using Window functions?
I'm using functions like lag, lead etc in the Window functions. I can't do the process with salted key, because I'll get wrong results.
How to solve skewness in this case?
I'm looking for a dynamic way of repartitioning my dataframe without skewness.
Updates based on answer from #jxc
I tried creating a sample df and tried running code over that,
df = pd.DataFrame()
df['id'] = np.random.randint(1, 1000, size=150000)
df['id'] = df['id'].map(lambda x: 100 if x % 2 == 0 else x)
df['timestamp'] = pd.date_range(start=pd.Timestamp('2020-01-01'), periods=len(df), freq='60s')
sdf = sc.createDataFrame(df)
sdf = sdf.withColumn("amt", F.rand()*100)
w = Window.partitionBy("id").orderBy("timestamp")
sdf = sdf.withColumn("new_col", F.lag("amt").over(w) + F.lead("amt").over(w))
x = sdf.toPandas()
This gave me a event timeline like this,
I tried the code from #jxc's answer,
sdf = sc.createDataFrame(df)
sdf = sdf.withColumn("amt", F.rand()*100)
N = 24*3600*365*2
sdf_1 = sdf.withColumn('pid', F.ceil(F.unix_timestamp('timestamp')/N))
w1 = Window.partitionBy('id', 'pid').orderBy('timestamp')
w2 = Window.partitionBy('id', 'pid')
sdf_2 = sdf_1.select(
'*',
F.count('*').over(w2).alias('cnt'),
F.row_number().over(w1).alias('rn'),
(F.lag('amt',1).over(w1) + F.lead('amt',1).over(w1)).alias('new_val')
)
sdf_3 = sdf_2.filter('rn in (1, 2, cnt-1, cnt)') \
.withColumn('new_val', F.lag('amt',1).over(w) + F.lead('amt',1).over(w)) \
.filter('rn in (1,cnt)')
df_new = sdf_2.filter('rn not in (1,cnt)').union(sdf_3)
x = df_new.toPandas()
I ended up one additional stage and the event timeline looked more skewed,
Also the run time is increased by a bit with new code
To process a large partition, you can try split it based on the orderBy column(most likely a numeric column or date/timestamp column which can be converted into numeric) so that all new sub-partitions maintain the correct order of rows. process rows with the new partitioner and for calculation using lag and lead functions, only rows around the boundary between sub-partitions need to be post-processed. (Below also discussed how to merge small partitions in task-2)
Use your example sdf and assume we have the following WinSpec and a simple aggregate function:
w = Window.partitionBy('id').orderBy('timestamp')
df.withColumn('new_amt', F.lag('amt',1).over(w) + F.lead('amt',1).over(w))
Task-1: split large partitions:
Try the following:
select a N to split timestamp and set up an additional partitionBy column pid (using ceil, int, floor etc.):
# N to cover 35-days' intervals
N = 24*3600*35
df1 = sdf.withColumn('pid', F.ceil(F.unix_timestamp('timestamp')/N))
add pid into partitionBy(see w1), then calaulte row_number(), lag() and lead() over w1. find also number of rows (cnt) in each new partition to help identify the end of partitions (rn == cnt). the resulting new_val will be fine for majority of rows except those on the boundaries of each partition.
w1 = Window.partitionBy('id', 'pid').orderBy('timestamp')
w2 = Window.partitionBy('id', 'pid')
df2 = df1.select(
'*',
F.count('*').over(w2).alias('cnt'),
F.row_number().over(w1).alias('rn'),
(F.lag('amt',1).over(w1) + F.lead('amt',1).over(w1)).alias('new_amt')
)
Below is an example df2 showing the boundary rows.
process the boundary: select rows which are on the boundaries rn in (1, cnt) plus those which have values used in the calculation rn in (2, cnt-1), do the same calculation of new_val over w and save result for boundary rows only.
df3 = df2.filter('rn in (1, 2, cnt-1, cnt)') \
.withColumn('new_amt', F.lag('amt',1).over(w) + F.lead('amt',1).over(w)) \
.filter('rn in (1,cnt)')
Below shows the resulting df3 from the above df2
merge df3 back to df2 to update boundary rows rn in (1,cnt)
df_new = df2.filter('rn not in (1,cnt)').union(df3)
Below screenshot shows the final df_new around the boundary rows:
# drop columns which are used to implement logic only
df_new = df_new.drop('cnt', 'rn')
Some Notes:
the following 3 WindowSpec are defined:
w = Window.partitionBy('id').orderBy('timestamp') <-- fix boundary rows
w1 = Window.partitionBy('id', 'pid').orderBy('timestamp') <-- calculate internal rows
w2 = Window.partitionBy('id', 'pid') <-- find #rows in a partition
note: strictly, we'd better use the following w to fix boundary rows to avoid issues with tied timestamp around the boundaries.
w = Window.partitionBy('id').orderBy('pid', 'rn') <-- fix boundary rows
if you know which partitions are skewed, just divide them and skip others. the existing method might split a small partition into 2 or even more if they are sparsely distributed
df1 = df.withColumn('pid', F.when(F.col('id').isin('a','b'), F.ceil(F.unix_timestamp('timestamp')/N)).otherwise(1))
If for each partition, you can retrieve count(number of rows) and min_ts=min(timestamp), then try something more dynamically for pid(below M is the threshold number of rows to split):
F.expr(f"IF(count>{M}, ceil((unix_timestamp(timestamp)-unix_timestamp(min_ts))/{N}), 1)")
note: for skewness inside a partition, will requires more complex functions to generate pid.
if only lag(1) function is used, just post-process left boundaries, filter by rn in (1, cnt) and update only rn == 1
df3 = df1.filter('rn in (1, cnt)') \
.withColumn('new_amt', F.lag('amt',1).over(w)) \
.filter('rn = 1')
similar to lead function when we need only to fix right boundaries and update rn == cnt
if only lag(2) is used, then filter and update more rows with df3:
df3 = df1.filter('rn in (1, 2, cnt-1, cnt)') \
.withColumn('new_amt', F.lag('amt',2).over(w)) \
.filter('rn in (1,2)')
You can extend the same method to mixed cases with both lag and lead having different offset.
Task-2: merge small partitions:
Based on the number of records in a partition count, we can set up an threshold M so that if count>M, the id holds its own partition, otherwise we merge partitions so that #of total records is less than M (below method has a edging case of 2*M-2).
M = 20000
# create pandas df with columns `id`, `count` and `f`, sort rows so that rows with count>=M are located on top
d2 = pd.DataFrame([ e.asDict() for e in sdf.groupby('id').count().collect() ]) \
.assign(f=lambda x: x['count'].lt(M)) \
.sort_values('f')
# add pid column to merge smaller partitions but the total row-count in partition should be less than or around M
# potentially there could be at most `2*M-2` records for the same pid, to make sure strictly count<M, use a for-loop to iterate d1 and set pid:
d2['pid'] = (d2.mask(d2['count'].gt(M),M)['count'].shift(fill_value=0).cumsum()/M).astype(int)
# add pid to sdf. In case join is too heavy, try using Map
sdf_1 = sdf.join(spark.createDataFrame(d2).alias('d2'), ["id"]) \
.select(sdf["*"], F.col("d2.pid"))
# check pid: # of records and # of distinct ids
sdf_1.groupby('pid').agg(F.count('*').alias('count'), F.countDistinct('id').alias('cnt_ids')).orderBy('pid').show()
+---+-----+-------+
|pid|count|cnt_ids|
+---+-----+-------+
| 0|74837| 1|
| 1|20036| 133|
| 2|20052| 134|
| 3|20010| 133|
| 4|15065| 100|
+---+-----+-------+
Now, the new Window should be partitioned by pid alone and move id to orderBy, see below:
w3 = Window.partitionBy('pid').orderBy('id','timestamp')
customize lag/lead functions based on the above w3 WinSpec, and then calculate new_val:
lag_w3 = lambda col,n=1: F.when(F.lag('id',n).over(w3) == F.col('id'), F.lag(col,n).over(w3))
lead_w3 = lambda col,n=1: F.when(F.lead('id',n).over(w3) == F.col('id'), F.lead(col,n).over(w3))
sdf_new = sdf_1.withColumn('new_val', lag_w3('amt',1) + lead_w3('amt',1))
To handle such skewed data, there are a couple of things you can try out.
If you are using Databricks to run your jobs and you know which column will have the skew then you can try out an option called skew hint
I recommend moving to Spark 3.0 since you will have the option to use Adaptive Query Execution (AQE) which can handle most of the issues improving your job health and potentially running them faster.
Usually, I suggest making your data more even-sized partitions before any wide operation, and Increasing the cluster size does help but I am not sure if this will work for you.

Error while getting user input and using Pandas DataFrame to extract data from LEFT JOIN

I am trying to create Sqlite3 statement in Python 3 to collect data from two tables called FreightCargo & Train where a train ID is the input value. I want to use Pandas since its easy to read the tables.
I have created the code below which is working perfectly fine, but its static and looks for only one given line in the statement.
import pandas as pd
SQL = '''SELECT F.Cargo_ID, F.Name, F.Weight, T.Train_ID, T.Assembly_date
FROM FreightCargo F LEFT JOIN [Train] T
ON F.Cargo_ID = T.Cargo_ID
WHERE Train_ID = 2;'''
cursor = conn.cursor()
cursor.execute( SQL )
names = [x[0] for x in cursor.description]
rows = cursor.fetchall()
Temp = pd.DataFrame( rows, columns=names)
Temp'''
I want to be able to create a variable with an input. The outcome of this action will then be determined with what has been given from the user. For example the user is asked for a train_id which is a primary key in a table and the relations with the train will be listed.
I expanded the code, but I am getting an error: ValueError: operation parameter must be str
Train_ID = input('Train ID')
SQL = '''SELECT F.Cargo_ID, F.Name, F.Weight, T.Train_ID, T.Assembly_date
FROM FreightCargo F LEFT JOIN [Train] T
ON F.Cargo_ID = T.Cargo_ID
WHERE Train_ID = ?;''', (Train_ID)
cursor = conn.cursor()
cursor.execute( SQL )
names = [x[0] for x in cursor.description]
rows = cursor.fetchall()
Temp = pd.DataFrame( rows, columns=names)
Temp
The problem lays in your definition of the SQL variable.
You are creating a tuple/collection of two elements. If you print type(SQL) you will see something like this: ('''SELECT...?;''', ('your_user's_input')).
When you pass this to cursor.execute(sql[, parameters]), it is expecting a string as the first argument, with the "optional" parameters. Your parameters are not really optional, since they are defined by your SQL-query's [Train]. Parameters must be a collection, for example a tuple.
You can unwrap your SQL statement with cursor.execute(*SQL), which will pass each element of your SQL list as a different argument, or you can move the parameters to the execute function.
Train_ID = input('Train ID')
SQL = '''SELECT F.Cargo_ID, F.Name, F.Weight, T.Train_ID, T.Assembly_date
FROM FreightCargo F LEFT JOIN [Train] T
ON F.Cargo_ID = T.Cargo_ID
WHERE Train_ID = ?;'''
cursor = conn.cursor()
cursor.execute( SQL, (Train_ID,) )
names = [x[0] for x in cursor.description]
rows = cursor.fetchall()
Temp = pd.DataFrame( rows, columns=names)
Temp

Auto increment id in delta table while inserting

I have a problem regarding merging csv files using pysparkSQL with delta table. I managed to create upsert function that update if matched and insert if not matched.
I want to add column ID to the final delta table and increment it each time we insert data. This column identify each row in our delta table. Is there any way to put that in place ?
def Merge(dict1, dict2):
res = {**dict1, **dict2}
return res
def create_default_values_dict(correspondance_df,marketplace):
dict_output = {}
for field in get_nan_keys_values(get_mapping_dict(correspondance_df, marketplace)):
dict_output[field] = 'null'
# We want to increment the id row each time we perform an insertion (TODO TODO TODO)
# if field == 'id':
# dict_output['id'] = col('id')+1
# else:
return dict_output
def create_matched_update_dict(mapping, products_table, updates_table):
output = {}
for k,v in mapping.items():
if k == 'source_name':
output['products.source_name'] = lit(v)
else:
output[products_table + '.' + k] = F.when(col(updates_table + '.' + v).isNull(), col(products_table + '.' + k)).when(col(updates_table + '.' + v).isNotNull(), col(updates_table + '.' + v))
return output
insert_dict = create_not_matched_insert_dict(mapping, 'products', 'updates')
default_dict = create_default_values_dict(correspondance_df_products, 'Cdiscount')
insert_values = Merge(insert_dict, default_dict)
update_values = create_matched_update_dict(mapping, 'products', 'updates')
delta_table_products.alias('products').merge(
updates_df_table.limit(20).alias('updates'),
"products.barcode_ean == updates.ean") \
.whenMatchedUpdate(set = update_values) \
.whenNotMatchedInsert(values = insert_values)\
.execute()
I tried to increment the column id in the function create_default_values_dict but it's seems to not working well, it doesn't auto increment by 1. Is there another way to solve this problem ? Thanks in advance :)
Databricks has IDENTITY columns for hosted Spark
https://docs.databricks.com/sql/language-manual/sql-ref-syntax-ddl-create-table-using.html#parameters
GENERATED { ALWAYS | BY DEFAULT } AS IDENTITY
[ ( [ START WITH start ] [ INCREMENT BY step ] ) ]
This works on Delta tables.
Example:
create table gen1 (
id long GENERATED ALWAYS AS IDENTITY
, t string
)
Requires Runtime version 10.4 or above.
Delta does not support auto-increment column types.
In general, Spark doesn't use auto-increment IDs, instead favoring monotonically increasing IDs. See functions.monotonically_increasing_id().
If you want to achieve auto-increment behavior you will have to use multiple Delta operations, e.g., query the max value + add it to a row_number() column computed via a window function + then write. This is problematic for two reasons:
Unless you introduce an external locking mechanism or some other way to ensure that no updates to the table happen in-between finding the max value and writing, you can end up with invalid data.
Using row_number() will reduce parallelism to 1, forcing all the data through a single core, which will be very slow with large data.
Bottom line, you really do not want to use auto-increment columns with Spark.
Hope this helps.

How to optimize searching RDD contents in another RDD

I have a problem where I need to search the contents of an RDD in another RDD.
This question is different from Efficient string matching in Apache Spark, as I am searching for an exact match and I don't need the overhead of using the ML stack.
I am new to spark and I want to know which of these methods is more efficient or if there is another way.
I have a keyword file like the below sample (in production it might reach up to 200 lines)
Sample keywords file
0.47uF 25V X7R 10% -TDK C2012X7R1E474K125AA
20pF-50V NPO/COG - AVX- 08055A200JAT2A
and I have another file(tab separated)from which I need to find matches(in production I have up to 80 Million line)
C2012X7R1E474K125AA Conn M12 Circular PIN 5 POS Screw ST Cable Mount 5 Terminal 1 Port
First method
I defined a UDF and looped through keywords for each line
keywords = sc.textFile("keys")
part_description = sc.textFile("part_description")
def build_regex(keywords):
res = '('
for key in keywords:
res += '(?<!\\\s)%s(?!\\\s)|' % re.escape(key)
res = res[0:len(res) - 1] + ')'
return r'%s' % res
def get_matching_string(line, regex):
matches = re.findall(regex, line, re.IGNORECASE)
matches = list(set(matches))
return list(set(matches)) if matches else None
def find_matching_regex(line):
result = list()
for keyGroup in keys:
matches = get_matching_string(line, keyGroup)
if matches:
result.append(str(keyGroup) + '~~' + str(matches) + '~~' + str(len(matches)))
if len(result) > 0:
return result
def split_row(list):
try:
return Row(list[0], list[1])
except:
return None
keys_rdd = keywords.map(lambda keywords: build_regex(keywords.replace(',', ' ').replace('-', ' ').split(' ')))
keys = keys_rdd.collect()
sc.broadcast(keys)
part_description = part_description.map(lambda item: item.split('\t'))
df = part_description.map(lambda list: split_row(list)).filter(lambda x: x).toDF(
["part_number", "description"])
find_regex = udf(lambda line: find_matching_regex(line), ArrayType(StringType()))
df = df.withColumn('matched', find_regex(df['part_number']))
df = df.filter(df.matched.isNotNull())
df.write.save(path=job_id, format='csv', mode='append', sep='\t')
Second method
I thought I can do more parallel processing (instead of looping through keys like above) I did cartersian product between keys and lines, splitted and exploded the keys then compared each key to the part column
df = part_description.cartesian(keywords)
df = df.map(lambda tuple: (tuple[0].split('\t'), tuple[1])).map(
lambda tuple: (tuple[0][0], tuple[0][1], tuple[1]))
df = df.toDF(['part_number', 'description', 'keywords'])
df = df.withColumn('single_keyword', explode(split(F.col("keywords"), "\s+"))).where('keywords != ""')
df = df.withColumn('matched_part_number', (df['part_number'] == df['single_keyword']))
df = df.filter(df['matched_part_number'] == F.lit(True))
df.write.save(path='part_number_search', format='csv', mode='append', sep='\t')
Are these the correct ways to do this? Is there anything I can do to process these data faster?
These are both valid solutions, and I have used both in different circumstances.
You communicate less data by using your broadcast approach, sending only 200 extra lines to each executor as opposed to replicating each line of your >80m line file 200 times, so it is likely this one will end up being faster for you.
I have used the cartesian approach when the number of records in my lookup is not feasibly broadcast-able (being much, much larger than 200 lines).
In your situation, I would use broadcast.

How to implement Slowly Changing Dimensions (SCD2) Type 2 in Spark

We want to implement SCD2 in Spark using SQL Join. i got reference from Github
https://gist.github.com/rampage644/cc4659edd11d9a288c1b
but it's not very clear.
Can anybody provide any example or reference to implement SCD2 in spark
Regards,
Manish
A little outdated in terms of newer Spark SQL, but here is an example
I trialed a la Ralph Kimball using Spark SQL, that worked and is thus
reliable. You can run it and it works - but file logic and such needs
to be added - this is the body of the ETL SCD2 logic based on 1.6
syntax but run in 2.x - it is not that hard but you will need to trace
through and generate test data and trace through each step:
Some pre-processing required before script initiates, save a copy of existing and copy existing to the DIM_CUSTOMER_EXISTING.
Write new output to DIM_CUSTOMER_NEW and then copy this to target, DIM_CUSTOMER_1 or DIM_CUSTOMER_2.
The feed can also be re-created or LOAD OVERWRITE.
^^^ NEED SOME BETTER SCRIPTING AROUND THIS. ^^^ The Type 2 dimension is simply only Type 2 values, not a mixed Type 1 & Type 2.
DUMPs that are accumulative can be in fact pre-processed to get the delta.
Use case assumes we can have N input for a person with a date validity / extract supplied.
SPARK 1.6 SQL based originally, not updated yet to SPARK 2.x SQL with nested correlated subquery support.
CUST_CODE cannot changes unless a stable Primary Key.
This approach handles no input, delta input, same input, all input, and can catch up and need not be run-date based.
^^^ Works best with deltas, as if pass all data and there is no change then still have make a dummy entry with all the same values else it will have gaps in key range
which means will not be able to link facts to dimensions in all cases. I.e. the discard logic works only in terms of a pure delta feed. All data can be passed but only
the current delta. Problem becomes difficult to solve in that we must then look for changes over different rows and expand date range, a little too complicated imho.
The dummy entries in the dimensions are not a huge issue. The problem is a little more difficult in such a less mutable environment, in KUDU it easier to solve.
Ideally there would be some sort of preprocessor that checks which fields have changed and only then passed on, but that may be a bridge too far.
HENCE IT IS A COMPROMISE ALGORITHM necessarily. ^^^
No Deletions processed.
Multi-step processing for SQL required in some cases. Gaps in key ranges difficult to avoid with set processing.
No out of order processing, that would mean re-processing all. Works on a whole date/day basis, if run more than once per day in batch then would need timestamp instead.
0.1 Any difference analysis on existimg dumps only possible if the dumps are accumulative. If they are transactional deltas only, then this is not required.
Care to be taken here.
0.2 If we want only the last update for a given date, then do that here by method of Partitioning and Ranking and filtering out.
These are all pre-processing steps as are the getting of the dimension data from which table.
0.3 Issue is that of small files, but that is not an issue here at xxx. RAW usage only as written to KUDU in a final step.
Actual coding:
import org.apache.spark.sql.SparkSession
val sparkSession = SparkSession
.builder
.master("local") // Not a good idea
.appName("Type 2 dimension update")
.config("spark.sql.crossJoin.enabled", "true") // Needed to add this
.getOrCreate()
spark.sql("drop table if exists DIM_CUSTOMER_EXISTING")
spark.sql("drop table if exists DIM_CUSTOMER_NEW")
spark.sql("drop table if exists FEED_CUSTOMER")
spark.sql("drop table if exists DIM_CUSTOMER_TEMP")
spark.sql("drop table if exists DIM_CUSTOMER_WORK")
spark.sql("drop table if exists DIM_CUSTOMER_WORK_2")
spark.sql("drop table if exists DIM_CUSTOMER_WORK_3")
spark.sql("drop table if exists DIM_CUSTOMER_WORK_4")
spark.sql("create table DIM_CUSTOMER_EXISTING (DWH_KEY int, CUST_CODE String, CUST_NAME String, ADDRESS_CITY String, SALARY int, VALID_FROM_DT String, VALID_TO_DT String) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' STORED AS TEXTFILE LOCATION '/FileStore/tables/alhwkf661500326287094' ")
spark.sql("create table DIM_CUSTOMER_NEW (DWH_KEY int, CUST_CODE String, CUST_NAME String, ADDRESS_CITY String, SALARY int, VALID_FROM_DT String, VALID_TO_DT String) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' STORED AS TEXTFILE LOCATION '/FileStore/tables/DIM_CUSTOMER_NEW_3' ")
spark.sql("CREATE TABLE FEED_CUSTOMER (CUST_CODE String, CUST_NAME String, ADDRESS_CITY String, SALARY int, VALID_DT String) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' STORED AS TEXTFILE LOCATION '/FileStore/tables/mhiscfsv1500226290781' ")
// 1. Get maximum value in dimension, this differs to other RDD approach, issues in parallel? May be other way to be done! Check, get a DF here and this is the interchangability
val max_val = spark.sql("select max(dwh_key) from DIM_CUSTOMER_EXISTING")
//max_val.show()
val null_count = max_val.filter("max(DWH_KEY) is null").count()
var max_Dim_Key = 0;
if ( null_count == 1 ) {
max_Dim_Key = 0
} else {
max_Dim_Key = max_val.head().getInt(0)
}
//2. Cannot do simple difference processing. The values of certain fields could be flip-flopping over time. A too simple MINUS will not work well. Need to process relative to
// youngest existing record etc. and roll the transactions forward. Hence we will not do any sort of difference analysis between new dimension data and existing dimension
// data in any way.
// DO NOTHING.
//3. Capture new stuff to be inserted.
// Some records for a given business key can be linea recta inserted as there have been no mutations to consider at all as there is nothing in current Staging. Does not mean
// delete.
// Also, the older mutations need not be re-processed, only the youngest! The younger one may need closing off or not, need to decide if it is now
// copied across or subject to updating in this cycle, depends on the requirements.
// Older mutations copied across immediately.
// DELTA not always strictly speaking needed, but common definitions. Some ranking required.
spark.sql("""insert into DIM_CUSTOMER_NEW select *
from DIM_CUSTOMER_EXISTING
where CUST_CODE not in (select distinct CUST_CODE FROM FEED_CUSTOMER) """) // This does not need RANKing, DWH Key retained.
spark.sql("""create table DIM_CUSTOMER_TEMP as select *, dense_rank() over (partition by CUST_CODE order by VALID_FROM_DT desc) as RANK
from DIM_CUSTOMER_EXISTING """)
spark.sql("""insert into DIM_CUSTOMER_NEW select DWH_KEY, CUST_CODE, CUST_NAME, ADDRESS_CITY, SALARY, VALID_FROM_DT, VALID_TO_DT
from DIM_CUSTOMER_TEMP
where CUST_CODE in (select distinct CUST_CODE from FEED_CUSTOMER)
and RANK <> 1 """)
// For updating of youngest record in terms of SLCD, we use use AND RANK <> 1 to filter these out here as we want to close off the period in this record, but other younger
// records can be passed through immediately with their retained DWH Key.
//4. Combine Staging and those existing facts required. The result of this eventually will be stored in DIM_CUSTOMER_NEW which can be used for updating a final target.
// Issue here is that DWH Key not yet set and different columns. DWH key can be set last.
//4.1 Get records to process, the will have the status NEW.
spark.sql("""create table DIM_CUSTOMER_WORK (DWH_KEY int, CUST_CODE String, CUST_NAME String, ADDRESS_CITY String, SALARY int, VALID_FROM_DT String, VALID_TO_DT String, RECSTAT String) """)
spark.sql("""insert into DIM_CUSTOMER_WORK select 0, CUST_CODE, CUST_NAME, ADDRESS_CITY, SALARY, VALID_DT, '2099-12-31', "NEW"
from FEED_CUSTOMER """)
//4.2 Get youngest already existing dimension record to process in conjunction with newer values.
spark.sql("""insert into DIM_CUSTOMER_WORK select DWH_KEY, CUST_CODE, CUST_NAME, ADDRESS_CITY, SALARY, VALID_FROM_DT, VALID_TO_DT, "OLD"
from DIM_CUSTOMER_TEMP
where CUST_CODE in (select distinct CUST_CODE from FEED_CUSTOMER)
and RANK = 1 """)
// 5. ISSUE with first record in a set. It is not a delta or is used for making a delta, need to know what to do or bypass, depends on case.
// Here we are doing deltas, so first rec is a complete delta
// RECSTAT to be filtered out at end
// NEW, 1 = INSERT --> checked, is correct way, can do in others. No delta computation required
// OLD, 1 = DO NOTHING
// else do delta and INSERT
//5.1 RANK and JOIN to get before and after images in CDC format so that we can decide what needs to be closed off.
// Get the new DWH key values + offset, there may exist gaps eventually.
spark.sql(""" create table DIM_CUSTOMER_WORK_2 as select *, rank() over (partition by CUST_CODE order by VALID_FROM_DT asc) as rank FROM DIM_CUSTOMER_WORK """)
//DWH_KEY, CUST_CODE, CUST_NAME, BIRTH_CITY, SALARY,VALID_FROM_DT, VALID_TO_DT, "OLD"
spark.sql(""" create table DIM_CUSTOMER_WORK_3 as
select T1.DWH_KEY as T1_DWH_KEY, T1.CUST_CODE as T1_CUST_CODE, T1.rank as CURR_RANK, T2.rank as NEXT_RANK,
T1.VALID_FROM_DT as CURR_VALID_FROM_DT, T2.VALID_FROM_DT as NEXT_VALID_FROM_DT,
T1.VALID_TO_DT as CURR_VALID_TO_DT, T2.VALID_TO_DT as NEXT_VALID_TO_DT,
T1.CUST_NAME as CURR_CUST_NAME, T2.CUST_NAME as NEXT_CUST_NAME,
T1.SALARY as CURR_SALARY, T2.SALARY as NEXT_SALARY,
T1.ADDRESS_CITY as CURR_ADDRESS_CITY, T2.ADDRESS_CITY as NEXT_ADDRESS_CITY,
T1.RECSTAT as CURR_RECSTAT, T2.RECSTAT as NEXT_RECSTAT
from DIM_CUSTOMER_WORK_2 T1 LEFT OUTER JOIN DIM_CUSTOMER_WORK_2 T2
on T1.CUST_CODE = T2.CUST_CODE AND T2.rank = T1.rank + 1 """)
//5.2 Get the data for computing new Dimension Surrogate DWH Keys, must execute new query or could use DF's and RDS, RDDs, but chosen for SPARK SQL as aeasier to follow
spark.sql(s""" create table DIM_CUSTOMER_WORK_4 as
select *, row_number() OVER( ORDER BY T1_CUST_CODE) as ROW_NUMBER, '$max_Dim_Key' as DIM_OFFSET
from DIM_CUSTOMER_WORK_3 """)
//spark.sql("""SELECT * FROM DIM_CUSTOMER_WORK_4 """).show()
//Execute the above to see results, could not format here.
//5.3 Process accordingly and check if no change at all, if no change can get holes in the sequence numbers, that is not an issue. NB: NOT DOING THIS DUE TO COMPLICATIONS !!!
// See sample data above for decision-making on what to do. NOTE THE FACT THAT WE WOULD NEED A PRE_PROCCESOR TO CHECK IF FIELD OF INTEREST ACTUALLY CHANGED
// to get the best result.
// We could elaborate and record via an extra step if there were only two records per business key and if all the current and only next record fields were all the same,
// we could disregard the first and the second record. Will attempt that later as an extra optimization. As soon as there are more than two here, then this scheme packs up
// Some effort still needed.
//5.3.1 Records that just need to be closed off. The previous version gets an appropriate DATE - 1. Dates must not overlap.
// No check on whether data changed or not due to issues above.
spark.sql("""insert into DIM_CUSTOMER_NEW select T1_DWH_KEY, T1_CUST_CODE, CURR_CUST_NAME, CURR_ADDRESS_CITY, CURR_SALARY,
CURR_VALID_FROM_DT, cast(date_sub(cast(NEXT_VALID_FROM_DT as DATE), 1) as STRING)
from DIM_CUSTOMER_WORK_4
where CURR_RECSTAT = 'OLD' """)
//5.3.2 Records that are the last in the sequence must have high end 2099-12-31 set, which has already been done.
// No check on whether data changed or not due to issues above.
spark.sql("""insert into DIM_CUSTOMER_NEW select ROW_NUMBER + DIM_OFFSET, T1_CUST_CODE, CURR_CUST_NAME, CURR_ADDRESS_CITY, CURR_SALARY,
CURR_VALID_FROM_DT, CURR_VALID_TO_DT
from DIM_CUSTOMER_WORK_4
where NEXT_RANK is null """)
//5.3.3
spark.sql("""insert into DIM_CUSTOMER_NEW select ROW_NUMBER + DIM_OFFSET, T1_CUST_CODE, CURR_CUST_NAME, CURR_ADDRESS_CITY, CURR_SALARY,
CURR_VALID_FROM_DT, cast(date_sub(cast(NEXT_VALID_FROM_DT as DATE), 1) as STRING)
from DIM_CUSTOMER_WORK_4
where CURR_RECSTAT = 'NEW'
and NEXT_RANK is not null""")
spark.sql("""SELECT * FROM DIM_CUSTOMER_NEW """).show()
// So, the question is if we could have done without JOINing and just sorted due to gap processing. This was derived off the delta processing but it turned out a little
// different.
// Well we did need the JOIN for next date at least, so if we add some optimization it still holds.
// My logic applied here per different steps, may well be less steps, left as is.
//6. The copy / insert to get a new big target table version and re-compile views. Outside of this actual processing. Logic performed elsewhere.
// NOTE now that 2.x supports nested correlated sub-queries are supported, so would need to re-visit this at a later point, but can leave as is.
// KUDU means no more restating.
Sample data so you know what to generate for the examples:
+-------+---------+----------------+------------+------+-------------+-----------+
|DWH_KEY|CUST_CODE| CUST_NAME|ADDRESS_CITY|SALARY|VALID_FROM_DT|VALID_TO_DT|
+-------+---------+----------------+------------+------+-------------+-----------+
| 230| E222222| Pete Saunders| Leeds| 75000| 2013-03-09| 2099-12-31|
| 400| A048901| John Alexander| Calgary| 22000| 2015-03-24| 2017-10-22|
| 402| A048901| John Alexander| Wellington| 47000| 2017-10-23| 2099-12-31|
| 403| B787555| Mark de Wit|Johannesburg| 49500| 2017-10-02| 2099-12-31|
| 406| C999666| Daya Dumar| Mumbai| 50000| 2016-12-16| 2099-12-31|
| 404| C999666| Daya Dumar| Mumbai| 49000| 2016-11-11| 2016-12-14|
| 405| C999666| Daya Dumar| Mumbai| 50000| 2016-12-15| 2016-12-15|
| 300| A048901| John Alexander| Calgary| 15000| 2014-03-24| 2015-03-23|
+-------+---------+----------------+------------+------+-------------+-----------+
Here's the detailed implementation of slowly changing dimension type 2 in Spark (Data frame and SQL) using exclusive join approach.
Assuming that the source is sending a complete data file i.e. old, updated and new records.
Steps:
Load the recent file data to STG table
Select all the expired records from HIST table
1. select * from HIST_TAB where exp_dt != '2099-12-31'
Select all the records which are not changed from STG and HIST using inner join and filter on HIST.column = STG.column as below
2. select hist.* from HIST_TAB hist inner join STG_TAB stg on hist.key = stg.key where hist.column = stg.column
Select all the new and updated records which are changed from STG_TAB using exclusive left join with HIST_TAB and set expiry and effective date as below
3. select stg.*, eff_dt (yyyy-MM-dd), exp_dt (2099-12-31) from STG_TAB stg left join (select * from HIST_TAB where exp_dt = '2099-12-31') hist
on hist.key = stg.key where hist.key is null or hist.column != stg.column
Select all updated old records from the HIST table using exclusive left join with STG table and set their expiry date as shown below:
4. select hist.*, exp_dt(yyyy-MM-dd) from (select * from HIST_TAB where exp_dt = '2099-12-31') hist left join STG_TAB stg
on hist.key= stg.key where hist.key is null or hist.column!= stg.column
unionall queries from 1-4 and insert overwrite result to HIST table
More detailed implementation of SCD type 2 in Scala and Pyspark can be found here-
https://github.com/sahilbhange/spark-slowly-changing-dimension
Hope this helps!
scala spark: https://georgheiler.com/2020/11/19/sparkling-scd2/
NOTICE: this is not a full SCD2 - it assumes one table of events and it determines/ deduplicates valid_from/valid_to from them i.e. no merge/upsert is implemented
val df = Seq(("k1","foo", "2020-01-01"), ("k1","foo", "2020-02-01"), ("k1","baz", "2020-02-01"),
("k2","foo", "2019-01-01"), ("k2","foo", "2019-02-01"), ("k2","baz", "2019-02-01")).toDF("key", "value_1", "date").withColumn("date", to_date(col("date")))
df.show
+---+-------+----------+
|key|value_1| date|
+---+-------+----------+
| k1| foo|2020-01-01|
| k1| foo|2020-02-01|
| k1| baz|2020-02-01|
| k2| foo|2019-01-01|
| k2| foo|2019-02-01|
| k2| baz|2019-02-01|
+---+-------+----------+
df.printSchema
root
|-- key: string (nullable = true)
|-- value_1: string (nullable = true)
|-- date: date (nullable = true)
df.transform(deduplicateScd2(Seq("key"), Seq("date"), "date", Seq())).show
+---+-------+----------+----------+
|key|value_1|valid_from| valid_to|
+---+-------+----------+----------+
| k1| foo|2020-01-01|2020-02-01|
| k1| baz|2020-02-01|2020-11-18|
| k2| foo|2019-01-01|2019-02-01|
| k2| baz|2019-02-01|2020-11-18|
+---+-------+----------+----------+
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.lag
import org.apache.spark.sql.functions.lead
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.functions.current_date
def deduplicateScd2(
key: Seq[String],
sortChangingIgnored: Seq[String],
timeColumn: String,
columnsToIgnore: Seq[String]
)(df: DataFrame): DataFrame = {
val windowPrimaryKey = Window
.partitionBy(key.map(col): _*)
.orderBy(sortChangingIgnored.map(col): _*)
val columnsToCompare =
df.drop(key ++ sortChangingIgnored: _*).drop(columnsToIgnore: _*).columns
val nextDataChange = lead(timeColumn, 1).over(windowPrimaryKey)
val deduplicated = df
.withColumn(
"data_changes_start",
columnsToCompare
.map(e => {
val previous = lag(col(e), 1).over(windowPrimaryKey)
val self = col(e)
// 3 cases: 1.: start (previous is NULL), 2: in between, try to collapse 3: end (= next is null)
// first, filter to only start & end events (= updates/invalidations of records)
//self =!= previous or self =!= next or previous.isNull or next.isNull
self =!= previous or previous.isNull
})
.reduce(_ or _)
)
.withColumn(
"data_changes_end",
columnsToCompare
.map(e => {
val next = lead(col(e), 1).over(windowPrimaryKey)
val self = col(e)
// 3 cases: 1.: start (previous is NULL), 2: in between, try to collapse 3: end (= next is null)
// first, filter to only start & end events (= updates/invalidations of records)
self =!= next or next.isNull
})
.reduce(_ or _)
)
.filter(col("data_changes_start") or col("data_changes_end"))
.drop("data_changes")
deduplicated //.withColumn("valid_to", nextDataChange)
.withColumn(
"valid_to",
when(col("data_changes_end") === true, col(timeColumn))
.otherwise(nextDataChange)
)
.filter(col("data_changes_start") === true)
.withColumn(
"valid_to",
when(nextDataChange.isNull, current_date()).otherwise(col("valid_to"))
)
.withColumnRenamed(timeColumn, "valid_from")
.drop("data_changes_end", "data_changes_start")
}
}
Here an updated answer with MERGE.
Note it will not work with Spark Structured Streaming, but can be used with Spark Kafka Batch Integration.
// 0. Standard, start of program.
// Handles multiple business keys in a single run. DELTA tables.
// Schema evolution also handled.
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
val sparkSession = SparkSession.builder
.master("local") // Not realistic
.appName("REF Zone History stuff and processing")
.enableHiveSupport() // Standard in Databricks.
.getOrCreate()
// 1. Read newer data to process in some way. Create tempView.
// In general we should have few rows to process, i.e. not at scale.
val dfA = spark.read.option("multiLine",false).json("/FileStore/tables/new_customers_json_multiple_alt3.txt") // New feed.
dfA.createOrReplaceTempView("newFeed")
// 2. First create the target for data at rest if it does not exist. Add an ASC col_key. Should only occur once.
val save_path = "/some_loc_fix/ref/atRest/data" // Make dynamic.
val table_name = "CUSTOMERS_AT_REST"
spark.sql("CREATE TABLE IF NOT EXISTS " + table_name + " LOCATION '" + save_path + "'" + " AS SELECT * from newFeed WHERE 1 = 0 " ) // Can also use limit 0 instead of WHERE 1 = 0.
// Add an ASC col_key column if it does not exist.
// I have in input valid_from_dt, but it could be different so we would need to add in reality as well. Mark to decide.
try {
spark.sql("ALTER TABLE " + table_name + " ADD COLUMNS (col_key BIGINT FIRST, valid_to_dt STRING) ")
} catch {
case unknown: Exception => {
None
}
}
// 3. Get maximum value for target. This is a necessity.
val max_val = spark.sql("select max(col_key) from " + table_name)
//max_val.show()
val null_count = max_val.filter("max(col_key) is null").count()
var max_Col_Key: BigInt = 0;
if ( null_count == 1 ) {
max_Col_Key = 0
} else {
max_Col_Key = max_val.head().getLong(0) // Long and BIGINT interoperable.
}
// 4.1 Create a temporary table for getting the youngest records from the existing data. table_name as variable, newFeed tempView as string. Then apply processing.
val dfB = spark.sql(" select O.* from (select A.cust_code, max(A.col_key) as max_col_key from " + table_name + " A where A.cust_code in (select B.cust_code from newFeed B ) group by A.cust_code ) Z, " + table_name + " O where O.col_key = Z.max_col_key ") // Most recent records.
// No tempView required.
// 4.2 Get the set of data to actually process. New feed + youngest records in feed.
val dfC =dfA.unionByName(dfB, true)
dfC.createOrReplaceTempView("cusToProcess")
// 4.3 RANK
val df1 = spark.sql("""select *, dense_rank() over (partition by CUST_CODE order by VALID_FROM_DT desc) as RANK from CusToProcess """)
df1.createOrReplaceTempView("CusToProcess2")
// 4.4 JOIN adjacent records & process closing off dates etc.
val df2 = spark.sql("""select A.*, B.rank as B_rank, cast(date_sub(cast(B.valid_from_dt as DATE), 1) as STRING) as untilMinus1
from CusToProcess2 A LEFT OUTER JOIN CusToProcess2 B
on A.cust_code = B.cust_code and A.RANK = B.RANK + 1 """)
val df3 = df2.drop("valid_to_dt").withColumn("valid_to_dt", $"untilMinus1").drop("untilMinus1").drop("B_rank")
val df4 = df3.withColumn("valid_to_dt", when($"valid_to_dt".isNull, lit("2099-12-31")).otherwise($"valid_to_dt")).drop("RANK")
df4.createOrReplaceTempView("CusToProcess3")
val df5 = spark.sql(s""" select *, row_number() OVER( ORDER BY cust_code ASC, valid_from_dt ASC) as ROW_NUMBER, '$max_Col_Key' as col_OFFSET
from CusToProcess3 """)
// Add new ASC col_key, gaps can result, not an issue must always be ascending.
val df6 = df5.withColumn("col_key", when($"col_key".isNull, ($"ROW_NUMBER" + $"col_OFFSET")).otherwise($"col_key"))
val df7 = df6.withColumn("col_key", col("col_key").cast(LongType)).drop("ROW_NUMBER").drop("col_OFFSET")
// 5. ACTUAL MERGE, is very simple.
// More than one Merge key possible? Need then to have a col_key if only one such possible.
df7.createOrReplaceTempView("CUST_DELTA")
spark.sql("SET spark.databricks.delta.schema.autoMerge.enabled = true")
spark.sql(""" MERGE INTO CUSTOMERS_AT_REST
USING CUST_DELTA
ON CUSTOMERS_AT_REST.col_key = CUST_DELTA.col_key
WHEN MATCHED THEN
UPDATE SET *
WHEN NOT MATCHED THEN
INSERT *
""")

Resources