window function on a subset of data - apache-spark

I have a table like the below. I want to calculate an average of median but only for Q=2 and Q=3. I don't want to include other Qs but still preserve the data.
df = spark.createDataFrame([('2018-03-31',6,1),('2018-03-31',27,2),('2018-03-31',3,3),('2018-03-31',44,4),('2018-06-30',6,1),('2018-06-30',4,3),('2018-06-30',32,2),('2018-06-30',112,4),('2018-09-30',2,1),('2018-09-30',23,4),('2018-09-30',37,3),('2018-09-30',3,2)],['date','median','Q'])
+----------+--------+---+
| date| median | Q |
+----------+--------+---+
|2018-03-31| 6| 1|
|2018-03-31| 27| 2|
|2018-03-31| 3| 3|
|2018-03-31| 44| 4|
|2018-06-30| 6| 1|
|2018-06-30| 4| 3|
|2018-06-30| 32| 2|
|2018-06-30| 112| 4|
|2018-09-30| 2| 1|
|2018-09-30| 23| 4|
|2018-09-30| 37| 3|
|2018-09-30| 3| 2|
+----------+--------+---+
Expected output:
+----------+--------+---+------------+
| date| median | Q |result |
+----------+--------+---+------------+
|2018-03-31| 6| 1| null|
|2018-03-31| 27| 2| 15|
|2018-03-31| 3| 3| 15|
|2018-03-31| 44| 4| null|
|2018-06-30| 6| 1| null|
|2018-06-30| 4| 3| 18|
|2018-06-30| 32| 2| 18|
|2018-06-30| 112| 4| null|
|2018-09-30| 2| 1| null|
|2018-09-30| 23| 4| null|
|2018-09-30| 37| 3| 20|
|2018-09-30| 3| 2| 20|
+----------+--------+---+------------+
OR
+----------+--------+---+------------+
| date| median | Q |result |
+----------+--------+---+------------+
|2018-03-31| 6| 1| 15|
|2018-03-31| 27| 2| 15|
|2018-03-31| 3| 3| 15|
|2018-03-31| 44| 4| 15|
|2018-06-30| 6| 1| 18|
|2018-06-30| 4| 3| 18|
|2018-06-30| 32| 2| 18|
|2018-06-30| 112| 4| 18|
|2018-09-30| 2| 1| 20|
|2018-09-30| 23| 4| 20|
|2018-09-30| 37| 3| 20|
|2018-09-30| 3| 2| 20|
+----------+--------+---+------------+
I tried the following code but when I include the where statement it drops Q=1 and Q=4.
window = (
Window
.partitionBy("date")
.orderBy("date")
)
df_avg = (
df
.where(
(F.col("Q") == 2) |
(F.col("Q") == 3)
)
.withColumn("result", F.avg("median").over(window))
)

For both of your expected output, you can use conditional aggregation, use avg with when (otherwise).
If you want the 1st expected output.
window = (
Window
.partitionBy("date", F.col("Q").isin([2, 3]))
)
df_avg = (
df.withColumn("result", F.when(F.col("Q").isin([2, 3]), F.avg("median").over(window)))
)
For the 2nd expected output.
window = (
Window
.partitionBy("date")
)
df_avg = (
df.withColumn("result", F.avg(F.when(F.col("Q").isin([2, 3]), F.col("median"))).over(window))
)

Alternatively, since you are really aggregating a (small?) subset, replace window with auto-join:
>>> df_avg = df.where(col("Q").isin([2,3])).groupBy("date","Q").agg(avg("median").alias("result"))
>>> df_result = df.join(df_avg,["date","Q"],"left")
Might turn out to be faster than using window.

Related

subtract column values over a window function

This is probably very simple but for some reason can't figure it out. I have a df that looks like this. Need to create a column where I subtract median 2&3avg col from median col when Q = 1. I assume you need to use window function and then create a UDF to subtract the columns? Or am I overcomplicating this?
+----------+--------+---+------------+--------+
| date| median | Q | 2&3avg | result |
+----------+--------+---+------------+--------+
|2018-03-31| 6| 1| 15| -9|
|2018-03-31| 27| 2| 15| -9|
|2018-03-31| 3| 3| 15| -9|
|2018-03-31| 44| 4| 15| -9|
|2018-06-30| 6| 1| 18| -12|
|2018-06-30| 4| 3| 18| -12|
|2018-06-30| 32| 2| 18| -12|
|2018-06-30| 112| 4| 18| -12|
|2018-09-30| 2| 1| 20| -18|
|2018-09-30| 23| 4| 20| -18|
|2018-09-30| 37| 3| 20| -18|
|2018-09-30| 3| 2| 20| -18|
+----------+--------+---+------------+--------+
You can indeed calculate the median using a window function. To avoid using an UDF, you can use the expr to get the percentile. Then you can use a simple when to do the calculation only in Q = 1:
from pyspark.sql import Window
import pyspark.sql.functions as F
w = W.partitionBy('Q')
percentile = F.expr('percentile_approx(2and3avg, 0.5)')
df\
.withColumn('med_val', percentile.over(w))\
.withColumn('new_col', F.when(F.col('Q') == 1, F.col('med_val')-F.col('median')
).otherwise(F.lit('N/A')))\
.orderBy('date', 'Q')\
.show()
+----------+------+---+--------+------+-------+-------+-------+
| date|median| Q|2and3avg|result|2result|med_val|new_col|
+----------+------+---+--------+------+-------+-------+-------+
|2018-03-31| 6| 1| 15| -9| -9| 18| 12|
|2018-03-31| 27| 2| 15| -9| -9| 18| N/A|
|2018-03-31| 3| 3| 15| -9| -9| 18| N/A|
|2018-03-31| 44| 4| 15| -9| -9| 18| N/A|
|2018-06-30| 6| 1| 18| -12| -12| 18| 12|
|2018-06-30| 32| 2| 18| -12| -12| 18| N/A|
|2018-06-30| 4| 3| 18| -12| -12| 18| N/A|
|2018-06-30| 112| 4| 18| -12| -12| 18| N/A|
|2018-09-30| 2| 1| 20| -18| -18| 18| 16|
|2018-09-30| 3| 2| 20| -18| -18| 18| N/A|
|2018-09-30| 37| 3| 20| -18| -18| 18| N/A|
|2018-09-30| 23| 4| 20| -18| -18| 18| N/A|
+----------+------+---+--------+------+-------+-------+-------+
you can use when otherwise function of pyspark
since you have to substract only where Q==1 so for any other value result will be null
>>> from pyspark.sql.functions import when
>>> df.withColumn("result",when(col('Q')==1,(df['median']-df['2_3avg'])).otherwise("nulll")).show()
+----------+------+---+------+------+
| date|median| Q|2_3avg|result|
+----------+------+---+------+------+
|2018-03-31| 6| 1| 15| -9|
|2018-03-31| 27| 2| 15| nulll|
|2018-03-31| 3| 3| 15| nulll|
|2018-03-31| 44| 4| 15| nulll|
|2018-06-30| 6| 1| 18| -12|
|2018-06-30| 4| 3| 18| nulll|
|2018-06-30| 32| 2| 18| nulll|
|2018-06-30| 112| 4| 18| nulll|
|2018-09-30| 2| 1| 20| -18|
|2018-09-30| 23| 4| 20| nulll|
|2018-09-30| 37| 3| 20| nulll|
|2018-09-30| 3| 2| 20| nulll|
+----------+------+---+------+------+
i think this answers your question
>>> df1=df.withColumn("result",when(col('Q')==1,(df['median']-df['2_3avg'])).otherwise("nulll"))
>>> df1.groupby("date").agg(F.collect_list("median"),F.collect_list("Q"),F.collect_list("2_3avg"),F.collect_list("result")).show()
+----------+--------------------+---------------+--------------------+--------------------+
| date|collect_list(median)|collect_list(Q)|collect_list(2_3avg)|collect_list(result)|
+----------+--------------------+---------------+--------------------+--------------------+
|2018-06-30| [6, 4, 32, 112]| [1, 3, 2, 4]| [18, 18, 18, 18]|[-12, nulll, null...|
|2018-03-31| [6, 27, 3, 44]| [1, 2, 3, 4]| [15, 15, 15, 15]|[-9, nulll, nulll...|
|2018-09-30| [2, 23, 37, 3]| [1, 4, 3, 2]| [20, 20, 20, 20]|[-18, nulll, null...|
+----------+--------------------+---------------+--------------------+--------------------+
Adding to answer by Sachin Assuming Q=1 is smallest
dql.withColumn("result",when(col('Q')==1,(dql['median']-dql['2&3avg'])).otherwise("nulll")).withColumn("result",first("result",True).over(Window.partitionBy("date").orderBy("Q"))).show()
#output
+----------+------+---+------+------+
| date|median| Q|2&3avg|result|
+----------+------+---+------+------+
|2018-03-31| 6| 1| 15| -9.0|
|2018-03-31| 27| 2| 15| -9.0|
|2018-03-31| 3| 3| 15| -9.0|
|2018-03-31| 44| 4| 15| -9.0|
|2018-06-30| 6| 1| 18| -12.0|
|2018-06-30| 32| 2| 18| -12.0|
|2018-06-30| 4| 3| 18| -12.0|
|2018-06-30| 112| 4| 18| -12.0|
|2018-09-30| 2| 1| 20| -18.0|
|2018-09-30| 3| 2| 20| -18.0|
|2018-09-30| 37| 3| 20| -18.0|
|2018-09-30| 23| 4| 20| -18.0|
+----------+------+---+------+------+

Creating a custom counter in Spark based on dataframe conditions

Current Dataset
+---+-----+-----+-----+----+
| ID|Event|Index|start| end|
+---+-----+-----+-----+----+
| 1| run| 0|start|null|
| 1| run| 1| null|null|
| 1| run| 2| null|null|
| 1| swim| 3| null| end|
| 1| run| 4|start|null|
| 1| swim| 5| null|null|
| 1| swim| 6| null| end|
| 1| run| 7|start|null|
| 1| run| 8| null|null|
| 1| run| 9| null|null|
| 1| swim| 10| null| end|
| 1| run| 11|start|null|
| 1| run| 12| null|null|
| 1| run| 13| null| end|
| 2| run| 14|start|null|
| 2| run| 15| null|null|
| 2| run| 16| null|null|
| 2| swim| 17| null| end|
| 2| run| 18|start|null|
| 2| swim| 19| null|null|
| 2| swim| 20| null|null|
| 2| swim| 21| null|null|
| 2| swim| 22| null| end|
| 2| run| 23|start|null|
| 2| run| 24| null|null|
| 2| run| 25| null| end|
| 3| run| 26|start|null|
| 3| run| 27| null|null|
| 3| swim| 28| null|null|
+---+-----+-----+-----+----+
Dataset I'm After
+---+-----+-----+-----+----+-------+
| ID|Event|Index|start| end|EventID|
+---+-----+-----+-----+----+-------+
| 1| run| 0|start|null| 1|
| 1| run| 1| null|null| 1|
| 1| run| 2| null|null| 1|
| 1| swim| 3| null| end| 1|
| 1| run| 4|start|null| 2|
| 1| swim| 5| null|null| 2|
| 1| swim| 6| null| end| 2|
| 1| run| 7|start|null| 3|
| 1| run| 8| null|null| 3|
| 1| run| 9| null|null| 3|
| 1| swim| 10| null| end| 3|
| 1| run| 11|start|null| 4|
| 1| run| 12| null|null| 4|
| 1| run| 13| null| end| 4|
| 2| run| 14|start|null| 1|
| 2| run| 15| null|null| 1|
| 2| run| 16| null|null| 1|
| 2| swim| 17| null| end| 1|
| 2| run| 18|start|null| 2|
| 2| swim| 19| null|null| 2|
| 2| swim| 20| null|null| 2|
| 2| swim| 21| null|null| 2|
| 2| swim| 22| null| end| 2|
| 2| run| 23|start|null| 3|
| 2| run| 24| null|null| 3|
| 2| run| 25| null| end| 3|
| 3| run| 26|start|null| 1|
| 3| run| 27| null|null| 1|
| 3| swim| 28| null|null| 1|
+---+-----+-----+-----+----+-------+
I am trying to create the above EventID Column. Is there a way to create a counter inside of a udf that updates based on column conditions? Note, I'm not sure if a UDF is the best approach here.
Here is my current thinking-logic:
When a "start" value is seen, start counting.
When an "end" value is seen, end counting
Every time a new ID is seen, reset the counter to 1
Thank you all for any assistance.
Here is the raw code to produce the current dataframe:
# Current Dataset
data = [
(1, "run", 0, 'start', None),
(1, "run", 1, None, None),
(1, "run", 2, None, None),
(1, "swim", 3, None, 'end'),
(1, "run", 4, 'start',None),
(1, "swim", 5, None, None),
(1, "swim", 6, None, 'end'),
(1, "run",7, 'start', None),
(1, "run",8, None, None),
(1, "run",9, None, None),
(1, "swim",10, None, 'end'),
(1, "run",11, 'start', None),
(1, "run",12, None, None),
(1, "run",13, None, 'end'),
(2, "run",14, 'start', None),
(2, "run",15, None, None),
(2, "run",16, None, None),
(2, "swim",17, None, 'end'),
(2, "run",18, 'start', None),
(2, "swim",19, None, None),
(2, "swim",20, None, None),
(2, "swim",21, None, None),
(2, "swim",22, None, 'end'),
(2, "run",23, 'start', None),
(2, "run",24, None, None),
(2, "run",25, None, 'end'),
(3, "run",26, 'start', None),
(3, "run",27, None, None),
(3, "swim",28, None, None)
]
schema = StructType([
StructField('ID', IntegerType(),True), \
StructField('Event', StringType(),True), \
StructField('Index', IntegerType(),True), \
StructField('start', StringType(),True), \
StructField('end', StringType(),True)
])
df = spark.createDataFrame(data=data, schema=schema)
df.show(30)
You can use a window function:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
w = Window.partitionBy('ID').rowsBetween(Window.unboundedPreceding,0).orderBy('index')
df.withColumn('EventId', F.sum(F.when(F.col('start') == 'start', 1).otherwise(0))\
.over(w)).orderBy('ID', 'Index').show(100)
results in
+---+-----+-----+-----+----+-------+
| ID|Event|Index|start| end|EventId|
+---+-----+-----+-----+----+-------+
| 1| run| 0|start|null| 1|
| 1| run| 1| null|null| 1|
| 1| run| 2| null|null| 1|
| 1| swim| 3| null| end| 1|
| 1| run| 4|start|null| 2|
| 1| swim| 5| null|null| 2|
| 1| swim| 6| null| end| 2|
| 1| run| 7|start|null| 3|
| 1| run| 8| null|null| 3|
| 1| run| 9| null|null| 3|
| 1| swim| 10| null| end| 3|
| 1| run| 11|start|null| 4|
| 1| run| 12| null|null| 4|
| 1| run| 13| null| end| 4|
| 2| run| 14|start|null| 1|
| 2| run| 15| null|null| 1|
| 2| run| 16| null|null| 1|
| 2| swim| 17| null| end| 1|
| 2| run| 18|start|null| 2|
| 2| swim| 19| null|null| 2|
| 2| swim| 20| null|null| 2|
| 2| swim| 21| null|null| 2|
| 2| swim| 22| null| end| 2|
| 2| run| 23|start|null| 3|
| 2| run| 24| null|null| 3|
| 2| run| 25| null| end| 3|
| 3| run| 26|start|null| 1|
| 3| run| 27| null|null| 1|
| 3| swim| 28| null|null| 1|
+---+-----+-----+-----+----+-------+
You can calculate dense_rank based on the most recent start time:
from pyspark.sql import functions as F, Window
df2 = df.withColumn(
'laststart',
F.last(F.when(F.col('start') == 'start', F.col('Index')), True).over(Window.partitionBy('ID').orderBy('Index'))
).withColumn(
'EventID',
F.dense_rank().over(Window.partitionBy('ID').orderBy('laststart'))
)
df2.show(999)
+---+-----+-----+-----+----+---------+-------+
| ID|Event|Index|start| end|laststart|EventID|
+---+-----+-----+-----+----+---------+-------+
| 1| run| 0|start|null| 0| 1|
| 1| run| 1| null|null| 0| 1|
| 1| run| 2| null|null| 0| 1|
| 1| swim| 3| null| end| 0| 1|
| 1| run| 4|start|null| 4| 2|
| 1| swim| 5| null|null| 4| 2|
| 1| swim| 6| null| end| 4| 2|
| 1| run| 7|start|null| 7| 3|
| 1| run| 8| null|null| 7| 3|
| 1| run| 9| null|null| 7| 3|
| 1| swim| 10| null| end| 7| 3|
| 1| run| 11|start|null| 11| 4|
| 1| run| 12| null|null| 11| 4|
| 1| run| 13| null| end| 11| 4|
| 2| run| 14|start|null| 14| 1|
| 2| run| 15| null|null| 14| 1|
| 2| run| 16| null|null| 14| 1|
| 2| swim| 17| null| end| 14| 1|
| 2| run| 18|start|null| 18| 2|
| 2| swim| 19| null|null| 18| 2|
| 2| swim| 20| null|null| 18| 2|
| 2| swim| 21| null|null| 18| 2|
| 2| swim| 22| null| end| 18| 2|
| 2| run| 23|start|null| 23| 3|
| 2| run| 24| null|null| 23| 3|
| 2| run| 25| null| end| 23| 3|
| 3| run| 26|start|null| 26| 1|
| 3| run| 27| null|null| 26| 1|
| 3| swim| 28| null|null| 26| 1|
+---+-----+-----+-----+----+---------+-------+

Aggregating two columns with Pyspark

Learning Apache Spark through PySpark and having issues.
I have the following DF:
+----------+------------+-----------+----------------+
| game_id|posteam_type|total_plays|total_touchdowns|
+----------+------------+-----------+----------------+
|2009092003| home| 90| 3|
|2010091912| home| 95| 0|
|2010112106| home| 75| 0|
|2010121213| home| 85| 3|
|2009092011| null| 9| null|
|2010110703| null| 2| null|
|2010112111| null| 6| null|
|2011100909| home| 102| 3|
|2011120800| home| 72| 2|
|2012010110| home| 74| 6|
|2012110410| home| 68| 1|
|2012120911| away| 91| 2|
|2011103008| null| 6| null|
|2012111100| null| 3| null|
|2013092212| home| 86| 6|
|2013112407| home| 73| 4|
|2013120106| home| 99| 3|
|2014090705| home| 94| 3|
|2014101203| home| 77| 4|
|2014102611| home| 107| 6|
+----------+------------+-----------+----------------+
I'm attempting to find the average number of plays it takes to score a TD or sum(total_plays)/sum(total_touchdowns).
I figured out the code to get the sums but can't figure out how to get the total average:
plays = nfl_game_play.groupBy().agg({'total_plays': 'sum'}).collect()
touchdowns = nfl_game_play.groupBy().agg({'total_touchdowns',: 'sum'}).collect()
As you can see I tried storing each as a variable but beyond just remembering what each value is and manually doing it.
Try with below code:
Example:
df.show()
#+-----------+----------------+
#|total_plays|total_touchdowns|
#+-----------+----------------+
#| 90| 3|
#| 95| 0|
#| 9| null|
#+-----------+----------------+
from pyspark.sql.functions import *
total_avg=df.groupBy().agg(sum("total_plays")/sum("total_touchdowns")).collect()[0][0]
#64.66666666666667

how to split one column and keep other columns in pyspark dataframe?

I have data like this:
>>> data = sc.parallelize([[1,5,10,0,[1,2,3,4,5,6]],[0,10,20,1,[2,3,4,5,6,7]],[1,15,25,0,[3,4,5,6,7,8]],[0,30,40,1,[4,5,6,7,8,9]]]).toDF(('a','b','c',"d","e"))
>>> data.show()
+---+---+---+---+------------------+
| a| b| c| d| e|
+---+---+---+---+------------------+
| 1| 5| 10| 0|[1, 2, 3, 4, 5, 6]|
| 0| 10| 20| 1|[2, 3, 4, 5, 6, 7]|
| 1| 15| 25| 0|[3, 4, 5, 6, 7, 8]|
| 0| 30| 40| 1|[4, 5, 6, 7, 8, 9]|
+---+---+---+---+------------------+
# colums should be kept in result
keep_cols = ["a","b"]
# column 'e' should be split into split_e_cols
split_e_cols = ["one","two","three","four","five","six"]
# I hope the result dataframe has keep_cols + split_res_cols
I want to split column e into multiple columns and keep columns a and b at the same time.
I have tried:
data.select(*(col("e").getItem(i).alias(split_e_cols[i]) for i in range(len(len(split_e_cols)))))
and
data.select("e").rdd.flatMap(lambda x:x).toDF(split_e_cols)
neither can keep columns a and b.
Could anyone help me? Thanks.
Try this:
select_cols = [col(c) for c in keep_cols] + [col("e").getItem(i).alias(split_e_cols[i]) for i in range(len(split_e_cols))]
data.select(*select_cols).show()
#+---+---+---+---+-----+----+----+---+
#| a| b|one|two|three|four|five|six|
#+---+---+---+---+-----+----+----+---+
#| 1| 5| 1| 2| 3| 4| 5| 6|
#| 0| 10| 2| 3| 4| 5| 6| 7|
#| 1| 15| 3| 4| 5| 6| 7| 8|
#| 0| 30| 4| 5| 6| 7| 8| 9|
#+---+---+---+---+-----+----+----+---+
Or using for loop and withColumn:
data = data.select(keep_cols + ["e"])
for i in range(len(split_e_cols)):
data = data.withColumn(split_e_cols[i], col("e").getItem(i))
data.drop("e").show()
You can concatenate the lists using +:
from pyspark.sql.functions import col
data.select(
keep_cols +
[col("e").getItem(i).alias(split_e_cols[i]) for i in range(len(split_e_cols))]
).show()
+---+---+---+---+-----+----+----+---+
| a| b|one|two|three|four|five|six|
+---+---+---+---+-----+----+----+---+
| 1| 5| 1| 2| 3| 4| 5| 6|
| 0| 10| 2| 3| 4| 5| 6| 7|
| 1| 15| 3| 4| 5| 6| 7| 8|
| 0| 30| 4| 5| 6| 7| 8| 9|
+---+---+---+---+-----+----+----+---+
A more pythonic way is to use enumerate instead of range(len()):
from pyspark.sql.functions import col
data.select(
keep_cols +
[col("e").getItem(i).alias(c) for (i, c) in enumerate(split_e_cols)]
).show()
+---+---+---+---+-----+----+----+---+
| a| b|one|two|three|four|five|six|
+---+---+---+---+-----+----+----+---+
| 1| 5| 1| 2| 3| 4| 5| 6|
| 0| 10| 2| 3| 4| 5| 6| 7|
| 1| 15| 3| 4| 5| 6| 7| 8|
| 0| 30| 4| 5| 6| 7| 8| 9|
+---+---+---+---+-----+----+----+---+

Check start, middle and end of groups in Spark

I have a Spark dataframe that looks like this:
+---+-----------+-------------------------+---------------+
| id| Phase | Switch | InputFileName |
+---+-----------+-------------------------+---------------+
| 1| 2| 1| fileA|
| 2| 2| 1| fileA|
| 3| 2| 1| fileA|
| 4| 2| 0| fileA|
| 5| 2| 0| fileA|
| 6| 2| 1| fileA|
| 11| 2| 1| fileB|
| 12| 2| 1| fileB|
| 13| 2| 0| fileB|
| 14| 2| 0| fileB|
| 15| 2| 1| fileB|
| 16| 2| 1| fileB|
| 21| 4| 1| fileB|
| 22| 4| 1| fileB|
| 23| 4| 1| fileB|
| 24| 4| 1| fileB|
| 25| 4| 1| fileB|
| 26| 4| 0| fileB|
| 31| 1| 0| fileC|
| 32| 1| 0| fileC|
| 33| 1| 0| fileC|
| 34| 1| 0| fileC|
| 35| 1| 0| fileC|
| 36| 1| 0| fileC|
+---+-----------+-------------------------+---------------+
For each group (a combination of InputFileName and Phase) I need to run a validation function which checks that Switch equals 1 at the very start and end of the group, and transitions to 0 at any point in-between. The function should add the validation result as a new column. The expected output is below: (gaps are just to highlight the different groups)
+---+-----------+-------------------------+---------------+--------+
| id| Phase | Switch | InputFileName | Valid |
+---+-----------+-------------------------+---------------+--------+
| 1| 2| 1| fileA| true |
| 2| 2| 1| fileA| true |
| 3| 2| 1| fileA| true |
| 4| 2| 0| fileA| true |
| 5| 2| 0| fileA| true |
| 6| 2| 1| fileA| true |
| 11| 2| 1| fileB| true |
| 12| 2| 1| fileB| true |
| 13| 2| 0| fileB| true |
| 14| 2| 0| fileB| true |
| 15| 2| 1| fileB| true |
| 16| 2| 1| fileB| true |
| 21| 4| 1| fileB| false|
| 22| 4| 1| fileB| false|
| 23| 4| 1| fileB| false|
| 24| 4| 1| fileB| false|
| 25| 4| 1| fileB| false|
| 26| 4| 0| fileB| false|
| 31| 1| 0| fileC| false|
| 32| 1| 0| fileC| false|
| 33| 1| 0| fileC| false|
| 34| 1| 0| fileC| false|
| 35| 1| 0| fileC| false|
| 36| 1| 0| fileC| false|
+---+-----------+-------------------------+---------------+--------+
I have previously solved this using Pyspark and a Pandas UDF:
df = df.groupBy("InputFileName", "Phase").apply(validate_profile)
#pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def validate_profile(df: pd.DataFrame):
first_valid = True if df["Switch"].iloc[0] == 1 else False
during_valid = (df["Switch"].iloc[1:-1] == 0).any()
last_valid = True if df["Switch"].iloc[-1] == 1 else False
df["Valid"] = first_valid & during_valid & last_valid
return df
However, now I need to rewrite this in Scala. I just want to know the best way of accomplishing this.
I'm currently trying window functions to get the first and last ids of each group:
val minIdWindow = Window.partitionBy("InputFileName", "Phase").orderBy("id")
val maxIdWindow = Window.partitionBy("InputFileName", "Phase").orderBy(col("id").desc)
I can then add the min and max ids as separate columns and use when to get the start and end values of Switch:
df.withColumn("MinId", min("id").over(minIdWindow))
.withColumn("MaxId", max("id").over(maxIdWindow))
.withColumn("Valid", when(
col("id") === col("MinId"), col("Switch")
).when(
col("id") === col("MaxId"), col("Switch")
))
This gets me the start and end values, but I'm not sure how to check if Switch equals 0 in between. Am I on the right track using window functions? Or would you recommend an alternative solution?
Try this,
val wind = Window.partitionBy("InputFileName", "Phase").orderBy("id")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
val df1 = df.withColumn("Valid",
when(first("Switch").over(wind) === 1
&& last("Switch").over(wind) === 1
&& min("Switch").over(wind) === 0, true)
.otherwise(false))
df1.orderBy("id").show() //Ordering for display purpose
Output:
+---+-----+------+-------------+-----+
| id|Phase|Switch|InputFileName|Valid|
+---+-----+------+-------------+-----+
| 1| 2| 1| fileA| true|
| 2| 2| 1| fileA| true|
| 3| 2| 1| fileA| true|
| 4| 2| 0| fileA| true|
| 5| 2| 0| fileA| true|
| 6| 2| 1| fileA| true|
| 11| 2| 1| fileB| true|
| 12| 2| 1| fileB| true|
| 13| 2| 0| fileB| true|
| 14| 2| 0| fileB| true|
| 15| 2| 1| fileB| true|
| 16| 2| 1| fileB| true|
| 21| 4| 1| fileB|false|
| 22| 4| 1| fileB|false|
| 23| 4| 1| fileB|false|
| 24| 4| 1| fileB|false|
| 25| 4| 1| fileB|false|
| 26| 4| 0| fileB|false|
| 31| 1| 0| fileC|false|
| 32| 1| 0| fileC|false|
+---+-----+------+-------------+-----+

Resources