Skip to content
Snippets Groups Projects
Commit 78229137 authored by boqiren's avatar boqiren
Browse files

learner with activity and preference

parent eda27000
Branches
No related tags found
No related merge requests found
Showing
with 535 additions and 494 deletions
...@@ -91,6 +91,9 @@ aspect Rules { ...@@ -91,6 +91,9 @@ aspect Rules {
public void TriggerRuleAction.applyFor(Item item) { public void TriggerRuleAction.applyFor(Item item) {
getRule().trigger(item); getRule().trigger(item);
} }
public void SetStateFromExpression.applyFor(Item item) {
getAffectedItem().setStateFromDouble(getNumberExpression().eval());
}
public void SetStateFromConstantStringAction.applyFor(Item item) { public void SetStateFromConstantStringAction.applyFor(Item item) {
getAffectedItem().setStateFromString(getNewState()); getAffectedItem().setStateFromString(getNewState());
} }
......
...@@ -13,6 +13,8 @@ rel TriggerRuleAction.Rule -> Rule ; ...@@ -13,6 +13,8 @@ rel TriggerRuleAction.Rule -> Rule ;
abstract SetStateAction : Action ; abstract SetStateAction : Action ;
rel SetStateAction.AffectedItem -> Item ; rel SetStateAction.AffectedItem -> Item ;
SetStateFromExpression : SetStateAction ::= NumberExpression ;
SetStateFromConstantStringAction : SetStateAction ::= <NewState:String> ; SetStateFromConstantStringAction : SetStateAction ::= <NewState:String> ;
SetStateFromLambdaAction : SetStateAction ::= <NewStateProvider:NewStateProvider> ; SetStateFromLambdaAction : SetStateAction ::= <NewStateProvider:NewStateProvider> ;
SetStateFromTriggeringItemAction : SetStateAction ::= ; SetStateFromTriggeringItemAction : SetStateAction ::= ;
......
...@@ -30,7 +30,7 @@ import de.tudresden.inf.st.eraser.jastadd.parser.EraserParser.Terminals; ...@@ -30,7 +30,7 @@ import de.tudresden.inf.st.eraser.jastadd.parser.EraserParser.Terminals;
%} %}
WhiteSpace = [ ] | \t | \f | \n | \r | \r\n WhiteSpace = [ ] | \t | \f | \n | \r | \r\n
//Identifier = [:jletter:][:jletterdigit:]* Identifier = [:jletter:][:jletterdigit:]*
Text = \" ([^\"]*) \" Text = \" ([^\"]*) \"
Integer = [:digit:]+ // | "+" [:digit:]+ | "-" [:digit:]+ Integer = [:digit:]+ // | "+" [:digit:]+ | "-" [:digit:]+
...@@ -125,7 +125,7 @@ Comment = "//" [^\n\r]+ ...@@ -125,7 +125,7 @@ Comment = "//" [^\n\r]+
")" { return sym(Terminals.RB_ROUND); } ")" { return sym(Terminals.RB_ROUND); }
"{" { return sym(Terminals.LB_CURLY); } "{" { return sym(Terminals.LB_CURLY); }
"}" { return sym(Terminals.RB_CURLY); } "}" { return sym(Terminals.RB_CURLY); }
//{Identifier} { return sym(Terminals.NAME); } {Identifier} { return sym(Terminals.NAME); }
{Text} { return symText(Terminals.TEXT); } {Text} { return symText(Terminals.TEXT); }
{Integer} { return sym(Terminals.INTEGER); } {Integer} { return sym(Terminals.INTEGER); }
{Real} { return sym(Terminals.REAL); } {Real} { return sym(Terminals.REAL); }
......
...@@ -90,7 +90,7 @@ NumberLiteralExpression literal_expression = ...@@ -90,7 +90,7 @@ NumberLiteralExpression literal_expression =
; ;
Designator designator = Designator designator =
TEXT.n {: return eph.createDesignator(n); :} NAME.n {: return eph.createDesignator(n); :}
; ;
Thing thing = Thing thing =
......
...@@ -48,6 +48,7 @@ public class EraserParserHelper { ...@@ -48,6 +48,7 @@ public class EraserParserHelper {
private Root root; private Root root;
private static boolean checkUnusedElements = true; private static boolean checkUnusedElements = true;
private static Root initialRoot = null;
private class ItemPrototype extends DefaultItem { private class ItemPrototype extends DefaultItem {
...@@ -61,21 +62,31 @@ public class EraserParserHelper { ...@@ -61,21 +62,31 @@ public class EraserParserHelper {
EraserParserHelper.checkUnusedElements = checkUnusedElements; EraserParserHelper.checkUnusedElements = checkUnusedElements;
} }
public static void setInitialRoot(Root root) {
EraserParserHelper.initialRoot = root;
}
/** /**
* Post processing step after parsing a model, to resolve all references within the model. * Post processing step after parsing a model, to resolve all references within the model.
* @throws java.util.NoSuchElementException if a reference can not be resolved
*/ */
public void resolveReferences() { public void resolveReferences() {
if (this.root == null) { if (this.root == null) {
// when parsing expressions // when parsing expressions
this.root = createRoot(); this.root = EraserParserHelper.initialRoot != null ? EraserParserHelper.initialRoot : createRoot();
} }
if (checkUnusedElements) { if (checkUnusedElements) {
fillUnused(); fillUnused();
} }
resolve(thingTypeMap, missingThingTypeMap, Thing::setType); resolve(thingTypeMap, missingThingTypeMap, Thing::setType);
resolve(channelTypeMap, missingChannelTypeMap, Channel::setType); resolve(channelTypeMap, missingChannelTypeMap, Channel::setType);
if (itemMap == null || itemMap.isEmpty()) {
missingItemForDesignator.forEach((designator, itemName) ->
JavaUtils.ifPresentOrElse(root.getOpenHAB2Model().resolveItem(itemName),
designator::setItem,
() -> logger.warn("Could not resolve item {} for {}", itemName, designator)));
} else {
resolve(itemMap, missingItemForDesignator, Designator::setItem); resolve(itemMap, missingItemForDesignator, Designator::setItem);
}
missingTopicMap.forEach((topic, parts) -> ParserUtils.createMqttTopic(topic, parts, this.root)); missingTopicMap.forEach((topic, parts) -> ParserUtils.createMqttTopic(topic, parts, this.root));
this.root.getMqttRoot().ensureCorrectPrefixes(); this.root.getMqttRoot().ensureCorrectPrefixes();
......
...@@ -6,6 +6,7 @@ import beaver.Symbol; ...@@ -6,6 +6,7 @@ import beaver.Symbol;
import de.tudresden.inf.st.eraser.jastadd.model.*; import de.tudresden.inf.st.eraser.jastadd.model.*;
import de.tudresden.inf.st.eraser.jastadd.parser.EraserParser; import de.tudresden.inf.st.eraser.jastadd.parser.EraserParser;
import de.tudresden.inf.st.eraser.jastadd.scanner.EraserScanner; import de.tudresden.inf.st.eraser.jastadd.scanner.EraserScanner;
import de.tudresden.inf.st.eraser.parser.EraserParserHelper;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
...@@ -193,14 +194,23 @@ public class ParserUtils { ...@@ -193,14 +194,23 @@ public class ParserUtils {
} }
public static NumberExpression parseNumberExpression(String expression_string) throws IOException, Parser.Exception { public static NumberExpression parseNumberExpression(String expression_string) throws IOException, Parser.Exception {
return (NumberExpression) parseExpression(expression_string, EraserParser.AltGoals.number_expression); return parseNumberExpression(expression_string, null);
} }
public static LogicalExpression parseLogicalExpression(String expression_string) throws IOException, Parser.Exception { public static LogicalExpression parseLogicalExpression(String expression_string) throws IOException, Parser.Exception {
return (LogicalExpression) parseExpression(expression_string, EraserParser.AltGoals.logical_expression); return parseLogicalExpression(expression_string, null);
} }
private static Expression parseExpression(String expression_string, short alt_goal) throws IOException, Parser.Exception { public static NumberExpression parseNumberExpression(String expression_string, Root root) throws IOException, Parser.Exception {
return (NumberExpression) parseExpression(expression_string, EraserParser.AltGoals.number_expression, root);
}
public static LogicalExpression parseLogicalExpression(String expression_string, Root root) throws IOException, Parser.Exception {
return (LogicalExpression) parseExpression(expression_string, EraserParser.AltGoals.logical_expression, root);
}
private static Expression parseExpression(String expression_string, short alt_goal, Root root) throws IOException, Parser.Exception {
EraserParserHelper.setInitialRoot(root);
StringReader reader = new StringReader(expression_string); StringReader reader = new StringReader(expression_string);
if (verboseLoading) { if (verboseLoading) {
EraserScanner scanner = new EraserScanner(reader); EraserScanner scanner = new EraserScanner(reader);
...@@ -220,6 +230,7 @@ public class ParserUtils { ...@@ -220,6 +230,7 @@ public class ParserUtils {
Expression result = (Expression) parser.parse(scanner, alt_goal); Expression result = (Expression) parser.parse(scanner, alt_goal);
parser.resolveReferences(); parser.resolveReferences();
reader.close(); reader.close();
EraserParserHelper.setInitialRoot(null);
return result; return result;
} }
......
...@@ -3,6 +3,7 @@ package de.tudresden.inf.st.eraser; ...@@ -3,6 +3,7 @@ package de.tudresden.inf.st.eraser;
import beaver.Parser; import beaver.Parser;
import de.tudresden.inf.st.eraser.jastadd.model.*; import de.tudresden.inf.st.eraser.jastadd.model.*;
import de.tudresden.inf.st.eraser.util.ParserUtils; import de.tudresden.inf.st.eraser.util.ParserUtils;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.io.IOException; import java.io.IOException;
...@@ -128,6 +129,16 @@ public class ExpressionParserTest { ...@@ -128,6 +129,16 @@ public class ExpressionParserTest {
assertThat(rightOfSub.getValue(), equalTo(8.0)); assertThat(rightOfSub.getValue(), equalTo(8.0));
} }
@Test
public void expressionWithItem() {
try {
ParserUtils.parseNumberExpression("(myItem * 3)");
} catch (IOException | Parser.Exception e) {
e.printStackTrace();
Assert.fail(e.getMessage());
}
}
@Test @Test
public void comparingExpressions() throws IOException, Parser.Exception { public void comparingExpressions() throws IOException, Parser.Exception {
comparingExpression("<", ComparatorType.LessThan, 1, 2); comparingExpression("<", ComparatorType.LessThan, 1, 2);
......
package de.tudresden.inf.st.eraser; package de.tudresden.inf.st.eraser;
import beaver.Parser;
import de.tudresden.inf.st.eraser.jastadd.model.*; import de.tudresden.inf.st.eraser.jastadd.model.*;
import de.tudresden.inf.st.eraser.util.ParserUtils;
import de.tudresden.inf.st.eraser.util.TestUtils; import de.tudresden.inf.st.eraser.util.TestUtils;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.StreamSupport; import java.util.stream.StreamSupport;
/** /**
...@@ -24,26 +28,28 @@ public class RulesTest { ...@@ -24,26 +28,28 @@ public class RulesTest {
private static final double DELTA = 0.01d; private static final double DELTA = 0.01d;
class Counters implements Action2EditConsumer { class CountingAction extends NoopAction {
Map<Item, Integer> counters; final Map<Item, AtomicInteger> counters = new HashMap<>();
Counters() { CountingAction() {
reset(); reset();
} }
private AtomicInteger getAtomic(Item item) {
return counters.computeIfAbsent(item, unused -> new AtomicInteger(0));
}
@Override @Override
public void accept(Item item) { public void applyFor(Item item) {
counters.computeIfPresent(item, (i, value) -> value + 1); getAtomic(item).addAndGet(1);
counters.putIfAbsent(item, 1);
} }
int get(Item item) { int get(Item item) {
counters.putIfAbsent(item, 0); return getAtomic(item).get();
return counters.get(item);
} }
void reset() { void reset() {
counters = new HashMap<>(); counters.clear();
} }
} }
...@@ -54,8 +60,8 @@ public class RulesTest { ...@@ -54,8 +60,8 @@ public class RulesTest {
NumberItem item = modelAndItem.item; NumberItem item = modelAndItem.item;
Rule rule = new Rule(); Rule rule = new Rule();
Counters counter = new Counters(); CountingAction counter = new CountingAction();
rule.addAction(new LambdaAction(counter)); rule.addAction(counter);
root.addRule(rule); root.addRule(rule);
rule.activateFor(item); rule.activateFor(item);
...@@ -97,14 +103,14 @@ public class RulesTest { ...@@ -97,14 +103,14 @@ public class RulesTest {
Root root = modelAndItem.model.getRoot(); Root root = modelAndItem.model.getRoot();
NumberItem item = modelAndItem.item; NumberItem item = modelAndItem.item;
Counters counter1 = new Counters(); CountingAction counter1 = new CountingAction();
Rule ruleA = new Rule(); Rule ruleA = new Rule();
ruleA.addAction(new LambdaAction(counter1)); ruleA.addAction(counter1);
root.addRule(ruleA); root.addRule(ruleA);
Rule ruleB = new Rule(); Rule ruleB = new Rule();
Counters counter2 = new Counters(); CountingAction counter2 = new CountingAction();
ruleB.addAction(new LambdaAction(counter2)); ruleB.addAction(counter2);
root.addRule(ruleB); root.addRule(ruleB);
ruleA.activateFor(item); ruleA.activateFor(item);
...@@ -134,8 +140,8 @@ public class RulesTest { ...@@ -134,8 +140,8 @@ public class RulesTest {
NumberItem item2 = TestUtils.addItemTo(root.getOpenHAB2Model(), 4, useUpdatingItem); NumberItem item2 = TestUtils.addItemTo(root.getOpenHAB2Model(), 4, useUpdatingItem);
Rule rule = new Rule(); Rule rule = new Rule();
Counters counter = new Counters(); CountingAction counter = new CountingAction();
rule.addAction(new LambdaAction(counter)); rule.addAction(counter);
root.addRule(rule); root.addRule(rule);
rule.activateFor(item1); rule.activateFor(item1);
...@@ -192,8 +198,8 @@ public class RulesTest { ...@@ -192,8 +198,8 @@ public class RulesTest {
NumberItem item = modelAndItem.item; NumberItem item = modelAndItem.item;
Rule rule = new Rule(); Rule rule = new Rule();
Counters counter = new Counters(); CountingAction counter = new CountingAction();
rule.addAction(new LambdaAction(counter)); rule.addAction(counter);
root.addRule(rule); root.addRule(rule);
rule.activateFor(item); rule.activateFor(item);
...@@ -220,8 +226,8 @@ public class RulesTest { ...@@ -220,8 +226,8 @@ public class RulesTest {
ItemStateNumberCheck check2 = new ItemStateNumberCheck(ComparatorType.LessThan, 6); ItemStateNumberCheck check2 = new ItemStateNumberCheck(ComparatorType.LessThan, 6);
rule.addCondition(new ItemStateCheckCondition(check1)); rule.addCondition(new ItemStateCheckCondition(check1));
rule.addCondition(new ItemStateCheckCondition(check2)); rule.addCondition(new ItemStateCheckCondition(check2));
Counters counter = new Counters(); CountingAction counter = new CountingAction();
rule.addAction(new LambdaAction(counter)); rule.addAction(counter);
root.addRule(rule); root.addRule(rule);
rule.activateFor(item); rule.activateFor(item);
...@@ -250,10 +256,10 @@ public class RulesTest { ...@@ -250,10 +256,10 @@ public class RulesTest {
NumberItem item = modelAndItem.item; NumberItem item = modelAndItem.item;
Rule rule = new Rule(); Rule rule = new Rule();
Counters counter1 = new Counters(); CountingAction counter1 = new CountingAction();
rule.addAction(new LambdaAction(counter1)); rule.addAction(counter1);
Counters counter2 = new Counters(); CountingAction counter2 = new CountingAction();
rule.addAction(new LambdaAction(counter2)); rule.addAction(counter2);
root.addRule(rule); root.addRule(rule);
rule.activateFor(item); rule.activateFor(item);
...@@ -277,12 +283,12 @@ public class RulesTest { ...@@ -277,12 +283,12 @@ public class RulesTest {
NumberItem item2 = TestUtils.addItemTo(root.getOpenHAB2Model(), 4, useUpdatingItem); NumberItem item2 = TestUtils.addItemTo(root.getOpenHAB2Model(), 4, useUpdatingItem);
Rule ruleA = new Rule(); Rule ruleA = new Rule();
Counters counter1 = new Counters(); CountingAction counter1 = new CountingAction();
ruleA.addAction(new LambdaAction(counter1)); ruleA.addAction(counter1);
Rule ruleB = new Rule(); Rule ruleB = new Rule();
Counters counter2 = new Counters(); CountingAction counter2 = new CountingAction();
ruleB.addAction(new LambdaAction(counter2)); ruleB.addAction(counter2);
ruleA.addAction(new TriggerRuleAction(ruleB)); ruleA.addAction(new TriggerRuleAction(ruleB));
...@@ -321,8 +327,8 @@ public class RulesTest { ...@@ -321,8 +327,8 @@ public class RulesTest {
Rule rule = new Rule(); Rule rule = new Rule();
rule.addAction(new SetStateFromConstantStringAction(item2, "5")); rule.addAction(new SetStateFromConstantStringAction(item2, "5"));
Counters counter = new Counters(); CountingAction counter = new CountingAction();
rule.addAction(new LambdaAction(counter)); rule.addAction(counter);
root.addRule(rule); root.addRule(rule);
rule.activateFor(item); rule.activateFor(item);
...@@ -353,8 +359,8 @@ public class RulesTest { ...@@ -353,8 +359,8 @@ public class RulesTest {
Rule rule = new Rule(); Rule rule = new Rule();
rule.addAction(new SetStateFromLambdaAction(item2, provider)); rule.addAction(new SetStateFromLambdaAction(item2, provider));
Counters counter = new Counters(); CountingAction counter = new CountingAction();
rule.addAction(new LambdaAction(counter)); rule.addAction(counter);
root.addRule(rule); root.addRule(rule);
rule.activateFor(item); rule.activateFor(item);
...@@ -394,9 +400,9 @@ public class RulesTest { ...@@ -394,9 +400,9 @@ public class RulesTest {
StringItem item2 = addStringItem(root.getOpenHAB2Model(), "0"); StringItem item2 = addStringItem(root.getOpenHAB2Model(), "0");
Rule rule = new Rule(); Rule rule = new Rule();
Counters counter = new Counters(); CountingAction counter = new CountingAction();
rule.addAction(new SetStateFromTriggeringItemAction(item2)); rule.addAction(new SetStateFromTriggeringItemAction(item2));
rule.addAction(new LambdaAction(counter)); rule.addAction(counter);
root.addRule(rule); root.addRule(rule);
rule.activateFor(item); rule.activateFor(item);
...@@ -433,8 +439,8 @@ public class RulesTest { ...@@ -433,8 +439,8 @@ public class RulesTest {
action.addSourceItem(item2); action.addSourceItem(item2);
action.setAffectedItem(affectedItem); action.setAffectedItem(affectedItem);
rule.addAction(action); rule.addAction(action);
Counters counter = new Counters(); CountingAction counter = new CountingAction();
rule.addAction(new LambdaAction(counter)); rule.addAction(counter);
root.addRule(rule); root.addRule(rule);
rule.activateFor(item); rule.activateFor(item);
...@@ -492,8 +498,8 @@ public class RulesTest { ...@@ -492,8 +498,8 @@ public class RulesTest {
Rule rule = new Rule(); Rule rule = new Rule();
rule.addAction(new AddDoubleToStateAction(affectedItem, 2)); rule.addAction(new AddDoubleToStateAction(affectedItem, 2));
Counters counter = new Counters(); CountingAction counter = new CountingAction();
rule.addAction(new LambdaAction(counter)); rule.addAction(counter);
root.addRule(rule); root.addRule(rule);
rule.activateFor(item); rule.activateFor(item);
...@@ -529,8 +535,8 @@ public class RulesTest { ...@@ -529,8 +535,8 @@ public class RulesTest {
Rule rule = new Rule(); Rule rule = new Rule();
rule.addAction(new MultiplyDoubleToStateAction(affectedItem, 2)); rule.addAction(new MultiplyDoubleToStateAction(affectedItem, 2));
Counters counter = new Counters(); CountingAction counter = new CountingAction();
rule.addAction(new LambdaAction(counter)); rule.addAction(counter);
root.addRule(rule); root.addRule(rule);
rule.activateFor(item); rule.activateFor(item);
...@@ -567,13 +573,13 @@ public class RulesTest { ...@@ -567,13 +573,13 @@ public class RulesTest {
Rule ruleA = new Rule(); Rule ruleA = new Rule();
ruleA.addAction(new AddDoubleToStateAction(affectedItem, 2)); ruleA.addAction(new AddDoubleToStateAction(affectedItem, 2));
Counters counterA = new Counters(); CountingAction counterA = new CountingAction();
ruleA.addAction(new LambdaAction(counterA)); ruleA.addAction(counterA);
Rule ruleB = new Rule(); Rule ruleB = new Rule();
ruleB.addAction(new MultiplyDoubleToStateAction(affectedItem, 3)); ruleB.addAction(new MultiplyDoubleToStateAction(affectedItem, 3));
Counters counterB = new Counters(); CountingAction counterB = new CountingAction();
ruleB.addAction(new LambdaAction(counterB)); ruleB.addAction(counterB);
ruleA.addAction(new TriggerRuleAction(ruleB)); ruleA.addAction(new TriggerRuleAction(ruleB));
...@@ -612,11 +618,62 @@ public class RulesTest { ...@@ -612,11 +618,62 @@ public class RulesTest {
63, affectedItem.getState(), DELTA); 63, affectedItem.getState(), DELTA);
} }
@Test
public void testSetFromExpression() throws IOException, Parser.Exception {
TestUtils.ModelAndItem modelAndItem = createModelAndItem(3);
Root root = modelAndItem.model.getRoot();
NumberItem item1 = modelAndItem.item;
NumberItem item2 = TestUtils.addItemTo(root.getOpenHAB2Model(), 4, useUpdatingItem);
NumberItem affectedItem = TestUtils.addItemTo(root.getOpenHAB2Model(), 5, useUpdatingItem);
Rule rule = new Rule();
SetStateFromExpression action = new SetStateFromExpression();
// TODO item1 should be referred to as triggering item
action.setNumberExpression(ParserUtils.parseNumberExpression("(" + item1.getID() + " + " + item2.getID() + ")", root));
action.setAffectedItem(affectedItem);
rule.addAction(action);
CountingAction counter = new CountingAction();
rule.addAction(counter);
root.addRule(rule);
rule.activateFor(item1);
Assert.assertEquals(m("Counter not initialized correctly"), 0, counter.get(item1));
Assert.assertEquals(m("Second item not initialized correctly"),
4, item2.getState(), DELTA);
Assert.assertEquals(m("Affected item not initialized correctly"),
5, affectedItem.getState(), DELTA);
// 5 + 4 = 9
setState(item1, 5);
Assert.assertEquals(m("Change of item state should trigger the rule"), 1, counter.get(item1));
Assert.assertEquals(m("Change of item state should set the state of the affected item"),
9, affectedItem.getState(), DELTA);
// still 9
setState(item1, 5);
Assert.assertEquals(m("Change of item to same state should not trigger the rule"), 1, counter.get(item1));
Assert.assertEquals(m("Change of item to same state should not set the state of the affected item"),
9, affectedItem.getState(), DELTA);
// still 9 (changes of item2 do not trigger the rule)
setState(item2, 1);
Assert.assertEquals(m("Change of second item to same state should not trigger the rule"), 1, counter.get(item1));
Assert.assertEquals(m("Change of second item to same state should not set the state of the affected item"),
9, affectedItem.getState(), DELTA);
// 0 + 1 = 1
setState(item1, 0);
Assert.assertEquals(m("Change of item state should trigger the rule"), 2, counter.get(item1));
Assert.assertEquals(m("Change of item state should set the state of the affected item"),
1, affectedItem.getState(), DELTA);
}
@Test @Test
public void testCronJobRule() { public void testCronJobRule() {
Rule rule = new Rule(); Rule rule = new Rule();
Counters counter = new Counters(); CountingAction counter = new CountingAction();
rule.addAction(new LambdaAction(counter)); rule.addAction(counter);
Assert.assertEquals(m("Counter not initialized correctly"), 0, counter.get(null)); Assert.assertEquals(m("Counter not initialized correctly"), 0, counter.get(null));
......
<?xml version="1.0" encoding="UTF-8"?>
<module type="JAVA_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
package de.tudresden.inf.st.eraser.starter;
import beaver.Parser;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import de.tudresden.inf.st.eraser.feedbackloop.analyze.AnalyzeImpl;
import de.tudresden.inf.st.eraser.feedbackloop.api.Analyze;
import de.tudresden.inf.st.eraser.feedbackloop.api.Execute;
import de.tudresden.inf.st.eraser.feedbackloop.api.Learner;
import de.tudresden.inf.st.eraser.feedbackloop.api.Plan;
import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model;
import de.tudresden.inf.st.eraser.feedbackloop.execute.ExecuteImpl;
import de.tudresden.inf.st.eraser.feedbackloop.learner.LearnerImpl;
import de.tudresden.inf.st.eraser.feedbackloop.plan.PlanImpl;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import de.tudresden.inf.st.eraser.openhab2.OpenHab2Importer;
import de.tudresden.inf.st.eraser.spark.Application;
import de.tudresden.inf.st.eraser.util.JavaUtils;
import de.tudresden.inf.st.eraser.util.ParserUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.File;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
/**
* This Starter combines and starts all modules. This includes:
*
* <ul>
* <li>Knowledge-Base in <code>eraser-base</code></li>
* <li>Feedback loop in <code>feedbackloop.{analyze,plan,execute}</code></li>
* <li>REST-API in <code>eraser-rest</code></li>
* </ul>
*/
public class EraserStarter {
private static final Logger logger = LogManager.getLogger(EraserStarter.class);
@SuppressWarnings("ResultOfMethodCallIgnored")
public static void main(String[] args) {
logger.info("Starting ERASER");
ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
File settingsFile = new File("starter-setting.yaml");
Setting settings;
try {
settings = mapper.readValue(settingsFile, Setting.class);
} catch (Exception e) {
logger.fatal("Could not read settings at '{}'. Exiting.", settingsFile.getAbsolutePath());
logger.catching(e);
System.exit(1);
return;
}
boolean startRest = settings.rest.use;
Root model;
switch (settings.initModelWith) {
case openhab:
OpenHab2Importer importer = new OpenHab2Importer();
try {
model = importer.importFrom(new URL(settings.openhab.url));
} catch (MalformedURLException e) {
logger.error("Could not parse URL {}", settings.openhab.url);
logger.catching(e);
System.exit(1);
return;
}
logger.info("Imported model {}", model.description());
break;
case load:
default:
try {
model = ParserUtils.load(settings.load.file, EraserStarter.class);
} catch (IOException | Parser.Exception e) {
logger.error("Problems parsing the given file {}", settings.load.file);
logger.catching(e);
System.exit(1);
return;
}
}
// initialize activity recognition
if (settings.activity.dummy) {
logger.info("Using dummy activity recognition");
model.getMachineLearningRoot().setActivityRecognition(DummyMachineLearningModel.createDefault());
} else {
logger.error("Reading activity recognition from file is not supported yet!");
// TODO
}
// initialize preference learning
if (settings.preference.dummy) {
logger.info("Using dummy preference learning");
model.getMachineLearningRoot().setPreferenceLearning(DummyMachineLearningModel.createDefault());
} else {
logger.info("Reading preference learning from file {}", settings.preference.file);
Learner learner = new LearnerImpl();
// there should be a method to load a model using an URL
Model preference = learner.getTrainedModel(settings.preference.realURL(), settings.preference.id);
NeuralNetworkRoot neuralNetwork = LearnerHelper.transform(preference);
if (neuralNetwork == null) {
logger.error("Could not create preference model, see possible previous errors.");
} else {
model.getMachineLearningRoot().setPreferenceLearning(neuralNetwork);
neuralNetwork.connectItems(settings.preference.items);
neuralNetwork.setOutputApplication(zeroToThree -> 25 * zeroToThree);
JavaUtils.ifPresentOrElse(
model.resolveItem(settings.preference.affectedItem),
item -> neuralNetwork.getOutputLayer().setAffectedItem(item),
() -> logger.error("Output item not set from value '{}'", settings.preference.affectedItem));
}
}
if (!model.getMachineLearningRoot().getActivityRecognition().check()) {
logger.fatal("Invalid activity recognition!");
System.exit(1);
}
if (!model.getMachineLearningRoot().getPreferenceLearning().check()) {
logger.fatal("Invalid preference learning!");
System.exit(1);
}
Analyze analyze = null;
if (settings.useMAPE) {
// configure and start mape loop
logger.info("Starting MAPE loop");
analyze = new AnalyzeImpl();
Plan plan = new PlanImpl();
Execute execute = new ExecuteImpl();
analyze.setPlan(plan);
plan.setExecute(execute);
analyze.setKnowledgeBase(model);
plan.setKnowledgeBase(model);
execute.setKnowledgeBase(model);
analyze.startAsThread(1, TimeUnit.SECONDS);
if (!startRest) {
// alternative exit condition
System.out.println("Hit [Enter] to exit");
try {
System.in.read();
} catch (IOException e) {
e.printStackTrace();
}
System.out.println("Stopping...");
analyze.stop();
}
} else {
logger.info("No MAPE loop this time");
}
if (startRest) {
// start REST-API in new thread
logger.info("Starting REST server");
Lock lock = new ReentrantLock();
Condition quitCondition = lock.newCondition();
Thread t = new Thread(new ThreadGroup("REST-API"),
() -> Application.start(settings.rest.port, model, settings.rest.createDummyMLData, lock, quitCondition));
t.setDaemon(true);
t.start();
logger.info("Waiting until request is send to '/system/exit'");
try {
lock.lock();
if (t.isAlive()) {
quitCondition.await();
}
} catch (InterruptedException e) {
logger.warn("Waiting was interrupted");
} finally {
lock.unlock();
}
if (analyze != null) {
analyze.stop();
}
} else {
logger.info("No REST server this time");
}
logger.info("I'm done here.");
}
}
package de.tudresden.inf.st.eraser.starter;
import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.Arrays;
/**
* Transformation of a {@link Model} into a {@link MachineLearningModel}.
*
* @author rschoene - Initial contribution
*/
class LearnerHelper {
private static final Logger logger = LogManager.getLogger(LearnerHelper.class);
// Activation Functions
private static DoubleArrayDoubleFunction sigmoid = inputs -> Math.signum(Arrays.stream(inputs).sum());
private static DoubleArrayDoubleFunction tanh = inputs -> Math.tanh(Arrays.stream(inputs).sum());
private static DoubleArrayDoubleFunction function_one = inputs -> 1.0;
static NeuralNetworkRoot transform(Model model) {
NeuralNetworkRoot result = NeuralNetworkRoot.createEmpty();
ArrayList<Double> weights = model.getWeights();
// inputs
int inputSum = model.getInputLayerNumber() + model.getInputBias();
for (int i = 0; i < inputSum; ++i) {
InputNeuron inputNeuron = new InputNeuron();
result.addInputNeuron(inputNeuron);
}
InputNeuron bias = result.getInputNeuron(model.getInputBias());
OutputLayer outputLayer = new OutputLayer();
// output layer
for (int i = 0; i < model.getOutputLayerNumber(); ++i) {
OutputNeuron outputNeuron = new OutputNeuron();
setActivationFunction(outputNeuron, model.getOutputActivationFunction());
outputLayer.addOutputNeuron(outputNeuron);
}
result.setOutputLayer(outputLayer);
// hidden layer
int hiddenSum = model.gethiddenLayerNumber() + model.getHiddenBias();
HiddenNeuron[] hiddenNeurons = new HiddenNeuron[hiddenSum];
for (int i = 0; i < (hiddenNeurons.length); i++) {
if (i == model.gethiddenLayerNumber()) {
HiddenNeuron hiddenNeuron = new HiddenNeuron();
hiddenNeuron.setActivationFormula(function_one);
hiddenNeurons[i] = hiddenNeuron;
result.addHiddenNeuron(hiddenNeuron);
bias.connectTo(hiddenNeuron, 1.0);
for (int out = 0; out < outputLayer.getNumOutputNeuron(); out++) {
hiddenNeuron.connectTo(outputLayer.getOutputNeuron(out), weights.get(i + hiddenSum * out));
}
} else {
HiddenNeuron hiddenNeuron = new HiddenNeuron();
setActivationFunction(hiddenNeuron, model.getHiddenActivationFunction());
hiddenNeurons[i] = hiddenNeuron;
result.addHiddenNeuron(hiddenNeuron);
for (int in = 0; in < inputSum; in++) {
// TODO replace 4 and 5 with model-attributes
result.getInputNeuron(in).connectTo(hiddenNeuron, weights.get((hiddenNeurons.length * 4 + in) + i * 5));
}
for (int out = 0; out < outputLayer.getNumOutputNeuron(); out++) {
hiddenNeuron.connectTo(outputLayer.getOutputNeuron(out), weights.get(i + hiddenSum * out));
}
}
}
outputLayer.setCombinator(LearnerHelper::predictor);
logger.info("Created model with {} input, {} hidden and {} output neurons",
result.getNumInputNeuron(), result.getNumHiddenNeuron(), result.getOutputLayer().getNumOutputNeuron());
return result;
}
private static void setActivationFunction(HiddenNeuron neuron, String functionName) {
switch (functionName) {
case "ActivationTANH": neuron.setActivationFormula(tanh); break;
case "ActivationLinear": neuron.setActivationFormula(function_one);
case "ActivationSigmoid": neuron.setActivationFormula(sigmoid); break;
default: throw new IllegalArgumentException("Unknown function " + functionName);
}
}
private static double predictor(double[] inputs) {
int index = 0;
double maxInput = StatUtils.max(inputs);
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] == maxInput) {
index = i;
}
}
//outputs from learner
final double[] outputs = new double[]{2.0, 1.0, 3.0, 0.0};
return outputs[index];
}
}
package de.tudresden.inf.st.eraser.starter;
import java.net.URL;
import java.util.List;
/**
* Setting bean.
*
* @author rschoene - Initial contribution
*/
@SuppressWarnings("WeakerAccess")
class Setting {
public class Rest {
public boolean use = true;
public int port = 4567;
public boolean createDummyMLData = false;
}
public class FileContainer {
public String file;
URL realURL() {
return Setting.class.getClassLoader().getResource(file);
}
}
public class MLContainer extends FileContainer {
public boolean dummy = false;
public int id = 1;
public List<String> items;
public String affectedItem;
}
public class OpenHabContainer {
public String url;
}
public enum InitModelWith {
load, openhab
}
public Rest rest;
public boolean useMAPE = true;
public FileContainer load;
public MLContainer activity;
public MLContainer preference;
public OpenHabContainer openhab;
public InitModelWith initModelWith = InitModelWith.load;
}
...@@ -15,6 +15,8 @@ dependencies { ...@@ -15,6 +15,8 @@ dependencies {
testCompile group: 'junit', name: 'junit', version: '4.12' testCompile group: 'junit', name: 'junit', version: '4.12'
testCompile group: 'org.hamcrest', name: 'hamcrest-junit', version: '2.0.0.0' testCompile group: 'org.hamcrest', name: 'hamcrest-junit', version: '2.0.0.0'
compile group: 'org.encog', name: 'encog-core', version: '3.4' compile group: 'org.encog', name: 'encog-core', version: '3.4'
implementation group: 'com.opencsv', name: 'opencsv', version: '4.1'
implementation group: 'commons-io', name: 'commons-io', version: '2.5'
} }
run { run {
......
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import java.io.*;
import java.util.*;
import com.opencsv.CSVReader;
import com.opencsv.CSVWriter;
public class CsvTransfer {
private static final String CSV_FILE_PATH
= "datasets/backup/activity_data_example.csv";
private static final String OUTPUT_FILE_PATH
= "datasets/backup/activity_data.csv";
public static void main(String[] args)
{
addDataToCSV(CSV_FILE_PATH, OUTPUT_FILE_PATH);
}
public static void addDataToCSV(String input, String output)
{
File input_file = new File(input);
File output_file = new File(output);
try {
// create FileWriter object with file as parameter
FileReader reader = new FileReader(input_file);
CSVReader csv_reader = new CSVReader(reader);
String[] nextRecord;
FileWriter writer = new FileWriter(output_file);
CSVWriter csv_writer = new CSVWriter(writer, ',',
CSVWriter.NO_QUOTE_CHARACTER,
CSVWriter.DEFAULT_ESCAPE_CHARACTER,
CSVWriter.DEFAULT_LINE_END);
List<String[]> data = new ArrayList<String[]>();
while ((nextRecord = csv_reader.readNext()) != null) {
data.add(nextRecord);
for (String cell : nextRecord) {
System.out.print(cell + "\t");
}
System.out.println();
}
csv_writer.writeAll(data);
writer.close();
csv_reader.close();
}
catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import com.opencsv.CSVWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class DummyPreference {
private static String activity; //Activity: walking, reading, working, dancing, lying, getting up
private static String watch_brightness; //dark: <45; dimmer 45-70; bright >70;
private static String light_color_openhab_H; //red 7; green 120; blue 240; yellow 60; sky blue 180; purple 300;
private static String brightness_output; //1-100**/
private static Random random = new Random();
public static void main(String[] args) {
creator();
}
static void creator(){
try{
FileWriter writer = new FileWriter("datasets/backup/preference_data.csv",true);
CSVWriter csv_writer = new CSVWriter(writer, ',',
CSVWriter.NO_QUOTE_CHARACTER,
CSVWriter.DEFAULT_ESCAPE_CHARACTER,
CSVWriter.DEFAULT_LINE_END);
//activity="walking" green
activity ="walking";
// activity ="reading";
csv_writer.writeAll(generator("walking","green"));
csv_writer.writeAll(generator("reading","sky blue"));
csv_writer.writeAll(generator("working","blue"));
csv_writer.writeAll(generator("dancing","purple"));
csv_writer.writeAll(generator("lying","red"));
csv_writer.writeAll(generator("getting up","yellow"));
csv_writer.close();
writer.close();
}catch (IOException e){e.printStackTrace();}
}
static List<String[]> generator(String activity_input, String color){
List<String[]> data = new ArrayList<String[]>();
activity = activity_input;
light_color_openhab_H =color;
//100 walking with different lighting intensity
for (int i=0; i<100; i++){
String[] add_data = new String[4];
int brightness = random.nextInt(3000);
System.out.println(brightness);
if (brightness<45){
watch_brightness = "dark";
brightness_output ="100";
}else if(45<=brightness && brightness<200){
watch_brightness = "dimmer";
brightness_output ="40";
}else if( 200<=brightness && brightness<1000){
watch_brightness = "medium";
brightness_output ="70";
}else{
watch_brightness = "bright";
brightness_output ="0";
}
add_data[0] = activity;
add_data[1] = watch_brightness;
add_data[2] = light_color_openhab_H;
add_data[3] = brightness_output;
data.add(add_data);
}
return data;
}
}
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup; package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import java.io.File; import java.io.File;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.Arrays;
import com.sun.javafx.tools.packager.Log; import com.sun.javafx.tools.packager.Log;
import org.encog.ConsoleStatusReportable; import org.encog.ConsoleStatusReportable;
import org.encog.Encog; import org.encog.Encog;
import org.encog.bot.BotUtil;
import org.encog.ml.MLInput;
import org.encog.ml.MLRegression; import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData; import org.encog.ml.data.MLData;
import org.encog.ml.data.versatile.NormalizationHelper; import org.encog.ml.data.versatile.NormalizationHelper;
...@@ -22,8 +16,6 @@ import org.encog.ml.factory.MLMethodFactory; ...@@ -22,8 +16,6 @@ import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.model.EncogModel; import org.encog.ml.model.EncogModel;
import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.BasicNetwork;
import org.encog.util.csv.CSVFormat; import org.encog.util.csv.CSVFormat;
import org.encog.util.csv.ReadCSV;
import org.encog.util.simple.EncogUtility;
import static org.encog.persist.EncogDirectoryPersistence.*; import static org.encog.persist.EncogDirectoryPersistence.*;
public class Learner { public class Learner {
...@@ -32,128 +24,173 @@ public class Learner { ...@@ -32,128 +24,173 @@ public class Learner {
* intial train * intial train
* */ * */
private String csv_url_activity; private String csv_url_activity;
private String csv_url_perference; private String csv_url_preference;
private String save_activity_model_file = "datasets/backup/activity_model.eg"; private String save_activity_model_file = "datasets/backup/activity_model.eg";
private String save_perference_model_file = "datasets/backup/preference_model.eg"; private String save_preference_model_file = "datasets/backup/preference_model.eg";
private File csv_file; private File csv_file;
private VersatileDataSource souce; private VersatileDataSource a_souce;
private VersatileMLDataSet data; private VersatileMLDataSet a_data;
private EncogModel model; private VersatileDataSource p_souce;
private VersatileMLDataSet p_data;
private EncogModel a_model;
private EncogModel p_model;
private NormalizationHelper activity_helper; private NormalizationHelper activity_helper;
private NormalizationHelper preference_helper; private NormalizationHelper preference_helper;
private MLRegression best_method; private MLRegression a_best_method;
private MLRegression p_best_method;
private String[] new_data; private String[] new_data;
private String[] preference_result;
private String activity_result; private String activity_result;
private String preference_result;
private void activityDataAnalyser(String activity_csv_url){ private void activityDataAnalyser(String activity_csv_url){
this.csv_url_activity = activity_csv_url; this.csv_url_activity = activity_csv_url;
this.csv_file = new File(csv_url_activity); this.csv_file = new File(csv_url_activity);
souce = new CSVDataSource(csv_file,false,CSVFormat.DECIMAL_POINT); a_souce = new CSVDataSource(csv_file,false,CSVFormat.DECIMAL_POINT);
data = new VersatileMLDataSet(souce); a_data = new VersatileMLDataSet(a_souce);
data.defineSourceColumn("monat", 0, ColumnType.continuous); a_data.defineSourceColumn("m_accel_x", 0, ColumnType.continuous);
data.defineSourceColumn("day", 1, ColumnType.continuous); a_data.defineSourceColumn("m_accel_y", 1, ColumnType.continuous);
data.defineSourceColumn("hour", 2, ColumnType.continuous); a_data.defineSourceColumn("m_accel_z", 2, ColumnType.continuous);
data.defineSourceColumn("minute", 3, ColumnType.continuous); a_data.defineSourceColumn("m_rotation_x", 3, ColumnType.continuous);
ColumnDefinition outputColumn = data.defineSourceColumn("labels", 4, ColumnType.continuous); a_data.defineSourceColumn("m_rotation_y", 4, ColumnType.continuous);
data.defineSingleOutputOthersInput(outputColumn); a_data.defineSourceColumn("m_rotation_z", 5, ColumnType.continuous);
data.analyze(); a_data.defineSourceColumn("w_accel_x", 6, ColumnType.continuous);
System.out.println("get data "); a_data.defineSourceColumn("w_accel_y", 7, ColumnType.continuous);
model = new EncogModel(data); a_data.defineSourceColumn("w_accel_z", 8, ColumnType.continuous);
model.selectMethod(data, MLMethodFactory.TYPE_FEEDFORWARD); a_data.defineSourceColumn("w_rotation_x", 9, ColumnType.continuous);
data.normalize(); a_data.defineSourceColumn("w_rotation_y", 10, ColumnType.continuous);
activity_helper = data.getNormHelper(); a_data.defineSourceColumn("w_rotation_z", 11, ColumnType.continuous);
System.out.println(activity_helper.toString()); ColumnDefinition outputColumn = a_data.defineSourceColumn("labels", 12, ColumnType.nominal);
a_data.defineSingleOutputOthersInput(outputColumn);
a_data.analyze();
a_model = new EncogModel(a_data);
a_model.selectMethod(a_data, MLMethodFactory.TYPE_FEEDFORWARD);
a_data.normalize();
activity_helper = a_data.getNormHelper();
//System.out.println(activity_helper.toString());
} }
private void perferenceDataAnalyser(String perference_csv_url){ private void preferenceDataAnalyser(String preference_csv_url){
this.csv_url_perference = perference_csv_url; this.csv_url_preference = preference_csv_url;
this.csv_file = new File(this.csv_url_perference); this.csv_file = new File(this.csv_url_preference);
souce = new CSVDataSource(csv_file,false,CSVFormat.DECIMAL_POINT); p_souce = new CSVDataSource(csv_file,false,CSVFormat.DECIMAL_POINT);
data = new VersatileMLDataSet(souce); p_data = new VersatileMLDataSet(p_souce);
data.defineSourceColumn("activity", 0, ColumnType.continuous); p_data.defineSourceColumn("activity", 0, ColumnType.nominal);
data.defineSourceColumn("brightness", 1, ColumnType.continuous); p_data.defineSourceColumn("w_brightness", 1, ColumnType.nominal);
data.defineSourceColumn("time", 2, ColumnType.continuous);
data.defineSourceColumn("minute", 3, ColumnType.continuous); ColumnDefinition outputColumn1 = p_data.defineSourceColumn("label1", 2, ColumnType.continuous);
ColumnDefinition outputColumn = data.defineSourceColumn("labels", 4, ColumnType.continuous); ColumnDefinition outputColumn2 = p_data.defineSourceColumn("label2", 3, ColumnType.continuous);
data.defineSingleOutputOthersInput(outputColumn); ColumnDefinition[] outputs = new ColumnDefinition[2];
data.analyze(); outputs[0] = outputColumn1;
model = new EncogModel(data); outputs[1] = outputColumn2;
model.selectMethod(data, MLMethodFactory.TYPE_FEEDFORWARD); p_data.defineMultipleOutputsOthersInput(outputs);
//model.setReport(new ConsoleStatusReportable()); p_data.analyze();
data.normalize(); p_model = new EncogModel(p_data);
preference_helper = data.getNormHelper(); p_model.selectMethod(p_data, MLMethodFactory.TYPE_FEEDFORWARD);
System.out.println(activity_helper.toString()); p_model.setReport(new ConsoleStatusReportable());
p_data.normalize();
preference_helper = p_data.getNormHelper();
System.out.println(preference_helper.toString());
} }
void train(String activity_url,String perference_url){ void train(String activity_url,String preference_url){
activity_train(activity_url); activity_train(activity_url);
Log.info("activity training finished"); Log.info("activity training finished");
preference_train(perference_url); preference_train(preference_url);
Log.info("preference training finished"); Log.info("preference training finished");
Encog.getInstance().shutdown();
} }
private void activity_train(String activity_csv_url){ public void activity_train(String activity_csv_url){
activityDataAnalyser(activity_csv_url); activityDataAnalyser(activity_csv_url);
model.holdBackValidation(0.3, true, 1001); a_model.holdBackValidation(0.3, true, 1001);
model.selectTrainingType(data); a_model.selectTrainingType(a_data);
best_method = (MLRegression)model.crossvalidate(5, true); a_best_method = (MLRegression)a_model.crossvalidate(5, true);
System.out.println(best_method); System.out.println(a_best_method);
saveEncogModel(save_activity_model_file); saveEncogModel(save_activity_model_file);
Encog.getInstance().shutdown();
} }
private void preference_train(String perfence_csv_url){ public void preference_train(String prefence_csv_url){
perferenceDataAnalyser(perfence_csv_url); preferenceDataAnalyser(prefence_csv_url);
model.holdBackValidation(0.3, true, 1001); p_model.holdBackValidation(0.3, true, 1001);
model.selectTrainingType(data); p_model.selectTrainingType(p_data);
best_method = (MLRegression)model.crossvalidate(5, true); p_best_method = (MLRegression)p_model.crossvalidate(5, true);
System.out.println(best_method); System.out.println(p_best_method);
saveEncogModel(save_perference_model_file); saveEncogModel(save_preference_model_file);
Encog.getInstance().shutdown();
} }
String[] predictor(String[] new_data){ String[] predictor(String[] new_data){
this.new_data = new_data; String[] preference_data = new String[2];
activityDataAnalyser("datasets/backup/activity_data.csv"); String[] result = new String[3];
perferenceDataAnalyser("datasets/backup/preference_data.csv"); String[] activity_data= new String[12];
String[] result = new String[2]; activity_data[0]=new_data[0];
result[0] = activity_predictor(); activity_data[1]=new_data[1];
result[1] = perference_predictor(); activity_data[2]=new_data[2];
activity_data[3]=new_data[3];
activity_data[4]=new_data[4];
activity_data[5]=new_data[5];
activity_data[6]=new_data[6];
activity_data[7]=new_data[7];
activity_data[8]=new_data[8];
activity_data[9]=new_data[9];
activity_data[10]=new_data[10];
activity_data[11]=new_data[11];
result[0] = activity_predictor(activity_data);
preference_data[0]=result[0];
preference_data[1]=new_data[12];
result[1] = preference_predictor(preference_data)[0];
result[2] = preference_predictor(preference_data)[1];
Encog.getInstance().shutdown(); Encog.getInstance().shutdown();
return result; return result;
} }
private String activity_predictor(){ public String activity_predictor(String[] new_data){
activityDataAnalyser("datasets/backup/activity_data.csv");
BasicNetwork activity_method = (BasicNetwork) loadObject(new File(save_activity_model_file)); BasicNetwork activity_method = (BasicNetwork) loadObject(new File(save_activity_model_file));
MLData input = activity_helper.allocateInputVector(); MLData input = activity_helper.allocateInputVector();
String[] activity_new_data = new String[4]; String[] activity_new_data = new String[12];
activity_new_data[0] = new_data[0]; activity_new_data[0] = new_data[0];
activity_new_data[1] = new_data[1]; activity_new_data[1] = new_data[1];
activity_new_data[2] = new_data[2]; activity_new_data[2] = new_data[2];
activity_new_data[3] = new_data[3]; activity_new_data[3] = new_data[3];
activity_new_data[4] = new_data[4];
activity_new_data[5] = new_data[5];
activity_new_data[6] = new_data[6];
activity_new_data[7] = new_data[7];
activity_new_data[8] = new_data[8];
activity_new_data[9] = new_data[9];
activity_new_data[10] = new_data[10];
activity_new_data[11] = new_data[11];
activity_helper.normalizeInputVector(activity_new_data,input.getData(),false); activity_helper.normalizeInputVector(activity_new_data,input.getData(),false);
MLData output = activity_method.compute(input); MLData output = activity_method.compute(input);
System.out.println("input:"+input);
System.out.println("output"+output);
activity_result = activity_helper.denormalizeOutputVectorToString(output)[0]; activity_result = activity_helper.denormalizeOutputVectorToString(output)[0];
System.out.println("output activity"+ activity_result);
return activity_result; return activity_result;
} }
private String perference_predictor(){ public String[] preference_predictor(String[] new_data){
BasicNetwork preference_method = (BasicNetwork)loadObject(new File(save_perference_model_file)); preference_result = new String[2];
preferenceDataAnalyser("datasets/backup/preference_data.csv");
BasicNetwork preference_method = (BasicNetwork)loadObject(new File(save_preference_model_file));
MLData input = preference_helper.allocateInputVector(); MLData input = preference_helper.allocateInputVector();
String[] perference_new_data = new String[4]; System.out.print("input: "+input);
perference_new_data[0] = activity_result; String[] preference_new_data = new String[2];
perference_new_data[1] = new_data[4]; //preference_new_data[0] = activity_result;
perference_new_data[2] = new_data[5]; preference_new_data[0] = new_data[0];
perference_new_data[3] = new_data[6]; preference_new_data[1] = new_data[1];
preference_helper.normalizeInputVector(perference_new_data, input.getData(),false); preference_helper.normalizeInputVector(preference_new_data, input.getData(),false);
System.out.println(preference_helper);
MLData output = preference_method.compute(input); MLData output = preference_method.compute(input);
preference_result = preference_helper.denormalizeOutputVectorToString(output)[0]; preference_result[0] = preference_helper.denormalizeOutputVectorToString(output)[0];
preference_result[1] = preference_helper.denormalizeOutputVectorToString(output)[1];
return preference_result; return preference_result;
} }
private void saveEncogModel(String model_file_url){ private void saveEncogModel(String model_file_url){
saveObject(new File(model_file_url), this.best_method); if (model_file_url.equals(save_activity_model_file)){saveObject(new File(model_file_url), this.a_best_method);}
else {
saveObject(new File(model_file_url), this.p_best_method);
}
} }
} }
......
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup; package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import de.tudresden.inf.st.eraser.jastadd.model.*; import de.tudresden.inf.st.eraser.jastadd.model.*;
import org.encog.util.csv.ReadCSV;
import org.encog.util.csv.CSVFormat;
import java.util.Arrays;
public class Main { public class Main {
...@@ -8,22 +11,140 @@ public class Main { ...@@ -8,22 +11,140 @@ public class Main {
/** /**
* new data from KB * new data from KB
* */ * */
String[] new_data = new String[7]; String[] new_data = new String[13];
new_data[0]="7"; new_data[0]="0.10654198";
new_data[1]="20"; new_data[1]="8.6574335";
new_data[2]="12"; new_data[2]="4.414908";
new_data[3]="13"; new_data[3]="0.040269";
new_data[4]="7"; new_data[4]="0.516884";
new_data[5]="25"; new_data[5]="0.853285";
new_data[6]="12"; new_data[6]="1.2066777";
new_data[7]="-1.1444284";
new_data[8]="9.648633";
new_data[9]="1.2207031E-4";
new_data[10]="-0.055358887";
new_data[11]="0.5834961";
new_data[12]="bright";
Learner learner = new Learner();
String[] result =learner.predictor(new_data);
System.out.println(result[0]);
System.out.println(result[1]);
System.out.println(result[2]);
//learner.preference_train("datasets/backup/preference_data.csv");
//learner.train("datasets/backup/activity_data.csv","datasets/backup/preference_data.csv");
//walking,medium,120,70
//reading,bright,180,0
//activity_validation_learner();
//0.10654198,8.6574335,4.414908,0.040269,0.516884,0.853285,1.2066777,-1.1444284,9.648633,1.2207031E-4,-0.055358887,0.5834961 working
/**String[] new_data = new String[12];
new_data[0]="0.10654198";
new_data[1]="8.6574335";
new_data[2]="4.414908";
new_data[3]="0.040269";
new_data[4]="0.516884";
new_data[5]="0.853285";
new_data[6]="1.2066777";
new_data[7]="-1.1444284";
new_data[8]="9.648633";
new_data[9]="1.2207031E-4";
new_data[10]="-0.055358887";
new_data[11]="0.5834961";
String[] new_data_1 =new String[12];
//-2.6252422,8.619126,-2.7030537,0.552147,0.5078,0.450302,-8.1881695,-1.2641385,0.038307227,-0.34222412,0.49102783,-0.016540527,walking
new_data_1[0]="-2.6252422";
new_data_1[1]="8.619126";
new_data_1[2]="-2.7030537";
new_data_1[3]="0.552147";
new_data_1[4]="0.5078";
new_data_1[5]="0.450302";
new_data_1[6]="-8.1881695";
new_data_1[7]="-1.2641385";
new_data_1[8]="0.038307227";
new_data_1[9]="-0.34222412";
new_data_1[10]="0.49102783";
new_data_1[11]="-0.016540527";
/** /**
* learner.train(activity_csv_url, preference_data_url) * learner.train(activity_csv_url, preference_data_url)
* learner.predictor get the result from predictor for new data * learner.predictor get the result from predictor for new data
* */ * */
/**String[] new_data_2 = new String[12];
//-6.5565214,5.717354,5.6658783,0.185591,0.464146,0.413321,-20.580557,3.8498764,-0.4261679,0.7647095,-0.4713745,0.23999023,dancing
new_data_2[0]="-6.5565214";
new_data_2[1]="5.717354";
new_data_2[2]="5.6658783";
new_data_2[3]="0.185591";
new_data_2[4]="0.464146";
new_data_2[5]="0.413321";
new_data_2[6]="-20.580557";
new_data_2[7]="3.8498764";
new_data_2[8]="-0.4261679";
new_data_2[9]="0.7647095";
new_data_2[10]="-0.4713745";
new_data_2[11]="0.23999023";
String[] new_data_3 = new String[12];
new_data_3[0]="-5.3881507";
new_data_3[1]="0.25378537";
new_data_3[2]="7.69257";
new_data_3[3]="-0.122974";
new_data_3[4]="0.247411";
new_data_3[5]="0.439031";
new_data_3[6]="4.9224787";
new_data_3[7]="-10.601525";
new_data_3[8]="-4.927267";
new_data_3[9]="0.7946167";
new_data_3[10]="0.35272217";
new_data_3[11]="0.16192627";
//"-5.3881507","0.25378537","7.69257","-0.122974","0.247411","0.439031","4.9224787","-10.601525","-4.927267","0.7946167","0.35272217","0.16192627","lying"
Learner learner=new Learner(); Learner learner=new Learner();
//learner.train("datasets/activity_data.csv", "datasets/preference_data.csv"); //learner.train("datasets/backup/activity_data.csv", "datasets/preference_data.csv");
String[] result = learner.predictor(new_data); String[] result = learner.predictor(new_data_3);
System.out.println("activity is:" + result[0]); System.out.println("activity is:" + result[0]);
System.out.println("perference is: "+ result[1]); //System.out.println("perference is: "+ result[1]);**/
}
public static void activity_validation_learner(){
ReadCSV csv = new ReadCSV("datasets/backup/activity_data.csv", false, CSVFormat.DECIMAL_POINT);
String[] line = new String[12];
Learner learner=new Learner();
int wrong=0;
int right=0;
while(csv.next()) {
StringBuilder result = new StringBuilder();
line[0] = csv.get(0);
line[1] = csv.get(1);
line[2] = csv.get(2);
line[3] = csv.get(3);
line[4] = csv.get(4);
line[5] = csv.get(5);
line[6] = csv.get(6);
line[7] = csv.get(7);
line[8] = csv.get(8);
line[9] = csv.get(9);
line[10] = csv.get(10);
line[11] = csv.get(11);
String correct = csv.get(12);
String irisChosen = learner.predictor(line)[0];
result.append(Arrays.toString(line));
result.append(" -> predicted: ");
result.append(irisChosen);
result.append("(correct: ");
result.append(correct);
result.append(")");
if (irisChosen.equals(correct)!=true){
System.out.println(correct);
System.out.println(irisChosen);
++wrong;
}else{
++right;
}
System.out.println(result.toString());
}
System.out.println("wrong number"+wrong);
System.out.println("right number"+right);
//double validation = (double(right))/(double(wrong+right));
//System.out.println("%.2f"+validation);
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment