spark pipeline KMeansModel clusterCenters - apache-spark

I'm using a pipeline to cluster text documents. The last stage in the pipeline is ml.clustering.KMeans which gives me a DataFrame with a column of cluster predictions. I would like to add the cluster centers as a column as well. I understand I can execute Vector[] clusterCenters = kmeansModel.clusterCenters(); and then convert the results into a DataFrame and join said results to the other DataFrame however I was hoping to find a way to accomplish this in a way similar to the Kmeans code below:
KMeans kMeans = new KMeans()
.setFeaturesCol("pca")
.setPredictionCol("kmeansclusterprediction")
.setK(5)
.setInitMode("random")
.setSeed(43L)
.setInitSteps(3)
.setMaxIter(15);
pipeline.setStages( ...
I was able extend KMeans and call the fit method via a pipeline however I'm not having any luck extending KMeansModel ... the constructor requires a String uid and a KMeansModel but I don't know how to pass in the model when defining the stages and calling the setStages method.
I also looked into extending KMeans.scala however as a Java developer I only understand about half the code thus, I'm hoping someone may have an easier solution before I tackle that. Ultimately I would like to end up with a DataFrame as follows:
+--------------------+-----------------------+--------------------+
| docid|kmeansclusterprediction|kmeansclustercenters|
+--------------------+-----------------------+--------------------+
|2bcbcd54-c11a-48c...| 2| [-0.04, -7.72]|
|0e644620-f5ff-40f...| 3| [0.23, 1.08]|
|665c1c2b-3065-4e8...| 3| [0.23, 1.08]|
|598c6268-e4b9-4c9...| 0| [-15.81, 0.01]|
+--------------------+-----------------------+--------------------+
Any help or hints is greatly appreciated.
Thank you

Answering my own question ... this was actually easy ... I extended KMeans and KMeansModel ... the extended Kmeans fit method must return the extended KMeansModel. For example:
public class AnalyticsKMeansModel extends KMeansModel ...
public class AnalyticsKMeans extends org.apache.spark.ml.clustering.KMeans { ...
public AnalyticsKMeansModel fit(DataFrame dataset) {
JavaRDD<Vector> javaRDD = dataset.select(this.getFeaturesCol()).toJavaRDD().map(new Function<Row, Vector>(){
private static final long serialVersionUID = -4588981547209486909L;
#Override
public Vector call(Row row) throws Exception {
Object point = row.getAs("pca");
Vector vector = (Vector)point;
return vector;
}
});
RDD<Vector> rdd = JavaRDD.toRDD(javaRDD);
org.apache.spark.mllib.clustering.KMeans algo = new org.apache.spark.mllib.clustering.KMeans().setK(BoxesRunTime.unboxToInt(this.$((Param<?>)this.k()))).setInitializationMode((String)this.$(this.initMode())).setInitializationSteps(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.initSteps()))).setMaxIterations(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.maxIter()))).setSeed(BoxesRunTime.unboxToLong((Object)this.$((Param<?>)this.seed()))).setEpsilon(BoxesRunTime.unboxToDouble((Object)this.$((Param<?>)this.tol())));
org.apache.spark.mllib.clustering.KMeansModel parentModel = algo.run(rdd);
AnalyticsKMeansModel model = new AnalyticsKMeansModel(this.uid(), parentModel);
return (AnalyticsKMeansModel) this.copyValues((Params)model, this.copyValues$default$2());
}
Once I changed the fit method to return my extended KMeansModel class everything worked as expected.

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import AnalyticsCluster;
public class AnalyticsKMeansModel extends KMeansModel {
private static final long serialVersionUID = -8893355418042946358L;
public AnalyticsKMeansModel(String uid, org.apache.spark.mllib.clustering.KMeansModel parentModel) {
super(uid, parentModel);
}
public DataFrame transform(DataFrame dataset) {
Vector[] clusterCenters = super.clusterCenters();
List<AnalyticsCluster> analyticsClusters = new ArrayList<AnalyticsCluster>();
for (int i=0; i<clusterCenters.length;i++){
Integer clusterId = super.predict(clusterCenters[i]);
Vector vector = clusterCenters[i];
double[] point = vector.toArray();
AnalyticsCluster analyticsCluster = new AnalyticsCluster(clusterId, point, 0L);
analyticsClusters.add(analyticsCluster);
}
JavaSparkContext jsc = JavaSparkContext.fromSparkContext(dataset.sqlContext().sparkContext());
JavaRDD<AnalyticsCluster> javaRDD = jsc.parallelize(analyticsClusters);
JavaRDD<Row> javaRDDRow = javaRDD.map(new Function<AnalyticsCluster, Row>() {
private static final long serialVersionUID = -2677295862916670965L;
#Override
public Row call(AnalyticsCluster cluster) throws Exception {
Row row = RowFactory.create(
String.valueOf(cluster.getID()),
String.valueOf(Arrays.toString(cluster.getCenter()))
);
return row;
}
});
List<StructField> schemaColumns = new ArrayList<StructField>();
schemaColumns.add(DataTypes.createStructField(this.getPredictionCol(), DataTypes.StringType, false));
schemaColumns.add(DataTypes.createStructField("clusterpoint", DataTypes.StringType, false));
StructType dataFrameSchema = DataTypes.createStructType(schemaColumns);
DataFrame clusterPointsDF = dataset.sqlContext().createDataFrame(javaRDDRow, dataFrameSchema);
//SOMETIMES "K" IS SET TO A VALUE GREATER THAN THE NUMBER OF ACTUAL ROWS OF DATA ... GET DISTINCT VALUES
clusterPointsDF.registerTempTable("clusterPoints");
DataFrame clustersDF = clusterPointsDF.sqlContext().sql("select distinct " + this.getPredictionCol()+ ", clusterpoint from clusterPoints");
clustersDF.cache();
clusterPointsDF.sqlContext().dropTempTable("clusterPoints");
DataFrame transformedDF = super.transform(dataset);
transformedDF.cache();
DataFrame df = transformedDF.join(clustersDF,
transformedDF.col(this.getPredictionCol()).equalTo(clustersDF.col(this.getPredictionCol())), "inner")
.drop(clustersDF.col(this.getPredictionCol()));
return df;
}
}
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.Params;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import scala.runtime.BoxesRunTime;
public class AnalyticsKMeans extends org.apache.spark.ml.clustering.KMeans {
private static final long serialVersionUID = 8943702485821267996L;
private static String uid = null;
public AnalyticsKMeans(String uid){
AnalyticsKMeans.uid= uid;
}
public AnalyticsKMeansModel fit(DataFrame dataset) {
JavaRDD<Vector> javaRDD = dataset.select(this.getFeaturesCol()).toJavaRDD().map(new Function<Row, Vector>(){
private static final long serialVersionUID = -4588981547209486909L;
#Override
public Vector call(Row row) throws Exception {
Object point = row.getAs("pca");
Vector vector = (Vector)point;
return vector;
}
});
RDD<Vector> rdd = JavaRDD.toRDD(javaRDD);
org.apache.spark.mllib.clustering.KMeans algo = new org.apache.spark.mllib.clustering.KMeans().setK(BoxesRunTime.unboxToInt(this.$((Param<?>)this.k()))).setInitializationMode((String)this.$(this.initMode())).setInitializationSteps(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.initSteps()))).setMaxIterations(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.maxIter()))).setSeed(BoxesRunTime.unboxToLong((Object)this.$((Param<?>)this.seed()))).setEpsilon(BoxesRunTime.unboxToDouble((Object)this.$((Param<?>)this.tol())));
org.apache.spark.mllib.clustering.KMeansModel parentModel = algo.run(rdd);
AnalyticsKMeansModel model = new AnalyticsKMeansModel(this.uid(), parentModel);
return (AnalyticsKMeansModel) this.copyValues((Params)model, this.copyValues$default$2());
}
}
import java.io.Serializable;
import java.util.Arrays;
public class AnalyticsCluster implements Serializable {
private static final long serialVersionUID = 6535671221958712594L;
private final int id;
private volatile double[] center;
private volatile long count;
public AnalyticsCluster(int id, double[] center, long initialCount) {
// Preconditions.checkArgument(center.length > 0);
// Preconditions.checkArgument(initialCount >= 1);
this.id = id;
this.center = center;
this.count = initialCount;
}
public int getID() {
return id;
}
public double[] getCenter() {
return center;
}
public long getCount() {
return count;
}
public synchronized void update(double[] newPoint, long newCount) {
int length = center.length;
// Preconditions.checkArgument(length == newPoint.length);
double[] newCenter = new double[length];
long newTotalCount = newCount + count;
double newToTotal = (double) newCount / newTotalCount;
for (int i = 0; i < length; i++) {
double centerI = center[i];
newCenter[i] = centerI + newToTotal * (newPoint[i] - centerI);
}
center = newCenter;
count = newTotalCount;
}
#Override
public synchronized String toString() {
return id + " " + Arrays.toString(center) + " " + count;
}
// public static void main(String[] args) {
// double[] point = new double[2];
// point[0] = 0.10150532938119154;
// point[1] = -0.23734759238651829;
//
// Cluster cluster = new Cluster(1,point, 10L);
// System.out.println("cluster: " + cluster.toString());
// }
}

Related

Filter JavaRDD based on a ArrayList of index id's

I have Dataset df with contents which have an index as accountid and I also have array list with accountids. How do I filter or map the Dataset to create a new Dataset that has only contents based on the accountid in the arraylist.
I am using Java 8
List<String> accountIdList= new ArrayList<String>();
accountIdList.add("1001");
accountIdList.add("1002");
accountIdList.add("1003");
accountIdList.add("1004");
Dataset<Row> filteredRows= df.filter(p-> df.col("accountId").equals(accountIdList));
I am trying to pass the list itself for the comparison operator do you think this is the correct approach
The Java syntax is
If you are looking for the java syntax
Dataset<Row> filteredRows= df.where(df.col("accountId").isin(accountIdList.toArray()));
Use Column.isin method:
import scala.collection.JavaConversions;
import static org.apache.spark.sql.functions.*;
Dataset<Row> filteredRows = df.where(col("accountId").isin(
JavaConversions.asScalaIterator(accountIdList.iterator()).toSeq()
));
Here is a working code in Java. Hope it helps .
This is my sampleFile Content (input):-
1001
1008
1005
1009
1010
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
public class DatasetFilter {
private static List<String> sampleList = new ArrayList<String>();
public static void main(String[] args)
{
sampleList.add("1001");
sampleList.add("1002");
sampleList.add("1003");
sampleList.add("1004");
sampleList.add("1005");
SparkSession sparkSession = SparkSession.builder()
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.sql.warehouse.dir", "file:///C:/Users/user/workspace/Validation/spark-warehouse")
.master("local[*]").getOrCreate();
//Read the source-file.
Dataset<String> src = sparkSession.read().textFile("C:\\Users\\user\\Desktop\\dataSetFilterTest.txt");
src.show(10);
//Apply filter
Dataset<String> filteredSource = src.filter(new FilterFunction<String>() {
private static final long serialVersionUID = 1L;
#Override
public boolean call(String value) throws Exception {
System.out.println("***************************************");
boolean status = false;
Iterator<String> iterator = sampleList.iterator();
while (iterator.hasNext()) {
String val = iterator.next();
System.out.println("Val is :: " + val + " Value is :: " + value);
if (value.equalsIgnoreCase(val)) {
status = true;
break;
}
}
return status;
}
});
filteredSource.show();
System.out.println("Completed the job :)");
}
}
Output:-

Using Hashmap making plagiarism checker based on Java

it reads input.txt file and compared with target file.
If more than 3words are same as input file, the program should say plagiarized from .
I used substring so the program only compares first 3 letters.
should use tokenized? or how can it compared
Here is my code
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Scanner;
public class CheckPlagiarism{
public static void main(String args[]) throws FileNotFoundException
{
//Init HashMap
HashMap<String, Integer> corpus = new HashMap<>();
String fileName = args[0];
String tragetName = args[1];
int matchCount = Integer.parseInt(args[2]);
Scanner scanner = new Scanner(new File(fileName));
while(scanner.hasNext())
{
String[] line = scanner.nextLine().split(":");
corpus.put(line[1], Integer.parseInt(line[0]));
}
boolean found = false;
scanner = new Scanner(new File(tragetName));
while(scanner.hasNext())
{
String line = scanner.nextLine();
line = line.string(0,matchCount);
for(Entry<String, Integer> temp: corpus.entrySet()){
String key=temp.getKey();
if(key.contains(line))
{
System.out.println("Plagiarized from " + temp.getValue());
found = true;
break;
}
}
}
if(!found)
{
System.out.println("Not Plagiarized");
}
}
}

Spark Parallelism with wholeTextFiles

I am trying to use wholeTextFiles API for file processing. I do have lot of .gz files in a folder and want to read them with the wholeTextFiles API.
I have 4 executors with each 1 core with 2GB RAM on each executor.
Only 2 executors are processing the job and the processing is really slow. The other two executors are sitting idle.
How do i spread the job to other 2 executors to increase the parallelism.?
package com.sss.ss.ss.WholeText;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.hive.HiveContext;
import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import scala.Tuple2;
public class WholeText {
public static class mySchema implements Serializable {
private String CFIELD1 ;
private String CFIELD2 ;
public String getCFIELD1()
{
return CFIELD1;
}
public void setCFIELD1(String cFIELD1)
{
CFIELD1 = cFIELD1;
}
public String getCFIELD2()
{
return CFIELD2;
}
public void setCFIELD2(String cFIELD2)
{
CFIELD2 = cFIELD2;
}
}
public static void main(String[] args) throws InterruptedException {
SparkConf sparkConf = new SparkConf().setAppName("My app")
.setMaster("mymaster..")
.set("spark.driver.allowMultipleContexts", "true");
JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(15));
JavaPairRDD<String, String> wholeTextFiles = jssc.sparkContext().wholeTextFiles(args[0],Integer.parseInt(args[3]));
Integer ll = wholeTextFiles.getNumPartitions();
System.out.println("Number of Partitions"+ll);
JavaRDD<String> stringRDD = wholeTextFiles.
map(
new Function<Tuple2<String, String>, String>() {
private static final long serialVersionUID = -551872585218963131L;
public String call(Tuple2<String, String> v1) throws Exception
{
return v1._2;
}
}
).
flatMap
(new FlatMapFunction<String, String>()
{
public Iterator<String> call(String t) throws Exception
{
return Arrays.asList(t.split("\\r?\\n")).iterator();
}
}).
filter(new Function<String, Boolean>() {
private static final long serialVersionUID = 1L;
public Boolean call(String t) throws Exception {
int colons = 0;
String s = t;
if(s == null || s.trim().length() < 1) {
return false;
}
for(int i = 0; i < s.length(); i++) {
if(s.charAt(i) == ';') colons++;
}
System.out.println("colons="+colons);
if ((colons <=3)){
return false;
}
return true;
}
});
JavaRDD<mySchema> schemaRDD = stringRDD.map(new Function<String, mySchema>()
{
private static final long serialVersionUID = 1L;
public mySchema call(String line) throws Exception
{
String[] parts = line.split(";",-1);
mySchema mySchema = new mySchema();
mySchema.setCFIELD1 (parts[0]);
mySchema.setCFIELD2 (parts[1]);
return mySchema;
}
});
SQLContext hc = new HiveContext(jssc.sparkContext());
Dataset<Row> df = hc.createDataFrame(schemaRDD, mySchema.class);
df.createOrReplaceTempView("myView");
hc.sql("INSERT INTO -----
"from myView");
hc.sql("INSERT INTO .......
"from myView");
}
}

