Skip to content
Snippets Groups Projects
Commit d7cc199b authored by boqiren's avatar boqiren
Browse files

old version

parent 2bf72521
Branches
No related tags found
No related merge requests found
<?xml version="1.0" encoding="UTF-8"?>
<module type="JAVA_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
package de.tudresden.inf.st.eraser.starter;
import beaver.Parser;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import de.tudresden.inf.st.eraser.feedbackloop.analyze.AnalyzeImpl;
import de.tudresden.inf.st.eraser.feedbackloop.api.Analyze;
import de.tudresden.inf.st.eraser.feedbackloop.api.Execute;
import de.tudresden.inf.st.eraser.feedbackloop.api.Learner;
import de.tudresden.inf.st.eraser.feedbackloop.api.Plan;
import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model;
import de.tudresden.inf.st.eraser.feedbackloop.execute.ExecuteImpl;
import de.tudresden.inf.st.eraser.feedbackloop.learner.LearnerImpl;
import de.tudresden.inf.st.eraser.feedbackloop.plan.PlanImpl;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import de.tudresden.inf.st.eraser.openhab2.OpenHab2Importer;
import de.tudresden.inf.st.eraser.spark.Application;
import de.tudresden.inf.st.eraser.util.JavaUtils;
import de.tudresden.inf.st.eraser.util.ParserUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.File;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
/**
* This Starter combines and starts all modules. This includes:
*
* <ul>
* <li>Knowledge-Base in <code>eraser-base</code></li>
* <li>Feedback loop in <code>feedbackloop.{analyze,plan,execute}</code></li>
* <li>REST-API in <code>eraser-rest</code></li>
* </ul>
*/
public class EraserStarter {
private static final Logger logger = LogManager.getLogger(EraserStarter.class);
@SuppressWarnings("ResultOfMethodCallIgnored")
public static void main(String[] args) {
logger.info("Starting ERASER");
ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
File settingsFile = new File("starter-setting.yaml");
Setting settings;
try {
settings = mapper.readValue(settingsFile, Setting.class);
} catch (Exception e) {
logger.fatal("Could not read settings at '{}'. Exiting.", settingsFile.getAbsolutePath());
logger.catching(e);
System.exit(1);
return;
}
boolean startRest = settings.rest.use;
Root model;
switch (settings.initModelWith) {
case openhab:
OpenHab2Importer importer = new OpenHab2Importer();
try {
model = importer.importFrom(new URL(settings.openhab.url));
} catch (MalformedURLException e) {
logger.error("Could not parse URL {}", settings.openhab.url);
logger.catching(e);
System.exit(1);
return;
}
logger.info("Imported model {}", model.description());
break;
case load:
default:
try {
model = ParserUtils.load(settings.load.file, EraserStarter.class);
} catch (IOException | Parser.Exception e) {
logger.error("Problems parsing the given file {}", settings.load.file);
logger.catching(e);
System.exit(1);
return;
}
}
// initialize activity recognition
if (settings.activity.dummy) {
logger.info("Using dummy activity recognition");
model.getMachineLearningRoot().setActivityRecognition(DummyMachineLearningModel.createDefault());
} else {
logger.error("Reading activity recognition from file is not supported yet!");
// TODO
}
// initialize preference learning
if (settings.preference.dummy) {
logger.info("Using dummy preference learning");
model.getMachineLearningRoot().setPreferenceLearning(DummyMachineLearningModel.createDefault());
} else {
logger.info("Reading preference learning from file {}", settings.preference.file);
Learner learner = new LearnerImpl();
// there should be a method to load a model using an URL
Model preference = learner.getTrainedModel(settings.preference.realURL(), settings.preference.id);
NeuralNetworkRoot neuralNetwork = LearnerHelper.transform(preference);
if (neuralNetwork == null) {
logger.error("Could not create preference model, see possible previous errors.");
} else {
model.getMachineLearningRoot().setPreferenceLearning(neuralNetwork);
neuralNetwork.connectItems(settings.preference.items);
neuralNetwork.setOutputApplication(zeroToThree -> 25 * zeroToThree);
JavaUtils.ifPresentOrElse(
model.resolveItem(settings.preference.affectedItem),
item -> neuralNetwork.getOutputLayer().setAffectedItem(item),
() -> logger.error("Output item not set from value '{}'", settings.preference.affectedItem));
}
}
if (!model.getMachineLearningRoot().getActivityRecognition().check()) {
logger.fatal("Invalid activity recognition!");
System.exit(1);
}
if (!model.getMachineLearningRoot().getPreferenceLearning().check()) {
logger.fatal("Invalid preference learning!");
System.exit(1);
}
Analyze analyze = null;
if (settings.useMAPE) {
// configure and start mape loop
logger.info("Starting MAPE loop");
analyze = new AnalyzeImpl();
Plan plan = new PlanImpl();
Execute execute = new ExecuteImpl();
analyze.setPlan(plan);
plan.setExecute(execute);
analyze.setKnowledgeBase(model);
plan.setKnowledgeBase(model);
execute.setKnowledgeBase(model);
analyze.startAsThread(1, TimeUnit.SECONDS);
if (!startRest) {
// alternative exit condition
System.out.println("Hit [Enter] to exit");
try {
System.in.read();
} catch (IOException e) {
e.printStackTrace();
}
System.out.println("Stopping...");
analyze.stop();
}
} else {
logger.info("No MAPE loop this time");
}
if (startRest) {
// start REST-API in new thread
logger.info("Starting REST server");
Lock lock = new ReentrantLock();
Condition quitCondition = lock.newCondition();
Thread t = new Thread(new ThreadGroup("REST-API"),
() -> Application.start(settings.rest.port, model, settings.rest.createDummyMLData, lock, quitCondition));
t.setDaemon(true);
t.start();
logger.info("Waiting until request is send to '/system/exit'");
try {
lock.lock();
if (t.isAlive()) {
quitCondition.await();
}
} catch (InterruptedException e) {
logger.warn("Waiting was interrupted");
} finally {
lock.unlock();
}
if (analyze != null) {
analyze.stop();
}
} else {
logger.info("No REST server this time");
}
logger.info("I'm done here.");
}
}
package de.tudresden.inf.st.eraser.starter;
import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.Arrays;
/**
* Transformation of a {@link Model} into a {@link MachineLearningModel}.
*
* @author rschoene - Initial contribution
*/
class LearnerHelper {
private static final Logger logger = LogManager.getLogger(LearnerHelper.class);
// Activation Functions
private static DoubleArrayDoubleFunction sigmoid = inputs -> Math.signum(Arrays.stream(inputs).sum());
private static DoubleArrayDoubleFunction tanh = inputs -> Math.tanh(Arrays.stream(inputs).sum());
private static DoubleArrayDoubleFunction function_one = inputs -> 1.0;
static NeuralNetworkRoot transform(Model model) {
NeuralNetworkRoot result = NeuralNetworkRoot.createEmpty();
ArrayList<Double> weights = model.getWeights();
// inputs
int inputSum = model.getInputLayerNumber() + model.getInputBias();
for (int i = 0; i < inputSum; ++i) {
InputNeuron inputNeuron = new InputNeuron();
result.addInputNeuron(inputNeuron);
}
InputNeuron bias = result.getInputNeuron(model.getInputBias());
OutputLayer outputLayer = new OutputLayer();
// output layer
for (int i = 0; i < model.getOutputLayerNumber(); ++i) {
OutputNeuron outputNeuron = new OutputNeuron();
setActivationFunction(outputNeuron, model.getOutputActivationFunction());
outputLayer.addOutputNeuron(outputNeuron);
}
result.setOutputLayer(outputLayer);
// hidden layer
int hiddenSum = model.gethiddenLayerNumber() + model.getHiddenBias();
HiddenNeuron[] hiddenNeurons = new HiddenNeuron[hiddenSum];
for (int i = 0; i < (hiddenNeurons.length); i++) {
if (i == model.gethiddenLayerNumber()) {
HiddenNeuron hiddenNeuron = new HiddenNeuron();
hiddenNeuron.setActivationFormula(function_one);
hiddenNeurons[i] = hiddenNeuron;
result.addHiddenNeuron(hiddenNeuron);
bias.connectTo(hiddenNeuron, 1.0);
for (int out = 0; out < outputLayer.getNumOutputNeuron(); out++) {
hiddenNeuron.connectTo(outputLayer.getOutputNeuron(out), weights.get(i + hiddenSum * out));
}
} else {
HiddenNeuron hiddenNeuron = new HiddenNeuron();
setActivationFunction(hiddenNeuron, model.getHiddenActivationFunction());
hiddenNeurons[i] = hiddenNeuron;
result.addHiddenNeuron(hiddenNeuron);
for (int in = 0; in < inputSum; in++) {
// TODO replace 4 and 5 with model-attributes
result.getInputNeuron(in).connectTo(hiddenNeuron, weights.get((hiddenNeurons.length * 4 + in) + i * 5));
}
for (int out = 0; out < outputLayer.getNumOutputNeuron(); out++) {
hiddenNeuron.connectTo(outputLayer.getOutputNeuron(out), weights.get(i + hiddenSum * out));
}
}
}
outputLayer.setCombinator(LearnerHelper::predictor);
logger.info("Created model with {} input, {} hidden and {} output neurons",
result.getNumInputNeuron(), result.getNumHiddenNeuron(), result.getOutputLayer().getNumOutputNeuron());
return result;
}
private static void setActivationFunction(HiddenNeuron neuron, String functionName) {
switch (functionName) {
case "ActivationTANH": neuron.setActivationFormula(tanh); break;
case "ActivationLinear": neuron.setActivationFormula(function_one);
case "ActivationSigmoid": neuron.setActivationFormula(sigmoid); break;
default: throw new IllegalArgumentException("Unknown function " + functionName);
}
}
private static double predictor(double[] inputs) {
int index = 0;
double maxInput = StatUtils.max(inputs);
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] == maxInput) {
index = i;
}
}
//outputs from learner
final double[] outputs = new double[]{2.0, 1.0, 3.0, 0.0};
return outputs[index];
}
}
package de.tudresden.inf.st.eraser.starter;
import java.net.URL;
import java.util.List;
/**
* Setting bean.
*
* @author rschoene - Initial contribution
*/
@SuppressWarnings("WeakerAccess")
class Setting {
public class Rest {
public boolean use = true;
public int port = 4567;
public boolean createDummyMLData = false;
}
public class FileContainer {
public String file;
URL realURL() {
return Setting.class.getClassLoader().getResource(file);
}
}
public class MLContainer extends FileContainer {
public boolean dummy = false;
public int id = 1;
public List<String> items;
public String affectedItem;
}
public class OpenHabContainer {
public String url;
}
public enum InitModelWith {
load, openhab
}
public Rest rest;
public boolean useMAPE = true;
public FileContainer load;
public MLContainer activity;
public MLContainer preference;
public OpenHabContainer openhab;
public InitModelWith initModelWith = InitModelWith.load;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment