I just wrote my own data Provider, which should read a file in chunks and provides it to my spock specification.
While debugging the next() method returns a proper batch and the hasNext() returns false if the reader can not read any more lines.
But I get this exception: SpockExecutionException: Data provider has no data
Here is my Provider and my feature
class DumpProvider implements Iterable<ArrayList<String>> {
private File fileHandle
private BufferedReader fileReader
private ArrayList<String> currentBatch = new ArrayList<String>()
private int chunksize
private boolean hasNext = true
DumpProvider(String pathToFile, int chunksize) {
this.chunksize = chunksize
this.fileHandle = new File(pathToFile)
this.fileReader = this.fileHandle.newReader()
}
#Override
Iterator iterator() {
new Iterator<ArrayList<String>>() {
#Override
boolean hasNext() {
if (hasNext) {
String nextLine = fileReader.readLine()
if (nextLine != null) {
currentBatch.push(nextLine)
} else {
hasNext = false
fileReader.close()
fileHandle = null
}
}
return hasNext
}
#Override
ArrayList<String> next() {
(chunksize - currentBatch.size()).times {
String line = fileReader.readLine()
if (line != null) {
currentBatch.push(line)
}
}
def batch = new ArrayList<String>(currentBatch)
currentBatch = new ArrayList<String>()
return batch
}
#Override
void remove() {
throw new UnsupportedOperationException();
}
}
}
}
Spock Feature
def "small import"() {
when:
println 'test'
println profileJSONStrings
connector.insertMultiple(profileJSONStrings as ArrayList<String>)
then:
println "hello"
where:
profileJSONStrings << dataProvider
}
Related
I have written a jar which has jedis connection pool feature, by using which I have written the groovy script in nifi for redis location search. But it is behaving stangely, sometimes it is working and sometimes not.
Redis.java
public class Redis {
private static Object staticLock = new Object();
private static JedisPool pool;
private static String host;
private static int port;
private static int connectTimeout;
private static int operationTimeout;
private static String password;
private static JedisPoolConfig config;
public static void initializeSettings(String host, int port, String password, int connectTimeout, int operationTimeout) {
Redis.host = host;
Redis.port = port;
Redis.password = password;
Redis.connectTimeout = connectTimeout;
Redis.operationTimeout = operationTimeout;
}
public static JedisPool getPoolInstance() {
if (pool == null) { // avoid synchronization lock if initialization has already happened
synchronized(staticLock) {
if (pool == null) { // don't re-initialize if another thread beat us to it.
JedisPoolConfig poolConfig = getPoolConfig();
boolean useSsl = port == 6380 ? true : false;
int db = 0;
String clientName = "MyClientName"; // null means use default
SSLSocketFactory sslSocketFactory = null; // null means use default
SSLParameters sslParameters = null; // null means use default
HostnameVerifier hostnameVerifier = new SimpleHostNameVerifier(host);
pool = new JedisPool(poolConfig, host, port);
//(poolConfig, host, port, connectTimeout,operationTimeout,password, db,
// clientName, useSsl, sslSocketFactory, sslParameters, hostnameVerifier);
}
}
}
return pool;
}
public static JedisPoolConfig getPoolConfig() {
if (config == null) {
JedisPoolConfig poolConfig = new JedisPoolConfig();
int maxConnections = 200;
poolConfig.setMaxTotal(maxConnections);
poolConfig.setMaxIdle(maxConnections);
poolConfig.setBlockWhenExhausted(true);
poolConfig.setMaxWaitMillis(operationTimeout);
poolConfig.setMinIdle(50);
Redis.config = poolConfig;
}
return config;
}
public static String getPoolCurrentUsage()
{
JedisPool jedisPool = getPoolInstance();
JedisPoolConfig poolConfig = getPoolConfig();
int active = jedisPool.getNumActive();
int idle = jedisPool.getNumIdle();
int total = active + idle;
String log = String.format(
"JedisPool: Active=%d, Idle=%d, Waiters=%d, total=%d, maxTotal=%d, minIdle=%d, maxIdle=%d",
active,
idle,
jedisPool.getNumWaiters(),
total,
poolConfig.getMaxTotal(),
poolConfig.getMinIdle(),
poolConfig.getMaxIdle()
);
return log;
}
private static class SimpleHostNameVerifier implements HostnameVerifier {
private String exactCN;
private String wildCardCN;
public SimpleHostNameVerifier(String cacheHostname)
{
exactCN = "CN=" + cacheHostname;
wildCardCN = "CN=*" + cacheHostname.substring(cacheHostname.indexOf('.'));
}
public boolean verify(String s, SSLSession sslSession) {
try {
String cn = sslSession.getPeerPrincipal().getName();
return cn.equalsIgnoreCase(wildCardCN) || cn.equalsIgnoreCase(exactCN);
} catch (SSLPeerUnverifiedException ex) {
return false;
}
}
}
}
CustomFunction:
public class Functions {
SecureRandom rand = new SecureRandom();
private static final String UTF8= "UTF-8";
public static JedisPool jedisPool=null;
public static String searchPlace(double lattitude,double longitude) {
try(Jedis jedis = jedisPool.getResource()) {
}
catch(Exception e){
log.error('execption',e);
}
}
}
Groovyscript:
import org.apache.nifi.processor.ProcessContext;
import com.customlib.functions.*;
def flowFile = session.get();
if (flowFile == null) {
return;
}
def flowFiles = [] as List<FlowFile>
def failflowFiles = [] as List<FlowFile>
def input=null;
def data=null;
static onStart(ProcessContext context){
Redis.initializeSettings("host", 6379, null,0,0);
Functions.jedisPool= Redis.getPoolInstance();
}
static onStop(ProcessContext context){
Functions.jedisPool.destroy();
}
try{
log.warn('is jedispool connected::::'+Functions.jedisPool.isClosed());
def inputStream = session.read(flowFile)
def writer = new StringWriter();
IOUtils.copy(inputStream, writer, "UTF-8");
data=writer.toString();
input = new JsonSlurper().parseText( data );
log.warn('place is::::'+Functions.getLocationByLatLong(input["data"]["lat"], input["data"]["longi"]);
.......
...........
}
catch(Exception e){
}
newFlowFile = session.write(newFlowFile, { outputStream ->
outputStream.write( data.getBytes(StandardCharsets.UTF_8) )
} as OutputStreamCallback)
failflowFiles<< newFlowFile;
}
session.transfer(flowFiles, REL_SUCCESS)
session.transfer(failflowFiles, REL_FAILURE)
session.remove(flowFile)
The nifi is in 3 node cluster. The function lib is configured in groovyscript module directory.In the above groovy script processor, the log statement is jedispool connected:::: is sometimes printing false,sometimes true but after deploying for the first time jar every time works. But later it is unpredictable, I am not getting what is wrong in the code. How the groovyscript will load the jar. How can I acheive the lib based search using groovy script.
Redis.pool never gets null after initialization. You are calling pool.destroy() but not setting it to null.
getPoolInstance() checks if pool is null only then it creates a new pool.
I don't see any reason to have 2 variables to hold reference to the same pool: in Redis and in Functions class.
Question: This doesn't seem to work.
#Post("/register")
#Consumes(MediaType.APPLICATION_FORM_URLENCODED)
#View("register")
#Error(exception = ConstraintViolationException.class)
def register(HttpRequest<?> request, ConstraintViolationException constraintViolationException) {
Optional<RegisterFormData> registerFormDataOptional = request.getBody(RegisterFormData.class)
Map<String, Object> map = new HashMap<>()
if(registerFormDataOptional.isPresent()){
RegisterRequest registerRequest = new RegisterRequest(registerFormDataOptional.get().properties)
registerRequest.returnSecureToken = true
try {
def registerResponse = firebaseClient.register(registerRequest, this.firebaseApiKey).blockingSingle()
SendEmailVerificationRequest sendEmailVerificationRequest = new SendEmailVerificationRequest()
sendEmailVerificationRequest.requestType = 'VERIFY_EMAIL'
sendEmailVerificationRequest.idToken = registerResponse.idToken
firebaseClient.sendEmailVerification(sendEmailVerificationRequest, this.firebaseApiKey)
HttpResponse.redirect(URI.create('/register-success'))
}catch(HttpClientResponseException ex){
map.put('errors', [ex.message])
return map
}
}else{
map.put('errors', violationMessageSource.violationsMessages(constraintViolationException.constraintViolations))
return map
}
}
gives me a
{"message":"Required argument [ConstraintViolationException constraintViolationException] not specified","path":"/constraintViolationException","_links":{"self":{"href":"/auth/register","templated":false}}}
Currently using Micronaut and Thymeleaf. Anyone know what else am I missing? I was following the examples from https://guides.micronaut.io/micronaut-error-handling/guide/index.html
The whole point here is to pass some error messages from the controller back to the UI when constraint violations happen. The default one that uses annotations #Body and #Valid don't work since it returns json errors without any views.
#Inject
Validator validator
#Inject
ViolationMessageSource violationMessageSource
#Post("/register")
#Consumes(MediaType.APPLICATION_FORM_URLENCODED)
def register(HttpRequest<?> request, #Body RegisterFormData registerFormData) {
//validate registerformdata object
Map<String, Object> map = new HashMap<>()
Set<ConstraintViolation<RegisterFormData>> violations = validator.validate(registerFormData)
if (violations.size() > 0) {
map.put('registerFormData', registerFormData)
map.put('errors', violationMessageSource.violationsMessages(violations))
HttpResponse.redirect(URI.create('/register')).body(map)
} else {
RegisterRequest registerRequest = new RegisterRequest(registerFormData.properties)
registerRequest.returnSecureToken = true
try {
def registerResponse = firebaseClient.register(registerRequest, this.firebaseApiKey).blockingSingle()
SendEmailVerificationRequest sendEmailVerificationRequest = new SendEmailVerificationRequest()
sendEmailVerificationRequest.requestType = 'VERIFY_EMAIL'
sendEmailVerificationRequest.idToken = registerResponse.idToken
firebaseClient.sendEmailVerification(sendEmailVerificationRequest, this.firebaseApiKey)
HttpResponse.redirect(URI.create('/register-success'))
} catch (HttpClientResponseException ex) {
map.put('errors', [ex.message])
HttpResponse.redirect(URI.create('/register')).body(map)
}
}
}
where I have injected a validator bean in Micronaut like so
#Factory
class ValidatorConfig {
Validator validator
#PostConstruct
void initialize(){
ValidatorFactory factory = Validation.buildDefaultValidatorFactory()
validator = factory.getValidator()
}
#Bean
Validator getValidator(){
return validator
}
}
and my message source like so
#Singleton
public class ViolationMessageSource {
public List<String> violationsMessages(Set<ConstraintViolation<?>> violations) {
return violations.stream()
.map(ViolationMessageSource::violationMessage)
.collect(Collectors.toList());
}
private static String violationMessage(ConstraintViolation violation) {
StringBuilder sb = new StringBuilder();
Path.Node lastNode = lastNode(violation.getPropertyPath());
if (lastNode != null) {
sb.append(lastNode.getName());
sb.append(" ");
}
sb.append(violation.getMessage());
return sb.toString();
}
private static Path.Node lastNode(Path path) {
Path.Node lastNode = null;
for (final Path.Node node : path) {
lastNode = node;
}
return lastNode;
}
}
The answers are based on the fundamentals on javax validation https://www.baeldung.com/javax-validation and error handling in micronaut https://guides.micronaut.io/micronaut-error-handling/guide/index.html
I'm trying to figure out how to skip serializing empty collections using YamlDotNet. I have experimented with both a custom ChainedObjectGraphVisitor and IYamlTypeConverter. I'm new to using YamlDotNet and have some knowledge gaps here.
Below is my implementation for the visitor pattern, which results in a YamlDotNet.Core.YamlException "Expected SCALAR, SEQUENCE-START, MAPPING-START, or ALIAS, got MappingEnd" error. I do see some online content for MappingStart/MappingEnd, but I'm not sure how it fits into what I'm trying to do (eliminate clutter from lots of empty collections). Any pointers in the right direction are appreciated.
Instantiating the serializer:
var serializer = new YamlDotNet.Serialization.SerializerBuilder()
.WithNamingConvention(new YamlDotNet.Serialization.NamingConventions.CamelCaseNamingConvention())
.WithEmissionPhaseObjectGraphVisitor(args => new YamlIEnumerableSkipEmptyObjectGraphVisitor(args.InnerVisitor))
.Build();
ChainedObjectGraphVisitor implementation:
public sealed class YamlIEnumerableSkipEmptyObjectGraphVisitor : ChainedObjectGraphVisitor
{
public YamlIEnumerableSkipEmptyObjectGraphVisitor(IObjectGraphVisitor<IEmitter> nextVisitor)
: base(nextVisitor)
{
}
public override bool Enter(IObjectDescriptor value, IEmitter context)
{
bool retVal;
if (typeof(System.Collections.IEnumerable).IsAssignableFrom(value.Value.GetType()))
{ // We have a collection
var enumerableObject = (System.Collections.IEnumerable)value.Value;
if (enumerableObject.GetEnumerator().MoveNext()) // Returns true if the collection is not empty.
{ // Serialize it as normal.
retVal = base.Enter(value, context);
}
else
{ // Skip this item.
retVal = false;
}
}
else
{ // Not a collection, normal serialization.
retVal = base.Enter(value, context);
}
return retVal;
}
}
I believe the answer is to also override the EnterMapping() method in the base class with logic that is similar to what was done in the Enter() method:
public override bool EnterMapping(IPropertyDescriptor key, IObjectDescriptor value, IEmitter context)
{
bool retVal = false;
if (value.Value == null)
return retVal;
if (typeof(System.Collections.IEnumerable).IsAssignableFrom(value.Value.GetType()))
{ // We have a collection
var enumerableObject = (System.Collections.IEnumerable)value.Value;
if (enumerableObject.GetEnumerator().MoveNext()) // Returns true if the collection is not empty.
{ // Don't skip this item - serialize it as normal.
retVal = base.EnterMapping(key, value, context);
}
// Else we have an empty collection and the initialized return value of false is correct.
}
else
{ // Not a collection, normal serialization.
retVal = base.EnterMapping(key, value, context);
}
return retVal;
}
I ended up with the following class:
using System.Collections;
using YamlDotNet.Core;
using YamlDotNet.Serialization;
using YamlDotNet.Serialization.ObjectGraphVisitors;
sealed class YamlIEnumerableSkipEmptyObjectGraphVisitor : ChainedObjectGraphVisitor
{
public YamlIEnumerableSkipEmptyObjectGraphVisitor(IObjectGraphVisitor<IEmitter> nextVisitor): base(nextVisitor)
{
}
private bool IsEmptyCollection(IObjectDescriptor value)
{
if (value.Value == null)
return true;
if (typeof(IEnumerable).IsAssignableFrom(value.Value.GetType()))
return !((IEnumerable)value.Value).GetEnumerator().MoveNext();
return false;
}
public override bool Enter(IObjectDescriptor value, IEmitter context)
{
if (IsEmptyCollection(value))
return false;
return base.Enter(value, context);
}
public override bool EnterMapping(IPropertyDescriptor key, IObjectDescriptor value, IEmitter context)
{
if (IsEmptyCollection(value))
return false;
return base.EnterMapping(key, value, context);
}
}
you can specify DefaultValuesHandling
in the serializer:
var serializer = new SerializerBuilder()
.ConfigureDefaultValuesHandling(DefaultValuesHandling.OmitEmptyCollections)
.Build();
or in an attribute YamlMember for a field/property:
public class MyDtoClass
{
[YamlMember(DefaultValuesHandling = DefaultValuesHandling.OmitEmptyCollections)]
public List<string> MyCollection;
}
I want to get all data which is available with the specific username from the table in listview using adapter.
I got error "org.json.JSONException: Value [] of type org.json.JSONArray cannot be converted to JSONObject".
below is code
PHP file is working perfectly. I get all data which is available with the specified username.
ExpenseList.php
<?php
require_once("Config.php");
$response = array();
if(isset($_GET['apicall'])){
switch($_GET['apicall']){
case 'expense':
if(isTheseParametersAvailable(array('username'))){
$username = $_POST['username'];
$stmt = $con->prepare("SELECT * FROM Expense_Master WHERE VV_User_Name = '$username' ORDER BY VD_Expense_Date ASC");
$stmt->execute();
$stmt->store_result();
if($stmt->num_rows > 0){
$stmt->bind_result($expenseid, $userid, $username, $entrydate, $expensedate, $credit, $debit, $description, $modifieddate);
$products = array();
while($stmt->fetch()){
$temp = array();
$temp['expenseid'] = $expenseid;
$temp['userid'] = $userid;
$temp['username'] = $username;
$temp['entrydate'] = $entrydate;
$temp['expensedate'] = $expensedate;
$temp['credit'] = $credit;
$temp['debit'] = $debit;
$temp['description'] = $description;
$temp['modifieddate'] = $modifieddate;
array_push($products, $temp);
$response['error'] = false;
$response['message'] = 'Fetch successfull';
}
}else{
$response['error'] = false;
$response['message'] = 'Invalid username';
}
}
break;
$response['error'] = true;
$response['message'] = 'Invalid Operation Called';
}
}else{
$response['error'] = true;
$response['message'] = 'Invalid API Call';
}
echo json_encode($response);
echo json_encode($products);
function isTheseParametersAvailable($params){
foreach($params as $param){
if(!isset($_POST[$param])){
return false;
}
}
return true;
}
?>
Here is ExpenseList.JAVA
private static final String EXPENSE_URL = "http:server.com/ExpenseList.php?apicall=expense";
private List<ExpenseListNotes> userNotes = new ArrayList<>();
ListView listView;
#Override
protected void onCreate(Bundle savedInstanceState)
{
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_expense_list);
spinnerUserNotes = new ArrayList<SpinnerUserNotes>();
expenseData();
}
private void expenseData() {
//if everything is fine
class expenseData extends AsyncTask<Void, Void, String> {
#Override
protected void onPreExecute() {
super.onPreExecute();
}
#Override
protected void onPostExecute(String s) {
super.onPostExecute(s);
try {
JSONArray array = new JSONArray(s);
for (int i = 0; i < array.length(); i++) {
//getting product object from json array
JSONObject product = array.getJSONObject(i);
userNotes.add(new ExpenseListNotes(
product.getInt("expenseid"),
product.getString("userid"),
product.getString("username"),
product.getString("entrydate"),
product.getString("expensedate"),
product.getString("credit"),
product.getString("debit"),
product.getString("description")));
}
ExpenseListAdapter adapter = new ExpenseListAdapter(ExpenseList.this, userNotes);
listView.setAdapter(adapter);
} catch (JSONException e) {
e.printStackTrace();
Toast.makeText(getApplicationContext(), "Error : "+e.toString(), Toast.LENGTH_LONG).show();
Log.e("Error", e.toString());
}
}
#Override
protected String doInBackground(Void... voids) {
//creating request handler object
RequestHandler requestHandler = new RequestHandler();
//creating request parameters
HashMap<String, String> params = new HashMap<>();
params.put("username", "Alpesh");
//returing the response
return requestHandler.sendPostRequest(URL_EXPENSE, params);
}
}
expenseData ul = new expenseData();
ul.execute();
}
Here is AdapterClass
private class ExpenseListAdapter extends BaseAdapter
{
private Context context;
private List<ExpenseListNotes> invoiceModelArrayList;
public ExpenseListAdapter(Context context, List<ExpenseListNotes> invoiceModelArrayList) {
this.context = context;
this.invoiceModelArrayList = (List<ExpenseListNotes>) invoiceModelArrayList;
}
#Override
public int getCount() {
return invoiceModelArrayList.size();
}
#Override
public Object getItem(int i) {
return invoiceModelArrayList.get(i);
}
#Override
public long getItemId(int i) {
return 0;
}
#Override
public View getView(final int i, View v, ViewGroup viewGroup)
{
final ViewHolder holder;
ButterKnife.bind(this, v);
if (v == null)
{
holder = new ViewHolder();
LayoutInflater inflater = (LayoutInflater) context.getSystemService(Context.LAYOUT_INFLATER_SERVICE);
v = inflater.inflate(R.layout.row_expense_list, null, true);
holder.tvRELUserid = v.findViewById(R.id.tvRELUserid);
holder.tvRELExpenseID = v.findViewById(R.id.tvRELExpenseID);
holder.tvRELUsername = v.findViewById(R.id.tvRELUsername);
holder.tvRELCredit = v.findViewById(R.id.tvRELCredit);
holder.tvRELDebit = v.findViewById(R.id.tvRELDebit);
holder.tvRELExpenseDate = v.findViewById(R.id.tvRELExpenseDate);
holder.tvRELEntryDate = v.findViewById(R.id.tvRELEntryDate);
holder.tvRELDescription = v.findViewById(R.id.tvRELDescription);
holder.btnDelete = v.findViewById(R.id.btnRELDelete);
holder.btnUpdate = v.findViewById(R.id.btnRELUpdate);
v.setTag(holder);
}
else
{
holder = (ViewHolder)v.getTag();
}
holder.tvRELExpenseID.setText(String.valueOf(invoiceModelArrayList.get(i).getExpenseID()));
holder.tvRELUserid.setText(String.valueOf(invoiceModelArrayList.get(i).getUserID()));
holder.tvRELUsername.setText(String.valueOf(invoiceModelArrayList.get(i).getUsername()));
holder.tvRELCredit.setText(String.valueOf(invoiceModelArrayList.get(i).getCredit()));
holder.tvRELDebit.setText(String.valueOf(invoiceModelArrayList.get(i).getDebit()));
holder.tvRELExpenseDate.setText(String.valueOf(invoiceModelArrayList.get(i).getExpenseDate()));
holder.tvRELEntryDate.setText(String.valueOf(invoiceModelArrayList.get(i).getEntryDate()));
holder.tvRELDescription.setText(String.valueOf(invoiceModelArrayList.get(i).getDescription()));
return v;
}
public void setFilter(List<ExpenseListNotes> newList)
{
invoiceModelArrayList = new ArrayList<>();
invoiceModelArrayList.addAll(newList);
notifyDataSetChanged();
}
}
private class ViewHolder {
TextView tvRELUserid, tvRELExpenseID, tvRELUsername, tvRELCredit, tvRELDebit, tvRELExpenseDate, tvRELEntryDate, tvRELDescription;
Button btnDelete, btnUpdate;
}
Code for RequestHandler.Class
public class RequestHandler
{
public String sendPostRequest(String requestURL, HashMap<String, String> postDataParams)
{
URL url;
StringBuilder sb = new StringBuilder();
try
{
url = new URL(requestURL);
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
conn.setReadTimeout(15000);
conn.setConnectTimeout(15000);
conn.setRequestMethod("POST");
conn.setDoInput(true);
conn.setDoOutput(true);
OutputStream os = conn.getOutputStream();
BufferedWriter writer = new BufferedWriter(
new OutputStreamWriter(os, "UTF-8"));
writer.write(getPostDataString(postDataParams));
writer.flush();
writer.close();
os.close();
int responseCode = conn.getResponseCode();
if (responseCode == HttpsURLConnection.HTTP_OK) {
BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream()));
sb = new StringBuilder();
String response;
while ((response = br.readLine()) != null) {
sb.append(response);
}
}
}
catch (Exception e)
{
e.printStackTrace();
}
return sb.toString();
}
//this method is converting keyvalue pairs data into a query string as needed to send to the server
private String getPostDataString(HashMap<String, String> params) throws UnsupportedEncodingException
{
StringBuilder result = new StringBuilder();
boolean first = true;
for (Map.Entry<String, String> entry : params.entrySet())
{
if (first)
first = false;
else
result.append("&");
result.append(URLEncoder.encode(entry.getKey(), "UTF-8"));
result.append("=");
result.append(URLEncoder.encode(entry.getValue(), "UTF-8"));
}
return result.toString();
}
}
Where am I doing wrong?
Using Spark 1.6 and the ML library I am saving the results of a trained RandomForestClassificationModel using toDebugString():
val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
val stringModel =rfModel.toDebugString
//save stringModel into a file in the driver in format .txt
So my idea is that in the future read the file .txt and load the trained randomForest, is it possible?
thanks!
That won't work. ToDebugString is merely a debug info to understand how it's got calculated.
If you want to keep this thing for later use, you can do the same we do, which is (although we are in pure java) simply serialise RandomForestModel object. There might be version incompatibilities with default java serialisation, so we use Hessian to do it. It worked through versions update - we started with spark 1.6.1 and it still works with spark 2.0.2.
If you're ok with not sticking to ml, juste use mllib's implementation: the RandomForestModel you get with mllib has a save function.
At least for Spark 2.1.0 you can do this with the following Java (sorry - no Scala) code. However, it may not be the smartest idea to rely on an undocumented format that may change without notice.
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.net.URL;
import java.util.*;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static java.nio.charset.StandardCharsets.US_ASCII;
/**
* RandomForest.
*/
public abstract class RandomForest {
private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class);
protected final List<Node> trees = new ArrayList<>();
/**
* #param model model file (format is Spark's RandomForestClassificationModel toDebugString())
* #throws IOException
*/
public RandomForest(final URL model) throws IOException {
try (final BufferedReader reader = new BufferedReader(new InputStreamReader(model.openStream(), US_ASCII))) {
Node node;
while ((node = load(reader)) != null) {
trees.add(node);
}
}
if (trees.isEmpty()) throw new IOException("Failed to read trees from " + model);
if (LOG.isDebugEnabled()) LOG.debug("Found " + trees.size() + " trees.");
}
private static Node load(final BufferedReader reader) throws IOException {
final Pattern ifPattern = Pattern.compile("If \\(feature (\\d+) (in|not in|<=|>) (.*)\\)");
final Pattern predictPattern = Pattern.compile("Predict: (\\d+\\.\\d+(E-\\d+)?)");
Node root = null;
final List<Node> stack = new ArrayList<>();
String line;
while ((line = reader.readLine()) != null) {
final String trimmed = line.trim();
//System.out.println(trimmed);
if (trimmed.startsWith("RandomForest")) {
// skip the "Tree 1" line
reader.readLine();
} else if (trimmed.startsWith("Tree")) {
break;
} else if (trimmed.startsWith("If")) {
// extract feature index
final Matcher m = ifPattern.matcher(trimmed);
m.matches();
final int featureIndex = Integer.parseInt(m.group(1));
final String operator = m.group(2);
final String operand = m.group(3);
final Predicate<Float> predicate;
if ("<=".equals(operator)) {
predicate = new LessOrEqual(Float.parseFloat(operand));
} else if (">".equals(operator)) {
predicate = new Greater(Float.parseFloat(operand));
} else if ("in".equals(operator)) {
predicate = new In(parseFloatArray(operand));
} else if ("not in".equals(operator)) {
predicate = new NotIn(parseFloatArray(operand));
} else {
predicate = null;
}
final Node node = new Node(featureIndex, predicate);
if (stack.isEmpty()) {
root = node;
} else {
insert(stack, node);
}
stack.add(node);
} else if (trimmed.startsWith("Predict")) {
final Matcher m = predictPattern.matcher(trimmed);
m.matches();
final Object node = Float.parseFloat(m.group(1));
insert(stack, node);
}
}
return root;
}
private static void insert(final List<Node> stack, final Object node) {
Node parent = stack.get(stack.size() - 1);
while (parent.getLeftChild() != null && parent.getRightChild() != null) {
stack.remove(stack.size() - 1);
parent = stack.get(stack.size() - 1);
}
if (parent.getLeftChild() == null) parent.setLeftChild(node);
else parent.setRightChild(node);
}
private static float[] parseFloatArray(final String set) {
final StringTokenizer st = new StringTokenizer(set, "{,}");
final float[] floats = new float[st.countTokens()];
for (int i=0; st.hasMoreTokens(); i++) {
floats[i] = Float.parseFloat(st.nextToken());
}
return floats;
}
public abstract float predict(final float[] features);
public String toDebugString() {
try {
final StringWriter sw = new StringWriter();
for (int i=0; i<trees.size(); i++) {
sw.write("Tree " + i + ":\n");
print(sw, "", trees.get(0));
}
return sw.toString();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
private static void print(final Writer w, final String indent, final Object object) throws IOException {
if (object instanceof Number) {
w.write(indent + "Predict: " + object + "\n");
} else if (object instanceof Node) {
final Node node = (Node) object;
// left node
w.write(indent + node + "\n");
print(w, indent + " ", node.getLeftChild());
w.write(indent + "Else\n");
print(w, indent + " ", node.getRightChild());
}
}
#Override
public String toString() {
return getClass().getSimpleName() + "{numTrees=" + trees.size() + "}";
}
/**
* Node.
*/
protected static class Node {
private final int featureIndex;
private final Predicate<Float> predicate;
private Object leftChild;
private Object rightChild;
public Node(final int featureIndex, final Predicate<Float> predicate) {
Objects.requireNonNull(predicate);
this.featureIndex = featureIndex;
this.predicate = predicate;
}
public void setLeftChild(final Object leftChild) {
this.leftChild = leftChild;
}
public void setRightChild(final Object rightChild) {
this.rightChild = rightChild;
}
public Object getLeftChild() {
return leftChild;
}
public Object getRightChild() {
return rightChild;
}
public Object eval(final float[] features) {
Object result = this;
do {
final Node node = (Node)result;
result = node.predicate.test(features[node.featureIndex]) ? node.leftChild : node.rightChild;
} while (result instanceof Node);
return result;
}
#Override
public String toString() {
return "If (feature " + featureIndex + " " + predicate + ")";
}
}
private static class LessOrEqual implements Predicate<Float> {
private final float value;
public LessOrEqual(final float value) {
this.value = value;
}
#Override
public boolean test(final Float f) {
return f <= value;
}
#Override
public String toString() {
return "<= " + value;
}
}
private static class Greater implements Predicate<Float> {
private final float value;
public Greater(final float value) {
this.value = value;
}
#Override
public boolean test(final Float f) {
return f > value;
}
#Override
public String toString() {
return "> " + value;
}
}
private static class In implements Predicate<Float> {
private final float[] array;
public In(final float[] array) {
this.array = array;
}
#Override
public boolean test(final Float f) {
for (int i=0; i<array.length; i++) {
if (array[i] == f) return true;
}
return false;
}
#Override
public String toString() {
return "in " + Arrays.toString(array);
}
}
private static class NotIn implements Predicate<Float> {
private final float[] array;
public NotIn(final float[] array) {
this.array = array;
}
#Override
public boolean test(final Float f) {
for (int i=0; i<array.length; i++) {
if (array[i] == f) return false;
}
return true;
}
#Override
public String toString() {
return "not in " + Arrays.toString(array);
}
}
}
To use the class for classification, use:
import java.io.IOException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
/**
* RandomForestClassifier.
*/
public class RandomForestClassifier extends RandomForest {
public RandomForestClassifier(final URL model) throws IOException {
super(model);
}
#Override
public float predict(final float[] features) {
final Map<Object, Integer> counts = new HashMap<>();
trees.stream().map(node -> node.eval(features))
.forEach(result -> {
Integer count = counts.get(result);
if (count == null) {
counts.put(result, 1);
} else {
counts.put(result, count + 1);
}
});
return (Float)counts.entrySet()
.stream()
.sorted((o1, o2) -> Integer.compare(o2.getValue(), o1.getValue()))
.map(Map.Entry::getKey)
.findFirst().get();
}
}
For regression:
import java.io.IOException;
import java.net.URL;
/**
* RandomForestRegressor.
*/
public class RandomForestRegressor extends RandomForest {
public RandomForestRegressor(final URL model) throws IOException {
super(model);
}
#Override
public float predict(final float[] features) {
return (float)trees
.stream()
.mapToDouble(node -> ((Number)node.eval(features)).doubleValue())
.average()
.getAsDouble();
}
}