How to limit functions.collect_set in Spark SQL? - apache-spark

I'm dealing with a column of numbers in a large spark DataFrame, and I would like to create a new column that stores an aggregated list of unique numbers that appear in that column.
Basically exactly what functions.collect_set does. However, i only need up to 1000 elements in the aggregated list. Is there any way to pass that parameter somehow to functions.collect_set(), or any other way to get only up to 1000 elements in the aggregated list, without using a UDAF?
Since the column is so large, I'd like to avoid collecting all elements and trimming the list afterwards.
Thanks!

Spark 2.4
As pointed out in a comment, Spark 2.4.0 comes with slice standard function which can do this sort of thing.
val usage = sql("describe function slice").as[String].collect()(2)
scala> println(usage)
Usage: slice(x, start, length) - Subsets array x starting from index start (array indices start at 1, or starting from the end if start is negative) with the specified length.
That gives the following query:
val q = input
.groupBy('key)
.agg(collect_set('id) as "collect")
.withColumn("three_only", slice('collect, 1, 3))
scala> q.show(truncate = false)
+---+--------------------------------------+------------+
|key|collect |three_only |
+---+--------------------------------------+------------+
|0 |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|[0, 15, 30] |
|1 |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|[1, 16, 31] |
|3 |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|[33, 48, 13]|
|2 |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|[12, 27, 37]|
|4 |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|[9, 19, 34] |
+---+--------------------------------------+------------+
Before Spark 2.4
I'd use a UDF that would do what you want after collect_set (or collect_list) or a much harder UDAF.
Given more experience with UDFs, I'd go with that first. Even though UDFs are not optimized, for this use case it's fine.
val limitUDF = udf { (nums: Seq[Long], limit: Int) => nums.take(limit) }
val sample = spark.range(50).withColumn("key", $"id" % 5)
scala> sample.groupBy("key").agg(collect_set("id") as "all").show(false)
+---+--------------------------------------+
|key|all |
+---+--------------------------------------+
|0 |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|
|1 |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|
|3 |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|
|2 |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|
|4 |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|
+---+--------------------------------------+
scala> sample.
groupBy("key").
agg(collect_set("id") as "all").
withColumn("limit(3)", limitUDF($"all", lit(3))).
show(false)
+---+--------------------------------------+------------+
|key|all |limit(3) |
+---+--------------------------------------+------------+
|0 |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|[0, 15, 30] |
|1 |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|[1, 16, 31] |
|3 |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|[33, 48, 13]|
|2 |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|[12, 27, 37]|
|4 |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|[9, 19, 34] |
+---+--------------------------------------+------------+
See functions object (for udf function's docs).

I'm using a modified copy of the collect_set and collect_list functions; because of code scopes, the modified copies must be in the same package path as the originals. The linked code works for Spark 2.1.0; if you are using a prior version, method signatures may be different.
Throw this file (https://gist.github.com/lokkju/06323e88746c85b2ce4de3ea9cdef9bc) into your project as src/main/org/apache/spark/sql/catalyst/expression/collect_limit.scala
use it as:
import org.apache.spark.sql.catalyst.expression.collect_limit._
df.groupBy('set_col).agg(collect_set_limit('set_col,1000)

scala> df.show
+---+-----+----+--------+
| C0| C1| C2| C3|
+---+-----+----+--------+
| 10| Name|2016| Country|
| 11|Name1|2016|country1|
| 10| Name|2016| Country|
| 10| Name|2016| Country|
| 12|Name2|2017|Country2|
+---+-----+----+--------+
scala> df.groupBy("C1").agg(sum("C0"))
res36: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]
scala> res36.show
+-----+-------+
| C1|sum(C0)|
+-----+-------+
|Name1| 11|
|Name2| 12|
| Name| 30|
+-----+-------+
scala> df.limit(2).groupBy("C1").agg(sum("C0"))
res33: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]
scala> res33.show
+-----+-------+
| C1|sum(C0)|
+-----+-------+
| Name| 10|
|Name1| 11|
+-----+-------+
scala> df.groupBy("C1").agg(sum("C0")).limit(2)
res2: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]
scala> res2.show
+-----+-------+
| C1|sum(C0)|
+-----+-------+
|Name1| 11|
|Name2| 12|
+-----+-------+
scala> df.distinct
res8: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string]
scala> res8.show
+---+-----+----+--------+
| C0| C1| C2| C3|
+---+-----+----+--------+
| 11|Name1|2016|country1|
| 10| Name|2016| Country|
| 12|Name2|2017|Country2|
+---+-----+----+--------+
scala> df.dropDuplicates(Array("c1"))
res11: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string]
scala> res11.show
+---+-----+----+--------+
| C0| C1| C2| C3|
+---+-----+----+--------+
| 11|Name1|2016|country1|
| 12|Name2|2017|Country2|
| 10| Name|2016| Country|
+---+-----+----+--------+

As other answers have mentioned, the performant way of doing this would be to write a UDAF. Unfortunately the UDAF API is actually not as extensible as the aggregate functions that ship with spark. However you can use their internal APIs to build on the internal functions to do what you need.
Here is an implementation for collect_set_limit that is mostly a copy past of Spark's internal CollectSet AggregateFunction. I would just extend it but its a case class. Really all that's needed is to override update and merge methods to respect a passed in limit:
case class CollectSetLimit(
child: Expression,
limitExp: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] {
val limit = limitExp.eval( null ).asInstanceOf[Int]
def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty
override def update(buffer: mutable.HashSet[Any], input: InternalRow): mutable.HashSet[Any] = {
if( buffer.size < limit ) super.update(buffer, input)
else buffer
}
override def merge(buffer: mutable.HashSet[Any], other: mutable.HashSet[Any]): mutable.HashSet[Any] = {
if( buffer.size >= limit ) buffer
else buffer ++= other.take( limit - buffer.size )
}
override def prettyName: String = "collect_set_limit"
}
And to actually register it, we can do it through Spark's internal FunctionRegistry which takes in the name and the builder which is effectively a function that creates a CollectSetLimit using the provided expressions:
val collectSetBuilder = (args: Seq[Expression]) => CollectSetLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_set_limit", collectSetBuilder )
Edit:
Turns out adding it to the builtin only works if you haven't created the SparkContext yet as it makes an immutable clone on startup. If you have an existing context then this should work to add it with reflection:
val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_set_limit", collectSetBuilder )

use take
val firstThousand = rdd.take(1000)
Will return the first 1000.
Collect also has a filter function that can be provided. That would allow you to be more specific as to what is returned.

Related

Can I use regexp_replace or some equivalent to replace multiple values in a pyspark dataframe column with one line of code?

Can I use regexp_replace or some equivalent to replace multiple values in a pyspark dataframe column with one line of code?
Here is the code to create my dataframe:
from pyspark import SparkContext, SparkConf, SQLContext
from datetime import datetime
sc = SparkContext().getOrCreate()
sqlContext = SQLContext(sc)
data1 = [
('George', datetime(2010, 3, 24, 3, 19, 58), 13),
('George', datetime(2020, 9, 24, 3, 19, 6), 8),
('George', datetime(2009, 12, 12, 17, 21, 30), 5),
('Micheal', datetime(2010, 11, 22, 13, 29, 40), 12),
('Maggie', datetime(2010, 2, 8, 3, 31, 23), 8),
('Ravi', datetime(2009, 1, 1, 4, 19, 47), 2),
('Xien', datetime(2010, 3, 2, 4, 33, 51), 3),
]
df1 = sqlContext.createDataFrame(data1, ['name', 'trial_start_time', 'purchase_time'])
df1.show(truncate=False)
Here is the dataframe:
+-------+-------------------+-------------+
|name |trial_start_time |purchase_time|
+-------+-------------------+-------------+
|George |2010-03-24 07:19:58|13 |
|George |2020-09-24 07:19:06|8 |
|George |2009-12-12 22:21:30|5 |
|Micheal|2010-11-22 18:29:40|12 |
|Maggie |2010-02-08 08:31:23|8 |
|Ravi |2009-01-01 09:19:47|2 |
|Xien |2010-03-02 09:33:51|3 |
+-------+-------------------+-------------+
Here is a working example to replace one string:
from pyspark.sql.functions import regexp_replace, regexp_extract, col
df1.withColumn("name", regexp_replace('name', "Ravi", "Ravi_renamed")).show()
Here is the output:
+------------+-------------------+-------------+
| name| trial_start_time|purchase_time|
+------------+-------------------+-------------+
| George|2010-03-24 07:19:58| 13|
| George|2020-09-24 07:19:06| 8|
| George|2009-12-12 22:21:30| 5|
| Micheal|2010-11-22 18:29:40| 12|
| Maggie|2010-02-08 08:31:23| 8|
|Ravi_renamed|2009-01-01 09:19:47| 2|
| Xien|2010-03-02 09:33:51| 3|
+------------+-------------------+-------------+
In pandas I could replace multiple strings in one line of code with a lambda expression:
df1[name].apply(lambda x: x.replace('George','George_renamed1').replace('Ravi', 'Ravi_renamed2')
I am not sure if this can be done in pyspark with regexp_replace. Perhaps another alternative? When I read about using lambda expressions in pyspark it seems I have to create udf functions (which seem to get a little long). But I am curious if I can simply run some type of regex expression on multiple strings like above in one line of code.
This is what you're looking for:
Using when() (most readable)
df1.withColumn('name',
when(col('name') == 'George', 'George_renamed1')
.when(col('name') == 'Ravi', 'Ravi_renamed2')
.otherwise(col('name'))
)
With mapping expr (less explicit but handy if there's many values to replace)
df1 = df1.withColumn('name', F.expr("coalesce(map('George', 'George_renamed1', 'Ravi', 'Ravi_renamed2')[name], name)"))
or if you already have a list to use i.e.
name_changes = ['George', 'George_renamed1', 'Ravi', 'Ravi_renamed2']
# str()[1:-1] to convert list to string and remove [ ]
df1 = df1.withColumn('name', expr(f'coalesce(map({str(name_changes)[1:-1]})[name], name)'))
the above but only using pyspark imported functions
mapping_expr = create_map([lit(x) for x in name_changes])
df1 = df1.withColumn('name', coalesce(mapping_expr[df1['name']], 'name'))
Result
df1.withColumn('name', F.expr("coalesce(map('George', 'George_renamed1', 'Ravi', 'Ravi_renamed2')[name],name)")).show()
+---------------+-------------------+-------------+
| name| trial_start_time|purchase_time|
+---------------+-------------------+-------------+
|George_renamed1|2010-03-24 03:19:58| 13|
|George_renamed1|2020-09-24 03:19:06| 8|
|George_renamed1|2009-12-12 17:21:30| 5|
| Micheal|2010-11-22 13:29:40| 12|
| Maggie|2010-02-08 03:31:23| 8|
| Ravi_renamed2|2009-01-01 04:19:47| 2|
| Xien|2010-03-02 04:33:51| 3|
+---------------+-------------------+-------------+

Conver pyspark column to a list

I thought this would be easy but can't find the answer :-)
How do I convert the name column in to a list. I am hoping I can get isin to work rather than a join against another datframe column. But isin seems to require a list (if I understand correctly).
Create the datframe:
from pyspark import SparkContext, SparkConf, SQLContext
from datetime import datetime
sc = SparkContext().getOrCreate()
sqlContext = SQLContext(sc)
data2 = [
('George', datetime(2010, 3, 24, 3, 19, 58), 3),
('Sally', datetime(2009, 12, 12, 17, 21, 30), 5),
('Frank', datetime(2010, 11, 22, 13, 29, 40), 2),
('Paul', datetime(2010, 2, 8, 3, 31, 23), 8),
('Jesus', datetime(2009, 1, 1, 4, 19, 47), 2),
('Lou', datetime(2010, 3, 2, 4, 33, 51), 3),
]
df2 = sqlContext.createDataFrame(data2, ['name', 'trial_start_time', 'purchase_time'])
df2.show(truncate=False)
Should look like:
+------+-------------------+-------------+
|name |trial_start_time |purchase_time|
+------+-------------------+-------------+
|George|2010-03-24 07:19:58|3 |
|Sally |2009-12-12 22:21:30|5 |
|Frank |2010-11-22 18:29:40|2 |
|Paul |2010-02-08 08:31:23|8 |
|Jesus |2009-01-01 09:19:47|2 |
|Lou |2010-03-02 09:33:51|3 |
+------+-------------------+-------------+
I am not sure if collect is the closest I can come to this.
df2.select("name").collect()
[Row(name='George'),
Row(name='Sally'),
Row(name='Frank'),
Row(name='Paul'),
Row(name='Jesus'),
Row(name='Lou')]
Any suggestions on how to output the name column to a list?
It may need to look something like this:
[George, Sally, Frank, Paul, Jesus, Lou]
Use collect_list function and then collect to get list variable.
Example:
from pyspark.sql.functions import *
df2.agg(collect_list(col("name")).alias("name")).show(10,False)
#+----------------------------------------+
#|name |
#+----------------------------------------+
#|[George, Sally, Frank, Paul, Jesus, Lou]|
#+----------------------------------------+
lst=df2.agg(collect_list(col("name"))).collect()[0][0]
lst
#['George', 'Sally', 'Frank', 'Paul', 'Jesus', 'Lou']

Append a value after every element in PySpark list Dataframe

I am having a dataframe like this
Data ID
[1,2,3,4] 22
I want to create a new column and each and every entry in the new column will be value from Data field appended with ID by ~ symbol, like below
Data ID New_Column
[1,2,3,4] 22 [1|22~2|22~3|22~4|22]
Note : In Data field the array size is not fixed one. It may not have entry or N number of entry will be there.
Can anyone please help me to solve!
package spark
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
object DF extends App {
val spark = SparkSession.builder()
.master("local")
.appName("DataFrame-example")
.getOrCreate()
import spark.implicits._
val df = Seq(
(22, Seq(1,2,3,4)),
(23, Seq(1,2,3,4,5,6,7,8)),
(24, Seq())
).toDF("ID", "Data")
val arrUDF = udf((id: Long, array: Seq[Long]) => {
val r = array.size match {
case 0 => ""
case _ => array.map(x => s"$x|$id").mkString("~")
}
s"[$r]"
})
val resDF = df.withColumn("New_Column", lit(arrUDF('ID, 'Data)))
resDF.show(false)
//+---+------------------------+-----------------------------------------+
//|ID |Data |New_Column |
//+---+------------------------+-----------------------------------------+
//|22 |[1, 2, 3, 4] |[1|22~2|22~3|22~4|22] |
//|23 |[1, 2, 3, 4, 5, 6, 7, 8]|[1|23~2|23~3|23~4|23~5|23~6|23~7|23~8|23]|
//|24 |[] |[] |
//+---+------------------------+-----------------------------------------+
}
Spark 2.4+
Pyspark equivalent for the same goes like
df = spark.createDataFrame([(22, [1,2,3,4]),(23, [1,2,3,4,5,6,7,8]),(24, [])],['Id','Data'])
df.show()
+---+--------------------+
| Id| Data|
+---+--------------------+
| 22| [1, 2, 3, 4]|
| 23|[1, 2, 3, 4, 5, 6...|
| 24| []|
+---+--------------------+
df.withColumn('ff', f.when(f.size('Data')==0,'').otherwise(f.expr('''concat_ws('~',transform(Data, x->concat(x,'|',Id)))'''))).show(20,False)
+---+------------------------+---------------------------------------+
|Id |Data |ff |
+---+------------------------+---------------------------------------+
|22 |[1, 2, 3, 4] |1|22~2|22~3|22~4|22 |
|23 |[1, 2, 3, 4, 5, 6, 7, 8]|1|23~2|23~3|23~4|23~5|23~6|23~7|23~8|23|
|24 |[] | |
+---+------------------------+---------------------------------------+
If you want final output as array
df.withColumn('ff',f.array(f.when(f.size('Data')==0,'').otherwise(f.expr('''concat_ws('~',transform(Data, x->concat(x,'|',Id)))''')))).show(20,False)
+---+------------------------+-----------------------------------------+
|Id |Data |ff |
+---+------------------------+-----------------------------------------+
|22 |[1, 2, 3, 4] |[1|22~2|22~3|22~4|22] |
|23 |[1, 2, 3, 4, 5, 6, 7, 8]|[1|23~2|23~3|23~4|23~5|23~6|23~7|23~8|23]|
|24 |[] |[] |
+---+------------------------+-----------------------------------------+
Hope this helps
A udf can help:
def func(array, suffix):
return '~'.join([str(x) + '|' + str(suffix) for x in array])
from pyspark.sql.types import StringType
from pyspark.sql import functions as F
my_udf = F.udf(func, StringType())
df.withColumn("New_Column", my_udf("Data", "ID")).show()
prints
+------------+---+-------------------+
| Data| ID| New_Column |
+------------+---+-------------------+
|[1, 2, 3, 4]| 22|22~1|22~2|22~3|22~4|
+------------+---+-------------------+

How to get the topic using pyspark LDA

I have used LDA for finding the topic
ref:
from pyspark.ml.clustering import LDA
lda = LDA(k=30, seed=123, optimizer="em", maxIter=10, featuresCol="features")
ldamodel = lda.fit(rescaledData)
when i run the below code i find the result with topic, termIndices and termWeights
ldatopics = ldamodel.describeTopics()
+-----+-----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|topic|termIndices |termWeights |
+-----+-----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|0 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13643518321703474, 0.09550636070102715, 0.0948398694152729, 0.05766480874922468, 0.04482014536392716, 0.04435413761288223, 0.04390248330822808, 0.042949351875241376, 0.039792489008412854, 0.03959557452696915] |
|1 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13650714786977475, 0.09549302392463813, 0.09473573233391298, 0.05765172919949817, 0.04483412655497465, 0.04437734133645192, 0.043917280973977395, 0.042966507550942196, 0.03978157454461109, 0.039593095881148545] |
|2 |[18, 14, 1, 2, 6, 15, 19, 0, 5, 11]|[0.13628285951099658, 0.0960680469091329, 0.0957882659160997, 0.05753830851677705, 0.04476298485932895, 0.04426297294164529, 0.043901044582340995, 0.04291495373567377, 0.039780501358024425, 0.03946491774918109] |
|3 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13674345908386384, 0.09541103313663375, 0.09430192089972703, 0.05770255869404108, 0.044821207138198024, 0.04441671155466873, 0.043902024784365994, 0.04294093494951478, 0.03981953622791824, 0.039614320304130965] |
|4 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.1364122091134319, 0.0953545543783122, 0.09429906593265366, 0.05772444907583193, 0.044862863343866806, 0.04442734201228477, 0.04389557512474934, 0.04296443267889805, 0.03982659626276644, 0.03962640467091713] |
|5 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13670139028438008, 0.09553267874876908, 0.0946605951853061, 0.05768305061621247, 0.04480375989378822, 0.04435320773212808, 0.043914956101645836, 0.0429208816887896, 0.03981783866996495, 0.0395571526370012] |
|6 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13655841155238757, 0.09554895256001286, 0.0949549658561804, 0.05762248195720487, 0.04480233168639564, 0.04433230827791344, 0.04390710584467472, 0.042930742282966325, 0.039814043132483344, 0.039566695745809] |
|7 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13663380500078914, 0.09550193655114965, 0.09478283807049456, 0.05766262514700366, 0.04480957336655386, 0.044349558779903736, 0.04392217503495675, 0.04293742117414056, 0.03978830895805316, 0.03953804442817271] |
|8 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13657941125100242, 0.09553782026082008, 0.09488224957791296, 0.0576562888735838, 0.04478320710416449, 0.044348589637858433, 0.043920567136734125, 0.04291002130418712, 0.03979768943659999, 0.039567161368513126] |
|9 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13679375890331896, 0.09550217013809542, 0.09456405056150366, 0.057658739173097585, 0.04482551122651224, 0.0443527568707439, 0.04392317149752475, 0.04290757658934661, 0.03979803663250257, 0.03956130101562635] |
|10 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.1363532425412472, 0.09567087909028642, 0.09540557719048867, 0.0576159180949091, 0.044781629751243876, 0.04429913174899199, 0.04391116942509204, 0.042932143462292065, 0.03978668768070522, 0.03951443679027362] |
|11 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13652547151289018, 0.09563212101150799, 0.09559109602844051, 0.057554258392101834, 0.044746350635605996, 0.044288429293621624, 0.04393701569930993, 0.04291245197810508, 0.03976607824588965, 0.039502323331991364]|
+---------++--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
then i ran the below code for transforming, which gives my whole input data with "topicDistribution", i want to know, how can i map their corresponding topic from above.
transformed = ldamodel.transform(rescaledData)
transformed.show(1000,truncate=False)
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|features |topicDistribution |
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|(20,[0,1,2,5,6,7,9,10,11,12,14,15,17,18],[1.104871083111864,2.992600400740947,2.6744520065699304,1.3684833461617696,2.2505718610977468,1.667281425641495,0.4887254650789885,1.6871692149804491,1.2395584596083011,1.1654793679713804,1.1856289804763125,0.8839314023462759,1.3072768360965874,1.8484873888968552]) |[0.033335920691559356,0.03333070077028474,0.033333697225675085,0.03333561564482505,0.03333642468043084,0.033329348877568624,0.033332822489008374,0.0333260618235691,0.0333303226504357,0.03333271017792987,0.03332747782859632,0.03332947925144302,0.03332588375760626,0.03333450406000825,0.03333113684458691,0.03333479529731098,0.03333650038104409,0.03333779171890448,0.033330895429371024,0.03333346818374588,0.03333587374669536,0.03333749064143864,0.033334190553966234,0.03333150654339392,0.03333448663832666,0.03333989460983954,0.03333245526240933,0.03333885169495637,0.03333318824494455,0.03333650428012542] |
|(20,[1,3,9,10,11,18],[0.997533466913649,1.2726157969776226,1.4661763952369655,1.6871692149804491,1.2395584596083011,1.2323249259312368]) |[0.03333282248535624,0.033332445609854544,0.033328001364314915,0.03333556680774651,0.03333511480351705,0.03333350870876004,0.03333187395000147,0.03333445718728295,0.03333444971729848,0.03333408945721941,0.033330403594416406,0.03333043254420633,0.03333157466407796,0.033332770049699215,0.033328359180981044,0.03333228442030213,0.03333371386494732,0.03333681533414124,0.0333317544337193,0.0333325923806738,0.03333507835375717,0.0333340926283776,0.033334289011977415,0.03333710235559251,0.033333843031028085,0.03333431818041727,0.03333321848222645,0.03333712581455898,0.03333342489994699,0.033334476683601115] |
|(20,[0,1,2,5,6,7,9,10,11,12,14,15,17,18],[1.104871083111864,2.992600400740947,2.6744520065699304,1.3684833461617696,2.2505718610977468,1.667281425641495,0.4887254650789885,1.6871692149804491,1.2395584596083011,1.1654793679713804,1.1856289804763125,0.8839314023462759,1.3072768360965874,1.8484873888968552]) |[0.03333351262221663,0.03334603652469889,0.03332161942705381,0.03333027102774255,0.03333390930010211,0.03333191764996061,0.033329934469446876,0.033333299309223546,0.033341041908943575,0.0333327792244044,0.0333317683142701,0.03332955596634091,0.033328132222973345,0.033337885062247365,0.03332933802183438,0.033340558825094374,0.03332932381968907,0.033330915221236455,0.03333188695334856,0.033327196223241304,0.0333344223760765,0.03332799716207421,0.033338186672894565,0.033336730538183736,0.03333440038338333,0.033337949318794795,0.033333179017769575,0.03333720003330217,0.0333312634707931,0.033337788932659276] |
|(20,[7,10,14,15,18],[1.667281425641495,1.6871692149804491,1.1856289804763125,0.8839314023462759,1.2323249259312368]) |[0.03333299162768077,0.03333267864965697,0.03333419860453984,0.0333335178997543,0.03333360724598342,0.03333242559458429,0.03333347136144222,0.03333355459499807,0.03333358147457144,0.03333288221459248,0.033333183195193086,0.03333412789444317,0.033334780263427975,0.033333456529848454,0.033336348770946406,0.03333413810342763,0.03333364116227056,0.03333271328980694,0.033333953427323364,0.03333358082109339,0.03333219641721605,0.03333309878777325,0.03333357447436248,0.03333149527336071,0.03333364840046119,0.03333405705432205,0.03333193039244976,0.033332142999523764,0.0333329476480237,0.033332075826922346] |
|(20,[7,10,14,15,18],[1.667281425641495,1.6871692149804491,1.1856289804763125,0.8839314023462759,1.2323249259312368]) |[0.033333498405871936,0.033333280285774286,0.03333426144981626,0.033332860785526656,0.033333582231251914,0.03333268142499894,0.0333323342942777,0.033333434542768936,0.03333306424165755,0.0333326212718143,0.03333359821538673,0.03333408970522017,0.03333440903364261,0.033333628480934824,0.033335988240581156,0.03333411840979216,0.03333341827050789,0.033332367335435646,0.033334058739684515,0.03333355988453005,0.033332524988062315,0.03333411721155432,0.033333323503042835,0.03333212784312619,0.0333335375674124,0.03333359055906317,0.033332183680577915,0.033332671344198254,0.03333288800754154,0.03333218004594688] |
|(20,[2,5,9,16,17,18],[0.8914840021899768,1.3684833461617696,0.4887254650789885,1.63668714555745,2.614553672193175,0.6161624629656184]) |[0.033334335349062126,0.033335078498013884,0.03332587515068904,0.03333663657766274,0.03333680954467168,0.03333423911759284,0.03333228558123402,0.03333307701307248,0.03333230070449398,0.03333357555321742,0.03333013898196302,0.033329558017855816,0.033331357742536066,0.03333281031681836,0.03332620861451679,0.03333164736607878,0.033333411422713344,0.03333811924641809,0.03333032668669555,0.03333155312673253,0.03333601811939769,0.03333261877594549,0.03333490102064452,0.033338814952549214,0.03333297326837463,0.03333518474628696,0.03333659535650031,0.033334882391957275,0.03333262407795885,0.03333604267834657] |
|(20,[9,14,17],[0.4887254650789885,1.1856289804763125,1.3072768360965874]) |[0.03333373548732593,0.03333366047427678,0.0333388921917903,0.0333324500674453,0.03333165231618069,0.033333396263220766,0.03333392152736817,0.03333280784507405,0.03333341964532808,0.033332027450640234,0.03333590751407632,0.03333625709954709,0.033335027917460715,0.033332984179205334,0.0333395676850832,0.03333465555831355,0.03333317988602309,0.03332999573625669,0.033335686662720146,0.03333444862817325,0.033331118794333245,0.03333227594267816,0.03333337815403729,0.03332815318734183,0.03333329306551438,0.033332463398394074,0.03333211180252395,0.03332955750997137,0.03333267853415129,0.033331295475544864] |
|(20,[0,1,2,5,6,8,9,12,14,19],[4.419484332447456,8.97780120222284,0.8914840021899768,1.3684833461617696,1.1252859305488734,1.233681992250468,0.4887254650789885,1.1654793679713804,42.68264329714725,3.685472474556911]) |[0.03333847676223903,0.03333217319316553,0.03361190962478921,0.03322296565909072,0.033224152351515865,0.033305780550557426,0.0333623247307499,0.03333530792027193,0.03335048739584223,0.03328160880463601,0.0334726823486143,0.03350523767287266,0.0334249096308398,0.0333474750843313,0.033626678634369765,0.033399676341965154,0.033312708717161896,0.03316394279048787,0.033455734324711994,0.033388088146836026,0.03321934579333965,0.03331319196962562,0.03332496308901043,0.03308256004388311,0.033355283259340625,0.03328280644324275,0.0332436575330529,0.03319398288438216,0.03331402807784297,0.03320786022123095] |
|(20,[2,8,10,13,15,16],[0.8914840021899768,3.701045976751404,1.6871692149804491,1.4594938792746623,0.8839314023462759,1.63668714555745]) |[0.0333356286958193,0.033332968808115676,0.03332196061288262,0.03333752852282297,0.03334171827141663,0.033333866506025614,0.03332902614535797,0.03332999721261129,0.03333228759033127,0.03333419089119785,0.033326295405433394,0.03332727748411373,0.03332970492260472,0.03333218324896459,0.033323419320752445,0.03333484772134479,0.033331408271507906,0.03333886892339302,0.033330223503670105,0.03333357612919576,0.033337166312023804,0.03333359583540509,0.033335058229509947,0.03334147190850166,0.03333378789809434,0.033336369234199595,0.033338417163238966,0.03333500103939193,0.03333456204893585,0.03333759214313734] |
|(20,[0,1,5,9,11,14,18],[2.209742166223728,1.995066933827298,1.3684833461617696,0.4887254650789885,1.2395584596083011,1.1856289804763125,4.313137240759328]) |[0.033331568372787335,0.03333199604802443,0.03333616626240888,0.03333311434679797,0.03332871558684149,0.03333377685027234,0.03333412369772984,0.03333326223955321,0.03333302350234804,0.033334011390192264,0.033332998734646985,0.03333572338531835,0.03333519102733107,0.033333669722909055,0.03333367298003543,0.03333316052548548,0.03333434503087792,0.03333204101430685,0.03333390212404242,0.03333317825522851,0.03333468470990933,0.0333342373715466,0.03333384912615554,0.03333262239313073,0.033332130871257,0.03333157173270879,0.033331620213171494,0.03333491236454045,0.03333364390007425,0.03333308622036798] |
|(20,[0,1,2,7,9,12,14,16,17],[1.104871083111864,1.995066933827298,1.7829680043799536,1.667281425641495,0.4887254650789885,3.4964381039141412,1.1856289804763125,1.63668714555745,1.3072768360965874]) |[0.033335136368759725,0.033335270064671074,0.033333243143688636,0.03333135191450168,0.03333521337181425,0.03333250351809707,0.03333535436589356,0.03333322757251989,0.033333747503130604,0.03333235416690325,0.03333273826809218,0.03333297626192852,0.033333297925505816,0.03333285267847805,0.033334741557111344,0.03333582976624498,0.03333192986853772,0.0333314389004771,0.03333186692894785,0.033332539165631093,0.03333280721245933,0.03333326073074502,0.03333359623114709,0.03333261337071022,0.03333404937430939,0.03333536871228382,0.033334559939835924,0.03333095099987788,0.0333326510660671,0.033332529051629894] |
|(20,[2,5],[0.8914840021899768,1.3684833461617696]) |[0.033333008832987246,0.03333320773569408,0.03333212285409669,0.03333370183210235,0.03333408881228382,0.03333369017540667,0.03333315896362525,0.03333316321736793,0.033333273291611384,0.033332955144148683,0.03333299410668819,0.03333254807040678,0.03333302471306508,0.03333324003641127,0.033332032586187325,0.0333328866035473,0.033333366369446636,0.03333432146865401,0.033333080986559745,0.033332919796171534,0.03333365180032741,0.03333319727206391,0.0333336947744764,0.033334184463059975,0.033333428924387516,0.033333998899679654,0.033333803141837516,0.03333374053474702,0.03333339151988759,0.03333412307307108] |
In order to remap the terminindices to words you have to access the vocabulary of the CountVectorizer model. Please have a look at the pseudocode below:
from pyspark.sql.functions import udf
from pyspark.sql.types import *
#Reading your data...
cv = CountVectorizer(inputCol="some", outputCol="features", vocabSize=2000)
cvmodel = cv.fit(df)
#other stuff
vocab = cvmodel.vocabulary
vocab_broadcast = sc.broadcast(vocab)
#creating LDA model
ldatopics = ldamodel.describeTopics()
def map_termID_to_Word(termIndices):
words = []
for termID in termIndices:
words.append(vocab_broadcast.value[termID])
return words
udf_map_termID_to_Word = udf(map_termID_to_Word , ArrayType(StringType()))
ldatopics_mapped = ldatopics.withColumn("topic_desc", udf_map_termID_to_Word(ldatopics.termIndices))

LDA consistency in Pyspark

I am using pyspark (version 2.3.1) and I am trying to reproduce the same results given the following code:
lda = LDA(k=10, seed=5, optimizer="em", featuresCol="features")
ldamodel = lda.fit(rescaledData)
ldatopics = ldamodel.describeTopics()
ldatopics.show(10)
Output 1:
+-----+--------------------+--------------------+
|topic| termIndices| termWeights|
+-----+--------------------+--------------------+
| 0|[0, 199, 2, 35, 1...|[0.02179604286102...|
| 1|[267, 142, 76, 50...|[0.01640698273265...|
| 2|[14, 6, 12, 29, 7...|[0.01542644578135...|
| 3|[279, 193, 21, 74...|[0.01304181652577...|
| 4|[12, 70, 252, 151...|[0.01104580800704...|
| 5|[9, 75, 474, 255,...|[0.01606660426132...|
| 6|[13, 4, 88, 3, 27...|[0.02825736583107...|
| 7|[42, 146, 26, 700...|[0.01156411695149...|
| 8|[89, 2, 82, 403, ...|[0.01666772169015...|
| 9|[1, 303, 411, 83,...|[0.02547416776649...|
+-----+--------------------+--------------------+
Even if I used the seed, every time I restart the application (close and re-open the notebook) I have different results. Look at the 2nd Output:
+-----+--------------------+--------------------+
|topic| termIndices| termWeights|
+-----+--------------------+--------------------+
| 0|[403, 199, 414, 1...|[0.01236421045802...|
| 1|[75, 109, 251, 5,...|[0.01551907510059...|
| 2|[12, 188, 6, 314,...|[0.01206780033644...|
| 3|[91, 76, 23, 82, ...|[0.01244511461388...|
| 4|[162, 127, 12, 14...|[0.01380643020451...|
| 5|[4, 46, 7, 220, 2...|[0.01591219626409...|
| 6|[89, 71, 272, 279...|[0.02027028435250...|
| 7|[1, 3, 13, 57, 27...|[0.02192425215634...|
| 8|[2, 0, 35, 87, 65...|[0.02033711369900...|
| 9|[194, 15, 37, 42,...|[0.01436615776405...|
+-----+--------------------+--------------------+
Notice that I have the same problem in the .transform phase (even using the seed). The code used was the following:
paramMap = {ldamodel.seed: 5}
ldaResults = ldamodel.transform(rescaledData, params=paramMap)
Do you have any hint to help me?
Many thanks,
Lorenzo

Resources