SPARK : Set a column value based on multiple row conditions - apache-spark

I have a dataframe of the below format:
+----+---+-----+------+-----+------+
|AGEF|SEX|F0_34|F35_44|M0_34|M35_44|
+----+---+-----+------+-----+------+
| 30| 0| 0| 0| 0| 0|
| 94| 1| 0| 0| 0| 0|
| 94| 0| 0| 0| 0| 0|
| 94| 0| 0| 0| 0| 0|
| 94| 1| 0| 0| 0| 0|
| 44| 0| 0| 0| 0| 0|
| 66| 0| 0| 0| 0| 0|
| 66| 0| 0| 0| 0| 0|
| 74| 0| 0| 0| 0| 0|
| 74| 0| 0| 0| 0| 0|
| 29| 0| 0| 0| 0| 0|
Now based on the values of columns AGEF and SEX I need to assign 1 to corresponding column name. Each column name is self explanatory like F0_34 is female between age 0 to 34 similarly for other case.
Expected output is
+----+---+-----+------+-----+------+
|AGEF|SEX|F0_34|F35_44|M0_34|M35_44|
+----+---+-----+------+-----+------+
| 30| 0| 1| 0| 0| 0|
| 94| 1| 0| 0| 0| 0|
| 94| 0| 0| 0| 0| 0|
| 94| 0| 0| 0| 0| 0|
| 94| 1| 0| 0| 0| 0|
| 44| 0| 0| 1| 0| 0|
| 66| 0| 0| 0| 0| 0|
| 66| 0| 0| 0| 0| 0|
| 74| 0| 0| 0| 0| 0|
| 74| 0| 0| 0| 0| 0|
| 29| 0| 1| 0| 0| 0|
Thanks in Advance!!!

Typically the most efficient approach is to operate directly on SQL expressions. For example:
def categorize(ageRanges: Seq[(Int, Int)], sexValues: Seq[(Int, String)]) = for {
(ageL, ageH) <- ageRanges
(sexV, sexL) <- sexValues
} yield ($"SEX" === sexL && $"AGEF".between(ageL, ageH)).alias(
s"$sexL-$ageL-$ageH"
)
df.select(
$"*" +: categorize(Seq((0, 34), (35, 44)), Seq((0, "F"), (1, "M"))): _*
)

Simplest way is to make a UDF that takes 5 parameters (e.g.: actual_age, actual_sex, target_sex, target_min_age, target_max_age) and returns either 1 or 0. Something like this:
val ageRanger = udf[Int,Int,Int,Int,Int,Int]((age: Int, sex: Int, targetSex: Int, targetMinAge: Int, targetMaxAge: Int) => {
if (age >= targetMinAge && age <= targetMaxAge && sex == targetSex) 1 else 0
})
Then if you had this DataFrame:
val df = Seq((30,0),(94,1),(94,0),(44,0)).toDF("AGEF", "SEX")
// +----+---+
// |AGEF|SEX|
// +----+---+
// | 30| 0|
// | 94| 1|
// | 94| 0|
// | 44| 0|
// +----+---+
df.withColumn("F0_34", ageRanger($"AGEF", $"SEX", lit(0), lit(0), lit(34)))
.withColumn("F35_44", ageRanger($"AGEF", $"SEX", lit(0), lit(35), lit(44)))
.show
// +----+---+-----+------+
// |AGEF|SEX|F0_34|F35_44|
// +----+---+-----+------+
// | 30| 0| 1| 0|
// | 94| 1| 0| 0|
// | 94| 0| 0| 0|
// | 44| 0| 0| 1|
// +----+---+-----+------+
Note that you have to pass values into the UDF as Columns, so I use lit(...) to wrap my Int values for the hard-coded values. There could be a slicker way to do that, but it works fine this way.

Related

Pyspark groupby for all column with unpivot

I have 101 columns from a pipe delimited and looking to get counts for all columns with all untransposing the data.
Sample data:
+----------------+------------+------------+------------+------------+------------+------------+
|rm_ky|flag_010961|flag_011622|flag_009670|flag_009708|flag_009890|flag_009893|
+----------------+------------+------------+------------+------------+------------+------------+
| 193012020044| 0| 0| 0| 0| 0| 0|
| 115012030044| 0| 0| 1| 1| 1| 1|
| 140012220044| 0| 0| 0| 0| 0| 0|
| 189012240044| 0| 0| 0| 0| 0| 0|
| 151012350044| 0| 0| 0| 0| 0| 0|
+----------------+------------+------------+------------+------------+------------+------------+
I have tried each column based out like
df.groupBy("flag_011622").count().show()
+------------+--------+
|flag_011622| count|
+------------+--------+
| 1| 192289|
| 0|69861980|
+------------+--------+
Instead I'm looking something like
I'm looking something like: Any suggestions to handle instead of loop in each time
+----------------+------------+------------+
|rm_ky|flag_010961|flag_name|counts|
+----------------+------------+------------+--------
| flag_011622| 1| 192289|
| flag_011622| 0| 69861980|
| flag_009670| 1| 120011800|
| flag_009670| 0| 240507|
| flag_009708| 1| 119049838|
| flag_009708| 0| 1202469|
+----------------+------------+------------+--------
You could use stack function that returns a reshaped DataFrame or Series having a multi-level index with one or more new inner-most levels compared to the current DataFrame. The new inner-most levels are created by pivoting the columns of the current dataframe.
Using your sample as df:
df = df.select(
"rm_ky",
F.expr(
"""stack(5,
'flag_010961', flag_010961,
'flag_009670', flag_009670,
'flag_009708', flag_009708,
'flag_009890', flag_009890,
'flag_009893', flag_009893
) AS (flag_name, value)"""
),
)
gives:
+------------+-----------+-----+
|rm_ky |flag_name |value|
+------------+-----------+-----+
|193012020044|flag_010961|0 |
|193012020044|flag_009670|0 |
|193012020044|flag_009708|0 |
|193012020044|flag_009890|0 |
|193012020044|flag_009893|0 |
|115012030044|flag_010961|0 |
|115012030044|flag_009670|0 |
|115012030044|flag_009708|1 |
|115012030044|flag_009890|1 |
|115012030044|flag_009893|1 |
|140012220044|flag_010961|0 |
|140012220044|flag_009670|0 |
|140012220044|flag_009708|0 |
|140012220044|flag_009890|0 |
|140012220044|flag_009893|0 |
|189012240044|flag_010961|0 |
|189012240044|flag_009670|0 |
|189012240044|flag_009708|0 |
|189012240044|flag_009890|0 |
|189012240044|flag_009893|0 |
|151012350044|flag_010961|0 |
|151012350044|flag_009670|0 |
|151012350044|flag_009708|0 |
|151012350044|flag_009890|0 |
|151012350044|flag_009893|0 |
+------------+-----------+-----+
Which you can then group and order:
df = (
df.groupBy("flag_name", "value")
.agg(F.count("*").alias("counts"))
.orderBy("flag_name", "value")
)
to get:
+-----------+-----+------+
|flag_name |value|counts|
+-----------+-----+------+
|flag_009670|0 |5 |
|flag_009708|0 |4 |
|flag_009708|1 |1 |
|flag_009890|0 |4 |
|flag_009890|1 |1 |
|flag_009893|0 |4 |
|flag_009893|1 |1 |
|flag_010961|0 |5 |
+-----------+-----+------+
Exemple:
data = [ ("193012020044",0, 0, 0, 0, 0, 1)
,("115012030044",0, 0, 1, 1, 1, 1)
,("140012220044",0, 0, 0, 0, 0, 0)
,("189012240044",0, 1, 0, 0, 0, 0)
,("151012350044",0, 0, 0, 1, 1, 0)]
columns= ["rm_ky","flag_010961","flag_011622","flag_009670","flag_009708","flag_009890","flag_009893"]
df = spark.createDataFrame(data = data, schema = columns)
df.show()
+------------+-----------+-----------+-----------+-----------+-----------+-----------+
| rm_ky|flag_010961|flag_011622|flag_009670|flag_009708|flag_009890|flag_009893|
+------------+-----------+-----------+-----------+-----------+-----------+-----------+
|193012020044| 0| 0| 0| 0| 0| 1|
|115012030044| 0| 0| 1| 1| 1| 1|
|140012220044| 0| 0| 0| 0| 0| 0|
|189012240044| 0| 1| 0| 0| 0| 0|
|151012350044| 0| 0| 0| 1| 1| 0|
+------------+-----------+-----------+-----------+-----------+-----------+-----------+
Creating an expression to unpivot:
x = ""
cnt = 0
for col in df.columns:
if col != 'rm_ky':
cnt += 1
x += "'"+str(col)+"', " + str(col) + ", "
x = x[:-2]
xpr = """stack({}, {}) as (Type,Value)""".format(cnt,x)
print(xpr)
>> stack(6, 'flag_010961', flag_010961, 'flag_011622', flag_011622, 'flag_009670', flag_009670, 'flag_009708', flag_009708, 'flag_009890', flag_009890, 'flag_009893', flag_009893) as (Type,Value)
Then, using expr and pivot:
from pyspark.sql import functions as F
df\
.drop('rm_ky')\
.select(F.lit('dummy'),F.expr(xpr))\
.drop('dummy')\
.groupBy('Type')\
.pivot('Value')\
.agg(*[F.count(x).alias(x) for x in df_output.columns if x not in {"Type"}])\
.fillna(0)\
.show()
+-----------+---+---+
| Type| 0| 1|
+-----------+---+---+
|flag_009890| 3| 2|
|flag_009893| 3| 2|
|flag_011622| 4| 1|
|flag_010961| 5| 0|
|flag_009708| 3| 2|
|flag_009670| 4| 1|
+-----------+---+---+
i think this is what you are looking for
>>> df2.show()
+------------+-----------+-----------+-----------+-----------+-----------+-----------+
| rm_ky|flag_010961|flag_011622|flag_009670|flag_009708|flag_009890|flag_009893|
+------------+-----------+-----------+-----------+-----------+-----------+-----------+
|193012020044| 0| 0| 0| 0| 0| 0|
|115012030044| 0| 0| 1| 1| 1| 1|
|140012220044| 0| 0| 0| 0| 0| 0|
|189012240044| 0| 0| 0| 0| 0| 0|
|151012350044| 0| 0| 0| 0| 0| 0|
+------------+-----------+-----------+-----------+-----------+-----------+-----------+
>>> unpivotExpr = "stack(6, 'flag_010961',flag_010961,'flag_011622',flag_011622,'flag_009670',flag_009670, 'flag_009708',flag_009708, 'flag_009890',flag_009890, 'flag_009893',flag_009893) as (flag,flag_val)"
>>> unPivotDF = df2.select("rm_ky", expr(unpivotExpr))
>>> unPivotDF.show()
+------------+-----------+--------+
| rm_ky| flag|flag_val|
+------------+-----------+--------+
|193012020044|flag_010961| 0|
|193012020044|flag_011622| 0|
|193012020044|flag_009670| 0|
|193012020044|flag_009708| 0|
|193012020044|flag_009890| 0|
|193012020044|flag_009893| 0|
|115012030044|flag_010961| 0|
|115012030044|flag_011622| 0|
|115012030044|flag_009670| 1|
|115012030044|flag_009708| 1|
|115012030044|flag_009890| 1|
|115012030044|flag_009893| 1|
|140012220044|flag_010961| 0|
|140012220044|flag_011622| 0|
|140012220044|flag_009670| 0|
|140012220044|flag_009708| 0|
|140012220044|flag_009890| 0|
|140012220044|flag_009893| 0|
|189012240044|flag_010961| 0|
|189012240044|flag_011622| 0|
+------------+-----------+--------+
only showing top 20 rows
>>> unPivotDF.groupBy("flag","flag_val").count().show()
+-----------+--------+-----+
| flag|flag_val|count|
+-----------+--------+-----+
|flag_009670| 0| 4|
|flag_009708| 0| 4|
|flag_009893| 0| 4|
|flag_009890| 0| 4|
|flag_009670| 1| 1|
|flag_009893| 1| 1|
|flag_011622| 0| 5|
|flag_010961| 0| 5|
|flag_009890| 1| 1|
|flag_009708| 1| 1|
+-----------+--------+-----+
>>> unPivotDF.groupBy("rm_ky","flag","flag_val").count().show()
+------------+-----------+--------+-----+
| rm_ky| flag|flag_val|count|
+------------+-----------+--------+-----+
|151012350044|flag_009708| 0| 1|
|115012030044|flag_010961| 0| 1|
|140012220044|flag_009670| 0| 1|
|189012240044|flag_010961| 0| 1|
|151012350044|flag_009670| 0| 1|
|115012030044|flag_009890| 1| 1|
|151012350044|flag_009890| 0| 1|
|189012240044|flag_009890| 0| 1|
|193012020044|flag_011622| 0| 1|
|193012020044|flag_009670| 0| 1|
|115012030044|flag_009670| 1| 1|
|140012220044|flag_011622| 0| 1|
|151012350044|flag_009893| 0| 1|
|140012220044|flag_009893| 0| 1|
|189012240044|flag_011622| 0| 1|
|189012240044|flag_009893| 0| 1|
|115012030044|flag_009893| 1| 1|
|140012220044|flag_009708| 0| 1|
|189012240044|flag_009708| 0| 1|
|193012020044|flag_010961| 0| 1|
+------------+-----------+--------+-----+

Is it possible to filter columns by the sum of their values in Spark?

I'm loading a sparse table using PySpark where I want to remove all columns where the sum of all values in the column is above a threshold.
For example, the sum of column values of the following table:
+---+---+---+---+---+---+
| a| b| c| d| e| f|
+---+---+---+---+---+---+
| 1| 0| 1| 1| 0| 0|
| 1| 1| 0| 0| 0| 0|
| 1| 0| 0| 1| 1| 1|
| 1| 0| 0| 1| 1| 1|
| 1| 1| 0| 0| 1| 0|
| 0| 0| 1| 0| 1| 0|
+---+---+---+---+---+---+
Is 5, 2, 2, 3, 4 and 2. Filtering for all columns with sum >= 3 should output this table:
+---+---+---+
| a| d| e|
+---+---+---+
| 1| 1| 0|
| 1| 0| 0|
| 1| 1| 1|
| 1| 1| 1|
| 1| 0| 1|
| 0| 0| 1|
+---+---+---+
I tried many different solutions without success. df.groupBy().sum() is giving me the sum of column values, so I'm searching how I can then filter those with threshold and get only the remaining columns from the original dataframe.
As there are not only 6 but a couple of thousand columns, I'm searching for a scalable solution, where I don't have to type in every column name. Thanks for help!
You can do this with a collect (or a first) step.
from pyspark.sql import functions as F
sum_result = df.groupBy().agg(*(F.sum(col).alias(col) for col in df.columns)).first()
filtered_df = df.select(
*(col for col, value in sum_result.asDict().items() if value >= 3)
)
filtered_df.show()
+---+---+---+
| a| d| e|
+---+---+---+
| 1| 1| 0|
| 1| 0| 0|
| 1| 1| 1|
| 1| 1| 1|
| 1| 0| 1|
| 0| 0| 1|
+---+---+---+

