Skip to content
Snippets Groups Projects
Commit 78229137 authored by boqiren's avatar boqiren
Browse files

learner with activity and preference

parent eda27000
No related branches found
No related tags found
No related merge requests found
Showing
with 535 additions and 494 deletions
......@@ -91,6 +91,9 @@ aspect Rules {
public void TriggerRuleAction.applyFor(Item item) {
getRule().trigger(item);
}
public void SetStateFromExpression.applyFor(Item item) {
getAffectedItem().setStateFromDouble(getNumberExpression().eval());
}
public void SetStateFromConstantStringAction.applyFor(Item item) {
getAffectedItem().setStateFromString(getNewState());
}
......
......@@ -13,6 +13,8 @@ rel TriggerRuleAction.Rule -> Rule ;
abstract SetStateAction : Action ;
rel SetStateAction.AffectedItem -> Item ;
SetStateFromExpression : SetStateAction ::= NumberExpression ;
SetStateFromConstantStringAction : SetStateAction ::= <NewState:String> ;
SetStateFromLambdaAction : SetStateAction ::= <NewStateProvider:NewStateProvider> ;
SetStateFromTriggeringItemAction : SetStateAction ::= ;
......
......@@ -30,7 +30,7 @@ import de.tudresden.inf.st.eraser.jastadd.parser.EraserParser.Terminals;
%}
WhiteSpace = [ ] | \t | \f | \n | \r | \r\n
//Identifier = [:jletter:][:jletterdigit:]*
Identifier = [:jletter:][:jletterdigit:]*
Text = \" ([^\"]*) \"
Integer = [:digit:]+ // | "+" [:digit:]+ | "-" [:digit:]+
......@@ -125,7 +125,7 @@ Comment = "//" [^\n\r]+
")" { return sym(Terminals.RB_ROUND); }
"{" { return sym(Terminals.LB_CURLY); }
"}" { return sym(Terminals.RB_CURLY); }
//{Identifier} { return sym(Terminals.NAME); }
{Identifier} { return sym(Terminals.NAME); }
{Text} { return symText(Terminals.TEXT); }
{Integer} { return sym(Terminals.INTEGER); }
{Real} { return sym(Terminals.REAL); }
......
......@@ -90,7 +90,7 @@ NumberLiteralExpression literal_expression =
;
Designator designator =
TEXT.n {: return eph.createDesignator(n); :}
NAME.n {: return eph.createDesignator(n); :}
;
Thing thing =
......
......@@ -48,6 +48,7 @@ public class EraserParserHelper {
private Root root;
private static boolean checkUnusedElements = true;
private static Root initialRoot = null;
private class ItemPrototype extends DefaultItem {
......@@ -61,21 +62,31 @@ public class EraserParserHelper {
EraserParserHelper.checkUnusedElements = checkUnusedElements;
}
public static void setInitialRoot(Root root) {
EraserParserHelper.initialRoot = root;
}
/**
* Post processing step after parsing a model, to resolve all references within the model.
* @throws java.util.NoSuchElementException if a reference can not be resolved
*/
public void resolveReferences() {
if (this.root == null) {
// when parsing expressions
this.root = createRoot();
this.root = EraserParserHelper.initialRoot != null ? EraserParserHelper.initialRoot : createRoot();
}
if (checkUnusedElements) {
fillUnused();
}
resolve(thingTypeMap, missingThingTypeMap, Thing::setType);
resolve(channelTypeMap, missingChannelTypeMap, Channel::setType);
if (itemMap == null || itemMap.isEmpty()) {
missingItemForDesignator.forEach((designator, itemName) ->
JavaUtils.ifPresentOrElse(root.getOpenHAB2Model().resolveItem(itemName),
designator::setItem,
() -> logger.warn("Could not resolve item {} for {}", itemName, designator)));
} else {
resolve(itemMap, missingItemForDesignator, Designator::setItem);
}
missingTopicMap.forEach((topic, parts) -> ParserUtils.createMqttTopic(topic, parts, this.root));
this.root.getMqttRoot().ensureCorrectPrefixes();
......
......@@ -6,6 +6,7 @@ import beaver.Symbol;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import de.tudresden.inf.st.eraser.jastadd.parser.EraserParser;
import de.tudresden.inf.st.eraser.jastadd.scanner.EraserScanner;
import de.tudresden.inf.st.eraser.parser.EraserParserHelper;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
......@@ -193,14 +194,23 @@ public class ParserUtils {
}
public static NumberExpression parseNumberExpression(String expression_string) throws IOException, Parser.Exception {
return (NumberExpression) parseExpression(expression_string, EraserParser.AltGoals.number_expression);
return parseNumberExpression(expression_string, null);
}
public static LogicalExpression parseLogicalExpression(String expression_string) throws IOException, Parser.Exception {
return (LogicalExpression) parseExpression(expression_string, EraserParser.AltGoals.logical_expression);
return parseLogicalExpression(expression_string, null);
}
private static Expression parseExpression(String expression_string, short alt_goal) throws IOException, Parser.Exception {
public static NumberExpression parseNumberExpression(String expression_string, Root root) throws IOException, Parser.Exception {
return (NumberExpression) parseExpression(expression_string, EraserParser.AltGoals.number_expression, root);
}
public static LogicalExpression parseLogicalExpression(String expression_string, Root root) throws IOException, Parser.Exception {
return (LogicalExpression) parseExpression(expression_string, EraserParser.AltGoals.logical_expression, root);
}
private static Expression parseExpression(String expression_string, short alt_goal, Root root) throws IOException, Parser.Exception {
EraserParserHelper.setInitialRoot(root);
StringReader reader = new StringReader(expression_string);
if (verboseLoading) {
EraserScanner scanner = new EraserScanner(reader);
......@@ -220,6 +230,7 @@ public class ParserUtils {
Expression result = (Expression) parser.parse(scanner, alt_goal);
parser.resolveReferences();
reader.close();
EraserParserHelper.setInitialRoot(null);
return result;
}
......
......@@ -3,6 +3,7 @@ package de.tudresden.inf.st.eraser;
import beaver.Parser;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import de.tudresden.inf.st.eraser.util.ParserUtils;
import org.junit.Assert;
import org.junit.Test;
import java.io.IOException;
......@@ -128,6 +129,16 @@ public class ExpressionParserTest {
assertThat(rightOfSub.getValue(), equalTo(8.0));
}
@Test
public void expressionWithItem() {
try {
ParserUtils.parseNumberExpression("(myItem * 3)");
} catch (IOException | Parser.Exception e) {
e.printStackTrace();
Assert.fail(e.getMessage());
}
}
@Test
public void comparingExpressions() throws IOException, Parser.Exception {
comparingExpression("<", ComparatorType.LessThan, 1, 2);
......
package de.tudresden.inf.st.eraser;
import beaver.Parser;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import de.tudresden.inf.st.eraser.util.ParserUtils;
import de.tudresden.inf.st.eraser.util.TestUtils;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.StreamSupport;
/**
......@@ -24,26 +28,28 @@ public class RulesTest {
private static final double DELTA = 0.01d;
class Counters implements Action2EditConsumer {
Map<Item, Integer> counters;
class CountingAction extends NoopAction {
final Map<Item, AtomicInteger> counters = new HashMap<>();
Counters() {
CountingAction() {
reset();
}
private AtomicInteger getAtomic(Item item) {
return counters.computeIfAbsent(item, unused -> new AtomicInteger(0));
}
@Override
public void accept(Item item) {
counters.computeIfPresent(item, (i, value) -> value + 1);
counters.putIfAbsent(item, 1);
public void applyFor(Item item) {
getAtomic(item).addAndGet(1);
}
int get(Item item) {
counters.putIfAbsent(item, 0);
return counters.get(item);
return getAtomic(item).get();
}
void reset() {
counters = new HashMap<>();
counters.clear();
}
}
......@@ -54,8 +60,8 @@ public class RulesTest {
NumberItem item = modelAndItem.item;
Rule rule = new Rule();
Counters counter = new Counters();
rule.addAction(new LambdaAction(counter));
CountingAction counter = new CountingAction();
rule.addAction(counter);
root.addRule(rule);
rule.activateFor(item);
......@@ -97,14 +103,14 @@ public class RulesTest {
Root root = modelAndItem.model.getRoot();
NumberItem item = modelAndItem.item;
Counters counter1 = new Counters();
CountingAction counter1 = new CountingAction();
Rule ruleA = new Rule();
ruleA.addAction(new LambdaAction(counter1));
ruleA.addAction(counter1);
root.addRule(ruleA);
Rule ruleB = new Rule();
Counters counter2 = new Counters();
ruleB.addAction(new LambdaAction(counter2));
CountingAction counter2 = new CountingAction();
ruleB.addAction(counter2);
root.addRule(ruleB);
ruleA.activateFor(item);
......@@ -134,8 +140,8 @@ public class RulesTest {
NumberItem item2 = TestUtils.addItemTo(root.getOpenHAB2Model(), 4, useUpdatingItem);
Rule rule = new Rule();
Counters counter = new Counters();
rule.addAction(new LambdaAction(counter));
CountingAction counter = new CountingAction();
rule.addAction(counter);
root.addRule(rule);
rule.activateFor(item1);
......@@ -192,8 +198,8 @@ public class RulesTest {
NumberItem item = modelAndItem.item;
Rule rule = new Rule();
Counters counter = new Counters();
rule.addAction(new LambdaAction(counter));
CountingAction counter = new CountingAction();
rule.addAction(counter);
root.addRule(rule);
rule.activateFor(item);
......@@ -220,8 +226,8 @@ public class RulesTest {
ItemStateNumberCheck check2 = new ItemStateNumberCheck(ComparatorType.LessThan, 6);
rule.addCondition(new ItemStateCheckCondition(check1));
rule.addCondition(new ItemStateCheckCondition(check2));
Counters counter = new Counters();
rule.addAction(new LambdaAction(counter));
CountingAction counter = new CountingAction();
rule.addAction(counter);
root.addRule(rule);
rule.activateFor(item);
......@@ -250,10 +256,10 @@ public class RulesTest {
NumberItem item = modelAndItem.item;
Rule rule = new Rule();
Counters counter1 = new Counters();
rule.addAction(new LambdaAction(counter1));
Counters counter2 = new Counters();
rule.addAction(new LambdaAction(counter2));
CountingAction counter1 = new CountingAction();
rule.addAction(counter1);
CountingAction counter2 = new CountingAction();
rule.addAction(counter2);
root.addRule(rule);
rule.activateFor(item);
......@@ -277,12 +283,12 @@ public class RulesTest {
NumberItem item2 = TestUtils.addItemTo(root.getOpenHAB2Model(), 4, useUpdatingItem);
Rule ruleA = new Rule();
Counters counter1 = new Counters();
ruleA.addAction(new LambdaAction(counter1));
CountingAction counter1 = new CountingAction();
ruleA.addAction(counter1);
Rule ruleB = new Rule();
Counters counter2 = new Counters();
ruleB.addAction(new LambdaAction(counter2));
CountingAction counter2 = new CountingAction();
ruleB.addAction(counter2);
ruleA.addAction(new TriggerRuleAction(ruleB));
......@@ -321,8 +327,8 @@ public class RulesTest {
Rule rule = new Rule();
rule.addAction(new SetStateFromConstantStringAction(item2, "5"));
Counters counter = new Counters();
rule.addAction(new LambdaAction(counter));
CountingAction counter = new CountingAction();
rule.addAction(counter);
root.addRule(rule);
rule.activateFor(item);
......@@ -353,8 +359,8 @@ public class RulesTest {
Rule rule = new Rule();
rule.addAction(new SetStateFromLambdaAction(item2, provider));
Counters counter = new Counters();
rule.addAction(new LambdaAction(counter));
CountingAction counter = new CountingAction();
rule.addAction(counter);
root.addRule(rule);
rule.activateFor(item);
......@@ -394,9 +400,9 @@ public class RulesTest {
StringItem item2 = addStringItem(root.getOpenHAB2Model(), "0");
Rule rule = new Rule();
Counters counter = new Counters();
CountingAction counter = new CountingAction();
rule.addAction(new SetStateFromTriggeringItemAction(item2));
rule.addAction(new LambdaAction(counter));
rule.addAction(counter);
root.addRule(rule);
rule.activateFor(item);
......@@ -433,8 +439,8 @@ public class RulesTest {
action.addSourceItem(item2);
action.setAffectedItem(affectedItem);
rule.addAction(action);
Counters counter = new Counters();
rule.addAction(new LambdaAction(counter));
CountingAction counter = new CountingAction();
rule.addAction(counter);
root.addRule(rule);
rule.activateFor(item);
......@@ -492,8 +498,8 @@ public class RulesTest {
Rule rule = new Rule();
rule.addAction(new AddDoubleToStateAction(affectedItem, 2));
Counters counter = new Counters();
rule.addAction(new LambdaAction(counter));
CountingAction counter = new CountingAction();
rule.addAction(counter);
root.addRule(rule);
rule.activateFor(item);
......@@ -529,8 +535,8 @@ public class RulesTest {
Rule rule = new Rule();
rule.addAction(new MultiplyDoubleToStateAction(affectedItem, 2));
Counters counter = new Counters();
rule.addAction(new LambdaAction(counter));
CountingAction counter = new CountingAction();
rule.addAction(counter);
root.addRule(rule);
rule.activateFor(item);
......@@ -567,13 +573,13 @@ public class RulesTest {
Rule ruleA = new Rule();
ruleA.addAction(new AddDoubleToStateAction(affectedItem, 2));
Counters counterA = new Counters();
ruleA.addAction(new LambdaAction(counterA));
CountingAction counterA = new CountingAction();
ruleA.addAction(counterA);
Rule ruleB = new Rule();
ruleB.addAction(new MultiplyDoubleToStateAction(affectedItem, 3));
Counters counterB = new Counters();
ruleB.addAction(new LambdaAction(counterB));
CountingAction counterB = new CountingAction();
ruleB.addAction(counterB);
ruleA.addAction(new TriggerRuleAction(ruleB));
......@@ -612,11 +618,62 @@ public class RulesTest {
63, affectedItem.getState(), DELTA);
}
@Test
public void testSetFromExpression() throws IOException, Parser.Exception {
TestUtils.ModelAndItem modelAndItem = createModelAndItem(3);
Root root = modelAndItem.model.getRoot();
NumberItem item1 = modelAndItem.item;
NumberItem item2 = TestUtils.addItemTo(root.getOpenHAB2Model(), 4, useUpdatingItem);
NumberItem affectedItem = TestUtils.addItemTo(root.getOpenHAB2Model(), 5, useUpdatingItem);
Rule rule = new Rule();
SetStateFromExpression action = new SetStateFromExpression();
// TODO item1 should be referred to as triggering item
action.setNumberExpression(ParserUtils.parseNumberExpression("(" + item1.getID() + " + " + item2.getID() + ")", root));
action.setAffectedItem(affectedItem);
rule.addAction(action);
CountingAction counter = new CountingAction();
rule.addAction(counter);
root.addRule(rule);
rule.activateFor(item1);
Assert.assertEquals(m("Counter not initialized correctly"), 0, counter.get(item1));
Assert.assertEquals(m("Second item not initialized correctly"),
4, item2.getState(), DELTA);
Assert.assertEquals(m("Affected item not initialized correctly"),
5, affectedItem.getState(), DELTA);
// 5 + 4 = 9
setState(item1, 5);
Assert.assertEquals(m("Change of item state should trigger the rule"), 1, counter.get(item1));
Assert.assertEquals(m("Change of item state should set the state of the affected item"),
9, affectedItem.getState(), DELTA);
// still 9
setState(item1, 5);
Assert.assertEquals(m("Change of item to same state should not trigger the rule"), 1, counter.get(item1));
Assert.assertEquals(m("Change of item to same state should not set the state of the affected item"),
9, affectedItem.getState(), DELTA);
// still 9 (changes of item2 do not trigger the rule)
setState(item2, 1);
Assert.assertEquals(m("Change of second item to same state should not trigger the rule"), 1, counter.get(item1));
Assert.assertEquals(m("Change of second item to same state should not set the state of the affected item"),
9, affectedItem.getState(), DELTA);
// 0 + 1 = 1
setState(item1, 0);
Assert.assertEquals(m("Change of item state should trigger the rule"), 2, counter.get(item1));
Assert.assertEquals(m("Change of item state should set the state of the affected item"),
1, affectedItem.getState(), DELTA);
}
@Test
public void testCronJobRule() {
Rule rule = new Rule();
Counters counter = new Counters();
rule.addAction(new LambdaAction(counter));
CountingAction counter = new CountingAction();
rule.addAction(counter);
Assert.assertEquals(m("Counter not initialized correctly"), 0, counter.get(null));
......
<?xml version="1.0" encoding="UTF-8"?>
<module type="JAVA_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
package de.tudresden.inf.st.eraser.starter;
import beaver.Parser;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import de.tudresden.inf.st.eraser.feedbackloop.analyze.AnalyzeImpl;
import de.tudresden.inf.st.eraser.feedbackloop.api.Analyze;
import de.tudresden.inf.st.eraser.feedbackloop.api.Execute;
import de.tudresden.inf.st.eraser.feedbackloop.api.Learner;
import de.tudresden.inf.st.eraser.feedbackloop.api.Plan;
import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model;
import de.tudresden.inf.st.eraser.feedbackloop.execute.ExecuteImpl;
import de.tudresden.inf.st.eraser.feedbackloop.learner.LearnerImpl;
import de.tudresden.inf.st.eraser.feedbackloop.plan.PlanImpl;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import de.tudresden.inf.st.eraser.openhab2.OpenHab2Importer;
import de.tudresden.inf.st.eraser.spark.Application;
import de.tudresden.inf.st.eraser.util.JavaUtils;
import de.tudresden.inf.st.eraser.util.ParserUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.File;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
/**
* This Starter combines and starts all modules. This includes:
*
* <ul>
* <li>Knowledge-Base in <code>eraser-base</code></li>
* <li>Feedback loop in <code>feedbackloop.{analyze,plan,execute}</code></li>
* <li>REST-API in <code>eraser-rest</code></li>
* </ul>
*/
public class EraserStarter {
private static final Logger logger = LogManager.getLogger(EraserStarter.class);
@SuppressWarnings("ResultOfMethodCallIgnored")
public static void main(String[] args) {
logger.info("Starting ERASER");
ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
File settingsFile = new File("starter-setting.yaml");
Setting settings;
try {
settings = mapper.readValue(settingsFile, Setting.class);
} catch (Exception e) {
logger.fatal("Could not read settings at '{}'. Exiting.", settingsFile.getAbsolutePath());
logger.catching(e);
System.exit(1);
return;
}
boolean startRest = settings.rest.use;
Root model;
switch (settings.initModelWith) {
case openhab:
OpenHab2Importer importer = new OpenHab2Importer();
try {
model = importer.importFrom(new URL(settings.openhab.url));
} catch (MalformedURLException e) {
logger.error("Could not parse URL {}", settings.openhab.url);
logger.catching(e);
System.exit(1);
return;
}
logger.info("Imported model {}", model.description());
break;
case load:
default:
try {
model = ParserUtils.load(settings.load.file, EraserStarter.class);
} catch (IOException | Parser.Exception e) {
logger.error("Problems parsing the given file {}", settings.load.file);
logger.catching(e);
System.exit(1);
return;
}
}
// initialize activity recognition
if (settings.activity.dummy) {
logger.info("Using dummy activity recognition");
model.getMachineLearningRoot().setActivityRecognition(DummyMachineLearningModel.createDefault());
} else {
logger.error("Reading activity recognition from file is not supported yet!");
// TODO
}
// initialize preference learning
if (settings.preference.dummy) {
logger.info("Using dummy preference learning");
model.getMachineLearningRoot().setPreferenceLearning(DummyMachineLearningModel.createDefault());
} else {
logger.info("Reading preference learning from file {}", settings.preference.file);
Learner learner = new LearnerImpl();
// there should be a method to load a model using an URL
Model preference = learner.getTrainedModel(settings.preference.realURL(), settings.preference.id);
NeuralNetworkRoot neuralNetwork = LearnerHelper.transform(preference);
if (neuralNetwork == null) {
logger.error("Could not create preference model, see possible previous errors.");
} else {
model.getMachineLearningRoot().setPreferenceLearning(neuralNetwork);
neuralNetwork.connectItems(settings.preference.items);
neuralNetwork.setOutputApplication(zeroToThree -> 25 * zeroToThree);
JavaUtils.ifPresentOrElse(
model.resolveItem(settings.preference.affectedItem),
item -> neuralNetwork.getOutputLayer().setAffectedItem(item),
() -> logger.error("Output item not set from value '{}'", settings.preference.affectedItem));
}
}
if (!model.getMachineLearningRoot().getActivityRecognition().check()) {
logger.fatal("Invalid activity recognition!");
System.exit(1);
}
if (!model.getMachineLearningRoot().getPreferenceLearning().check()) {
logger.fatal("Invalid preference learning!");
System.exit(1);
}
Analyze analyze = null;
if (settings.useMAPE) {
// configure and start mape loop
logger.info("Starting MAPE loop");
analyze = new AnalyzeImpl();
Plan plan = new PlanImpl();
Execute execute = new ExecuteImpl();
analyze.setPlan(plan);
plan.setExecute(execute);
analyze.setKnowledgeBase(model);
plan.setKnowledgeBase(model);
execute.setKnowledgeBase(model);
analyze.startAsThread(1, TimeUnit.SECONDS);
if (!startRest) {
// alternative exit condition
System.out.println("Hit [Enter] to exit");
try {
System.in.read();
} catch (IOException e) {
e.printStackTrace();
}
System.out.println("Stopping...");
analyze.stop();
}
} else {
logger.info("No MAPE loop this time");
}
if (startRest) {
// start REST-API in new thread
logger.info("Starting REST server");
Lock lock = new ReentrantLock();
Condition quitCondition = lock.newCondition();
Thread t = new Thread(new ThreadGroup("REST-API"),
() -> Application.start(settings.rest.port, model, settings.rest.createDummyMLData, lock, quitCondition));
t.setDaemon(true);
t.start();
logger.info("Waiting until request is send to '/system/exit'");
try {
lock.lock();
if (t.isAlive()) {
quitCondition.await();
}
} catch (InterruptedException e) {
logger.warn("Waiting was interrupted");
} finally {
lock.unlock();
}
if (analyze != null) {
analyze.stop();
}
} else {
logger.info("No REST server this time");
}
logger.info("I'm done here.");
}
}
package de.tudresden.inf.st.eraser.starter;
import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.Arrays;
/**
* Transformation of a {@link Model} into a {@link MachineLearningModel}.
*
* @author rschoene - Initial contribution
*/
class LearnerHelper {
private static final Logger logger = LogManager.getLogger(LearnerHelper.class);
// Activation Functions
private static DoubleArrayDoubleFunction sigmoid = inputs -> Math.signum(Arrays.stream(inputs).sum());
private static DoubleArrayDoubleFunction tanh = inputs -> Math.tanh(Arrays.stream(inputs).sum());
private static DoubleArrayDoubleFunction function_one = inputs -> 1.0;
static NeuralNetworkRoot transform(Model model) {
NeuralNetworkRoot result = NeuralNetworkRoot.createEmpty();
ArrayList<Double> weights = model.getWeights();
// inputs
int inputSum = model.getInputLayerNumber() + model.getInputBias();
for (int i = 0; i < inputSum; ++i) {
InputNeuron inputNeuron = new InputNeuron();
result.addInputNeuron(inputNeuron);
}
InputNeuron bias = result.getInputNeuron(model.getInputBias());
OutputLayer outputLayer = new OutputLayer();
// output layer
for (int i = 0; i < model.getOutputLayerNumber(); ++i) {
OutputNeuron outputNeuron = new OutputNeuron();
setActivationFunction(outputNeuron, model.getOutputActivationFunction());
outputLayer.addOutputNeuron(outputNeuron);
}
result.setOutputLayer(outputLayer);
// hidden layer
int hiddenSum = model.gethiddenLayerNumber() + model.getHiddenBias();
HiddenNeuron[] hiddenNeurons = new HiddenNeuron[hiddenSum];
for (int i = 0; i < (hiddenNeurons.length); i++) {
if (i == model.gethiddenLayerNumber()) {
HiddenNeuron hiddenNeuron = new HiddenNeuron();
hiddenNeuron.setActivationFormula(function_one);
hiddenNeurons[i] = hiddenNeuron;
result.addHiddenNeuron(hiddenNeuron);
bias.connectTo(hiddenNeuron, 1.0);
for (int out = 0; out < outputLayer.getNumOutputNeuron(); out++) {
hiddenNeuron.connectTo(outputLayer.getOutputNeuron(out), weights.get(i + hiddenSum * out));
}
} else {
HiddenNeuron hiddenNeuron = new HiddenNeuron();
setActivationFunction(hiddenNeuron, model.getHiddenActivationFunction());
hiddenNeurons[i] = hiddenNeuron;
result.addHiddenNeuron(hiddenNeuron);
for (int in = 0; in < inputSum; in++) {
// TODO replace 4 and 5 with model-attributes
result.getInputNeuron(in).connectTo(hiddenNeuron, weights.get((hiddenNeurons.length * 4 + in) + i * 5));
}
for (int out = 0; out < outputLayer.getNumOutputNeuron(); out++) {
hiddenNeuron.connectTo(outputLayer.getOutputNeuron(out), weights.get(i + hiddenSum * out));
}
}
}
outputLayer.setCombinator(LearnerHelper::predictor);
logger.info("Created model with {} input, {} hidden and {} output neurons",
result.getNumInputNeuron(), result.getNumHiddenNeuron(), result.getOutputLayer().getNumOutputNeuron());
return result;
}
private static void setActivationFunction(HiddenNeuron neuron, String functionName) {
switch (functionName) {
case "ActivationTANH": neuron.setActivationFormula(tanh); break;
case "ActivationLinear": neuron.setActivationFormula(function_one);
case "ActivationSigmoid": neuron.setActivationFormula(sigmoid); break;
default: throw new IllegalArgumentException("Unknown function " + functionName);
}
}
private static double predictor(double[] inputs) {
int index = 0;
double maxInput = StatUtils.max(inputs);
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] == maxInput) {
index = i;
}
}
//outputs from learner
final double[] outputs = new double[]{2.0, 1.0, 3.0, 0.0};
return outputs[index];
}
}
package de.tudresden.inf.st.eraser.starter;
import java.net.URL;
import java.util.List;
/**
* Setting bean.
*
* @author rschoene - Initial contribution
*/
@SuppressWarnings("WeakerAccess")
class Setting {
public class Rest {
public boolean use = true;
public int port = 4567;
public boolean createDummyMLData = false;
}
public class FileContainer {
public String file;
URL realURL() {
return Setting.class.getClassLoader().getResource(file);
}
}
public class MLContainer extends FileContainer {
public boolean dummy = false;
public int id = 1;
public List<String> items;
public String affectedItem;
}
public class OpenHabContainer {
public String url;
}
public enum InitModelWith {
load, openhab
}
public Rest rest;
public boolean useMAPE = true;
public FileContainer load;
public MLContainer activity;
public MLContainer preference;
public OpenHabContainer openhab;
public InitModelWith initModelWith = InitModelWith.load;
}
......@@ -15,6 +15,8 @@ dependencies {
testCompile group: 'junit', name: 'junit', version: '4.12'
testCompile group: 'org.hamcrest', name: 'hamcrest-junit', version: '2.0.0.0'
compile group: 'org.encog', name: 'encog-core', version: '3.4'
implementation group: 'com.opencsv', name: 'opencsv', version: '4.1'
implementation group: 'commons-io', name: 'commons-io', version: '2.5'
}
run {
......
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import java.io.*;
import java.util.*;
import com.opencsv.CSVReader;
import com.opencsv.CSVWriter;
public class CsvTransfer {
private static final String CSV_FILE_PATH
= "datasets/backup/activity_data_example.csv";
private static final String OUTPUT_FILE_PATH
= "datasets/backup/activity_data.csv";
public static void main(String[] args)
{
addDataToCSV(CSV_FILE_PATH, OUTPUT_FILE_PATH);
}
public static void addDataToCSV(String input, String output)
{
File input_file = new File(input);
File output_file = new File(output);
try {
// create FileWriter object with file as parameter
FileReader reader = new FileReader(input_file);
CSVReader csv_reader = new CSVReader(reader);
String[] nextRecord;
FileWriter writer = new FileWriter(output_file);
CSVWriter csv_writer = new CSVWriter(writer, ',',
CSVWriter.NO_QUOTE_CHARACTER,
CSVWriter.DEFAULT_ESCAPE_CHARACTER,
CSVWriter.DEFAULT_LINE_END);
List<String[]> data = new ArrayList<String[]>();
while ((nextRecord = csv_reader.readNext()) != null) {
data.add(nextRecord);
for (String cell : nextRecord) {
System.out.print(cell + "\t");
}
System.out.println();
}
csv_writer.writeAll(data);
writer.close();
csv_reader.close();
}
catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import com.opencsv.CSVWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class DummyPreference {
private static String activity; //Activity: walking, reading, working, dancing, lying, getting up
private static String watch_brightness; //dark: <45; dimmer 45-70; bright >70;
private static String light_color_openhab_H; //red 7; green 120; blue 240; yellow 60; sky blue 180; purple 300;
private static String brightness_output; //1-100**/
private static Random random = new Random();
public static void main(String[] args) {
creator();
}
static void creator(){
try{
FileWriter writer = new FileWriter("datasets/backup/preference_data.csv",true);
CSVWriter csv_writer = new CSVWriter(writer, ',',
CSVWriter.NO_QUOTE_CHARACTER,
CSVWriter.DEFAULT_ESCAPE_CHARACTER,
CSVWriter.DEFAULT_LINE_END);
//activity="walking" green
activity ="walking";
// activity ="reading";
csv_writer.writeAll(generator("walking","green"));
csv_writer.writeAll(generator("reading","sky blue"));
csv_writer.writeAll(generator("working","blue"));
csv_writer.writeAll(generator("dancing","purple"));
csv_writer.writeAll(generator("lying","red"));
csv_writer.writeAll(generator("getting up","yellow"));
csv_writer.close();
writer.close();
}catch (IOException e){e.printStackTrace();}
}
static List<String[]> generator(String activity_input, String color){
List<String[]> data = new ArrayList<String[]>();
activity = activity_input;
light_color_openhab_H =color;
//100 walking with different lighting intensity
for (int i=0; i<100; i++){
String[] add_data = new String[4];
int brightness = random.nextInt(3000);
System.out.println(brightness);
if (brightness<45){
watch_brightness = "dark";
brightness_output ="100";
}else if(45<=brightness && brightness<200){
watch_brightness = "dimmer";
brightness_output ="40";
}else if( 200<=brightness && brightness<1000){
watch_brightness = "medium";
brightness_output ="70";
}else{
watch_brightness = "bright";
brightness_output ="0";
}
add_data[0] = activity;
add_data[1] = watch_brightness;
add_data[2] = light_color_openhab_H;
add_data[3] = brightness_output;
data.add(add_data);
}
return data;
}
}
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import java.io.File;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.Arrays;
import com.sun.javafx.tools.packager.Log;
import org.encog.ConsoleStatusReportable;
import org.encog.Encog;
import org.encog.bot.BotUtil;
import org.encog.ml.MLInput;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.versatile.NormalizationHelper;
......@@ -22,8 +16,6 @@ import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.model.EncogModel;
import org.encog.neural.networks.BasicNetwork;
import org.encog.util.csv.CSVFormat;
import org.encog.util.csv.ReadCSV;
import org.encog.util.simple.EncogUtility;
import static org.encog.persist.EncogDirectoryPersistence.*;
public class Learner {
......@@ -32,128 +24,173 @@ public class Learner {
* intial train
* */
private String csv_url_activity;
private String csv_url_perference;
private String csv_url_preference;
private String save_activity_model_file = "datasets/backup/activity_model.eg";
private String save_perference_model_file = "datasets/backup/preference_model.eg";
private String save_preference_model_file = "datasets/backup/preference_model.eg";
private File csv_file;
private VersatileDataSource souce;
private VersatileMLDataSet data;
private EncogModel model;
private VersatileDataSource a_souce;
private VersatileMLDataSet a_data;
private VersatileDataSource p_souce;
private VersatileMLDataSet p_data;
private EncogModel a_model;
private EncogModel p_model;
private NormalizationHelper activity_helper;
private NormalizationHelper preference_helper;
private MLRegression best_method;
private MLRegression a_best_method;
private MLRegression p_best_method;
private String[] new_data;
private String[] preference_result;
private String activity_result;
private String preference_result;
private void activityDataAnalyser(String activity_csv_url){
this.csv_url_activity = activity_csv_url;
this.csv_file = new File(csv_url_activity);
souce = new CSVDataSource(csv_file,false,CSVFormat.DECIMAL_POINT);
data = new VersatileMLDataSet(souce);
data.defineSourceColumn("monat", 0, ColumnType.continuous);
data.defineSourceColumn("day", 1, ColumnType.continuous);
data.defineSourceColumn("hour", 2, ColumnType.continuous);
data.defineSourceColumn("minute", 3, ColumnType.continuous);
ColumnDefinition outputColumn = data.defineSourceColumn("labels", 4, ColumnType.continuous);
data.defineSingleOutputOthersInput(outputColumn);
data.analyze();
System.out.println("get data ");
model = new EncogModel(data);
model.selectMethod(data, MLMethodFactory.TYPE_FEEDFORWARD);
data.normalize();
activity_helper = data.getNormHelper();
System.out.println(activity_helper.toString());
a_souce = new CSVDataSource(csv_file,false,CSVFormat.DECIMAL_POINT);
a_data = new VersatileMLDataSet(a_souce);
a_data.defineSourceColumn("m_accel_x", 0, ColumnType.continuous);
a_data.defineSourceColumn("m_accel_y", 1, ColumnType.continuous);
a_data.defineSourceColumn("m_accel_z", 2, ColumnType.continuous);
a_data.defineSourceColumn("m_rotation_x", 3, ColumnType.continuous);
a_data.defineSourceColumn("m_rotation_y", 4, ColumnType.continuous);
a_data.defineSourceColumn("m_rotation_z", 5, ColumnType.continuous);
a_data.defineSourceColumn("w_accel_x", 6, ColumnType.continuous);
a_data.defineSourceColumn("w_accel_y", 7, ColumnType.continuous);
a_data.defineSourceColumn("w_accel_z", 8, ColumnType.continuous);
a_data.defineSourceColumn("w_rotation_x", 9, ColumnType.continuous);
a_data.defineSourceColumn("w_rotation_y", 10, ColumnType.continuous);
a_data.defineSourceColumn("w_rotation_z", 11, ColumnType.continuous);
ColumnDefinition outputColumn = a_data.defineSourceColumn("labels", 12, ColumnType.nominal);
a_data.defineSingleOutputOthersInput(outputColumn);
a_data.analyze();
a_model = new EncogModel(a_data);
a_model.selectMethod(a_data, MLMethodFactory.TYPE_FEEDFORWARD);
a_data.normalize();
activity_helper = a_data.getNormHelper();
//System.out.println(activity_helper.toString());
}
private void perferenceDataAnalyser(String perference_csv_url){
this.csv_url_perference = perference_csv_url;
this.csv_file = new File(this.csv_url_perference);
souce = new CSVDataSource(csv_file,false,CSVFormat.DECIMAL_POINT);
data = new VersatileMLDataSet(souce);
data.defineSourceColumn("activity", 0, ColumnType.continuous);
data.defineSourceColumn("brightness", 1, ColumnType.continuous);
data.defineSourceColumn("time", 2, ColumnType.continuous);
data.defineSourceColumn("minute", 3, ColumnType.continuous);
ColumnDefinition outputColumn = data.defineSourceColumn("labels", 4, ColumnType.continuous);
data.defineSingleOutputOthersInput(outputColumn);
data.analyze();
model = new EncogModel(data);
model.selectMethod(data, MLMethodFactory.TYPE_FEEDFORWARD);
//model.setReport(new ConsoleStatusReportable());
data.normalize();
preference_helper = data.getNormHelper();
System.out.println(activity_helper.toString());
private void preferenceDataAnalyser(String preference_csv_url){
this.csv_url_preference = preference_csv_url;
this.csv_file = new File(this.csv_url_preference);
p_souce = new CSVDataSource(csv_file,false,CSVFormat.DECIMAL_POINT);
p_data = new VersatileMLDataSet(p_souce);
p_data.defineSourceColumn("activity", 0, ColumnType.nominal);
p_data.defineSourceColumn("w_brightness", 1, ColumnType.nominal);
ColumnDefinition outputColumn1 = p_data.defineSourceColumn("label1", 2, ColumnType.continuous);
ColumnDefinition outputColumn2 = p_data.defineSourceColumn("label2", 3, ColumnType.continuous);
ColumnDefinition[] outputs = new ColumnDefinition[2];
outputs[0] = outputColumn1;
outputs[1] = outputColumn2;
p_data.defineMultipleOutputsOthersInput(outputs);
p_data.analyze();
p_model = new EncogModel(p_data);
p_model.selectMethod(p_data, MLMethodFactory.TYPE_FEEDFORWARD);
p_model.setReport(new ConsoleStatusReportable());
p_data.normalize();
preference_helper = p_data.getNormHelper();
System.out.println(preference_helper.toString());
}
void train(String activity_url,String perference_url){
void train(String activity_url,String preference_url){
activity_train(activity_url);
Log.info("activity training finished");
preference_train(perference_url);
preference_train(preference_url);
Log.info("preference training finished");
Encog.getInstance().shutdown();
}
private void activity_train(String activity_csv_url){
public void activity_train(String activity_csv_url){
activityDataAnalyser(activity_csv_url);
model.holdBackValidation(0.3, true, 1001);
model.selectTrainingType(data);
best_method = (MLRegression)model.crossvalidate(5, true);
System.out.println(best_method);
a_model.holdBackValidation(0.3, true, 1001);
a_model.selectTrainingType(a_data);
a_best_method = (MLRegression)a_model.crossvalidate(5, true);
System.out.println(a_best_method);
saveEncogModel(save_activity_model_file);
Encog.getInstance().shutdown();
}
private void preference_train(String perfence_csv_url){
perferenceDataAnalyser(perfence_csv_url);
model.holdBackValidation(0.3, true, 1001);
model.selectTrainingType(data);
best_method = (MLRegression)model.crossvalidate(5, true);
System.out.println(best_method);
saveEncogModel(save_perference_model_file);
public void preference_train(String prefence_csv_url){
preferenceDataAnalyser(prefence_csv_url);
p_model.holdBackValidation(0.3, true, 1001);
p_model.selectTrainingType(p_data);
p_best_method = (MLRegression)p_model.crossvalidate(5, true);
System.out.println(p_best_method);
saveEncogModel(save_preference_model_file);
Encog.getInstance().shutdown();
}
String[] predictor(String[] new_data){
this.new_data = new_data;
activityDataAnalyser("datasets/backup/activity_data.csv");
perferenceDataAnalyser("datasets/backup/preference_data.csv");
String[] result = new String[2];
result[0] = activity_predictor();
result[1] = perference_predictor();
String[] preference_data = new String[2];
String[] result = new String[3];
String[] activity_data= new String[12];
activity_data[0]=new_data[0];
activity_data[1]=new_data[1];
activity_data[2]=new_data[2];
activity_data[3]=new_data[3];
activity_data[4]=new_data[4];
activity_data[5]=new_data[5];
activity_data[6]=new_data[6];
activity_data[7]=new_data[7];
activity_data[8]=new_data[8];
activity_data[9]=new_data[9];
activity_data[10]=new_data[10];
activity_data[11]=new_data[11];
result[0] = activity_predictor(activity_data);
preference_data[0]=result[0];
preference_data[1]=new_data[12];
result[1] = preference_predictor(preference_data)[0];
result[2] = preference_predictor(preference_data)[1];
Encog.getInstance().shutdown();
return result;
}
private String activity_predictor(){
public String activity_predictor(String[] new_data){
activityDataAnalyser("datasets/backup/activity_data.csv");
BasicNetwork activity_method = (BasicNetwork) loadObject(new File(save_activity_model_file));
MLData input = activity_helper.allocateInputVector();
String[] activity_new_data = new String[4];
String[] activity_new_data = new String[12];
activity_new_data[0] = new_data[0];
activity_new_data[1] = new_data[1];
activity_new_data[2] = new_data[2];
activity_new_data[3] = new_data[3];
activity_new_data[4] = new_data[4];
activity_new_data[5] = new_data[5];
activity_new_data[6] = new_data[6];
activity_new_data[7] = new_data[7];
activity_new_data[8] = new_data[8];
activity_new_data[9] = new_data[9];
activity_new_data[10] = new_data[10];
activity_new_data[11] = new_data[11];
activity_helper.normalizeInputVector(activity_new_data,input.getData(),false);
MLData output = activity_method.compute(input);
System.out.println("input:"+input);
System.out.println("output"+output);
activity_result = activity_helper.denormalizeOutputVectorToString(output)[0];
System.out.println("output activity"+ activity_result);
return activity_result;
}
private String perference_predictor(){
BasicNetwork preference_method = (BasicNetwork)loadObject(new File(save_perference_model_file));
public String[] preference_predictor(String[] new_data){
preference_result = new String[2];
preferenceDataAnalyser("datasets/backup/preference_data.csv");
BasicNetwork preference_method = (BasicNetwork)loadObject(new File(save_preference_model_file));
MLData input = preference_helper.allocateInputVector();
String[] perference_new_data = new String[4];
perference_new_data[0] = activity_result;
perference_new_data[1] = new_data[4];
perference_new_data[2] = new_data[5];
perference_new_data[3] = new_data[6];
preference_helper.normalizeInputVector(perference_new_data, input.getData(),false);
System.out.print("input: "+input);
String[] preference_new_data = new String[2];
//preference_new_data[0] = activity_result;
preference_new_data[0] = new_data[0];
preference_new_data[1] = new_data[1];
preference_helper.normalizeInputVector(preference_new_data, input.getData(),false);
System.out.println(preference_helper);
MLData output = preference_method.compute(input);
preference_result = preference_helper.denormalizeOutputVectorToString(output)[0];
preference_result[0] = preference_helper.denormalizeOutputVectorToString(output)[0];
preference_result[1] = preference_helper.denormalizeOutputVectorToString(output)[1];
return preference_result;
}
private void saveEncogModel(String model_file_url){
saveObject(new File(model_file_url), this.best_method);
if (model_file_url.equals(save_activity_model_file)){saveObject(new File(model_file_url), this.a_best_method);}
else {
saveObject(new File(model_file_url), this.p_best_method);
}
}
}
......
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import org.encog.util.csv.ReadCSV;
import org.encog.util.csv.CSVFormat;
import java.util.Arrays;
public class Main {
......@@ -8,22 +11,140 @@ public class Main {
/**
* new data from KB
* */
String[] new_data = new String[7];
new_data[0]="7";
new_data[1]="20";
new_data[2]="12";
new_data[3]="13";
new_data[4]="7";
new_data[5]="25";
new_data[6]="12";
String[] new_data = new String[13];
new_data[0]="0.10654198";
new_data[1]="8.6574335";
new_data[2]="4.414908";
new_data[3]="0.040269";
new_data[4]="0.516884";
new_data[5]="0.853285";
new_data[6]="1.2066777";
new_data[7]="-1.1444284";
new_data[8]="9.648633";
new_data[9]="1.2207031E-4";
new_data[10]="-0.055358887";
new_data[11]="0.5834961";
new_data[12]="bright";
Learner learner = new Learner();
String[] result =learner.predictor(new_data);
System.out.println(result[0]);
System.out.println(result[1]);
System.out.println(result[2]);
//learner.preference_train("datasets/backup/preference_data.csv");
//learner.train("datasets/backup/activity_data.csv","datasets/backup/preference_data.csv");
//walking,medium,120,70
//reading,bright,180,0
//activity_validation_learner();
//0.10654198,8.6574335,4.414908,0.040269,0.516884,0.853285,1.2066777,-1.1444284,9.648633,1.2207031E-4,-0.055358887,0.5834961 working
/**String[] new_data = new String[12];
new_data[0]="0.10654198";
new_data[1]="8.6574335";
new_data[2]="4.414908";
new_data[3]="0.040269";
new_data[4]="0.516884";
new_data[5]="0.853285";
new_data[6]="1.2066777";
new_data[7]="-1.1444284";
new_data[8]="9.648633";
new_data[9]="1.2207031E-4";
new_data[10]="-0.055358887";
new_data[11]="0.5834961";
String[] new_data_1 =new String[12];
//-2.6252422,8.619126,-2.7030537,0.552147,0.5078,0.450302,-8.1881695,-1.2641385,0.038307227,-0.34222412,0.49102783,-0.016540527,walking
new_data_1[0]="-2.6252422";
new_data_1[1]="8.619126";
new_data_1[2]="-2.7030537";
new_data_1[3]="0.552147";
new_data_1[4]="0.5078";
new_data_1[5]="0.450302";
new_data_1[6]="-8.1881695";
new_data_1[7]="-1.2641385";
new_data_1[8]="0.038307227";
new_data_1[9]="-0.34222412";
new_data_1[10]="0.49102783";
new_data_1[11]="-0.016540527";
/**
* learner.train(activity_csv_url, preference_data_url)
* learner.predictor get the result from predictor for new data
* */
/**String[] new_data_2 = new String[12];
//-6.5565214,5.717354,5.6658783,0.185591,0.464146,0.413321,-20.580557,3.8498764,-0.4261679,0.7647095,-0.4713745,0.23999023,dancing
new_data_2[0]="-6.5565214";
new_data_2[1]="5.717354";
new_data_2[2]="5.6658783";
new_data_2[3]="0.185591";
new_data_2[4]="0.464146";
new_data_2[5]="0.413321";
new_data_2[6]="-20.580557";
new_data_2[7]="3.8498764";
new_data_2[8]="-0.4261679";
new_data_2[9]="0.7647095";
new_data_2[10]="-0.4713745";
new_data_2[11]="0.23999023";
String[] new_data_3 = new String[12];
new_data_3[0]="-5.3881507";
new_data_3[1]="0.25378537";
new_data_3[2]="7.69257";
new_data_3[3]="-0.122974";
new_data_3[4]="0.247411";
new_data_3[5]="0.439031";
new_data_3[6]="4.9224787";
new_data_3[7]="-10.601525";
new_data_3[8]="-4.927267";
new_data_3[9]="0.7946167";
new_data_3[10]="0.35272217";
new_data_3[11]="0.16192627";
//"-5.3881507","0.25378537","7.69257","-0.122974","0.247411","0.439031","4.9224787","-10.601525","-4.927267","0.7946167","0.35272217","0.16192627","lying"
Learner learner=new Learner();
//learner.train("datasets/activity_data.csv", "datasets/preference_data.csv");
String[] result = learner.predictor(new_data);
//learner.train("datasets/backup/activity_data.csv", "datasets/preference_data.csv");
String[] result = learner.predictor(new_data_3);
System.out.println("activity is:" + result[0]);
System.out.println("perference is: "+ result[1]);
//System.out.println("perference is: "+ result[1]);**/
}
public static void activity_validation_learner(){
ReadCSV csv = new ReadCSV("datasets/backup/activity_data.csv", false, CSVFormat.DECIMAL_POINT);
String[] line = new String[12];
Learner learner=new Learner();
int wrong=0;
int right=0;
while(csv.next()) {
StringBuilder result = new StringBuilder();
line[0] = csv.get(0);
line[1] = csv.get(1);
line[2] = csv.get(2);
line[3] = csv.get(3);
line[4] = csv.get(4);
line[5] = csv.get(5);
line[6] = csv.get(6);
line[7] = csv.get(7);
line[8] = csv.get(8);
line[9] = csv.get(9);
line[10] = csv.get(10);
line[11] = csv.get(11);
String correct = csv.get(12);
String irisChosen = learner.predictor(line)[0];
result.append(Arrays.toString(line));
result.append(" -> predicted: ");
result.append(irisChosen);
result.append("(correct: ");
result.append(correct);
result.append(")");
if (irisChosen.equals(correct)!=true){
System.out.println(correct);
System.out.println(irisChosen);
++wrong;
}else{
++right;
}
System.out.println(result.toString());
}
System.out.println("wrong number"+wrong);
System.out.println("right number"+right);
//double validation = (double(right))/(double(wrong+right));
//System.out.println("%.2f"+validation);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment