How to get the topic using pyspark LDA - apache-spark

I have used LDA for finding the topic
ref:
from pyspark.ml.clustering import LDA
lda = LDA(k=30, seed=123, optimizer="em", maxIter=10, featuresCol="features")
ldamodel = lda.fit(rescaledData)
when i run the below code i find the result with topic, termIndices and termWeights
ldatopics = ldamodel.describeTopics()
+-----+-----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|topic|termIndices |termWeights |
+-----+-----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|0 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13643518321703474, 0.09550636070102715, 0.0948398694152729, 0.05766480874922468, 0.04482014536392716, 0.04435413761288223, 0.04390248330822808, 0.042949351875241376, 0.039792489008412854, 0.03959557452696915] |
|1 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13650714786977475, 0.09549302392463813, 0.09473573233391298, 0.05765172919949817, 0.04483412655497465, 0.04437734133645192, 0.043917280973977395, 0.042966507550942196, 0.03978157454461109, 0.039593095881148545] |
|2 |[18, 14, 1, 2, 6, 15, 19, 0, 5, 11]|[0.13628285951099658, 0.0960680469091329, 0.0957882659160997, 0.05753830851677705, 0.04476298485932895, 0.04426297294164529, 0.043901044582340995, 0.04291495373567377, 0.039780501358024425, 0.03946491774918109] |
|3 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13674345908386384, 0.09541103313663375, 0.09430192089972703, 0.05770255869404108, 0.044821207138198024, 0.04441671155466873, 0.043902024784365994, 0.04294093494951478, 0.03981953622791824, 0.039614320304130965] |
|4 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.1364122091134319, 0.0953545543783122, 0.09429906593265366, 0.05772444907583193, 0.044862863343866806, 0.04442734201228477, 0.04389557512474934, 0.04296443267889805, 0.03982659626276644, 0.03962640467091713] |
|5 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13670139028438008, 0.09553267874876908, 0.0946605951853061, 0.05768305061621247, 0.04480375989378822, 0.04435320773212808, 0.043914956101645836, 0.0429208816887896, 0.03981783866996495, 0.0395571526370012] |
|6 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13655841155238757, 0.09554895256001286, 0.0949549658561804, 0.05762248195720487, 0.04480233168639564, 0.04433230827791344, 0.04390710584467472, 0.042930742282966325, 0.039814043132483344, 0.039566695745809] |
|7 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13663380500078914, 0.09550193655114965, 0.09478283807049456, 0.05766262514700366, 0.04480957336655386, 0.044349558779903736, 0.04392217503495675, 0.04293742117414056, 0.03978830895805316, 0.03953804442817271] |
|8 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13657941125100242, 0.09553782026082008, 0.09488224957791296, 0.0576562888735838, 0.04478320710416449, 0.044348589637858433, 0.043920567136734125, 0.04291002130418712, 0.03979768943659999, 0.039567161368513126] |
|9 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13679375890331896, 0.09550217013809542, 0.09456405056150366, 0.057658739173097585, 0.04482551122651224, 0.0443527568707439, 0.04392317149752475, 0.04290757658934661, 0.03979803663250257, 0.03956130101562635] |
|10 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.1363532425412472, 0.09567087909028642, 0.09540557719048867, 0.0576159180949091, 0.044781629751243876, 0.04429913174899199, 0.04391116942509204, 0.042932143462292065, 0.03978668768070522, 0.03951443679027362] |
|11 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13652547151289018, 0.09563212101150799, 0.09559109602844051, 0.057554258392101834, 0.044746350635605996, 0.044288429293621624, 0.04393701569930993, 0.04291245197810508, 0.03976607824588965, 0.039502323331991364]|
+---------++--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
then i ran the below code for transforming, which gives my whole input data with "topicDistribution", i want to know, how can i map their corresponding topic from above.
transformed = ldamodel.transform(rescaledData)
transformed.show(1000,truncate=False)
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|features |topicDistribution |
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|(20,[0,1,2,5,6,7,9,10,11,12,14,15,17,18],[1.104871083111864,2.992600400740947,2.6744520065699304,1.3684833461617696,2.2505718610977468,1.667281425641495,0.4887254650789885,1.6871692149804491,1.2395584596083011,1.1654793679713804,1.1856289804763125,0.8839314023462759,1.3072768360965874,1.8484873888968552]) |[0.033335920691559356,0.03333070077028474,0.033333697225675085,0.03333561564482505,0.03333642468043084,0.033329348877568624,0.033332822489008374,0.0333260618235691,0.0333303226504357,0.03333271017792987,0.03332747782859632,0.03332947925144302,0.03332588375760626,0.03333450406000825,0.03333113684458691,0.03333479529731098,0.03333650038104409,0.03333779171890448,0.033330895429371024,0.03333346818374588,0.03333587374669536,0.03333749064143864,0.033334190553966234,0.03333150654339392,0.03333448663832666,0.03333989460983954,0.03333245526240933,0.03333885169495637,0.03333318824494455,0.03333650428012542] |
|(20,[1,3,9,10,11,18],[0.997533466913649,1.2726157969776226,1.4661763952369655,1.6871692149804491,1.2395584596083011,1.2323249259312368]) |[0.03333282248535624,0.033332445609854544,0.033328001364314915,0.03333556680774651,0.03333511480351705,0.03333350870876004,0.03333187395000147,0.03333445718728295,0.03333444971729848,0.03333408945721941,0.033330403594416406,0.03333043254420633,0.03333157466407796,0.033332770049699215,0.033328359180981044,0.03333228442030213,0.03333371386494732,0.03333681533414124,0.0333317544337193,0.0333325923806738,0.03333507835375717,0.0333340926283776,0.033334289011977415,0.03333710235559251,0.033333843031028085,0.03333431818041727,0.03333321848222645,0.03333712581455898,0.03333342489994699,0.033334476683601115] |
|(20,[0,1,2,5,6,7,9,10,11,12,14,15,17,18],[1.104871083111864,2.992600400740947,2.6744520065699304,1.3684833461617696,2.2505718610977468,1.667281425641495,0.4887254650789885,1.6871692149804491,1.2395584596083011,1.1654793679713804,1.1856289804763125,0.8839314023462759,1.3072768360965874,1.8484873888968552]) |[0.03333351262221663,0.03334603652469889,0.03332161942705381,0.03333027102774255,0.03333390930010211,0.03333191764996061,0.033329934469446876,0.033333299309223546,0.033341041908943575,0.0333327792244044,0.0333317683142701,0.03332955596634091,0.033328132222973345,0.033337885062247365,0.03332933802183438,0.033340558825094374,0.03332932381968907,0.033330915221236455,0.03333188695334856,0.033327196223241304,0.0333344223760765,0.03332799716207421,0.033338186672894565,0.033336730538183736,0.03333440038338333,0.033337949318794795,0.033333179017769575,0.03333720003330217,0.0333312634707931,0.033337788932659276] |
|(20,[7,10,14,15,18],[1.667281425641495,1.6871692149804491,1.1856289804763125,0.8839314023462759,1.2323249259312368]) |[0.03333299162768077,0.03333267864965697,0.03333419860453984,0.0333335178997543,0.03333360724598342,0.03333242559458429,0.03333347136144222,0.03333355459499807,0.03333358147457144,0.03333288221459248,0.033333183195193086,0.03333412789444317,0.033334780263427975,0.033333456529848454,0.033336348770946406,0.03333413810342763,0.03333364116227056,0.03333271328980694,0.033333953427323364,0.03333358082109339,0.03333219641721605,0.03333309878777325,0.03333357447436248,0.03333149527336071,0.03333364840046119,0.03333405705432205,0.03333193039244976,0.033332142999523764,0.0333329476480237,0.033332075826922346] |
|(20,[7,10,14,15,18],[1.667281425641495,1.6871692149804491,1.1856289804763125,0.8839314023462759,1.2323249259312368]) |[0.033333498405871936,0.033333280285774286,0.03333426144981626,0.033332860785526656,0.033333582231251914,0.03333268142499894,0.0333323342942777,0.033333434542768936,0.03333306424165755,0.0333326212718143,0.03333359821538673,0.03333408970522017,0.03333440903364261,0.033333628480934824,0.033335988240581156,0.03333411840979216,0.03333341827050789,0.033332367335435646,0.033334058739684515,0.03333355988453005,0.033332524988062315,0.03333411721155432,0.033333323503042835,0.03333212784312619,0.0333335375674124,0.03333359055906317,0.033332183680577915,0.033332671344198254,0.03333288800754154,0.03333218004594688] |
|(20,[2,5,9,16,17,18],[0.8914840021899768,1.3684833461617696,0.4887254650789885,1.63668714555745,2.614553672193175,0.6161624629656184]) |[0.033334335349062126,0.033335078498013884,0.03332587515068904,0.03333663657766274,0.03333680954467168,0.03333423911759284,0.03333228558123402,0.03333307701307248,0.03333230070449398,0.03333357555321742,0.03333013898196302,0.033329558017855816,0.033331357742536066,0.03333281031681836,0.03332620861451679,0.03333164736607878,0.033333411422713344,0.03333811924641809,0.03333032668669555,0.03333155312673253,0.03333601811939769,0.03333261877594549,0.03333490102064452,0.033338814952549214,0.03333297326837463,0.03333518474628696,0.03333659535650031,0.033334882391957275,0.03333262407795885,0.03333604267834657] |
|(20,[9,14,17],[0.4887254650789885,1.1856289804763125,1.3072768360965874]) |[0.03333373548732593,0.03333366047427678,0.0333388921917903,0.0333324500674453,0.03333165231618069,0.033333396263220766,0.03333392152736817,0.03333280784507405,0.03333341964532808,0.033332027450640234,0.03333590751407632,0.03333625709954709,0.033335027917460715,0.033332984179205334,0.0333395676850832,0.03333465555831355,0.03333317988602309,0.03332999573625669,0.033335686662720146,0.03333444862817325,0.033331118794333245,0.03333227594267816,0.03333337815403729,0.03332815318734183,0.03333329306551438,0.033332463398394074,0.03333211180252395,0.03332955750997137,0.03333267853415129,0.033331295475544864] |
|(20,[0,1,2,5,6,8,9,12,14,19],[4.419484332447456,8.97780120222284,0.8914840021899768,1.3684833461617696,1.1252859305488734,1.233681992250468,0.4887254650789885,1.1654793679713804,42.68264329714725,3.685472474556911]) |[0.03333847676223903,0.03333217319316553,0.03361190962478921,0.03322296565909072,0.033224152351515865,0.033305780550557426,0.0333623247307499,0.03333530792027193,0.03335048739584223,0.03328160880463601,0.0334726823486143,0.03350523767287266,0.0334249096308398,0.0333474750843313,0.033626678634369765,0.033399676341965154,0.033312708717161896,0.03316394279048787,0.033455734324711994,0.033388088146836026,0.03321934579333965,0.03331319196962562,0.03332496308901043,0.03308256004388311,0.033355283259340625,0.03328280644324275,0.0332436575330529,0.03319398288438216,0.03331402807784297,0.03320786022123095] |
|(20,[2,8,10,13,15,16],[0.8914840021899768,3.701045976751404,1.6871692149804491,1.4594938792746623,0.8839314023462759,1.63668714555745]) |[0.0333356286958193,0.033332968808115676,0.03332196061288262,0.03333752852282297,0.03334171827141663,0.033333866506025614,0.03332902614535797,0.03332999721261129,0.03333228759033127,0.03333419089119785,0.033326295405433394,0.03332727748411373,0.03332970492260472,0.03333218324896459,0.033323419320752445,0.03333484772134479,0.033331408271507906,0.03333886892339302,0.033330223503670105,0.03333357612919576,0.033337166312023804,0.03333359583540509,0.033335058229509947,0.03334147190850166,0.03333378789809434,0.033336369234199595,0.033338417163238966,0.03333500103939193,0.03333456204893585,0.03333759214313734] |
|(20,[0,1,5,9,11,14,18],[2.209742166223728,1.995066933827298,1.3684833461617696,0.4887254650789885,1.2395584596083011,1.1856289804763125,4.313137240759328]) |[0.033331568372787335,0.03333199604802443,0.03333616626240888,0.03333311434679797,0.03332871558684149,0.03333377685027234,0.03333412369772984,0.03333326223955321,0.03333302350234804,0.033334011390192264,0.033332998734646985,0.03333572338531835,0.03333519102733107,0.033333669722909055,0.03333367298003543,0.03333316052548548,0.03333434503087792,0.03333204101430685,0.03333390212404242,0.03333317825522851,0.03333468470990933,0.0333342373715466,0.03333384912615554,0.03333262239313073,0.033332130871257,0.03333157173270879,0.033331620213171494,0.03333491236454045,0.03333364390007425,0.03333308622036798] |
|(20,[0,1,2,7,9,12,14,16,17],[1.104871083111864,1.995066933827298,1.7829680043799536,1.667281425641495,0.4887254650789885,3.4964381039141412,1.1856289804763125,1.63668714555745,1.3072768360965874]) |[0.033335136368759725,0.033335270064671074,0.033333243143688636,0.03333135191450168,0.03333521337181425,0.03333250351809707,0.03333535436589356,0.03333322757251989,0.033333747503130604,0.03333235416690325,0.03333273826809218,0.03333297626192852,0.033333297925505816,0.03333285267847805,0.033334741557111344,0.03333582976624498,0.03333192986853772,0.0333314389004771,0.03333186692894785,0.033332539165631093,0.03333280721245933,0.03333326073074502,0.03333359623114709,0.03333261337071022,0.03333404937430939,0.03333536871228382,0.033334559939835924,0.03333095099987788,0.0333326510660671,0.033332529051629894] |
|(20,[2,5],[0.8914840021899768,1.3684833461617696]) |[0.033333008832987246,0.03333320773569408,0.03333212285409669,0.03333370183210235,0.03333408881228382,0.03333369017540667,0.03333315896362525,0.03333316321736793,0.033333273291611384,0.033332955144148683,0.03333299410668819,0.03333254807040678,0.03333302471306508,0.03333324003641127,0.033332032586187325,0.0333328866035473,0.033333366369446636,0.03333432146865401,0.033333080986559745,0.033332919796171534,0.03333365180032741,0.03333319727206391,0.0333336947744764,0.033334184463059975,0.033333428924387516,0.033333998899679654,0.033333803141837516,0.03333374053474702,0.03333339151988759,0.03333412307307108] |

In order to remap the terminindices to words you have to access the vocabulary of the CountVectorizer model. Please have a look at the pseudocode below:
from pyspark.sql.functions import udf
from pyspark.sql.types import *
#Reading your data...
cv = CountVectorizer(inputCol="some", outputCol="features", vocabSize=2000)
cvmodel = cv.fit(df)
#other stuff
vocab = cvmodel.vocabulary
vocab_broadcast = sc.broadcast(vocab)
#creating LDA model
ldatopics = ldamodel.describeTopics()
def map_termID_to_Word(termIndices):
words = []
for termID in termIndices:
words.append(vocab_broadcast.value[termID])
return words
udf_map_termID_to_Word = udf(map_termID_to_Word , ArrayType(StringType()))
ldatopics_mapped = ldatopics.withColumn("topic_desc", udf_map_termID_to_Word(ldatopics.termIndices))

Related

Select nth row after orderby in pyspark dataframe

I want to select the second row for each group of names. I used orderby to sort by name and then the purchase date/timestamp. It is important that I select the second purchase for each name (by datetime).
Here is the data to build dataframe:
data = [
('George', datetime(2020, 3, 24, 3, 19, 58), datetime(2018, 2, 24, 3, 22, 55)),
('Andrew', datetime(2019, 12, 12, 17, 21, 30), datetime(2019, 7, 21, 2, 14, 22)),
('Micheal', datetime(2018, 11, 22, 13, 29, 40), datetime(2018, 5, 17, 8, 10, 19)),
('Maggie', datetime(2019, 2, 8, 3, 31, 23), datetime(2019, 5, 19, 6, 11, 33)),
('Ravi', datetime(2019, 1, 1, 4, 19, 47), datetime(2019, 1, 1, 4, 22, 55)),
('Xien', datetime(2020, 3, 2, 4, 33, 51), datetime(2020, 5, 21, 7, 11, 50)),
('George', datetime(2020, 3, 24, 3, 19, 58), datetime(2020, 3, 24, 3, 22, 45)),
('Andrew', datetime(2019, 12, 12, 17, 21, 30), datetime(2019, 9, 19, 1, 14, 11)),
('Micheal', datetime(2018, 11, 22, 13, 29, 40), datetime(2018, 8, 19, 7, 11, 37)),
('Maggie', datetime(2019, 2, 8, 3, 31, 23), datetime(2018, 2, 19, 6, 11, 42)),
('Ravi', datetime(2019, 1, 1, 4, 19, 47), datetime(2019, 1, 1, 4, 22, 17)),
('Xien', datetime(2020, 3, 2, 4, 33, 51), datetime(2020, 6, 21, 7, 11, 11)),
('George', datetime(2020, 3, 24, 3, 19, 58), datetime(2020, 4, 24, 3, 22, 54)),
('Andrew', datetime(2019, 12, 12, 17, 21, 30), datetime(2019, 8, 30, 3, 12, 41)),
('Micheal', datetime(2018, 11, 22, 13, 29, 40), datetime(2017, 5, 17, 8, 10, 38)),
('Maggie', datetime(2019, 2, 8, 3, 31, 23), datetime(2020, 3, 19, 6, 11, 12)),
('Ravi', datetime(2019, 1, 1, 4, 19, 47), datetime(2018, 2, 1, 4, 22, 24)),
('Xien', datetime(2020, 3, 2, 4, 33, 51), datetime(2018, 9, 21, 7, 11, 41)),
]
df = sqlContext.createDataFrame(data, ['name', 'trial_start', 'purchase'])
df.show(truncate=False)
I order the data by name and then purchase
df.orderBy("name","purchase").show()
to produce the result:
+-------+-------------------+-------------------+
| name| trial_start| purchase|
+-------+-------------------+-------------------+
| Andrew|2019-12-12 22:21:30|2019-07-21 06:14:22|
| Andrew|2019-12-12 22:21:30|2019-08-30 07:12:41|
| Andrew|2019-12-12 22:21:30|2019-09-19 05:14:11|
| George|2020-03-24 07:19:58|2018-02-24 08:22:55|
| George|2020-03-24 07:19:58|2020-03-24 07:22:45|
| George|2020-03-24 07:19:58|2020-04-24 07:22:54|
| Maggie|2019-02-08 08:31:23|2018-02-19 11:11:42|
| Maggie|2019-02-08 08:31:23|2019-05-19 10:11:33|
| Maggie|2019-02-08 08:31:23|2020-03-19 10:11:12|
|Micheal|2018-11-22 18:29:40|2017-05-17 12:10:38|
|Micheal|2018-11-22 18:29:40|2018-05-17 12:10:19|
|Micheal|2018-11-22 18:29:40|2018-08-19 11:11:37|
| Ravi|2019-01-01 09:19:47|2018-02-01 09:22:24|
| Ravi|2019-01-01 09:19:47|2019-01-01 09:22:17|
| Ravi|2019-01-01 09:19:47|2019-01-01 09:22:55|
| Xien|2020-03-02 09:33:51|2018-09-21 11:11:41|
| Xien|2020-03-02 09:33:51|2020-05-21 11:11:50|
| Xien|2020-03-02 09:33:51|2020-06-21 11:11:11|
+-------+-------------------+-------------------+
How might I get the second row for each name? In pandas it was easy. I could just use nth. I have been looking at sql but have not found a solution. Any suggestions appreciated.
The output I am looking for would be:
+-------+-------------------+-------------------+
| name| trial_start| purchase|
+-------+-------------------+-------------------+
| Andrew|2019-12-12 22:21:30|2019-08-30 07:12:41|
| George|2020-03-24 07:19:58|2020-03-24 07:22:45|
| Maggie|2019-02-08 08:31:23|2019-05-19 10:11:33|
|Micheal|2018-11-22 18:29:40|2018-05-17 12:10:19|
| Ravi|2019-01-01 09:19:47|2019-01-01 09:22:17|
| Xien|2020-03-02 09:33:51|2020-05-21 11:11:50|
+-------+-------------------+-------------------+
Try with window row_number() function then filter only the 2 row after ordering by purchase.
Example:
from pyspark.sql import *
from pyspark.sql.functions import *
w=Window.partitionBy("name").orderBy(col("purchase"))
df.withColumn("rn",row_number().over(w)).filter(col("rn") ==2).drop(*["rn"]).show()
SQL Api:
df.createOrReplaceTempView("tmp")
spark.sql("SET spark.sql.parser.quotedRegexColumnNames=true")
sql("select `(rn)?+.+` from (select *,row_number() over(partition by name order by purchase) rn from tmp) e where rn =2").\
show()

Swap pair of elements along an axis

I have a 2d numpy array as such:
import numpy as np
a = np.arange(20).reshape((2,10))
# array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]])
I want to swap pairs of elements in each row. The desired output looks like this:
# array([[ 9, 0, 2, 1, 4, 3, 6, 5, 8, 7],
# [19, 10, 12, 11, 14, 13, 16, 15, 18, 17]])
I managed to find a solution in 1d:
a = np.arange(10)
# does the job for all pairs except the first
output = np.roll(np.flip(np.roll(a,-1).reshape((-1,2)),1).flatten(),2)
# first pair done manually
output[0] = a[-1]
output[1] = a[0]
Any ideas on a "numpy only" solution for the 2d case ?
Owing to the first pair not exactly subscribing to the usual pair swap, we can do that separately. For the rest, it would relatively straight-forward with reshaping to split axes and flip axis. Hence, it would be -
In [42]: a # 2D input array
Out[42]:
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]])
In [43]: b2 = a[:,1:-1].reshape(a.shape[0],-1,2)[...,::-1].reshape(a.shape[0],-1)
In [44]: np.hstack((a[:,[-1,0]],b2))
Out[44]:
array([[ 9, 0, 2, 1, 4, 3, 6, 5, 8, 7],
[19, 10, 12, 11, 14, 13, 16, 15, 18, 17]])
Alternatively, stack and then reshape+flip-axis -
In [50]: a1 = np.hstack((a[:,[0,-1]],a[:,1:-1]))
In [51]: a1.reshape(a.shape[0],-1,2)[...,::-1].reshape(a.shape[0],-1)
Out[51]:
array([[ 9, 0, 2, 1, 4, 3, 6, 5, 8, 7],
[19, 10, 12, 11, 14, 13, 16, 15, 18, 17]])

How to limit functions.collect_set in Spark SQL?

I'm dealing with a column of numbers in a large spark DataFrame, and I would like to create a new column that stores an aggregated list of unique numbers that appear in that column.
Basically exactly what functions.collect_set does. However, i only need up to 1000 elements in the aggregated list. Is there any way to pass that parameter somehow to functions.collect_set(), or any other way to get only up to 1000 elements in the aggregated list, without using a UDAF?
Since the column is so large, I'd like to avoid collecting all elements and trimming the list afterwards.
Thanks!
Spark 2.4
As pointed out in a comment, Spark 2.4.0 comes with slice standard function which can do this sort of thing.
val usage = sql("describe function slice").as[String].collect()(2)
scala> println(usage)
Usage: slice(x, start, length) - Subsets array x starting from index start (array indices start at 1, or starting from the end if start is negative) with the specified length.
That gives the following query:
val q = input
.groupBy('key)
.agg(collect_set('id) as "collect")
.withColumn("three_only", slice('collect, 1, 3))
scala> q.show(truncate = false)
+---+--------------------------------------+------------+
|key|collect |three_only |
+---+--------------------------------------+------------+
|0 |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|[0, 15, 30] |
|1 |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|[1, 16, 31] |
|3 |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|[33, 48, 13]|
|2 |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|[12, 27, 37]|
|4 |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|[9, 19, 34] |
+---+--------------------------------------+------------+
Before Spark 2.4
I'd use a UDF that would do what you want after collect_set (or collect_list) or a much harder UDAF.
Given more experience with UDFs, I'd go with that first. Even though UDFs are not optimized, for this use case it's fine.
val limitUDF = udf { (nums: Seq[Long], limit: Int) => nums.take(limit) }
val sample = spark.range(50).withColumn("key", $"id" % 5)
scala> sample.groupBy("key").agg(collect_set("id") as "all").show(false)
+---+--------------------------------------+
|key|all |
+---+--------------------------------------+
|0 |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|
|1 |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|
|3 |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|
|2 |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|
|4 |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|
+---+--------------------------------------+
scala> sample.
groupBy("key").
agg(collect_set("id") as "all").
withColumn("limit(3)", limitUDF($"all", lit(3))).
show(false)
+---+--------------------------------------+------------+
|key|all |limit(3) |
+---+--------------------------------------+------------+
|0 |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|[0, 15, 30] |
|1 |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|[1, 16, 31] |
|3 |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|[33, 48, 13]|
|2 |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|[12, 27, 37]|
|4 |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|[9, 19, 34] |
+---+--------------------------------------+------------+
See functions object (for udf function's docs).
I'm using a modified copy of the collect_set and collect_list functions; because of code scopes, the modified copies must be in the same package path as the originals. The linked code works for Spark 2.1.0; if you are using a prior version, method signatures may be different.
Throw this file (https://gist.github.com/lokkju/06323e88746c85b2ce4de3ea9cdef9bc) into your project as src/main/org/apache/spark/sql/catalyst/expression/collect_limit.scala
use it as:
import org.apache.spark.sql.catalyst.expression.collect_limit._
df.groupBy('set_col).agg(collect_set_limit('set_col,1000)
scala> df.show
+---+-----+----+--------+
| C0| C1| C2| C3|
+---+-----+----+--------+
| 10| Name|2016| Country|
| 11|Name1|2016|country1|
| 10| Name|2016| Country|
| 10| Name|2016| Country|
| 12|Name2|2017|Country2|
+---+-----+----+--------+
scala> df.groupBy("C1").agg(sum("C0"))
res36: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]
scala> res36.show
+-----+-------+
| C1|sum(C0)|
+-----+-------+
|Name1| 11|
|Name2| 12|
| Name| 30|
+-----+-------+
scala> df.limit(2).groupBy("C1").agg(sum("C0"))
res33: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]
scala> res33.show
+-----+-------+
| C1|sum(C0)|
+-----+-------+
| Name| 10|
|Name1| 11|
+-----+-------+
scala> df.groupBy("C1").agg(sum("C0")).limit(2)
res2: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]
scala> res2.show
+-----+-------+
| C1|sum(C0)|
+-----+-------+
|Name1| 11|
|Name2| 12|
+-----+-------+
scala> df.distinct
res8: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string]
scala> res8.show
+---+-----+----+--------+
| C0| C1| C2| C3|
+---+-----+----+--------+
| 11|Name1|2016|country1|
| 10| Name|2016| Country|
| 12|Name2|2017|Country2|
+---+-----+----+--------+
scala> df.dropDuplicates(Array("c1"))
res11: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string]
scala> res11.show
+---+-----+----+--------+
| C0| C1| C2| C3|
+---+-----+----+--------+
| 11|Name1|2016|country1|
| 12|Name2|2017|Country2|
| 10| Name|2016| Country|
+---+-----+----+--------+
As other answers have mentioned, the performant way of doing this would be to write a UDAF. Unfortunately the UDAF API is actually not as extensible as the aggregate functions that ship with spark. However you can use their internal APIs to build on the internal functions to do what you need.
Here is an implementation for collect_set_limit that is mostly a copy past of Spark's internal CollectSet AggregateFunction. I would just extend it but its a case class. Really all that's needed is to override update and merge methods to respect a passed in limit:
case class CollectSetLimit(
child: Expression,
limitExp: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] {
val limit = limitExp.eval( null ).asInstanceOf[Int]
def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty
override def update(buffer: mutable.HashSet[Any], input: InternalRow): mutable.HashSet[Any] = {
if( buffer.size < limit ) super.update(buffer, input)
else buffer
}
override def merge(buffer: mutable.HashSet[Any], other: mutable.HashSet[Any]): mutable.HashSet[Any] = {
if( buffer.size >= limit ) buffer
else buffer ++= other.take( limit - buffer.size )
}
override def prettyName: String = "collect_set_limit"
}
And to actually register it, we can do it through Spark's internal FunctionRegistry which takes in the name and the builder which is effectively a function that creates a CollectSetLimit using the provided expressions:
val collectSetBuilder = (args: Seq[Expression]) => CollectSetLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_set_limit", collectSetBuilder )
Edit:
Turns out adding it to the builtin only works if you haven't created the SparkContext yet as it makes an immutable clone on startup. If you have an existing context then this should work to add it with reflection:
val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_set_limit", collectSetBuilder )
use take
val firstThousand = rdd.take(1000)
Will return the first 1000.
Collect also has a filter function that can be provided. That would allow you to be more specific as to what is returned.

Strange Behaviour when Updating Cassandra row

I am using pyspark and pyspark-cassandra.
I have noticed this behaviour on multiple versions of Cassandra(3.0.x and 3.6.x) using COPY, sstableloader, and now saveToCassandra in pyspark.
I have the following schema
CREATE TABLE test (
id int,
time timestamp,
a int,
b int,
c int,
PRIMARY KEY ((id), time)
) WITH CLUSTERING ORDER BY (time DESC);
and the following data
(1, datetime.datetime(2015, 3, 1, 0, 18, 18, tzinfo=<UTC>), 1, 0, 0)
(1, datetime.datetime(2015, 3, 1, 0, 19, 12, tzinfo=<UTC>), 0, 1, 0)
(1, datetime.datetime(2015, 3, 1, 0, 22, 59, tzinfo=<UTC>), 1, 0, 0)
(1, datetime.datetime(2015, 3, 1, 0, 23, 52, tzinfo=<UTC>), 0, 1, 0)
(1, datetime.datetime(2015, 3, 1, 0, 32, 2, tzinfo=<UTC>), 1, 1, 0)
(1, datetime.datetime(2015, 3, 1, 0, 32, 8, tzinfo=<UTC>), 0, 2, 0)
(1, datetime.datetime(2015, 3, 1, 0, 43, 30, tzinfo=<UTC>), 1, 1, 0)
(1, datetime.datetime(2015, 3, 1, 0, 44, 12, tzinfo=<UTC>), 0, 2, 0)
(1, datetime.datetime(2015, 3, 1, 0, 48, 49, tzinfo=<UTC>), 1, 1, 0)
(1, datetime.datetime(2015, 3, 1, 0, 49, 7, tzinfo=<UTC>), 0, 2, 0)
(1, datetime.datetime(2015, 3, 1, 0, 50, 5, tzinfo=<UTC>), 1, 1, 0)
(1, datetime.datetime(2015, 3, 1, 0, 50, 53, tzinfo=<UTC>), 0, 2, 0)
(1, datetime.datetime(2015, 3, 1, 0, 51, 53, tzinfo=<UTC>), 1, 1, 0)
(1, datetime.datetime(2015, 3, 1, 0, 51, 59, tzinfo=<UTC>), 0, 2, 0)
(1, datetime.datetime(2015, 3, 1, 0, 54, 35, tzinfo=<UTC>), 1, 1, 0)
(1, datetime.datetime(2015, 3, 1, 0, 55, 28, tzinfo=<UTC>), 0, 2, 0)
(1, datetime.datetime(2015, 3, 1, 0, 55, 55, tzinfo=<UTC>), 1, 2, 0)
(1, datetime.datetime(2015, 3, 1, 0, 56, 24, tzinfo=<UTC>), 0, 3, 0)
(1, datetime.datetime(2015, 3, 1, 1, 11, 14, tzinfo=<UTC>), 1, 2, 0)
(1, datetime.datetime(2015, 3, 1, 1, 11, 17, tzinfo=<UTC>), 2, 1, 0)
(1, datetime.datetime(2015, 3, 1, 1, 12, 8, tzinfo=<UTC>), 1, 2, 0)
(1, datetime.datetime(2015, 3, 1, 1, 12, 10, tzinfo=<UTC>), 0, 3, 0)
(1, datetime.datetime(2015, 3, 1, 1, 17, 43, tzinfo=<UTC>), 1, 2, 0)
(1, datetime.datetime(2015, 3, 1, 1, 17, 49, tzinfo=<UTC>), 0, 3, 0)
(1, datetime.datetime(2015, 3, 1, 1, 24, 12, tzinfo=<UTC>), 1, 2, 0)
(1, datetime.datetime(2015, 3, 1, 1, 24, 18, tzinfo=<UTC>), 2, 1, 0)
(1, datetime.datetime(2015, 3, 1, 1, 24, 18, tzinfo=<UTC>), 1, 2, 0)
(1, datetime.datetime(2015, 3, 1, 1, 24, 24, tzinfo=<UTC>), 2, 1, 0)
Towards the end of the data, there are two rows which have the same timestamp.
(1, datetime.datetime(2015, 3, 1, 1, 24, 18, tzinfo=<UTC>), 2, 1, 0)
(1, datetime.datetime(2015, 3, 1, 1, 24, 18, tzinfo=<UTC>), 1, 2, 0)
It is my understanding that when I save to Cassandra, one of these will "win" - there will only be one row.
After writing to cassandra using
rdd.saveToCassandra(keyspace, table, ['id', 'time', 'a', 'b', 'c'])
Neither row appears to have won. Rather, the rows seem to have "merged".
1 | 2015-03-01 01:17:43+0000 | 1 | 2 | 0
1 | 2015-03-01 01:17:49+0000 | 0 | 3 | 0
1 | 2015-03-01 01:24:12+0000 | 1 | 2 | 0
1 | 2015-03-01 01:24:18+0000 | 2 | 2 | 0
1 | 2015-03-01 01:24:24+0000 | 2 | 1 | 0
Rather than the 2015-03-01 01:24:18+0000 containing (1, 2, 0) or (2, 1, 0), it contains (2, 2, 0).
What is happening here? I can't for the life of me figure out this behaviour is being caused.
This is a little known effect that comes from the batching together of data. Batching writes assigns the same timestamp to all Inserts in the batch. Next, if two writes are done with the exact same timestamp then there is a special merge rule since there was no "last" write. The Spark Cassandra Connector uses intra-partition batches by default so this is very likely to happen if you have this kind of clobbering of values.
The behavior with two identical write timestamps is a merge based on the Greater value.
Given Table (key, a, b)
Batch
Insert "foo", 2, 1
Insert "foo", 1, 2
End batch
The batch gives both mutations the same timestamp. Cassandra cannot chose a "last-written" since they both happened at the same time, instead it just chooses the greater value of the two. The merged result will be
"foo", 2, 2

Load data from file and normalize

How to normalize data loaded from file? Here what I have. Data looks kind of like this:
65535, 3670, 65535, 3885, -0.73, 1
65535, 3962, 65535, 3556, -0.72, 1
Last value in each line is a target. I want to have the same structure of the data but with normalized values.
import numpy as np
dataset = np.loadtxt('infrared_data.txt', delimiter=',')
# select first 5 columns as the data
X = dataset[:, 0:5]
# is that correct? Should I normalize along 0 axis?
normalized_X = preprocessing.normalize(X, axis=0)
y = dataset[:, 5]
Now the question is, how to pack correctly normalized_X and y back, that it has the structure:
dataset = [[normalized_X[0], y[0]],[normalized_X[1], y[1]],...]
It sounds like you're asking for np.column_stack. For example, let's set up some dummy data:
import numpy as np
x = np.arange(25).reshape(5, 5)
y = np.arange(5) + 1000
Which gives us:
X:
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]])
Y:
array([1000, 1001, 1002, 1003, 1004])
And we want:
new = np.column_stack([x, y])
Which gives us:
New:
array([[ 0, 1, 2, 3, 4, 1000],
[ 5, 6, 7, 8, 9, 1001],
[ 10, 11, 12, 13, 14, 1002],
[ 15, 16, 17, 18, 19, 1003],
[ 20, 21, 22, 23, 24, 1004]])
If you'd prefer less typing, you can also use:
In [4]: np.c_[x, y]
Out[4]:
array([[ 0, 1, 2, 3, 4, 1000],
[ 5, 6, 7, 8, 9, 1001],
[ 10, 11, 12, 13, 14, 1002],
[ 15, 16, 17, 18, 19, 1003],
[ 20, 21, 22, 23, 24, 1004]])
However, I'd discourage using np.c_ for anything other than interactive use, simply due to readability concerns.

Resources