Best practice to reduce the number of spark job started. Use union? - apache-spark

With Spark, starting a job take a time.
For a complex workflow, it's possible to invoke a job in a loop.
But, we pay for each 'start'.
def test_loop(spark):
all_datas = []
for i in ['CZ12905K01', 'CZ12809WRH', 'CZ129086RP']:
all_datas.extend(spark.sql(f"""
select * from data where id=='{i}'
""").collect()) # Star a job
return all_datas
Sometime, it's possible to explode the loop to a big job, with 'union'.
def test_union(spark):
full_request = None
for i in ['CZ12905K01', 'CZ12809WRH' ,'CZ129086RP']:
q = f"""
select '{i}' ID,* from data where leh_be_lot_id=='{i}'
"""
partial_df = spark.sql(q)
if not full_request:
full_request = partial_df
else:
full_request = full_request.union(partial_df)
return full_request.collect() # Start a job
For clarity, my samples are elementary (I know, I can use in (...)) . The real requests will be more complex.
It's a good idea ?
With union approach, I can reduce drastically the number of jobs submitted, but with a more complex job.

My tests show that:
It's possible to union > 1000 request for a better performance
For 950 requests with local[6]
With 0 union : 1h53m
with 10 unions : 20m01s
with 100 unions : 7m12s
with 200 unions : 6m02s
with 500 unions : 6m25s
Sometime, the union version must "broadcast big data", or generate an "Out of memory"
My final approach: set a level_of_union
Merge some request, start the job, get the data
continue the loop with another batch
def test_union(spark,level_of_union):
full_request = None
all_datas = []
todo=['CZ12905K01', 'CZ12809WRH' ,'CZ129086RP']
for idx,i in enumerate(todo):
q = f"""
select '{i}' ID,* from leh_be where leh_be_lot_id=='{i}'
"""
partial_df = spark.sql(q)
if not full_request:
full_request = partial_df
else:
full_request = full_request.union(partial_df)
if idx % level_of_union == level_of_union-1 or idx == len(todo)-1:
all_datas.extend(full_request.collect()) # Start a job
full_request=None
return all_datas
Make a test to adjust the meta parameter : level_of_union

Related

Single Task Taking Long Time in PySpark

I am running a PySpark application where I am reading several Parquet files into Spark dataframes and created temporary views on them to use in my SQL query. So I have like 18 views where some are ~ 1TB, few in several GBs and some other smaller views. I am joining all of these and running my business logic to get the desired outcome. My code takes extremely long time to run (>3 hours) for this data. Looking at the Spark History Server, I can see there's one task that seems the culprit as the time taken, data spilled to memory and disk, shuffle read/write everything is way higher than the median. This indicates a data skew. So I even used salting on my large dataframes before creating the temp views. However there's still no difference in the execution time. I checked the number of partitions and it's already 792 (maximum I can have my current Glue config). I have also enabled adaptive query execution and adaptive skewJoin handling.
My original dataset was extremely huge largest table being ~40TB and has 2.5 years of data. I am trying to do a one time historical load and was unsuccessful on running over the entire data. With trial and error, I had to reduce this to processing 1TB of data at a time (for the largest table) which is still taking 3+ hours. This is not a scalable approach and hence I am looking for some inputs to optimize this.
Below are my app details:
Number of workers = 792
Spark config:
spark= (SparkSession
.builder
.appName("scmCaseAlertDatamartFullLoad")
.config("spark.sql.sources.partitionOverwriteMode", "dynamic")
.config("spark.sql.adaptive.enabled","true")
.config("spark.sql.broadcastTimeout","900")
.config("spark.sql.adaptive.skewJoin.enabled","true")
.getOrCreate()
)
Code (just included key relevant methods, starting point is loadSCMCseAlertData()):
def getIncomingMatchesFullData(self):
select_query_incoming_matches_full_data = """
SELECT DISTINCT alrt.caseid AS case_id,
alrt.alertid AS alert_id,
alrt.accountid AS account_id,
sc.created_time AS case_created_time,
sc.last_updated_time AS case_last_updated_time,
alrt.srccreatedtime AS alert_created_time,
aud.last_updated_by AS case_last_updated_by,
sc.closed_time AS case_last_close_time,
lcs.status AS case_status,
lcst.state AS case_state,
lcra.responsive_action,
sc.assigned_to AS case_assigned_to,
cr1.team_name AS case_assigned_to_team,
sc.resolved_by AS case_resolved_by,
cr2.team_name AS case_resolved_by_team,
aud.last_annotation AS case_last_annotation,
ca.name AS case_approver,
alrt.screeningdecision AS screening_decision,
ap.accountpool AS division,
lcd.decision AS case_current_decision,
CASE
WHEN sm.grylaclientid LIKE '%AddressService%' THEN 'Address Service'
WHEN sm.grylaclientid LIKE '%GrylaOrderProcessingService%' THEN 'Retail Checkout Service'
WHEN sm.grylaclientid = 'urn:cdo:GrylaBatchScreeningAAA:AWS:Default' THEN 'Batch Screening'
WHEN sm.grylaclientid = 'urn:cdo:OfficerJennyBindle:AWS:Default' THEN 'API'
ELSE 'Other'
END AS channel,
ap.businesstype AS business_type,
ap.businessname AS business_name,
ap.marketplaceid AS ap_marketplace_id,
ap.region AS ap_region,
ap.memberid AS ap_member_id,
ap.secondaryaccountpool AS secondary_account_pool,
sm.action AS client_action,
acl.added_by,
acl.lnb_id AS accept_list_lnb_id,
acl.created_time AS accept_list_created_time,
acl.source_case_id AS accept_list_source_case_id,
acs.status AS accept_list_status,
ap.street1 AS ap_line_1,
ap.street2 AS ap_line_2,
ap.street3 AS ap_line_3,
ap.city AS ap_city,
ap.state AS ap_state,
ap.postalcode AS ap_postal_code,
ap.country AS ap_country,
ap.fullname AS ap_full_name,
ap.email AS ap_email,
sm.screening_match_id AS dp_screening_match_id,
CASE
WHEN sm.matchtype = 'name_only_matching_details' THEN 'Name Only'
WHEN sm.matchtype = 'address_only_matching_details' THEN 'Address Only'
WHEN sm.matchtype = 'address_matching_details' THEN 'Address'
WHEN sm.matchtype = 'scr_matching_details' THEN 'SCR'
WHEN sm.matchtype = 'hotkey_matching_details' THEN 'HotKey'
END AS match_type,
sm.matchaction AS match_action,
alrt.batchfilename AS batch_file_id,
REGEXP_REPLACE(dp.name, '\\n|\\r|\\t', ' ') AS dp_matched_add_full_name,
dp.street AS dp_line1,
'' AS dp_line2,
dp.city AS dp_city,
dp.state AS dp_state,
dp.postalcode AS dp_postal_code,
dp.country AS dp_country,
dp.matchedplaces AS scr_value,
dp.hotkeyvalues AS hotkey_value,
sm.acceptlistid AS suppressed_by_accept_list_id,
sm.suppresseddedupe AS is_deduped,
sm.matchhash AS hash,
sm.matchdecision AS match_decision,
ap.addressid AS amazon_address_id,
ap.dateofbirth AS date_of_birth,
sm.grylaclientid AS gryla_client_id,
cr1.name AS case_assigned_to_role,
cr2.name AS case_resolved_by_role,
alrt.screeningengine AS screening_engine,
sm.srccreatedtime AS match_created_time,
sm.srclastupdatedtime AS match_updated_time,
to_date(sm.srclastupdatedtime,"yyyy-MM-dd") AS match_updated_date,
sm.match_updated_time_msec,
sm.suppressedby AS match_suppressed_by
FROM
cm_screening_match sm
JOIN
cm_screening_match_redshift smr ON sm.screening_match_id = smr.screening_match_id
LEFT JOIN
cm_case_alert alrt ON sm.screening_match_id = alrt.screening_match_id
LEFT JOIN
cm_amazon_party ap ON sm.screening_match_id = ap.screening_match_id
LEFT JOIN
cm_denied_party dp ON sm.screening_match_id = dp.screening_match_id
LEFT JOIN
cm_spectre_case sc ON alrt.caseid = sc.case_id
LEFT JOIN
cm_lookup_case_status lcs ON sc.status_id = lcs.status_id
LEFT JOIN
cm_lookup_case_state lcst ON sc.state_id = lcst.state_id
LEFT JOIN
cm_lookup_case_decision lcd ON sc.decision_id = lcd.decision_id
LEFT JOIN
cm_lookup_case_responsive_action lcra ON sc.responsive_action_id = lcra.responsive_action_id
LEFT JOIN
cm_user cu1 ON sc.assigned_to = cu1.alias
LEFT JOIN
cm_role cr1 ON cu1.current_role_id = cr1.role_id
LEFT JOIN
cm_user cu2 ON sc.resolved_by = cu2.alias
LEFT JOIN
cm_role cr2 ON cu2.current_role_id = cr2.role_id
LEFT JOIN
cm_accept_list acl ON acl.screening_match_id = sm.screening_match_id
LEFT JOIN
cm_lookup_accept_list_status acs ON acs.status_id = acl.status_id
LEFT JOIN
(
SELECT case_id,
last_value(username) OVER (PARTITION BY case_id ORDER BY created_time
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS last_updated_by,
last_value(description) OVER (PARTITION BY case_id ORDER BY created_time
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS last_annotation
FROM cm_spectre_case_audit
) aud ON sc.case_id = aud.case_id
LEFT JOIN
cm_approver ca ON sc.approver_id = ca.approver_id
"""
print(select_query_incoming_matches_full_data)
incomingMatchesFullDF = self.spark.sql(select_query_incoming_matches_full_data)
return incomingMatchesFullDF
def getBaseTables(self,matchtime_lower_threshold,matchtime_upper_threshold,cursor):
print('Fetching datalake data for matches created after: {}' .format(matchtime_lower_threshold))
matchDF = self.getDatalakeData(matchtime_lower_threshold,matchtime_upper_threshold,self.data_input_match)
matchDF = matchDF.select("screening_match_id","grylaclientid","action","matchtype","matchaction","acceptlistid","suppresseddedupe","matchhash","matchdecision","srccreatedtime","srclastupdatedtime","suppressedby","lastupdatedtime")
#.withColumn("screentime",to_timestamp("screentime")) \
matchDF = matchDF.withColumn("match_updated_time_msec",col("lastupdatedtime").cast(LongType())).drop("lastupdatedtime")
#matchDF = matchDF.repartition(2400,"screening_match_id")
matchDF = self.getLatestRecord(matchDF)
matchDF = matchDF.withColumn("salt", rand())
matchDF = matchDF.repartition("salt")
matchDF.createOrReplaceTempView("cm_screening_match")
print("Total from matchDF:",matchDF.count())
print("Number of paritions in matchDF: " ,matchDF.rdd.getNumPartitions())
alertDF = self.getDatalakeData(matchtime_lower_threshold,matchtime_upper_threshold,self.data_input_alert)
alertDF = alertDF.select("screening_match_id","caseid","alertid","accountid","srccreatedtime","screeningdecision","batchfilename","screeningengine","lastupdatedtime")
alertDF = alertDF.withColumn("match_updated_time_msec",col("lastupdatedtime").cast(LongType())).drop("lastupdatedtime")
#alertDF = alertDF.repartition(2400,"screening_match_id")
alertDF = self.getLatestRecord(alertDF)
alertDF = alertDF.withColumn("salt", rand())
alertDF = alertDF.repartition("salt")
alertDF.createOrReplaceTempView("cm_case_alert")
print("Total from alertDF:",alertDF.count())
print("Number of paritions in alertDF: " ,alertDF.rdd.getNumPartitions())
apDF = self.getDatalakeData(matchtime_lower_threshold,matchtime_upper_threshold,self.data_input_ap)
apDF = apDF.select("screening_match_id","accountpool","businesstype","businessname","marketplaceid","region","memberid","secondaryaccountpool","street1","street2","street3","city","state","postalcode","country","fullname","email","addressid","dateofbirth","lastupdatedtime")
apDF = apDF.withColumn("dateofbirth",to_date("dateofbirth","yyyy-MM-dd")) \
.withColumn("match_updated_time_msec",col("lastupdatedtime").cast(LongType())) \
.drop("lastupdatedtime")
#apDF = apDF.repartition(2400,"screening_match_id")
apDF = self.getLatestRecord(apDF)
apDF = apDF.withColumn("salt", rand())
apDF = apDF.repartition("salt")
apDF.createOrReplaceTempView("cm_amazon_party")
print("Total from apDF:",apDF.count())
print("Number of paritions in apDF: " ,apDF.rdd.getNumPartitions())
dpDF = self.getDatalakeData(matchtime_lower_threshold,matchtime_upper_threshold,self.data_input_dp)
dpDF = dpDF.select("screening_match_id","name","street","city","state","postalcode","country","matchedplaces","hotkeyvalues","lastupdatedtime")
dpDF = dpDF.withColumn("match_updated_time_msec",col("lastupdatedtime").cast(LongType())).drop("lastupdatedtime")
#dpDF = dpDF.repartition(2400,"screening_match_id")
dpDF = self.getLatestRecord(dpDF)
dpDF = dpDF.withColumn("salt", rand())
dpDF = dpDF.repartition("salt")
dpDF.createOrReplaceTempView("cm_denied_party")
print("Total from dpDF:",dpDF.count())
print("Number of paritions in dpDF: " ,dpDF.rdd.getNumPartitions())
print('Fetching data from Redshift Base tables...')
self.getRedshiftData(matchtime_lower_threshold,matchtime_upper_threshold,cursor)
caseAuditDF = self.spark.read.parquet(self.data_input_case_audit)
caseAuditDF.createOrReplaceTempView("cm_spectre_case_audit")
caseDF = self.spark.read.parquet(self.data_input_case)
caseDF.createOrReplaceTempView("cm_spectre_case")
caseStatusDF = self.spark.read.parquet(self.data_input_case_status)
caseStatusDF.createOrReplaceTempView("cm_lookup_case_status")
caseStateDF = self.spark.read.parquet(self.data_input_case_state)
caseStateDF.createOrReplaceTempView("cm_lookup_case_state")
caseDecisionDF = self.spark.read.parquet(self.data_input_case_decision)
caseDecisionDF.createOrReplaceTempView("cm_lookup_case_decision")
caseRespActDF = self.spark.read.parquet(self.data_input_case_responsive_action)
caseRespActDF.createOrReplaceTempView("cm_lookup_case_responsive_action")
userDF = self.spark.read.parquet(self.data_input_user)
userDF.createOrReplaceTempView("cm_user")
userSnapshotDF = self.spark.read.parquet(self.data_input_user_snapshot)
userSnapshotDF.createOrReplaceTempView("v_cm_user_snapshot")
roleDF = self.spark.read.parquet(self.data_input_role)
roleDF.createOrReplaceTempView("cm_role")
skillDF = self.spark.read.parquet(self.data_input_skill)
skillDF.createOrReplaceTempView("cm_skill")
lookupSkillDF = self.spark.read.parquet(self.data_input_lookup_skills)
lookupSkillDF.createOrReplaceTempView("cm_lookup_skills")
skillTypeDF = self.spark.read.parquet(self.data_input_skill_type)
skillTypeDF.createOrReplaceTempView("cm_skill_type")
acceptListDF = self.spark.read.parquet(self.data_input_accept_list)
acceptListDF.createOrReplaceTempView("cm_accept_list")
lookupAcceptListStatusDF = self.spark.read.parquet(self.data_input_lookup_accept_list_status)
lookupAcceptListStatusDF.createOrReplaceTempView("cm_lookup_accept_list_status")
approverDF = self.spark.read.parquet(self.data_input_approver)
approverDF.createOrReplaceTempView("cm_approver")
screeningMatchDF_temp = self.spark.read.parquet(self.data_input_screening_match_redshift)
screeningMatchLookupDF_temp = self.spark.read.parquet(self.data_input_lookup_screening_match_redshift)
screeningMatchLookupDF_temp_new = screeningMatchLookupDF_temp.withColumnRenamed("screening_match_id","lookupdf_screening_match_id")
"""
The screening_match_id in datalake table is a mix of alphanumeric match IDs (the ones in cm_lookup_screening_match_id in Redshift) and numeric (the ones in cm_screening_match in Redshift). Hence we combine the match IDs from both the Redshift tables. Also, there are matches which were created in the past but updated recently. Since updated date is only present in cm_screening_match and not in cm_lookup_screening_match_id, we will only have the numeric match Ids. When we join this to datalake table, we won't be able to find these matches as they are present in the alphanumeric form in datalake. Hence what we do is read the entire table of cm_lookup_screening_match_id and join it with cm_screening_match to enrich cm_screening_match with the alphanumeric match Id. Finally we filter cm_lookup_screening_match_id only for newly created matches and combine with the matches from enriched version of cm_screening_match.
"""
screeningMatchDF_enriched = screeningMatchDF_temp.join(screeningMatchLookupDF_temp_new,screeningMatchDF_temp.screening_match_id == screeningMatchLookupDF_temp_new.lookupdf_screening_match_id,"left")
screeningMatchDF_enriched = screeningMatchDF_enriched.withColumn("screening_match_id",col("screening_match_id").cast(StringType()))
screeningMatchDF = screeningMatchDF_enriched.select(col("screening_match_id")).union(screeningMatchDF_enriched.select(col("match_event_id")))
screeningMatchLookupDF = screeningMatchLookupDF_temp_new.filter("created_time > '{}'" .format(matchtime_lower_threshold)).select(col("match_event_id"))
screeningMatchRedshiftDF = screeningMatchDF.union(screeningMatchLookupDF)
#screeningMatchRedshiftDF = screeningMatchRedshiftDF.repartition(792,"screening_match_id")
screeningMatchRedshiftDF = screeningMatchRedshiftDF.withColumn("salt", rand())
screeningMatchRedshiftDF = screeningMatchRedshiftDF.repartition("salt")
screeningMatchRedshiftDF.createOrReplaceTempView("cm_screening_match_redshift")
print("Total from screeningMatchRedshiftDF:",screeningMatchRedshiftDF.count())
def loadSCMCaseAlertTable(self):
print('Getting the thresholds for data to be loaded')
matchtime_lower_threshold = self.getLowerThreshold('scm_case_alert_data')
print('Match time lower threshold is: {}' .format(matchtime_lower_threshold))
matchtime_upper_threshold = self.default_upper_threshold
print('Match time upper threshold is: {}' .format(matchtime_upper_threshold))
print("Getting the required base tables")
con = self.get_redshift_connection()
cursor = con.cursor()
self.getBaseTables(matchtime_lower_threshold,matchtime_upper_threshold,cursor)
print("Getting the enriched dataset for incoming matches (the ones to be inserted or updated)")
incomingMatchesFullDF = self.getIncomingMatchesFullData()
print("Total records in incomingMatchesFullDF: ", incomingMatchesFullDF.count())
print("Copying the incoming data to temp work dir")
print("Clearing work directory: {}" .format(self.work_scad_path))
self.deleteAllObjectsFromS3Prefix(self.dest_bucket,self.dest_work_prefix_scad)
print("Writing data to work dir: {}" .format(self.work_scad_path))
#.coalesce(1) \
incomingMatchesFullDF.write \
.partitionBy("match_updated_date") \
.mode("overwrite") \
.parquet(self.work_scad_path + self.work_dir_partitioned_table_scad)
print("Data copied to work dir")
print("Reading data from work dir in a temporary dataframe")
incomingMatchesFullDF_copy = self.spark.read.parquet(self.work_scad_path + "scm_case_alert_data_work.parquet/")
if self.update_mode == 'overwrite':
print("Datamart update mode is overwrite. New data will replace existing data.")
print("Publishing to Redshift")
self.publishToRedshift(con,cursor)
print("Publishing to Redshift complete")
elif self.update_mode == 'upsert':
print("Datamart update mode is upsert. New data will be loaded and existing data will be updated.")
print("Checking for cases updated between {} and {}" .format(matchtime_lower_threshold,matchtime_upper_threshold))
updatedCasesDF = self.getUpdatedCases(matchtime_lower_threshold,matchtime_upper_threshold)
updatedCasesDF.createOrReplaceTempView("updated_cases")
print("Getting updated case attributes")
updatedCaseAttributesDF = self.getUpdatedCaseAttributes()
print("Moving updated case data to temp work directory: {}".format(self.work_updated_cases_path))
print("Clearing work directory")
self.deleteAllObjectsFromS3Prefix(self.dest_bucket,self.dest_work_prefix_updated_cases)
try:
print("Writing data to work dir: {}" .format(self.work_updated_cases_path))
updatedCaseAttributesDF.coalesce(1) \
.write \
.mode("overwrite") \
.parquet(self.work_updated_cases_path + "updated_cases.parquet")
except Exception as e:
e = sys.exc_info()[0]
print("No data to write to work dir")
print("Starting the process to publish data to Redshift")
self.publishToRedshift(con,cursor)
print("Publishing to Redshift complete")
print('Updating metadata table')
matchtime_lower_threshold_new = incomingMatchesFullDF_copy.agg({'match_updated_time': 'max'}).collect()[0][0]
if matchtime_lower_threshold_new is not None:
matchtime_lower_threshold_new_formatted = matchtime_lower_threshold_new.strftime("%Y-%m-%d %H:%M:%S")
print("Latest match time lower threshold with new load: {}" .format(matchtime_lower_threshold_new_formatted))
self.updatePipelineMetadata('scm_case_alert_data','max_data_update_time',matchtime_lower_threshold_new_formatted)
else:
print("No new matches, leaving max_data_update_time for match as it is")
print("Metadata table up to date")
print("Committing the updates to Redshift and closing the connection")
con.commit() #Committing after the metadata table is updated to ensure the datamart data and threshold are aligned
cursor.close()
con.close()
Spark History Server Screenshot:
As you have correctly felt, you're having data skew issues. This is really apparent from your last screenshot. Have a look at the shuffle read/write sizes! The thing that you have to find out is: for which shuffle operation (looks like a join) are you having this issue?
Only salting the large dataframes without knowing where your skew is wont solve the issue.
So, my proposed plan of action:
You see that stage 112 from your picture is the problematic stage. Figure out which join operation this is about. In the SQL tab of the web-ui you can find that stage 112 and hover over it. That should give you enough info to figure out which shuffle/join key is skewed.
Once you know which key is skewed, understand the statistical contents of your key using spark-shell or something like that. Figure out which value is overly common. This will help in making future decisions. A simple df.groupBy("problematicKey").count will already be really interesting.
Once you know that, you can go ahead and salt that specific key.
But you're absolutely on the right track! Keeping an eye on that Tasks page and the time it takes for each task is a great approach!
Hope this helps :)

How to get maximum/minimum duration of all DagRun instances in Airflow?

Is there a way to find the maximum/minimum or even an average duration of all DagRun instances in Airflow? - That is all dagruns from all dags not just one single dag.
I can't find anywhere to do this on the UI or even a page with a programmatic/command line example.
You can use airflow- api to get all dag_runs for dag and calculate statistics.
An example to get all dag_runs per dag and calc total time :
import datetime
import requests
from requests.auth import HTTPBasicAuth
airflow_server = "http://localhost:8080/api/v1/"
auth = HTTPBasicAuth("airflow", "airflow")
get_dags_url = f"{airflow_server}dags"
get_dag_params = {
"limit": 100,
"only_active": "true"
}
response = requests.get(get_dags_url, params=get_dag_params, auth=auth)
dags = response.json()["dags"]
get_dag_run_params = {
"limit": 100,
}
for dag in dags:
dag_id = dag["dag_id"]
dag_run_url = f"{airflow_server}/dags/{dag_id}/dagRuns?limit=100&state=success"
response = requests.get(dag_run_url, auth=auth)
dag_runs = response.json()["dag_runs"]
for dag_run in dag_runs:
execution_date = datetime.datetime.fromisoformat(dag_run['execution_date'])
end_date = datetime.datetime.fromisoformat(dag_run['end_date'])
duration = end_date - execution_date
duration_in_s = duration.total_seconds()
print(duration_in_s)
The easiest way will be to query your Airflow metastore. All the scheduling, DAG runs, and task instances are stored there and Airflow can't operate without it. I do recommend filtering on DAG/execution date if your use-case allows. It's not obvious to me what one can do with just these three overarching numbers alone.
select
min(runtime_seconds) min_runtime,
max(runtime_seconds) max_runtime,
avg(runtime_seconds) avg_runtime
from (
select extract(epoch from (d.end_date - d.start_date)) runtime_seconds
from public.dag_run d
where d.execution_date between '2022-01-01' and '2022-06-30' and d.state = 'success'
)
You might also consider joining to the task_instance table to get some task-level data, and perhaps use the min start and max end times for DAG tasks within a DAG run for your timestamps.

