Creating a custom counter in Spark based on dataframe conditions - apache-spark
Current Dataset
+---+-----+-----+-----+----+
| ID|Event|Index|start| end|
+---+-----+-----+-----+----+
| 1| run| 0|start|null|
| 1| run| 1| null|null|
| 1| run| 2| null|null|
| 1| swim| 3| null| end|
| 1| run| 4|start|null|
| 1| swim| 5| null|null|
| 1| swim| 6| null| end|
| 1| run| 7|start|null|
| 1| run| 8| null|null|
| 1| run| 9| null|null|
| 1| swim| 10| null| end|
| 1| run| 11|start|null|
| 1| run| 12| null|null|
| 1| run| 13| null| end|
| 2| run| 14|start|null|
| 2| run| 15| null|null|
| 2| run| 16| null|null|
| 2| swim| 17| null| end|
| 2| run| 18|start|null|
| 2| swim| 19| null|null|
| 2| swim| 20| null|null|
| 2| swim| 21| null|null|
| 2| swim| 22| null| end|
| 2| run| 23|start|null|
| 2| run| 24| null|null|
| 2| run| 25| null| end|
| 3| run| 26|start|null|
| 3| run| 27| null|null|
| 3| swim| 28| null|null|
+---+-----+-----+-----+----+
Dataset I'm After
+---+-----+-----+-----+----+-------+
| ID|Event|Index|start| end|EventID|
+---+-----+-----+-----+----+-------+
| 1| run| 0|start|null| 1|
| 1| run| 1| null|null| 1|
| 1| run| 2| null|null| 1|
| 1| swim| 3| null| end| 1|
| 1| run| 4|start|null| 2|
| 1| swim| 5| null|null| 2|
| 1| swim| 6| null| end| 2|
| 1| run| 7|start|null| 3|
| 1| run| 8| null|null| 3|
| 1| run| 9| null|null| 3|
| 1| swim| 10| null| end| 3|
| 1| run| 11|start|null| 4|
| 1| run| 12| null|null| 4|
| 1| run| 13| null| end| 4|
| 2| run| 14|start|null| 1|
| 2| run| 15| null|null| 1|
| 2| run| 16| null|null| 1|
| 2| swim| 17| null| end| 1|
| 2| run| 18|start|null| 2|
| 2| swim| 19| null|null| 2|
| 2| swim| 20| null|null| 2|
| 2| swim| 21| null|null| 2|
| 2| swim| 22| null| end| 2|
| 2| run| 23|start|null| 3|
| 2| run| 24| null|null| 3|
| 2| run| 25| null| end| 3|
| 3| run| 26|start|null| 1|
| 3| run| 27| null|null| 1|
| 3| swim| 28| null|null| 1|
+---+-----+-----+-----+----+-------+
I am trying to create the above EventID Column. Is there a way to create a counter inside of a udf that updates based on column conditions? Note, I'm not sure if a UDF is the best approach here.
Here is my current thinking-logic:
When a "start" value is seen, start counting.
When an "end" value is seen, end counting
Every time a new ID is seen, reset the counter to 1
Thank you all for any assistance.
Here is the raw code to produce the current dataframe:
# Current Dataset
data = [
(1, "run", 0, 'start', None),
(1, "run", 1, None, None),
(1, "run", 2, None, None),
(1, "swim", 3, None, 'end'),
(1, "run", 4, 'start',None),
(1, "swim", 5, None, None),
(1, "swim", 6, None, 'end'),
(1, "run",7, 'start', None),
(1, "run",8, None, None),
(1, "run",9, None, None),
(1, "swim",10, None, 'end'),
(1, "run",11, 'start', None),
(1, "run",12, None, None),
(1, "run",13, None, 'end'),
(2, "run",14, 'start', None),
(2, "run",15, None, None),
(2, "run",16, None, None),
(2, "swim",17, None, 'end'),
(2, "run",18, 'start', None),
(2, "swim",19, None, None),
(2, "swim",20, None, None),
(2, "swim",21, None, None),
(2, "swim",22, None, 'end'),
(2, "run",23, 'start', None),
(2, "run",24, None, None),
(2, "run",25, None, 'end'),
(3, "run",26, 'start', None),
(3, "run",27, None, None),
(3, "swim",28, None, None)
]
schema = StructType([
StructField('ID', IntegerType(),True), \
StructField('Event', StringType(),True), \
StructField('Index', IntegerType(),True), \
StructField('start', StringType(),True), \
StructField('end', StringType(),True)
])
df = spark.createDataFrame(data=data, schema=schema)
df.show(30)
You can use a window function:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
w = Window.partitionBy('ID').rowsBetween(Window.unboundedPreceding,0).orderBy('index')
df.withColumn('EventId', F.sum(F.when(F.col('start') == 'start', 1).otherwise(0))\
.over(w)).orderBy('ID', 'Index').show(100)
results in
+---+-----+-----+-----+----+-------+
| ID|Event|Index|start| end|EventId|
+---+-----+-----+-----+----+-------+
| 1| run| 0|start|null| 1|
| 1| run| 1| null|null| 1|
| 1| run| 2| null|null| 1|
| 1| swim| 3| null| end| 1|
| 1| run| 4|start|null| 2|
| 1| swim| 5| null|null| 2|
| 1| swim| 6| null| end| 2|
| 1| run| 7|start|null| 3|
| 1| run| 8| null|null| 3|
| 1| run| 9| null|null| 3|
| 1| swim| 10| null| end| 3|
| 1| run| 11|start|null| 4|
| 1| run| 12| null|null| 4|
| 1| run| 13| null| end| 4|
| 2| run| 14|start|null| 1|
| 2| run| 15| null|null| 1|
| 2| run| 16| null|null| 1|
| 2| swim| 17| null| end| 1|
| 2| run| 18|start|null| 2|
| 2| swim| 19| null|null| 2|
| 2| swim| 20| null|null| 2|
| 2| swim| 21| null|null| 2|
| 2| swim| 22| null| end| 2|
| 2| run| 23|start|null| 3|
| 2| run| 24| null|null| 3|
| 2| run| 25| null| end| 3|
| 3| run| 26|start|null| 1|
| 3| run| 27| null|null| 1|
| 3| swim| 28| null|null| 1|
+---+-----+-----+-----+----+-------+
You can calculate dense_rank based on the most recent start time:
from pyspark.sql import functions as F, Window
df2 = df.withColumn(
'laststart',
F.last(F.when(F.col('start') == 'start', F.col('Index')), True).over(Window.partitionBy('ID').orderBy('Index'))
).withColumn(
'EventID',
F.dense_rank().over(Window.partitionBy('ID').orderBy('laststart'))
)
df2.show(999)
+---+-----+-----+-----+----+---------+-------+
| ID|Event|Index|start| end|laststart|EventID|
+---+-----+-----+-----+----+---------+-------+
| 1| run| 0|start|null| 0| 1|
| 1| run| 1| null|null| 0| 1|
| 1| run| 2| null|null| 0| 1|
| 1| swim| 3| null| end| 0| 1|
| 1| run| 4|start|null| 4| 2|
| 1| swim| 5| null|null| 4| 2|
| 1| swim| 6| null| end| 4| 2|
| 1| run| 7|start|null| 7| 3|
| 1| run| 8| null|null| 7| 3|
| 1| run| 9| null|null| 7| 3|
| 1| swim| 10| null| end| 7| 3|
| 1| run| 11|start|null| 11| 4|
| 1| run| 12| null|null| 11| 4|
| 1| run| 13| null| end| 11| 4|
| 2| run| 14|start|null| 14| 1|
| 2| run| 15| null|null| 14| 1|
| 2| run| 16| null|null| 14| 1|
| 2| swim| 17| null| end| 14| 1|
| 2| run| 18|start|null| 18| 2|
| 2| swim| 19| null|null| 18| 2|
| 2| swim| 20| null|null| 18| 2|
| 2| swim| 21| null|null| 18| 2|
| 2| swim| 22| null| end| 18| 2|
| 2| run| 23|start|null| 23| 3|
| 2| run| 24| null|null| 23| 3|
| 2| run| 25| null| end| 23| 3|
| 3| run| 26|start|null| 26| 1|
| 3| run| 27| null|null| 26| 1|
| 3| swim| 28| null|null| 26| 1|
+---+-----+-----+-----+----+---------+-------+
Related
window function on a subset of data
I have a table like the below. I want to calculate an average of median but only for Q=2 and Q=3. I don't want to include other Qs but still preserve the data. df = spark.createDataFrame([('2018-03-31',6,1),('2018-03-31',27,2),('2018-03-31',3,3),('2018-03-31',44,4),('2018-06-30',6,1),('2018-06-30',4,3),('2018-06-30',32,2),('2018-06-30',112,4),('2018-09-30',2,1),('2018-09-30',23,4),('2018-09-30',37,3),('2018-09-30',3,2)],['date','median','Q']) +----------+--------+---+ | date| median | Q | +----------+--------+---+ |2018-03-31| 6| 1| |2018-03-31| 27| 2| |2018-03-31| 3| 3| |2018-03-31| 44| 4| |2018-06-30| 6| 1| |2018-06-30| 4| 3| |2018-06-30| 32| 2| |2018-06-30| 112| 4| |2018-09-30| 2| 1| |2018-09-30| 23| 4| |2018-09-30| 37| 3| |2018-09-30| 3| 2| +----------+--------+---+ Expected output: +----------+--------+---+------------+ | date| median | Q |result | +----------+--------+---+------------+ |2018-03-31| 6| 1| null| |2018-03-31| 27| 2| 15| |2018-03-31| 3| 3| 15| |2018-03-31| 44| 4| null| |2018-06-30| 6| 1| null| |2018-06-30| 4| 3| 18| |2018-06-30| 32| 2| 18| |2018-06-30| 112| 4| null| |2018-09-30| 2| 1| null| |2018-09-30| 23| 4| null| |2018-09-30| 37| 3| 20| |2018-09-30| 3| 2| 20| +----------+--------+---+------------+ OR +----------+--------+---+------------+ | date| median | Q |result | +----------+--------+---+------------+ |2018-03-31| 6| 1| 15| |2018-03-31| 27| 2| 15| |2018-03-31| 3| 3| 15| |2018-03-31| 44| 4| 15| |2018-06-30| 6| 1| 18| |2018-06-30| 4| 3| 18| |2018-06-30| 32| 2| 18| |2018-06-30| 112| 4| 18| |2018-09-30| 2| 1| 20| |2018-09-30| 23| 4| 20| |2018-09-30| 37| 3| 20| |2018-09-30| 3| 2| 20| +----------+--------+---+------------+ I tried the following code but when I include the where statement it drops Q=1 and Q=4. window = ( Window .partitionBy("date") .orderBy("date") ) df_avg = ( df .where( (F.col("Q") == 2) | (F.col("Q") == 3) ) .withColumn("result", F.avg("median").over(window)) )
For both of your expected output, you can use conditional aggregation, use avg with when (otherwise). If you want the 1st expected output. window = ( Window .partitionBy("date", F.col("Q").isin([2, 3])) ) df_avg = ( df.withColumn("result", F.when(F.col("Q").isin([2, 3]), F.avg("median").over(window))) ) For the 2nd expected output. window = ( Window .partitionBy("date") ) df_avg = ( df.withColumn("result", F.avg(F.when(F.col("Q").isin([2, 3]), F.col("median"))).over(window)) )
Alternatively, since you are really aggregating a (small?) subset, replace window with auto-join: >>> df_avg = df.where(col("Q").isin([2,3])).groupBy("date","Q").agg(avg("median").alias("result")) >>> df_result = df.join(df_avg,["date","Q"],"left") Might turn out to be faster than using window.
Pyspark update record based on last value using timestamp and column value
I'm struggling to figure this out. I need to find the last record with reason backfill and update the non backfill record with the greatest timestamp. Here is what I've tried - w = Window.orderBy("idx") w1 = Window.partitionBy('reason').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) df_uahr.withColumn('idx',F.monotonically_increasing_id()).withColumn("app_data_new",F.last(F.lead("app_data").over(w)).over(w1)).orderBy("idx").show() +----------------------+-------------+-------------------+-------------------+------------+---+------------+ |upstart_application_id| reason| created_at| updated_at| app_data|idx|app_data_new| +----------------------+-------------+-------------------+-------------------+------------+---+------------+ | 2|disqualified |2018-07-12 15:57:26|2018-07-12 15:57:26| app_data_a| 0| app_data_c| | 2| backfill|2020-05-29 17:47:09|2021-05-29 17:47:09| app_data_c| 1| null| | 2| backfill|2022-03-09 09:47:09|2022-03-09 09:47:09| app_data_d| 2| null| | 2| test|2022-04-09 09:47:09|2022-04-09 09:47:09| app_data_e| 3| app_data_f| | 2| test|2022-04-19 09:47:09|2022-04-19 09:47:09|app_data_e_a| 4| app_data_f| | 2| backfill|2022-05-09 09:47:09|2022-05-09 09:47:09| app_data_f| 5| null| | 2| after|2023-04-09 09:47:09|2023-04-09 09:47:09| app_data_g| 6| app_data_h| | 2| backfill|2023-05-09 09:47:09|2023-05-09 09:47:09| app_data_h| 7| null| +----------------------+-------------+-------------------+-------------------+------------+---+------------+ Expected value +----------------------+-------------+-------------------+-------------------+------------+---+------------+ |upstart_application_id| reason| created_at| updated_at| app_data|idx|app_data_new| +----------------------+-------------+-------------------+-------------------+------------+---+------------+ | 2|disqualified |2018-07-12 15:57:26|2018-07-12 15:57:26| app_data_a| 0| app_data_d| | 2| backfill|2020-05-29 17:47:09|2021-05-29 17:47:09| app_data_c| 1| null| | 2| backfill|2022-03-09 09:47:09|2022-03-09 09:47:09| app_data_d| 2| null| | 2| test|2022-04-09 09:47:09|2022-04-09 09:47:09| app_data_e| 3| null| | 2| test|2022-04-19 09:47:09|2022-04-19 09:47:09|app_data_e_a| 4| app_data_f| | 2| backfill|2022-05-09 09:47:09|2022-05-09 09:47:09| app_data_f| 5| null| | 2| after|2023-04-09 09:47:09|2023-04-09 09:47:09| app_data_g| 6| app_data_h| | 2| backfill|2023-05-09 09:47:09|2023-05-09 09:47:09| app_data_h| 7| null| +----------------------+-------------+-------------------+-------------------+------------+---+------------+
Check start, middle and end of groups in Spark
I have a Spark dataframe that looks like this: +---+-----------+-------------------------+---------------+ | id| Phase | Switch | InputFileName | +---+-----------+-------------------------+---------------+ | 1| 2| 1| fileA| | 2| 2| 1| fileA| | 3| 2| 1| fileA| | 4| 2| 0| fileA| | 5| 2| 0| fileA| | 6| 2| 1| fileA| | 11| 2| 1| fileB| | 12| 2| 1| fileB| | 13| 2| 0| fileB| | 14| 2| 0| fileB| | 15| 2| 1| fileB| | 16| 2| 1| fileB| | 21| 4| 1| fileB| | 22| 4| 1| fileB| | 23| 4| 1| fileB| | 24| 4| 1| fileB| | 25| 4| 1| fileB| | 26| 4| 0| fileB| | 31| 1| 0| fileC| | 32| 1| 0| fileC| | 33| 1| 0| fileC| | 34| 1| 0| fileC| | 35| 1| 0| fileC| | 36| 1| 0| fileC| +---+-----------+-------------------------+---------------+ For each group (a combination of InputFileName and Phase) I need to run a validation function which checks that Switch equals 1 at the very start and end of the group, and transitions to 0 at any point in-between. The function should add the validation result as a new column. The expected output is below: (gaps are just to highlight the different groups) +---+-----------+-------------------------+---------------+--------+ | id| Phase | Switch | InputFileName | Valid | +---+-----------+-------------------------+---------------+--------+ | 1| 2| 1| fileA| true | | 2| 2| 1| fileA| true | | 3| 2| 1| fileA| true | | 4| 2| 0| fileA| true | | 5| 2| 0| fileA| true | | 6| 2| 1| fileA| true | | 11| 2| 1| fileB| true | | 12| 2| 1| fileB| true | | 13| 2| 0| fileB| true | | 14| 2| 0| fileB| true | | 15| 2| 1| fileB| true | | 16| 2| 1| fileB| true | | 21| 4| 1| fileB| false| | 22| 4| 1| fileB| false| | 23| 4| 1| fileB| false| | 24| 4| 1| fileB| false| | 25| 4| 1| fileB| false| | 26| 4| 0| fileB| false| | 31| 1| 0| fileC| false| | 32| 1| 0| fileC| false| | 33| 1| 0| fileC| false| | 34| 1| 0| fileC| false| | 35| 1| 0| fileC| false| | 36| 1| 0| fileC| false| +---+-----------+-------------------------+---------------+--------+ I have previously solved this using Pyspark and a Pandas UDF: df = df.groupBy("InputFileName", "Phase").apply(validate_profile) #pandas_udf(schema, PandasUDFType.GROUPED_MAP) def validate_profile(df: pd.DataFrame): first_valid = True if df["Switch"].iloc[0] == 1 else False during_valid = (df["Switch"].iloc[1:-1] == 0).any() last_valid = True if df["Switch"].iloc[-1] == 1 else False df["Valid"] = first_valid & during_valid & last_valid return df However, now I need to rewrite this in Scala. I just want to know the best way of accomplishing this. I'm currently trying window functions to get the first and last ids of each group: val minIdWindow = Window.partitionBy("InputFileName", "Phase").orderBy("id") val maxIdWindow = Window.partitionBy("InputFileName", "Phase").orderBy(col("id").desc) I can then add the min and max ids as separate columns and use when to get the start and end values of Switch: df.withColumn("MinId", min("id").over(minIdWindow)) .withColumn("MaxId", max("id").over(maxIdWindow)) .withColumn("Valid", when( col("id") === col("MinId"), col("Switch") ).when( col("id") === col("MaxId"), col("Switch") )) This gets me the start and end values, but I'm not sure how to check if Switch equals 0 in between. Am I on the right track using window functions? Or would you recommend an alternative solution?
Try this, val wind = Window.partitionBy("InputFileName", "Phase").orderBy("id") .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) val df1 = df.withColumn("Valid", when(first("Switch").over(wind) === 1 && last("Switch").over(wind) === 1 && min("Switch").over(wind) === 0, true) .otherwise(false)) df1.orderBy("id").show() //Ordering for display purpose Output: +---+-----+------+-------------+-----+ | id|Phase|Switch|InputFileName|Valid| +---+-----+------+-------------+-----+ | 1| 2| 1| fileA| true| | 2| 2| 1| fileA| true| | 3| 2| 1| fileA| true| | 4| 2| 0| fileA| true| | 5| 2| 0| fileA| true| | 6| 2| 1| fileA| true| | 11| 2| 1| fileB| true| | 12| 2| 1| fileB| true| | 13| 2| 0| fileB| true| | 14| 2| 0| fileB| true| | 15| 2| 1| fileB| true| | 16| 2| 1| fileB| true| | 21| 4| 1| fileB|false| | 22| 4| 1| fileB|false| | 23| 4| 1| fileB|false| | 24| 4| 1| fileB|false| | 25| 4| 1| fileB|false| | 26| 4| 0| fileB|false| | 31| 1| 0| fileC|false| | 32| 1| 0| fileC|false| +---+-----+------+-------------+-----+
Enumerate blocks of successively equal values in Spark
I want to find the IDs of groups (or blocks) of trues in a Spark DataFrame. That is, I want to go from this: >>> df.show() +---------+-----+ |timestamp| bool| +---------+-----+ | 1|false| | 2| true| | 3| true| | 4|false| | 5| true| | 6| true| | 7| true| | 8| true| | 9|false| | 10|false| | 11|false| | 12|false| | 13|false| | 14| true| | 15| true| | 16| true| +---------+-----+ to this: >>> df.show() +---------+-----+-----+ |timestamp| bool|block| +---------+-----+-----+ | 1|false| 0| | 2| true| 1| | 3| true| 1| | 4|false| 0| | 5| true| 2| | 6| true| 2| | 7| true| 2| | 8| true| 2| | 9|false| 0| | 10|false| 0| | 11|false| 0| | 12|false| 0| | 13|false| 0| | 14| true| 3| | 15| true| 3| | 16| true| 3| +---------+-----+-----+ (the zeros are optional, could be Null or -1 or whatever is easier to implement)
I have a solution in scala, should be easy to adapt it to pyspark. Consider the following dataframe df: +---------+-----+ |timestamp| bool| +---------+-----+ | 1|false| | 2| true| | 3| true| | 4|false| | 5| true| | 6| true| | 7| true| | 8| true| | 9|false| | 10|false| | 11|false| | 12|false| | 13|false| | 14| true| | 15| true| | 16| true| +---------+-----+ then you could do: df .withColumn("prev_bool",lag($"bool",1).over(Window.orderBy($"timestamp"))) .withColumn("block",sum(when(!$"prev_bool" and $"bool",1).otherwise(0)).over(Window.orderBy($"timestamp"))) .drop($"prev_bool") .withColumn("block",when($"bool",$"block").otherwise(0)) .show() +---------+-----+-----+ |timestamp| bool|block| +---------+-----+-----+ | 1|false| 0| | 2| true| 1| | 3| true| 1| | 4|false| 0| | 5| true| 2| | 6| true| 2| | 7| true| 2| | 8| true| 2| | 9|false| 0| | 10|false| 0| | 11|false| 0| | 12|false| 0| | 13|false| 0| | 14| true| 3| | 15| true| 3| | 16| true| 3| +---------+-----+-----+
Spark - Window with recursion? - Conditionally propagating values across rows
I have the following dataframe showing the revenue of purchases. +-------+--------+-------+ |user_id|visit_id|revenue| +-------+--------+-------+ | 1| 1| 0| | 1| 2| 0| | 1| 3| 0| | 1| 4| 100| | 1| 5| 0| | 1| 6| 0| | 1| 7| 200| | 1| 8| 0| | 1| 9| 10| +-------+--------+-------+ Ultimately I want the new column purch_revenue to show the revenue generated by the purchase in every row. As a workaround, I have also tried to introduce a purchase identifier purch_id which is incremented each time a purchase was made. So this is listed just as a reference. +-------+--------+-------+-------------+--------+ |user_id|visit_id|revenue|purch_revenue|purch_id| +-------+--------+-------+-------------+--------+ | 1| 1| 0| 100| 1| | 1| 2| 0| 100| 1| | 1| 3| 0| 100| 1| | 1| 4| 100| 100| 1| | 1| 5| 0| 100| 2| | 1| 6| 0| 100| 2| | 1| 7| 200| 100| 2| | 1| 8| 0| 100| 3| | 1| 9| 10| 100| 3| +-------+--------+-------+-------------+--------+ I've tried to use the lag/lead function like this: user_timeline = Window.partitionBy("user_id").orderBy("visit_id") find_rev = fn.when(fn.col("revenue") > 0,fn.col("revenue"))\ .otherwise(fn.lead(fn.col("revenue"), 1).over(user_timeline)) df.withColumn("purch_revenue", find_rev) This duplicates the revenue column if revenue > 0 and also pulls it up by one row. Clearly, I can chain this for a finite N, but that's not a solution. Is there a way to apply this recursively until revenue > 0? Alternatively, is there a way to increment a value based on a condition? I've tried to figure out a way to do that but struggled to find one.
Window functions don't support recursion but it is not required here. This type of sesionization can be easily handled with cumulative sum: from pyspark.sql.functions import col, sum, when, lag from pyspark.sql.window import Window w = Window.partitionBy("user_id").orderBy("visit_id") purch_id = sum(lag(when( col("revenue") > 0, 1).otherwise(0), 1, 0 ).over(w)).over(w) + 1 df.withColumn("purch_id", purch_id).show() +-------+--------+-------+--------+ |user_id|visit_id|revenue|purch_id| +-------+--------+-------+--------+ | 1| 1| 0| 1| | 1| 2| 0| 1| | 1| 3| 0| 1| | 1| 4| 100| 1| | 1| 5| 0| 2| | 1| 6| 0| 2| | 1| 7| 200| 2| | 1| 8| 0| 3| | 1| 9| 10| 3| +-------+--------+-------+--------+