Extracting value using Window and Partition - apache-spark

I have a dataframe in pyspark
id | value
1 0
1 1
1 0
2 1
2 0
3 0
3 0
3 1
I want to extract all the rows after the first occurrence of 1 in value column in the same id group. I have created Window with partition of Id but do not know how to get rows which are present after value 1.
Im expecting result to be
id | value
1 1
1 0
2 1
2 0
3 1

Below solutions may be relevant for this (It is working perfectly for small data but may cause the problem in big data if id are on multiple partitions)
df = sqlContext.createDataFrame([
[1, 0],
[1, 1],
[1, 0],
[2, 1],
[2, 0],
[3, 0],
[3, 0],
[3, 1]
],
['id', 'Value']
)
df.show()
+---+-----+
| id|Value|
+---+-----+
| 1| 0|
| 1| 1|
| 1| 0|
| 2| 1|
| 2| 0|
| 3| 0|
| 3| 0|
| 3| 1|
+---+-----+
#importing Libraries
from pyspark.sql import functions as F
from pyspark.sql.window import Window as W
import sys
#This way we can generate a cumulative sum for values
df.withColumn(
"sum",
F.sum(
"value"
).over(W.partitionBy(["id"]).rowsBetween(-sys.maxsize, 0))
).show()
+---+-----+-----+
| id|Value|sum |
+---+-----+-----+
| 1| 0| 0|
| 1| 1| 1|
| 1| 0| 1|
| 3| 0| 0|
| 3| 0| 0|
| 3| 1| 1|
| 2| 1| 1|
| 2| 0| 1|
+---+-----+-----+
#Filter all those which are having sum > 0
df.withColumn(
"sum",
F.sum(
"value"
).over(W.partitionBy(["id"]).rowsBetween(-sys.maxsize, 0))
).where("sum > 0").show()
+---+-----+-----+
| id|Value|sum |
+---+-----+-----+
| 1| 1| 1|
| 1| 0| 1|
| 3| 1| 1|
| 2| 1| 1|
| 2| 0| 1|
+---+-----+-----+
Before running this you must be sure that data related to ID should be partitioned and no id can be on 2 partitions.

Ideally, you would need to:
Create a window partitioned by id and ordered the same way the dataframe already is
Keep only the rows for which there is a "one" before them in the window
AFAIK, there is no look up function within windows in Spark. Yet, you could follow this idea and work something out. Let's first create the data and import functions and windows.
import pyspark.sql.functions as F
from pyspark.sql.window import Window
l = [(1, 0), (1, 1), (1, 0), (2, 1), (2, 0), (3, 0), (3, 0), (3, 1)]
df = spark.createDataFrame(l, ['id', 'value'])
Then, let's add an index on the dataframe (it's free) to be able to order the windows.
indexedDf = df.withColumn("index", F.monotonically_increasing_id())
Then we create a window that only looks at the values before the current row, ordered by that index and partitioned by id.
w = Window.partitionBy("id").orderBy("index").rowsBetween(Window.unboundedPreceding, 0)
Finally, we use that window to collect the set of preceding values of each row, and filter out the ones that do not contain 1. Optionally, we order back by index because the windowing does not preserve the order by id column.
indexedDf\
.withColumn('set', F.collect_set(F.col('value')).over(w))\
.where(F.array_contains(F.col('set'), 1))\
.orderBy("index")\
.select("id", "value").show()
+---+-----+
| id|value|
+---+-----+
| 1| 1|
| 1| 0|
| 2| 1|
| 2| 0|
| 3| 1|
+---+-----+

Related

Check multiple columns for any column greater than zero using a regex

I need to apply a when function on multiple columns. I want to check if at least one of the columns has a value greater than 0.
This is my solution:
df.withColumn("any value", F.when(
(col("col1") > 0) |
(col("col2") > 0) |
(col("col3") > 0) |
...
(col("colX") > 0)
, "any greater than 0").otherwise(None))
Is it possible to do the same task with a regex, so I don't have to write all the column names?
So let's create sample data:
df = spark.createDataFrame(
[(0, 0, 0, 0), (0, 0, 2, 0), (0, 0, 0, 0), (1, 0, 0, 0)],
['a', 'b', 'c', 'd']
)
Then, you can build your condition from a list of columns (say all the columns of the dataframe) using map and reduce like this:
cols = df.columns
from pyspark.sql import functions as F
condition = reduce(lambda a, b: a | b, map(lambda c: F.col(c) > 0, cols))
df.withColumn("any value", F.when(condition, "any greater than 0")).show()
which yields:
+---+---+---+---+------------------+
| a| b| c| d| any value|
+---+---+---+---+------------------+
| 0| 0| 0| 0| null|
| 0| 0| 2| 0|any greater than 0|
| 0| 0| 0| 0| null|
| 1| 0| 0| 0|any greater than 0|
+---+---+---+---+------------------+
Another way you could have this done is create an array, use forall to check and conditionally assign values. Code below
df = df.withColumn('any value', array(df.columns)).withColumn('any value',when(forall('any value',lambda x: x==0),None).otherwise("any greater than 0"))
df.show()
+---+---+---+---+------------------+
| a| b| c| d| any value|
+---+---+---+---+------------------+
| 0| 0| 0| 0| null|
| 0| 0| 2| 0|any greater than 0|
| 0| 0| 0| 0| null|
| 1| 0| 0| 0|any greater than 0|
+---+---+---+---+------------------+

