array_contains() on array of structs Column is not Iterable - apache-spark

I have 2 columns that has this schema:
root
|-- parent_column: array (nullable = true)
| |-- element: struct (containsNull = false)
| | |-- item_1: integer (nullable = true)
| | |-- item_2: long (nullable = true)
| | |-- item_3: integer (nullable = true)
| | |-- item_4: boolean (nullable = true)
|-- child_column: struct (nullable = false)
| |-- item_1: integer (nullable = true)
| |-- item_2: long (nullable = true)
| |-- item_3: integer (nullable = true)
| |-- item_4: boolean (nullable = false)
I wanted to check if the child_column exists in the parent_column by doing array_contains(F.col('parent_column'), F.col('child_column')) but I am running into Column is not Iterable error.
Sample data:
+----------------------------------------------+--------------------------------------------+--------------+
|parent_column | child_column | data_check |
+----------------------------------------------+--------------------------------------------+--------------+
|[[1, 2, 3, 4, false]] | [1, 2, 3, 4, false] | true |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
|[[1, 2, 3, 4, false]] | [6, 7, 8, 9, false] | false |
+----------------------------------------------+--------------------------------------------+--------------+
Sample runnable code:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[([(2, 2, 2,)],)],
'parent_column:array<struct<item_1:bigint,item_2:bigint,item_3:bigint>>'
)
df = df.withColumn(
'child_column',
F.expr("transform(parent_column, x -> struct(x.item_1 as item_1, x.item_2 as item_2, x.item_3 as item_3))")
)
# WITH ERRORS
# df = df.withColumn(
# 'contains',
# F.array_contains(F.col('parent_column'), F.col('child_column'))
# )
df.show(truncate=False)
In my mind I am checking if a struct exists in an array of structs. So I am not sure why I am getting this error. Any tips?

Seems like you sample data is off. I fixed it. See child column definition. Not sure if this is your problem with the original query.
>>> from pyspark.sql import functions as F
>>> df = spark.createDataFrame(
... [([(2, 2, 2,)],)],
... 'parent_column:array<struct<item_1:bigint,item_2:bigint,item_3:bigint>>'
... )
>>>
>>> df = df.withColumn(
... 'child_column',
... F.expr("transform(parent_column, x -> struct(x.item_1 as item_1, x.item_2 as item_2, x.item_3 as item_3))")
... )
>>> df = df.withColumn(
... 'child_column',
... F.expr("transform(parent_column, x -> struct(x.item_1 as item_1, x.item_2 as item_2, x.item_3 as item_3))")[0])
>>> df.withColumn( 'contains',expr(" array_contains(parent_column, child_column )" )).show()
+-------------+------------+--------+
|parent_column|child_column|contains|
+-------------+------------+--------+
| [[2, 2, 2]]| [2, 2, 2]| true|
+-------------+------------+--------+

Related

pyspark sort array of it's array's value