How can we get access to SQLContext in a flatmap function?

SQlContext when accessed as below using a singleton class works fine in local mode, however when submitted spark master, it becomes null and throws nullpointer exceptions. How can this be fixed?
In our usecase FlatMapFunction is expected to query another DStream and the results returned are used to create a new stream.
Have extended the JavaStatefulNetworkWordCount example to print the changes to the state. I need to access the rdds from a stateful dstream in another dstream using sqlcontext in order to create another dstream. How can this be achieved?
import java.util.Arrays;
import java.util.List;
import java.util.regex.Pattern;
import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.StorageLevels;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import scala.Tuple2;
import com.google.common.base.Optional;
import com.google.common.collect.Lists;
public class JavaStatefulNetworkWordCount {
private static final Pattern SPACE = Pattern.compile(" ");
public static void main(String[] args) {
if (args.length < 2) {
System.err.println("Usage: JavaStatefulNetworkWordCount <hostname> <port>");
System.exit(1);
}
// Update the cumulative count function
final Function2<List<Integer>, Optional<Integer>, Optional<Integer>> updateFunction =
new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
#Override
public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
Integer newSum = state.or(0);
for (Integer value : values) {
newSum += value;
}
return Optional.of(newSum);
}
};
// Create the context with a 1 second batch size
SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount");
// sparkConf.setMaster("local[5]");
// sparkConf.set("spark.executor.uri", "target/rkspark-0.0.1-SNAPSHOT.jar");
JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
ssc.checkpoint(".");
SQLContext sqlContext = JavaSQLContextSingleton.getInstance(ssc.sparkContext().sc());
// Initial RDD input to updateStateByKey
List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, Integer>("hello", 1),
new Tuple2<String, Integer>("world", 1));
JavaPairRDD<String, Integer> initialRDD = ssc.sc().parallelizePairs(tuples);
JavaReceiverInputDStream<String> lines = ssc.socketTextStream(
args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2);
JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
#Override
public Iterable<String> call(String x) {
return Lists.newArrayList(SPACE.split(x));
}
});
JavaPairDStream<String, Integer> wordsDstream = words.mapToPair(
new PairFunction<String, String, Integer>() {
#Override
public Tuple2<String, Integer> call(String s) {
return new Tuple2<String, Integer>(s, 1);
}
});
// This will give a Dstream made of state (which is the cumulative count of the words)
JavaPairDStream<String, Integer> stateDstream = wordsDstream.updateStateByKey(updateFunction,
new HashPartitioner(ssc.sparkContext().defaultParallelism()), initialRDD);
JavaDStream<WordCount> countStream = stateDstream.map(new Function<Tuple2<String, Integer>, WordCount>(){
#Override
public WordCount call(Tuple2<String, Integer> v1) throws Exception {
return new WordCount(v1._1,v1._2);
}});
countStream.foreachRDD(new Function<JavaRDD<WordCount>,Void>() {
#Override
public Void call(JavaRDD<WordCount> rdd) {
SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context());
DataFrame wordsDataFrame = sqlContext.createDataFrame(rdd, WordCount.class);
wordsDataFrame.registerTempTable("words");
return null;
}
});
wordsDstream.map(new Function<Tuple2<String,Integer>,String>(){
#Override
public String call(Tuple2<String, Integer> v1) throws Exception {
// Below sql context becomes null when run on a master instead of local.
SQLContext sqlContext = JavaSQLContextSingleton.getInstance();
DataFrame counterpartyIds = sqlContext.sql("select * from words where word ='"+v1._1()+"'");
Row[] rows = counterpartyIds.cache().collect();
if(rows.length>0){
Row row = rows[0];
return row.getInt(0)+"-"+ row.getString(1);
} else {
return "";
}
}
}).print();
ssc.start();
ssc.awaitTermination();
}
}
class JavaSQLContextSingleton {
static private transient SQLContext instance = null;
static public SQLContext getInstance(SparkContext sparkContext) {
if (instance == null) {
instance = new SQLContext(sparkContext);
}
return instance;
}
}
import java.io.Serializable;
public class WordCount implements Serializable{
public String getWord() {
return word;
}
public void setWord(String word) {
this.word = word;
}
public int getCount() {
return count;
}
public void setCount(int count) {
this.count = count;
}
String word;
public WordCount(String word, int count) {
super();
this.word = word;
this.count = count;
}
int count;
}
The SparkContext (and thus the SQLContext) is only available in the Driver and not serialized to the Workers. Your program works in local since it is running in the context of the driver where the context is available.

JavaFX: Adding rows to TableView with a HashMap binding

Suppose that I have a map collection:
ObserableHashMap<K,V> map = FXCollections.observableHashMap();
I put 1 record into this map during fxml controller initialization, then wrap it as ObservableList:
ObservableList<ObserableHashMap.Entry<K,V>> list = FXCollections.observableArrayList(map.entrySet());
then setitems for my tableView.setItems(list);
Everything is fine when I run this JavaFX app and 1 record is showing.
Question is that:
When I add more records later to my map, my TableView will not refresh these records.
How could I bind a dynamical map collection into my TableView?
Thanks
If you use an ObservableList of ObservableMaps as your TableView's data structure
ObservableList<ObservableMap> rowMaps = FXCollections.observableArrayList();
tableView.setItems(rowMaps);
and implement your own ObservableMapValueFactory
import javafx.beans.property.ObjectProperty;
import javafx.beans.property.SimpleObjectProperty;
import javafx.beans.value.ObservableValue;
import javafx.collections.MapChangeListener;
import javafx.collections.ObservableMap;
import javafx.scene.control.TableColumn;
import javafx.scene.control.TableColumn.CellDataFeatures;
import javafx.util.Callback;
public class ObservableMapValueFactory<V> implements
Callback<TableColumn.CellDataFeatures<ObservableMap, V>, ObservableValue<V>> {
private final Object key;
public ObservableMapValueFactory(Object key) {
this.key = key;
}
#Override
public ObservableValue<V> call(CellDataFeatures<ObservableMap, V> features) {
final ObservableMap map = features.getValue();
final ObjectProperty<V> property = new SimpleObjectProperty<V>((V) map.get(key));
map.addListener(new MapChangeListener<Object, V>() {
public void onChanged(Change<?, ? extends V> change) {
if (key.equals(change.getKey())) {
property.set((V) map.get(key));
}
}
});
return property;
}
}
and then set it as the cell value factory for your column(s)
column.setCellValueFactory(new ObservableMapValueFactory<String>(columnId));
all changes to your data are reflected in the TableView, even changes only affecting the ObservableMaps.
You can bind a map directly to a TableView, consider this example from the JavaFX documentation :
import java.util.HashMap;
import java.util.Map;
import javafx.application.Application;
import javafx.collections.FXCollections;
import javafx.collections.ObservableList;
import javafx.geometry.Insets;
import javafx.scene.Group;
import javafx.scene.Scene;
import javafx.scene.control.Label;
import javafx.scene.control.TableCell;
import javafx.scene.control.TableColumn;
import javafx.scene.control.TableView;
import javafx.scene.control.cell.MapValueFactory;
import javafx.scene.control.cell.TextFieldTableCell;
import javafx.scene.layout.VBox;
import javafx.scene.text.Font;
import javafx.stage.Stage;
import javafx.util.Callback;
import javafx.util.StringConverter;
public class TableViewSample extends Application {
public static final String Column1MapKey = "A";
public static final String Column2MapKey = "B";
public static void main(String[] args) {
launch(args);
}
#Override
public void start(Stage stage) {
Scene scene = new Scene(new Group());
stage.setTitle("Table View Sample");
stage.setWidth(300);
stage.setHeight(500);
final Label label = new Label("Student IDs");
label.setFont(new Font("Arial", 20));
TableColumn<Map, String> firstDataColumn = new TableColumn<>("Class A");
TableColumn<Map, String> secondDataColumn = new TableColumn<>("Class B");
firstDataColumn.setCellValueFactory(new MapValueFactory(Column1MapKey));
firstDataColumn.setMinWidth(130);
secondDataColumn.setCellValueFactory(new MapValueFactory(Column2MapKey));
secondDataColumn.setMinWidth(130);
TableView table_view = new TableView<>(generateDataInMap());
table_view.setEditable(true);
table_view.getSelectionModel().setCellSelectionEnabled(true);
table_view.getColumns().setAll(firstDataColumn, secondDataColumn);
Callback<TableColumn<Map, String>, TableCell<Map, String>>
cellFactoryForMap = new Callback<TableColumn<Map, String>,
TableCell<Map, String>>() {
#Override
public TableCell call(TableColumn p) {
return new TextFieldTableCell(new StringConverter() {
#Override
public String toString(Object t) {
return t.toString();
}
#Override
public Object fromString(String string) {
return string;
}
});
}
};
firstDataColumn.setCellFactory(cellFactoryForMap);
secondDataColumn.setCellFactory(cellFactoryForMap);
final VBox vbox = new VBox();
vbox.setSpacing(5);
vbox.setPadding(new Insets(10, 0, 0, 10));
vbox.getChildren().addAll(label, table_view);
((Group) scene.getRoot()).getChildren().addAll(vbox);
stage.setScene(scene);
stage.show();
}
private ObservableList<Map> generateDataInMap() {
int max = 10;
ObservableList<Map> allData = FXCollections.observableArrayList();
for (int i = 1; i < max; i++) {
Map<String, String> dataRow = new HashMap<>();
String value1 = "A" + i;
String value2 = "B" + i;
dataRow.put(Column1MapKey, value1);
dataRow.put(Column2MapKey, value2);
allData.add(dataRow);
}
return allData;
}
}
More information can be found here
The answer from ItachiUchiha use the columns as keys and the rows as individual maps. If you'd like one map with the rows as keys->values, you'll have to add a listener to the map that will change the list when you add or delete. I did something similar here. https://stackoverflow.com/a/21339428/2855515

Resources