How to find the distribution of a column in PySpark dataframe for all the unique values present in that column?

I have a PySpark dataframe-
df = spark.createDataFrame([
("u1", 0),
("u2", 0),
("u3", 1),
("u4", 2),
("u5", 3),
("u6", 2),],
['user_id', 'medals'])
df.show()
Output-
+-------+------+
|user_id|medals|
+-------+------+
| u1| 0|
| u2| 0|
| u3| 1|
| u4| 2|
| u5| 3|
| u6| 2|
+-------+------+
I want to get the distribution of the medals column for all the users. So if there are n unique values in the medals column, I want n columns in the output dataframe with corresponding number of users who received that many medals.
The output for the data given above should look like-
+------- +--------+--------+--------+
|medals_0|medals_1|medals_2|medals_3|
+--------+--------+--------+--------+
| 2| 1| 2| 1|
+--------+--------+--------+--------+
How do I achieve this?
it's a simple pivot:
df.groupBy().pivot("medals").count().show()
+---+---+---+---+
| 0| 1| 2| 3|
+---+---+---+---+
| 2| 1| 2| 1|
+---+---+---+---+
if you need some cosmetic to add the word medals in the column name, then you can do this :
medals_df = df.groupBy().pivot("medals").count()
for col in medals_df.columns:
medals_df = medals_df.withColumnRenamed(col, "medals_{}".format(col))
medals_df.show()
+--------+--------+--------+--------+
|medals_0|medals_1|medals_2|medals_3|
+--------+--------+--------+--------+
| 2| 1| 2| 1|
+--------+--------+--------+--------+

How spark RangeBetween works with Descending Order?

I thought rangeBetween(start, end) looks into values of the range(cur_value - start, cur_value + end). https://spark.apache.org/docs/2.3.0/api/java/org/apache/spark/sql/expressions/WindowSpec.html
But, I saw an example where they used descending orderBy() on timestamp, and then used (unboundedPreceeding, 0) with rangeBetween. Which led me to explore the following example:
dd = spark.createDataFrame(
[(1, "a"), (3, "a"), (3, "a"), (1, "b"), (2, "b"), (3, "b")],
['id', 'category']
)
dd.show()
# output
+---+--------+
| id|category|
+---+--------+
| 1| a|
| 3| a|
| 3| a|
| 1| b|
| 2| b|
| 3| b|
+---+--------+
It seems to include preceding row whose value is higher by 1.
byCategoryOrderedById = Window.partitionBy('category')\
.orderBy(desc('id'))\
.rangeBetween(-1, Window.currentRow)
dd.withColumn("sum", Fsum('id').over(byCategoryOrderedById)).show()
# output
+---+--------+---+
| id|category|sum|
+---+--------+---+
| 3| b| 3|
| 2| b| 5|
| 1| b| 3|
| 3| a| 6|
| 3| a| 6|
| 1| a| 1|
+---+--------+---+
And with start set to -2, it includes value greater by 2 but in preceding rows.
byCategoryOrderedById = Window.partitionBy('category')\
.orderBy(desc('id'))\
.rangeBetween(-2,Window.currentRow)
dd.withColumn("sum", Fsum('id').over(byCategoryOrderedById)).show()
# output
+---+--------+---+
| id|category|sum|
+---+--------+---+
| 3| b| 3|
| 2| b| 5|
| 1| b| 6|
| 3| a| 6|
| 3| a| 6|
| 1| a| 7|
+---+--------+---+
So, what is the exact behavior of rangeBetween with desc orderBy?
It's not well documented but when using range (or value-based) frames the ascending and descending order affects the determination of the values that are included in the frame.
Let's take the example you provided:
RANGE BETWEEN 1 PRECEDING AND CURRENT ROW
Depending on the order by direction, 1 PRECEDING means:
current_row_value - 1 if ASC
current_row_value + 1 if DESC
Consider the row with value 1 in partition b.
With the descending order, the frame includes :
(current_value and all preceding values where x = current_value + 1) = (1, 2)
With the ascending order, the frame includes:
(current_value and all preceding values where x = current_value - 1) = (1)
PS: using rangeBetween(-1, Window.currentRow) with desc ordering is just equivalent to rangeBetween(Window.currentRow, 1) with asc ordering.

how can I create a pyspark udf using multiple columns?

I need to write some custum code using multiple columns within a group of my data.
My custom code is to set a flag if a value is over a threshold, but suppress the flag if it is within a certain time of a previous flag.
Here is some sample code:
df = spark.createDataFrame(
[
("a", 1, 0),
("a", 2, 1),
("a", 3, 1),
("a", 4, 1),
("a", 5, 1),
("a", 6, 0),
("a", 7, 1),
("a", 8, 1),
("b", 1, 0),
("b", 2, 1)
],
["group_col","order_col", "flag_col"]
)
df.show()
+---------+---------+--------+
|group_col|order_col|flag_col|
+---------+---------+--------+
| a| 1| 0|
| a| 2| 1|
| a| 3| 1|
| a| 4| 1|
| a| 5| 1|
| a| 6| 0|
| a| 7| 1|
| a| 8| 1|
| b| 1| 0|
| b| 2| 1|
+---------+---------+--------+
from pyspark.sql.functions import udf, col, asc
from pyspark.sql.window import Window
def _suppress(dates=None, alert_flags=None, window=2):
sup_alert_flag = alert_flag
last_alert_date = None
for i, alert_flag in enumerate(alert_flag):
current_date = dates[i]
if alert_flag == 1:
if not last_alert_date:
sup_alert_flag[i] = 1
last_alert_date = current_date
elif (current_date - last_alert_date) > window:
sup_alert_flag[i] = 1
last_alert_date = current_date
else:
sup_alert_flag[i] = 0
else:
alert_flag = 0
return sup_alert_flag
suppress_udf = udf(_suppress, DoubleType())
df_out = df.withColumn("supressed_flag_col", suppress_udf(dates=col("order_col"), alert_flags=col("flag_col"), window=4).Window.partitionBy(col("group_col")).orderBy(asc("order_col")))
df_out.show()
The above fails, but my expected output is the following:
+---------+---------+--------+------------------+
|group_col|order_col|flag_col|supressed_flag_col|
+---------+---------+--------+------------------+
| a| 1| 0| 0|
| a| 2| 1| 1|
| a| 3| 1| 0|
| a| 4| 1| 0|
| a| 5| 1| 0|
| a| 6| 0| 0|
| a| 7| 1| 1|
| a| 8| 1| 0|
| b| 1| 0| 0|
| b| 2| 1| 1|
+---------+---------+--------+------------------+
Editing answer after more thought.
The general problem seems to be that the result of the current row depends upon result of the previous row. In effect, there is a recurrence relationship. I haven't found a good way to implement a recursive UDF in Spark. There are several challenges that result from the assumed distributed nature of the data in Spark which would make this difficult to achieve. At least in my mind. The following solution should work but may not scale for large data sets.
from pyspark.sql import Row
import pyspark.sql.functions as F
import pyspark.sql.types as T
suppress_flag_row = Row("order_col", "flag_col", "res_flag")
def suppress_flag( date_alert_flags, window_size ):
sorted_alerts = sorted( date_alert_flags, key=lambda x: x["order_col"])
res_flags = []
last_alert_date = None
for row in sorted_alerts:
current_date = row["order_col"]
aflag = row["flag_col"]
if aflag == 1 and (not last_alert_date or (current_date - last_alert_date) > window_size):
res = suppress_flag_row(current_date, aflag, True)
last_alert_date = current_date
else:
res = suppress_flag_row(current_date, aflag, False)
res_flags.append(res)
return res_flags
in_fields = [T.StructField("order_col", T.IntegerType(), nullable=True )]
in_fields.append( T.StructField("flag_col", T.IntegerType(), nullable=True) )
out_fields = in_fields
out_fields.append(T.StructField("res_flag", T.BooleanType(), nullable=True) )
out_schema = T.StructType(out_fields)
suppress_udf = F.udf(suppress_flag, T.ArrayType(out_schema) )
window_size = 4
tmp = df.groupBy("group_col").agg( F.collect_list( F.struct( F.col("order_col"), F.col("flag_col") ) ).alias("date_alert_flags"))
tmp2 = tmp.select(F.col("group_col"), suppress_udf(F.col("date_alert_flags"), F.lit(window_size)).alias("suppress_res"))
expand_fields = [F.col("group_col")] + [F.col("res_expand")[f.name].alias(f.name) for f in out_fields]
final_df = tmp2.select(F.col("group_col"), F.explode(F.col("suppress_res")).alias("res_expand")).select( expand_fields )
I think, You don't need custom function for this. you can use rowsBetween option along with window to get the 5 rows range. Please check and let me know if missed something.
>>> from pyspark.sql import functions as F
>>> from pyspark.sql import Window
>>> w = Window.partitionBy('group_col').orderBy('order_col').rowsBetween(-5,-1)
>>> df = df.withColumn('supr_flag_col',F.when(F.sum('flag_col').over(w) == 0,1).otherwise(0))
>>> df.orderBy('group_col','order_col').show()
+---------+---------+--------+-------------+
|group_col|order_col|flag_col|supr_flag_col|
+---------+---------+--------+-------------+
| a| 1| 0| 0|
| a| 2| 1| 1|
| a| 3| 1| 0|
| b| 1| 0| 0|
| b| 2| 1| 1|
+---------+---------+--------+-------------+

[Py]Spark SQL: Merge two or more rows based on equal values in different columns

I have the following DataFrame ordered by group, n1, n2
+-----+--+--+------+------+
|group|n1|n2|n1_ptr|n2_ptr|
+-----+--+--+------+------+
| 1| 0| 0| 1| 1|
| 1| 1| 1| 2| 2|
| 1| 1| 5| 2| 6|
| 1| 2| 2| 3| 3|
| 1| 2| 6| 3| 7|
| 1| 3| 3| 4| 4|
| 1| 3| 7| null| null|
| 1| 4| 4| 5| 5|
| 1| 5| 1| null| null|
| 1| 5| 5| null| null|
+-----+--+--+------+------+
Each row's n1_ptr and n2_ptr values refer to the n1 and n2 values of some other row in the group that comes later in the ordering. In other words, n1_ptr and n2_ptr are effectively pointers to another row. I want to use these pointers to identify chains of (n1, n2) pairs. For example, the chains in the given data would be: (0,0) -> (1,1) -> (2,2) -> (3,3) -> (4,4) -> (5,5); (1,5) -> (2,6) -> (3,7); and (5,1).
The ultimate goal is to consolidate each chain into a single row in a DataFrame describing the min and max n1 and n2 values in each chain. Continuing the example, this would yield
+-----+------+------+------+------+
|group|n1_min|n2_min|n1_max|n2_max|
+-----+------+------+------+------+
| 1| 0| 0| 5| 5|
| 1| 1| 5| 3| 7|
| 1| 5| 1| 5| 1|
+-----+------+------+------+------+
It seems like a udf might do the trick, but I am concerned about performance. Is there a more sensible/performant way to go about this?
A good solution would be to use graphframes: https://graphframes.github.io/quick-start.html.
First let's change the structure of your initial dataframe:
import pyspark.sql.functions as psf
df = sc.parallelize([[1, 0, 0, 1, 1],[1, 1, 1, 2, 2],[1, 1, 5, 2, 6],
[1, 2, 2, 3, 3],[1, 2, 6, 3, 7],[1, 3, 3, 4, 4],
[1, 3, 7, None, None],[1, 4, 4, 5, 5],[1, 5, 1, None, None],
[1, 5, 5, None, None]]).toDF(["group","n1","n2","n1_ptr","n2_ptr"]).filter("n1_ptr IS NOT NULL")
df = df.select(
"group",
psf.struct("n1", "n2").alias("src"),
psf.struct(df.n1_ptr.alias("n1"), df.n2_ptr.alias("n2")).alias("dst"))
From df we'll build a vertex and an edge dataframe:
v = df.select(
"group",
psf.explode(psf.array("src", "dst")).alias("id"))
e = df.drop("group")
The next step is to find all connected components using graphframes:
from graphframes import *
g = GraphFrame(v, e)
res = g.connectedComponents()
+-----+-----+------------+
|group| id| component|
+-----+-----+------------+
| 1|[0,0]|309237645312|
| 1|[1,1]|309237645312|
| 1|[1,1]|309237645312|
| 1|[2,2]|309237645312|
| 1|[1,5]| 85899345920|
| 1|[2,6]| 85899345920|
| 1|[2,2]|309237645312|
| 1|[3,3]|309237645312|
| 1|[2,6]| 85899345920|
| 1|[3,7]| 85899345920|
| 1|[3,3]|309237645312|
| 1|[4,4]|309237645312|
| 1|[3,7]| 85899345920|
| 1|[4,4]|309237645312|
| 1|[5,5]|309237645312|
| 1|[5,1]|292057776128|
| 1|[5,5]|309237645312|
+-----+-----+------------+
Now since the relation in your graph edges implies that nodes numbers n1 and n2 are monotonically increasing, we can simply aggregate by component and compute the min and max:
res.groupBy("group", "component").agg(
psf.min("id").alias("min_id"),
psf.max("id").alias("max_id")
)
+-----+------------+------+------+
|group| component|min_id|max_id|
+-----+------------+------+------+
| 1|309237645312| [0,0]| [5,5]|
| 1| 85899345920| [1,5]| [3,7]|
| 1|292057776128| [5,1]| [5,1]|
+-----+------------+------+------+

Resources