Skip to content
Snippets Groups Projects
Commit 036dcf0b authored by boqiren's avatar boqiren
Browse files

first prototype learner backup

parent 67522fd8
No related branches found
No related tags found
No related merge requests found
/build/
logs/
repositories {
mavenCentral()
}
sourceCompatibility = 1.8
apply plugin: 'java'
apply plugin: 'application'
dependencies {
compile project(':eraser-base')
compile group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: '2.9.8'
compile group: 'org.apache.logging.log4j', name: 'log4j-api', version: '2.11.2'
compile group: 'org.apache.logging.log4j', name: 'log4j-core', version: '2.11.2'
testCompile group: 'junit', name: 'junit', version: '4.12'
testCompile group: 'org.hamcrest', name: 'hamcrest-junit', version: '2.0.0.0'
compile group: 'org.encog', name: 'encog-core', version: '3.4'
}
run {
mainClassName = 'de.tudresden.inf.st.eraser.feedbackloop.learner_backup.Main'
standardInput = System.in
if (project.hasProperty("appArgs")) {
args Eval.me(appArgs)
}
}
sourceSets {
main {
java {
srcDir 'src/main/java'
}
}
}
<?xml version="1.0" encoding="UTF-8"?>
<Configuration>
<Appenders>
<Console name="Console">
<PatternLayout pattern="%highlight{%d{HH:mm:ss.SSS} %-5level} %c{1.} - %msg%n"/>
</Console>
<RollingFile name="RollingFile" fileName="logs/eraser.log"
filePattern="logs/eraser-%i.log">
<PatternLayout pattern="%d{HH:mm:ss.SSS} %-5level %logger{36} - %msg%n"/>
<Policies>
<OnStartupTriggeringPolicy/>
</Policies>
<DefaultRolloverStrategy max="20"/>
</RollingFile>
</Appenders>
<Loggers>
<Root level="debug">
<AppenderRef ref="Console"/>
<AppenderRef ref="RollingFile"/>
</Root>
</Loggers>
</Configuration>
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import java.io.File;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.Arrays;
import com.sun.javafx.tools.packager.Log;
import org.encog.ConsoleStatusReportable;
import org.encog.Encog;
import org.encog.bot.BotUtil;
import org.encog.ml.MLInput;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.versatile.NormalizationHelper;
import org.encog.ml.data.versatile.VersatileMLDataSet;
import org.encog.ml.data.versatile.columns.ColumnDefinition;
import org.encog.ml.data.versatile.columns.ColumnType;
import org.encog.ml.data.versatile.sources.CSVDataSource;
import org.encog.ml.data.versatile.sources.VersatileDataSource;
import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.model.EncogModel;
import org.encog.neural.networks.BasicNetwork;
import org.encog.util.csv.CSVFormat;
import org.encog.util.csv.ReadCSV;
import org.encog.util.simple.EncogUtility;
import static org.encog.persist.EncogDirectoryPersistence.*;
public class Learner {
/**
* intial train
* */
private String csv_url_activity;
private String csv_url_perference;
private String save_activity_model_file = "datasets/backup/activity_model.eg";
private String save_perference_model_file = "datasets/backup/preference_model.eg";
private File csv_file;
private VersatileDataSource souce;
private VersatileMLDataSet data;
private EncogModel model;
private NormalizationHelper activity_helper;
private NormalizationHelper preference_helper;
private MLRegression best_method;
private String[] new_data;
private String activity_result;
private String preference_result;
private void activityDataAnalyser(String activity_csv_url){
this.csv_url_activity = activity_csv_url;
this.csv_file = new File(csv_url_activity);
souce = new CSVDataSource(csv_file,false,CSVFormat.DECIMAL_POINT);
data = new VersatileMLDataSet(souce);
data.defineSourceColumn("monat", 0, ColumnType.continuous);
data.defineSourceColumn("day", 1, ColumnType.continuous);
data.defineSourceColumn("hour", 2, ColumnType.continuous);
data.defineSourceColumn("minute", 3, ColumnType.continuous);
ColumnDefinition outputColumn = data.defineSourceColumn("labels", 4, ColumnType.continuous);
data.defineSingleOutputOthersInput(outputColumn);
data.analyze();
System.out.println("get data ");
model = new EncogModel(data);
model.selectMethod(data, MLMethodFactory.TYPE_FEEDFORWARD);
data.normalize();
activity_helper = data.getNormHelper();
System.out.println(activity_helper.toString());
}
private void perferenceDataAnalyser(String perference_csv_url){
this.csv_url_perference = perference_csv_url;
this.csv_file = new File(this.csv_url_perference);
souce = new CSVDataSource(csv_file,false,CSVFormat.DECIMAL_POINT);
data = new VersatileMLDataSet(souce);
data.defineSourceColumn("activity", 0, ColumnType.continuous);
data.defineSourceColumn("brightness", 1, ColumnType.continuous);
data.defineSourceColumn("time", 2, ColumnType.continuous);
data.defineSourceColumn("minute", 3, ColumnType.continuous);
ColumnDefinition outputColumn = data.defineSourceColumn("labels", 4, ColumnType.continuous);
data.defineSingleOutputOthersInput(outputColumn);
data.analyze();
model = new EncogModel(data);
model.selectMethod(data, MLMethodFactory.TYPE_FEEDFORWARD);
//model.setReport(new ConsoleStatusReportable());
data.normalize();
preference_helper = data.getNormHelper();
System.out.println(activity_helper.toString());
}
void train(String activity_url,String perference_url){
activity_train(activity_url);
Log.info("activity training finished");
preference_train(perference_url);
Log.info("preference training finished");
}
private void activity_train(String activity_csv_url){
activityDataAnalyser(activity_csv_url);
model.holdBackValidation(0.3, true, 1001);
model.selectTrainingType(data);
best_method = (MLRegression)model.crossvalidate(5, true);
System.out.println(best_method);
saveEncogModel(save_activity_model_file);
}
private void preference_train(String perfence_csv_url){
perferenceDataAnalyser(perfence_csv_url);
model.holdBackValidation(0.3, true, 1001);
model.selectTrainingType(data);
best_method = (MLRegression)model.crossvalidate(5, true);
System.out.println(best_method);
saveEncogModel(save_perference_model_file);
}
String[] predictor(String[] new_data){
this.new_data = new_data;
activityDataAnalyser("datasets/backup/activity_data.csv");
perferenceDataAnalyser("datasets/backup/preference_data.csv");
String[] result = new String[2];
result[0] = activity_predictor();
result[1] = perference_predictor();
Encog.getInstance().shutdown();
return result;
}
private String activity_predictor(){
BasicNetwork activity_method = (BasicNetwork) loadObject(new File(save_activity_model_file));
MLData input = activity_helper.allocateInputVector();
String[] activity_new_data = new String[4];
activity_new_data[0] = new_data[0];
activity_new_data[1] = new_data[1];
activity_new_data[2] = new_data[2];
activity_new_data[3] = new_data[3];
activity_helper.normalizeInputVector(activity_new_data,input.getData(),false);
MLData output = activity_method.compute(input);
System.out.println("input:"+input);
System.out.println("output"+output);
activity_result = activity_helper.denormalizeOutputVectorToString(output)[0];
System.out.println("output activity"+ activity_result);
return activity_result;
}
private String perference_predictor(){
BasicNetwork preference_method = (BasicNetwork)loadObject(new File(save_perference_model_file));
MLData input = preference_helper.allocateInputVector();
String[] perference_new_data = new String[4];
perference_new_data[0] = activity_result;
perference_new_data[1] = new_data[4];
perference_new_data[2] = new_data[5];
perference_new_data[3] = new_data[6];
preference_helper.normalizeInputVector(perference_new_data, input.getData(),false);
MLData output = preference_method.compute(input);
preference_result = preference_helper.denormalizeOutputVectorToString(output)[0];
return preference_result;
}
private void saveEncogModel(String model_file_url){
saveObject(new File(model_file_url), this.best_method);
}
}
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import de.tudresden.inf.st.eraser.jastadd.model.*;
public class Main {
public static void main(String[] args) {
/**
* new data from KB
* */
String[] new_data = new String[7];
new_data[0]="7";
new_data[1]="20";
new_data[2]="12";
new_data[3]="13";
new_data[4]="7";
new_data[5]="25";
new_data[6]="12";
/**
* learner.train(activity_csv_url, preference_data_url)
* learner.predictor get the result from predictor for new data
* */
Learner learner=new Learner();
//learner.train("datasets/activity_data.csv", "datasets/preference_data.csv");
String[] result = learner.predictor(new_data);
System.out.println("activity is:" + result[0]);
System.out.println("perference is: "+ result[1]);
}
}
<?xml version="1.0" encoding="UTF-8"?>
<Configuration>
<Appenders>
<Console name="Console">
<PatternLayout pattern="%highlight{%d{HH:mm:ss.SSS} %-5level} %c{1.} - %msg%n"/>
</Console>
<RollingFile name="RollingFile" fileName="logs/eraser.log"
filePattern="logs/eraser-%i.log">
<PatternLayout pattern="%d{HH:mm:ss.SSS} %-5level %logger{36} - %msg%n"/>
<Policies>
<OnStartupTriggeringPolicy/>
</Policies>
<DefaultRolloverStrategy max="20"/>
</RollingFile>
</Appenders>
<Loggers>
<Root level="debug">
<AppenderRef ref="Console"/>
<AppenderRef ref="RollingFile"/>
</Root>
</Loggers>
</Configuration>
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import de.tudresden.inf.st.eraser.jastadd.model.*;
import org.junit.Test;
import java.util.Set;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.fail;
/**
* TODO: Add description.
*
* @author rschoene - Initial contribution
*/
public class ATest {
@Test
public void test1() {
fail();
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment