Skip to content
Snippets Groups Projects
Commit d1aa00b1 authored by René Schöne's avatar René Schöne
Browse files

Move LearnerHelper to feedbackloop.learner module

parent 235b969c
No related branches found
No related tags found
No related merge requests found
...@@ -10,6 +10,7 @@ import de.tudresden.inf.st.eraser.feedbackloop.api.Learner; ...@@ -10,6 +10,7 @@ import de.tudresden.inf.st.eraser.feedbackloop.api.Learner;
import de.tudresden.inf.st.eraser.feedbackloop.api.Plan; import de.tudresden.inf.st.eraser.feedbackloop.api.Plan;
import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model; import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model;
import de.tudresden.inf.st.eraser.feedbackloop.execute.ExecuteImpl; import de.tudresden.inf.st.eraser.feedbackloop.execute.ExecuteImpl;
import de.tudresden.inf.st.eraser.feedbackloop.learner.LearnerHelper;
import de.tudresden.inf.st.eraser.feedbackloop.learner.LearnerImpl; import de.tudresden.inf.st.eraser.feedbackloop.learner.LearnerImpl;
import de.tudresden.inf.st.eraser.feedbackloop.plan.PlanImpl; import de.tudresden.inf.st.eraser.feedbackloop.plan.PlanImpl;
import de.tudresden.inf.st.eraser.jastadd.model.DummyMachineLearningModel; import de.tudresden.inf.st.eraser.jastadd.model.DummyMachineLearningModel;
......
package de.tudresden.inf.st.eraser.starter; package de.tudresden.inf.st.eraser.feedbackloop.learner;
import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model; import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model;
import de.tudresden.inf.st.eraser.jastadd.model.*; import de.tudresden.inf.st.eraser.jastadd.model.*;
...@@ -17,11 +17,11 @@ import java.util.List; ...@@ -17,11 +17,11 @@ import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
* Transformation of a {@link Model} into a {@link MachineLearningModel}. * Transformation of a {@link Model} into a {@link NeuralNetworkRoot}.
* *
* @author rschoene - Initial contribution * @author rschoene - Initial contribution
*/ */
class LearnerHelper { public class LearnerHelper {
private static final Logger logger = LogManager.getLogger(LearnerHelper.class); private static final Logger logger = LogManager.getLogger(LearnerHelper.class);
...@@ -30,7 +30,7 @@ class LearnerHelper { ...@@ -30,7 +30,7 @@ class LearnerHelper {
private static DoubleArrayDoubleFunction tanh = inputs -> Math.tanh(Arrays.stream(inputs).sum()); private static DoubleArrayDoubleFunction tanh = inputs -> Math.tanh(Arrays.stream(inputs).sum());
private static DoubleArrayDoubleFunction function_one = inputs -> 1.0; private static DoubleArrayDoubleFunction function_one = inputs -> 1.0;
static NeuralNetworkRoot transform(Model model) { public static NeuralNetworkRoot transform(Model model) {
NeuralNetworkRoot result = NeuralNetworkRoot.createEmpty(); NeuralNetworkRoot result = NeuralNetworkRoot.createEmpty();
List<Double> weights = model.getWeights(); List<Double> weights = model.getWeights();
logger.debug("Got {} weights", weights.size()); logger.debug("Got {} weights", weights.size());
......
package de.tudresden.inf.st.eraser.feedbackloop.learner; package de.tudresden.inf.st.eraser.feedbackloop.learner;
import de.tudresden.inf.st.eraser.feedbackloop.api.Learner; import de.tudresden.inf.st.eraser.feedbackloop.api.Learner;
import de.tudresden.inf.st.eraser.feedbackloop.api.model.*; import de.tudresden.inf.st.eraser.feedbackloop.api.model.Model;
import de.tudresden.inf.st.eraser.jastadd.model.*; import de.tudresden.inf.st.eraser.jastadd.model.*;
import org.apache.commons.math3.stat.StatUtils; import org.apache.commons.math3.stat.StatUtils;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.encog.ml.data.MLData;
import org.encog.ml.data.versatile.NormalizationHelper;
import java.io.File; import java.io.File;
import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors;
@SuppressWarnings("unused") @SuppressWarnings("unused")
public class Main { public class Main {
...@@ -70,6 +65,7 @@ public class Main { ...@@ -70,6 +65,7 @@ public class Main {
InitialDataConfig.inputMaxes, InitialDataConfig.inputMins, InitialDataConfig.inputMaxes, InitialDataConfig.inputMins,
InitialDataConfig.targetMaxes, InitialDataConfig.targetMins); InitialDataConfig.targetMaxes, InitialDataConfig.targetMins);
printModel(learner.getTrainedModel(1)); printModel(learner.getTrainedModel(1));
NeuralNetworkRoot eraserModel = LearnerHelper.transform(learner.getTrainedModel(1));
} }
private static void printModel(Model model) { private static void printModel(Model model) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment