Is there a way to replace null values in pyspark dataframe with the last valid value? There is addtional timestamp and session columns if you think you need them for windows partitioning and ordering. More specifically, I'd like to achieve the following conversion:
+---------+-----------+-----------+ +---------+-----------+-----------+
| session | timestamp | id| | session | timestamp | id|
+---------+-----------+-----------+ +---------+-----------+-----------+
| 1| 1| null| | 1| 1| null|
| 1| 2| 109| | 1| 2| 109|
| 1| 3| null| | 1| 3| 109|
| 1| 4| null| | 1| 4| 109|
| 1| 5| 109| => | 1| 5| 109|
| 1| 6| null| | 1| 6| 109|
| 1| 7| 110| | 1| 7| 110|
| 1| 8| null| | 1| 8| 110|
| 1| 9| null| | 1| 9| 110|
| 1| 10| null| | 1| 10| 110|
+---------+-----------+-----------+ +---------+-----------+-----------+
This uses last and ignores nulls.
Let's re-create something similar to the original data:
import sys
from pyspark.sql.window import Window
import pyspark.sql.functions as func
d = [{'session': 1, 'ts': 1}, {'session': 1, 'ts': 2, 'id': 109}, {'session': 1, 'ts': 3}, {'session': 1, 'ts': 4, 'id': 110}, {'session': 1, 'ts': 5}, {'session': 1, 'ts': 6}]
df = spark.createDataFrame(d)
df.show()
# +-------+---+----+
# |session| ts| id|
# +-------+---+----+
# | 1| 1|null|
# | 1| 2| 109|
# | 1| 3|null|
# | 1| 4| 110|
# | 1| 5|null|
# | 1| 6|null|
# +-------+---+----+
Now, let's use window function last:
df.withColumn("id", func.last('id', True).over(Window.partitionBy('session').orderBy('ts').rowsBetween(-sys.maxsize, 0))).show()
# +-------+---+----+
# |session| ts| id|
# +-------+---+----+
# | 1| 1|null|
# | 1| 2| 109|
# | 1| 3| 109|
# | 1| 4| 110|
# | 1| 5| 110|
# | 1| 6| 110|
# +-------+---+----+
This seems to be doing the trick using Window functions:
import sys
from pyspark.sql.window import Window
import pyspark.sql.functions as func
def fill_nulls(df):
df_na = df.na.fill(-1)
lag = df_na.withColumn('id_lag', func.lag('id', default=-1)\
.over(Window.partitionBy('session')\
.orderBy('timestamp')))
switch = lag.withColumn('id_change',
((lag['id'] != lag['id_lag']) &
(lag['id'] != -1)).cast('integer'))
switch_sess = switch.withColumn(
'sub_session',
func.sum("id_change")
.over(
Window.partitionBy("session")
.orderBy("timestamp")
.rowsBetween(-sys.maxsize, 0))
)
fid = switch_sess.withColumn('nn_id',
func.first('id')\
.over(Window.partitionBy('session', 'sub_session')\
.orderBy('timestamp')))
fid_na = fid.replace(-1, 'null')
ff = fid_na.drop('id').drop('id_lag')\
.drop('id_change')\
.drop('sub_session').\
withColumnRenamed('nn_id', 'id')
return ff
Here is the full null_test.py.
#Oleksiy's answer is great, but didn't fully work for my requirements. Within a session, if multiple nulls are observed, all are filled with the first non-null for the session. I needed the last non-null value to propagate forward.
The following tweak worked for my use case:
def fill_forward(df, id_column, key_column, fill_column):
# Fill null's with last *non null* value in the window
ff = df.withColumn(
'fill_fwd',
func.last(fill_column, True) # True: fill with last non-null
.over(
Window.partitionBy(id_column)
.orderBy(key_column)
.rowsBetween(-sys.maxsize, 0))
)
# Drop the old column and rename the new column
ff_out = ff.drop(fill_column).withColumnRenamed('fill_fwd', fill_column)
return ff_out
Here is the trick I followed by converting pyspark dataframe into pandas dataframe and doing the operation as pandas has built-in function to fill null values with previously known good value. And changing it back to pyspark dataframe.
Here is the code!!
d = [{'session': 1, 'ts': 1}, {'session': 1, 'ts': 2, 'id': 109}, {'session': 1, 'ts': 3}, {'session': 1, 'ts': 4, 'id': 110}, {'session': 1, 'ts': 5}, {'session': 1, 'ts': 6},{'session': 1, 'ts': 7, 'id': 110},{'session': 1, 'ts': 8},{'session': 1, 'ts': 9},{'session': 1, 'ts': 10}]\
dt = spark.createDataFrame(d)
import pandas as pd\
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
psdf= dt.select("*").toPandas()\
psdf["id"].fillna(method='ffill', inplace=True)\
dt= spark.createDataFrame(psdf)\
dt.show()
Related
I am trying to achieve the expected output shown here:
+---+-----+--------+--------+--------+----+
| ID|State| Time|Expected|lagState|rank|
+---+-----+--------+--------+--------+----+
| 1| P|20220722| 1| null| 1|
| 1| P|20220723| 2| P| 2|
| 1| P|20220724| 3| P| 3|
| 1| P|20220725| 4| P| 4|
| 1| D|20220726| 1| P| 1|
| 1| O|20220727| 1| D| 1|
| 1| D|20220728| 1| O| 1|
| 1| P|20220729| 2| D| 1|
| 1| P|20220730| 3| P| 9|
| 1| P|20220731| 4| P| 10|
+---+-----+--------+--------+--------+----+
# create df
df = spark.createDataFrame(sc.parallelize([
[1, 'P', 20220722, 1],
[1, 'P', 20220723, 2],
[1, 'P', 20220724, 3],
[1, 'P', 20220725, 4],
[1, 'D', 20220726, 1],
[1, 'O', 20220727, 1],
[1, 'D', 20220728, 1],
[1, 'P', 20220729, 2],
[1, 'P', 20220730, 3],
[1, 'P', 20220731, 4],
]),
['ID', 'State', 'Time', 'Expected'])
# lag
df = df.withColumn('lagState', F.lag('State').over(w.partitionBy('id').orderBy('time')))
# rn
df = df.withColumn('rank', F.when( F.col('State') == F.col('lagState'), F.rank().over(w.partitionBy('id').orderBy('time', 'state'))).otherwise(1))
# view
df.show()
The general problem is that the tail of the DF is not resetting to the expected value as hoped.
data_sdf. \
withColumn('st_notsame',
func.coalesce(func.col('state') != func.lag('state').over(wd.partitionBy('id').orderBy('time')),
func.lit(False)).cast('int')
). \
withColumn('rank_temp',
func.sum('st_notsame').over(wd.partitionBy('id').orderBy('time').rowsBetween(-sys.maxsize, 0))
). \
withColumn('rank',
func.row_number().over(wd.partitionBy('id', 'rank_temp').orderBy('time'))
). \
show()
# +---+-----+--------+--------+----------+---------+----+
# | id|state| time|expected|st_notsame|rank_temp|rank|
# +---+-----+--------+--------+----------+---------+----+
# | 1| P|20220722| 1| 0| 0| 1|
# | 1| P|20220723| 2| 0| 0| 2|
# | 1| P|20220724| 3| 0| 0| 3|
# | 1| P|20220725| 4| 0| 0| 4|
# | 1| D|20220726| 1| 1| 1| 1|
# | 1| O|20220727| 1| 1| 2| 1|
# | 1| D|20220728| 1| 1| 3| 1|
# | 1| P|20220729| 2| 1| 4| 1|
# | 1| P|20220730| 3| 0| 4| 2|
# | 1| P|20220731| 4| 0| 4| 3|
# +---+-----+--------+--------+----------+---------+----+
your expected field looks a little incorrect. I believe the rank against "20220729" should be 1.
you first flag all the consecutive occurrences of the state as 0 and others as 1 - this'll enable you to do a running sum
use the sum window with infinite lookback for each id to get a temp rank
use the temp rank as a partition column to be used for row_number()
I'm new to spark, and I am trying to calculate a window running sum that is floored by 0 and ceiled by 8
a toy example is given below (note that the actual data is closer to millions of rows):
import pyspark.sql.functions as F
from pyspark.sql import Window
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType
pdf = pd.DataFrame({'ids': [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
'day': [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4],
'counts': [-3, 3, -6, 3, 3, 6, -3, -6, 3, 3, 3, -3]})
sdf = spark.createDataFrame(pdf)
sdf = sdf.orderBy(sdf.ids,sdf.day)
This creates the table
+----+---+-------+
|aIds|day|eCounts|
+----+---+-------+
| 1| 1| -3|
| 1| 2| 3|
| 1| 3| -6|
| 1| 4| 3|
| 2| 1| 3|
| 2| 2| 6|
| 2| 3| -3|
| 2| 4| -6|
| 3| 1| 3|
| 3| 2| 3|
| 3| 3| 3|
| 3| 4| -3|
+----+---+-------+
Below is an example of the result of doing a running sum, and the expected output runSumCap
+----+---+-------+------+---------+
|aIds|day|eCounts|runSum|runSumCap|
+----+---+-------+------+---------+
| 1| 1| -3| -3| 0| <-- reset to 0
| 1| 2| 3| 0| 3|
| 1| 3| -6| -6| 0| <-- reset to 0
| 1| 4| 3| -3| 3|
| 2| 1| 3| 3| 3|
| 2| 2| 6| 9| 8| <-- reset to 8
| 2| 3| -3| 6| 5|
| 2| 4| -6| 0| 0| <-- reset to 0
| 3| 1| 3| 3| 3|
| 3| 2| 3| 6| 6|
| 3| 3| 3| 9| 8| <-- reset to 8
| 3| 4| -3| 6| 5|
+----+---+-------+------+---------+
i know i can calculate the running sum as
partition = Window.partitionBy('aIds').orderBy('aIds','day').rowsBetween(Window.unboundedPreceding, Window.currentRow)`
sdf1 = sdf.withColumn('runSum',F.sum(sdf.eCounts).over(partition))
sdf1.orderBy('aIds','day').show()
To achieve the expected I have tried looking into #pandas_udf to modify the sum:
#pandas_udf('double', PandasUDFType.GROUPED_AGG)
def runSumCap(counts):
#counts columns is passed as a pandas series
floor = 0
cap = 8
runSum = 0
runSumList = []
for count in counts.tolist():
runSum = runSum + count
if(runSum > cap):
runSum = 8
elif(runSum < floor ):
runSum = 0
runSumList += [runSum]
return pd.Series(runSumList)
partition = Window.partitionBy('aIds').orderBy('aIds','day').rowsBetween(Window.unboundedPreceding, Window.currentRow)
sdf1 = sdf.withColumn('runSum',runSumCap(sdf['counts']).over(partition))
However this does not work, and it does not seem like the most efficient way to do this.
How can i make this work? Is there a way to keep it parallel, or do i have to go to pandas dataframes
EDIT:
Came with some clarifications about present columns to order the dataset by, and some more insights into what I am trying to achieve
EDIT2:
The answer that was provided by #DrChess almost yields the correct result, but the series isn't matching the correct day for some reason:
+----+---+-------+------+
|aIds|day|eCounts|runSum|
+----+---+-------+------+
| 1| 1| -3| 0|
| 1| 2| 3| 0|
| 1| 3| -6| 3|
| 1| 4| 3| 3|
| 2| 1| 3| 3|
| 2| 2| 6| 8|
| 2| 3| -3| 0|
| 2| 4| -6| 5|
| 3| 1| 3| 6|
| 3| 2| 3| 3|
| 3| 3| 3| 8|
| 3| 4| -3| 5|
+----+---+-------+------+
I found a way to do this by first making an array in each row (using collect_list as a window function) containing the values used to make the running sum up until that point.
I then defined an udf (couldn't make this work with pandas_udf) and this worked.
Below is full reproducible example:
import pyspark.sql.functions as F
from pyspark.sql import Window
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import *
import numpy as np
def accumalate(iterable):
total = 0
ceil = 8
floor = 0
for element in iterable:
total = total + element
if (total > ceil):
total = ceil
elif (total < floor):
total = floor
return total
pdf = pd.DataFrame({'aIds': [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
'day': [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4],
'eCounts': [-3, 3, -6, 3, 3, 6, -3, -6, 3, 3, 3, -3]})
sdf = spark.createDataFrame(pdf)
sdf = sdf.orderBy(sdf.aIds,sdf.day)
runSumCap = F.udf(accumalate,LongType())
partition = Window.partitionBy('aIds').orderBy('aIds','day').rowsBetween(Window.unboundedPreceding, Window.currentRow)
sdf1 = sdf.withColumn('splitWindow',F.collect_list(sdf.eCounts).over(partition))
sdf2 = sdf1.withColumn('runSumCap',runSumCap(sdf1.splitWindow))
sdf2.orderBy('aIds','day').show()
This yields the expected result:
+----+---+-------+--------------+---------+
|aIds|day|eCounts| splitWindow|runSumCap|
+----+---+-------+--------------+---------+
| 1| 1| -3| [-3]| 0|
| 1| 2| 3| [-3, 3]| 3|
| 1| 3| -6| [-3, 3, -6]| 0|
| 1| 4| 3|[-3, 3, -6, 3]| 3|
| 2| 1| 3| [3]| 3|
| 2| 2| 6| [3, 6]| 8|
| 2| 3| -3| [3, 6, -3]| 5|
| 2| 4| -6|[3, 6, -3, -6]| 0|
| 3| 1| 3| [3]| 3|
| 3| 2| 3| [3, 3]| 6|
| 3| 3| 3| [3, 3, 3]| 8|
| 3| 4| -3| [3, 3, 3, -3]| 5|
+----+---+-------+--------------+---------+
Unfortunately window functions with pandas_udf of type GROUPED_AGG do not work with bounded window functions (.rowsBetween(Window.unboundedPreceding, Window.currentRow)). It currently only works with unbounded windows, namely .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing). Additionally the input is a pandas.Series but the output should be a constant of the provided type. Therefore you won't be able to achieve partial aggregations with that.
Instead you could use GROUPED_MAP pandas_udf which works with df.groupBy().apply().
Here some code:
#pandas_udf('ids integer, day integer, counts integer, runSum integer', PandasUDFType.GROUPED_MAP)
def runSumCap(pdf):
def _apply_on_series(counts):
floor = 0
cap = 8
runSum = 0
runSumList = []
for count in counts.tolist():
runSum = runSum + count
if(runSum > cap):
runSum = 8
elif(runSum < floor ):
runSum = 0
runSumList += [runSum]
return pd.Series(runSumList)
pdf.sort_values(by=['day'], inplace=True)
pdf['runSum'] = _apply_on_series(pdf['counts'])
return pdf
sdf1 = sdf.groupBy('ids').apply(runSumCap)
This question already has answers here:
Spark Equivalent of IF Then ELSE
(4 answers)
Closed 3 years ago.
Given a pyspark dataframe, for example:
ls = [
['1', 2],
['2', 7],
['1', 3],
['2',-6],
['1', 3],
['1', 5],
['1', 4],
['2', 7]
]
df = spark.createDataFrame(pd.DataFrame(ls, columns=['col1', 'col2']))
df.show()
+----+-----+
|col1| col2|
+----+-----+
| 1| 2|
| 2| 7|
| 1| 3|
| 2| -6|
| 1| 3|
| 1| 5|
| 1| 4|
| 2| 7|
+----+-----+
How can I apply a function to col2 values where col1 == '1' and store result in a new column?
For example the function is:
f = x**2
Result should look like this:
+----+-----+-----+
|col1| col2| y|
+----+-----+-----+
| 1| 2| 4|
| 2| 7| null|
| 1| 3| 9|
| 2| -6| null|
| 1| 3| 9|
| 1| 5| 25|
| 1| 4| 16|
| 2| 7| null|
+----+-----+-----+
I tried defining a separate function, and use df.withColumn(y).when(condition,function) but it wouldn't work.
So what is a way to do this?
I hope this helps:
def myFun(x):
return (x**2).cast(IntegerType())
df2 = df.withColumn("y", when(df.col1 == 1, myFun(df.col2)).otherwise(None))
df2.show()
+----+----+----+
|col1|col2| y|
+----+----+----+
| 1| 2| 4|
| 2| 7|null|
| 1| 3| 9|
| 2| -6|null|
| 1| 3| 9|
| 1| 5| 25|
| 1| 4| 16|
| 2| 7|null|
+----+----+----+
I have the following DataFrame ordered by group, n1, n2
+-----+--+--+------+------+
|group|n1|n2|n1_ptr|n2_ptr|
+-----+--+--+------+------+
| 1| 0| 0| 1| 1|
| 1| 1| 1| 2| 2|
| 1| 1| 5| 2| 6|
| 1| 2| 2| 3| 3|
| 1| 2| 6| 3| 7|
| 1| 3| 3| 4| 4|
| 1| 3| 7| null| null|
| 1| 4| 4| 5| 5|
| 1| 5| 1| null| null|
| 1| 5| 5| null| null|
+-----+--+--+------+------+
Each row's n1_ptr and n2_ptr values refer to the n1 and n2 values of some other row in the group that comes later in the ordering. In other words, n1_ptr and n2_ptr are effectively pointers to another row. I want to use these pointers to identify chains of (n1, n2) pairs. For example, the chains in the given data would be: (0,0) -> (1,1) -> (2,2) -> (3,3) -> (4,4) -> (5,5); (1,5) -> (2,6) -> (3,7); and (5,1).
The ultimate goal is to consolidate each chain into a single row in a DataFrame describing the min and max n1 and n2 values in each chain. Continuing the example, this would yield
+-----+------+------+------+------+
|group|n1_min|n2_min|n1_max|n2_max|
+-----+------+------+------+------+
| 1| 0| 0| 5| 5|
| 1| 1| 5| 3| 7|
| 1| 5| 1| 5| 1|
+-----+------+------+------+------+
It seems like a udf might do the trick, but I am concerned about performance. Is there a more sensible/performant way to go about this?
A good solution would be to use graphframes: https://graphframes.github.io/quick-start.html.
First let's change the structure of your initial dataframe:
import pyspark.sql.functions as psf
df = sc.parallelize([[1, 0, 0, 1, 1],[1, 1, 1, 2, 2],[1, 1, 5, 2, 6],
[1, 2, 2, 3, 3],[1, 2, 6, 3, 7],[1, 3, 3, 4, 4],
[1, 3, 7, None, None],[1, 4, 4, 5, 5],[1, 5, 1, None, None],
[1, 5, 5, None, None]]).toDF(["group","n1","n2","n1_ptr","n2_ptr"]).filter("n1_ptr IS NOT NULL")
df = df.select(
"group",
psf.struct("n1", "n2").alias("src"),
psf.struct(df.n1_ptr.alias("n1"), df.n2_ptr.alias("n2")).alias("dst"))
From df we'll build a vertex and an edge dataframe:
v = df.select(
"group",
psf.explode(psf.array("src", "dst")).alias("id"))
e = df.drop("group")
The next step is to find all connected components using graphframes:
from graphframes import *
g = GraphFrame(v, e)
res = g.connectedComponents()
+-----+-----+------------+
|group| id| component|
+-----+-----+------------+
| 1|[0,0]|309237645312|
| 1|[1,1]|309237645312|
| 1|[1,1]|309237645312|
| 1|[2,2]|309237645312|
| 1|[1,5]| 85899345920|
| 1|[2,6]| 85899345920|
| 1|[2,2]|309237645312|
| 1|[3,3]|309237645312|
| 1|[2,6]| 85899345920|
| 1|[3,7]| 85899345920|
| 1|[3,3]|309237645312|
| 1|[4,4]|309237645312|
| 1|[3,7]| 85899345920|
| 1|[4,4]|309237645312|
| 1|[5,5]|309237645312|
| 1|[5,1]|292057776128|
| 1|[5,5]|309237645312|
+-----+-----+------------+
Now since the relation in your graph edges implies that nodes numbers n1 and n2 are monotonically increasing, we can simply aggregate by component and compute the min and max:
res.groupBy("group", "component").agg(
psf.min("id").alias("min_id"),
psf.max("id").alias("max_id")
)
+-----+------------+------+------+
|group| component|min_id|max_id|
+-----+------------+------+------+
| 1|309237645312| [0,0]| [5,5]|
| 1| 85899345920| [1,5]| [3,7]|
| 1|292057776128| [5,1]| [5,1]|
+-----+------------+------+------+
I tried to understand the difference between dense rank and row number.Each new window partition both is starting from 1. Does rank of a row is not always start from 1 ? Any help would be appreciated
The difference is when there are "ties" in the ordering column. Check the example below:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val df = Seq(("a", 10), ("a", 10), ("a", 20)).toDF("col1", "col2")
val windowSpec = Window.partitionBy("col1").orderBy("col2")
df
.withColumn("rank", rank().over(windowSpec))
.withColumn("dense_rank", dense_rank().over(windowSpec))
.withColumn("row_number", row_number().over(windowSpec)).show
+----+----+----+----------+----------+
|col1|col2|rank|dense_rank|row_number|
+----+----+----+----------+----------+
| a| 10| 1| 1| 1|
| a| 10| 1| 1| 2|
| a| 20| 3| 2| 3|
+----+----+----+----------+----------+
Note that the value "10" exists twice in col2 within the same window (col1 = "a"). That's when you see a difference between the three functions.
I'm showing #Daniel's answer in Python and I'm adding a comparison with count('*') that can be used if you want to get top-n at most rows per group.
from pyspark.sql.session import SparkSession
from pyspark.sql import Window
from pyspark.sql import functions as F
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([
['a', 10], ['a', 20], ['a', 30],
['a', 40], ['a', 40], ['a', 40], ['a', 40],
['a', 50], ['a', 50], ['a', 60]], ['part_col', 'order_col'])
window = Window.partitionBy("part_col").orderBy("order_col")
df = (df
.withColumn("rank", F.rank().over(window))
.withColumn("dense_rank", F.dense_rank().over(window))
.withColumn("row_number", F.row_number().over(window))
.withColumn("count", F.count('*').over(window))
)
df.show()
+--------+---------+----+----------+----------+-----+
|part_col|order_col|rank|dense_rank|row_number|count|
+--------+---------+----+----------+----------+-----+
| a| 10| 1| 1| 1| 1|
| a| 20| 2| 2| 2| 2|
| a| 30| 3| 3| 3| 3|
| a| 40| 4| 4| 4| 7|
| a| 40| 4| 4| 5| 7|
| a| 40| 4| 4| 6| 7|
| a| 40| 4| 4| 7| 7|
| a| 50| 8| 5| 8| 9|
| a| 50| 8| 5| 9| 9|
| a| 60| 10| 6| 10| 10|
+--------+---------+----+----------+----------+-----+
For example if you want to take at most 4 without randomly picking one of the 4 "40" of the sorting column:
df.where("count <= 4").show()
+--------+---------+----+----------+----------+-----+
|part_col|order_col|rank|dense_rank|row_number|count|
+--------+---------+----+----------+----------+-----+
| a| 10| 1| 1| 1| 1|
| a| 20| 2| 2| 2| 2|
| a| 30| 3| 3| 3| 3|
+--------+---------+----+----------+----------+-----+
In summary, if you filter <= n those columns you will get:
rank at least n rows
dense_rank at least n different order_col values
row_number exactly n rows
count at most n rows