PySpark apply function on 2 dataframes and write to csv for billions of rows on small hardware

I am trying to apply a levenshtein function for each string in dfs against each string in dfc and write the resulting dataframe to csv. The issue is that I'm creating so many rows by using the cross join and then applying the function, that my machine is struggling to write anything (taking forever to execute).
Trying to improve write performance:
I'm filtering out a few things on the result of the cross join i.e. rows where the LevenshteinDistance is less than 15% of the target word's.
Using bucketing on the first letter of each target word i.e. a, b, c, etc. still no luck (i.e. job runs for hours and doesn't generate any results).
from datetime import datetime
from config import config
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.functions import col
from pyspark.sql import Window
def fuzzy_match(dfs, dfc, path_summary):
"""
Implements the Levenshtein and Soundex algorithms and returns a fuzzy matched DataFrame.
Filters out those where resulting LS distance is less than 15% of SF name length.
"""
# Apply Levenshtein and Soundex functions
dfs = dfs.withColumn("OrganisationNameKeyLen", F.length("OrganisationNameKey"))
df = dfc\
.crossJoin(dfs)\
.withColumn( "LevenshteinDistance", F.levenshtein( F.lower("OrganisationNameKey") , F.lower("CompanyNameKey") ) )\
.withColumn( "HasSameSoundex", F.soundex("OrganisationNameKey") == F.soundex("CompanyNameKey") )\
.where("LevenshteinDistance < OrganisationNameKeyLen * 0.15")\
.orderBy("OrganisationName", "CompanyName")
def fuzzy_match_approve(df, path_fuzzy_match_approved, path_fuzzy_match_rejected, path_summary):
"""
Filters fuzzy matching DataFrame results on approved/rejected based on set of conditions:
- If there is only 1 match against the SF name
- If more than 1 match then take that with LS distance of 1
- If more than 1 match and more multiple LS distances of 1, then take the one where Soundex codes are the same
Writes results and summary to CSV.
"""
def write_with_bucket(df, bucket_col, path):
df.write\
.mode("overwrite")\
.bucketBy(26, bucket_col)\
.option("path", path)\
.option("header", True)\
.saveAsTable("bucket", format="csv")
# Add window function columns:
# OrganisationNameMatchCount: Count AccountID per OrganisationName
# LevenshteinDistance1Count: Count AccountID per OrganisationName where LevenshteinDistance = 1
windowSpec = Window.partitionBy("OrganisationName")
df = df\
.select("AccountID", "OrganisationName", "OrganisationNameKey", "CompanyNumber", "CompanyName", "LevenshteinDistance", "HasSameSoundex")\
.withColumn("OrganisationNameMatchCount", F.count("AccountID").over(windowSpec))\
.withColumn("LevenshteinDistance1Count", F.count(F.when(F.col("LevenshteinDistance")==1, F.col("AccountID"))).over(windowSpec))
# Add bucket key column
df = df.withColumn( "OrganisationNameBucketKey", F.substring( col("OrganisationNameKey"),0,1) )
# Define fuzzy match approved condition
is_approved_1 = ( F.col("OrganisationNameMatchCount") == 1 )
is_approved_2 = ( (F.col("OrganisationNameMatchCount") > 1) & (F.col("LevenshteinDistance1Count") == 1) & (F.col("LevenshteinDistance") == 1) )
is_approved_3 = ( (F.col("OrganisationNameMatchCount") > 1) & (F.col("LevenshteinDistance1Count") > 1) & (F.col("HasSameSoundex") == 'true') )
is_approved = is_approved_1 | is_approved_2 | is_approved_3
# Split fuzzy match results into approved and rejected
df_approved = df.filter(is_approved)
df_rejected = df.filter(~is_approved)
# Export results
# df_approved.write.csv(path_fuzzy_match_approved, mode="overwrite", header=True, quoteAll=True)
# df_rejected.write.csv(path_fuzzy_match_rejected, mode="overwrite", header=True, quoteAll=True)
write_with_bucket(df_approved, "OrganisationNameBucketKey", path_fuzzy_match_approved)
write_with_bucket(df_rejected, "OrganisationNameBucketKey", path_fuzzy_match_rejected)
def main():
spark = SparkSession...
# Apply fuzzy match
dfs = spark.read...
dfc = spark.read...
path_summary = ...
df_fuzzy_match = fuzzy_match(dfs, dfc, path_summary)
# Export results
path_fuzzy_match_approved = ...
path_fuzzy_match_rejected = ...
fuzzy_match_approve(df_fuzzy_match, path_fuzzy_match_approved, path_fuzzy_match_rejected, path_summary)
main()
Other info:
df.rdd.getNumPartitions() is 2
dfs.count() is 12,515
dfc.count() is 5,110,430
Jobs:
How can I improve performance here and get the results into a CSV successfully?
There are a couple of things you can do to improve your computation:
Improve parallelism
As Nithish mentioned in the comments, you don't have enough partitions in your input data frames to make use of all your CPU cores. You're not using all your CPU capability and this will slow you down.
To increase your parallelism, repartition dfc to at least your number of cores:
dfc = dfc.repartition(dfc.sql_ctx.sparkContext.defaultParallelism)
You need to do this because your crossJoin is run as a BroadcastNestedLoopJoin which doesn't reshuffle your large input dataframe.
Separate your computation stages
A Spark dataframe/RDD is conceptually just a directed action graph (DAG) of operations to run on your input data but it does not hold data. One consequence of this behavior is that, by default, you'll rerun your computations as many times as you reuse your dataframe.
In your fuzzy_match_approve function, you run 2 separate filters on your df, this means you rerun the whole cross-join operations twice. You really don't want this !
One easy way to avoid this is to use cache() on your fuzzy_match result which should be fairly small given your inputs and matching criteria.
def fuzzy_match_running(dfs, dfc, path_summary):
"""
Implements the Levenshtein and Soundex algorithms and returns a fuzzy matched DataFrame.
Filters out those where resulting LS distance is less than 15% of SF name length.
"""
# Apply Levenshtein and Soundex functions
dfs = dfs.withColumn("OrganisationNameKeyLen", F.length("OrganisationNameKey")).cache()
dfc = dfc.repartition(dfc.sql_ctx.sparkContext.defaultParallelism).cache()
df = dfc.crossJoin(dfs) \
.withColumn( "LevenshteinDistance", F.levenshtein( F.lower("OrganisationNameKey") , F.lower("CompanyNameKey") ) ) \
.withColumn( "HasSameSoundex", F.soundex("OrganisationNameKey") == F.soundex("CompanyNameKey") ) \
.where("LevenshteinDistance < OrganisationNameKeyLen * 0.15") \
.orderBy("OrganisationName", "CompanyName") \
.cache()
return df
If I run my fuzzy_match_running on some example data frames on my 8 core/16 threads I9-9980HK laptop (spark in local[*] mode with 8GB driver memory):
dfc rowcount : 572494
dfs rowcount : 17728
fuzzy_match rowcount: 7228499
Duration: 679.5572581291199 seconds
Matches/core/sec: 933436.210726889
The job takes about 12 min doing 572494*17728 ~ 10 billion row comparisons
at 933k comparisons/seconds/core. Since your job does 64 billions row comparisons I would expect it to take about 80 min on my laptop.
You should run a similar experiment on your computer with a smaller sample to get an idea of your actual computing speed.
Going further: maximizing matches/sec
To go faster, we need to adjust the computation and increase the number of comparisons that can be done per seconds.
A few things stand out in the function:
you filter your output by comparing the levenshtein distance, an integer, to a decimal calculation. This means spark will cast your integer to a decimal and operate on decimal. Comparing decimals is much slower than integers and it's unnecessary here, you can cast the bound to an int beforehand.
your levenshtein operates on the lower versions of your keys, this means, for each row comparison, Spark will convert the column values to lower again and again, wasting CPU cycles for redundant stuff. You can preprocess this before your join.
I update the function like this:
def fuzzy_match(dfs: DataFrame, dfc: DataFrame, path_summary: str) -> DataFrame:
dfs = dfs.withColumn("OrganisationNameKeyLower", F.lower("OrganisationNameKey"))\
.withColumn("MatchingTolerance", F.ceil(F.length("OrganisationNameKey") * 0.15).cast("int"))\
.cache()
dfc = dfc.repartition(dfc.sql_ctx.sparkContext.defaultParallelism)\
.withColumn("CompanyNameKeyLower", F.lower("CompanyNameKey"))\
.cache()
df = dfc.crossJoin(dfs)\
.withColumn("LevenshteinDistance", F.levenshtein(F.col("OrganisationNameKeyLower"), F.col("CompanyNameKeyLower")).cast("int")) \
.where("LevenshteinDistance < MatchingTolerance")\
.drop("MatchingTolerance")\
.cache()
# clean unnecessary caches before returning
dfs.unpersist()
dfc.unpersist()
return df
When running the updated version on the same inputs as before and on the same computer I get nearly twice the performance as the first implementation
dfc rowcount : 572494
dfs rowcount : 17728
fuzzy_match rowcount: 7228499
Duration: 356.23311281204224 seconds
Matches/core/sec: 1780641.1846241967
If that is still too slow for your needs, you'll need to find conditions on your data that you can use as a join condition but that's highly data and use case specific.

How to optimize searching RDD contents in another RDD

I have a problem where I need to search the contents of an RDD in another RDD.
This question is different from Efficient string matching in Apache Spark, as I am searching for an exact match and I don't need the overhead of using the ML stack.
I am new to spark and I want to know which of these methods is more efficient or if there is another way.
I have a keyword file like the below sample (in production it might reach up to 200 lines)
Sample keywords file
0.47uF 25V X7R 10% -TDK C2012X7R1E474K125AA
20pF-50V NPO/COG - AVX- 08055A200JAT2A
and I have another file(tab separated)from which I need to find matches(in production I have up to 80 Million line)
C2012X7R1E474K125AA Conn M12 Circular PIN 5 POS Screw ST Cable Mount 5 Terminal 1 Port
First method
I defined a UDF and looped through keywords for each line
keywords = sc.textFile("keys")
part_description = sc.textFile("part_description")
def build_regex(keywords):
res = '('
for key in keywords:
res += '(?<!\\\s)%s(?!\\\s)|' % re.escape(key)
res = res[0:len(res) - 1] + ')'
return r'%s' % res
def get_matching_string(line, regex):
matches = re.findall(regex, line, re.IGNORECASE)
matches = list(set(matches))
return list(set(matches)) if matches else None
def find_matching_regex(line):
result = list()
for keyGroup in keys:
matches = get_matching_string(line, keyGroup)
if matches:
result.append(str(keyGroup) + '~~' + str(matches) + '~~' + str(len(matches)))
if len(result) > 0:
return result
def split_row(list):
try:
return Row(list[0], list[1])
except:
return None
keys_rdd = keywords.map(lambda keywords: build_regex(keywords.replace(',', ' ').replace('-', ' ').split(' ')))
keys = keys_rdd.collect()
sc.broadcast(keys)
part_description = part_description.map(lambda item: item.split('\t'))
df = part_description.map(lambda list: split_row(list)).filter(lambda x: x).toDF(
["part_number", "description"])
find_regex = udf(lambda line: find_matching_regex(line), ArrayType(StringType()))
df = df.withColumn('matched', find_regex(df['part_number']))
df = df.filter(df.matched.isNotNull())
df.write.save(path=job_id, format='csv', mode='append', sep='\t')
Second method
I thought I can do more parallel processing (instead of looping through keys like above) I did cartersian product between keys and lines, splitted and exploded the keys then compared each key to the part column
df = part_description.cartesian(keywords)
df = df.map(lambda tuple: (tuple[0].split('\t'), tuple[1])).map(
lambda tuple: (tuple[0][0], tuple[0][1], tuple[1]))
df = df.toDF(['part_number', 'description', 'keywords'])
df = df.withColumn('single_keyword', explode(split(F.col("keywords"), "\s+"))).where('keywords != ""')
df = df.withColumn('matched_part_number', (df['part_number'] == df['single_keyword']))
df = df.filter(df['matched_part_number'] == F.lit(True))
df.write.save(path='part_number_search', format='csv', mode='append', sep='\t')
Are these the correct ways to do this? Is there anything I can do to process these data faster?
These are both valid solutions, and I have used both in different circumstances.
You communicate less data by using your broadcast approach, sending only 200 extra lines to each executor as opposed to replicating each line of your >80m line file 200 times, so it is likely this one will end up being faster for you.
I have used the cartesian approach when the number of records in my lookup is not feasibly broadcast-able (being much, much larger than 200 lines).
In your situation, I would use broadcast.

Weird behavior upon splitting RDD into smaller RDDs and storing them in list

because of resource constraints, I need to be able to split up a large RDD into n smaller RDDs and call spark-submit on them as separate jobs.
The code looks something like this:
def split_rdd_by_key(input_rdd, distinct_key_count, num_splits=10):
# Calc. chunk indexes - lower and upper bounds for each smaller rdd
chunk_gtor = __chunk_points(distinct_key_count, num_splits)
smaller_rdds = []
sort_key = "data_key"
# Create sets of smaller rdds by filtering on indexes
for item in chunk_gtor:
lbval, ubval = item[0], item[1]
print "lbval=%s ubval=%s" % (lbval, ubval)
filt_rdd = input_rdd.filter(lambda x : x.key.entity >= lbval \
and x.key.entity <= ubval)
filt_count=filt_rdd.count()
print "filt_count=%s" % filt_count
smaller_rdds.append(filt_rdd)
return smaller_rdds
The code above prints the size of each smaller rdd as it is generated and appends it to the smaller_rdds list.
However, if I run the function above:
filtered_rdds=split_rdd_by_key(rdd, distinct_key_count)
and do the following:
# See the size of the 1st smaller rdd:
filtered_rdds[0].count()
it returns 31, which is MUCH smaller than what was printed when the function split_rdd_by_key was being run !!
Can anyone help explain this ? I must be missing something.
I figured out the issue. I needed to make a call to cache() :
filt_rdd = dataflow_rdd.filter(lambda x : x.key.entity >= lbval \
and x.key.entity <= ubval)
filt_rdd.cache()
Now it works ok :-)

Resources