Related
I would like to interpolate time series data. Thereby, the challenge is to interpolate only if the time interval between the existing values is not greater than a specified limit.
Input data
from pyspark.sql import SparkSession
spark = SparkSession.builder.config("spark.driver.memory", "60g").getOrCreate()
df = spark.createDataFrame([{'timestamp': 1642205833225, 'value': 58.00},
{'timestamp': 1642205888654, 'value': float('nan')},
{'timestamp': 1642205899657, 'value': float('nan')},
{'timestamp': 1642205892970, 'value': 55.00},
{'timestamp': 1642206338180, 'value': float('nan')},
{'timestamp': 1642206353652, 'value': 56.45},
{'timestamp': 1642206853451, 'value': float('nan')},
{'timestamp': 1642207353652, 'value': 80.45}
])
df.show()
+-------------+-----+
| timestamp|value|
+-------------+-----+
|1642205833225| 58.0|
|1642205888654| NaN|
|1642205899654| NaN|
|1642205892970| 55.0|
|1642206338180| NaN|
|1642206353652|56.45|
|1642206853451| NaN|
|1642207353652|80.45|
+-------------+-----+
First I want to calculate the time gap to the next existing value
(next_value - current_value).
+-------------+-----+---------------+
| timestamp|value|timegap_to_next|
+-------------+-----+---------------+
|1642205833225| 58.0| 59745|
|1642205888654| NaN| NaN|
|1642205899657| NaN| NaN|
|1642205892970| 55.0| 460682|
|1642206338180| NaN| NaN|
|1642206353652|56.45| 1030300|
|1642206853451| NaN| NaN|
|1642207383952|80.45| NaN|
+-------------+-----+---------------+
Based on the calculated Timegap the interpolation should be done. In this case the threshold is 500000.
Final Output:
+-------------+-----+---------------+
| timestamp|value|timegap_to_next|
+-------------+-----+---------------+
|1642205833225| 58.0| 59745|
|1642205888654| 57.0| NaN|
|1642205899657| 56.0| NaN|
|1642205892970| 55.0| 460682|
|1642206338180|55.75| NaN|
|1642206353652|56.45| 1030300|
|1642206853451| NaN| NaN|
|1642207383952|80.45| NaN|
+-------------+-----+---------------+
Can anybody help me with this special case? That would be very nice!
Having this input dataframe:
df = spark.createDataFrame([
(1642205833225, 58.00), (1642205888654, float('nan')),
(1642205899657, float('nan')), (1642205899970, 55.00),
(1642206338180, float('nan')), (1642206353652, 56.45),
(1642206853451, float('nan')), (1642207353652, 80.45)
], ["timestamp", "value"])
# replace NaN value by Nulls
df = df.replace(float("nan"), None, ["value"])
You can use some window functions (last, first) to get next and previous non null values for each row and calculate the time gap like this:
from pyspark.sql import functions as F, Window
w1 = Window.orderBy("timestamp").rowsBetween(1, Window.unboundedFollowing)
w2 = Window.orderBy("timestamp").rowsBetween(Window.unboundedPreceding, -1)
df = (
df.withColumn("rn", F.row_number().over(Window.orderBy("timestamp")))
.withColumn("next_val", F.first("value", ignorenulls=True).over(w1))
.withColumn("next_rn", F.first(F.when(F.col("value").isNotNull(), F.col("rn")), ignorenulls=True).over(w1))
.withColumn("prev_val", F.last("value", ignorenulls=True).over(w2))
.withColumn("prev_rn", F.last(F.when(F.col("value").isNotNull(), F.col("rn")), ignorenulls=True).over(w2))
.withColumn("timegap_to_next", F.when(F.col("value").isNotNull(), F.min(F.when(F.col("value").isNotNull(), F.col("timestamp"))).over(w1) - F.col("timestamp")))
)
Now, you can do the linear interpolation of column value depending on your threshold using when expression:
w3 = Window.orderBy("timestamp").rowsBetween(Window.unboundedPreceding, Window.currentRow)
df = df.withColumn(
"value",
F.coalesce(
"value",
F.when(
F.last("timegap_to_next", ignorenulls=True).over(w3) < 500000,
(F.col("prev_val") +
((F.col("next_val") - F.col("prev_val"))/
(F.col("next_timestamp") - F.col("prev_next_timestamp"))
* (F.col("timestamp") - F.col("prev_next_timestamp")
)
)
)
)
)
).select("timestamp", "value", "timegap_to_next")
df.show()
#+-------------+------+---------------+
#| timestamp| value|timegap_to_next|
#+-------------+------+---------------+
#|1642205833225| 58.0| 66745|
#|1642205888654| 56.0| null|
#|1642205899657| 57.0| null|
#|1642205899970| 55.0| 453682|
#|1642206338180|55.725| null|
#|1642206353652| 56.45| 1000000|
#|1642206853451| null| null|
#|1642207353652| 80.45| null|
#+-------------+------+---------------+
I was looking for a solution to this problem, and I have noticed that the last part of the answer above cannot be ran directly as it is, since some columns are not defined. I edited and simplified a bit (I am using window functions only in the first step) in case anyone else find it useful.
Input dataframe:
df = spark.createDataFrame([
(0, None),
(10, 10.00),
(20, None), # 20
(30, 30.00),
(40, None), # 25
(50, None), # 20
(60, 15.00),
(70, None), # 20
(80, 25.00),
(90, None)
], ["timestamp", "value"])
Compute previous/next timestamp and values using window functions:
from pyspark.sql import functions as F, Window
w1 = Window.orderBy("timestamp").rowsBetween(1, Window.unboundedFollowing)
w2 = Window.orderBy("timestamp").rowsBetween(Window.unboundedPreceding, -1)
df = (
df.withColumn("next_val", F.first("value", ignorenulls=True).over(w1))
.withColumn("prev_val", F.last("value", ignorenulls=True).over(w2))
.withColumn("next_timestamp", F.first(F.when(F.col("value").isNotNull(), F.col("timestamp")), ignorenulls=True).over(w1))
.withColumn("prev_timestamp", F.last(F.when(F.col("value").isNotNull(), F.col("timestamp")), ignorenulls=True).over(w2))
)
df.show()
+---------+-----+--------+--------+--------------+--------------+
|timestamp|value|next_val|prev_val|next_timestamp|prev_timestamp|
+---------+-----+--------+--------+--------------+--------------+
| 0| null| 10.0| null| 10| null|
| 10| 10.0| 30.0| null| 30| null|
| 20| null| 30.0| 10.0| 30| 10|
| 30| 30.0| 15.0| 10.0| 60| 10|
| 40| null| 15.0| 30.0| 60| 30|
| 50| null| 15.0| 30.0| 60| 30|
| 60| 15.0| 25.0| 30.0| 80| 30|
| 70| null| 25.0| 15.0| 80| 60|
| 80| 25.0| null| 15.0| null| 60|
| 90| null| null| 25.0| null| 80|
+---------+-----+--------+--------+--------------+--------------+
Make interpolation with conditional on the length of intervals of missing values. In this case the time gap from two consecutive existing values need to be less than 30.
df = (df.withColumn("timegap", F.when(F.col("value").isNull(), F.col("next_timestamp")-F.col("prev_timestamp")))
.withColumn("new_value",
F.when(
(F.col("value").isNull()) & (F.col('timegap')<30),
F.round(F.col('prev_val') + (F.col('next_val')-F.col('prev_val')) / (F.col('next_timestamp')-F.col('prev_timestamp')) * (F.col('timestamp')-F.col('prev_timestamp')), 2 )
).otherwise(F.col('value')))
)
df.select('timestamp','value', 'timegap', 'new_value').show()
+---------+-----+-------+---------+
|timestamp|value|timegap|new_value|
+---------+-----+-------+---------+
| 0| null| null| null|
| 10| 10.0| null| 10.0|
| 20| null| 20| 20.0|
| 30| 30.0| null| 30.0|
| 40| null| 30| null|
| 50| null| 30| null|
| 60| 15.0| null| 15.0|
| 70| null| 20| 20.0|
| 80| 25.0| null| 25.0|
| 90| null| null| null|
+---------+-----+-------+---------+
I have a dataframe where I need to convert rows of the same group to columns. basically pivot these. below is my df.
+------------+-------+-----+-------+
|Customer |ID |unit |order |
+------------+-------+-----+-------+
|John |123 |00015|1 |
|John |123 |00016|2 |
|John |345 |00205|3 |
|John |345 |00206|4 |
|John |789 |00283|5 |
|John |789 |00284|6 |
+------------+-------+-----+-------+
I need the resultant data for the above as..
+--------+-------+--------+----------+--------+--------+-----------+--------+-------+----------+
|state | ID_1 | unit_1 |seq_num_1 | ID_2 | unit_2 | seq_num_2 | ID_3 |unit_3 |seq_num_3 |
+--------+-------+--------+----------+--------+--------+-----------+--------+-------+----------+
|John | 123 | 00015 | 1 | 345 | 00205 | 3 | 789 |00283 | 5 |
|John | 123 | 00016 | 2 | 345 | 00206 | 4 | 789 |00284 | 6 |
+--------+-------+--------+----------+--------+--------+-----------+--------+-------+----------+
I tried to groupBy and pivot() function, but its throwing error says large pivot values found. Is there any way to get the result without using the pivot() function..any help is greatly appreciated.
thanks.
This looks like a typical case of using dense_rank() aggregate function to create a generic sequence (dr in the below code) of distinct IDs under each group of Customer, then do pivoting on this sequence. we can do the similar to order column using row_number() so that it can be used in groupby:
from pyspark.sql import Window, functions as F
# below I added an extra row for a reference when the number of rows vary for different IDs
df = spark.createDataFrame([
('John', '123', '00015', '1'), ('John', '123', '00016', '2'), ('John', '345', '00205', '3'),
('John', '345', '00206', '4'), ('John', '789', '00283', '5'), ('John', '789', '00284', '6'),
('John', '789', '00285', '7')
], ['Customer', 'ID', 'unit', 'order'])
Add two Window Specs: w1 to get dense_rank() of IDs over Customer and w2 to get row_number() of order under the same Customer and ID.
w1 = Window.partitionBy('Customer').orderBy('ID')
w2 = Window.partitionBy('Customer','ID').orderBy('order')
Add two new columns based on the above two WinSpecs: dr(dense_rank) and sid(row_number)
df1 = df.select(
"*",
F.dense_rank().over(w1).alias('dr'),
F.row_number().over(w2).alias('sid')
)
+--------+---+-----+-----+---+---+
|Customer| ID| unit|order| dr|sid|
+--------+---+-----+-----+---+---+
| John|123|00015| 1| 1| 1|
| John|123|00016| 2| 1| 2|
| John|345|00205| 3| 2| 1|
| John|345|00206| 4| 2| 2|
| John|789|00283| 5| 3| 1|
| John|789|00284| 6| 3| 2|
| John|789|00285| 7| 3| 3|
+--------+---+-----+-----+---+---+
Find the max(dr), so that we can pre-define the list to pivot on which is range(1,N+1) (this will improve the efficiency of pivot method).
N = df1.agg(F.max('dr')).first()[0]
Groupby Customer, sid and pivot with dr and then do the aggregate:
df_new = df1.groupby('Customer','sid') \
.pivot('dr', range(1,N+1)) \
.agg(
F.first('ID').alias('ID'),
F.first('unit').alias('unit'),
F.first('order').alias('order')
)
df_new.show()
+--------+---+----+------+-------+----+------+-------+----+------+-------+
|Customer|sid|1_ID|1_unit|1_order|2_ID|2_unit|2_order|3_ID|3_unit|3_order|
+--------+---+----+------+-------+----+------+-------+----+------+-------+
| John| 1| 123| 00015| 1| 345| 00205| 3| 789| 00283| 5|
| John| 2| 123| 00016| 2| 345| 00206| 4| 789| 00284| 6|
| John| 3|null| null| null|null| null| null| 789| 00285| 7|
+--------+---+----+------+-------+----+------+-------+----+------+-------+
Rename the column names if needed:
import re
df_new.toDF(*['_'.join(reversed(re.split('_',c,1))) for c in df_new.columns]).show()
+--------+---+----+------+-------+----+------+-------+----+------+-------+
|Customer|sid|ID_1|unit_1|order_1|ID_2|unit_2|order_2|ID_3|unit_3|order_3|
+--------+---+----+------+-------+----+------+-------+----+------+-------+
| John| 1| 123| 00015| 1| 345| 00205| 3| 789| 00283| 5|
| John| 2| 123| 00016| 2| 345| 00206| 4| 789| 00284| 6|
| John| 3|null| null| null|null| null| null| 789| 00285| 7|
+--------+---+----+------+-------+----+------+-------+----+------+-------+
below is my solution.. doing the rank and then flattening the results.
df = spark.createDataFrame([
('John', '123', '00015', '1'), ('John', '123', '00016', '2'), ('John', '345', '00205', '3'),
('John', '345', '00206', '4'), ('John', '789', '00283', '5'), ('John', '789', '00284', '6'),
('John', '789', '00285', '7')
], ['Customer', 'ID', 'unit', 'order'])
rankedDF = df.withColumn("rank", row_number().over(Window.partitionBy("customer").orderBy("order")))
w1 = Window.partitionBy("customer").orderBy("order")
groupedDF = rankedDF.select("customer", "rank", collect_list("ID").over(w1).alias("ID"), collect_list("unit").over(w1).alias("unit"), collect_list("order").over(w1).alias("seq_num")).groupBy("customer", "rank").agg(max("ID").alias("ID"), max("unit").alias("unit"), max("seq_num").alias("seq_num") )
groupedColumns = [col("customer")]
pivotColumns = map(lambda i:map(lambda a:col(a)[i-1].alias(a + "_" + `i`), ["ID", "unit", "seq_num"]), [1,2,3])
flattenedCols = [item for sublist in pivotColumns for item in sublist]
finalDf=groupedDF.select(groupedColumns + flattenedCols)
There may be multiple ways to do this but a pandas udf can be one such way. Here is a toy example based on your data:
df = pd.DataFrame({'Customer': ['John']*6,
'ID': [123]*2 + [345]*2 + [789]*2,
'unit': ['00015', '00016', '00205', '00206', '00283', '00284'],
'order': range(1, 7)})
sdf = spark.createDataFrame(df)
# Spark 2.4 syntax. Spark 3.0 is less verbose
return_types = 'state string, ID_1 int, unit_1 string, seq_num_1 int, ID_2int, unit_2 string, seq_num_2 int, ID_3 int, unit_3 string, seq_num_3 int'
#pandas_udf(returnType=return_types, functionType=PandasUDFType.GROUPED_MAP)
def convert_to_wide(pdf):
groups = pdf.groupby('ID')
out = pd.concat([group.set_index('Customer') for _, group in groups], axis=1).reset_index()
out.columns = ['state', 'ID_1', 'unit_1', 'seq_num_1', 'ID_2', 'unit_2', 'seq_num_2', 'ID_3', 'unit_3', 'seq_num_3']
return out
sdf.groupby('Customer').apply(convert_to_wide).show()
+-----+----+------+---------+----+------+---------+----+------+---------+
|state|ID_1|unit_1|seq_num_1|ID_2|unit_2|seq_num_2|ID_3|unit_3|seq_num_3|
+-----+----+------+---------+----+------+---------+----+------+---------+
| John| 123| 00015| 1| 345| 00205| 3| 789| 00283| 5|
| John| 123| 00016| 2| 345| 00206| 4| 789| 00284| 6|
+-----+----+------+---------+----+------+---------+----+------+---------+
I am using spark-sql-2.4.1v with java8 version.
I have a scenario where I need to copy current row and create another row modifying few columns data how can this be achieved in spark-sql ?
Ex :
Given
val data = List(
("20", "score", "school", 14 ,12),
("21", "score", "school", 13 , 13),
("22", "rate", "school", 11 ,14)
)
val df = data.toDF("id", "code", "entity", "value1","value2")
Current Output
+---+-----+------+------+------+
| id| code|entity|value1|value2|
+---+-----+------+------+------+
| 20|score|school| 14| 12|
| 21|score|school| 13| 13|
| 22| rate|school| 11| 14|
+---+-----+------+------+------+
When column "code" is "rate" copy it as two rows i.e. one is
original , second it is another row with new code "old_ rate" like
below
Expected output :
+---+--------+------+------+------+
| id| code|entity|value1|value2|
+---+--------+------+------+------+
| 20| score|school| 14| 12|
| 21| score|school| 13| 13|
| 22| rate|school| 11| 14|
| 22|new_rate|school| 11| 14|
+---+--------+------+------+------+
how to achieve this ?
you can use this approach for your scenario,
df.union(df.filter($"code"==="rate").withColumn("code",concat(lit("new_"), $"code"))).show()
/*
+---+--------+------+------+------+
| id| code|entity|value1|value2|
+---+--------+------+------+------+
| 20| score|school| 14| 12|
| 21| score|school| 13| 13|
| 22| rate|school| 11| 14|
| 22|new_rate|school| 11| 14|
+---+--------+------+------+------+
*/
Use when to check code === rate, if it is matched then replace that column value with array(lit("rate"),lit("new_rate")) & not matched column values array($"code") then explode code column.
Check below code.
scala> df.show(false)
+---+-----+------+------+------+
|id |code |entity|value1|value2|
+---+-----+------+------+------+
|20 |score|school|14 |12 |
|21 |score|school|13 |13 |
|22 |rate |school|11 |14 |
+---+-----+------+------+------+
val colExpr = explode(
when(
$"code" === "rate",
array(
lit("rate"),
lit("new_rate")
)
)
.otherwise(array($"code"))
)
scala> df.withColumn("code",colExpr).show(false)
+---+--------+------+------+------+
|id |code |entity|value1|value2|
+---+--------+------+------+------+
|20 |score |school|14 |12 |
|21 |score |school|13 |13 |
|22 |rate |school|11 |14 |
|22 |new_rate|school|11 |14 |
+---+--------+------+------+------+
Team, we have a requirement to generate a report of mismatched columns based on key field between 2 Pyspark dataframes of exactly same structure.
Here is first dataframe-
>>> df.show()
+--------+----+----+----+----+----+----+----+----+
| key|col1|col2|col3|col4|col5|col6|col7|col8|
+--------+----+----+----+----+----+----+----+----+
| abcd| 123| xyz| a| ab| abc| def| qew| uvw|
| abcd1| 123| xyz| a| ab| abc| def| qew| uvw|
| abcd12| 123| xyz| a| ab| abc| def| qew| uvw|
| abcd123| 123| xyz| a| ab| abc| def| qew| uvw|
|abcd1234| 123| xyz| a| ab| abc| def| qew| uvw|
+--------+----+----+----+----+----+----+----+----+
And here is 2nd dataframe-
>>> df1.show()
+--------+----+----+----+----+----+----+----+----+
| key|col1|col2|col3|col4|col5|col6|col7|col8|
+--------+----+----+----+----+----+----+----+----+
| abcd| 123| xyz| a| ab| abc| def| qew| uvw|
| abcdx| 123| xyz| a| ab| abc| def| qew| uvw|
| abcd12| 123| xyz| a| abx| abc|defg| qew| uvw|
| abcd123| 123| xyz| a| ab| abc|defg| qew| uvw|
|abcd1234| 123| xyz| a| ab|abcd|defg| qew| uvw|
+--------+----+----+----+----+----+----+----+----+
Full Outer Join gives me this-
>>> dfFull=df.join(df1,'key','outer')
>>> dfFull.show()
+--------+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
| key|col1|col2|col3|col4|col5|col6|col7|col8|col1|col2|col3|col4|col5|col6|col7|col8|
+--------+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
| abcd12| 123| xyz| a| ab| abc| def| qew| uvw| 123| xyz| a| abx| abc|defg| qew| uvw|
| abcd1| 123| xyz| a| ab| abc| def| qew| uvw|null|null|null|null|null|null|null|null|
|abcd1234| 123| xyz| a| ab| abc| def| qew| uvw| 123| xyz| a| ab|abcd|defg| qew| uvw|
| abcd123| 123| xyz| a| ab| abc| def| qew| uvw| 123| xyz| a| ab| abc|defg| qew| uvw|
| abcdx|null|null|null|null|null|null|null|null| 123| xyz| a| ab| abc| def| qew| uvw|
| abcd| 123| xyz| a| ab| abc| def| qew| uvw| 123| xyz| a| ab| abc| def| qew| uvw|
+--------+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
if i just look at col6, there are 5 values which mismatched for the "key" field (only value match is for last record).
>>> dfFull.select('key',df['col6'],df1['col6']).show()
+--------+----+----+
| key|col6|col6|
+--------+----+----+
| abcd12| def|defg|
| abcd1| def|null|
|abcd1234| def|defg|
| abcd123| def|defg|
| abcdx|null| def|
| abcd| def| def|
+--------+----+----+
I need to generate a report of something like this for all the columns. The mismatch sample can be any record's value from dataframes.
colName,NumofMismatch,mismatchSampleFromDf,misMatchSamplefromDf1
col6,5,def,defg
col7,2,null,qew
col8,2,null,uvw
col5,3,null,abc
It is a column wise summary based on key, saying how many values are mismatch between 2 dataframes.
Sid
Assume the two dataframes are df1 and df2, you can try the following:
from pyspark.sql.functions import when, array, count, first
# list of columns to be compared
cols = df1.columns[1:]
df_new = (df1.join(df2, "key", "outer")
.select([ when(~df1[c].eqNullSafe(df2[c]), array(df1[c], df2[c])).alias(c) for c in cols ])
.selectExpr('stack({},{}) as (colName, mismatch)'.format(len(cols), ','.join('"{0}",`{0}`'.format(c) for c in cols)))
.filter('mismatch is not NULL'))
df_new.show(10)
+-------+-----------+
|colName| mismatch|
+-------+-----------+
| col4| [ab, abx]|
| col6|[def, defg]|
| col6|[def, defg]|
| col5|[abc, abcd]|
| col6|[def, defg]|
| col1| [, 123]|
| col2| [, xyz]|
| col3| [, a]|
| col4| [, ab]|
| col5| [, abc]|
+-------+-----------+
Notes: (1) the condition ~df1[c].eqNullSafe(df2[c]) used to find the mismatches satisfies either of the following:
+ df1[c] != df2[c]
+ df1[c] is NULL or df2[c] is NULL but not both
(2) The mismatches if exist are saved as ArrayType column with the first item from df1 and 2nd item from df2. NULL is returned if no mismatch and later filtered out.
(3) the stack() function dynamically generated by Python format functions is as follows:
stack(8,"col1",`col1`,"col2",`col2`,"col3",`col3`,"col4",`col4`,"col5",`col5`,"col6",`col6`,"col7",`col7`,"col8",`col8`) as (colName, mismatch)
After we have df_new, then we can do the groupby + aggregation:
df_new.groupby('colName') \
.agg(count('mismatch').alias('NumOfMismatch'), first('mismatch').alias('mismatch')) \
.selectExpr('colName', 'NumOfMismatch', 'mismatch[0] as misMatchFromdf1', 'mismatch[1] as misMatchFromdf2')
.show()
+-------+-------------+---------------+---------------+
|colName|NumOfMismatch|misMatchFromdf1|misMatchFromdf2|
+-------+-------------+---------------+---------------+
| col8| 2| null| uvw|
| col3| 2| null| a|
| col4| 3| ab| abx|
| col1| 2| null| 123|
| col6| 5| def| defg|
| col5| 3| abc| abcd|
| col2| 2| null| xyz|
| col7| 2| null| qew|
+-------+-------------+---------------+---------------+
Here I have student marks like below and I want to transpose subject name column and want to get the total marks also after the pivot.
Source table like:
+---------+-----------+-----+
|StudentId|SubjectName|Marks|
+---------+-----------+-----+
| 1| A| 10|
| 1| B| 20|
| 1| C| 30|
| 2| A| 20|
| 2| B| 25|
| 2| C| 30|
| 3| A| 10|
| 3| B| 20|
| 3| C| 20|
+---------+-----------+-----+
Destination:
+---------+---+---+---+-----+
|StudentId| A| B| C|Total|
+---------+---+---+---+-----+
| 1| 10| 20| 30| 60|
| 3| 10| 20| 20| 50|
| 2| 20| 25| 30| 75|
+---------+---+---+---+-----+
Please find the below source code:
val spark = SparkSession.builder().appName("test").master("local[*]").getOrCreate()
import spark.implicits._
val list = List((1, "A", 10), (1, "B", 20), (1, "C", 30), (2, "A", 20), (2, "B", 25), (2, "C", 30), (3, "A", 10),
(3, "B", 20), (3, "C", 20))
val df = list.toDF("StudentId", "SubjectName", "Marks")
df.show() // source table as per above
val df1 = df.groupBy("StudentId").pivot("SubjectName", Seq("A", "B", "C")).agg(sum("Marks"))
df1.show()
val df2 = df1.withColumn("Total", col("A") + col("B") + col("C"))
df2.show // required destitnation
val df3 = df.groupBy("StudentId").agg(sum("Marks").as("Total"))
df3.show()
df1 is not displaying the sum/total column. it's displaying like below.
+---------+---+---+---+
|StudentId| A| B| C|
+---------+---+---+---+
| 1| 10| 20| 30|
| 3| 10| 20| 20|
| 2| 20| 25| 30|
+---------+---+---+---+
df3 is able to create new Total column but why in df1 it not able to create a new column?
Please, can anybody help me what I missing or anything wrong with my understanding of pivot concept?
This is an expected behaviour from spark pivot function as .agg function is applied on the pivoted columns that's the reason why you are not able to see sum of marks as new column.
Refer to this link for official documentation about pivot.
Example:
scala> df.groupBy("StudentId").pivot("SubjectName").agg(sum("Marks") + 2).show()
+---------+---+---+---+
|StudentId| A| B| C|
+---------+---+---+---+
| 1| 12| 22| 32|
| 3| 12| 22| 22|
| 2| 22| 27| 32|
+---------+---+---+---+
In the above example we have added 2 to all the pivoted columns.
Example2:
To get count using pivot and agg
scala> df.groupBy("StudentId").pivot("SubjectName").agg(count("*")).show()
+---------+---+---+---+
|StudentId| A| B| C|
+---------+---+---+---+
| 1| 1| 1| 1|
| 3| 1| 1| 1|
| 2| 1| 1| 1|
+---------+---+---+---+
The .agg followed by pivot is applicable only for the pivoted data. To find the sum you should you should add new column and sum it as below.
val cols = Seq("A", "B", "C")
val result = df.groupBy("StudentId")
.pivot("SubjectName")
.agg(sum("Marks"))
.withColumn("Total", cols.map(col _).reduce(_ + _))
result.show(false)
Output:
+---------+---+---+---+-----+
|StudentId|A |B |C |Total|
+---------+---+---+---+-----+
|1 |10 |20 |30 |60 |
|3 |10 |20 |20 |50 |
|2 |20 |25 |30 |75 |
+---------+---+---+---+-----+