From bcd27db3c7827078f3818d08561bb26d1ab52c5e Mon Sep 17 00:00:00 2001 From: rschoene <rene.schoene@tu-dresden.de> Date: Fri, 25 Oct 2019 10:31:19 +0200 Subject: [PATCH] Make LearnerTest succeed for minimum accuracy. - Remove some logging of Learner --- .../feedbackloop.learner_backup/Learner.java | 4 +-- .../learner_backup/LearnerTest.java | 30 +++++++++++++++---- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/Learner.java b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/Learner.java index c12c6091..3b77d9ff 100644 --- a/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/Learner.java +++ b/feedbackloop.learner_backup/src/main/java/de/tudresden/inf/st/eraser/feedbackloop.learner_backup/Learner.java @@ -149,7 +149,7 @@ public class Learner { } String activity_predictor(String[] new_data) { - logger.info("Activity predicting ... ..."); +// logger.info("Activity predicting ... ..."); String activity_result; activityDataAnalyser("../datasets/backup/activity_data.csv"); BasicNetwork activity_method = (BasicNetwork) loadObject(save_activity_model_file); @@ -163,7 +163,7 @@ public class Learner { } String[] preference_predictor(String[] new_data) { - logger.info("Activity predicting ... ..."); +// logger.info("Preference predicting ... ..."); String[] preference_result; preference_result = new String[2]; preferenceDataAnalyser("../datasets/backup/preference_data.csv"); diff --git a/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTest.java b/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTest.java index 9396fd1a..896f402c 100644 --- a/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTest.java +++ b/feedbackloop.learner_backup/src/test/java/de/tudresden/inf/st/eraser/feedbackloop/learner_backup/LearnerTest.java @@ -3,19 +3,23 @@ package de.tudresden.inf.st.eraser.feedbackloop.learner_backup; import com.opencsv.CSVReader; import de.tudresden.inf.st.eraser.jastadd.model.*; import de.tudresden.inf.st.eraser.util.ParserUtils; -import org.junit.Assert; +import org.apache.logging.log4j.LogManager; +import org.hamcrest.Matchers; import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Test; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.Reader; +import java.net.MalformedURLException; import java.net.URL; +import java.nio.file.Paths; import java.util.Collections; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.*; /** * Testing the learner. @@ -24,7 +28,9 @@ import static org.junit.Assert.fail; */ public class LearnerTest { - private static final URL ACTIVITY_DATA = LearnerTest.class.getResource("/activity_data.csv"); + private static final double MIN_ACCURACY = 0.95; + + private static URL ACTIVITY_DATA; private MachineLearningEncoder encoder; private MachineLearningDecoder decoder; private String[] itemNames = new String[]{"m_accel_x", "m_accel_y", "m_accel_z", "m_rotation_x", "m_rotation_y", @@ -37,6 +43,11 @@ public class LearnerTest { "reading"}; private Root root; + @BeforeClass + public static void setActivityData() throws MalformedURLException { + ACTIVITY_DATA = Paths.get("src", "test", "resources", "activity_data.csv").toUri().toURL(); + } + private Root createKnowledgeBase() { Root result = Root.createEmptyRoot(); Group group = new Group(); @@ -70,6 +81,7 @@ public class LearnerTest { @Test public void testActivities() { // go through same csv as for training and test some of the values + int correct = 0, wrong = 0; try(InputStream is = ACTIVITY_DATA.openStream(); Reader reader = new InputStreamReader(is); CSVReader csvreader = new CSVReader(reader)) { @@ -92,12 +104,20 @@ public class LearnerTest { assertEquals("activity not updated", root.getSmartHomeEntityModel().getActivityItem(), update.getItem()); update.apply(); // check if the correct activity was set - assertEquals("wrong activity", line[12], root.currentActivityName()); + if (line[12].equals(root.currentActivityName())) { + correct++; + } else { + wrong++; + } } } } catch (IOException e) { e.printStackTrace(); fail(); } + assertThat(correct + wrong, greaterThan(0)); + double accuracy = correct * 1.0 / (correct + wrong); + assertThat(accuracy, greaterThan(MIN_ACCURACY)); + LogManager.getLogger(LearnerTest.class).info("Accuracy: {}", accuracy); } } -- GitLab