Select column name per row for max value in PySpark - apache-spark

I have a dataframe like this, shown only two columns however there are many columns in original dataframe
data = [(("ID1", 3, 5)), (("ID2", 4, 12)), (("ID3", 8, 3))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()
+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1| 3| 5|
|ID2| 4| 12|
|ID3| 8| 3|
+---+----+----+
I want to extract the name of the column per row, which has the max value. Hence the expected output is like this
+---+----+----+-------+
| ID|colA|colB|Max_col|
+---+----+----+-------+
|ID1| 3| 5| colB|
|ID2| 4| 12| colB|
|ID3| 8| 3| colA|
+---+----+----+-------+
In case of tie, where colA and colB have same value, choose the first column.
How can I achieve this in pyspark

You can use UDF on each row for row wise computation and use struct to pass multiple columns to udf. Hope this helps.
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
from operator import itemgetter
data = [(("ID1", 3, 5,78)), (("ID2", 4, 12,45)), (("ID3", 70, 3,67))]
df = spark.createDataFrame(data, ["ID", "colA", "colB","colC"])
df.show()
+---+----+----+----+
| ID|colA|colB|colC|
+---+----+----+----+
|ID1| 3| 5| 78|
|ID2| 4| 12| 45|
|ID3| 70| 3| 70|
+---+----+----+----+
cols = df.columns
# to get max of values in a row
maxcol = F.udf(lambda row: max(row), IntegerType())
maxDF = df.withColumn("maxval", maxcol(F.struct([df[x] for x in df.columns[1:]])))
maxDF.show()
+---+----+----+----+-------+
|ID |colA|colB|colC|Max_col|
+---+----+----+----+-------+
|ID1|3 |5 |78 |78 |
|ID2|4 |12 |45 |45 |
|ID3|70 |3 |67 |70 |
+---+----+----+----+-------+
# to get max of value & corresponding column name
schema=StructType([StructField('maxval',IntegerType()),StructField('maxval_colname',StringType())])
maxcol = F.udf(lambda row: max(row,key=itemgetter(0)), schema)
maxDF = df.withColumn('maxfield', maxcol(F.struct([F.struct(df[x],F.lit(x)) for x in df.columns[1:]]))).\
select(df.columns+['maxfield.maxval','maxfield.maxval_colname'])
+---+----+----+----+------+--------------+
| ID|colA|colB|colC|maxval|maxval_colname|
+---+----+----+----+------+--------------+
|ID1| 3 | 5 | 78 | 78 | colC |
|ID2| 4 | 12 | 45 | 45 | colC |
|ID3| 70 | 3 | 67 | 68 | colA |
+---+----+----+----+------+--------------+

There are multiple options to achieve this. I am a providing example for one and can provide a hint for rest-
from pyspark.sql import functions as F
from pyspark.sql.window import Window as W
from pyspark.sql import types as T
data = [(("ID1", 3, 5)), (("ID2", 4, 12)), (("ID3", 8, 3))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()
+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1| 3| 5|
|ID2| 4| 12|
|ID3| 8| 3|
+---+----+----+
#Below F.array creates an array of column name and value pair like [['colA', 3], ['colB', 5]] then F.explode break this array into rows like different column and value pair should be in different rows
df = df.withColumn(
"max_val",
F.explode(
F.array([
F.array([F.lit(cl), F.col(cl)]) for cl in df.columns[1:]
])
)
)
df.show()
+---+----+----+----------+
| ID|colA|colB| max_val|
+---+----+----+----------+
|ID1| 3| 5| [colA, 3]|
|ID1| 3| 5| [colB, 5]|
|ID2| 4| 12| [colA, 4]|
|ID2| 4| 12|[colB, 12]|
|ID3| 8| 3| [colA, 8]|
|ID3| 8| 3| [colB, 3]|
+---+----+----+----------+
#Then select columns so that column name and value should be in different columns
df = df.select(
"ID",
"colA",
"colB",
F.col("max_val").getItem(0).alias("col_name"),
F.col("max_val").getItem(1).cast(T.IntegerType()).alias("col_value"),
)
df.show()
+---+----+----+--------+---------+
| ID|colA|colB|col_name|col_value|
+---+----+----+--------+---------+
|ID1| 3| 5| colA| 3|
|ID1| 3| 5| colB| 5|
|ID2| 4| 12| colA| 4|
|ID2| 4| 12| colB| 12|
|ID3| 8| 3| colA| 8|
|ID3| 8| 3| colB| 3|
+---+----+----+--------+---------+
# Rank column values based on ID in desc order
df = df.withColumn(
"rank",
F.rank().over(W.partitionBy("ID").orderBy(F.col("col_value").desc()))
)
df.show()
+---+----+----+--------+---------+----+
| ID|colA|colB|col_name|col_value|rank|
+---+----+----+--------+---------+----+
|ID2| 4| 12| colB| 12| 1|
|ID2| 4| 12| colA| 4| 2|
|ID3| 8| 3| colA| 8| 1|
|ID3| 8| 3| colB| 3| 2|
|ID1| 3| 5| colB| 5| 1|
|ID1| 3| 5| colA| 3| 2|
+---+----+----+--------+---------+----+
#Finally Filter rank = 1 as max value have rank 1 because we ranked desc value
df.where("rank=1").show()
+---+----+----+--------+---------+----+
| ID|colA|colB|col_name|col_value|rank|
+---+----+----+--------+---------+----+
|ID2| 4| 12| colB| 12| 1|
|ID3| 8| 3| colA| 8| 1|
|ID1| 3| 5| colB| 5| 1|
+---+----+----+--------+---------+----+
Other Options are -
Use UDF on your base df and return column name having a max value
In the same example after making the column name and value column instead of rank use group by ID take max col_value. Then join with the previous df.

You can use the RDD API to add the new column:
df.rdd.map(lambda r: r.asDict())\
.map(lambda r: Row(Max_col=max([i for i in r.items() if i[0] != 'ID'],
key=lambda kv: kv[1])[0], **r) )\
.toDF()
Resulting in:
+---+-------+----+----+
| ID|Max_col|colA|colB|
+---+-------+----+----+
|ID1| colB| 3| 5|
|ID2| colB| 4| 12|
|ID3| colA| 8| 3|
+---+-------+----+----+

Extending what Suresh has done.... returning appropriate the column name
from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType, StringType
import numpy as np
data = [(("ID1", 3, 5,78)), (("ID2", 4, 12,45)), (("ID3", 68, 3,67))]
df = spark.createDataFrame(data, ["ID", "colA", "colB","colC"])
df.show()
cols = df.columns
maxcol = f.udf(lambda row: cols[row.index(max(row)) +1], StringType())
maxDF = df.withColumn("Max_col", maxcol(f.struct([df[x] for x in df.columns[1:]])))
maxDF.show(truncate=False)
+---+----+----+----+------+
|ID |colA|colB|colC|Max_col|
+---+----+----+----+------+
|ID1|3 |5 |78 |colC |
|ID2|4 |12 |45 |colC |
|ID3|68 |3 |67 |colA |
+---+----+----+----+------+

try the following:
from pyspark.sql import functions as F
data = [(("ID1", 3, 5)), (("ID2", 4, 12)), (("ID3", 8, 3))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.withColumn('max_col',
F.when(F.col('colA') > F.col('colB'), 'colA').
otherwise('colB')).show()
Yields:
+---+----+----+-------+
| ID|colA|colB|max_col|
+---+----+----+-------+
|ID1| 3| 5| colB|
|ID2| 4| 12| colB|
|ID3| 8| 3| colA|
+---+----+----+-------+

Related

Iterating through rows to create custom formula structure in PySpark

I have a dataframe with variable names and numerator and denominator.
Each variable is a ratio, eg below:
And another dataset with actual data to compute the attributes:
Goal is to create these attributes with formulas in 1st and compute with 2nd.
Currently my approach is very naive:
df = df.withColumn("var1", col('a')/col('b'))./
.
.
.
Desired Output:
Since I have >500 variables, any suggestions for a smarter way to get around this are welcome!
This can be achieved by cross join , unpivot and pivot function in PySpark.
import pyspark.sql.functions as f
from pyspark.sql.functions import *
from pyspark.sql.types import *
data = [
("var1", "a","c"),
("var2", "b","d"),
("var3", "b","a"),
("var4", "d","c")
]
schema = StructType([
StructField('name', StringType(),True), \
StructField('numerator', StringType(),True), \
StructField('denonminator', StringType(),True)
])
data2 = [
("ID1", 6,4,3,7),
("ID2", 1,2,3,9)
]
schema2 = StructType([
StructField('ID', StringType(),True), \
StructField('a', IntegerType(),True), \
StructField('b', IntegerType(),True),\
StructField('c', IntegerType(),True), \
StructField('d', IntegerType(),True)
])
df = spark.createDataFrame(data=data, schema=schema)
df2 = spark.createDataFrame(data=data2, schema=schema2)
df.createOrReplaceTempView("table1")
df2.createOrReplaceTempView("table2")
df.createOrReplaceTempView("table1")
df2.createOrReplaceTempView("table2")
""" CRoss Join for Duplicating the values """
df3=spark.sql("select * from table1 cross join table2")
df3.createOrReplaceTempView("table3")
""" Unpivoting the values and joining to fecth the value of numerator and denominator"""
cols = df2.columns[1:]
df4=df2.selectExpr('ID', "stack({}, {})".format(len(cols), ', '.join(("'{}', {}".format(i, i) for i in cols))))
df4.createOrReplaceTempView("table4")
df5=spark.sql("select name,B.ID,round(B.col1/C.col1,2) as value from table3 A left outer join table4 B on A.ID=B.ID and a.numerator=b.col0 left outer join table4 C on A.ID=C.ID and a.denonminator=C.col0 order by name,ID")
""" Pivot for fetching the results """
df_final=df5.groupBy("ID").pivot("name").max("value")
The results of all intermediate and final dataframes
>>> df.show()
+----+---------+------------+
|name|numerator|denonminator|
+----+---------+------------+
|var1| a| c|
|var2| b| d|
|var3| b| a|
|var4| d| c|
+----+---------+------------+
>>> df2.show()
+---+---+---+---+---+
| ID| a| b| c| d|
+---+---+---+---+---+
|ID1| 6| 4| 3| 7|
|ID2| 1| 2| 3| 9|
+---+---+---+---+---+
>>> df3.show()
+----+---------+------------+---+---+---+---+---+
|name|numerator|denonminator| ID| a| b| c| d|
+----+---------+------------+---+---+---+---+---+
|var1| a| c|ID1| 6| 4| 3| 7|
|var2| b| d|ID1| 6| 4| 3| 7|
|var1| a| c|ID2| 1| 2| 3| 9|
|var2| b| d|ID2| 1| 2| 3| 9|
|var3| b| a|ID1| 6| 4| 3| 7|
|var4| d| c|ID1| 6| 4| 3| 7|
|var3| b| a|ID2| 1| 2| 3| 9|
|var4| d| c|ID2| 1| 2| 3| 9|
+----+---------+------------+---+---+---+---+---+
>>> df4.show()
+---+----+----+
| ID|col0|col1|
+---+----+----+
|ID1| a| 6|
|ID1| b| 4|
|ID1| c| 3|
|ID1| d| 7|
|ID2| a| 1|
|ID2| b| 2|
|ID2| c| 3|
|ID2| d| 9|
+---+----+----+
>>> df5.show()
+----+---+-----+
|name| ID|value|
+----+---+-----+
|var1|ID1| 2.0|
|var1|ID2| 0.33|
|var2|ID1| 0.57|
|var2|ID2| 0.22|
|var3|ID1| 0.67|
|var3|ID2| 2.0|
|var4|ID1| 2.33|
|var4|ID2| 3.0|
+----+---+-----+
>>> df_final.show() final
+---+----+----+----+----+
| ID|var1|var2|var3|var4|
+---+----+----+----+----+
|ID2|0.33|0.22| 2.0| 3.0|
|ID1| 2.0|0.57|0.67|2.33|
+---+----+----+----+----+

Rolling correlation and average (last 3) Per Group in PySpark

I have a dataframe like this
data = [(("ID1", 1, 5)), (("ID1", 2, 6)), (("ID1", 3, 7)),
(("ID1", 4, 4)), (("ID1", 5, 2)), (("ID1", 6, 2)),
(("ID2", 1, 4)), (("ID2", 2, 6)), (("ID2", 3, 1)), (("ID2", 4, 1)), (("ID2", 5, 4))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()
+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1| 1| 5|
|ID1| 2| 6|
|ID1| 3| 7|
|ID1| 4| 4|
|ID1| 5| 2|
|ID1| 6| 2|
|ID2| 1| 4|
|ID2| 2| 6|
|ID2| 3| 1|
|ID2| 4| 1|
|ID2| 5| 4|
+---+----+----+
I want to calculate last 3 correlation and average, per group, of last 3 elements.
Hence for ID1, for first element (5) - Average = 5, corr = 0
for ID1, for first 2 element (5, 6) - Average = 5.5, corr with colA = 1
for ID1, for first 3 element (5, 6, 7) - Average = 6, corr with colA = 1
for ID1, for elements (6, 7, 4) - Average = 5.66, corr with colA = -0.65
Expected output is like this
+---+----+----+----------+---------+
| ID|colA|colB|corr_last3|avg_last3|
+---+----+----+----------+---------+
|ID1| 1| 5| 0| 5|
|ID1| 2| 6| 1| 5.5|
|ID1| 3| 7| 1| 6|
|ID1| 4| 4| -0.65| 5.66|
|ID1| 5| 2| -0.99| 4.33|
|ID1| 6| 2| -0.86| 2.66|
|ID2| 1| 4| 0| 4|
|ID2| 2| 6| 1| 5|
|ID2| 3| 1| -0.59| 3.66|
|ID2| 4| 1| -0.86| 2.66|
|ID2| 5| 4| 0.86| 2|
+---+----+----+----------+---------+
You can do it with built-in functions avg and corr, here the scala solution :
df
.withColumn("indices",row_number().over(Window.partitionBy($"ID").orderBy($"colA")))
.withColumn("corr_last3", when($"indices">1,corr($"indices",$"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow))).otherwise(0.0))
.withColumn("avg_last3", avg($"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow)))
.drop($"indices")
.orderBy($"ID",$"colA")
.show()
gives:
+---+----+----+-------------------+------------------+
| ID|colA|colB| corr_last3| avg_last3|
+---+----+----+-------------------+------------------+
|ID1| 1| 5| 0.0| 5.0|
|ID1| 2| 6| 1.0| 5.5|
|ID1| 3| 7| 1.0| 6.0|
|ID1| 4| 4|-0.6546536707079772| 5.666666666666667|
|ID1| 5| 2|-0.9933992677987828| 4.333333333333333|
|ID1| 6| 2|-0.8660254037844386|2.6666666666666665|
|ID2| 1| 4| 0.0| 4.0|
|ID2| 2| 6| 1.0| 5.0|
|ID2| 3| 1|-0.5960395606792697|3.6666666666666665|
|ID2| 4| 1|-0.8660254037844387|2.6666666666666665|
|ID2| 5| 4| 0.8660254037844387| 2.0|
+---+----+----+-------------------+------------------+
Pyspark version of the answer is this
from pyspark.sql import Window
from pyspark.sql.functions import rank, corr, when, mean, col, round
df = df\
.withColumn("indices",rank().over(Window.partitionBy("ID").orderBy("colA")))\
.withColumn("corr_last3", when(col("indices") > 1, corr(col("indices"), col("colB"))
.over(Window.partitionBy("ID").orderBy("colA")
.rangeBetween(-2, Window.currentRow))).otherwise(0.0))\
.withColumn("avg_last3", mean(col("colB")).over(Window.partitionBy("ID").orderBy("colA").rangeBetween(-2, Window.currentRow)))\
.drop(col("indices"))\
.orderBy("ID","colA")
df = df.withColumn("corr_last3", round(col("corr_last3"), 3))\
.withColumn("avg_last3", round(col("corr_last3"), 3))
df.show()
+---+----+----+----------+---------+
| ID|colA|colB|corr_last3|avg_last3|
+---+----+----+----------+---------+
|ID1| 1| 5| 0.0| 0.0|
|ID1| 2| 6| 1.0| 1.0|
|ID1| 3| 7| 1.0| 1.0|
|ID1| 4| 4| -0.655| -0.655|
|ID1| 5| 2| -0.993| -0.993|
|ID1| 6| 2| -0.866| -0.866|
|ID2| 1| 4| 0.0| 0.0|
|ID2| 2| 6| 1.0| 1.0|
|ID2| 3| 1| -0.596| -0.596|
|ID2| 4| 1| -0.866| -0.866|
|ID2| 5| 4| 0.866| 0.866|
+---+----+----+----------+---------+

Why sum is not displaying after aggregation & pivot?

Here I have student marks like below and I want to transpose subject name column and want to get the total marks also after the pivot.
Source table like:
+---------+-----------+-----+
|StudentId|SubjectName|Marks|
+---------+-----------+-----+
| 1| A| 10|
| 1| B| 20|
| 1| C| 30|
| 2| A| 20|
| 2| B| 25|
| 2| C| 30|
| 3| A| 10|
| 3| B| 20|
| 3| C| 20|
+---------+-----------+-----+
Destination:
+---------+---+---+---+-----+
|StudentId| A| B| C|Total|
+---------+---+---+---+-----+
| 1| 10| 20| 30| 60|
| 3| 10| 20| 20| 50|
| 2| 20| 25| 30| 75|
+---------+---+---+---+-----+
Please find the below source code:
val spark = SparkSession.builder().appName("test").master("local[*]").getOrCreate()
import spark.implicits._
val list = List((1, "A", 10), (1, "B", 20), (1, "C", 30), (2, "A", 20), (2, "B", 25), (2, "C", 30), (3, "A", 10),
(3, "B", 20), (3, "C", 20))
val df = list.toDF("StudentId", "SubjectName", "Marks")
df.show() // source table as per above
val df1 = df.groupBy("StudentId").pivot("SubjectName", Seq("A", "B", "C")).agg(sum("Marks"))
df1.show()
val df2 = df1.withColumn("Total", col("A") + col("B") + col("C"))
df2.show // required destitnation
val df3 = df.groupBy("StudentId").agg(sum("Marks").as("Total"))
df3.show()
df1 is not displaying the sum/total column. it's displaying like below.
+---------+---+---+---+
|StudentId| A| B| C|
+---------+---+---+---+
| 1| 10| 20| 30|
| 3| 10| 20| 20|
| 2| 20| 25| 30|
+---------+---+---+---+
df3 is able to create new Total column but why in df1 it not able to create a new column?
Please, can anybody help me what I missing or anything wrong with my understanding of pivot concept?
This is an expected behaviour from spark pivot function as .agg function is applied on the pivoted columns that's the reason why you are not able to see sum of marks as new column.
Refer to this link for official documentation about pivot.
Example:
scala> df.groupBy("StudentId").pivot("SubjectName").agg(sum("Marks") + 2).show()
+---------+---+---+---+
|StudentId| A| B| C|
+---------+---+---+---+
| 1| 12| 22| 32|
| 3| 12| 22| 22|
| 2| 22| 27| 32|
+---------+---+---+---+
In the above example we have added 2 to all the pivoted columns.
Example2:
To get count using pivot and agg
scala> df.groupBy("StudentId").pivot("SubjectName").agg(count("*")).show()
+---------+---+---+---+
|StudentId| A| B| C|
+---------+---+---+---+
| 1| 1| 1| 1|
| 3| 1| 1| 1|
| 2| 1| 1| 1|
+---------+---+---+---+
The .agg followed by pivot is applicable only for the pivoted data. To find the sum you should you should add new column and sum it as below.
val cols = Seq("A", "B", "C")
val result = df.groupBy("StudentId")
.pivot("SubjectName")
.agg(sum("Marks"))
.withColumn("Total", cols.map(col _).reduce(_ + _))
result.show(false)
Output:
+---------+---+---+---+-----+
|StudentId|A |B |C |Total|
+---------+---+---+---+-----+
|1 |10 |20 |30 |60 |
|3 |10 |20 |20 |50 |
|2 |20 |25 |30 |75 |
+---------+---+---+---+-----+

create unique id for combination of a pair of values from two columns in a spark dataframe

I have a spark dataframe of six columns say (col1, col2,...col6). I want to create a unique id for each combination of values from "col1" and "col2" and add it to the dataframe. Can someone help me with some pyspark code on how to do it?
You can achieve it using monotonically_increasing_id(pyspark >1.6) or monotonicallyIncreasingId(pyspark <1.6)
>>> from pyspark.sql.functions import monotonically_increasing_id
>>> rdd=sc.parallelize([[12,23,3,4,5,6],[12,23,56,67,89,20],[12,23,0,0,0,0],[12,2,12,12,12,23],[1,2,3,4,56,7],[1,2,3,4,56,7]])
>>> df = rdd.toDF(['col_1','col_2','col_3','col_4','col_5','col_6'])
>>> df.show()
+-----+-----+-----+-----+-----+-----+
|col_1|col_2|col_3|col_4|col_5|col_6|
+-----+-----+-----+-----+-----+-----+
| 12| 23| 3| 4| 5| 6|
| 12| 23| 56| 67| 89| 20|
| 12| 23| 0| 0| 0| 0|
| 12| 2| 12| 12| 12| 23|
| 1| 2| 3| 4| 56| 7|
| 1| 2| 3| 4| 56| 7|
+-----+-----+-----+-----+-----+-----+
>>> df_1=df.groupBy(df.col_1,df.col_2).count().withColumn("id", monotonically_increasing_id()).select(['col_1','col_2','id'])
>>> df_1.show()
+-----+-----+-------------+
|col_1|col_2| id|
+-----+-----+-------------+
| 12| 23| 34359738368|
| 1| 2|1434519076864|
| 12| 2|1554778161152|
+-----+-----+-------------+
>>> df.join(df_1,(df.col_1==df_1.col_1) & (df.col_2==df_1.col_2)).drop(df_1.col_1).drop(df_1.col_2).show()
+-----+-----+-----+-----+-----+-----+-------------+
|col_3|col_4|col_5|col_6|col_1|col_2| id|
+-----+-----+-----+-----+-----+-----+-------------+
| 3| 4| 5| 6| 12| 23| 34359738368|
| 56| 67| 89| 20| 12| 23| 34359738368|
| 0| 0| 0| 0| 12| 23| 34359738368|
| 3| 4| 56| 7| 1| 2|1434519076864|
| 3| 4| 56| 7| 1| 2|1434519076864|
| 12| 12| 12| 23| 12| 2|1554778161152|
+-----+-----+-----+-----+-----+-----+-------------+
If you really need to generate the unique ID from col1 and col2 you can also create a hash value leveraging the sha2 function of Spark.
First let's generate some dummy data with:
from random import randint
max_range = 10
df1 = spark.createDataFrame(
[(x, x * randint(1, max_range), x * 10 * randint(1, max_range)) for x in range(1, max_range)],
['C1', 'C2', 'C3'])
>>> df1.show()
+---+---+---+
| C1| C2| C3|
+---+---+---+
| 1| 1| 60|
| 2| 14|180|
| 3| 21|270|
| 4| 16|360|
| 5| 35|250|
| 6| 30|480|
| 7| 28|210|
| 8| 80|320|
| 9| 45|360|
+---+---+---+
Then create a new uid column from columns C2 and C3 with the next code:
from pyspark.sql.functions import col, sha2, concat
df1.withColumn("uid", sha2(concat(col("C2"), col("C3")), 256)).show(10, False)
And the output:
+---+---+---+--------------------+
| C1| C2| C3| uid|
+---+---+---+--------------------+
| 1| 1| 60|a512db2741cd20693...|
| 2| 14|180|2f6543dc6c0e06e4a...|
| 3| 21|270|bd3c65ddde4c6f733...|
| 4| 16|360|c7a1e8c59fc9dcc21...|
| 5| 35|250|cba1aeb7a72d9ae27...|
| 6| 30|480|ad7352ff8927cf790...|
| 7| 28|210|ea7bc25aa7cd3503f...|
| 8| 80|320|02e1d953517339552...|
| 9| 45|360|b485cf8f710a65755...|
+---+---+---+--------------------+

Difference in dense rank and row number in spark

I tried to understand the difference between dense rank and row number.Each new window partition both is starting from 1. Does rank of a row is not always start from 1 ? Any help would be appreciated
The difference is when there are "ties" in the ordering column. Check the example below:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val df = Seq(("a", 10), ("a", 10), ("a", 20)).toDF("col1", "col2")
val windowSpec = Window.partitionBy("col1").orderBy("col2")
df
.withColumn("rank", rank().over(windowSpec))
.withColumn("dense_rank", dense_rank().over(windowSpec))
.withColumn("row_number", row_number().over(windowSpec)).show
+----+----+----+----------+----------+
|col1|col2|rank|dense_rank|row_number|
+----+----+----+----------+----------+
| a| 10| 1| 1| 1|
| a| 10| 1| 1| 2|
| a| 20| 3| 2| 3|
+----+----+----+----------+----------+
Note that the value "10" exists twice in col2 within the same window (col1 = "a"). That's when you see a difference between the three functions.
I'm showing #Daniel's answer in Python and I'm adding a comparison with count('*') that can be used if you want to get top-n at most rows per group.
from pyspark.sql.session import SparkSession
from pyspark.sql import Window
from pyspark.sql import functions as F
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([
['a', 10], ['a', 20], ['a', 30],
['a', 40], ['a', 40], ['a', 40], ['a', 40],
['a', 50], ['a', 50], ['a', 60]], ['part_col', 'order_col'])
window = Window.partitionBy("part_col").orderBy("order_col")
df = (df
.withColumn("rank", F.rank().over(window))
.withColumn("dense_rank", F.dense_rank().over(window))
.withColumn("row_number", F.row_number().over(window))
.withColumn("count", F.count('*').over(window))
)
df.show()
+--------+---------+----+----------+----------+-----+
|part_col|order_col|rank|dense_rank|row_number|count|
+--------+---------+----+----------+----------+-----+
| a| 10| 1| 1| 1| 1|
| a| 20| 2| 2| 2| 2|
| a| 30| 3| 3| 3| 3|
| a| 40| 4| 4| 4| 7|
| a| 40| 4| 4| 5| 7|
| a| 40| 4| 4| 6| 7|
| a| 40| 4| 4| 7| 7|
| a| 50| 8| 5| 8| 9|
| a| 50| 8| 5| 9| 9|
| a| 60| 10| 6| 10| 10|
+--------+---------+----+----------+----------+-----+
For example if you want to take at most 4 without randomly picking one of the 4 "40" of the sorting column:
df.where("count <= 4").show()
+--------+---------+----+----------+----------+-----+
|part_col|order_col|rank|dense_rank|row_number|count|
+--------+---------+----+----------+----------+-----+
| a| 10| 1| 1| 1| 1|
| a| 20| 2| 2| 2| 2|
| a| 30| 3| 3| 3| 3|
+--------+---------+----+----------+----------+-----+
In summary, if you filter <= n those columns you will get:
rank at least n rows
dense_rank at least n different order_col values
row_number exactly n rows
count at most n rows

Resources