Skip to content
Snippets Groups Projects
Commit da21d563 authored by René Schöne's avatar René Schöne
Browse files

Use color similarity to test preferences.

parent cf9ce4a0
No related branches found
No related tags found
1 merge request!19dev to master
Pipeline #4719 failed
...@@ -84,6 +84,7 @@ public class LearnerTest { ...@@ -84,6 +84,7 @@ public class LearnerTest {
.orElseThrow(() -> new AssertionError( .orElseThrow(() -> new AssertionError(
"Item " + LearnerTestConstants.PREFERENCE_OUTPUT_ITEM_NAME + " not found"))) "Item " + LearnerTestConstants.PREFERENCE_OUTPUT_ITEM_NAME + " not found")))
.setStateOfOutputItem(Item::getStateAsString) .setStateOfOutputItem(Item::getStateAsString)
.setCheckUpdate(LearnerTestUtils::colorSimilar)
.setFactoryTarget(PREFERENCE_LEARNING) .setFactoryTarget(PREFERENCE_LEARNING)
.setSingleUpdateList(true)); .setSingleUpdateList(true));
} }
......
...@@ -8,6 +8,10 @@ package de.tudresden.inf.st.eraser.feedbackloop.learner_backup; ...@@ -8,6 +8,10 @@ package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
public interface LearnerTestConstants { public interface LearnerTestConstants {
/** Minimal accuracy (correct / total classifications) */ /** Minimal accuracy (correct / total classifications) */
double MIN_ACCURACY = 0.8; double MIN_ACCURACY = 0.8;
/** Maximum difference when comparing colors */
double MAX_COLOR_DIFFERENCE = 0.2;
/** Weights for difference (in order: Hue, Saturation, Brightness) when comparing colors */
double[] COLOR_WEIGHTS = new double[]{0.8/360, 0.1/100, 0.1/100};
/** Names of item names for activity recognition, in test data */ /** Names of item names for activity recognition, in test data */
String[] ACTIVITY_INPUT_ITEM_NAMES = new String[]{"m_accel_x", "m_accel_y", "m_accel_z", "m_rotation_x", "m_rotation_y", String[] ACTIVITY_INPUT_ITEM_NAMES = new String[]{"m_accel_x", "m_accel_y", "m_accel_z", "m_rotation_x", "m_rotation_y",
"m_rotation_z", "w_accel_x", "w_accel_y", "w_accel_z", "w_rotation_x", "w_rotation_y", "w_rotation_z"}; "m_rotation_z", "w_accel_x", "w_accel_y", "w_accel_z", "w_rotation_x", "w_rotation_y", "w_rotation_z"};
......
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup; package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import de.tudresden.inf.st.eraser.feedbackloop.learner_backup.LearnerTestUtils.CheckUpdate;
import de.tudresden.inf.st.eraser.jastadd.model.Item; import de.tudresden.inf.st.eraser.jastadd.model.Item;
import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningHandlerFactory; import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningHandlerFactory;
import lombok.Data; import lombok.Data;
...@@ -14,7 +15,7 @@ import java.util.function.Supplier; ...@@ -14,7 +15,7 @@ import java.util.function.Supplier;
@Data @Data
@Accessors(chain = true) @Accessors(chain = true)
public class LearnerTestSettings { class LearnerTestSettings {
private URL configURL; private URL configURL;
private URL dataURL; private URL dataURL;
private String[] inputItemNames; private String[] inputItemNames;
...@@ -22,10 +23,12 @@ public class LearnerTestSettings { ...@@ -22,10 +23,12 @@ public class LearnerTestSettings {
private final Map<String, BiConsumer<Item, String>> specialInputHandler = new HashMap<>(); private final Map<String, BiConsumer<Item, String>> specialInputHandler = new HashMap<>();
private Supplier<Item> outputItemProvider; private Supplier<Item> outputItemProvider;
private Function<Item, String> stateOfOutputItem; private Function<Item, String> stateOfOutputItem;
private CheckUpdate checkUpdate = String::equals;
private MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget; private MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget;
private boolean singleUpdateList; private boolean singleUpdateList;
public LearnerTestSettings putSpecialInputHandler(String itemName, BiConsumer<Item, String> handler) { @SuppressWarnings("SameParameterValue")
LearnerTestSettings putSpecialInputHandler(String itemName, BiConsumer<Item, String> handler) {
specialInputHandler.put(itemName, handler); specialInputHandler.put(itemName, handler);
return this; return this;
} }
......
...@@ -5,6 +5,7 @@ import de.tudresden.inf.st.eraser.jastadd.model.*; ...@@ -5,6 +5,7 @@ import de.tudresden.inf.st.eraser.jastadd.model.*;
import de.tudresden.inf.st.eraser.util.ParserUtils; import de.tudresden.inf.st.eraser.util.ParserUtils;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.junit.Test;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
...@@ -13,6 +14,8 @@ import java.io.Reader; ...@@ -13,6 +14,8 @@ import java.io.Reader;
import java.util.*; import java.util.*;
import java.util.stream.Stream; import java.util.stream.Stream;
import static de.tudresden.inf.st.eraser.feedbackloop.learner_backup.LearnerTestConstants.COLOR_WEIGHTS;
import static de.tudresden.inf.st.eraser.feedbackloop.learner_backup.LearnerTestConstants.MAX_COLOR_DIFFERENCE;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.*; import static org.junit.Assert.*;
...@@ -55,35 +58,35 @@ public class LearnerTestUtils { ...@@ -55,35 +58,35 @@ public class LearnerTestUtils {
} }
static void testLearner( static void testLearner(
LearnerSubjectUnderTest sut, LearnerTestSettings learnerTestSettings) throws IOException { LearnerSubjectUnderTest sut, LearnerTestSettings settings) throws IOException {
sut.initFor(learnerTestSettings.getFactoryTarget(), learnerTestSettings.getConfigURL()); sut.initFor(settings.getFactoryTarget(), settings.getConfigURL());
// maybe use factory.createModel() here instead // maybe use factory.createModel() here instead
// go through same csv as for training and test some of the values // go through same csv as for training and test some of the values
int correct = 0, wrong = 0; int correct = 0, wrong = 0;
try(InputStream is = learnerTestSettings.getDataURL().openStream(); try(InputStream is = settings.getDataURL().openStream();
Reader reader = new InputStreamReader(is); Reader reader = new InputStreamReader(is);
CSVReader csvreader = new CSVReader(reader)) { CSVReader csvreader = new CSVReader(reader)) {
int index = 0; int index = 0;
for (String[] line : csvreader) { for (String[] line : csvreader) {
if (++index % 10 == 0) { if (++index % 10 == 0) {
// only check every 10th line, push an update for every 12 input columns // only check every 10th line, push an update for every 12 input columns
List<Item> itemsToUpdate = new ArrayList<>(learnerTestSettings.getInputItemNames().length); List<Item> itemsToUpdate = new ArrayList<>(settings.getInputItemNames().length);
for (int i = 0; i < learnerTestSettings.getInputItemNames().length; i++) { for (int i = 0; i < settings.getInputItemNames().length; i++) {
String itemName = learnerTestSettings.getInputItemNames()[i]; String itemName = settings.getInputItemNames()[i];
Item item = sut.root.getSmartHomeEntityModel().resolveItem(itemName) Item item = sut.root.getSmartHomeEntityModel().resolveItem(itemName)
.orElseThrow(() -> new AssertionError("Item " + itemName + " not found")); .orElseThrow(() -> new AssertionError("Item " + itemName + " not found"));
if (learnerTestSettings.getSpecialInputHandler().containsKey(itemName)) { if (settings.getSpecialInputHandler().containsKey(itemName)) {
learnerTestSettings.getSpecialInputHandler().get(itemName).accept(item, line[i]); settings.getSpecialInputHandler().get(itemName).accept(item, line[i]);
} else { } else {
item.setStateFromString(line[i]); item.setStateFromString(line[i]);
} }
if (learnerTestSettings.isSingleUpdateList()) { if (settings.isSingleUpdateList()) {
itemsToUpdate.add(item); itemsToUpdate.add(item);
} else { } else {
sut.encoder.newData(Collections.singletonList(item)); sut.encoder.newData(Collections.singletonList(item));
} }
} }
if (learnerTestSettings.isSingleUpdateList()) { if (settings.isSingleUpdateList()) {
sut.encoder.newData(itemsToUpdate); sut.encoder.newData(itemsToUpdate);
} }
MachineLearningResult result = sut.decoder.classify(); MachineLearningResult result = sut.decoder.classify();
...@@ -91,12 +94,12 @@ public class LearnerTestUtils { ...@@ -91,12 +94,12 @@ public class LearnerTestUtils {
assertEquals("Not one item update!", 1, result.getNumItemUpdate()); assertEquals("Not one item update!", 1, result.getNumItemUpdate());
ItemUpdate update = result.getItemUpdate(0); ItemUpdate update = result.getItemUpdate(0);
// check that the output item is to be updated // check that the output item is to be updated
assertEquals("Output item not to be updated!", learnerTestSettings.getOutputItemProvider().get(), update.getItem()); assertEquals("Output item not to be updated!", settings.getOutputItemProvider().get(), update.getItem());
update.apply(); update.apply();
// check if the correct new state was set // check if the correct new state was set
String expected = learnerTestSettings.getExpectedOutput().apply(line); String expected = settings.getExpectedOutput().apply(line);
String actual = learnerTestSettings.getStateOfOutputItem().apply(update.getItem()); String actual = settings.getStateOfOutputItem().apply(update.getItem());
if (expected.equals(actual)) { if (settings.getCheckUpdate().assertEquals(expected, actual)) {
correct++; correct++;
} else { } else {
wrong++; wrong++;
...@@ -111,9 +114,111 @@ public class LearnerTestUtils { ...@@ -111,9 +114,111 @@ public class LearnerTestUtils {
assertThat(accuracy, greaterThan(LearnerTestConstants.MIN_ACCURACY)); assertThat(accuracy, greaterThan(LearnerTestConstants.MIN_ACCURACY));
} }
@FunctionalInterface
public interface CheckUpdate {
boolean assertEquals(String expected, String actual);
}
static String decodeOutput(String[] line) { static String decodeOutput(String[] line) {
int color = Integer.parseInt(line[2]); int color = Integer.parseInt(line[2]);
int brightness = Integer.parseInt(line[3]); int brightness = Integer.parseInt(line[3]);
return TupleHSB.of(color, 100, brightness).toString(); return TupleHSB.of(color, 100, brightness).toString();
} }
private static int hueDistance(int hue1, int hue2) {
int d = Math.abs(hue1 - hue2);
return d > 180 ? 360 - d : d;
}
/**
* Compares two colours given as strings of the form "HUE,SATURATION,BRIGHTNESS"
* @param expected the expected colour
* @param actual the computed, actual colour
* @return <code>true</code>, if both a colours are similar
*/
static boolean colorSimilar(String expected, String actual) {
TupleHSB expectedTuple = TupleHSB.parse(expected);
TupleHSB actualTuple = TupleHSB.parse(actual);
int diffHue = hueDistance(expectedTuple.getHue(), actualTuple.getHue());
int diffSaturation = Math.abs(expectedTuple.getSaturation() - actualTuple.getSaturation());
int diffBrightness = Math.abs(expectedTuple.getBrightness() - actualTuple.getBrightness());
double total = diffHue * COLOR_WEIGHTS[0] +
diffSaturation * COLOR_WEIGHTS[1] +
diffBrightness * COLOR_WEIGHTS[2];
// logger.debug("Diff expected {} and actual {}: H={} + S={} + B={} -> {} < {} ?", expected, actual,
// diffHue, diffSaturation, diffBrightness, total, MAX_COLOR_DIFFERENCE);
return total < MAX_COLOR_DIFFERENCE;
}
@Test
public void testColorSimilar() {
Map<String, TupleHSB> colors = new HashMap<>();
// reddish target colors
colors.put("pink", TupleHSB.of(350, 100, 82));
colors.put("orangeRed", TupleHSB.of(16, 100, 45));
colors.put("lightPink", TupleHSB.of(351, 100, 80));
colors.put("darkSalmon", TupleHSB.of(15, 71, 67));
colors.put("lightCoral", TupleHSB.of(0, 78, 63));
colors.put("darkRed", TupleHSB.of(0, 100, 16));
colors.put("indianRed", TupleHSB.of(0, 53, 49));
colors.put("lavenderBlush", TupleHSB.of(340, 100, 95));
colors.put("lavender", TupleHSB.of(240, 66, 90));
String[] targetColors = new String[]{"pink", "orangeRed", "lightPink", "darkSalmon", "lightCoral",
"darkRed", "indianRed", "lavenderBlush", "lavender"};
// reference colors
colors.put("blue", TupleHSB.of(240, 100, 11));
colors.put("blueViolet", TupleHSB.of(271, 75, 36));
colors.put("magenta", TupleHSB.of(300, 100, 41));
colors.put("purple", TupleHSB.of(300, 100, 20));
colors.put("red", TupleHSB.of(0, 100, 29));
colors.put("tomato", TupleHSB.of(9, 100, 55));
colors.put("orange", TupleHSB.of(39, 100, 67));
colors.put("yellow", TupleHSB.of(60, 100, 88));
colors.put("yellowGreen", TupleHSB.of(80, 60, 67));
colors.put("green", TupleHSB.of(120, 100, 29));
colors.put("springGreen", TupleHSB.of(150, 100, 64));
colors.put("cyan", TupleHSB.of(180, 100, 69));
colors.put("ivory", TupleHSB.of(60, 100, 98));
String[] referenceColors = new String[]{"blue", "blueViolet", "magenta", "purple", "red", "tomato",
"orange", "yellow", "yellowGreen", "green", "springGreen", "cyan", "ivory"};
/* Code to help producing similarity matrix */
// for (String target : targetColors) {
// String tmp = "";
// for (String reference : referenceColors) {
// tmp += assertColorSimilar(colors, target, reference) ? "x" : " ";
// tmp += ",";
// }
// System.out.println( "***" + target + ": " + tmp);
// }
String[] similarityMatrix = new String[]{
"blue, blueViolet, magenta, purple, red, tomato, orange, yellow, yellowGreen, green, springGreen, cyan, ivory", // <- reference colors
" , , x , x , x , x , x , x , , , , , x ", // pink
" , , x , x , x , x , x , x , , , , , x ", // orangeRed
" , , x , x , x , x , x , x , , , , , x ", // lightPink
" , , , , x , x , x , x , x , , , , x ", // darkSalmon
" , , x , x , x , x , x , x , x , , , , x ", // lightCoral
" , , x , x , x , x , x , , , , , , ", // darkRed
" , , x , , x , x , x , , , , , , ", // indianRed
" , , x , x , x , x , x , x , , , , , x ", // lavenderBlush
" x , x , , , , , , , , , , x , "}; // lavender
for (int targetIndex = 0; targetIndex < targetColors.length; targetIndex++) {
String target = targetColors[targetIndex];
String[] expectedValues = similarityMatrix[targetIndex + 1].split(",");
for (int referenceIndex = 0; referenceIndex < referenceColors.length; referenceIndex++) {
String reference = referenceColors[referenceIndex];
boolean expectedToBeSimilar = expectedValues[referenceIndex].contains("x");
String message = String.format("%s iss%s expected to be similar to %s, but %s!",
target, expectedToBeSimilar ? "" : " not", reference, expectedToBeSimilar ? "differs" : "it was");
assertEquals(message, expectedToBeSimilar,
colorSimilar(colors.get(reference).toString(), colors.get(target).toString()));
}
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment