Enumerate blocks of successively equal values in Spark - apache-spark

I want to find the IDs of groups (or blocks) of trues in a Spark DataFrame. That is, I want to go from this:
>>> df.show()
+---------+-----+
|timestamp| bool|
+---------+-----+
| 1|false|
| 2| true|
| 3| true|
| 4|false|
| 5| true|
| 6| true|
| 7| true|
| 8| true|
| 9|false|
| 10|false|
| 11|false|
| 12|false|
| 13|false|
| 14| true|
| 15| true|
| 16| true|
+---------+-----+
to this:
>>> df.show()
+---------+-----+-----+
|timestamp| bool|block|
+---------+-----+-----+
| 1|false| 0|
| 2| true| 1|
| 3| true| 1|
| 4|false| 0|
| 5| true| 2|
| 6| true| 2|
| 7| true| 2|
| 8| true| 2|
| 9|false| 0|
| 10|false| 0|
| 11|false| 0|
| 12|false| 0|
| 13|false| 0|
| 14| true| 3|
| 15| true| 3|
| 16| true| 3|
+---------+-----+-----+
(the zeros are optional, could be Null or -1 or whatever is easier to implement)

I have a solution in scala, should be easy to adapt it to pyspark. Consider the following dataframe df:
+---------+-----+
|timestamp| bool|
+---------+-----+
| 1|false|
| 2| true|
| 3| true|
| 4|false|
| 5| true|
| 6| true|
| 7| true|
| 8| true|
| 9|false|
| 10|false|
| 11|false|
| 12|false|
| 13|false|
| 14| true|
| 15| true|
| 16| true|
+---------+-----+
then you could do:
df
.withColumn("prev_bool",lag($"bool",1).over(Window.orderBy($"timestamp")))
.withColumn("block",sum(when(!$"prev_bool" and $"bool",1).otherwise(0)).over(Window.orderBy($"timestamp")))
.drop($"prev_bool")
.withColumn("block",when($"bool",$"block").otherwise(0))
.show()
+---------+-----+-----+
|timestamp| bool|block|
+---------+-----+-----+
| 1|false| 0|
| 2| true| 1|
| 3| true| 1|
| 4|false| 0|
| 5| true| 2|
| 6| true| 2|
| 7| true| 2|
| 8| true| 2|
| 9|false| 0|
| 10|false| 0|
| 11|false| 0|
| 12|false| 0|
| 13|false| 0|
| 14| true| 3|
| 15| true| 3|
| 16| true| 3|
+---------+-----+-----+

Related

Pyspark update record based on last value using timestamp and column value

I'm struggling to figure this out. I need to find the last record with reason backfill and update the non backfill record with the greatest timestamp.
Here is what I've tried -
w = Window.orderBy("idx")
w1 = Window.partitionBy('reason').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
df_uahr.withColumn('idx',F.monotonically_increasing_id()).withColumn("app_data_new",F.last(F.lead("app_data").over(w)).over(w1)).orderBy("idx").show()
+----------------------+-------------+-------------------+-------------------+------------+---+------------+
|upstart_application_id| reason| created_at| updated_at| app_data|idx|app_data_new|
+----------------------+-------------+-------------------+-------------------+------------+---+------------+
| 2|disqualified |2018-07-12 15:57:26|2018-07-12 15:57:26| app_data_a| 0| app_data_c|
| 2| backfill|2020-05-29 17:47:09|2021-05-29 17:47:09| app_data_c| 1| null|
| 2| backfill|2022-03-09 09:47:09|2022-03-09 09:47:09| app_data_d| 2| null|
| 2| test|2022-04-09 09:47:09|2022-04-09 09:47:09| app_data_e| 3| app_data_f|
| 2| test|2022-04-19 09:47:09|2022-04-19 09:47:09|app_data_e_a| 4| app_data_f|
| 2| backfill|2022-05-09 09:47:09|2022-05-09 09:47:09| app_data_f| 5| null|
| 2| after|2023-04-09 09:47:09|2023-04-09 09:47:09| app_data_g| 6| app_data_h|
| 2| backfill|2023-05-09 09:47:09|2023-05-09 09:47:09| app_data_h| 7| null|
+----------------------+-------------+-------------------+-------------------+------------+---+------------+
Expected value
+----------------------+-------------+-------------------+-------------------+------------+---+------------+
|upstart_application_id| reason| created_at| updated_at| app_data|idx|app_data_new|
+----------------------+-------------+-------------------+-------------------+------------+---+------------+
| 2|disqualified |2018-07-12 15:57:26|2018-07-12 15:57:26| app_data_a| 0| app_data_d|
| 2| backfill|2020-05-29 17:47:09|2021-05-29 17:47:09| app_data_c| 1| null|
| 2| backfill|2022-03-09 09:47:09|2022-03-09 09:47:09| app_data_d| 2| null|
| 2| test|2022-04-09 09:47:09|2022-04-09 09:47:09| app_data_e| 3| null|
| 2| test|2022-04-19 09:47:09|2022-04-19 09:47:09|app_data_e_a| 4| app_data_f|
| 2| backfill|2022-05-09 09:47:09|2022-05-09 09:47:09| app_data_f| 5| null|
| 2| after|2023-04-09 09:47:09|2023-04-09 09:47:09| app_data_g| 6| app_data_h|
| 2| backfill|2023-05-09 09:47:09|2023-05-09 09:47:09| app_data_h| 7| null|
+----------------------+-------------+-------------------+-------------------+------------+---+------------+

Creating a custom counter in Spark based on dataframe conditions

Current Dataset
+---+-----+-----+-----+----+
| ID|Event|Index|start| end|
+---+-----+-----+-----+----+
| 1| run| 0|start|null|
| 1| run| 1| null|null|
| 1| run| 2| null|null|
| 1| swim| 3| null| end|
| 1| run| 4|start|null|
| 1| swim| 5| null|null|
| 1| swim| 6| null| end|
| 1| run| 7|start|null|
| 1| run| 8| null|null|
| 1| run| 9| null|null|
| 1| swim| 10| null| end|
| 1| run| 11|start|null|
| 1| run| 12| null|null|
| 1| run| 13| null| end|
| 2| run| 14|start|null|
| 2| run| 15| null|null|
| 2| run| 16| null|null|
| 2| swim| 17| null| end|
| 2| run| 18|start|null|
| 2| swim| 19| null|null|
| 2| swim| 20| null|null|
| 2| swim| 21| null|null|
| 2| swim| 22| null| end|
| 2| run| 23|start|null|
| 2| run| 24| null|null|
| 2| run| 25| null| end|
| 3| run| 26|start|null|
| 3| run| 27| null|null|
| 3| swim| 28| null|null|
+---+-----+-----+-----+----+
Dataset I'm After
+---+-----+-----+-----+----+-------+
| ID|Event|Index|start| end|EventID|
+---+-----+-----+-----+----+-------+
| 1| run| 0|start|null| 1|
| 1| run| 1| null|null| 1|
| 1| run| 2| null|null| 1|
| 1| swim| 3| null| end| 1|
| 1| run| 4|start|null| 2|
| 1| swim| 5| null|null| 2|
| 1| swim| 6| null| end| 2|
| 1| run| 7|start|null| 3|
| 1| run| 8| null|null| 3|
| 1| run| 9| null|null| 3|
| 1| swim| 10| null| end| 3|
| 1| run| 11|start|null| 4|
| 1| run| 12| null|null| 4|
| 1| run| 13| null| end| 4|
| 2| run| 14|start|null| 1|
| 2| run| 15| null|null| 1|
| 2| run| 16| null|null| 1|
| 2| swim| 17| null| end| 1|
| 2| run| 18|start|null| 2|
| 2| swim| 19| null|null| 2|
| 2| swim| 20| null|null| 2|
| 2| swim| 21| null|null| 2|
| 2| swim| 22| null| end| 2|
| 2| run| 23|start|null| 3|
| 2| run| 24| null|null| 3|
| 2| run| 25| null| end| 3|
| 3| run| 26|start|null| 1|
| 3| run| 27| null|null| 1|
| 3| swim| 28| null|null| 1|
+---+-----+-----+-----+----+-------+
I am trying to create the above EventID Column. Is there a way to create a counter inside of a udf that updates based on column conditions? Note, I'm not sure if a UDF is the best approach here.
Here is my current thinking-logic:
When a "start" value is seen, start counting.
When an "end" value is seen, end counting
Every time a new ID is seen, reset the counter to 1
Thank you all for any assistance.
Here is the raw code to produce the current dataframe:
# Current Dataset
data = [
(1, "run", 0, 'start', None),
(1, "run", 1, None, None),
(1, "run", 2, None, None),
(1, "swim", 3, None, 'end'),
(1, "run", 4, 'start',None),
(1, "swim", 5, None, None),
(1, "swim", 6, None, 'end'),
(1, "run",7, 'start', None),
(1, "run",8, None, None),
(1, "run",9, None, None),
(1, "swim",10, None, 'end'),
(1, "run",11, 'start', None),
(1, "run",12, None, None),
(1, "run",13, None, 'end'),
(2, "run",14, 'start', None),
(2, "run",15, None, None),
(2, "run",16, None, None),
(2, "swim",17, None, 'end'),
(2, "run",18, 'start', None),
(2, "swim",19, None, None),
(2, "swim",20, None, None),
(2, "swim",21, None, None),
(2, "swim",22, None, 'end'),
(2, "run",23, 'start', None),
(2, "run",24, None, None),
(2, "run",25, None, 'end'),
(3, "run",26, 'start', None),
(3, "run",27, None, None),
(3, "swim",28, None, None)
]
schema = StructType([
StructField('ID', IntegerType(),True), \
StructField('Event', StringType(),True), \
StructField('Index', IntegerType(),True), \
StructField('start', StringType(),True), \
StructField('end', StringType(),True)
])
df = spark.createDataFrame(data=data, schema=schema)
df.show(30)
You can use a window function:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
w = Window.partitionBy('ID').rowsBetween(Window.unboundedPreceding,0).orderBy('index')
df.withColumn('EventId', F.sum(F.when(F.col('start') == 'start', 1).otherwise(0))\
.over(w)).orderBy('ID', 'Index').show(100)
results in
+---+-----+-----+-----+----+-------+
| ID|Event|Index|start| end|EventId|
+---+-----+-----+-----+----+-------+
| 1| run| 0|start|null| 1|
| 1| run| 1| null|null| 1|
| 1| run| 2| null|null| 1|
| 1| swim| 3| null| end| 1|
| 1| run| 4|start|null| 2|
| 1| swim| 5| null|null| 2|
| 1| swim| 6| null| end| 2|
| 1| run| 7|start|null| 3|
| 1| run| 8| null|null| 3|
| 1| run| 9| null|null| 3|
| 1| swim| 10| null| end| 3|
| 1| run| 11|start|null| 4|
| 1| run| 12| null|null| 4|
| 1| run| 13| null| end| 4|
| 2| run| 14|start|null| 1|
| 2| run| 15| null|null| 1|
| 2| run| 16| null|null| 1|
| 2| swim| 17| null| end| 1|
| 2| run| 18|start|null| 2|
| 2| swim| 19| null|null| 2|
| 2| swim| 20| null|null| 2|
| 2| swim| 21| null|null| 2|
| 2| swim| 22| null| end| 2|
| 2| run| 23|start|null| 3|
| 2| run| 24| null|null| 3|
| 2| run| 25| null| end| 3|
| 3| run| 26|start|null| 1|
| 3| run| 27| null|null| 1|
| 3| swim| 28| null|null| 1|
+---+-----+-----+-----+----+-------+
You can calculate dense_rank based on the most recent start time:
from pyspark.sql import functions as F, Window
df2 = df.withColumn(
'laststart',
F.last(F.when(F.col('start') == 'start', F.col('Index')), True).over(Window.partitionBy('ID').orderBy('Index'))
).withColumn(
'EventID',
F.dense_rank().over(Window.partitionBy('ID').orderBy('laststart'))
)
df2.show(999)
+---+-----+-----+-----+----+---------+-------+
| ID|Event|Index|start| end|laststart|EventID|
+---+-----+-----+-----+----+---------+-------+
| 1| run| 0|start|null| 0| 1|
| 1| run| 1| null|null| 0| 1|
| 1| run| 2| null|null| 0| 1|
| 1| swim| 3| null| end| 0| 1|
| 1| run| 4|start|null| 4| 2|
| 1| swim| 5| null|null| 4| 2|
| 1| swim| 6| null| end| 4| 2|
| 1| run| 7|start|null| 7| 3|
| 1| run| 8| null|null| 7| 3|
| 1| run| 9| null|null| 7| 3|
| 1| swim| 10| null| end| 7| 3|
| 1| run| 11|start|null| 11| 4|
| 1| run| 12| null|null| 11| 4|
| 1| run| 13| null| end| 11| 4|
| 2| run| 14|start|null| 14| 1|
| 2| run| 15| null|null| 14| 1|
| 2| run| 16| null|null| 14| 1|
| 2| swim| 17| null| end| 14| 1|
| 2| run| 18|start|null| 18| 2|
| 2| swim| 19| null|null| 18| 2|
| 2| swim| 20| null|null| 18| 2|
| 2| swim| 21| null|null| 18| 2|
| 2| swim| 22| null| end| 18| 2|
| 2| run| 23|start|null| 23| 3|
| 2| run| 24| null|null| 23| 3|
| 2| run| 25| null| end| 23| 3|
| 3| run| 26|start|null| 26| 1|
| 3| run| 27| null|null| 26| 1|
| 3| swim| 28| null|null| 26| 1|
+---+-----+-----+-----+----+---------+-------+

Update column groupwise using multiple condition with window

I came across those window function pyspark offers and they seem to be quite useful. Unfortunately trying to solve problems I often don't get it to work. Now I wonder if my problem can at all be solved with window function...
Here's my task:
Starting with a dataframe mockup like below:
values = [(0,"a",True,True),(1,"a",True,True),(2,"a",True,True),(3,"a",True,True),(4,"a",True,True),
(0,"b",False,True),(1,"b",True,True),(2,"b",True,True),(3,"b",False,True),(4,"b",True,True),
(0,"c",False,True),(1,"c",True,True),(2,"c",True,True),(3,"c",False,True),(4,"c",False,True)]
columns = ['index', 'name', 'Res','solution']
mockup= spark.createDataFrame(values, columns)
mockup.show()
+-----+----+-----+----------------+
|index|name| Res|default_solution|
+-----+----+-----+----------------+
| 0| a| true| true|
| 1| a| true| true|
| 2| a| true| true|
| 3| a| true| true|
| 4| a| true| true|
| 0| b|false| true|
| 1| b| true| true|
| 2| b| true| true|
| 3| b|false| true|
| 4| b| true| true|
| 0| c|false| true|
| 1| c| true| true|
| 2| c| true| true|
| 3| c|false| true|
| 4| c|false| true|
+-----+----+-----+----------------+
I now want to update the solution column using multiple conditions.
If there are more than 2 false valus per group(name) OR if there are two false values in a group but non of them is at index = 0 the solution column should be false for the whole group, otherwise true.
See the desired outcome:
+-----+----+-----+--------+
|index|name| Res|solution|
+-----+----+-----+--------+
| 0| a| true| true|
| 1| a| true| true|
| 2| a| true| true|
| 3| a| true| true|
| 4| a| true| true|
| 0| b|false| true|
| 1| b| true| true|
| 2| b| true| true|
| 3| b|false| true|
| 4| b| true| true|
| 0| c|false| false|
| 1| c| true| false|
| 2| c| true| false|
| 3| c|false| false|
| 4| c|false| false|
+-----+----+-----+--------+
I managed to solve the problem with solution following but I hope there is a more elegant way to do this - maybe with windows. For window functions I am always struggling with where to put the window and how to use it in a more complex "when" condition.
My not so great solution :0)
df = mockup.filter(mockup.trip_distance_greater_zero == False).groupby(mockup.name).count()
false_filter_1 = df.filter(F.col('count')>2) \
.select('name').collect()
false_filter_2 = df.filter(F.col('count')==2) \
.select('name').collect()
array_false_1 = [str(row['name']) for row in false_filter_1]
array_false_2 = [str(row['name']) for row in false_filter_2]
false_filter_3 = mockup.filter((mockup['index']==0) & (mockup['Res']== False))\
.select('name').collect()
array_false_3 = [str(row['name']) for row in false_filter_3]
mockup = mockup.withColumn("over_2",
F.when((F.col('name').isin(array_false_1)), True).otherwise(False))\
.withColumn("eq_2",
F.when((F.col('name').isin(array_false_2)), True).otherwise(False))\
.withColumn("at0",
F.when((F.col('name').isin(array_false_3)), True).otherwise(False))\
.withColumn("solution",
F.when(((F.col('eq_2')==True) & (F.col('at0')==True)) | (F.col('over_2')==False)&(F.col('eq_2')==False), True).otherwise(False))\
.drop('over_2')\
.drop('eq_2')\
.drop('at0')\
mockup.show()
Here's my attempt at coding up your description. The output is different from your "expected" output because I guess you dealt with some logic incorrectly? b and c have the same pattern in your dataframe but somehow one of them is true and the other one is false.
from pyspark.sql import functions as F, Window
df2 = mockup.withColumn(
'false_count',
F.count(F.when(F.col('Res') == False, 1)).over(Window.partitionBy('name'))
).withColumn(
'false_at_0',
F.count(F.when((F.col('Res') == False) & (F.col('index') == 0), 1)).over(Window.partitionBy('name'))
).withColumn(
'solution',
~((F.col('false_count') > 2) | ((F.col('false_count') == 2) & (F.col('false_at_0') != 1)))
)
df2.show()
+-----+----+-----+--------+-----------+----------+
|index|name| Res|solution|false_count|false_at_0|
+-----+----+-----+--------+-----------+----------+
| 0| c|false| true| 2| 1|
| 1| c| true| true| 2| 1|
| 2| c| true| true| 2| 1|
| 3| c|false| true| 2| 1|
| 4| c| true| true| 2| 1|
| 0| b|false| true| 2| 1|
| 1| b| true| true| 2| 1|
| 2| b| true| true| 2| 1|
| 3| b|false| true| 2| 1|
| 4| b| true| true| 2| 1|
| 0| a| true| true| 0| 0|
| 1| a| true| true| 0| 0|
| 2| a| true| true| 0| 0|
| 3| a| true| true| 0| 0|
| 4| a| true| true| 0| 0|
+-----+----+-----+--------+-----------+----------+
Another perhaps more useful example:
values = [(0,"a",True,True),(1,"a",True,True),(2,"a",True,True),(3,"a",True,True),(4,"a",True,True),
(0,"b",False,True),(1,"b",True,True),(2,"b",True,True),(3,"b",False,True),(4,"b",True,True),
(0,"c",True,True),(1,"c",False,True),(2,"c",True,True),(3,"c",False,True),(4,"c",True,True),
(0,"d",True,True),(1,"d",False,True),(2,"d",False,True),(3,"d",False,True),(4,"d",True,True)]
columns = ['index', 'name', 'Res','solution']
mockup= spark.createDataFrame(values, columns)
which, after being processed by the first code, will give
+-----+----+-----+--------+-----------+----------+
|index|name| Res|solution|false_count|false_at_0|
+-----+----+-----+--------+-----------+----------+
| 0| d| true| false| 3| 0|
| 1| d|false| false| 3| 0|
| 2| d|false| false| 3| 0|
| 3| d|false| false| 3| 0|
| 4| d| true| false| 3| 0|
| 0| c| true| false| 2| 0|
| 1| c|false| false| 2| 0|
| 2| c| true| false| 2| 0|
| 3| c|false| false| 2| 0|
| 4| c| true| false| 2| 0|
| 0| b|false| true| 2| 1|
| 1| b| true| true| 2| 1|
| 2| b| true| true| 2| 1|
| 3| b|false| true| 2| 1|
| 4| b| true| true| 2| 1|
| 0| a| true| true| 0| 0|
| 1| a| true| true| 0| 0|
| 2| a| true| true| 0| 0|
| 3| a| true| true| 0| 0|
| 4| a| true| true| 0| 0|
+-----+----+-----+--------+-----------+----------+

Check start, middle and end of groups in Spark

I have a Spark dataframe that looks like this:
+---+-----------+-------------------------+---------------+
| id| Phase | Switch | InputFileName |
+---+-----------+-------------------------+---------------+
| 1| 2| 1| fileA|
| 2| 2| 1| fileA|
| 3| 2| 1| fileA|
| 4| 2| 0| fileA|
| 5| 2| 0| fileA|
| 6| 2| 1| fileA|
| 11| 2| 1| fileB|
| 12| 2| 1| fileB|
| 13| 2| 0| fileB|
| 14| 2| 0| fileB|
| 15| 2| 1| fileB|
| 16| 2| 1| fileB|
| 21| 4| 1| fileB|
| 22| 4| 1| fileB|
| 23| 4| 1| fileB|
| 24| 4| 1| fileB|
| 25| 4| 1| fileB|
| 26| 4| 0| fileB|
| 31| 1| 0| fileC|
| 32| 1| 0| fileC|
| 33| 1| 0| fileC|
| 34| 1| 0| fileC|
| 35| 1| 0| fileC|
| 36| 1| 0| fileC|
+---+-----------+-------------------------+---------------+
For each group (a combination of InputFileName and Phase) I need to run a validation function which checks that Switch equals 1 at the very start and end of the group, and transitions to 0 at any point in-between. The function should add the validation result as a new column. The expected output is below: (gaps are just to highlight the different groups)
+---+-----------+-------------------------+---------------+--------+
| id| Phase | Switch | InputFileName | Valid |
+---+-----------+-------------------------+---------------+--------+
| 1| 2| 1| fileA| true |
| 2| 2| 1| fileA| true |
| 3| 2| 1| fileA| true |
| 4| 2| 0| fileA| true |
| 5| 2| 0| fileA| true |
| 6| 2| 1| fileA| true |
| 11| 2| 1| fileB| true |
| 12| 2| 1| fileB| true |
| 13| 2| 0| fileB| true |
| 14| 2| 0| fileB| true |
| 15| 2| 1| fileB| true |
| 16| 2| 1| fileB| true |
| 21| 4| 1| fileB| false|
| 22| 4| 1| fileB| false|
| 23| 4| 1| fileB| false|
| 24| 4| 1| fileB| false|
| 25| 4| 1| fileB| false|
| 26| 4| 0| fileB| false|
| 31| 1| 0| fileC| false|
| 32| 1| 0| fileC| false|
| 33| 1| 0| fileC| false|
| 34| 1| 0| fileC| false|
| 35| 1| 0| fileC| false|
| 36| 1| 0| fileC| false|
+---+-----------+-------------------------+---------------+--------+
I have previously solved this using Pyspark and a Pandas UDF:
df = df.groupBy("InputFileName", "Phase").apply(validate_profile)
#pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def validate_profile(df: pd.DataFrame):
first_valid = True if df["Switch"].iloc[0] == 1 else False
during_valid = (df["Switch"].iloc[1:-1] == 0).any()
last_valid = True if df["Switch"].iloc[-1] == 1 else False
df["Valid"] = first_valid & during_valid & last_valid
return df
However, now I need to rewrite this in Scala. I just want to know the best way of accomplishing this.
I'm currently trying window functions to get the first and last ids of each group:
val minIdWindow = Window.partitionBy("InputFileName", "Phase").orderBy("id")
val maxIdWindow = Window.partitionBy("InputFileName", "Phase").orderBy(col("id").desc)
I can then add the min and max ids as separate columns and use when to get the start and end values of Switch:
df.withColumn("MinId", min("id").over(minIdWindow))
.withColumn("MaxId", max("id").over(maxIdWindow))
.withColumn("Valid", when(
col("id") === col("MinId"), col("Switch")
).when(
col("id") === col("MaxId"), col("Switch")
))
This gets me the start and end values, but I'm not sure how to check if Switch equals 0 in between. Am I on the right track using window functions? Or would you recommend an alternative solution?
Try this,
val wind = Window.partitionBy("InputFileName", "Phase").orderBy("id")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
val df1 = df.withColumn("Valid",
when(first("Switch").over(wind) === 1
&& last("Switch").over(wind) === 1
&& min("Switch").over(wind) === 0, true)
.otherwise(false))
df1.orderBy("id").show() //Ordering for display purpose
Output:
+---+-----+------+-------------+-----+
| id|Phase|Switch|InputFileName|Valid|
+---+-----+------+-------------+-----+
| 1| 2| 1| fileA| true|
| 2| 2| 1| fileA| true|
| 3| 2| 1| fileA| true|
| 4| 2| 0| fileA| true|
| 5| 2| 0| fileA| true|
| 6| 2| 1| fileA| true|
| 11| 2| 1| fileB| true|
| 12| 2| 1| fileB| true|
| 13| 2| 0| fileB| true|
| 14| 2| 0| fileB| true|
| 15| 2| 1| fileB| true|
| 16| 2| 1| fileB| true|
| 21| 4| 1| fileB|false|
| 22| 4| 1| fileB|false|
| 23| 4| 1| fileB|false|
| 24| 4| 1| fileB|false|
| 25| 4| 1| fileB|false|
| 26| 4| 0| fileB|false|
| 31| 1| 0| fileC|false|
| 32| 1| 0| fileC|false|
+---+-----+------+-------------+-----+

How to use variable arguments _* in udf with Scala/Spark?

I have a dataframe where the number of column is variable. Every column type is Int and I want to get sum of all column. thought of using :_* ,this is my code:
val arr = Array(1,4,3,2,5,7,3,5,4,18)
val input=new ArrayBuffer[(Int,Int)]()
for(i<-0 until 10){
input.append((i,arr(i%10)))
}
var df=sc.parallelize(input,3).toDF("value1","value2")
val cols=new ArrayBuffer[Column]()
val colNames=df.columns
for(name<-colNames){
cols.append(col(name))
}
val func = udf((s: Int*) => s.sum)
df.withColumn("sum",func(cols:_*)).show()
But I get a error:
Error:(101, 27) ')' expected but identifier found.
val func = udf((s: Int*) => s.sum)
Error:(101, 27) ')' expected but identifier found.
val func = udf((s: Int*) => s.sum)
how to use :_* in udf?
my except result is:
+------+------+---+
|value1|value2|sum|
+------+------+---+
| 0| 1| 1|
| 1| 4| 5|
| 2| 3| 5|
| 3| 2| 5|
| 4| 5| 9|
| 5| 7| 12|
| 6| 3| 9|
| 7| 5| 12|
| 8| 4| 12|
| 9| 18| 27|
+------+------+---+
This may what you expect
val func = udf((s: Seq[Int]) => s.sum)
df.withColumn("sum", func(array(cols: _*))).show()
where array is org.apache.spark.sql.functions.array which
Creates a new array column. The input columns must all have the same data type.
Spark UDF does not supports variable length arguments,
Here is a solution for your problem.
import spark.implicits._
val input = Array(1,4,3,2,5,7,3,5,4,18).zipWithIndex
var df=spark.sparkContext.parallelize(input,3).toDF("value2","value1")
df.withColumn("total", df.columns.map(col(_)).reduce(_ + _))
Output:
+------+------+-----+
|value2|value1|total|
+------+------+-----+
| 1| 0| 1|
| 4| 1| 5|
| 3| 2| 5|
| 2| 3| 5|
| 5| 4| 9|
| 7| 5| 12|
| 3| 6| 9|
| 5| 7| 12|
| 4| 8| 12|
| 18| 9| 27|
+------+------+-----+
Hope this helps
you can try VectorAssembler
import org.apache.spark.ml.feature.VectorAssembler
import breeze.linalg.DenseVector
val assembler = new VectorAssembler().
setInputCols(Array("your column name")).
setOutputCol("allNum")
val assembledDF = assembler.transform(df)
assembledDF.show
+------+------+----------+
|value1|value2| allNum|
+------+------+----------+
| 0| 1| [0.0,1.0]|
| 1| 4| [1.0,4.0]|
| 2| 3| [2.0,3.0]|
| 3| 2| [3.0,2.0]|
| 4| 5| [4.0,5.0]|
| 5| 7| [5.0,7.0]|
| 6| 3| [6.0,3.0]|
| 7| 5| [7.0,5.0]|
| 8| 4| [8.0,4.0]|
| 9| 18|[9.0,18.0]|
+------+------+----------+
def yourSumUDF = udf((allNum:Vector) => new DenseVector(allNum.toArray).sum)
assembledDF.withColumn("sum", yourSumUDF($"allNum")).show
+------+------+----------+----+
|value1|value2| allNum| sum|
+------+------+----------+----+
| 0| 1| [0.0,1.0]| 1.0|
| 1| 4| [1.0,4.0]| 5.0|
| 2| 3| [2.0,3.0]| 5.0|
| 3| 2| [3.0,2.0]| 5.0|
| 4| 5| [4.0,5.0]| 9.0|
| 5| 7| [5.0,7.0]|12.0|
| 6| 3| [6.0,3.0]| 9.0|
| 7| 5| [7.0,5.0]|12.0|
| 8| 4| [8.0,4.0]|12.0|
| 9| 18|[9.0,18.0]|27.0|
+------+------+----------+----+

Resources