Split numerical count in Spark DataFrame column into several columns

Let's say I have a spark DataFrame like this
+------------------+----------+--------------+-----+
| user| dt| action|count|
+------------------+----------+--------------+-----+
|Albert |2018-03-24|Action1 | 19|
|Albert |2018-03-25|Action1 | 1|
|Albert |2018-03-26|Action1 | 6|
|Barack |2018-03-26|Action2 | 3|
|Barack |2018-03-26|Action3 | 1|
|Donald |2018-03-26|Action3 | 29|
|Hillary |2018-03-24|Action1 | 4|
|Hillary |2018-03-26|Action2 | 2|
and I'd like to have counts for Action1/Action2/Action3 in the separate counts, so to convert it into another DataFrame like this
+------------------+----------+-------------+-------------+-------------+
| user| dt|action1_count|action2_count|action3_count|
+------------------+----------+-------------+-------------+-------------+
|Albert |2018-03-24| 19| 0| 0|
|Albert |2018-03-25| 1| 0| 0|
|Albert |2018-03-26| 6| 0| 0|
|Barack |2018-03-26| 0| 3| 0|
|Barack |2018-03-26| 0| 0| 1|
|Donald |2018-03-26| 0| 0| 29|
|Hillary |2018-03-24| 4| 0| 0|
|Hillary |2018-03-26| 0| 2| 0|
As I'm a newbie to Spark, my attempt to reach that was quite dull and straightforward:
Get 3 new DF's from filtering by each "action"
Join original DF with each of new ones, using the second DF's "count" in the new DF
The code I tried looked like this:
val a1 = originalDf.filter("action = 'Action1'")
val df1 = originalDf.as('o)
.join(a1,
($"o.user" === $"a1.user" && $"o.dt" === $"a1.dt"),
"left_outer")
.select($"o.user", $"o.dt", $"a1.count".as("action1_count"))
Then do the same with Action2/Action3, then join those.
However, even at this stage I've already got several problems with such approach:
It doesn't work at all - I mean fails with an error the reason of which I don't understand: org.apache.spark.sql.AnalysisException: cannot resolve 'o.user' given input columns: [user, dt, action, count, user, dt, action, count];
Even if it succeeded, I assume I would have got nulls where I need zeros.
I feel there should be a better way to reach this. Like some map construct or something. But at the moment I don't feel I'm able to construct the transform required to convert first dataframe into second one.
So as right now I don't have working solution at all, I'll be very thankful for any suggestions.
UPD: I might also get DF's that don't contain all of 3 possible "action" values, for instance
+------------------+----------+--------------+-----+
| user| dt| action|count|
+------------------+----------+--------------+-----+
|Albert |2018-03-24|Action1 | 19|
|Albert |2018-03-25|Action1 | 1|
|Albert |2018-03-26|Action1 | 6|
|Hillary |2018-03-24|Action1 | 4|
For those, I still need the resulting DF with 3 columns:
+------------------+----------+-------------+-------------+-------------+
| user| dt|action1_count|action2_count|action3_count|
+------------------+----------+-------------+-------------+-------------+
|Albert |2018-03-24| 19| 0| 0|
|Albert |2018-03-25| 1| 0| 0|
|Albert |2018-03-26| 6| 0| 0|
|Hillary |2018-03-24| 4| 0| 0|
You can avoid multiple join by using when to select appropriate value of column.
About your join, I don't really think it got exception like cannot resolve 'o.user', you may want to check your code again.
val df = Seq(("Albert","2018-03-24","Action1",19),
("Albert","2018-03-25","Action1",1),
("Albert","2018-03-26","Action1",6),
("Barack","2018-03-26","Action2",3),
("Barack","2018-03-26","Action3",1),
("Donald","2018-03-26","Action3",29),
("Hillary","2018-03-24","Action1",4),
("Hillary","2018-03-26","Action2",2)).toDF("user", "dt", "action", "count")
val df2 = df.withColumn("count1", when($"action" === "Action1", $"count").otherwise(lit(0))).
withColumn("count2", when($"action" === "Action2", $"count").otherwise(lit(0))).
withColumn("count3", when($"action" === "Action3", $"count").otherwise(lit(0)))
+-------+----------+-------+-----+------+------+------+
|user |dt |action |count|count1|count2|count3|
+-------+----------+-------+-----+------+------+------+
|Albert |2018-03-24|Action1|19 |19 |0 |0 |
|Albert |2018-03-25|Action1|1 |1 |0 |0 |
|Albert |2018-03-26|Action1|6 |6 |0 |0 |
|Barack |2018-03-26|Action2|3 |0 |3 |0 |
|Barack |2018-03-26|Action3|1 |0 |0 |1 |
|Donald |2018-03-26|Action3|29 |0 |0 |29 |
|Hillary|2018-03-24|Action1|4 |4 |0 |0 |
|Hillary|2018-03-26|Action2|2 |0 |2 |0 |
+-------+----------+-------+-----+------+------+------+
Here's one approach using pivot and first, with the advantage of not having to know what the action values are:
val df = Seq(
("Albert", "2018-03-24", "Action1", 19),
("Albert", "2018-03-25", "Action1", 1),
("Albert", "2018-03-26", "Action1", 6),
("Barack", "2018-03-26", "Action2", 3),
("Barack", "2018-03-26", "Action3", 1),
("Donald", "2018-03-26", "Action3", 29),
("Hillary", "2018-03-24", "Action1", 4),
("Hillary", "2018-03-26", "Action2", 2)
).toDF("user", "dt", "action", "count")
val pivotDF = df.groupBy("user", "dt", "action").pivot("action").agg(first($"count")).
na.fill(0).
orderBy("user", "dt", "action")
// +-------+----------+-------+-------+-------+-------+
// | user| dt| action|Action1|Action2|Action3|
// +-------+----------+-------+-------+-------+-------+
// | Albert|2018-03-24|Action1| 19| 0| 0|
// | Albert|2018-03-25|Action1| 1| 0| 0|
// | Albert|2018-03-26|Action1| 6| 0| 0|
// | Barack|2018-03-26|Action2| 0| 3| 0|
// | Barack|2018-03-26|Action3| 0| 0| 1|
// | Donald|2018-03-26|Action3| 0| 0| 29|
// |Hillary|2018-03-24|Action1| 4| 0| 0|
// |Hillary|2018-03-26|Action2| 0| 2| 0|
// +-------+----------+-------+-------+-------+-------+
[UPDATE]
Per comments, if you have more Action? to be created as columns than those in the pivot column, you can traverse the missing Action? to add them as zero-filled as columns:
val fullActionList = List("Action1", "Action2", "Action3", "Action4", "Action5")
val missingActions = fullActionList.diff(
pivotDF.select($"action").as[String].collect.toList.distinct
)
// missingActions: List[String] = List(Action4, Action5)
missingActions.foldLeft( pivotDF )( _.withColumn(_, lit(0)) ).
show
// +-------+----------+-------+-------+-------+-------+-------+-------+
// | user| dt| action|Action1|Action2|Action3|Action4|Action5|
// +-------+----------+-------+-------+-------+-------+-------+-------+
// | Albert|2018-03-24|Action1| 19| 0| 0| 0| 0|
// | Albert|2018-03-25|Action1| 1| 0| 0| 0| 0|
// | Albert|2018-03-26|Action1| 6| 0| 0| 0| 0|
// | Barack|2018-03-26|Action2| 0| 3| 0| 0| 0|
// | Barack|2018-03-26|Action3| 0| 0| 1| 0| 0|
// | Donald|2018-03-26|Action3| 0| 0| 29| 0| 0|
// |Hillary|2018-03-24|Action1| 4| 0| 0| 0| 0|
// |Hillary|2018-03-26|Action2| 0| 2| 0| 0| 0|
// +-------+----------+-------+-------+-------+-------+-------+-------+

Convert Dataframe to multiple 2D arrays

I have this dataset:
+----+-----+-------+-----+
|code|code2|machine|value|
+----+-----+-------+-----+
| 1| 2| A| 42|
| 2| 1| A| 11|
| 1| 4| A| 55|
| 1| 1| B| 2|
| 3| 3| B| 34|
| 3| 2| B| 111|
+----+-----+-------+-----+
I want that for each machine a kind of matrix like the following:
code and code2 are the column and at the intersection I want to fill the value.
Machine A
+----+----+----+----+----+
| A| 1| 2| 3| 4|
+----+----+----+----+----+
| 1| 0| 11| 0| 0|
| 2| 42| 0| 0| 0|
| 3| 0| 0| 0| 0|
| 4| 55| 0| 0| 0|
+----+----+----+----+----+
Machine B
+----+----+----+----+----+
| B| 1| 2| 3| 4|
+----+----+----+----+----+
| 1| 2| 0| 0| 0|
| 2| 0| 0| 111| 0|
| 3| 0| 0| 34| 0|
| 4| 0| 0| 0| 0|
+----+----+----+----+----+
I have multiple machine there (unknown number) and the codes can only be 0-255.
So my problem is how to achieve that matrix...
My fist naive idea was to make a hashmap and as key the machine name and as value a 256x256 2D array. But I don't think it would be efficient and I also don't know how to achieve that.
Or probably have a dataset for each machine??
If someone has an idea I would like to listen.
Btw I'm using Scala.
For maximum coding flexibility, you could switch to the RDD API. An example of a solution would give you a RDD that maps a machine to its matrix, represented as a scala two-dimensional array. Note that Array.ofDimInt creates a two-dim array of sine n*m with zeros everywhere.
df
.map(x=> x.getAs[String]("machine") -> (x.getAs[Int]("code"), x.getAs[Int]("code2"),x.getAs[Int]("value")))
.groupByKey
.mapValues( seq => {
var result = Array.ofDim[Int](256, 256)
seq.foreach{ case (i,j,value) => result(i)(j) = value }
result
})

Spark - Window with recursion? - Conditionally propagating values across rows

I have the following dataframe showing the revenue of purchases.
+-------+--------+-------+
|user_id|visit_id|revenue|
+-------+--------+-------+
| 1| 1| 0|
| 1| 2| 0|
| 1| 3| 0|
| 1| 4| 100|
| 1| 5| 0|
| 1| 6| 0|
| 1| 7| 200|
| 1| 8| 0|
| 1| 9| 10|
+-------+--------+-------+
Ultimately I want the new column purch_revenue to show the revenue generated by the purchase in every row.
As a workaround, I have also tried to introduce a purchase identifier purch_id which is incremented each time a purchase was made. So this is listed just as a reference.
+-------+--------+-------+-------------+--------+
|user_id|visit_id|revenue|purch_revenue|purch_id|
+-------+--------+-------+-------------+--------+
| 1| 1| 0| 100| 1|
| 1| 2| 0| 100| 1|
| 1| 3| 0| 100| 1|
| 1| 4| 100| 100| 1|
| 1| 5| 0| 100| 2|
| 1| 6| 0| 100| 2|
| 1| 7| 200| 100| 2|
| 1| 8| 0| 100| 3|
| 1| 9| 10| 100| 3|
+-------+--------+-------+-------------+--------+
I've tried to use the lag/lead function like this:
user_timeline = Window.partitionBy("user_id").orderBy("visit_id")
find_rev = fn.when(fn.col("revenue") > 0,fn.col("revenue"))\
.otherwise(fn.lead(fn.col("revenue"), 1).over(user_timeline))
df.withColumn("purch_revenue", find_rev)
This duplicates the revenue column if revenue > 0 and also pulls it up by one row. Clearly, I can chain this for a finite N, but that's not a solution.
Is there a way to apply this recursively until revenue > 0?
Alternatively, is there a way to increment a value based on a condition? I've tried to figure out a way to do that but struggled to find one.
Window functions don't support recursion but it is not required here. This type of sesionization can be easily handled with cumulative sum:
from pyspark.sql.functions import col, sum, when, lag
from pyspark.sql.window import Window
w = Window.partitionBy("user_id").orderBy("visit_id")
purch_id = sum(lag(when(
col("revenue") > 0, 1).otherwise(0),
1, 0
).over(w)).over(w) + 1
df.withColumn("purch_id", purch_id).show()
+-------+--------+-------+--------+
|user_id|visit_id|revenue|purch_id|
+-------+--------+-------+--------+
| 1| 1| 0| 1|
| 1| 2| 0| 1|
| 1| 3| 0| 1|
| 1| 4| 100| 1|
| 1| 5| 0| 2|
| 1| 6| 0| 2|
| 1| 7| 200| 2|
| 1| 8| 0| 3|
| 1| 9| 10| 3|
+-------+--------+-------+--------+

Resources