diff --git a/eraser.starter/src/main/test.iml b/eraser.starter/src/main/test.iml new file mode 100644 index 0000000000000000000000000000000000000000..c90834f2d607afe55e6104d8aa2cdfffb713f688 --- /dev/null +++ b/eraser.starter/src/main/test.iml @@ -0,0 +1,11 @@ +<?xml version="1.0" encoding="UTF-8"?> +<module type="JAVA_MODULE" version="4"> + <component name="NewModuleRootManager" inherit-compiler-output="true"> + <exclude-output /> + <content url="file://$MODULE_DIR$"> + <sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" /> + </content> + <orderEntry type="inheritedJdk" /> + <orderEntry type="sourceFolder" forTests="false" /> + </component> +</module> \ No newline at end of file diff --git a/eraser.starter/src/test/java/de/tudresden/inf/st/eraser/starter/EraserStarterTest.java b/eraser.starter/src/test/java/de/tudresden/inf/st/eraser/starter/EraserStarterTest.java new file mode 100644 index 0000000000000000000000000000000000000000..5ba76bb5ddd0c36dddbdd7dd1e37eb8bcc03d1cc --- /dev/null +++ b/eraser.starter/src/test/java/de/tudresden/inf/st/eraser/starter/EraserStarterTest.java @@ -0,0 +1,187 @@ +package de.tudresden.inf.st.eraser.starter; + +import beaver.Parser; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import de.tudresden.inf.st.eraser.feedbackloop.analyze.AnalyzeImpl; +import de.tudresden.inf.st.eraser.feedbackloop.api.Analyze; +import de.tudresden.inf.st.eraser.feedbackloop.api.Execute; +import de.tudresden.inf.st.eraser.feedbackloop.api.Learner; +import de.tudresden.inf.st.eraser.feedbackloop.api.Plan; +import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model; +import de.tudresden.inf.st.eraser.feedbackloop.execute.ExecuteImpl; +import de.tudresden.inf.st.eraser.feedbackloop.learner.LearnerImpl; +import de.tudresden.inf.st.eraser.feedbackloop.plan.PlanImpl; +import de.tudresden.inf.st.eraser.jastadd.model.*; +import de.tudresden.inf.st.eraser.openhab2.OpenHab2Importer; +import de.tudresden.inf.st.eraser.spark.Application; +import de.tudresden.inf.st.eraser.util.JavaUtils; +import de.tudresden.inf.st.eraser.util.ParserUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.io.File; +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +/** + * This Starter combines and starts all modules. This includes: + * + * <ul> + * <li>Knowledge-Base in <code>eraser-base</code></li> + * <li>Feedback loop in <code>feedbackloop.{analyze,plan,execute}</code></li> + * <li>REST-API in <code>eraser-rest</code></li> + * </ul> + */ +public class EraserStarter { + + private static final Logger logger = LogManager.getLogger(EraserStarter.class); + + @SuppressWarnings("ResultOfMethodCallIgnored") + public static void main(String[] args) { + logger.info("Starting ERASER"); + ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); + File settingsFile = new File("starter-setting.yaml"); + Setting settings; + try { + settings = mapper.readValue(settingsFile, Setting.class); + } catch (Exception e) { + logger.fatal("Could not read settings at '{}'. Exiting.", settingsFile.getAbsolutePath()); + logger.catching(e); + System.exit(1); + return; + } + boolean startRest = settings.rest.use; + + Root model; + switch (settings.initModelWith) { + case openhab: + OpenHab2Importer importer = new OpenHab2Importer(); + try { + model = importer.importFrom(new URL(settings.openhab.url)); + } catch (MalformedURLException e) { + logger.error("Could not parse URL {}", settings.openhab.url); + logger.catching(e); + System.exit(1); + return; + } + logger.info("Imported model {}", model.description()); + break; + case load: + default: + try { + model = ParserUtils.load(settings.load.file, EraserStarter.class); + } catch (IOException | Parser.Exception e) { + logger.error("Problems parsing the given file {}", settings.load.file); + logger.catching(e); + System.exit(1); + return; + } + } + + // initialize activity recognition + if (settings.activity.dummy) { + logger.info("Using dummy activity recognition"); + model.getMachineLearningRoot().setActivityRecognition(DummyMachineLearningModel.createDefault()); + } else { + logger.error("Reading activity recognition from file is not supported yet!"); + // TODO + } + + // initialize preference learning + if (settings.preference.dummy) { + logger.info("Using dummy preference learning"); + model.getMachineLearningRoot().setPreferenceLearning(DummyMachineLearningModel.createDefault()); + } else { + logger.info("Reading preference learning from file {}", settings.preference.file); + Learner learner = new LearnerImpl(); + // there should be a method to load a model using an URL + Model preference = learner.getTrainedModel(settings.preference.realURL(), settings.preference.id); + NeuralNetworkRoot neuralNetwork = LearnerHelper.transform(preference); + if (neuralNetwork == null) { + logger.error("Could not create preference model, see possible previous errors."); + } else { + model.getMachineLearningRoot().setPreferenceLearning(neuralNetwork); + neuralNetwork.connectItems(settings.preference.items); + neuralNetwork.setOutputApplication(zeroToThree -> 25 * zeroToThree); + JavaUtils.ifPresentOrElse( + model.resolveItem(settings.preference.affectedItem), + item -> neuralNetwork.getOutputLayer().setAffectedItem(item), + () -> logger.error("Output item not set from value '{}'", settings.preference.affectedItem)); + } + } + if (!model.getMachineLearningRoot().getActivityRecognition().check()) { + logger.fatal("Invalid activity recognition!"); + System.exit(1); + } + if (!model.getMachineLearningRoot().getPreferenceLearning().check()) { + logger.fatal("Invalid preference learning!"); + System.exit(1); + } + + Analyze analyze = null; + if (settings.useMAPE) { + // configure and start mape loop + logger.info("Starting MAPE loop"); + analyze = new AnalyzeImpl(); + Plan plan = new PlanImpl(); + Execute execute = new ExecuteImpl(); + + analyze.setPlan(plan); + plan.setExecute(execute); + + analyze.setKnowledgeBase(model); + plan.setKnowledgeBase(model); + execute.setKnowledgeBase(model); + + analyze.startAsThread(1, TimeUnit.SECONDS); + if (!startRest) { + // alternative exit condition + System.out.println("Hit [Enter] to exit"); + try { + System.in.read(); + } catch (IOException e) { + e.printStackTrace(); + } + + System.out.println("Stopping..."); + analyze.stop(); + } + } else { + logger.info("No MAPE loop this time"); + } + + if (startRest) { + // start REST-API in new thread + logger.info("Starting REST server"); + Lock lock = new ReentrantLock(); + Condition quitCondition = lock.newCondition(); + Thread t = new Thread(new ThreadGroup("REST-API"), + () -> Application.start(settings.rest.port, model, settings.rest.createDummyMLData, lock, quitCondition)); + t.setDaemon(true); + t.start(); + logger.info("Waiting until request is send to '/system/exit'"); + try { + lock.lock(); + if (t.isAlive()) { + quitCondition.await(); + } + } catch (InterruptedException e) { + logger.warn("Waiting was interrupted"); + } finally { + lock.unlock(); + } + if (analyze != null) { + analyze.stop(); + } + } else { + logger.info("No REST server this time"); + } + logger.info("I'm done here."); + } +} diff --git a/eraser.starter/src/test/java/de/tudresden/inf/st/eraser/starter/LearnerHelperTest.java b/eraser.starter/src/test/java/de/tudresden/inf/st/eraser/starter/LearnerHelperTest.java new file mode 100644 index 0000000000000000000000000000000000000000..34404713f0e3abc32cf87ebca43db783c9dc3288 --- /dev/null +++ b/eraser.starter/src/test/java/de/tudresden/inf/st/eraser/starter/LearnerHelperTest.java @@ -0,0 +1,102 @@ +package de.tudresden.inf.st.eraser.starter; + +import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model; +import de.tudresden.inf.st.eraser.jastadd.model.*; +import org.apache.commons.math3.stat.StatUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.ArrayList; +import java.util.Arrays; + +/** + * Transformation of a {@link Model} into a {@link MachineLearningModel}. + * + * @author rschoene - Initial contribution + */ +class LearnerHelper { + + private static final Logger logger = LogManager.getLogger(LearnerHelper.class); + + // Activation Functions + private static DoubleArrayDoubleFunction sigmoid = inputs -> Math.signum(Arrays.stream(inputs).sum()); + private static DoubleArrayDoubleFunction tanh = inputs -> Math.tanh(Arrays.stream(inputs).sum()); + private static DoubleArrayDoubleFunction function_one = inputs -> 1.0; + + static NeuralNetworkRoot transform(Model model) { + NeuralNetworkRoot result = NeuralNetworkRoot.createEmpty(); + ArrayList<Double> weights = model.getWeights(); + + // inputs + int inputSum = model.getInputLayerNumber() + model.getInputBias(); + for (int i = 0; i < inputSum; ++i) { + InputNeuron inputNeuron = new InputNeuron(); + result.addInputNeuron(inputNeuron); + } + InputNeuron bias = result.getInputNeuron(model.getInputBias()); + + OutputLayer outputLayer = new OutputLayer(); + // output layer + for (int i = 0; i < model.getOutputLayerNumber(); ++i) { + OutputNeuron outputNeuron = new OutputNeuron(); + setActivationFunction(outputNeuron, model.getOutputActivationFunction()); + outputLayer.addOutputNeuron(outputNeuron); + } + result.setOutputLayer(outputLayer); + + // hidden layer + int hiddenSum = model.gethiddenLayerNumber() + model.getHiddenBias(); + HiddenNeuron[] hiddenNeurons = new HiddenNeuron[hiddenSum]; + for (int i = 0; i < (hiddenNeurons.length); i++) { + if (i == model.gethiddenLayerNumber()) { + HiddenNeuron hiddenNeuron = new HiddenNeuron(); + hiddenNeuron.setActivationFormula(function_one); + hiddenNeurons[i] = hiddenNeuron; + result.addHiddenNeuron(hiddenNeuron); + bias.connectTo(hiddenNeuron, 1.0); + for (int out = 0; out < outputLayer.getNumOutputNeuron(); out++) { + hiddenNeuron.connectTo(outputLayer.getOutputNeuron(out), weights.get(i + hiddenSum * out)); + } + } else { + HiddenNeuron hiddenNeuron = new HiddenNeuron(); + setActivationFunction(hiddenNeuron, model.getHiddenActivationFunction()); + hiddenNeurons[i] = hiddenNeuron; + result.addHiddenNeuron(hiddenNeuron); + for (int in = 0; in < inputSum; in++) { + // TODO replace 4 and 5 with model-attributes + result.getInputNeuron(in).connectTo(hiddenNeuron, weights.get((hiddenNeurons.length * 4 + in) + i * 5)); + } + for (int out = 0; out < outputLayer.getNumOutputNeuron(); out++) { + hiddenNeuron.connectTo(outputLayer.getOutputNeuron(out), weights.get(i + hiddenSum * out)); + } + } + } + outputLayer.setCombinator(LearnerHelper::predictor); + logger.info("Created model with {} input, {} hidden and {} output neurons", + result.getNumInputNeuron(), result.getNumHiddenNeuron(), result.getOutputLayer().getNumOutputNeuron()); + return result; + } + + private static void setActivationFunction(HiddenNeuron neuron, String functionName) { + switch (functionName) { + case "ActivationTANH": neuron.setActivationFormula(tanh); break; + case "ActivationLinear": neuron.setActivationFormula(function_one); + case "ActivationSigmoid": neuron.setActivationFormula(sigmoid); break; + default: throw new IllegalArgumentException("Unknown function " + functionName); + } + } + + private static double predictor(double[] inputs) { + int index = 0; + double maxInput = StatUtils.max(inputs); + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] == maxInput) { + index = i; + } + } + //outputs from learner + final double[] outputs = new double[]{2.0, 1.0, 3.0, 0.0}; + return outputs[index]; + } + +} diff --git a/eraser.starter/src/test/java/de/tudresden/inf/st/eraser/starter/SettingTest.java b/eraser.starter/src/test/java/de/tudresden/inf/st/eraser/starter/SettingTest.java new file mode 100644 index 0000000000000000000000000000000000000000..aa11fd9b7b1caf0b749e47996c1cf1885b8525d1 --- /dev/null +++ b/eraser.starter/src/test/java/de/tudresden/inf/st/eraser/starter/SettingTest.java @@ -0,0 +1,43 @@ +package de.tudresden.inf.st.eraser.starter; + +import java.net.URL; +import java.util.List; + +/** + * Setting bean. + * + * @author rschoene - Initial contribution + */ +@SuppressWarnings("WeakerAccess") +class Setting { + public class Rest { + public boolean use = true; + public int port = 4567; + public boolean createDummyMLData = false; + } + public class FileContainer { + public String file; + URL realURL() { + return Setting.class.getClassLoader().getResource(file); + } + } + public class MLContainer extends FileContainer { + public boolean dummy = false; + public int id = 1; + public List<String> items; + public String affectedItem; + } + public class OpenHabContainer { + public String url; + } + public enum InitModelWith { + load, openhab + } + public Rest rest; + public boolean useMAPE = true; + public FileContainer load; + public MLContainer activity; + public MLContainer preference; + public OpenHabContainer openhab; + public InitModelWith initModelWith = InitModelWith.load; +}