I have the following df:
+--------------------+
| id| id_info|
+--------------------+
|id_1| [[1, 8, 2, "bar"], [5, 9, 2, "foo"], [4, 3, 2, "something"], [9, null, 2, "this_is_null"]] |
I would like this sorted by the second element in descending order, so:
+--------------------+
| id| id_info|
+--------------------+
|id_1| [[5, 9, 2, "foo"], [1, 8, 2, "bar"], [4, 3, 2, "something"], [9, null, 2, "this_is_null"]] |
I came up with something like this :
def def_sort(x):
return sorted(x, key=lambda x:x[1], reverse=True)
udf_sort = F.udf(def_sort, T.ArrayType(T.ArrayType(T.IntegerType())))
df.select("id", udf_sort("id_info"))
I'm not sure how to handle null values like this, also is there maybe a built-in function for this? Can I somehow do it with F.array_sort?
The elements of the array contain integers and a string, so I assume that the column id_info is an array of structs.
So the schema of the input data would be similiar to
root
|-- id: string (nullable = true)
|-- id_info: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- col1: integer (nullable = true)
| | |-- col2: integer (nullable = true)
| | |-- col3: integer (nullable = true)
| | |-- col4: string (nullable = true)
The names of the elements of the struct might be different.
With this schema information we can use array_sort to order the array:
df.selectExpr("array_sort(id_info, (l,r) -> \
case when l['col2'] > r['col2'] then -1 else 1 end) as sorted") \
.show(truncate=False)
prints
+----------------------------------------------------------------------------------+
|sorted |
+----------------------------------------------------------------------------------+
|[{5, 9, 2, foo}, {1, 8, 2, bar}, {4, 3, 2, something}, {9, null, 2, this_is_null}]|
+----------------------------------------------------------------------------------+
You can try explode folowed by orderby on id and second element on descending order, then groupBy + collect_list:
out = (sdf.select("*",F.explode("id_info").alias("element"))
.withColumn("second_ele",F.element_at("element",2))
.orderBy(*["id",F.desc("second_ele")])
.groupBy("id").agg(F.collect_list("element").alias("id_info"))
)
out.show(truncate=False)
+----+-----------------------------------------------------------------------+
|id |id_info |
+----+-----------------------------------------------------------------------+
|id_1|[[5, 9, 2, null], [1, 8, 2, null], [4, 3, 2, null], [9, null, 2, null]]|
+----+-----------------------------------------------------------------------+

How to get the topic using pyspark LDA

I have used LDA for finding the topic
ref:
from pyspark.ml.clustering import LDA
lda = LDA(k=30, seed=123, optimizer="em", maxIter=10, featuresCol="features")
ldamodel = lda.fit(rescaledData)
when i run the below code i find the result with topic, termIndices and termWeights
ldatopics = ldamodel.describeTopics()
+-----+-----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|topic|termIndices |termWeights |
+-----+-----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|0 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13643518321703474, 0.09550636070102715, 0.0948398694152729, 0.05766480874922468, 0.04482014536392716, 0.04435413761288223, 0.04390248330822808, 0.042949351875241376, 0.039792489008412854, 0.03959557452696915] |
|1 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13650714786977475, 0.09549302392463813, 0.09473573233391298, 0.05765172919949817, 0.04483412655497465, 0.04437734133645192, 0.043917280973977395, 0.042966507550942196, 0.03978157454461109, 0.039593095881148545] |
|2 |[18, 14, 1, 2, 6, 15, 19, 0, 5, 11]|[0.13628285951099658, 0.0960680469091329, 0.0957882659160997, 0.05753830851677705, 0.04476298485932895, 0.04426297294164529, 0.043901044582340995, 0.04291495373567377, 0.039780501358024425, 0.03946491774918109] |
|3 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13674345908386384, 0.09541103313663375, 0.09430192089972703, 0.05770255869404108, 0.044821207138198024, 0.04441671155466873, 0.043902024784365994, 0.04294093494951478, 0.03981953622791824, 0.039614320304130965] |
|4 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.1364122091134319, 0.0953545543783122, 0.09429906593265366, 0.05772444907583193, 0.044862863343866806, 0.04442734201228477, 0.04389557512474934, 0.04296443267889805, 0.03982659626276644, 0.03962640467091713] |
|5 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13670139028438008, 0.09553267874876908, 0.0946605951853061, 0.05768305061621247, 0.04480375989378822, 0.04435320773212808, 0.043914956101645836, 0.0429208816887896, 0.03981783866996495, 0.0395571526370012] |
|6 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13655841155238757, 0.09554895256001286, 0.0949549658561804, 0.05762248195720487, 0.04480233168639564, 0.04433230827791344, 0.04390710584467472, 0.042930742282966325, 0.039814043132483344, 0.039566695745809] |
|7 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13663380500078914, 0.09550193655114965, 0.09478283807049456, 0.05766262514700366, 0.04480957336655386, 0.044349558779903736, 0.04392217503495675, 0.04293742117414056, 0.03978830895805316, 0.03953804442817271] |
|8 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13657941125100242, 0.09553782026082008, 0.09488224957791296, 0.0576562888735838, 0.04478320710416449, 0.044348589637858433, 0.043920567136734125, 0.04291002130418712, 0.03979768943659999, 0.039567161368513126] |
|9 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13679375890331896, 0.09550217013809542, 0.09456405056150366, 0.057658739173097585, 0.04482551122651224, 0.0443527568707439, 0.04392317149752475, 0.04290757658934661, 0.03979803663250257, 0.03956130101562635] |
|10 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.1363532425412472, 0.09567087909028642, 0.09540557719048867, 0.0576159180949091, 0.044781629751243876, 0.04429913174899199, 0.04391116942509204, 0.042932143462292065, 0.03978668768070522, 0.03951443679027362] |
|11 |[18, 1, 14, 2, 6, 15, 19, 0, 5, 11]|[0.13652547151289018, 0.09563212101150799, 0.09559109602844051, 0.057554258392101834, 0.044746350635605996, 0.044288429293621624, 0.04393701569930993, 0.04291245197810508, 0.03976607824588965, 0.039502323331991364]|
+---------++--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
then i ran the below code for transforming, which gives my whole input data with "topicDistribution", i want to know, how can i map their corresponding topic from above.
transformed = ldamodel.transform(rescaledData)
transformed.show(1000,truncate=False)
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|features |topicDistribution |
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|(20,[0,1,2,5,6,7,9,10,11,12,14,15,17,18],[1.104871083111864,2.992600400740947,2.6744520065699304,1.3684833461617696,2.2505718610977468,1.667281425641495,0.4887254650789885,1.6871692149804491,1.2395584596083011,1.1654793679713804,1.1856289804763125,0.8839314023462759,1.3072768360965874,1.8484873888968552]) |[0.033335920691559356,0.03333070077028474,0.033333697225675085,0.03333561564482505,0.03333642468043084,0.033329348877568624,0.033332822489008374,0.0333260618235691,0.0333303226504357,0.03333271017792987,0.03332747782859632,0.03332947925144302,0.03332588375760626,0.03333450406000825,0.03333113684458691,0.03333479529731098,0.03333650038104409,0.03333779171890448,0.033330895429371024,0.03333346818374588,0.03333587374669536,0.03333749064143864,0.033334190553966234,0.03333150654339392,0.03333448663832666,0.03333989460983954,0.03333245526240933,0.03333885169495637,0.03333318824494455,0.03333650428012542] |
|(20,[1,3,9,10,11,18],[0.997533466913649,1.2726157969776226,1.4661763952369655,1.6871692149804491,1.2395584596083011,1.2323249259312368]) |[0.03333282248535624,0.033332445609854544,0.033328001364314915,0.03333556680774651,0.03333511480351705,0.03333350870876004,0.03333187395000147,0.03333445718728295,0.03333444971729848,0.03333408945721941,0.033330403594416406,0.03333043254420633,0.03333157466407796,0.033332770049699215,0.033328359180981044,0.03333228442030213,0.03333371386494732,0.03333681533414124,0.0333317544337193,0.0333325923806738,0.03333507835375717,0.0333340926283776,0.033334289011977415,0.03333710235559251,0.033333843031028085,0.03333431818041727,0.03333321848222645,0.03333712581455898,0.03333342489994699,0.033334476683601115] |
|(20,[0,1,2,5,6,7,9,10,11,12,14,15,17,18],[1.104871083111864,2.992600400740947,2.6744520065699304,1.3684833461617696,2.2505718610977468,1.667281425641495,0.4887254650789885,1.6871692149804491,1.2395584596083011,1.1654793679713804,1.1856289804763125,0.8839314023462759,1.3072768360965874,1.8484873888968552]) |[0.03333351262221663,0.03334603652469889,0.03332161942705381,0.03333027102774255,0.03333390930010211,0.03333191764996061,0.033329934469446876,0.033333299309223546,0.033341041908943575,0.0333327792244044,0.0333317683142701,0.03332955596634091,0.033328132222973345,0.033337885062247365,0.03332933802183438,0.033340558825094374,0.03332932381968907,0.033330915221236455,0.03333188695334856,0.033327196223241304,0.0333344223760765,0.03332799716207421,0.033338186672894565,0.033336730538183736,0.03333440038338333,0.033337949318794795,0.033333179017769575,0.03333720003330217,0.0333312634707931,0.033337788932659276] |
|(20,[7,10,14,15,18],[1.667281425641495,1.6871692149804491,1.1856289804763125,0.8839314023462759,1.2323249259312368]) |[0.03333299162768077,0.03333267864965697,0.03333419860453984,0.0333335178997543,0.03333360724598342,0.03333242559458429,0.03333347136144222,0.03333355459499807,0.03333358147457144,0.03333288221459248,0.033333183195193086,0.03333412789444317,0.033334780263427975,0.033333456529848454,0.033336348770946406,0.03333413810342763,0.03333364116227056,0.03333271328980694,0.033333953427323364,0.03333358082109339,0.03333219641721605,0.03333309878777325,0.03333357447436248,0.03333149527336071,0.03333364840046119,0.03333405705432205,0.03333193039244976,0.033332142999523764,0.0333329476480237,0.033332075826922346] |
|(20,[7,10,14,15,18],[1.667281425641495,1.6871692149804491,1.1856289804763125,0.8839314023462759,1.2323249259312368]) |[0.033333498405871936,0.033333280285774286,0.03333426144981626,0.033332860785526656,0.033333582231251914,0.03333268142499894,0.0333323342942777,0.033333434542768936,0.03333306424165755,0.0333326212718143,0.03333359821538673,0.03333408970522017,0.03333440903364261,0.033333628480934824,0.033335988240581156,0.03333411840979216,0.03333341827050789,0.033332367335435646,0.033334058739684515,0.03333355988453005,0.033332524988062315,0.03333411721155432,0.033333323503042835,0.03333212784312619,0.0333335375674124,0.03333359055906317,0.033332183680577915,0.033332671344198254,0.03333288800754154,0.03333218004594688] |
|(20,[2,5,9,16,17,18],[0.8914840021899768,1.3684833461617696,0.4887254650789885,1.63668714555745,2.614553672193175,0.6161624629656184]) |[0.033334335349062126,0.033335078498013884,0.03332587515068904,0.03333663657766274,0.03333680954467168,0.03333423911759284,0.03333228558123402,0.03333307701307248,0.03333230070449398,0.03333357555321742,0.03333013898196302,0.033329558017855816,0.033331357742536066,0.03333281031681836,0.03332620861451679,0.03333164736607878,0.033333411422713344,0.03333811924641809,0.03333032668669555,0.03333155312673253,0.03333601811939769,0.03333261877594549,0.03333490102064452,0.033338814952549214,0.03333297326837463,0.03333518474628696,0.03333659535650031,0.033334882391957275,0.03333262407795885,0.03333604267834657] |
|(20,[9,14,17],[0.4887254650789885,1.1856289804763125,1.3072768360965874]) |[0.03333373548732593,0.03333366047427678,0.0333388921917903,0.0333324500674453,0.03333165231618069,0.033333396263220766,0.03333392152736817,0.03333280784507405,0.03333341964532808,0.033332027450640234,0.03333590751407632,0.03333625709954709,0.033335027917460715,0.033332984179205334,0.0333395676850832,0.03333465555831355,0.03333317988602309,0.03332999573625669,0.033335686662720146,0.03333444862817325,0.033331118794333245,0.03333227594267816,0.03333337815403729,0.03332815318734183,0.03333329306551438,0.033332463398394074,0.03333211180252395,0.03332955750997137,0.03333267853415129,0.033331295475544864] |
|(20,[0,1,2,5,6,8,9,12,14,19],[4.419484332447456,8.97780120222284,0.8914840021899768,1.3684833461617696,1.1252859305488734,1.233681992250468,0.4887254650789885,1.1654793679713804,42.68264329714725,3.685472474556911]) |[0.03333847676223903,0.03333217319316553,0.03361190962478921,0.03322296565909072,0.033224152351515865,0.033305780550557426,0.0333623247307499,0.03333530792027193,0.03335048739584223,0.03328160880463601,0.0334726823486143,0.03350523767287266,0.0334249096308398,0.0333474750843313,0.033626678634369765,0.033399676341965154,0.033312708717161896,0.03316394279048787,0.033455734324711994,0.033388088146836026,0.03321934579333965,0.03331319196962562,0.03332496308901043,0.03308256004388311,0.033355283259340625,0.03328280644324275,0.0332436575330529,0.03319398288438216,0.03331402807784297,0.03320786022123095] |
|(20,[2,8,10,13,15,16],[0.8914840021899768,3.701045976751404,1.6871692149804491,1.4594938792746623,0.8839314023462759,1.63668714555745]) |[0.0333356286958193,0.033332968808115676,0.03332196061288262,0.03333752852282297,0.03334171827141663,0.033333866506025614,0.03332902614535797,0.03332999721261129,0.03333228759033127,0.03333419089119785,0.033326295405433394,0.03332727748411373,0.03332970492260472,0.03333218324896459,0.033323419320752445,0.03333484772134479,0.033331408271507906,0.03333886892339302,0.033330223503670105,0.03333357612919576,0.033337166312023804,0.03333359583540509,0.033335058229509947,0.03334147190850166,0.03333378789809434,0.033336369234199595,0.033338417163238966,0.03333500103939193,0.03333456204893585,0.03333759214313734] |
|(20,[0,1,5,9,11,14,18],[2.209742166223728,1.995066933827298,1.3684833461617696,0.4887254650789885,1.2395584596083011,1.1856289804763125,4.313137240759328]) |[0.033331568372787335,0.03333199604802443,0.03333616626240888,0.03333311434679797,0.03332871558684149,0.03333377685027234,0.03333412369772984,0.03333326223955321,0.03333302350234804,0.033334011390192264,0.033332998734646985,0.03333572338531835,0.03333519102733107,0.033333669722909055,0.03333367298003543,0.03333316052548548,0.03333434503087792,0.03333204101430685,0.03333390212404242,0.03333317825522851,0.03333468470990933,0.0333342373715466,0.03333384912615554,0.03333262239313073,0.033332130871257,0.03333157173270879,0.033331620213171494,0.03333491236454045,0.03333364390007425,0.03333308622036798] |
|(20,[0,1,2,7,9,12,14,16,17],[1.104871083111864,1.995066933827298,1.7829680043799536,1.667281425641495,0.4887254650789885,3.4964381039141412,1.1856289804763125,1.63668714555745,1.3072768360965874]) |[0.033335136368759725,0.033335270064671074,0.033333243143688636,0.03333135191450168,0.03333521337181425,0.03333250351809707,0.03333535436589356,0.03333322757251989,0.033333747503130604,0.03333235416690325,0.03333273826809218,0.03333297626192852,0.033333297925505816,0.03333285267847805,0.033334741557111344,0.03333582976624498,0.03333192986853772,0.0333314389004771,0.03333186692894785,0.033332539165631093,0.03333280721245933,0.03333326073074502,0.03333359623114709,0.03333261337071022,0.03333404937430939,0.03333536871228382,0.033334559939835924,0.03333095099987788,0.0333326510660671,0.033332529051629894] |
|(20,[2,5],[0.8914840021899768,1.3684833461617696]) |[0.033333008832987246,0.03333320773569408,0.03333212285409669,0.03333370183210235,0.03333408881228382,0.03333369017540667,0.03333315896362525,0.03333316321736793,0.033333273291611384,0.033332955144148683,0.03333299410668819,0.03333254807040678,0.03333302471306508,0.03333324003641127,0.033332032586187325,0.0333328866035473,0.033333366369446636,0.03333432146865401,0.033333080986559745,0.033332919796171534,0.03333365180032741,0.03333319727206391,0.0333336947744764,0.033334184463059975,0.033333428924387516,0.033333998899679654,0.033333803141837516,0.03333374053474702,0.03333339151988759,0.03333412307307108] |
In order to remap the terminindices to words you have to access the vocabulary of the CountVectorizer model. Please have a look at the pseudocode below:
from pyspark.sql.functions import udf
from pyspark.sql.types import *
#Reading your data...
cv = CountVectorizer(inputCol="some", outputCol="features", vocabSize=2000)
cvmodel = cv.fit(df)
#other stuff
vocab = cvmodel.vocabulary
vocab_broadcast = sc.broadcast(vocab)
#creating LDA model
ldatopics = ldamodel.describeTopics()
def map_termID_to_Word(termIndices):
words = []
for termID in termIndices:
words.append(vocab_broadcast.value[termID])
return words
udf_map_termID_to_Word = udf(map_termID_to_Word , ArrayType(StringType()))
ldatopics_mapped = ldatopics.withColumn("topic_desc", udf_map_termID_to_Word(ldatopics.termIndices))

Subtract Two Arrays to Get A New Array in Pyspark

I am new to Spark.
I can sum, subtract or multiply arrays in python Pandas&Numpy. But I am having difficulty doing something similar in Spark (python). I am on Databricks.
For example this kind of approach is giving a huge error message which I don't want to copy paste here:
differencer=udf(lambda x,y: x-y, ArrayType(FloatType()))
df.withColumn('difference', differencer('Array1', 'Array2'))
Schema looks like this:
root
|-- col1: integer (nullable = true)
|-- time: timestamp (nullable = true)
|-- num: integer (nullable = true)
|-- part: integer (nullable = true)
|-- result: integer (nullable = true)
|-- Array1: array (nullable = true)
| |-- element: float (containsNull = true)
|-- Array2: array (nullable = false)
| |-- element: float (containsNull = true)
I just want to create a new column subtracting those 2 array columns. Actually, I will get the RMSE between them. But I think I can handle it once I learn how to get this difference.
Arrays look like this(I am just typing in some integers):
Array1_row1[5, 4, 2, 4, 3]
Array2_row1[4, 3, 1, 2, 1]
So the resulting array for row1 should be:
DiffCol_row1[1, 1, 1, 2, 2]
Thanks for suggestions or giving directions. Thank you.
You can zip_arrays and transform
from pyspark.sql.functions import expr
df = spark.createDataFrame(
[([5, 4, 2, 4, 3], [4, 3, 1, 2, 1])], ("array1", "array2")
)
df.withColumn(
"array3",
expr("transform(arrays_zip(array1, array2), x -> x.array1 - x.array2)")
).show()
# +---------------+---------------+---------------+
# | array1| array2| array3|
# +---------------+---------------+---------------+
# |[5, 4, 2, 4, 3]|[4, 3, 1, 2, 1]|[1, 1, 1, 2, 2]|
# +---------------+---------------+---------------+
A valid udf would require an equivalent logic, i.e.
from pyspark.sql.functions import udf
#udf("array<double>")
def differencer(xs, ys):
if xs and ys:
return [float(x - y) for x, y in zip(xs, ys)]
df.withColumn("array3", differencer("array1", "array2")).show()
# +---------------+---------------+--------------------+
# | array1| array2| array3|
# +---------------+---------------+--------------------+
# |[5, 4, 2, 4, 3]|[4, 3, 1, 2, 1]|[1.0, 1.0, 1.0, 2...|
# +---------------+---------------+--------------------+
You can use zip_with (since Spark 2.4):
from pyspark.sql.functions import expr
df = spark.createDataFrame(
[([5, 4, 2, 4, 3], [4, 3, 1, 2, 1])], ("array1", "array2")
)
df.withColumn(
"array3",
expr("zip_with(array1, array2, (x, y) -> x - y)")
).show()
# +---------------+---------------+---------------+
# | array1| array2| array3|
# +---------------+---------------+---------------+
# |[5, 4, 2, 4, 3]|[4, 3, 1, 2, 1]|[1, 1, 1, 2, 2]|
# +---------------+---------------+---------------+

create combination dataframe from another dataframe

I have a dataframe similar to following:
q =sc.parallelize([Row(items=[1]), Row(items=[2]), Row(items=[2, 1]), Row(items=[5]), Row(items=[5, 2]), Row(items=[5, 2, 1]), Row(items=[5, 1]), Row(items=[3]), Row(items=[3, 5]), Row(items=[3, 5, 2]), Row(items=[3, 5, 2, 1]), Row(items=[3, 5, 1]), Row(items=[3, 2]), Row(items=[3, 2, 1]), Row(items=[3, 1])])
I need to create a new dataframe which contains all combination of items:
+------------+--------------+
| left | right
+------------+---------------
| [1]|[2]
| [1]|[2, 1]
| [1]|[5]
| [1]|[5,2]
| [1]|[5,2,1]
| [1]|[5,1]
| [1]|[3]
| [1]|[3,5]
| [1]|[3, 5, 2]
| [1]|[3, 5, 2, 1]
| [1]|[3,5,1]
| [1]|[3,2]
| [1]|[3,2,1]
| [1]|[3, 1]|
| [2]|[1]
| [2]|[2,1]
...
+------------+
I need to create a dataframe looks like above
Try the crossJoinmethod:
q.crossJoin(q)
That should do the trick.

PySpark DF column creation with UDF to mimic np.roll function from numpy

Trying to create a new column in a PySpark UDF but the values are null!
Create the DF
data_list = [['a', [1, 2, 3]], ['b', [4, 5, 6]],['c', [2, 4, 6, 8]],['d', [4, 1]],['e', [1,2]]]
all_cols = ['COL1','COL2']
df = sqlContext.createDataFrame(data_list, all_cols)
df.show()
+----+------------+
|COL1| COL2|
+----+------------+
| a| [1, 2, 3]|
| b| [4, 5, 6]|
| c|[2, 4, 6, 8]|
| d| [4, 1]|
| e| [1, 2]|
+----+------------+
df.printSchema()
root
|-- COL1: string (nullable = true)
|-- COL2: array (nullable = true)
| |-- element: long (containsNull = true)
Create a function
def cr_pair(idx_src, idx_dest):
idx_dest.append(idx_dest.pop(0))
return idx_src, idx_dest
lst1 = [1,2,3]
lst2 = [1,2,3]
cr_pair(lst1, lst2)
([1, 2, 3], [2, 3, 1])
Create and register a UDF
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
from pyspark.sql.types import ArrayType
get_idx_pairs = udf(lambda x: cr_pair(x, x), ArrayType(IntegerType()))
Add a new column to the DF
df = df.select('COL1', 'COL2', get_idx_pairs('COL2').alias('COL3'))
df.printSchema()
root
|-- COL1: string (nullable = true)
|-- COL2: array (nullable = true)
| |-- element: long (containsNull = true)
|-- COL3: array (nullable = true)
| |-- element: integer (containsNull = true)
df.show()
+----+------------+------------+
|COL1| COL2| COL3|
+----+------------+------------+
| a| [1, 2, 3]|[null, null]|
| b| [4, 5, 6]|[null, null]|
| c|[2, 4, 6, 8]|[null, null]|
| d| [4, 1]|[null, null]|
| e| [1, 2]|[null, null]|
+----+------------+------------+
Here where the problem is.
I am getting all values 'null' in the COL3 column.
The intended outcome should be:
+----+------------+----------------------------+
|COL1| COL2| COL3|
+----+------------+----------------------------+
| a| [1, 2, 3]|[[1 ,2, 3], [2, 3, 1]] |
| b| [4, 5, 6]|[[4, 5, 6], [5, 6, 4]] |
| c|[2, 4, 6, 8]|[[2, 4, 6, 8], [4, 6, 8, 2]]|
| d| [4, 1]|[[4, 1], [1, 4]] |
| e| [1, 2]|[[1, 2], [2, 1]] |
+----+------------+----------------------------+
Your UDF should return ArrayType(ArrayType(IntegerType())) since you are expecting a list of lists in your column, besides it only needs one parameter:
def cr_pair(idx_src):
return idx_src, idx_src[1:] + idx_src[:1]
get_idx_pairs = udf(cr_pair, ArrayType(ArrayType(IntegerType())))
df.withColumn('COL3', get_idx_pairs(df['COL2'])).show(5, False)
+----+------------+----------------------------+
|COL1|COL2 |COL3 |
+----+------------+----------------------------+
|a |[1, 2, 3] |[[2, 3, 1], [2, 3, 1]] |
|b |[4, 5, 6] |[[5, 6, 4], [5, 6, 4]] |
|c |[2, 4, 6, 8]|[[4, 6, 8, 2], [4, 6, 8, 2]]|
|d |[4, 1] |[[1, 4], [1, 4]] |
|e |[1, 2] |[[2, 1], [2, 1]] |
+----+------------+----------------------------+
It seems like what you want to do is circularly shift the elements in your list. Here is a non-udf approach using pyspark.sql.functions.posexplode() (Spark version 2.1 and above):
import pyspark.sql.functions as f
from pyspark.sql import Window
w = Window.partitionBy("COL1", "COL2").orderBy(f.col("pos") == 0, "pos")
df = df.select("*", f.posexplode("COL2"))\
.select("COL1", "COL2", "pos", f.collect_list("col").over(w).alias('COL3'))\
.where("pos = 0")\
.drop("pos")\
.withColumn("COL3", f.array("COL2", "COL3"))
df.show(truncate=False)
#+----+------------+----------------------------------------------------+
#|COL1|COL2 |COL3 |
#+----+------------+----------------------------------------------------+
#|a |[1, 2, 3] |[WrappedArray(1, 2, 3), WrappedArray(2, 3, 1)] |
#|b |[4, 5, 6] |[WrappedArray(4, 5, 6), WrappedArray(5, 6, 4)] |
#|c |[2, 4, 6, 8]|[WrappedArray(2, 4, 6, 8), WrappedArray(4, 6, 8, 2)]|
#|d |[4, 1] |[WrappedArray(4, 1), WrappedArray(1, 4)] |
#|e |[1, 2] |[WrappedArray(1, 2), WrappedArray(2, 1)] |
#+----+------------+----------------------------------------------------+
Using posexplode will return two columns- the position in the list (pos) and the value (col). The trick here is that we order by f.col("pos") == 0 first and then "pos". This will move the first position in the array to the end of the list.
Though this output prints differently than you would expect with list of lists in python, the contents of COL3 are indeed a list of lists of integers.
df.printSchema()
#root
# |-- COL1: string (nullable = true)
# |-- COL2: array (nullable = true)
# | |-- element: long (containsNull = true)
# |-- COL3: array (nullable = false)
# | |-- element: array (containsNull = true)
# | | |-- element: long (containsNull = true)
Update
The "WrappedArray prefix" is just the way Spark prints nested lists. The underlying array is exactly as you need it. One way to verify this is by calling collect() and inspecting the data:
results = df.collect()
print([(r["COL1"], r["COL3"]) for r in results])
#[(u'a', [[1, 2, 3], [2, 3, 1]]),
# (u'b', [[4, 5, 6], [5, 6, 4]]),
# (u'c', [[2, 4, 6, 8], [4, 6, 8, 2]]),
# (u'd', [[4, 1], [1, 4]]),
# (u'e', [[1, 2], [2, 1]])]
Or if you converted df to a pandas DataFrame:
print(df.toPandas())
# COL1 COL2 COL3
#0 a [1, 2, 3] ([1, 2, 3], [2, 3, 1])
#1 b [4, 5, 6] ([4, 5, 6], [5, 6, 4])
#2 c [2, 4, 6, 8] ([2, 4, 6, 8], [4, 6, 8, 2])
#3 d [4, 1] ([4, 1], [1, 4])
#4 e [1, 2] ([1, 2], [2, 1])

Resources