From bafabb58c678117d9e1caeb33952ecb82cdd36c3 Mon Sep 17 00:00:00 2001 From: rschoene <rene.schoene@tu-dresden.de> Date: Wed, 30 Oct 2019 10:35:01 +0100 Subject: [PATCH] Restructure Learner to only handle one goal. - Introduce JSON definition of inputs and outputs, not final yet - Introduce LearnerTestSettings to avoid long parameter list --- .../model/MachineLearningHandlerFactory.java | 77 ++++++- .../inf/st/eraser/starter/EraserStarter.java | 12 +- .../src/main/resources/activity_data.csv | 1 + .../main/resources/activity_definition.json | 1 + eraser.starter/starter-setting.yaml | 2 +- feedbackloop.learner_backup/build.gradle | 5 +- .../DummyPreference.java | 43 ++-- .../feedbackloop.learner_backup/Learner.java | 202 +++++++----------- .../MachineLearningHandlerFactoryImpl.java | 15 +- .../MachineLearningImpl.java | 34 ++- .../feedbackloop.learner_backup/Main.java | 60 ++++-- .../ReaderCSV.java | 6 +- .../learner_backup/data/LearnerSettings.java | 37 ++++ .../data/LearnerSettingsColumnDefinition.java | 15 ++ .../main/resources/activity_definition.json | 21 ++ .../main/resources/preference_definition.json | 12 ++ .../LearnerSubjectUnderTest.java | 5 +- .../learner_backup/LearnerTest.java | 59 +++-- .../learner_backup/LearnerTestSettings.java | 80 +++++++ .../learner_backup/LearnerTestUtils.java | 41 ++-- .../test/resources/activity_definition.json | 1 + .../test/resources/preference_definition.json | 1 + 22 files changed, 492 insertions(+), 238 deletions(-) create mode 120000 eraser.starter/src/main/resources/activity_data.csv create mode 120000 eraser.starter/src/main/resources/activity_definition.json create mode 100644 feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/data/LearnerSettings.java create mode 100644 feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/data/LearnerSettingsColumnDefinition.java create mode 100644 feedbackloop.learner_backup/src/main/resources/activity_definition.json create mode 100644 feedbackloop.learner_backup/src/main/resources/preference_definition.json create mode 100644 feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTestSettings.java create mode 120000 feedbackloop.learner_backup/src/test/resources/activity_definition.json create mode 120000 feedbackloop.learner_backup/src/test/resources/preference_definition.json diff --git a/eraser-base/src/main/java/de/tudresden/inf/st/eraser/jastadd/model/MachineLearningHandlerFactory.java b/eraser-base/src/main/java/de/tudresden/inf/st/eraser/jastadd/model/MachineLearningHandlerFactory.java index 5d982f34..a8a7dcfe 100644 --- a/eraser-base/src/main/java/de/tudresden/inf/st/eraser/jastadd/model/MachineLearningHandlerFactory.java +++ b/eraser-base/src/main/java/de/tudresden/inf/st/eraser/jastadd/model/MachineLearningHandlerFactory.java @@ -1,6 +1,13 @@ package de.tudresden.inf.st.eraser.jastadd.model; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.io.IOException; import java.net.URL; +import java.time.Instant; +import java.util.Collections; +import java.util.List; /** * Factory to create new handlers ({@link MachineLearningEncoder} and {@link MachineLearningDecoder}) @@ -21,7 +28,7 @@ public abstract class MachineLearningHandlerFactory implements MachineLearningSe this.knowledgeBase = root; } - public abstract void initializeFor(MachineLearningHandlerFactoryTarget target, URL configUrl); + public abstract void initializeFor(MachineLearningHandlerFactoryTarget target, URL configUrl) throws IOException; public abstract MachineLearningEncoder createEncoder(); @@ -46,4 +53,72 @@ public abstract class MachineLearningHandlerFactory implements MachineLearningSe public void shutdown() { // empty by default } + + /** + * Creates a new factory to be used, if there was an error during configuration of a factory + * @return a new factory logging warning messages upon each invoked method + */ + public static MachineLearningHandlerFactory createErrorFactory() { + return new MachineLearningHandlerFactory() { + private final Logger logger = LogManager.getLogger(MachineLearningHandlerFactory.class); + @Override + public void initializeFor(MachineLearningHandlerFactoryTarget target, URL configUrl) { + logger.warn("initializeFor called for ErrorFactory"); + } + + @Override + public MachineLearningEncoder createEncoder() { + return new MachineLearningEncoder() { + @Override + public void newData(List<Item> changedItems) { + logger.warn("newData called for encoder of ErrorFactory"); + } + + @Override + public List<Item> getTargets() { + logger.warn("getTargets called for encoder of ErrorFactory"); + return Collections.emptyList(); + } + + @Override + public List<Item> getRelevantItems() { + logger.warn("getRelevantItems called for encoder of ErrorFactory"); + return Collections.emptyList(); + } + + @Override + public void triggerTraining() { + logger.warn("triggerTraining called for encoder of ErrorFactory"); + } + + @Override + public void setKnowledgeBaseRoot(Root root) { + logger.warn("setKnowledgeBaseRoot called for encoder of ErrorFactory"); + } + }; + } + + @Override + public MachineLearningDecoder createDecoder() { + return new MachineLearningDecoder() { + @Override + public MachineLearningResult classify() { + logger.warn("classify called for decoder of ErrorFactory"); + return new MachineLearningResult(); + } + + @Override + public Instant lastModelUpdate() { + logger.warn("lastModelUpdate called for decoder of ErrorFactory"); + return null; + } + + @Override + public void setKnowledgeBaseRoot(Root root) { + logger.warn("setKnowledgeBaseRoot called for decoder of ErrorFactory"); + } + }; + } + }; + } } diff --git a/eraser.starter/src/main/java/de/tudresden/inf/st/eraser/starter/EraserStarter.java b/eraser.starter/src/main/java/de/tudresden/inf/st/eraser/starter/EraserStarter.java index 37b1f52f..c6134e86 100644 --- a/eraser.starter/src/main/java/de/tudresden/inf/st/eraser/starter/EraserStarter.java +++ b/eraser.starter/src/main/java/de/tudresden/inf/st/eraser/starter/EraserStarter.java @@ -223,23 +223,25 @@ public class EraserStarter { } private static MachineLearningHandlerFactory createFactory(MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget target, Setting.MLContainer config, Root root) { - MachineLearningHandlerFactory factory = new DummyMachineLearningHandlerFactory(); + MachineLearningHandlerFactory factory; String niceTargetName = target.toString().toLowerCase().replace("_", " "); if (config.dummy || config.factory == null) { logger.info("Using dummy {}, ignoring other settings for this", niceTargetName); + factory = new DummyMachineLearningHandlerFactory(); } else { try { Class<? extends MachineLearningHandlerFactory> clazz = Class.forName(config.factory) .asSubclass(MachineLearningHandlerFactory.class); factory = clazz.newInstance(); - } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { - logger.error("Could not instantiate machine learning factory for {} with class '{}'. Using dummy instead.", + factory.setKnowledgeBaseRoot(root); + factory.initializeFor(target, config.realURL()); + } catch (ClassNotFoundException | InstantiationException | IllegalAccessException | IOException e) { + logger.error("Could not instantiate machine learning factory for {} with class '{}'. Using error factory instead.", niceTargetName, config); logger.catching(e); + factory = MachineLearningHandlerFactory.createErrorFactory(); } } - factory.setKnowledgeBaseRoot(root); - factory.initializeFor(target, config.realURL()); return factory; } } diff --git a/eraser.starter/src/main/resources/activity_data.csv b/eraser.starter/src/main/resources/activity_data.csv new file mode 120000 index 00000000..b758a4ea --- /dev/null +++ b/eraser.starter/src/main/resources/activity_data.csv @@ -0,0 +1 @@ +../../../../feedbackloop.learner_backup/src/test/resources/activity_data.csv \ No newline at end of file diff --git a/eraser.starter/src/main/resources/activity_definition.json b/eraser.starter/src/main/resources/activity_definition.json new file mode 120000 index 00000000..9a98b24c --- /dev/null +++ b/eraser.starter/src/main/resources/activity_definition.json @@ -0,0 +1 @@ +../../../../feedbackloop.learner_backup/src/main/resources/activity_definition.json \ No newline at end of file diff --git a/eraser.starter/starter-setting.yaml b/eraser.starter/starter-setting.yaml index 6882f55d..3423a904 100644 --- a/eraser.starter/starter-setting.yaml +++ b/eraser.starter/starter-setting.yaml @@ -26,7 +26,7 @@ load: activity: factory: de.tudresden.inf.st.eraser.feedbackloop.learner_backup.MachineLearningHandlerFactoryImpl # File to read in. Expected format depends on factory - file: ../datasets/backup/activity_data.csv + file: ../datasets/backup/activity_definition.json external: true # Use dummy model in which the current activity is directly editable. Default: false. dummy: false diff --git a/feedbackloop.learner_backup/build.gradle b/feedbackloop.learner_backup/build.gradle index 5ab2310a..4ee3fb5d 100644 --- a/feedbackloop.learner_backup/build.gradle +++ b/feedbackloop.learner_backup/build.gradle @@ -1,4 +1,7 @@ -apply plugin: 'application' +plugins { + id 'application' + id 'io.franzbecker.gradle-lombok' version '3.0.0' +} dependencies { compile project(':eraser-base') diff --git a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/DummyPreference.java b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/DummyPreference.java index 59f3ad9f..58eaf48b 100644 --- a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/DummyPreference.java +++ b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/DummyPreference.java @@ -9,25 +9,19 @@ import java.util.Random; public class DummyPreference { private static String activity; //Activity: walking, reading, working, dancing, lying, getting up - private static String watch_brightness; //dark: <45; dimmer 45-70; bright >70; - private static String light_color_openhab_H; //red 7; green 120; blue 240; yellow 60; sky blue 180; purple 300; - private static String brightness_output; //1-100**/ private static Random random = new Random(); public static void main(String[] args) { creator(); } - static void creator(){ - - - try{ + private static void creator(){ + try { FileWriter writer = new FileWriter("datasets/backup/preference_data.csv",true); CSVWriter csv_writer = new CSVWriter(writer, ',', CSVWriter.NO_QUOTE_CHARACTER, CSVWriter.DEFAULT_ESCAPE_CHARACTER, CSVWriter.DEFAULT_LINE_END); - //activity="walking" green activity ="walking"; @@ -40,33 +34,46 @@ public class DummyPreference { csv_writer.writeAll(generator("getting up","yellow")); csv_writer.close(); writer.close(); - }catch (IOException e){e.printStackTrace();} + } catch (IOException e) { + e.printStackTrace(); + } } - static List<String[]> generator(String activity_input, String color){ - List<String[]> data = new ArrayList<String[]>(); + + /** + * Generate random data + * @param activity_input input for activity + * @param color red 7; green 120; blue 240; yellow 60; sky blue 180; purple 300; + * @return generated data + */ + private static List<String[]> generator(String activity_input, String color){ + List<String[]> data = new ArrayList<>(); + //dark: <45; dimmer 45-70; bright >70; + //1-100**/ + String brightness_output; + String watch_brightness; activity = activity_input; - light_color_openhab_H =color; + // //100 walking with different lighting intensity - for (int i=0; i<100; i++){ + for (int i=0; i<100; i++) { String[] add_data = new String[4]; int brightness = random.nextInt(3000); System.out.println(brightness); - if (brightness<45){ + if (brightness<45) { watch_brightness = "dark"; brightness_output ="100"; - }else if(45<=brightness && brightness<200){ + } else if(brightness < 200) { watch_brightness = "dimmer"; brightness_output ="40"; - }else if( 200<=brightness && brightness<1000){ + } else if(brightness < 1000) { watch_brightness = "medium"; brightness_output ="70"; - }else{ + } else { watch_brightness = "bright"; brightness_output ="0"; } add_data[0] = activity; add_data[1] = watch_brightness; - add_data[2] = light_color_openhab_H; + add_data[2] = color; add_data[3] = brightness_output; data.add(add_data); } diff --git a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/Learner.java b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/Learner.java index c9d33ba6..f7fe7afe 100644 --- a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/Learner.java +++ b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/Learner.java @@ -1,5 +1,8 @@ package de.tudresden.inf.st.eraser.feedbackloop.learner_backup; +import com.fasterxml.jackson.databind.ObjectMapper; +import de.tudresden.inf.st.eraser.feedbackloop.learner_backup.data.LearnerSettings; +import de.tudresden.inf.st.eraser.feedbackloop.learner_backup.data.LearnerSettingsColumnDefinition; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.encog.Encog; @@ -13,145 +16,106 @@ import org.encog.ml.data.versatile.sources.VersatileDataSource; import org.encog.ml.factory.MLMethodFactory; import org.encog.ml.model.EncogModel; import org.encog.neural.networks.BasicNetwork; +import org.encog.persist.EncogDirectoryPersistence; import org.encog.util.csv.CSVFormat; import java.io.File; import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.ArrayList; +import java.util.List; -import static org.encog.persist.EncogDirectoryPersistence.loadObject; -import static org.encog.persist.EncogDirectoryPersistence.saveObject; - +/** + * Internal class for Neural Networks using Encog. + * Expects either a CSV file with data points to learn from, + * or serialized versions of a trained network and its normalization helpers. + * + * @author rschoene - Initial contribution + */ public class Learner { + private File modelFile; + private NormalizationHelper normalizationHelper; + private BasicNetwork network; + private static final Logger logger = LogManager.getLogger(Learner.class); + + private final LearnerSettings settings; + /** - * intial train + * Creates a new learner object using the configuration at the given URL. + * @param configURL the location of the configuration + * @throws IOException if an error occurs in {@link ObjectMapper#readValue(URL, Class)} */ - private File save_activity_model_file; - private File save_preference_model_file; - private File csv_file; - private VersatileMLDataSet a_data; - private VersatileMLDataSet p_data; - private EncogModel a_model; - private EncogModel p_model; - private NormalizationHelper activity_helper; - private NormalizationHelper preference_helper; - private BasicNetwork a_best_method; - private BasicNetwork p_best_method; - private final Logger logger = LogManager.getLogger(Learner.class); + Learner(URL configURL) throws IOException { + this(new ObjectMapper().readValue(configURL, LearnerSettings.class)); + } - public Learner() { - try { - save_activity_model_file = File.createTempFile("activity_model", "eg"); - } catch (IOException e) { - // use local alternative - save_activity_model_file = new File("activity_model.eq"); - } + Learner(LearnerSettings settings) { + this.settings = settings; try { - save_preference_model_file = File.createTempFile("preference_model", "eg"); + modelFile = File.createTempFile(settings.name + "_model", "eg"); } catch (IOException e) { // use local alternative - save_preference_model_file = new File("preference_model.eg"); + modelFile = new File(settings.name + "_model.eg"); } - save_activity_model_file.deleteOnExit(); - save_preference_model_file.deleteOnExit(); - } - - private void activityDataAnalyser(String activity_csv_url) { - VersatileDataSource a_source; - String csv_url_activity; - csv_url_activity = activity_csv_url; - this.csv_file = new File(csv_url_activity); - a_source = new CSVDataSource(csv_file, true, CSVFormat.DECIMAL_POINT); - a_data = new VersatileMLDataSet(a_source); - String[] activity_inputs = {"m_accel_x", "m_accel_y", "m_accel_z", - "m_rotation_x", "m_rotation_y", "m_rotation_z", - "w_accel_x", "w_accel_y", "w_accel_z", - "w_rotation_x", "w_rotation_y", "w_rotation_z" - }; - - for (int i = 0; i < activity_inputs.length; i++) { - a_data.defineSourceColumn(activity_inputs[i], i, ColumnType.continuous); - } - ColumnDefinition outputColumn = a_data.defineSourceColumn("labels", 12, ColumnType.nominal); - a_data.defineSingleOutputOthersInput(outputColumn); - a_data.analyze(); - a_model = new EncogModel(a_data); - a_model.selectMethod(a_data, MLMethodFactory.TYPE_FEEDFORWARD); - a_data.normalize(); - activity_helper = a_data.getNormHelper(); - } - - private void preferenceDataAnalyser(String preference_csv_url) { - VersatileDataSource p_source; - String csv_url_preference; - csv_url_preference = preference_csv_url; - this.csv_file = new File(csv_url_preference); - p_source = new CSVDataSource(csv_file, true, CSVFormat.DECIMAL_POINT); - p_data = new VersatileMLDataSet(p_source); - p_data.defineSourceColumn("activity", 0, ColumnType.nominal); - p_data.defineSourceColumn("w_brightness", 1, ColumnType.nominal); - ColumnDefinition outputColumn1 = p_data.defineSourceColumn("label1", 2, ColumnType.continuous); - ColumnDefinition outputColumn2 = p_data.defineSourceColumn("label2", 3, ColumnType.continuous); - ColumnDefinition[] outputs = new ColumnDefinition[2]; - outputs[0] = outputColumn1; - outputs[1] = outputColumn2; - p_data.defineMultipleOutputsOthersInput(outputs); - p_data.analyze(); - p_model = new EncogModel(p_data); - p_model.selectMethod(p_data, MLMethodFactory.TYPE_FEEDFORWARD); - p_data.normalize(); - preference_helper = p_data.getNormHelper(); + modelFile.deleteOnExit(); } - void activity_train(String activity_csv_url) { - logger.info("Activity training is beginning ... ..."); - activityDataAnalyser(activity_csv_url); - a_model.holdBackValidation(0.3, true, 1001); - a_model.selectTrainingType(a_data); - a_best_method = (BasicNetwork) a_model.crossvalidate(5, true); - saveEncogModel(save_activity_model_file); - logger.info("Activity training is finished ... ..."); - } - - void preference_train(String prefence_csv_url) { - logger.info("Preference training is beginning ... ..."); - preferenceDataAnalyser(prefence_csv_url); - p_model.holdBackValidation(0.3, true, 1001); - p_model.selectTrainingType(p_data); - p_best_method = (BasicNetwork) p_model.crossvalidate(5, true); - saveEncogModel(save_preference_model_file); - logger.info("Preference training is finished ... ..."); - } - - String activity_predictor(String[] new_data) { - String activity_result; - MLData input = activity_helper.allocateInputVector(); - activity_helper.normalizeInputVector(new_data, input.getData(), false); - MLData output = a_best_method.compute(input); - activity_result = activity_helper.denormalizeOutputVectorToString(output)[0]; - logger.debug("Activity Predictor result is: {}", activity_result); - return activity_result; - } - - String[] preference_predictor(String[] new_data) { - MLData input = preference_helper.allocateInputVector(); - preference_helper.normalizeInputVector(new_data, input.getData(), false); - MLData output = p_best_method.compute(input); - String[] preference_result = new String[] { - preference_helper.denormalizeOutputVectorToString(output)[0], - preference_helper.denormalizeOutputVectorToString(output)[1] - }; - logger.debug("Preference Predictor result, Color: {}, brightness: {}", - Math.round(Float.parseFloat(preference_result[0])), Math.round(Float.parseFloat(preference_result[1]))); - return preference_result; + /** + * Begin training using the training set specified in settings. + * @throws MalformedURLException if the location of the training set in the settings is malformed + */ + void train() throws MalformedURLException { + URL location = new File(settings.initialDataFile).toURI().toURL(); + train(location); } - private void saveEncogModel(File modelFile) { - if (modelFile.equals(save_activity_model_file)) { - saveObject(modelFile, this.a_best_method); + /** + * Begin training with the given initial training set. + * @param location the location of the training set + */ + void train(URL location) { + logger.info("Training for {} begins using {}", settings.name, location); + VersatileDataSource source; + File csvFile = new File(location.getFile()); + source = new CSVDataSource(csvFile, true, CSVFormat.DECIMAL_POINT); + VersatileMLDataSet data = new VersatileMLDataSet(source); + final int inputSize = settings.inputColumns.size(); + for (int index = 0; index < inputSize; index++) { + LearnerSettingsColumnDefinition columnDefinition = settings.inputColumns.get(index); + data.defineSourceColumn(columnDefinition.name, index, columnDefinition.type); + } + List<ColumnDefinition> targets = new ArrayList<>(); + for (int targetIndex = 0; targetIndex < settings.targetColumns.size(); targetIndex++) { + LearnerSettingsColumnDefinition columnDefinition = settings.targetColumns.get(targetIndex); + targets.add(data.defineSourceColumn(columnDefinition.name, inputSize + targetIndex, columnDefinition.type)); + } + if (targets.size() == 1) { + data.defineSingleOutputOthersInput(targets.get(0)); } else { - saveObject(modelFile, this.p_best_method); + data.defineMultipleOutputsOthersInput(targets.toArray(new ColumnDefinition[0])); } + data.analyze(); + EncogModel model = new EncogModel(data); + model.selectMethod(data, settings.trainingMethod); + data.normalize(); + normalizationHelper = data.getNormHelper(); + model.holdBackValidation(settings.validationPercent, settings.shuffleForValidation, settings.validationSeed); + model.selectTrainingType(data); + network = (BasicNetwork) model.crossvalidate(settings.validationFolds, settings.shuffleForValidation); + EncogDirectoryPersistence.saveObject(modelFile, network); + logger.info("Training for {} finished", settings.name); + } + + String[] predictor(String[] newData) { + String[] result; + MLData input = normalizationHelper.allocateInputVector(); + normalizationHelper.normalizeInputVector(newData, input.getData(), false); + MLData output = network.compute(input); + result = normalizationHelper.denormalizeOutputVectorToString(output); + logger.debug("Result prediction for {} applied on {} is {}:", settings.name, newData, result); + return result; } void shutdown() { diff --git a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/MachineLearningHandlerFactoryImpl.java b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/MachineLearningHandlerFactoryImpl.java index 3d4500e5..8e808454 100644 --- a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/MachineLearningHandlerFactoryImpl.java +++ b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/MachineLearningHandlerFactoryImpl.java @@ -4,6 +4,7 @@ import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningDecoder; import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningEncoder; import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningHandlerFactory; +import java.io.IOException; import java.net.URL; /** @@ -14,24 +15,22 @@ import java.net.URL; public class MachineLearningHandlerFactoryImpl extends MachineLearningHandlerFactory { private MachineLearningImpl handler; - private static Learner learner = new Learner(); @Override - public void initializeFor(MachineLearningHandlerFactoryTarget target, URL configUrl) { + public void initializeFor(MachineLearningHandlerFactoryTarget target, URL configUrl) throws IOException { switch (target) { case ACTIVITY_RECOGNITION: - handler = new MachineLearningImpl(learner, MachineLearningImpl.GOAL_ACTIVITY_PHONE_AND_WATCH); + handler = new MachineLearningImpl(MachineLearningImpl.GOAL_ACTIVITY_PHONE_AND_WATCH, configUrl); handler.setKnowledgeBaseRoot(knowledgeBase); - handler.initActivities(configUrl.getFile()); break; case PREFERENCE_LEARNING: - handler = new MachineLearningImpl(learner, MachineLearningImpl.GOAL_PREFERENCE_BRIGHTNESS_IRIS); + handler = new MachineLearningImpl(MachineLearningImpl.GOAL_PREFERENCE_BRIGHTNESS_IRIS, configUrl); handler.setKnowledgeBaseRoot(knowledgeBase); - handler.initPreferences(configUrl.getFile()); break; default: throw new UnsupportedOperationException("Target " + target + " is not supported"); } + handler.startTraining(); } @Override @@ -45,6 +44,8 @@ public class MachineLearningHandlerFactoryImpl extends MachineLearningHandlerFac } public void shutdown() { - learner.shutdown(); + if (handler != null) { + handler.getLearner().shutdown(); + } } } diff --git a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/MachineLearningImpl.java b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/MachineLearningImpl.java index 0155d93a..f5d24109 100644 --- a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/MachineLearningImpl.java +++ b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/MachineLearningImpl.java @@ -4,6 +4,9 @@ import de.tudresden.inf.st.eraser.jastadd.model.*; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; import java.time.Instant; import java.util.*; import java.util.stream.Collectors; @@ -29,20 +32,17 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn private Root root; private String activity_result; + private Instant lastModelUpdate; - public MachineLearningImpl(Learner learner, int goal) { - this.learner = learner; + public MachineLearningImpl(int goal, URL configURL) throws IOException { + this.learner = new Learner(configURL); this.goal = goal; } @Override public void setKnowledgeBaseRoot(Root root) { this.root = root; - updateItems(); - } - - private void updateItems() { - SmartHomeEntityModel model = root.getSmartHomeEntityModel(); + SmartHomeEntityModel model = this.root.getSmartHomeEntityModel(); List<String> targetItemNames, relevantItemNames; switch (this.goal) { case GOAL_ACTIVITY_PHONE_AND_WATCH: @@ -123,7 +123,7 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn } } if(!empty){ - this.activity_result = learner.activity_predictor(a_new_data); + this.activity_result = learner.predictor(a_new_data)[0]; } a_length = 0; Arrays.fill(this.a_new_data, null); @@ -140,9 +140,7 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn this.p_new_data[1] = item.getStateAsString(); } } - logger.debug("debug_preference_new_data: {}", Arrays.toString(this.p_new_data)); - this.preference_result = learner.preference_predictor(this.p_new_data); - logger.debug("preference for {} is {}", Arrays.toString(this.p_new_data), this.preference_result); + this.preference_result = learner.predictor(this.p_new_data); } } @@ -163,7 +161,7 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn @Override public Instant lastModelUpdate() { - return null; + return this.lastModelUpdate; } @Override @@ -202,14 +200,12 @@ public class MachineLearningImpl implements MachineLearningDecoder, MachineLearn } } - void initActivities(String filenameOfCsv) { - logger.debug("init activities with {}", filenameOfCsv); - learner.activity_train(filenameOfCsv); + void startTraining() throws MalformedURLException { + learner.train(); + this.lastModelUpdate = Instant.now(); } - void initPreferences(String filenameOfCsv) { - logger.debug("init preferences with {}", filenameOfCsv); - learner.preference_train(filenameOfCsv); + Learner getLearner() { + return learner; } - } diff --git a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/Main.java b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/Main.java index ee281c9b..5df0fc27 100644 --- a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/Main.java +++ b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/Main.java @@ -1,30 +1,62 @@ package de.tudresden.inf.st.eraser.feedbackloop.learner_backup; +import com.fasterxml.jackson.databind.ObjectMapper; +import de.tudresden.inf.st.eraser.feedbackloop.learner_backup.data.LearnerSettings; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.encog.util.csv.CSVFormat; import org.encog.util.csv.ReadCSV; +import java.io.File; +import java.io.IOException; +import java.net.MalformedURLException; +import java.nio.file.Paths; import java.util.Arrays; +import java.util.stream.Collectors; public class Main { + private static final Logger logger = LogManager.getLogger(Main.class); + public static void main(String[] args) { - /* - new data from KB - */ - //ReaderCSV reader = new ReaderCSV("datasets/backup/activity_data.csv","preference"); - //reader.updater(); - //Learner learner=new Learner(); - //learner.preference_train("../datasets/backup/preference_data.csv"); - //learner.train("datasets/backup/activity_data.csv","datasets/backup/preference_data.csv"); - activity_validation_learner(); +// ReaderCSV reader = new ReaderCSV("datasets/backup/activity_data.csv","preference"); +// reader.updater(); +// Learner learner=new Learner(); +// learner.preference_train("../datasets/backup/preference_data.csv"); +// learner.train("datasets/backup/activity_data.csv","datasets/backup/preference_data.csv"); +// activity_validation_learner(); + testSettings(); + } + + private static void testSettings() { + ObjectMapper mapper = new ObjectMapper(); + File settingsFile = Paths.get("src", "main", "resources", "activity_definition.json").toFile(); + LearnerSettings settings; + try { + settings = mapper.readValue(settingsFile, LearnerSettings.class); + } catch (IOException e) { + logger.catching(e); + return; + } + System.out.println("settings.name = " + settings.name); + System.out.println("settings.inputColumns = " + settings.inputColumns + .stream() + .map(col -> "(" + col.name + "," + col.type + ")") + .collect(Collectors.joining(";"))); + System.out.println("settings.targetColumns = " + settings.targetColumns + .stream() + .map(col -> "(" + col.name + "," + col.type + ")") + .collect(Collectors.joining(";"))); } - private static void activity_validation_learner() { + private static void activity_validation_learner() throws IOException { ReadCSV csv = new ReadCSV("../datasets/backup/activity_data.csv", true, CSVFormat.DECIMAL_POINT); String[] line = new String[11]; - Learner learner = new Learner(); - learner.activity_train("../datasets/backup/activity_data.csv"); - learner.preference_train("../datasets/backup/preference_data.csv"); + Learner learner = new Learner(new ObjectMapper().readValue( + Paths.get("src", "main", "resources", "activity_definition.json").toFile(), + LearnerSettings.class)); + learner.train(Paths.get("src", "test", "activity_data.csv").toUri().toURL()); +// learner.preference_train("../datasets/backup/preference_data.csv"); int wrong = 0; int right = 0; int i = 0; @@ -46,7 +78,7 @@ public class Main { line[10] = csv.get(10); //line[11] = csv.get(11); String correct = csv.get(11); - String irisChosen = learner.activity_predictor(line); + String irisChosen = learner.predictor(line)[0]; result.append(Arrays.toString(line)); result.append(" -> predicted: "); result.append(irisChosen); diff --git a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/ReaderCSV.java b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/ReaderCSV.java index 76c1ce79..0fd0f005 100644 --- a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/ReaderCSV.java +++ b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/ReaderCSV.java @@ -19,7 +19,7 @@ public class ReaderCSV{ //read every 5 s from csv //activity CSV - /** + /* * Col 1: smartphone acceleration x * Col 2: smartphone acceleration y * Col 3: smartphone acceleration z @@ -33,7 +33,7 @@ public class ReaderCSV{ * Col 11: watch rotation y * Col 12: watch rotation z/*/ //preference CSV - /** + /* * Col 1: Activity * Col 2: watch brightness range "bright, medium, dimmer, dark"*/ @@ -113,4 +113,4 @@ public class ReaderCSV{ } } -} \ No newline at end of file +} diff --git a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/data/LearnerSettings.java b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/data/LearnerSettings.java new file mode 100644 index 00000000..cd035bd1 --- /dev/null +++ b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/data/LearnerSettings.java @@ -0,0 +1,37 @@ +package de.tudresden.inf.st.eraser.feedbackloop.learner_backup.data; + +import com.fasterxml.jackson.annotation.JsonInclude; +import de.tudresden.inf.st.eraser.feedbackloop.learner_backup.Learner; +import org.encog.ml.data.versatile.columns.ColumnDefinition; +import org.encog.ml.factory.MLMethodFactory; +import org.encog.ml.model.EncogModel; + +import java.util.List; + +/** + * Settings to initialize the {@link Learner}. + * + * @author rschoene - Initial contribution + */ +@JsonInclude(JsonInclude.Include.NON_DEFAULT) +public class LearnerSettings { + /** Description what this learner is about to learn */ + public String name; + /** Input columns (of the CSV used for training) */ + public List<LearnerSettingsColumnDefinition> inputColumns; + /** Target columns (of the CSV used for training) */ + public List<LearnerSettingsColumnDefinition> targetColumns; + /** Training method */ + public String trainingMethod = MLMethodFactory.TYPE_FEEDFORWARD; + /** Training parameter. Used in {@link EncogModel#holdBackValidation(double, boolean, int)} */ + public double validationPercent = 0.3; + /** Training parameter. Used in {@link EncogModel#holdBackValidation(double, boolean, int)} and {@link EncogModel#crossvalidate(int, boolean)} */ + public boolean shuffleForValidation = true; + /** Training parameter. Used in {@link EncogModel#holdBackValidation(double, boolean, int)} */ + public int validationSeed = 1001; + /** Training parameter. Used in {@link EncogModel#crossvalidate(int, boolean)} */ + public int validationFolds = 5; + + /** Filename to load initial data from. Should be in another settings file! */ + public String initialDataFile = null; +} diff --git a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/data/LearnerSettingsColumnDefinition.java b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/data/LearnerSettingsColumnDefinition.java new file mode 100644 index 00000000..8f940837 --- /dev/null +++ b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/data/LearnerSettingsColumnDefinition.java @@ -0,0 +1,15 @@ +package de.tudresden.inf.st.eraser.feedbackloop.learner_backup.data; + +import com.fasterxml.jackson.annotation.JsonInclude; +import org.encog.ml.data.versatile.columns.ColumnType; + +/** + * Simple representation of a {@link org.encog.ml.data.versatile.columns.ColumnDefinition}. + * + * @author rschoene - Initial contribution + */ +@JsonInclude(JsonInclude.Include.NON_DEFAULT) +public class LearnerSettingsColumnDefinition { + public String name; + public ColumnType type; +} diff --git a/feedbackloop.learner_backup/src/main/resources/activity_definition.json b/feedbackloop.learner_backup/src/main/resources/activity_definition.json new file mode 100644 index 00000000..49325fc7 --- /dev/null +++ b/feedbackloop.learner_backup/src/main/resources/activity_definition.json @@ -0,0 +1,21 @@ +{ + "name": "activity", + "inputColumns": [ + { "name": "m_accel_x", "type":"continuous" }, + { "name": "m_accel_y", "type": "continuous" }, + { "name": "m_accel_z", "type": "continuous" }, + { "name": "m_rotation_x", "type": "continuous" }, + { "name": "m_rotation_y", "type": "continuous" }, + { "name": "m_rotation_z", "type": "continuous" }, + { "name": "w_accel_x", "type": "continuous" }, + { "name": "w_accel_y", "type": "continuous" }, + { "name": "w_accel_z", "type": "continuous" }, + { "name": "w_rotation_x", "type": "continuous" }, + { "name": "w_rotation_y", "type": "continuous" }, + { "name": "w_rotation_z", "type": "continuous" } + ], + "targetColumns": [ + { "name": "labels", "type": "nominal" } + ], + "initialDataFile": "src/test/resources/activity_data.csv" +} diff --git a/feedbackloop.learner_backup/src/main/resources/preference_definition.json b/feedbackloop.learner_backup/src/main/resources/preference_definition.json new file mode 100644 index 00000000..0d0d5011 --- /dev/null +++ b/feedbackloop.learner_backup/src/main/resources/preference_definition.json @@ -0,0 +1,12 @@ +{ + "name": "preference", + "inputColumns": [ + { "name": "activity", "type":"nominal" }, + { "name": "w_brightness", "type": "nominal" } + ], + "targetColumns": [ + { "name": "label1", "type": "continuous" }, + { "name": "label2", "type": "continuous" } + ], + "initialDataFile": "src/test/resources/preference_data.csv" +} diff --git a/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerSubjectUnderTest.java b/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerSubjectUnderTest.java index 5743f0a9..7a9a9c70 100644 --- a/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerSubjectUnderTest.java +++ b/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerSubjectUnderTest.java @@ -5,6 +5,7 @@ import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningEncoder; import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningHandlerFactory; import de.tudresden.inf.st.eraser.jastadd.model.Root; +import java.io.IOException; import java.net.URL; /** @@ -24,8 +25,8 @@ class LearnerSubjectUnderTest { factory.setKnowledgeBaseRoot(root); } - void initFor(MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget, URL inputCsvFileName) { - factory.initializeFor(factoryTarget, inputCsvFileName); + void initFor(MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget, URL configURL) throws IOException { + factory.initializeFor(factoryTarget, configURL); encoder = factory.createEncoder(); decoder = factory.createDecoder(); } diff --git a/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTest.java b/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTest.java index b82c6d3f..4e5b4cba 100644 --- a/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTest.java +++ b/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTest.java @@ -1,9 +1,9 @@ package de.tudresden.inf.st.eraser.feedbackloop.learner_backup; import de.tudresden.inf.st.eraser.jastadd.model.*; -import org.encog.Encog; import org.junit.*; +import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; import java.nio.file.Path; @@ -21,14 +21,18 @@ import static de.tudresden.inf.st.eraser.jastadd.model.MachineLearningHandlerFac */ public class LearnerTest { + private static URL ACTIVITY_CONFIG; private static URL ACTIVITY_DATA; + private static URL PREFERENCE_CONFIG; private static URL PREFERENCE_DATA; private LearnerSubjectUnderTest sut; @BeforeClass public static void setData() throws MalformedURLException { Path base = Paths.get("src", "test", "resources"); + ACTIVITY_CONFIG = base.resolve("activity_definition.json").toUri().toURL(); ACTIVITY_DATA = base.resolve("activity_data.csv").toUri().toURL(); + PREFERENCE_CONFIG = base.resolve("preference_definition.json").toUri().toURL(); PREFERENCE_DATA = base.resolve("preference_data.csv").toUri().toURL(); } @@ -39,21 +43,21 @@ public class LearnerTest { } @Test - public void testActivities() { - LearnerTestUtils.testLearner(sut, ACTIVITY_DATA, - LearnerTestConstants.ACTIVITY_INPUT_ITEM_NAMES, line -> line[12], Collections.emptyMap(), - () -> sut.root.getSmartHomeEntityModel().getActivityItem(), - item -> sut.root.currentActivityName(), - ACTIVITY_RECOGNITION, false); + public void testActivities() throws IOException { + LearnerTestUtils.testLearner(sut, + new LearnerTestSettings() + .setConfigURL(ACTIVITY_CONFIG) + .setDataURL(ACTIVITY_DATA) + .setInputItemNames(LearnerTestConstants.ACTIVITY_INPUT_ITEM_NAMES) + .setExpectedOutput(line -> line[12]) + .setOutputItemProvider(() -> sut.root.getSmartHomeEntityModel().getActivityItem()) + .setStateOfOutputItem(item -> sut.root.currentActivityName()) + .setFactoryTarget(ACTIVITY_RECOGNITION) + .setSingleUpdateList(false)); } @Test - public void testPreferences() { - Map<String, BiConsumer<Item, String>> specialHandler = new HashMap<>(); - specialHandler.put("activity", (item, value) -> item.setStateFromLong( - sut.root.resolveActivity(value) - .orElseThrow(() -> new AssertionError("Activity " + value + " not found")) - .getIdentifier())); + public void testPreferences() throws IOException { // specialHandler.put("w_brightness", (item, value) -> { // int target; // switch (value) { @@ -66,13 +70,28 @@ public class LearnerTest { // } // item.setStateFromLong(target); // }); - LearnerTestUtils.testLearner(sut, PREFERENCE_DATA, - LearnerTestConstants.PREFERENCE_INPUT_ITEM_NAMES, LearnerTestUtils::decodeOutput, specialHandler, - () -> sut.root.getSmartHomeEntityModel().resolveItem(LearnerTestConstants.PREFERENCE_OUTPUT_ITEM_NAME) - .orElseThrow(() -> new AssertionError( - "Item " + LearnerTestConstants.PREFERENCE_OUTPUT_ITEM_NAME + " not found")), - Item::getStateAsString, - PREFERENCE_LEARNING, true); + LearnerTestUtils.testLearner(sut, + new LearnerTestSettings() + .setConfigURL(PREFERENCE_CONFIG) + .setDataURL(PREFERENCE_DATA) + .setInputItemNames(LearnerTestConstants.PREFERENCE_INPUT_ITEM_NAMES) + .setExpectedOutput(LearnerTestUtils::decodeOutput) + .putSpecialInputHandler("activity", (item, value) -> item.setStateFromLong( + sut.root.resolveActivity(value) + .orElseThrow(() -> new AssertionError("Activity " + value + " not found")) + .getIdentifier())) + .setOutputItemProvider(() -> sut.root.getSmartHomeEntityModel().resolveItem(LearnerTestConstants.PREFERENCE_OUTPUT_ITEM_NAME) + .orElseThrow(() -> new AssertionError( + "Item " + LearnerTestConstants.PREFERENCE_OUTPUT_ITEM_NAME + " not found"))) + .setStateOfOutputItem(Item::getStateAsString) + .setFactoryTarget(PREFERENCE_LEARNING) + .setSingleUpdateList(true)); + } + + @Ignore + @Test + public void testLoadedActivities() { + // TODO test with pre-trained model } @After diff --git a/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTestSettings.java b/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTestSettings.java new file mode 100644 index 00000000..4cd4576c --- /dev/null +++ b/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTestSettings.java @@ -0,0 +1,80 @@ +package de.tudresden.inf.st.eraser.feedbackloop.learner_backup; + +import de.tudresden.inf.st.eraser.jastadd.model.Item; +import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningHandlerFactory; +import lombok.Data; +import lombok.experimental.Accessors; + +import java.net.URL; +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.function.Supplier; + +@Data +@Accessors(chain = true) +public class LearnerTestSettings { + private URL configURL; + private URL dataURL; + private String[] inputItemNames; + private Function<String[], String> expectedOutput; + private final Map<String, BiConsumer<Item, String>> specialInputHandler = new HashMap<>(); + private Supplier<Item> outputItemProvider; + private Function<Item, String> stateOfOutputItem; + private MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget; + private boolean singleUpdateList; + + public LearnerTestSettings putSpecialInputHandler(String itemName, BiConsumer<Item, String> handler) { + specialInputHandler.put(itemName, handler); + return this; + } + +// public LearnerTestSettings(URL configURL, URL inputCsvFileName, String[] inputItemNames, Function<String[], String> expectedOutput, Map<String, BiConsumer<Item, String>> specialInputHandler, Supplier<Item> outputItemProvider, Function<Item, String> stateOfOutputItem, MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget, boolean singleUpdateList) { +// this.configURL = configURL; +// this.inputCsvFileName = inputCsvFileName; +// this.inputItemNames = inputItemNames; +// this.expectedOutput = expectedOutput; +// this.specialInputHandler = specialInputHandler; +// this.outputItemProvider = outputItemProvider; +// this.stateOfOutputItem = stateOfOutputItem; +// this.factoryTarget = factoryTarget; +// this.singleUpdateList = singleUpdateList; +// } +// +// public URL getConfigURL() { +// return configURL; +// } +// +// public URL getInputCsvFileName() { +// return inputCsvFileName; +// } +// +// public String[] getInputItemNames() { +// return inputItemNames; +// } +// +// public Function<String[], String> getExpectedOutput() { +// return expectedOutput; +// } +// +// public Map<String, BiConsumer<Item, String>> getSpecialInputHandler() { +// return specialInputHandler; +// } +// +// public Supplier<Item> getOutputItemProvider() { +// return outputItemProvider; +// } +// +// public Function<Item, String> getStateOfOutputItem() { +// return stateOfOutputItem; +// } +// +// public MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget getFactoryTarget() { +// return factoryTarget; +// } +// +// public boolean isSingleUpdateList() { +// return singleUpdateList; +// } +} diff --git a/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTestUtils.java b/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTestUtils.java index f26ecbc1..da9f3680 100644 --- a/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTestUtils.java +++ b/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTestUtils.java @@ -10,11 +10,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.Reader; -import java.net.URL; import java.util.*; -import java.util.function.BiConsumer; -import java.util.function.Function; -import java.util.function.Supplier; import java.util.stream.Stream; import static org.hamcrest.Matchers.greaterThan; @@ -59,43 +55,35 @@ public class LearnerTestUtils { } static void testLearner( - LearnerSubjectUnderTest sut, - URL inputCsvFileName, - String[] inputItemNames, - Function<String[], String> expectedOutput, - Map<String, BiConsumer<Item, String>> specialInputHandler, - Supplier<Item> outputItemProvider, - Function<Item, String> stateOfOutputItem, - MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget, - boolean singleUpdateList) { - sut.initFor(factoryTarget, inputCsvFileName); + LearnerSubjectUnderTest sut, LearnerTestSettings learnerTestSettings) throws IOException { + sut.initFor(learnerTestSettings.getFactoryTarget(), learnerTestSettings.getConfigURL()); // maybe use factory.createModel() here instead // go through same csv as for training and test some of the values int correct = 0, wrong = 0; - try(InputStream is = inputCsvFileName.openStream(); + try(InputStream is = learnerTestSettings.getDataURL().openStream(); Reader reader = new InputStreamReader(is); CSVReader csvreader = new CSVReader(reader)) { int index = 0; for (String[] line : csvreader) { if (++index % 10 == 0) { // only check every 10th line, push an update for every 12 input columns - List<Item> itemsToUpdate = new ArrayList<>(inputItemNames.length); - for (int i = 0; i < inputItemNames.length; i++) { - String itemName = inputItemNames[i]; + List<Item> itemsToUpdate = new ArrayList<>(learnerTestSettings.getInputItemNames().length); + for (int i = 0; i < learnerTestSettings.getInputItemNames().length; i++) { + String itemName = learnerTestSettings.getInputItemNames()[i]; Item item = sut.root.getSmartHomeEntityModel().resolveItem(itemName) .orElseThrow(() -> new AssertionError("Item " + itemName + " not found")); - if (specialInputHandler.containsKey(itemName)) { - specialInputHandler.get(itemName).accept(item, line[i]); + if (learnerTestSettings.getSpecialInputHandler().containsKey(itemName)) { + learnerTestSettings.getSpecialInputHandler().get(itemName).accept(item, line[i]); } else { item.setStateFromString(line[i]); } - if (singleUpdateList) { + if (learnerTestSettings.isSingleUpdateList()) { itemsToUpdate.add(item); } else { sut.encoder.newData(Collections.singletonList(item)); } } - if (singleUpdateList) { + if (learnerTestSettings.isSingleUpdateList()) { sut.encoder.newData(itemsToUpdate); } MachineLearningResult result = sut.decoder.classify(); @@ -103,11 +91,11 @@ public class LearnerTestUtils { assertEquals("Not one item update!", 1, result.getNumItemUpdate()); ItemUpdate update = result.getItemUpdate(0); // check that the output item is to be updated - assertEquals("Output item not to be updated!", outputItemProvider.get(), update.getItem()); + assertEquals("Output item not to be updated!", learnerTestSettings.getOutputItemProvider().get(), update.getItem()); update.apply(); // check if the correct new state was set - String expected = expectedOutput.apply(line); - String actual = stateOfOutputItem.apply(update.getItem()); + String expected = learnerTestSettings.getExpectedOutput().apply(line); + String actual = learnerTestSettings.getStateOfOutputItem().apply(update.getItem()); if (expected.equals(actual)) { correct++; } else { @@ -116,9 +104,6 @@ public class LearnerTestUtils { } } } - } catch (IOException e) { - e.printStackTrace(); - fail(); } assertThat(correct + wrong, greaterThan(0)); double accuracy = correct * 1.0 / (correct + wrong); diff --git a/feedbackloop.learner_backup/src/test/resources/activity_definition.json b/feedbackloop.learner_backup/src/test/resources/activity_definition.json new file mode 120000 index 00000000..c6b66b3d --- /dev/null +++ b/feedbackloop.learner_backup/src/test/resources/activity_definition.json @@ -0,0 +1 @@ +../../../src/main/resources/activity_definition.json \ No newline at end of file diff --git a/feedbackloop.learner_backup/src/test/resources/preference_definition.json b/feedbackloop.learner_backup/src/test/resources/preference_definition.json new file mode 120000 index 00000000..0f88f935 --- /dev/null +++ b/feedbackloop.learner_backup/src/test/resources/preference_definition.json @@ -0,0 +1 @@ +../../../src/main/resources/preference_definition.json \ No newline at end of file -- GitLab