From d1aa00b1499c0673d44b01722faaf4c5f3c14bfc Mon Sep 17 00:00:00 2001 From: rschoene <rene.schoene@tu-dresden.de> Date: Tue, 30 Apr 2019 12:32:40 +0200 Subject: [PATCH] Move LearnerHelper to feedbackloop.learner module --- .../de/tudresden/inf/st/eraser/starter/EraserStarter.java | 1 + .../st/eraser/feedbackloop/learner}/LearnerHelper.java | 8 ++++---- .../inf/st/eraser/feedbackloop/learner/Main.java | 8 ++------ 3 files changed, 7 insertions(+), 10 deletions(-) rename {eraser.starter/src/main/java/de/tudresden/inf/st/eraser/starter => feedbackloop.learner/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner}/LearnerHelper.java (96%) diff --git a/eraser.starter/src/main/java/de/tudresden/inf/st/eraser/starter/EraserStarter.java b/eraser.starter/src/main/java/de/tudresden/inf/st/eraser/starter/EraserStarter.java index 2e5f36c4..b207dc5f 100644 --- a/eraser.starter/src/main/java/de/tudresden/inf/st/eraser/starter/EraserStarter.java +++ b/eraser.starter/src/main/java/de/tudresden/inf/st/eraser/starter/EraserStarter.java @@ -10,6 +10,7 @@ import de.tudresden.inf.st.eraser.feedbackloop.api.Learner; import de.tudresden.inf.st.eraser.feedbackloop.api.Plan; import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model; import de.tudresden.inf.st.eraser.feedbackloop.execute.ExecuteImpl; +import de.tudresden.inf.st.eraser.feedbackloop.learner.LearnerHelper; import de.tudresden.inf.st.eraser.feedbackloop.learner.LearnerImpl; import de.tudresden.inf.st.eraser.feedbackloop.plan.PlanImpl; import de.tudresden.inf.st.eraser.jastadd.model.DummyMachineLearningModel; diff --git a/eraser.starter/src/main/java/de/tudresden/inf/st/eraser/starter/LearnerHelper.java b/feedbackloop.learner/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner/LearnerHelper.java similarity index 96% rename from eraser.starter/src/main/java/de/tudresden/inf/st/eraser/starter/LearnerHelper.java rename to feedbackloop.learner/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner/LearnerHelper.java index 0c180ca7..c33928e0 100644 --- a/eraser.starter/src/main/java/de/tudresden/inf/st/eraser/starter/LearnerHelper.java +++ b/feedbackloop.learner/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner/LearnerHelper.java @@ -1,4 +1,4 @@ -package de.tudresden.inf.st.eraser.starter; +package de.tudresden.inf.st.eraser.feedbackloop.learner; import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model; import de.tudresden.inf.st.eraser.jastadd.model.*; @@ -17,11 +17,11 @@ import java.util.List; import java.util.stream.Collectors; /** - * Transformation of a {@link Model} into a {@link MachineLearningModel}. + * Transformation of a {@link Model} into a {@link NeuralNetworkRoot}. * * @author rschoene - Initial contribution */ -class LearnerHelper { +public class LearnerHelper { private static final Logger logger = LogManager.getLogger(LearnerHelper.class); @@ -30,7 +30,7 @@ class LearnerHelper { private static DoubleArrayDoubleFunction tanh = inputs -> Math.tanh(Arrays.stream(inputs).sum()); private static DoubleArrayDoubleFunction function_one = inputs -> 1.0; - static NeuralNetworkRoot transform(Model model) { + public static NeuralNetworkRoot transform(Model model) { NeuralNetworkRoot result = NeuralNetworkRoot.createEmpty(); List<Double> weights = model.getWeights(); logger.debug("Got {} weights", weights.size()); diff --git a/feedbackloop.learner/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner/Main.java b/feedbackloop.learner/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner/Main.java index 2f91a3a3..19765c1e 100644 --- a/feedbackloop.learner/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner/Main.java +++ b/feedbackloop.learner/src/main/java/de/tudresden/inf/st/eraser/feedbackloop/learner/Main.java @@ -1,24 +1,19 @@ package de.tudresden.inf.st.eraser.feedbackloop.learner; import de.tudresden.inf.st.eraser.feedbackloop.api.Learner; -import de.tudresden.inf.st.eraser.feedbackloop.api.model.*; +import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model; import de.tudresden.inf.st.eraser.jastadd.model.*; import org.apache.commons.math3.stat.StatUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.encog.ml.data.MLData; -import org.encog.ml.data.versatile.NormalizationHelper; import java.io.File; -import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.concurrent.TimeUnit; import java.util.function.Function; -import java.util.stream.Collectors; @SuppressWarnings("unused") public class Main { @@ -70,6 +65,7 @@ public class Main { InitialDataConfig.inputMaxes, InitialDataConfig.inputMins, InitialDataConfig.targetMaxes, InitialDataConfig.targetMins); printModel(learner.getTrainedModel(1)); + NeuralNetworkRoot eraserModel = LearnerHelper.transform(learner.getTrainedModel(1)); } private static void printModel(Model model) { -- GitLab