Skip to content
Snippets Groups Projects
Commit 0c8ebacd authored by René Schöne's avatar René Schöne
Browse files

Internal rewrite of MachineLearningResult to a nonterminal.

- Make preference less noisy in logging
- MachineLearningResult and ItemPreference have equals added
- For classification, MachineLearningModel.classify() should be used as it sets lastPreference automatically
- Analyze also checks for updated preference
- Plan adds changed items of preference to recognition event
- Learner returns recognized activity correctly
parent 38bc86fe
Branches
No related tags found
1 merge request!19dev to master
Showing
with 150 additions and 142 deletions
...@@ -3,21 +3,21 @@ aspect DecisionTree { ...@@ -3,21 +3,21 @@ aspect DecisionTree {
// let DecisionTreeLeaf implement Leaf // let DecisionTreeLeaf implement Leaf
public class DecisionTreeLeaf implements Leaf { } public class DecisionTreeLeaf implements Leaf { }
//--- classify --- //--- internalClassify ---
syn DecisionTreeLeaf DecisionTreeRoot.classify() { syn DecisionTreeLeaf DecisionTreeRoot.internalClassify() {
return getRootRule().classify(); return getRootRule().internalClassify();
} }
syn DecisionTreeLeaf DecisionTreeElement.classify(); syn DecisionTreeLeaf DecisionTreeElement.internalClassify();
syn DecisionTreeLeaf DecisionTreeRule.classify(); syn DecisionTreeLeaf DecisionTreeRule.internalClassify();
syn DecisionTreeLeaf ItemStateCheckRule.classify() { syn DecisionTreeLeaf ItemStateCheckRule.internalClassify() {
boolean chooseLeft = getItemStateCheck().holds(); boolean chooseLeft = getItemStateCheck().holds();
return (chooseLeft ? getLeft() : getRight()).classify(); return (chooseLeft ? getLeft() : getRight()).internalClassify();
} }
syn DecisionTreeLeaf DecisionTreeLeaf.classify() = this; syn DecisionTreeLeaf DecisionTreeLeaf.internalClassify() = this;
//--- holds --- //--- holds ---
syn boolean ItemStateCheck.holds() = holdsFor(getItem()); syn boolean ItemStateCheck.holds() = holdsFor(getItem());
...@@ -47,19 +47,19 @@ aspect DecisionTree { ...@@ -47,19 +47,19 @@ aspect DecisionTree {
} }
//--- computePreferences --- //--- computePreferences ---
syn List<ItemPreference> DecisionTreeLeaf.computePreferences() { syn MachineLearningResult DecisionTreeLeaf.computePreferences() {
// iterate over preference of this leaf, and all its parents and ancestors // iterate over preference of this leaf, and all its parents and ancestors
List<ItemPreference> result = new ArrayList<>(); MachineLearningResult result = new MachineLearningResult();
Set<Item> seenItems = new HashSet<>(); Set<Item> seenItems = new HashSet<>();
List<DecisionTreeElement> ancestors = ancestors(); List<DecisionTreeElement> ancestors = ancestors();
for (ItemPreference pref : getPreferenceList()) { for (ItemPreference pref : getPreferenceList()) {
result.add(pref); result.addItemPreference(pref);
seenItems.add(pref.getItem()); seenItems.add(pref.getItem());
} }
for (DecisionTreeElement ancestor : ancestors) { for (DecisionTreeElement ancestor : ancestors) {
for (ItemPreference pref : ancestor.getPreferenceList()) { for (ItemPreference pref : ancestor.getPreferenceList()) {
if (!seenItems.contains(pref.getItem())) { if (!seenItems.contains(pref.getItem())) {
result.add(pref); result.addItemPreference(pref);
seenItems.add(pref.getItem()); seenItems.add(pref.getItem());
} }
} }
......
...@@ -401,7 +401,6 @@ aspect ItemHandling { ...@@ -401,7 +401,6 @@ aspect ItemHandling {
//--- ItemPreference.apply --- //--- ItemPreference.apply ---
public abstract void ItemPreference.apply(); public abstract void ItemPreference.apply();
public void ItemPreferenceColor.apply() { public void ItemPreferenceColor.apply() {
logger.debug("Apply color preference {} -> {}", getItem().getID(), getPreferredHSB());
getItem().setStateFromColor(getPreferredHSB()); getItem().setStateFromColor(getPreferredHSB());
getItem().freeze(); getItem().freeze();
for (Item controller : getItem().getControlledByList()) { for (Item controller : getItem().getControlledByList()) {
...@@ -409,8 +408,8 @@ aspect ItemHandling { ...@@ -409,8 +408,8 @@ aspect ItemHandling {
} }
getItem().unfreeze(); getItem().unfreeze();
} }
//--- ItemPreference.apply ---
public void ItemPreferenceDouble.apply() { public void ItemPreferenceDouble.apply() {
logger.debug("Apply double preference {} -> {}", getItem().getID(), getPreferredValue());
getItem().setStateFromDouble(getPreferredValue()); getItem().setStateFromDouble(getPreferredValue());
getItem().freeze(); getItem().freeze();
for (Item controller : getItem().getControlledByList()) { for (Item controller : getItem().getControlledByList()) {
...@@ -419,6 +418,14 @@ aspect ItemHandling { ...@@ -419,6 +418,14 @@ aspect ItemHandling {
getItem().unfreeze(); getItem().unfreeze();
} }
//--- ItemPreference.describe ---
syn String ItemPreference.describe() = getItem().getID() + " -> " + getNewStateAsString();
//--- ItemPreference.getNewStateAsString ---
syn String ItemPreference.getNewStateAsString();
eq ItemPreferenceColor.getNewStateAsString() = getPreferredHSB().toString();
eq ItemPreferenceDouble.getNewStateAsString() = Double.toString(getPreferredValue());
// // override Item.init$Children from JastAdd's own ASTNode aspect // // override Item.init$Children from JastAdd's own ASTNode aspect
// refine ASTNode public void Item.init$Children() { // refine ASTNode public void Item.init$Children() {
// refined(); // refined();
......
...@@ -8,10 +8,10 @@ aspect MachineLearning { ...@@ -8,10 +8,10 @@ aspect MachineLearning {
public interface Leaf { public interface Leaf {
String getLabel(); String getLabel();
int getActivityIdentifier(); int getActivityIdentifier();
List<ItemPreference> computePreferences(); MachineLearningResult computePreferences();
} }
syn Leaf InternalMachineLearningModel.classify(); syn Leaf InternalMachineLearningModel.internalClassify();
//--- currentActivityName --- //--- currentActivityName ---
syn String Root.currentActivityName() = JavaUtils.ifPresentOrElseReturn( syn String Root.currentActivityName() = JavaUtils.ifPresentOrElseReturn(
...@@ -31,8 +31,8 @@ aspect MachineLearning { ...@@ -31,8 +31,8 @@ aspect MachineLearning {
return (int) ((ItemPreferenceDouble) preferences.get(0)).getPreferredValue(); return (int) ((ItemPreferenceDouble) preferences.get(0)).getPreferredValue();
} }
//--- currentPreferences --- // //--- currentPreferences ---
syn List<ItemPreference> Root.currentPreferences() = getMachineLearningRoot().getPreferenceLearning().getDecoder().classify().getPreferences(); // syn List<ItemPreference> Root.currentPreferences() = getMachineLearningRoot().getPreferenceLearning().getDecoder().classify().getItemPreferences();
//--- canSetActivity --- //--- canSetActivity ---
syn boolean MachineLearningModel.canSetActivity() = false; syn boolean MachineLearningModel.canSetActivity() = false;
...@@ -45,7 +45,7 @@ aspect MachineLearning { ...@@ -45,7 +45,7 @@ aspect MachineLearning {
} }
//--- DummyMachineLearningModel.classify --- //--- DummyMachineLearningModel.classify ---
eq DummyMachineLearningModel.classify() { eq DummyMachineLearningModel.internalClassify() {
if (logger.isInfoEnabled() && getItemList().size() > 0) { if (logger.isInfoEnabled() && getItemList().size() > 0) {
logger.info("Dummy classification of {}, values of connected items: {}", logger.info("Dummy classification of {}, values of connected items: {}",
mlKind(), mlKind(),
...@@ -167,6 +167,43 @@ aspect MachineLearning { ...@@ -167,6 +167,43 @@ aspect MachineLearning {
return this.decoder; return this.decoder;
} }
//--- classify ---
public MachineLearningResult MachineLearningModel.classify() {
MachineLearningResult result = getDecoder().classify();
setLastPreference(result);
return result;
}
//--- equals ---
public boolean MachineLearningResult.equals(Object other) {
if (!(other instanceof MachineLearningResult)) {
return false;
}
MachineLearningResult otherResult = (MachineLearningResult) other;
if (getNumItemPreference() != otherResult.getNumItemPreference()) {
return false;
}
for (int i = 0; i < getNumItemPreference(); i++) {
if (!getItemPreference(i).equals(otherResult.getItemPreference(i))) {
return false;
}
}
return true;
}
public abstract boolean ItemPreference.equals(Object other);
public boolean ItemPreferenceDouble.equals(Object other) {
if (!(other instanceof ItemPreferenceDouble)) {
return false;
}
return getItem() == ((ItemPreferenceDouble) other).getItem() && getPreferredValue() == ((ItemPreferenceDouble) other).getPreferredValue();
}
public boolean ItemPreferenceColor.equals(Object other) {
if (!(other instanceof ItemPreferenceColor)) {
return false;
}
return getItem() == ((ItemPreferenceColor) other).getItem() && getPreferredHSB() == ((ItemPreferenceColor) other).getPreferredHSB();
}
} }
aspect ChangeEvents { aspect ChangeEvents {
...@@ -177,19 +214,30 @@ aspect ChangeEvents { ...@@ -177,19 +214,30 @@ aspect ChangeEvents {
RecognitionEvent result = new RecognitionEvent(); RecognitionEvent result = new RecognitionEvent();
result.initChangeEvent(); result.initChangeEvent();
for (Item relevantItem : modelOfRecognition.getRelevantItems()) { for (Item relevantItem : modelOfRecognition.getRelevantItems()) {
result.addChangedItem(ChangedItem.newFromItem(relevantItem)); result.addRelevantItem(ChangedItem.newFrom(relevantItem));
}
for (ItemPreference preference : modelOfRecognition.getLastPreference().getItemPreferences()) {
result.addChangedItem(ChangedItem.newFrom(preference));
} }
return result; return result;
} }
//--- newFrom Item --- //--- newFrom Item ---
public static ChangedItem ChangedItem.newFromItem(Item source) { public static ChangedItem ChangedItem.newFrom(Item source) {
ChangedItem result = new ChangedItem(); ChangedItem result = new ChangedItem();
result.setItem(source); result.setItem(source);
result.setNewStateAsString(source.getStateAsString()); result.setNewStateAsString(source.getStateAsString());
return result; return result;
} }
//--- newFrom ItemPreference ---
public static ChangedItem ChangedItem.newFrom(ItemPreference update) {
ChangedItem result = new ChangedItem();
result.setItem(update.getItem());
result.setNewStateAsString(update.getNewStateAsString());
return result;
}
//--- initChangeEvent --- //--- initChangeEvent ---
protected void ChangeEvent.initChangeEvent() { protected void ChangeEvent.initChangeEvent() {
this.setCreated(Instant.now()); this.setCreated(Instant.now());
......
...@@ -10,10 +10,11 @@ rel ChangedItem.Item -> Item ; ...@@ -10,10 +10,11 @@ rel ChangedItem.Item -> Item ;
RecognitionEvent : ChangeEvent ; RecognitionEvent : ChangeEvent ;
rel RecognitionEvent.Activity -> Activity ; rel RecognitionEvent.Activity -> Activity ;
rel RecognitionEvent.RelevantItem* -> ChangedItem ;
ManualChangeEvent : ChangeEvent ; ManualChangeEvent : ChangeEvent ;
abstract MachineLearningModel ::= ; abstract MachineLearningModel ::= LastPreference:MachineLearningResult ;
rel MachineLearningModel.RelevantItem* <-> Item.RelevantInMachineLearningModel* ; rel MachineLearningModel.RelevantItem* <-> Item.RelevantInMachineLearningModel* ;
rel MachineLearningModel.TargetItem* <-> Item.TargetInMachineLearningModel* ; rel MachineLearningModel.TargetItem* <-> Item.TargetInMachineLearningModel* ;
...@@ -21,6 +22,8 @@ ExternalMachineLearningModel : MachineLearningModel ; ...@@ -21,6 +22,8 @@ ExternalMachineLearningModel : MachineLearningModel ;
abstract InternalMachineLearningModel : MachineLearningModel ::= <OutputApplication:DoubleDoubleFunction> ; abstract InternalMachineLearningModel : MachineLearningModel ::= <OutputApplication:DoubleDoubleFunction> ;
MachineLearningResult ::= ItemPreference* ;
abstract ItemPreference ::= ; abstract ItemPreference ::= ;
rel ItemPreference.Item -> Item ; rel ItemPreference.Item -> Item ;
......
...@@ -18,17 +18,19 @@ aspect NeuralNetwork { ...@@ -18,17 +18,19 @@ aspect NeuralNetwork {
public int getActivityIdentifier() { public int getActivityIdentifier() {
return (int) number; return (int) number;
} }
public List<ItemPreference> computePreferences() { public MachineLearningResult computePreferences() {
return Collections.singletonList(new ItemPreferenceDouble(affectedItem, number)); MachineLearningResult result = new MachineLearningResult();
result.addItemPreference(new ItemPreferenceDouble(affectedItem, number));
return result;
} }
} }
//--- classify --- //--- internalClassify ---
syn DoubleNumber NeuralNetworkRoot.classify() { syn DoubleNumber NeuralNetworkRoot.internalClassify() {
return getOutputLayer().classify(); return getOutputLayer().internalClassify();
} }
syn DoubleNumber OutputLayer.classify() { syn DoubleNumber OutputLayer.internalClassify() {
double[] inputs = new double[getNumOutputNeuron()]; double[] inputs = new double[getNumOutputNeuron()];
for (int i = 0; i < getNumOutputNeuron(); ++i) { for (int i = 0; i < getNumOutputNeuron(); ++i) {
OutputNeuron n = getOutputNeuron(i); OutputNeuron n = getOutputNeuron(i);
......
...@@ -48,8 +48,7 @@ public class InternalMachineLearningHandler implements MachineLearningEncoder, M ...@@ -48,8 +48,7 @@ public class InternalMachineLearningHandler implements MachineLearningEncoder, M
@Override @Override
public MachineLearningResult classify() { public MachineLearningResult classify() {
List<ItemPreference> preferences = model.classify().computePreferences(); return model.internalClassify().computePreferences();
return new InternalMachineLearningResult(preferences);
} }
@Override @Override
......
package de.tudresden.inf.st.eraser.jastadd.model;
import java.util.List;
/**
* Result of a classification returned by an internally held machine learning model.
*
* @author rschoene - Initial contribution
*/
public class InternalMachineLearningResult implements MachineLearningResult {
private final List<ItemPreference> preferences;
InternalMachineLearningResult(List<ItemPreference> preferences) {
this.preferences = preferences;
}
@Override
public List<ItemPreference> getPreferences() {
return this.preferences;
}
}
package de.tudresden.inf.st.eraser.jastadd.model;
import de.tudresden.inf.st.eraser.jastadd.model.ItemPreference;
import java.util.List;
/**
* Representation of a classification result using a MachineLearningModel.
*
* @author rschoene - Initial contribution
*/
@SuppressWarnings("unused")
public interface MachineLearningResult {
// Object rawClass();
// double rawConfidence();
// can be used for both activity and preferences
/**
* Get the result as a list of item preferences, i.e., new states to be set for those items.
* @return the classification result as item preferences
*/
List<ItemPreference> getPreferences();
}
...@@ -25,17 +25,17 @@ public class DecisionTreeTest { ...@@ -25,17 +25,17 @@ public class DecisionTreeTest {
dtroot.setRootRule(newRule(isLessThanFour, isFourOrGreater, "check item1", check)); dtroot.setRootRule(newRule(isLessThanFour, isFourOrGreater, "check item1", check));
// current value is four, so return value should be "four or greater" // current value is four, so return value should be "four or greater"
Leaf leaf = dtroot.classify(); Leaf leaf = dtroot.internalClassify();
Assert.assertEquals(isFourOrGreater, leaf); Assert.assertEquals(isFourOrGreater, leaf);
// change value to 5, so return value should still be "four or greater" // change value to 5, so return value should still be "four or greater"
mai.item.setState(5); mai.item.setState(5);
leaf = dtroot.classify(); leaf = dtroot.internalClassify();
Assert.assertEquals(isFourOrGreater, leaf); Assert.assertEquals(isFourOrGreater, leaf);
// change value to 2, so return value should now be "less than four" // change value to 2, so return value should now be "less than four"
mai.item.setState(2); mai.item.setState(2);
leaf = dtroot.classify(); leaf = dtroot.internalClassify();
Assert.assertEquals(isLessThanFour, leaf); Assert.assertEquals(isLessThanFour, leaf);
} }
...@@ -64,17 +64,17 @@ public class DecisionTreeTest { ...@@ -64,17 +64,17 @@ public class DecisionTreeTest {
dtroot.setRootRule(newRule(rule25, rule75, "50-item1", check50)); dtroot.setRootRule(newRule(rule25, rule75, "50-item1", check50));
// current value is 20, so return value should be "less than 25" // current value is 20, so return value should be "less than 25"
Leaf leaf = dtroot.classify(); Leaf leaf = dtroot.internalClassify();
Assert.assertEquals(isLessThan25, leaf); Assert.assertEquals(isLessThan25, leaf);
// change value to 25, so return value should still be "25 or greater" // change value to 25, so return value should still be "25 or greater"
mai.item.setState(25); mai.item.setState(25);
leaf = dtroot.classify(); leaf = dtroot.internalClassify();
Assert.assertEquals(is25OrGreater, leaf); Assert.assertEquals(is25OrGreater, leaf);
// change value to 100, so return value should now be "greater than 75" // change value to 100, so return value should now be "greater than 75"
mai.item.setState(100); mai.item.setState(100);
leaf = dtroot.classify(); leaf = dtroot.internalClassify();
Assert.assertEquals(is75OrGreater, leaf); Assert.assertEquals(is75OrGreater, leaf);
} }
...@@ -139,7 +139,7 @@ public class DecisionTreeTest { ...@@ -139,7 +139,7 @@ public class DecisionTreeTest {
for (TestResult result : testResults) { for (TestResult result : testResults) {
mai.item.setState(Math.round(result.value)); mai.item.setState(Math.round(result.value));
Leaf leaf = dtroot.classify(); Leaf leaf = dtroot.internalClassify();
Assert.assertEquals(result.chooseLeft ? leaf : right, leaf); Assert.assertEquals(result.chooseLeft ? leaf : right, leaf);
} }
} }
......
...@@ -50,17 +50,17 @@ public class NeuralNetworkTest { ...@@ -50,17 +50,17 @@ public class NeuralNetworkTest {
neuralNetworkRoot.setOutputLayer(outputLayer); neuralNetworkRoot.setOutputLayer(outputLayer);
// Current value is 1, so return value is 1 * 4 * 0.6 = 2.4, 1 * 4 * 0.4 = 1.6, both are less than 4. so 0. // Current value is 1, so return value is 1 * 4 * 0.6 = 2.4, 1 * 4 * 0.4 = 1.6, both are less than 4. so 0.
Leaf leaf = neuralNetworkRoot.classify(); Leaf leaf = neuralNetworkRoot.internalClassify();
assertLeafEqual(0, leaf); assertLeafEqual(0, leaf);
// Current value is 2, so return value is 2 * 4 * 0.6 = 4.8, 2 * 4 * 0.4 = 3.2, first is less than 4. so 1. // Current value is 2, so return value is 2 * 4 * 0.6 = 4.8, 2 * 4 * 0.4 = 3.2, first is less than 4. so 1.
mai.item.setState(2); mai.item.setState(2);
leaf = neuralNetworkRoot.classify(); leaf = neuralNetworkRoot.internalClassify();
assertLeafEqual(1, leaf); assertLeafEqual(1, leaf);
// Current value is 5, so return value is 5 * 4 * 0.6 = 12, 5 * 4 * 0.4 = 8, both are greater than 4. so 3. // Current value is 5, so return value is 5 * 4 * 0.6 = 12, 5 * 4 * 0.4 = 8, both are greater than 4. so 3.
mai.item.setState(5); mai.item.setState(5);
leaf = neuralNetworkRoot.classify(); leaf = neuralNetworkRoot.internalClassify();
assertLeafEqual(3, leaf); assertLeafEqual(3, leaf);
} }
......
...@@ -3,6 +3,8 @@ package de.tudresden.inf.st.eraser.feedbackloop.analyze; ...@@ -3,6 +3,8 @@ package de.tudresden.inf.st.eraser.feedbackloop.analyze;
import de.tudresden.inf.st.eraser.feedbackloop.api.Analyze; import de.tudresden.inf.st.eraser.feedbackloop.api.Analyze;
import de.tudresden.inf.st.eraser.feedbackloop.api.Plan; import de.tudresden.inf.st.eraser.feedbackloop.api.Plan;
import de.tudresden.inf.st.eraser.jastadd.model.Activity; import de.tudresden.inf.st.eraser.jastadd.model.Activity;
import de.tudresden.inf.st.eraser.jastadd.model.ItemPreference;
import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningResult;
import de.tudresden.inf.st.eraser.jastadd.model.Root; import de.tudresden.inf.st.eraser.jastadd.model.Root;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
...@@ -17,10 +19,12 @@ public class AnalyzeImpl implements Analyze { ...@@ -17,10 +19,12 @@ public class AnalyzeImpl implements Analyze {
private Root knowledgeBase; private Root knowledgeBase;
private Plan plan; private Plan plan;
private Activity mostRecentActivity; private Activity mostRecentActivity;
private MachineLearningResult mostRecentPreferences;
private Logger logger = LogManager.getLogger(AnalyzeImpl.class); private Logger logger = LogManager.getLogger(AnalyzeImpl.class);
public AnalyzeImpl() { public AnalyzeImpl() {
this.mostRecentActivity = null; this.mostRecentActivity = null;
this.mostRecentPreferences = null;
} }
@Override @Override
...@@ -40,17 +44,32 @@ public class AnalyzeImpl implements Analyze { ...@@ -40,17 +44,32 @@ public class AnalyzeImpl implements Analyze {
@Override @Override
public void analyzeLatestChanges() { public void analyzeLatestChanges() {
MachineLearningResult recognitionResult = knowledgeBase.getMachineLearningRoot().getActivityRecognition().getDecoder().classify();
recognitionResult.getItemPreferences().forEach(ItemPreference::apply);
knowledgeBase.currentActivity().ifPresent(activity -> { knowledgeBase.currentActivity().ifPresent(activity -> {
MachineLearningResult newMLResult = knowledgeBase.getMachineLearningRoot().getPreferenceLearning().classify();
// check if activity has changed
if (!activity.equals(mostRecentActivity)) { if (!activity.equals(mostRecentActivity)) {
// new! inform plan! // new! inform plan!
logger.info("Found new activity '{}'", activity.getLabel()); logger.info("Found new activity '{}'", activity.getLabel());
mostRecentActivity = activity; try {
informPlan(activity);
} catch (Exception e) {
logger.catching(e);
}
} else {
// if no change, also check, if preferences have changed
if (!newMLResult.equals(mostRecentPreferences)) {
logger.info("Preferences have changed for same activity '{}'", activity.getLabel());
try { try {
informPlan(activity); informPlan(activity);
} catch (Exception e) { } catch (Exception e) {
logger.catching(e); logger.catching(e);
} }
} }
}
mostRecentActivity = activity;
mostRecentPreferences = newMLResult;
}); });
} }
} }
...@@ -22,7 +22,7 @@ public interface Execute { ...@@ -22,7 +22,7 @@ public interface Execute {
void setKnowledgeBase(Root knowledgeBase); void setKnowledgeBase(Root knowledgeBase);
/** /**
* <b>Deprecated</b>: Use {@link #updateItems(List)} instead. * <b>Deprecated</b>: Use {@link #updateItems(Iterable)} instead.
* @param brightnessAndRgbForItems Map, keys are item names, values are RGB and brightness values * @param brightnessAndRgbForItems Map, keys are item names, values are RGB and brightness values
*/ */
@Deprecated @Deprecated
...@@ -32,5 +32,5 @@ public interface Execute { ...@@ -32,5 +32,5 @@ public interface Execute {
* Updates items according to given preferences * Updates items according to given preferences
* @param preferences tuples containing item and its new HSB value * @param preferences tuples containing item and its new HSB value
*/ */
void updateItems(List<ItemPreference> preferences); void updateItems(Iterable<ItemPreference> preferences);
} }
...@@ -21,7 +21,7 @@ public interface Plan { ...@@ -21,7 +21,7 @@ public interface Plan {
void planToMatchPreferences(Activity activity); void planToMatchPreferences(Activity activity);
default void informExecute(List<ItemPreference> preferences) { default void informExecute(Iterable<ItemPreference> preferences) {
getExecute().updateItems(preferences); getExecute().updateItems(preferences);
} }
} }
...@@ -54,7 +54,7 @@ public class ExecuteImpl implements Execute { ...@@ -54,7 +54,7 @@ public class ExecuteImpl implements Execute {
} }
@Override @Override
public void updateItems(List<ItemPreference> preferences) { public void updateItems(Iterable<ItemPreference> preferences) {
for (ItemPreference preference : preferences) { for (ItemPreference preference : preferences) {
preference.apply(); preference.apply();
} }
......
...@@ -226,7 +226,7 @@ public class Main { ...@@ -226,7 +226,7 @@ public class Main {
List<String> output = new ArrayList<>(); List<String> output = new ArrayList<>();
Function<DoubleNumber, String> leafToString = classification -> Double.toString(classification.number); Function<DoubleNumber, String> leafToString = classification -> Double.toString(classification.number);
Function<NeuralNetworkRoot, DoubleNumber> classify = NeuralNetworkRoot::classify; Function<NeuralNetworkRoot, DoubleNumber> classify = NeuralNetworkRoot::internalClassify;
DoubleNumber classification = classify.apply(nn); DoubleNumber classification = classify.apply(nn);
output.add(leafToString.apply(classification)); output.add(leafToString.apply(classification));
System.out.println(output); System.out.println(output);
......
...@@ -160,21 +160,27 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn ...@@ -160,21 +160,27 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn
@Override @Override
public MachineLearningResult classify() { public MachineLearningResult classify() {
MachineLearningResult result = new MachineLearningResult();
switch (this.goal) { switch (this.goal) {
case GOAL_ACTIVITY_PHONE_AND_WATCH: case GOAL_ACTIVITY_PHONE_AND_WATCH:
String activityStringValue = activity_result; String activityStringValue = activity_result;
if (activityStringValue == null) {
return result;
}
Item activityItem = resolve(this.root.getSmartHomeEntityModel(), "activity"); Item activityItem = resolve(this.root.getSmartHomeEntityModel(), "activity");
//activityItem.setStateFromString(activityStringValue); //activityItem.setStateFromString(activityStringValue);
// FIXME how to translate activityStringValue to a number? or should activity item state better be a String? // FIXME how to translate activityStringValue to a number? or should activity item state better be a String?
for (int i=0; i< activites.length;i++){ ItemPreference classifiedActivity = null;
if(activites[i].equals(activityStringValue)){ for (Activity activity : this.root.getMachineLearningRoot().getActivityList()) {
activityItem.setStateFromString(String.valueOf(i)); if (activity.getLabel().equals(activityStringValue)) {
ItemPreference classifiedActivity = new ItemPreferenceDouble(activityItem,i); classifiedActivity = new ItemPreferenceDouble(activityItem, activity.getIdentifier());
return new MachineLearningResultImpl(classifiedActivity);
} }
logger.debug("Classify would return activity: {}", activityStringValue); logger.debug("Classify would return activity: {}", activityStringValue);
} }
if (classifiedActivity != null) {
result.addItemPreference(classifiedActivity);
}
return result;
case GOAL_PREFERENCE_BRIGHTNESS_IRIS: case GOAL_PREFERENCE_BRIGHTNESS_IRIS:
// String[] preference = {result[1], result[2]}; // String[] preference = {result[1], result[2]};
// FIXME what is the meaning of result[1] and result[2] // FIXME what is the meaning of result[1] and result[2]
...@@ -185,11 +191,11 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn ...@@ -185,11 +191,11 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn
color = Math.round(Float.valueOf(preference_result[0])); color = Math.round(Float.valueOf(preference_result[0]));
brightness = Math.round(Float.valueOf(preference_result[1])); brightness = Math.round(Float.valueOf(preference_result[1]));
} }
ItemPreference classifiedPreference = new ItemPreferenceColor(iris1, TupleHSB.of(color, 100, brightness)); result.addItemPreference(new ItemPreferenceColor(iris1, TupleHSB.of(color, 100, brightness)));
return new MachineLearningResultImpl(classifiedPreference); return result;
default: default:
logger.error("Unknown goal value ({}) set in classify", this.goal); logger.error("Unknown goal value ({}) set in classify", this.goal);
return new EmptyMachineLearningResult(); return result;
} }
} }
...@@ -202,12 +208,4 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn ...@@ -202,12 +208,4 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn
learner.preference_train(filenameOfCsv); learner.preference_train(filenameOfCsv);
} }
class EmptyMachineLearningResult implements MachineLearningResult {
@Override
public List<ItemPreference> getPreferences() {
return Collections.emptyList();
}
}
} }
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import de.tudresden.inf.st.eraser.jastadd.model.ItemPreference;
import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningResult;
import java.util.Collections;
import java.util.List;
/**
* TODO: Add description.
*
* @author rschoene - Initial contribution
*/
public class MachineLearningResultImpl implements MachineLearningResult {
private final ItemPreference preference;
MachineLearningResultImpl(ItemPreference preference) {
this.preference = preference;
}
@Override
public List<ItemPreference> getPreferences() {
return Collections.singletonList(preference);
}
}
...@@ -37,14 +37,18 @@ public class PlanImpl implements Plan { ...@@ -37,14 +37,18 @@ public class PlanImpl implements Plan {
@Override @Override
public void planToMatchPreferences(Activity activity) { public void planToMatchPreferences(Activity activity) {
logger.info("Plan got new activity [{}]: {}", activity.getIdentifier(), activity.getLabel()); logger.info("Plan got new activity [{}]: {}", activity.getIdentifier(), activity.getLabel());
List<ItemPreference> preferences = knowledgeBase.currentPreferences(); MachineLearningResult mlResult = knowledgeBase.getMachineLearningRoot().getPreferenceLearning().getLastPreference();
knowledgeBase.getMachineLearningRoot().addChangeEvent(createRecognitionEvent(activity)); knowledgeBase.getMachineLearningRoot().addChangeEvent(createRecognitionEvent(activity, mlResult.getItemPreferences()));
informExecute(preferences); knowledgeBase.getMachineLearningRoot().getChangeEventList();
informExecute(mlResult.getItemPreferences());
} }
private ChangeEvent createRecognitionEvent(Activity activity) { private ChangeEvent createRecognitionEvent(Activity activity, Iterable<ItemPreference> preferences) {
RecognitionEvent result = RecognitionEvent.createRecognitionEvent(knowledgeBase.getMachineLearningRoot().getActivityRecognition()); RecognitionEvent result = RecognitionEvent.createRecognitionEvent(knowledgeBase.getMachineLearningRoot().getActivityRecognition());
result.setActivity(activity); result.setActivity(activity);
for (ItemPreference preference : preferences) {
result.addChangedItem(ChangedItem.newFrom(preference));
}
return result; return result;
} }
} }
...@@ -158,7 +158,7 @@ public class Main { ...@@ -158,7 +158,7 @@ public class Main {
hiddenNeuron.connectTo(output, 1.0/pr.hiddenNeurons.length); hiddenNeuron.connectTo(output, 1.0/pr.hiddenNeurons.length);
} }
classifyTimed(pr.nn, NeuralNetworkRoot::classify, classifyTimed(pr.nn, NeuralNetworkRoot::internalClassify,
classification -> Double.toString(classification.number)); classification -> Double.toString(classification.number));
} }
...@@ -183,7 +183,7 @@ public class Main { ...@@ -183,7 +183,7 @@ public class Main {
} }
} }
classifyTimed(pr.nn, NeuralNetworkRoot::classify, classifyTimed(pr.nn, NeuralNetworkRoot::internalClassify,
classification -> Double.toHexString(classification.number)); classification -> Double.toHexString(classification.number));
// long before = System.nanoTime(); // long before = System.nanoTime();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment