Skip to content
Snippets Groups Projects
Commit bafabb58 authored by René Schöne's avatar René Schöne
Browse files

Restructure Learner to only handle one goal.

- Introduce JSON definition of inputs and outputs, not final yet
- Introduce LearnerTestSettings to avoid long parameter list
parent cc5c9cde
No related branches found
No related tags found
1 merge request!19dev to master
Pipeline #4716 failed
Showing
with 490 additions and 238 deletions
package de.tudresden.inf.st.eraser.jastadd.model;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.IOException;
import java.net.URL;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
/**
* Factory to create new handlers ({@link MachineLearningEncoder} and {@link MachineLearningDecoder})
......@@ -21,7 +28,7 @@ public abstract class MachineLearningHandlerFactory implements MachineLearningSe
this.knowledgeBase = root;
}
public abstract void initializeFor(MachineLearningHandlerFactoryTarget target, URL configUrl);
public abstract void initializeFor(MachineLearningHandlerFactoryTarget target, URL configUrl) throws IOException;
public abstract MachineLearningEncoder createEncoder();
......@@ -46,4 +53,72 @@ public abstract class MachineLearningHandlerFactory implements MachineLearningSe
public void shutdown() {
// empty by default
}
/**
* Creates a new factory to be used, if there was an error during configuration of a factory
* @return a new factory logging warning messages upon each invoked method
*/
public static MachineLearningHandlerFactory createErrorFactory() {
return new MachineLearningHandlerFactory() {
private final Logger logger = LogManager.getLogger(MachineLearningHandlerFactory.class);
@Override
public void initializeFor(MachineLearningHandlerFactoryTarget target, URL configUrl) {
logger.warn("initializeFor called for ErrorFactory");
}
@Override
public MachineLearningEncoder createEncoder() {
return new MachineLearningEncoder() {
@Override
public void newData(List<Item> changedItems) {
logger.warn("newData called for encoder of ErrorFactory");
}
@Override
public List<Item> getTargets() {
logger.warn("getTargets called for encoder of ErrorFactory");
return Collections.emptyList();
}
@Override
public List<Item> getRelevantItems() {
logger.warn("getRelevantItems called for encoder of ErrorFactory");
return Collections.emptyList();
}
@Override
public void triggerTraining() {
logger.warn("triggerTraining called for encoder of ErrorFactory");
}
@Override
public void setKnowledgeBaseRoot(Root root) {
logger.warn("setKnowledgeBaseRoot called for encoder of ErrorFactory");
}
};
}
@Override
public MachineLearningDecoder createDecoder() {
return new MachineLearningDecoder() {
@Override
public MachineLearningResult classify() {
logger.warn("classify called for decoder of ErrorFactory");
return new MachineLearningResult();
}
@Override
public Instant lastModelUpdate() {
logger.warn("lastModelUpdate called for decoder of ErrorFactory");
return null;
}
@Override
public void setKnowledgeBaseRoot(Root root) {
logger.warn("setKnowledgeBaseRoot called for decoder of ErrorFactory");
}
};
}
};
}
}
......@@ -223,23 +223,25 @@ public class EraserStarter {
}
private static MachineLearningHandlerFactory createFactory(MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget target, Setting.MLContainer config, Root root) {
MachineLearningHandlerFactory factory = new DummyMachineLearningHandlerFactory();
MachineLearningHandlerFactory factory;
String niceTargetName = target.toString().toLowerCase().replace("_", " ");
if (config.dummy || config.factory == null) {
logger.info("Using dummy {}, ignoring other settings for this", niceTargetName);
factory = new DummyMachineLearningHandlerFactory();
} else {
try {
Class<? extends MachineLearningHandlerFactory> clazz = Class.forName(config.factory)
.asSubclass(MachineLearningHandlerFactory.class);
factory = clazz.newInstance();
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
logger.error("Could not instantiate machine learning factory for {} with class '{}'. Using dummy instead.",
factory.setKnowledgeBaseRoot(root);
factory.initializeFor(target, config.realURL());
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException | IOException e) {
logger.error("Could not instantiate machine learning factory for {} with class '{}'. Using error factory instead.",
niceTargetName, config);
logger.catching(e);
factory = MachineLearningHandlerFactory.createErrorFactory();
}
}
factory.setKnowledgeBaseRoot(root);
factory.initializeFor(target, config.realURL());
return factory;
}
}
../../../../feedbackloop.learner_backup/src/test/resources/activity_data.csv
\ No newline at end of file
../../../../feedbackloop.learner_backup/src/main/resources/activity_definition.json
\ No newline at end of file
......@@ -26,7 +26,7 @@ load:
activity:
factory: de.tudresden.inf.st.eraser.feedbackloop.learner_backup.MachineLearningHandlerFactoryImpl
# File to read in. Expected format depends on factory
file: ../datasets/backup/activity_data.csv
file: ../datasets/backup/activity_definition.json
external: true
# Use dummy model in which the current activity is directly editable. Default: false.
dummy: false
......
apply plugin: 'application'
plugins {
id 'application'
id 'io.franzbecker.gradle-lombok' version '3.0.0'
}
dependencies {
compile project(':eraser-base')
......
......@@ -9,17 +9,12 @@ import java.util.Random;
public class DummyPreference {
private static String activity; //Activity: walking, reading, working, dancing, lying, getting up
private static String watch_brightness; //dark: <45; dimmer 45-70; bright >70;
private static String light_color_openhab_H; //red 7; green 120; blue 240; yellow 60; sky blue 180; purple 300;
private static String brightness_output; //1-100**/
private static Random random = new Random();
public static void main(String[] args) {
creator();
}
static void creator(){
private static void creator(){
try {
FileWriter writer = new FileWriter("datasets/backup/preference_data.csv",true);
CSVWriter csv_writer = new CSVWriter(writer, ',',
......@@ -27,7 +22,6 @@ public class DummyPreference {
CSVWriter.DEFAULT_ESCAPE_CHARACTER,
CSVWriter.DEFAULT_LINE_END);
//activity="walking" green
activity ="walking";
......@@ -40,12 +34,25 @@ public class DummyPreference {
csv_writer.writeAll(generator("getting up","yellow"));
csv_writer.close();
writer.close();
}catch (IOException e){e.printStackTrace();}
} catch (IOException e) {
e.printStackTrace();
}
static List<String[]> generator(String activity_input, String color){
List<String[]> data = new ArrayList<String[]>();
}
/**
* Generate random data
* @param activity_input input for activity
* @param color red 7; green 120; blue 240; yellow 60; sky blue 180; purple 300;
* @return generated data
*/
private static List<String[]> generator(String activity_input, String color){
List<String[]> data = new ArrayList<>();
//dark: <45; dimmer 45-70; bright >70;
//1-100**/
String brightness_output;
String watch_brightness;
activity = activity_input;
light_color_openhab_H =color;
//
//100 walking with different lighting intensity
for (int i=0; i<100; i++) {
String[] add_data = new String[4];
......@@ -54,10 +61,10 @@ public class DummyPreference {
if (brightness<45) {
watch_brightness = "dark";
brightness_output ="100";
}else if(45<=brightness && brightness<200){
} else if(brightness < 200) {
watch_brightness = "dimmer";
brightness_output ="40";
}else if( 200<=brightness && brightness<1000){
} else if(brightness < 1000) {
watch_brightness = "medium";
brightness_output ="70";
} else {
......@@ -66,7 +73,7 @@ public class DummyPreference {
}
add_data[0] = activity;
add_data[1] = watch_brightness;
add_data[2] = light_color_openhab_H;
add_data[2] = color;
add_data[3] = brightness_output;
data.add(add_data);
}
......
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import com.fasterxml.jackson.databind.ObjectMapper;
import de.tudresden.inf.st.eraser.feedbackloop.learner_backup.data.LearnerSettings;
import de.tudresden.inf.st.eraser.feedbackloop.learner_backup.data.LearnerSettingsColumnDefinition;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.encog.Encog;
......@@ -13,145 +16,106 @@ import org.encog.ml.data.versatile.sources.VersatileDataSource;
import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.model.EncogModel;
import org.encog.neural.networks.BasicNetwork;
import org.encog.persist.EncogDirectoryPersistence;
import org.encog.util.csv.CSVFormat;
import java.io.File;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import static org.encog.persist.EncogDirectoryPersistence.loadObject;
import static org.encog.persist.EncogDirectoryPersistence.saveObject;
/**
* Internal class for Neural Networks using Encog.
* Expects either a CSV file with data points to learn from,
* or serialized versions of a trained network and its normalization helpers.
*
* @author rschoene - Initial contribution
*/
public class Learner {
private File modelFile;
private NormalizationHelper normalizationHelper;
private BasicNetwork network;
private static final Logger logger = LogManager.getLogger(Learner.class);
private final LearnerSettings settings;
/**
* intial train
* Creates a new learner object using the configuration at the given URL.
* @param configURL the location of the configuration
* @throws IOException if an error occurs in {@link ObjectMapper#readValue(URL, Class)}
*/
private File save_activity_model_file;
private File save_preference_model_file;
private File csv_file;
private VersatileMLDataSet a_data;
private VersatileMLDataSet p_data;
private EncogModel a_model;
private EncogModel p_model;
private NormalizationHelper activity_helper;
private NormalizationHelper preference_helper;
private BasicNetwork a_best_method;
private BasicNetwork p_best_method;
private final Logger logger = LogManager.getLogger(Learner.class);
public Learner() {
try {
save_activity_model_file = File.createTempFile("activity_model", "eg");
} catch (IOException e) {
// use local alternative
save_activity_model_file = new File("activity_model.eq");
Learner(URL configURL) throws IOException {
this(new ObjectMapper().readValue(configURL, LearnerSettings.class));
}
Learner(LearnerSettings settings) {
this.settings = settings;
try {
save_preference_model_file = File.createTempFile("preference_model", "eg");
modelFile = File.createTempFile(settings.name + "_model", "eg");
} catch (IOException e) {
// use local alternative
save_preference_model_file = new File("preference_model.eg");
}
save_activity_model_file.deleteOnExit();
save_preference_model_file.deleteOnExit();
}
private void activityDataAnalyser(String activity_csv_url) {
VersatileDataSource a_source;
String csv_url_activity;
csv_url_activity = activity_csv_url;
this.csv_file = new File(csv_url_activity);
a_source = new CSVDataSource(csv_file, true, CSVFormat.DECIMAL_POINT);
a_data = new VersatileMLDataSet(a_source);
String[] activity_inputs = {"m_accel_x", "m_accel_y", "m_accel_z",
"m_rotation_x", "m_rotation_y", "m_rotation_z",
"w_accel_x", "w_accel_y", "w_accel_z",
"w_rotation_x", "w_rotation_y", "w_rotation_z"
};
for (int i = 0; i < activity_inputs.length; i++) {
a_data.defineSourceColumn(activity_inputs[i], i, ColumnType.continuous);
modelFile = new File(settings.name + "_model.eg");
}
ColumnDefinition outputColumn = a_data.defineSourceColumn("labels", 12, ColumnType.nominal);
a_data.defineSingleOutputOthersInput(outputColumn);
a_data.analyze();
a_model = new EncogModel(a_data);
a_model.selectMethod(a_data, MLMethodFactory.TYPE_FEEDFORWARD);
a_data.normalize();
activity_helper = a_data.getNormHelper();
modelFile.deleteOnExit();
}
private void preferenceDataAnalyser(String preference_csv_url) {
VersatileDataSource p_source;
String csv_url_preference;
csv_url_preference = preference_csv_url;
this.csv_file = new File(csv_url_preference);
p_source = new CSVDataSource(csv_file, true, CSVFormat.DECIMAL_POINT);
p_data = new VersatileMLDataSet(p_source);
p_data.defineSourceColumn("activity", 0, ColumnType.nominal);
p_data.defineSourceColumn("w_brightness", 1, ColumnType.nominal);
ColumnDefinition outputColumn1 = p_data.defineSourceColumn("label1", 2, ColumnType.continuous);
ColumnDefinition outputColumn2 = p_data.defineSourceColumn("label2", 3, ColumnType.continuous);
ColumnDefinition[] outputs = new ColumnDefinition[2];
outputs[0] = outputColumn1;
outputs[1] = outputColumn2;
p_data.defineMultipleOutputsOthersInput(outputs);
p_data.analyze();
p_model = new EncogModel(p_data);
p_model.selectMethod(p_data, MLMethodFactory.TYPE_FEEDFORWARD);
p_data.normalize();
preference_helper = p_data.getNormHelper();
/**
* Begin training using the training set specified in settings.
* @throws MalformedURLException if the location of the training set in the settings is malformed
*/
void train() throws MalformedURLException {
URL location = new File(settings.initialDataFile).toURI().toURL();
train(location);
}
void activity_train(String activity_csv_url) {
logger.info("Activity training is beginning ... ...");
activityDataAnalyser(activity_csv_url);
a_model.holdBackValidation(0.3, true, 1001);
a_model.selectTrainingType(a_data);
a_best_method = (BasicNetwork) a_model.crossvalidate(5, true);
saveEncogModel(save_activity_model_file);
logger.info("Activity training is finished ... ...");
/**
* Begin training with the given initial training set.
* @param location the location of the training set
*/
void train(URL location) {
logger.info("Training for {} begins using {}", settings.name, location);
VersatileDataSource source;
File csvFile = new File(location.getFile());
source = new CSVDataSource(csvFile, true, CSVFormat.DECIMAL_POINT);
VersatileMLDataSet data = new VersatileMLDataSet(source);
final int inputSize = settings.inputColumns.size();
for (int index = 0; index < inputSize; index++) {
LearnerSettingsColumnDefinition columnDefinition = settings.inputColumns.get(index);
data.defineSourceColumn(columnDefinition.name, index, columnDefinition.type);
}
void preference_train(String prefence_csv_url) {
logger.info("Preference training is beginning ... ...");
preferenceDataAnalyser(prefence_csv_url);
p_model.holdBackValidation(0.3, true, 1001);
p_model.selectTrainingType(p_data);
p_best_method = (BasicNetwork) p_model.crossvalidate(5, true);
saveEncogModel(save_preference_model_file);
logger.info("Preference training is finished ... ...");
List<ColumnDefinition> targets = new ArrayList<>();
for (int targetIndex = 0; targetIndex < settings.targetColumns.size(); targetIndex++) {
LearnerSettingsColumnDefinition columnDefinition = settings.targetColumns.get(targetIndex);
targets.add(data.defineSourceColumn(columnDefinition.name, inputSize + targetIndex, columnDefinition.type));
}
String activity_predictor(String[] new_data) {
String activity_result;
MLData input = activity_helper.allocateInputVector();
activity_helper.normalizeInputVector(new_data, input.getData(), false);
MLData output = a_best_method.compute(input);
activity_result = activity_helper.denormalizeOutputVectorToString(output)[0];
logger.debug("Activity Predictor result is: {}", activity_result);
return activity_result;
if (targets.size() == 1) {
data.defineSingleOutputOthersInput(targets.get(0));
} else {
data.defineMultipleOutputsOthersInput(targets.toArray(new ColumnDefinition[0]));
}
String[] preference_predictor(String[] new_data) {
MLData input = preference_helper.allocateInputVector();
preference_helper.normalizeInputVector(new_data, input.getData(), false);
MLData output = p_best_method.compute(input);
String[] preference_result = new String[] {
preference_helper.denormalizeOutputVectorToString(output)[0],
preference_helper.denormalizeOutputVectorToString(output)[1]
};
logger.debug("Preference Predictor result, Color: {}, brightness: {}",
Math.round(Float.parseFloat(preference_result[0])), Math.round(Float.parseFloat(preference_result[1])));
return preference_result;
data.analyze();
EncogModel model = new EncogModel(data);
model.selectMethod(data, settings.trainingMethod);
data.normalize();
normalizationHelper = data.getNormHelper();
model.holdBackValidation(settings.validationPercent, settings.shuffleForValidation, settings.validationSeed);
model.selectTrainingType(data);
network = (BasicNetwork) model.crossvalidate(settings.validationFolds, settings.shuffleForValidation);
EncogDirectoryPersistence.saveObject(modelFile, network);
logger.info("Training for {} finished", settings.name);
}
private void saveEncogModel(File modelFile) {
if (modelFile.equals(save_activity_model_file)) {
saveObject(modelFile, this.a_best_method);
} else {
saveObject(modelFile, this.p_best_method);
}
String[] predictor(String[] newData) {
String[] result;
MLData input = normalizationHelper.allocateInputVector();
normalizationHelper.normalizeInputVector(newData, input.getData(), false);
MLData output = network.compute(input);
result = normalizationHelper.denormalizeOutputVectorToString(output);
logger.debug("Result prediction for {} applied on {} is {}:", settings.name, newData, result);
return result;
}
void shutdown() {
......
......@@ -4,6 +4,7 @@ import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningDecoder;
import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningEncoder;
import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningHandlerFactory;
import java.io.IOException;
import java.net.URL;
/**
......@@ -14,24 +15,22 @@ import java.net.URL;
public class MachineLearningHandlerFactoryImpl extends MachineLearningHandlerFactory {
private MachineLearningImpl handler;
private static Learner learner = new Learner();
@Override
public void initializeFor(MachineLearningHandlerFactoryTarget target, URL configUrl) {
public void initializeFor(MachineLearningHandlerFactoryTarget target, URL configUrl) throws IOException {
switch (target) {
case ACTIVITY_RECOGNITION:
handler = new MachineLearningImpl(learner, MachineLearningImpl.GOAL_ACTIVITY_PHONE_AND_WATCH);
handler = new MachineLearningImpl(MachineLearningImpl.GOAL_ACTIVITY_PHONE_AND_WATCH, configUrl);
handler.setKnowledgeBaseRoot(knowledgeBase);
handler.initActivities(configUrl.getFile());
break;
case PREFERENCE_LEARNING:
handler = new MachineLearningImpl(learner, MachineLearningImpl.GOAL_PREFERENCE_BRIGHTNESS_IRIS);
handler = new MachineLearningImpl(MachineLearningImpl.GOAL_PREFERENCE_BRIGHTNESS_IRIS, configUrl);
handler.setKnowledgeBaseRoot(knowledgeBase);
handler.initPreferences(configUrl.getFile());
break;
default:
throw new UnsupportedOperationException("Target " + target + " is not supported");
}
handler.startTraining();
}
@Override
......@@ -45,6 +44,8 @@ public class MachineLearningHandlerFactoryImpl extends MachineLearningHandlerFac
}
public void shutdown() {
learner.shutdown();
if (handler != null) {
handler.getLearner().shutdown();
}
}
}
......@@ -4,6 +4,9 @@ import de.tudresden.inf.st.eraser.jastadd.model.*;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.time.Instant;
import java.util.*;
import java.util.stream.Collectors;
......@@ -29,20 +32,17 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn
private Root root;
private String activity_result;
private Instant lastModelUpdate;
public MachineLearningImpl(Learner learner, int goal) {
this.learner = learner;
public MachineLearningImpl(int goal, URL configURL) throws IOException {
this.learner = new Learner(configURL);
this.goal = goal;
}
@Override
public void setKnowledgeBaseRoot(Root root) {
this.root = root;
updateItems();
}
private void updateItems() {
SmartHomeEntityModel model = root.getSmartHomeEntityModel();
SmartHomeEntityModel model = this.root.getSmartHomeEntityModel();
List<String> targetItemNames, relevantItemNames;
switch (this.goal) {
case GOAL_ACTIVITY_PHONE_AND_WATCH:
......@@ -123,7 +123,7 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn
}
}
if(!empty){
this.activity_result = learner.activity_predictor(a_new_data);
this.activity_result = learner.predictor(a_new_data)[0];
}
a_length = 0;
Arrays.fill(this.a_new_data, null);
......@@ -140,9 +140,7 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn
this.p_new_data[1] = item.getStateAsString();
}
}
logger.debug("debug_preference_new_data: {}", Arrays.toString(this.p_new_data));
this.preference_result = learner.preference_predictor(this.p_new_data);
logger.debug("preference for {} is {}", Arrays.toString(this.p_new_data), this.preference_result);
this.preference_result = learner.predictor(this.p_new_data);
}
}
......@@ -163,7 +161,7 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn
@Override
public Instant lastModelUpdate() {
return null;
return this.lastModelUpdate;
}
@Override
......@@ -202,14 +200,12 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn
}
}
void initActivities(String filenameOfCsv) {
logger.debug("init activities with {}", filenameOfCsv);
learner.activity_train(filenameOfCsv);
void startTraining() throws MalformedURLException {
learner.train();
this.lastModelUpdate = Instant.now();
}
void initPreferences(String filenameOfCsv) {
logger.debug("init preferences with {}", filenameOfCsv);
learner.preference_train(filenameOfCsv);
Learner getLearner() {
return learner;
}
}
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import com.fasterxml.jackson.databind.ObjectMapper;
import de.tudresden.inf.st.eraser.feedbackloop.learner_backup.data.LearnerSettings;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.encog.util.csv.CSVFormat;
import org.encog.util.csv.ReadCSV;
import java.io.File;
import java.io.IOException;
import java.net.MalformedURLException;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.stream.Collectors;
public class Main {
private static final Logger logger = LogManager.getLogger(Main.class);
public static void main(String[] args) {
/*
new data from KB
*/
// ReaderCSV reader = new ReaderCSV("datasets/backup/activity_data.csv","preference");
// reader.updater();
// Learner learner=new Learner();
// learner.preference_train("../datasets/backup/preference_data.csv");
// learner.train("datasets/backup/activity_data.csv","datasets/backup/preference_data.csv");
activity_validation_learner();
// activity_validation_learner();
testSettings();
}
private static void activity_validation_learner() {
private static void testSettings() {
ObjectMapper mapper = new ObjectMapper();
File settingsFile = Paths.get("src", "main", "resources", "activity_definition.json").toFile();
LearnerSettings settings;
try {
settings = mapper.readValue(settingsFile, LearnerSettings.class);
} catch (IOException e) {
logger.catching(e);
return;
}
System.out.println("settings.name = " + settings.name);
System.out.println("settings.inputColumns = " + settings.inputColumns
.stream()
.map(col -> "(" + col.name + "," + col.type + ")")
.collect(Collectors.joining(";")));
System.out.println("settings.targetColumns = " + settings.targetColumns
.stream()
.map(col -> "(" + col.name + "," + col.type + ")")
.collect(Collectors.joining(";")));
}
private static void activity_validation_learner() throws IOException {
ReadCSV csv = new ReadCSV("../datasets/backup/activity_data.csv", true, CSVFormat.DECIMAL_POINT);
String[] line = new String[11];
Learner learner = new Learner();
learner.activity_train("../datasets/backup/activity_data.csv");
learner.preference_train("../datasets/backup/preference_data.csv");
Learner learner = new Learner(new ObjectMapper().readValue(
Paths.get("src", "main", "resources", "activity_definition.json").toFile(),
LearnerSettings.class));
learner.train(Paths.get("src", "test", "activity_data.csv").toUri().toURL());
// learner.preference_train("../datasets/backup/preference_data.csv");
int wrong = 0;
int right = 0;
int i = 0;
......@@ -46,7 +78,7 @@ public class Main {
line[10] = csv.get(10);
//line[11] = csv.get(11);
String correct = csv.get(11);
String irisChosen = learner.activity_predictor(line);
String irisChosen = learner.predictor(line)[0];
result.append(Arrays.toString(line));
result.append(" -> predicted: ");
result.append(irisChosen);
......
......@@ -19,7 +19,7 @@ public class ReaderCSV{
//read every 5 s from csv
//activity CSV
/**
/*
* Col 1: smartphone acceleration x
* Col 2: smartphone acceleration y
* Col 3: smartphone acceleration z
......@@ -33,7 +33,7 @@ public class ReaderCSV{
* Col 11: watch rotation y
* Col 12: watch rotation z/*/
//preference CSV
/**
/*
* Col 1: Activity
* Col 2: watch brightness range "bright, medium, dimmer, dark"*/
......
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup.data;
import com.fasterxml.jackson.annotation.JsonInclude;
import de.tudresden.inf.st.eraser.feedbackloop.learner_backup.Learner;
import org.encog.ml.data.versatile.columns.ColumnDefinition;
import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.model.EncogModel;
import java.util.List;
/**
* Settings to initialize the {@link Learner}.
*
* @author rschoene - Initial contribution
*/
@JsonInclude(JsonInclude.Include.NON_DEFAULT)
public class LearnerSettings {
/** Description what this learner is about to learn */
public String name;
/** Input columns (of the CSV used for training) */
public List<LearnerSettingsColumnDefinition> inputColumns;
/** Target columns (of the CSV used for training) */
public List<LearnerSettingsColumnDefinition> targetColumns;
/** Training method */
public String trainingMethod = MLMethodFactory.TYPE_FEEDFORWARD;
/** Training parameter. Used in {@link EncogModel#holdBackValidation(double, boolean, int)} */
public double validationPercent = 0.3;
/** Training parameter. Used in {@link EncogModel#holdBackValidation(double, boolean, int)} and {@link EncogModel#crossvalidate(int, boolean)} */
public boolean shuffleForValidation = true;
/** Training parameter. Used in {@link EncogModel#holdBackValidation(double, boolean, int)} */
public int validationSeed = 1001;
/** Training parameter. Used in {@link EncogModel#crossvalidate(int, boolean)} */
public int validationFolds = 5;
/** Filename to load initial data from. Should be in another settings file! */
public String initialDataFile = null;
}
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup.data;
import com.fasterxml.jackson.annotation.JsonInclude;
import org.encog.ml.data.versatile.columns.ColumnType;
/**
* Simple representation of a {@link org.encog.ml.data.versatile.columns.ColumnDefinition}.
*
* @author rschoene - Initial contribution
*/
@JsonInclude(JsonInclude.Include.NON_DEFAULT)
public class LearnerSettingsColumnDefinition {
public String name;
public ColumnType type;
}
{
"name": "activity",
"inputColumns": [
{ "name": "m_accel_x", "type":"continuous" },
{ "name": "m_accel_y", "type": "continuous" },
{ "name": "m_accel_z", "type": "continuous" },
{ "name": "m_rotation_x", "type": "continuous" },
{ "name": "m_rotation_y", "type": "continuous" },
{ "name": "m_rotation_z", "type": "continuous" },
{ "name": "w_accel_x", "type": "continuous" },
{ "name": "w_accel_y", "type": "continuous" },
{ "name": "w_accel_z", "type": "continuous" },
{ "name": "w_rotation_x", "type": "continuous" },
{ "name": "w_rotation_y", "type": "continuous" },
{ "name": "w_rotation_z", "type": "continuous" }
],
"targetColumns": [
{ "name": "labels", "type": "nominal" }
],
"initialDataFile": "src/test/resources/activity_data.csv"
}
{
"name": "preference",
"inputColumns": [
{ "name": "activity", "type":"nominal" },
{ "name": "w_brightness", "type": "nominal" }
],
"targetColumns": [
{ "name": "label1", "type": "continuous" },
{ "name": "label2", "type": "continuous" }
],
"initialDataFile": "src/test/resources/preference_data.csv"
}
......@@ -5,6 +5,7 @@ import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningEncoder;
import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningHandlerFactory;
import de.tudresden.inf.st.eraser.jastadd.model.Root;
import java.io.IOException;
import java.net.URL;
/**
......@@ -24,8 +25,8 @@ class LearnerSubjectUnderTest {
factory.setKnowledgeBaseRoot(root);
}
void initFor(MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget, URL inputCsvFileName) {
factory.initializeFor(factoryTarget, inputCsvFileName);
void initFor(MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget, URL configURL) throws IOException {
factory.initializeFor(factoryTarget, configURL);
encoder = factory.createEncoder();
decoder = factory.createDecoder();
}
......
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import org.encog.Encog;
import org.junit.*;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.file.Path;
......@@ -21,14 +21,18 @@ import static de.tudresden.inf.st.eraser.jastadd.model.MachineLearningHandlerFac
*/
public class LearnerTest {
private static URL ACTIVITY_CONFIG;
private static URL ACTIVITY_DATA;
private static URL PREFERENCE_CONFIG;
private static URL PREFERENCE_DATA;
private LearnerSubjectUnderTest sut;
@BeforeClass
public static void setData() throws MalformedURLException {
Path base = Paths.get("src", "test", "resources");
ACTIVITY_CONFIG = base.resolve("activity_definition.json").toUri().toURL();
ACTIVITY_DATA = base.resolve("activity_data.csv").toUri().toURL();
PREFERENCE_CONFIG = base.resolve("preference_definition.json").toUri().toURL();
PREFERENCE_DATA = base.resolve("preference_data.csv").toUri().toURL();
}
......@@ -39,21 +43,21 @@ public class LearnerTest {
}
@Test
public void testActivities() {
LearnerTestUtils.testLearner(sut, ACTIVITY_DATA,
LearnerTestConstants.ACTIVITY_INPUT_ITEM_NAMES, line -> line[12], Collections.emptyMap(),
() -> sut.root.getSmartHomeEntityModel().getActivityItem(),
item -> sut.root.currentActivityName(),
ACTIVITY_RECOGNITION, false);
public void testActivities() throws IOException {
LearnerTestUtils.testLearner(sut,
new LearnerTestSettings()
.setConfigURL(ACTIVITY_CONFIG)
.setDataURL(ACTIVITY_DATA)
.setInputItemNames(LearnerTestConstants.ACTIVITY_INPUT_ITEM_NAMES)
.setExpectedOutput(line -> line[12])
.setOutputItemProvider(() -> sut.root.getSmartHomeEntityModel().getActivityItem())
.setStateOfOutputItem(item -> sut.root.currentActivityName())
.setFactoryTarget(ACTIVITY_RECOGNITION)
.setSingleUpdateList(false));
}
@Test
public void testPreferences() {
Map<String, BiConsumer<Item, String>> specialHandler = new HashMap<>();
specialHandler.put("activity", (item, value) -> item.setStateFromLong(
sut.root.resolveActivity(value)
.orElseThrow(() -> new AssertionError("Activity " + value + " not found"))
.getIdentifier()));
public void testPreferences() throws IOException {
// specialHandler.put("w_brightness", (item, value) -> {
// int target;
// switch (value) {
......@@ -66,13 +70,28 @@ public class LearnerTest {
// }
// item.setStateFromLong(target);
// });
LearnerTestUtils.testLearner(sut, PREFERENCE_DATA,
LearnerTestConstants.PREFERENCE_INPUT_ITEM_NAMES, LearnerTestUtils::decodeOutput, specialHandler,
() -> sut.root.getSmartHomeEntityModel().resolveItem(LearnerTestConstants.PREFERENCE_OUTPUT_ITEM_NAME)
LearnerTestUtils.testLearner(sut,
new LearnerTestSettings()
.setConfigURL(PREFERENCE_CONFIG)
.setDataURL(PREFERENCE_DATA)
.setInputItemNames(LearnerTestConstants.PREFERENCE_INPUT_ITEM_NAMES)
.setExpectedOutput(LearnerTestUtils::decodeOutput)
.putSpecialInputHandler("activity", (item, value) -> item.setStateFromLong(
sut.root.resolveActivity(value)
.orElseThrow(() -> new AssertionError("Activity " + value + " not found"))
.getIdentifier()))
.setOutputItemProvider(() -> sut.root.getSmartHomeEntityModel().resolveItem(LearnerTestConstants.PREFERENCE_OUTPUT_ITEM_NAME)
.orElseThrow(() -> new AssertionError(
"Item " + LearnerTestConstants.PREFERENCE_OUTPUT_ITEM_NAME + " not found")),
Item::getStateAsString,
PREFERENCE_LEARNING, true);
"Item " + LearnerTestConstants.PREFERENCE_OUTPUT_ITEM_NAME + " not found")))
.setStateOfOutputItem(Item::getStateAsString)
.setFactoryTarget(PREFERENCE_LEARNING)
.setSingleUpdateList(true));
}
@Ignore
@Test
public void testLoadedActivities() {
// TODO test with pre-trained model
}
@After
......
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import de.tudresden.inf.st.eraser.jastadd.model.Item;
import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningHandlerFactory;
import lombok.Data;
import lombok.experimental.Accessors;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;
@Data
@Accessors(chain = true)
public class LearnerTestSettings {
private URL configURL;
private URL dataURL;
private String[] inputItemNames;
private Function<String[], String> expectedOutput;
private final Map<String, BiConsumer<Item, String>> specialInputHandler = new HashMap<>();
private Supplier<Item> outputItemProvider;
private Function<Item, String> stateOfOutputItem;
private MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget;
private boolean singleUpdateList;
public LearnerTestSettings putSpecialInputHandler(String itemName, BiConsumer<Item, String> handler) {
specialInputHandler.put(itemName, handler);
return this;
}
// public LearnerTestSettings(URL configURL, URL inputCsvFileName, String[] inputItemNames, Function<String[], String> expectedOutput, Map<String, BiConsumer<Item, String>> specialInputHandler, Supplier<Item> outputItemProvider, Function<Item, String> stateOfOutputItem, MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget, boolean singleUpdateList) {
// this.configURL = configURL;
// this.inputCsvFileName = inputCsvFileName;
// this.inputItemNames = inputItemNames;
// this.expectedOutput = expectedOutput;
// this.specialInputHandler = specialInputHandler;
// this.outputItemProvider = outputItemProvider;
// this.stateOfOutputItem = stateOfOutputItem;
// this.factoryTarget = factoryTarget;
// this.singleUpdateList = singleUpdateList;
// }
//
// public URL getConfigURL() {
// return configURL;
// }
//
// public URL getInputCsvFileName() {
// return inputCsvFileName;
// }
//
// public String[] getInputItemNames() {
// return inputItemNames;
// }
//
// public Function<String[], String> getExpectedOutput() {
// return expectedOutput;
// }
//
// public Map<String, BiConsumer<Item, String>> getSpecialInputHandler() {
// return specialInputHandler;
// }
//
// public Supplier<Item> getOutputItemProvider() {
// return outputItemProvider;
// }
//
// public Function<Item, String> getStateOfOutputItem() {
// return stateOfOutputItem;
// }
//
// public MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget getFactoryTarget() {
// return factoryTarget;
// }
//
// public boolean isSingleUpdateList() {
// return singleUpdateList;
// }
}
......@@ -10,11 +10,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.URL;
import java.util.*;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.greaterThan;
......@@ -59,43 +55,35 @@ public class LearnerTestUtils {
}
static void testLearner(
LearnerSubjectUnderTest sut,
URL inputCsvFileName,
String[] inputItemNames,
Function<String[], String> expectedOutput,
Map<String, BiConsumer<Item, String>> specialInputHandler,
Supplier<Item> outputItemProvider,
Function<Item, String> stateOfOutputItem,
MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget,
boolean singleUpdateList) {
sut.initFor(factoryTarget, inputCsvFileName);
LearnerSubjectUnderTest sut, LearnerTestSettings learnerTestSettings) throws IOException {
sut.initFor(learnerTestSettings.getFactoryTarget(), learnerTestSettings.getConfigURL());
// maybe use factory.createModel() here instead
// go through same csv as for training and test some of the values
int correct = 0, wrong = 0;
try(InputStream is = inputCsvFileName.openStream();
try(InputStream is = learnerTestSettings.getDataURL().openStream();
Reader reader = new InputStreamReader(is);
CSVReader csvreader = new CSVReader(reader)) {
int index = 0;
for (String[] line : csvreader) {
if (++index % 10 == 0) {
// only check every 10th line, push an update for every 12 input columns
List<Item> itemsToUpdate = new ArrayList<>(inputItemNames.length);
for (int i = 0; i < inputItemNames.length; i++) {
String itemName = inputItemNames[i];
List<Item> itemsToUpdate = new ArrayList<>(learnerTestSettings.getInputItemNames().length);
for (int i = 0; i < learnerTestSettings.getInputItemNames().length; i++) {
String itemName = learnerTestSettings.getInputItemNames()[i];
Item item = sut.root.getSmartHomeEntityModel().resolveItem(itemName)
.orElseThrow(() -> new AssertionError("Item " + itemName + " not found"));
if (specialInputHandler.containsKey(itemName)) {
specialInputHandler.get(itemName).accept(item, line[i]);
if (learnerTestSettings.getSpecialInputHandler().containsKey(itemName)) {
learnerTestSettings.getSpecialInputHandler().get(itemName).accept(item, line[i]);
} else {
item.setStateFromString(line[i]);
}
if (singleUpdateList) {
if (learnerTestSettings.isSingleUpdateList()) {
itemsToUpdate.add(item);
} else {
sut.encoder.newData(Collections.singletonList(item));
}
}
if (singleUpdateList) {
if (learnerTestSettings.isSingleUpdateList()) {
sut.encoder.newData(itemsToUpdate);
}
MachineLearningResult result = sut.decoder.classify();
......@@ -103,11 +91,11 @@ public class LearnerTestUtils {
assertEquals("Not one item update!", 1, result.getNumItemUpdate());
ItemUpdate update = result.getItemUpdate(0);
// check that the output item is to be updated
assertEquals("Output item not to be updated!", outputItemProvider.get(), update.getItem());
assertEquals("Output item not to be updated!", learnerTestSettings.getOutputItemProvider().get(), update.getItem());
update.apply();
// check if the correct new state was set
String expected = expectedOutput.apply(line);
String actual = stateOfOutputItem.apply(update.getItem());
String expected = learnerTestSettings.getExpectedOutput().apply(line);
String actual = learnerTestSettings.getStateOfOutputItem().apply(update.getItem());
if (expected.equals(actual)) {
correct++;
} else {
......@@ -116,9 +104,6 @@ public class LearnerTestUtils {
}
}
}
} catch (IOException e) {
e.printStackTrace();
fail();
}
assertThat(correct + wrong, greaterThan(0));
double accuracy = correct * 1.0 / (correct + wrong);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment