Skip to content
Snippets Groups Projects
Commit 872a8543 authored by FRohde's avatar FRohde
Browse files

KB<->Learner integration: added Encoder/Decoder implementations that allow KB...

KB<->Learner integration: added Encoder/Decoder implementations that allow KB to integrate with external model
parent f763cf0e
Branches
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ dependencies { ...@@ -9,6 +9,7 @@ dependencies {
compile group: 'net.sf.beaver', name: 'beaver-rt', version: '0.9.11' compile group: 'net.sf.beaver', name: 'beaver-rt', version: '0.9.11'
compile group: 'org.fusesource.mqtt-client', name: 'mqtt-client', version: '1.15' compile group: 'org.fusesource.mqtt-client', name: 'mqtt-client', version: '1.15'
compile group: 'org.influxdb', name: 'influxdb-java', version: '2.15' compile group: 'org.influxdb', name: 'influxdb-java', version: '2.15'
compile project(':feedbackloop.learner')
testCompile group: 'org.testcontainers', name: 'testcontainers', version: '1.11.2' testCompile group: 'org.testcontainers', name: 'testcontainers', version: '1.11.2'
testCompile group: 'org.testcontainers', name: 'influxdb', version: '1.11.2' testCompile group: 'org.testcontainers', name: 'influxdb', version: '1.11.2'
testCompile group: 'org.apache.logging.log4j', name: 'log4j-slf4j-impl', version: '2.11.2' testCompile group: 'org.apache.logging.log4j', name: 'log4j-slf4j-impl', version: '2.11.2'
......
package de.tudresden.inf.st.eraser.feedbackloop.learner;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
/**
* Adapter for internally held machine learning models.
*
* @author rschoene - Initial contribution
*/
public class ActivityLearningHandler extends LearningHandler{
public ActivityLearningHandler setLearner(LearnerImpl learner) {
return (ActivityLearningHandler)super.setLearner(learner);
}
public ActivityLearningHandler setModel(InternalMachineLearningModel model) {
return (ActivityLearningHandler)super.setModel(model);
}
@Override
public List<Item> getTargets() {
List<Item> targets = new ArrayList<Item>();
List<String> itemsIds = this.getLearner().getTargetItemsIdsActivityLearning();
for (String itemId:itemsIds){
targets.add(resolve(itemId));
}
return targets ;
}
@Override
public List<Item> getRelevantItems() {
List<Item> relevantItems = new ArrayList<Item>();
List<String> itemsIds = this.getLearner().getRelevantItemsIdsActivityLearning();
for (String itemId:itemsIds){
relevantItems.add(resolve(itemId));
}
return relevantItems ;
}
@Override
public void triggerTraining() {
this.getLogger().debug("Ignored training trigger.");
}
//classify using the input vector given by newData
@Override
public MachineLearningResult classify() {
List<Item> targets = getTargets();//get items that this model is supposed to change
//Using activity recognition to get current activity
ActivityItem activity_item = (ActivityItem) targets.get(0);
//prepare output
List<ItemPreference> preferences = new ArrayList<ItemPreference>();
ItemPreference preference = getPreferenceItem(activity_item, new double[] {this.getLearner().getActivity()});
preferences.add(preference);//add preference to the preferences array
return new ExternalMachineLearningResult(preferences);
}
@Override
public Instant lastModelUpdate() {
return null;
}
}
package de.tudresden.inf.st.eraser.feedbackloop.learner;
import de.tudresden.inf.st.eraser.jastadd.model.ItemPreference;
import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningResult;
import java.util.List;
public class ExternalMachineLearningResult implements MachineLearningResult {
public ExternalMachineLearningResult(List<ItemPreference> preferences) {
this.preferences = preferences;
}
private final List<ItemPreference> preferences;
@Override
public List<ItemPreference> getPreferences() {
return this.preferences;
}
}
\ No newline at end of file
...@@ -2,7 +2,10 @@ package de.tudresden.inf.st.eraser.feedbackloop.learner; ...@@ -2,7 +2,10 @@ package de.tudresden.inf.st.eraser.feedbackloop.learner;
import de.tudresden.inf.st.eraser.feedbackloop.api.EncogModel; import de.tudresden.inf.st.eraser.feedbackloop.api.EncogModel;
import de.tudresden.inf.st.eraser.feedbackloop.api.Learner; import de.tudresden.inf.st.eraser.feedbackloop.api.Learner;
import de.tudresden.inf.st.eraser.jastadd.model.ColorItem;
import de.tudresden.inf.st.eraser.jastadd.model.Item;
import de.tudresden.inf.st.eraser.jastadd.model.Root; import de.tudresden.inf.st.eraser.jastadd.model.Root;
import de.tudresden.inf.st.eraser.jastadd.model.TupleHSB;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.encog.neural.flat.FlatNetwork; import org.encog.neural.flat.FlatNetwork;
...@@ -35,7 +38,7 @@ public class LearnerImpl implements Learner { ...@@ -35,7 +38,7 @@ public class LearnerImpl implements Learner {
private CSVFormat format = new CSVFormat('.', ','); private CSVFormat format = new CSVFormat('.', ',');
private Map<Integer, Dataset> datasets = new HashMap<>(); private Map<Integer, Dataset> datasets = new HashMap<>();
private Map<Integer, Network> models = new HashMap<>(); private Map<Integer, Network> models = new HashMap<>();
private Map<Integer, List<Double>> inputVectors = new HashMap<>();
@Override @Override
public void setKnowledgeBase(Root knowledgeBase) { public void setKnowledgeBase(Root knowledgeBase) {
...@@ -66,24 +69,26 @@ public class LearnerImpl implements Learner { ...@@ -66,24 +69,26 @@ public class LearnerImpl implements Learner {
@Override @Override
public boolean loadModelFromFile(File file, int modelID, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes, public boolean loadModelFromFile(File file, int modelID, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes,
List<Integer> targetMins) { List<Integer> targetMins,List<String> listInputIndex, List<String> listOutputIndex) {
logger.debug("Load model from file {}", file); logger.debug("Load model from file {}", file);
models.put(modelID, new Network(file.getAbsolutePath(), modelID, inputMaxes, inputMins, targetMaxes, targetMins)); models.put(modelID, new Network(file.getAbsolutePath(), modelID, inputMaxes, inputMins, targetMaxes, targetMins, listInputIndex, listOutputIndex));
inputVectors.put(modelID,new ArrayList<>());
return true; return true;
} }
@Override @Override
public boolean loadModelFromFile(InputStream input, int modelID, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes, public boolean loadModelFromFile(InputStream input, int modelID, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes,
List<Integer> targetMins) { List<Integer> targetMins,List<String> listInputIndex, List<String> listOutputIndex) {
logger.debug("Load model from input stream"); logger.debug("Load model from input stream");
models.put(modelID, new Network(input, modelID, inputMaxes, inputMins, targetMaxes, targetMins)); models.put(modelID, new Network(input, modelID, inputMaxes, inputMins, targetMaxes, targetMins, listInputIndex, listOutputIndex));
inputVectors.put(modelID,new ArrayList<>());
return true; return true;
} }
@Override @Override
public boolean train(int inputCount, int outputCount, int hiddenCount, int hiddenNeuronCount, int modelID, public boolean train(int inputCount, int outputCount, int hiddenCount, int hiddenNeuronCount, int modelID,
List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes,
List<Integer> targetMins) { List<Integer> targetMins,List<String> listInputIndex, List<String> listOutputIndex) {
// Method for the initial training of algorithms and models. That uses external data set for training. // Method for the initial training of algorithms and models. That uses external data set for training.
if (datasets.get(modelID) != null) { if (datasets.get(modelID) != null) {
...@@ -92,7 +97,7 @@ public class LearnerImpl implements Learner { ...@@ -92,7 +97,7 @@ public class LearnerImpl implements Learner {
ReadCSV csv = set.getCsv(); ReadCSV csv = set.getCsv();
Network model = new Network(inputCount, outputCount, hiddenCount, hiddenNeuronCount, modelID, inputMaxes, Network model = new Network(inputCount, outputCount, hiddenCount, hiddenNeuronCount, modelID, inputMaxes,
inputMins, targetMaxes, targetMins); inputMins, targetMaxes, targetMins, listInputIndex, listOutputIndex);
ArrayList<Double> input = new ArrayList<>(); ArrayList<Double> input = new ArrayList<>();
ArrayList<Double> target = new ArrayList<>(); ArrayList<Double> target = new ArrayList<>();
...@@ -114,6 +119,7 @@ public class LearnerImpl implements Learner { ...@@ -114,6 +119,7 @@ public class LearnerImpl implements Learner {
} }
models.put(modelID, model); models.put(modelID, model);
inputVectors.put(modelID,new ArrayList<>());
model.saveModel(modelFolderPath); model.saveModel(modelFolderPath);
return true; return true;
...@@ -124,10 +130,10 @@ public class LearnerImpl implements Learner { ...@@ -124,10 +130,10 @@ public class LearnerImpl implements Learner {
@Override @Override
public boolean train(double[][] data, int inputCount, int outputCount, int hiddenCount, int hiddenNeuronCount, int modelID, public boolean train(double[][] data, int inputCount, int outputCount, int hiddenCount, int hiddenNeuronCount, int modelID,
List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes,
List<Integer> targetMins, List<Integer> targetColumns) { List<Integer> targetMins, List<Integer> targetColumns, List<String> listInputIndex, List<String> listOutputIndex) {
Network model = new Network(inputCount, outputCount, hiddenCount, hiddenNeuronCount, modelID, inputMaxes, Network model = new Network(inputCount, outputCount, hiddenCount, hiddenNeuronCount, modelID, inputMaxes,
inputMins, targetMaxes, targetMins); inputMins, targetMaxes, targetMins, listInputIndex, listOutputIndex);
return reTrainModel(model, data, targetColumns, modelID); return reTrainModel(model, data, targetColumns, modelID);
} }
...@@ -163,6 +169,7 @@ public class LearnerImpl implements Learner { ...@@ -163,6 +169,7 @@ public class LearnerImpl implements Learner {
} }
models.put(modelID, model); models.put(modelID, model);
inputVectors.put(modelID,new ArrayList<>());
model.saveModel(modelFolderPath); model.saveModel(modelFolderPath);
return true; return true;
...@@ -227,4 +234,72 @@ public class LearnerImpl implements Learner { ...@@ -227,4 +234,72 @@ public class LearnerImpl implements Learner {
return models.get(modelID).getNormalizersTar().get(columnNr); return models.get(modelID).getNormalizersTar().get(columnNr);
} }
private void setValueInputVectorByModelId(int modelId, int index, Double value){
this.inputVectors.get(modelId).set(index,value);
}
private int getInputIndex( String itemId, int modelID){
return this.models.get(modelID).getInputIndex(itemId);
}
public int getOutputIndex( String itemId, int modelID){
return this.models.get(modelID).getOutputIndex(itemId);
}
public double getActivity(){
double[] output_activity_recognition = this.models.get(0).computeResult(this.inputVectors.get(0));
return output_activity_recognition[0];
}
public double[] getPreferencesForCurrentActivity(){
return this.models.get(1).computeResult(this.inputVectors.get(1));
}
//prepare input vector for classify
private void updateInputVector(Item item, int modelId){
int index = getInputIndex(item.getID(),modelId);//get position in the input vector
//update item in position "index" in input vector
if(item instanceof ColorItem){//color items have three values (HSB)
ColorItem coloritem = (ColorItem) item;
TupleHSB state = coloritem.get_state();
setValueInputVectorByModelId(modelId,index, Double.valueOf(state.getHue()) );
setValueInputVectorByModelId(modelId,index+1, Double.valueOf(state.getSaturation()));
setValueInputVectorByModelId(modelId,index+2, Double.valueOf(state.getBrightness()));
}else{
double value = item.getStateAsDouble();
setValueInputVectorByModelId(modelId,index, value );
}
}
public void updateLearner(List<Item> changedItems){
int n_models = this.models.size();
for(Item item:changedItems){
for(int i = 0; i< n_models ;i++){
if(this.models.get(i).itemRelevant(item.getID())) {
updateInputVector(item, i);
}
}
}
}
public List<String> getTargetItemsIdsPreferenceLearning() {
return this.models.get(1).getTargetItemsIds();
}
public List<String> getRelevantItemsIdsPrefenceLearning() {
return this.models.get(1).getRelevantItemsIds();
}
public List<String> getTargetItemsIdsActivityLearning() {
return this.models.get(0).getTargetItemsIds();
}
public List<String> getRelevantItemsIdsActivityLearning() {
return this.models.get(0).getRelevantItemsIds();
}
} }
package de.tudresden.inf.st.eraser.feedbackloop.learner;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.time.Instant;
import java.util.List;
/**
* Adapter for internally held machine learning models.
*
* @author rschoene - Initial contribution
*/
public abstract class LearningHandler implements MachineLearningEncoder, MachineLearningDecoder {
private static final Logger logger = LogManager.getLogger(LearningHandler.class);
private LearnerImpl Learner;
private InternalMachineLearningModel model;
private OpenHAB2Model openHAB2model;
public LearnerImpl getLearner() {
return this.Learner;
}
public LearningHandler setLearner(LearnerImpl learner) {
this.Learner=learner;
return this;
}
@Override
public void newData(List<Item> changedItems) {
//prepare input vector for each model and each item
this.getLearner().updateLearner(changedItems);
}
public LearningHandler setModel(InternalMachineLearningModel model) {
this.model = model;
return this;
}
public static Logger getLogger() {
return logger;
}
@Override
public void setKnowledgeBaseRoot(Root root) {
// ignored
}
public Item resolve(String itemId) {
java.util.Optional<Item> maybeItem = this.openHAB2model.resolveItem(itemId);
if (maybeItem.isPresent()) {
return maybeItem.get();
} else {
logger.warn("Could not find item with id {}", itemId);
return null;
}
}
@Override
public abstract List<Item> getTargets();
@Override
public abstract List<Item> getRelevantItems();
@Override
public void triggerTraining() {
logger.debug("Ignored training trigger.");
}
//classify using the input vector given by newData
@Override
public abstract MachineLearningResult classify();
public ItemPreference getPreferenceItem(Item item, double[] output_preferenceLearning){
ItemPreference preference;
int index = this.Learner.getOutputIndex(item.getID(),1);
if(item.getClass().getName()=="ColorItem"){
preference = new ItemPreferenceColor(item, TupleHSB.of( (int)Math.round(output_preferenceLearning[index]), (int)Math.round(output_preferenceLearning[index+1]), (int)Math.round(output_preferenceLearning[index+2])));
}else{
preference = new ItemPreferenceDouble(item,output_preferenceLearning[index]);
}
return preference;
}
@Override
public Instant lastModelUpdate() {
return null;
}
}
...@@ -18,6 +18,8 @@ import org.encog.neural.networks.training.propagation.back.Backpropagation; ...@@ -18,6 +18,8 @@ import org.encog.neural.networks.training.propagation.back.Backpropagation;
import org.encog.persist.EncogDirectoryPersistence; import org.encog.persist.EncogDirectoryPersistence;
import org.encog.util.arrayutil.NormalizationAction; import org.encog.util.arrayutil.NormalizationAction;
import org.encog.util.arrayutil.NormalizedField; import org.encog.util.arrayutil.NormalizedField;
import java.util.HashMap;
import java.util.Map;
import org.encog.util.simple.EncogUtility; import org.encog.util.simple.EncogUtility;
/** /**
...@@ -31,6 +33,8 @@ public class Network { ...@@ -31,6 +33,8 @@ public class Network {
private int modelID; private int modelID;
private ArrayList<NormalizedField> normalizersIn; private ArrayList<NormalizedField> normalizersIn;
private ArrayList<NormalizedField> normalizersTar; private ArrayList<NormalizedField> normalizersTar;
private Map<String, Integer> indexInputvector = new HashMap<>();
private Map<String, Integer> indexOutputvector = new HashMap<>();
/** /**
* Constructor for when the neural network is created from data. * Constructor for when the neural network is created from data.
...@@ -44,10 +48,12 @@ public class Network { ...@@ -44,10 +48,12 @@ public class Network {
* @param inputMins list that contains min values of all input columns (sensors) e.g. light intensity 0 * @param inputMins list that contains min values of all input columns (sensors) e.g. light intensity 0
* @param targetMaxes list that contains max values of all output columns (results) e.g. brightness 100 for preference learning * @param targetMaxes list that contains max values of all output columns (results) e.g. brightness 100 for preference learning
* @param targetMins list that contains min values of all output columns (results) e.g. brightness 0 for preference learning * @param targetMins list that contains min values of all output columns (results) e.g. brightness 0 for preference learning
* @param listInputIndex list that containg the item's Ids to initialize the map indexInputvector
* @param listOutputIndex list that containg the item's Ids to initialize the map indexOutpurvector
*/ */
public Network(int inputCount, int outputCount, int hiddenCount, int hiddenNeuronCount, int modelID, public Network(int inputCount, int outputCount, int hiddenCount, int hiddenNeuronCount, int modelID,
List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes,
List<Integer> targetMins) { List<Integer> targetMins,List<String> listInputIndex,List<String> listOutputIndex) {
normalizersIn = new ArrayList<>(); normalizersIn = new ArrayList<>();
normalizersTar = new ArrayList<>(); normalizersTar = new ArrayList<>();
...@@ -67,8 +73,36 @@ public class Network { ...@@ -67,8 +73,36 @@ public class Network {
addNormalizer(inputMaxes, inputMins, normalizersIn); addNormalizer(inputMaxes, inputMins, normalizersIn);
addNormalizer(targetMaxes, targetMins, normalizersTar); addNormalizer(targetMaxes, targetMins, normalizersTar);
initializeMaps(listInputIndex,listOutputIndex);
} }
private void initializeMaps(List<String> listInputIndex, List<String> listOutputIndex){
for(int i = 0; i < listInputIndex.size();i++){
this.indexInputvector.put(listInputIndex.get(i),i);
}
for(int i = 0; i < listOutputIndex.size();i++){
this.indexOutputvector.put(listOutputIndex.get(i),i);
}
}
public List<String> getTargetItemsIds(){
return new ArrayList<String>(this.indexOutputvector.keySet());
}
public List<String> getRelevantItemsIds(){
return new ArrayList<String>(this.indexInputvector.keySet());
}
public int getInputIndex( String itemId){
return indexInputvector.get(itemId);
}
public int getOutputIndex( String itemId){
return indexOutputvector.get(itemId);
}
private void addNormalizer(List<Integer> maxes, List<Integer> mins, ArrayList<NormalizedField> normalizers) { private void addNormalizer(List<Integer> maxes, List<Integer> mins, ArrayList<NormalizedField> normalizers) {
for (int j = 0; j < maxes.size(); j++) { for (int j = 0; j < maxes.size(); j++) {
NormalizedField normalizer = new NormalizedField("in_" + j, NormalizationAction.Normalize, NormalizedField normalizer = new NormalizedField("in_" + j, NormalizationAction.Normalize,
...@@ -87,10 +121,12 @@ public class Network { ...@@ -87,10 +121,12 @@ public class Network {
* @param inputMins list that contains min values of all input columns (sensors) e.g. light intensity 0 * @param inputMins list that contains min values of all input columns (sensors) e.g. light intensity 0
* @param targetMaxes list that contains max values of all output columns (results) e.g. brightness 100 for preference learning * @param targetMaxes list that contains max values of all output columns (results) e.g. brightness 100 for preference learning
* @param targetMins list that contains min values of all output columns (results) e.g. brightness 0 for preference learning * @param targetMins list that contains min values of all output columns (results) e.g. brightness 0 for preference learning
* @param listInputIndex list that containg the item's Ids to initialize the map indexInputvector
* @param listOutputIndex list that containg the item's Ids to initialize the map indexOutpurvector
*/ */
public Network(String path, int modelID, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes, public Network(String path, int modelID, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes,
List<Integer> targetMins) { List<Integer> targetMins, List<String> listInputIndex,List<String> listOutputIndex) {
this(() -> (BasicNetwork) EncogDirectoryPersistence.loadObject(new File(path, "NN_" + modelID)), modelID, inputMaxes, inputMins, targetMaxes, targetMins); this(() -> (BasicNetwork) EncogDirectoryPersistence.loadObject(new File(path, "NN_" + modelID)), modelID, inputMaxes, inputMins, targetMaxes, targetMins,listInputIndex,listOutputIndex);
} }
/** /**
...@@ -103,14 +139,16 @@ public class Network { ...@@ -103,14 +139,16 @@ public class Network {
* @param inputMins list that contains min values of all input columns (sensors) e.g. light intensity 0 * @param inputMins list that contains min values of all input columns (sensors) e.g. light intensity 0
* @param targetMaxes list that contains max values of all output columns (results) e.g. brightness 100 for preference learning * @param targetMaxes list that contains max values of all output columns (results) e.g. brightness 100 for preference learning
* @param targetMins list that contains min values of all output columns (results) e.g. brightness 0 for preference learning * @param targetMins list that contains min values of all output columns (results) e.g. brightness 0 for preference learning
* @param listInputIndex list that containg the item's Ids to initialize the map indexInputvector
* @param listOutputIndex list that containg the item's Ids to initialize the map indexOutpurvector
*/ */
public Network(InputStream input, int modelID, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes, public Network(InputStream input, int modelID, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes,
List<Integer> targetMins) { List<Integer> targetMins, List<String> listInputIndex,List<String> listOutputIndex) {
this(() -> (BasicNetwork) EncogDirectoryPersistence.loadObject(input), modelID, inputMaxes, inputMins, targetMaxes, targetMins); this(() -> (BasicNetwork) EncogDirectoryPersistence.loadObject(input), modelID, inputMaxes, inputMins, targetMaxes, targetMins, listInputIndex,listOutputIndex);
} }
private Network(LoadEncogModel loader, int modelID, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes, private Network(LoadEncogModel loader, int modelID, List<Integer> inputMaxes, List<Integer> inputMins, List<Integer> targetMaxes,
List<Integer> targetMins) { List<Integer> targetMins, List<String> listInputIndex,List<String> listOutputIndex) {
this.modelID = modelID; this.modelID = modelID;
normalizersIn = new ArrayList<>(); normalizersIn = new ArrayList<>();
...@@ -120,6 +158,7 @@ public class Network { ...@@ -120,6 +158,7 @@ public class Network {
addNormalizer(inputMaxes, inputMins, normalizersIn); addNormalizer(inputMaxes, inputMins, normalizersIn);
addNormalizer(targetMaxes, targetMins, normalizersTar); addNormalizer(targetMaxes, targetMins, normalizersTar);
initializeMaps(listInputIndex,listOutputIndex);
} }
@FunctionalInterface @FunctionalInterface
...@@ -199,4 +238,8 @@ public class Network { ...@@ -199,4 +238,8 @@ public class Network {
public ArrayList<NormalizedField> getNormalizersTar() { public ArrayList<NormalizedField> getNormalizersTar() {
return normalizersTar; return normalizersTar;
} }
public boolean itemRelevant(String itemId){
return this.indexInputvector.containsKey(itemId);
}
} }
package de.tudresden.inf.st.eraser.feedbackloop.learner;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import java.util.ArrayList;
import java.util.List;
/**
* Adapter for internally held machine learning models.
*
* @author rschoene - Initial contribution
*/
public class PreferenceLearningHandler extends ActivityLearningHandler implements MachineLearningEncoder, MachineLearningDecoder {
public PreferenceLearningHandler setLearner(LearnerImpl learner) {
return (PreferenceLearningHandler) super.setLearner(learner);
}
public PreferenceLearningHandler setModel(InternalMachineLearningModel model) {
return (PreferenceLearningHandler) super.setModel(model);
}
@Override
public List<Item> getTargets() {
List<Item> targets = new ArrayList<Item>();
List<String> itemsIds = this.getLearner().getTargetItemsIdsPreferenceLearning();
for (String itemId:itemsIds){
targets.add(resolve(itemId));
}
return targets ;
}
@Override
public List<Item> getRelevantItems() {
List<Item> relevantItems = new ArrayList<Item>();
List<String> itemsIds = this.getLearner().getRelevantItemsIdsPrefenceLearning();
for (String itemId:itemsIds){
relevantItems.add(resolve(itemId));
}
return relevantItems ;
}
@Override
public void triggerTraining() {
getLogger().debug("Ignored training trigger.");
}
//classify using the input vector given by newData
@Override
public MachineLearningResult classify() {
//Using activity recognition to get current activity
MachineLearningResult resultActivityRecognition = super.classify();
ActivityItem current_activity = (ActivityItem) resultActivityRecognition.getPreferences();
//update Learner's state
List<Item> changedItems = new ArrayList<>();
changedItems.add(current_activity);
this.getLearner().updateLearner(changedItems);
//Using preference learning
double[] output_preferenceLearning = this.getLearner().getPreferencesForCurrentActivity();
//prepare output
List<ItemPreference> preferences = new ArrayList<ItemPreference>();
List<Item> targets = getTargets();//get items that this model is supposed to change
for (Item i:targets){
ItemPreference preference = getPreferenceItem(i,output_preferenceLearning);
preferences.add(preference);//add preference to the preferences array
}
return new ExternalMachineLearningResult(preferences);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment