Skip to content
Snippets Groups Projects
Commit da21d563 authored by René Schöne's avatar René Schöne
Browse files

Use color similarity to test preferences.

parent cf9ce4a0
Branches
No related tags found
1 merge request!19dev to master
Pipeline #4719 failed
...@@ -84,6 +84,7 @@ public class LearnerTest { ...@@ -84,6 +84,7 @@ public class LearnerTest {
.orElseThrow(() -> new AssertionError( .orElseThrow(() -> new AssertionError(
"Item " + LearnerTestConstants.PREFERENCE_OUTPUT_ITEM_NAME + " not found"))) "Item " + LearnerTestConstants.PREFERENCE_OUTPUT_ITEM_NAME + " not found")))
.setStateOfOutputItem(Item::getStateAsString) .setStateOfOutputItem(Item::getStateAsString)
.setCheckUpdate(LearnerTestUtils::colorSimilar)
.setFactoryTarget(PREFERENCE_LEARNING) .setFactoryTarget(PREFERENCE_LEARNING)
.setSingleUpdateList(true)); .setSingleUpdateList(true));
} }
......
...@@ -8,6 +8,10 @@ package de.tudresden.inf.st.eraser.feedbackloop.learner_backup; ...@@ -8,6 +8,10 @@ package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
public interface LearnerTestConstants { public interface LearnerTestConstants {
/** Minimal accuracy (correct / total classifications) */ /** Minimal accuracy (correct / total classifications) */
double MIN_ACCURACY = 0.8; double MIN_ACCURACY = 0.8;
/** Maximum difference when comparing colors */
double MAX_COLOR_DIFFERENCE = 0.2;
/** Weights for difference (in order: Hue, Saturation, Brightness) when comparing colors */
double[] COLOR_WEIGHTS = new double[]{0.8/360, 0.1/100, 0.1/100};
/** Names of item names for activity recognition, in test data */ /** Names of item names for activity recognition, in test data */
String[] ACTIVITY_INPUT_ITEM_NAMES = new String[]{"m_accel_x", "m_accel_y", "m_accel_z", "m_rotation_x", "m_rotation_y", String[] ACTIVITY_INPUT_ITEM_NAMES = new String[]{"m_accel_x", "m_accel_y", "m_accel_z", "m_rotation_x", "m_rotation_y",
"m_rotation_z", "w_accel_x", "w_accel_y", "w_accel_z", "w_rotation_x", "w_rotation_y", "w_rotation_z"}; "m_rotation_z", "w_accel_x", "w_accel_y", "w_accel_z", "w_rotation_x", "w_rotation_y", "w_rotation_z"};
......
package de.tudresden.inf.st.eraser.feedbackloop.learner_backup; package de.tudresden.inf.st.eraser.feedbackloop.learner_backup;
import de.tudresden.inf.st.eraser.feedbackloop.learner_backup.LearnerTestUtils.CheckUpdate;
import de.tudresden.inf.st.eraser.jastadd.model.Item; import de.tudresden.inf.st.eraser.jastadd.model.Item;
import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningHandlerFactory; import de.tudresden.inf.st.eraser.jastadd.model.MachineLearningHandlerFactory;
import lombok.Data; import lombok.Data;
...@@ -14,7 +15,7 @@ import java.util.function.Supplier; ...@@ -14,7 +15,7 @@ import java.util.function.Supplier;
@Data @Data
@Accessors(chain = true) @Accessors(chain = true)
public class LearnerTestSettings { class LearnerTestSettings {
private URL configURL; private URL configURL;
private URL dataURL; private URL dataURL;
private String[] inputItemNames; private String[] inputItemNames;
...@@ -22,10 +23,12 @@ public class LearnerTestSettings { ...@@ -22,10 +23,12 @@ public class LearnerTestSettings {
private final Map<String, BiConsumer<Item, String>> specialInputHandler = new HashMap<>(); private final Map<String, BiConsumer<Item, String>> specialInputHandler = new HashMap<>();
private Supplier<Item> outputItemProvider; private Supplier<Item> outputItemProvider;
private Function<Item, String> stateOfOutputItem; private Function<Item, String> stateOfOutputItem;
private CheckUpdate checkUpdate = String::equals;
private MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget; private MachineLearningHandlerFactory.MachineLearningHandlerFactoryTarget factoryTarget;
private boolean singleUpdateList; private boolean singleUpdateList;
public LearnerTestSettings putSpecialInputHandler(String itemName, BiConsumer<Item, String> handler) { @SuppressWarnings("SameParameterValue")
LearnerTestSettings putSpecialInputHandler(String itemName, BiConsumer<Item, String> handler) {
specialInputHandler.put(itemName, handler); specialInputHandler.put(itemName, handler);
return this; return this;
} }
......
...@@ -5,6 +5,7 @@ import de.tudresden.inf.st.eraser.jastadd.model.*; ...@@ -5,6 +5,7 @@ import de.tudresden.inf.st.eraser.jastadd.model.*;
import de.tudresden.inf.st.eraser.util.ParserUtils; import de.tudresden.inf.st.eraser.util.ParserUtils;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.junit.Test;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
...@@ -13,6 +14,8 @@ import java.io.Reader; ...@@ -13,6 +14,8 @@ import java.io.Reader;
import java.util.*; import java.util.*;
import java.util.stream.Stream; import java.util.stream.Stream;
import static de.tudresden.inf.st.eraser.feedbackloop.learner_backup.LearnerTestConstants.COLOR_WEIGHTS;
import static de.tudresden.inf.st.eraser.feedbackloop.learner_backup.LearnerTestConstants.MAX_COLOR_DIFFERENCE;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.*; import static org.junit.Assert.*;
...@@ -55,35 +58,35 @@ public class LearnerTestUtils { ...@@ -55,35 +58,35 @@ public class LearnerTestUtils {
} }
static void testLearner( static void testLearner(
LearnerSubjectUnderTest sut, LearnerTestSettings learnerTestSettings) throws IOException { LearnerSubjectUnderTest sut, LearnerTestSettings settings) throws IOException {
sut.initFor(learnerTestSettings.getFactoryTarget(), learnerTestSettings.getConfigURL()); sut.initFor(settings.getFactoryTarget(), settings.getConfigURL());
// maybe use factory.createModel() here instead // maybe use factory.createModel() here instead
// go through same csv as for training and test some of the values // go through same csv as for training and test some of the values
int correct = 0, wrong = 0; int correct = 0, wrong = 0;
try(InputStream is = learnerTestSettings.getDataURL().openStream(); try(InputStream is = settings.getDataURL().openStream();
Reader reader = new InputStreamReader(is); Reader reader = new InputStreamReader(is);
CSVReader csvreader = new CSVReader(reader)) { CSVReader csvreader = new CSVReader(reader)) {
int index = 0; int index = 0;
for (String[] line : csvreader) { for (String[] line : csvreader) {
if (++index % 10 == 0) { if (++index % 10 == 0) {
// only check every 10th line, push an update for every 12 input columns // only check every 10th line, push an update for every 12 input columns
List<Item> itemsToUpdate = new ArrayList<>(learnerTestSettings.getInputItemNames().length); List<Item> itemsToUpdate = new ArrayList<>(settings.getInputItemNames().length);
for (int i = 0; i < learnerTestSettings.getInputItemNames().length; i++) { for (int i = 0; i < settings.getInputItemNames().length; i++) {
String itemName = learnerTestSettings.getInputItemNames()[i]; String itemName = settings.getInputItemNames()[i];
Item item = sut.root.getSmartHomeEntityModel().resolveItem(itemName) Item item = sut.root.getSmartHomeEntityModel().resolveItem(itemName)
.orElseThrow(() -> new AssertionError("Item " + itemName + " not found")); .orElseThrow(() -> new AssertionError("Item " + itemName + " not found"));
if (learnerTestSettings.getSpecialInputHandler().containsKey(itemName)) { if (settings.getSpecialInputHandler().containsKey(itemName)) {
learnerTestSettings.getSpecialInputHandler().get(itemName).accept(item, line[i]); settings.getSpecialInputHandler().get(itemName).accept(item, line[i]);
} else { } else {
item.setStateFromString(line[i]); item.setStateFromString(line[i]);
} }
if (learnerTestSettings.isSingleUpdateList()) { if (settings.isSingleUpdateList()) {
itemsToUpdate.add(item); itemsToUpdate.add(item);
} else { } else {
sut.encoder.newData(Collections.singletonList(item)); sut.encoder.newData(Collections.singletonList(item));
} }
} }
if (learnerTestSettings.isSingleUpdateList()) { if (settings.isSingleUpdateList()) {
sut.encoder.newData(itemsToUpdate); sut.encoder.newData(itemsToUpdate);
} }
MachineLearningResult result = sut.decoder.classify(); MachineLearningResult result = sut.decoder.classify();
...@@ -91,12 +94,12 @@ public class LearnerTestUtils { ...@@ -91,12 +94,12 @@ public class LearnerTestUtils {
assertEquals("Not one item update!", 1, result.getNumItemUpdate()); assertEquals("Not one item update!", 1, result.getNumItemUpdate());
ItemUpdate update = result.getItemUpdate(0); ItemUpdate update = result.getItemUpdate(0);
// check that the output item is to be updated // check that the output item is to be updated
assertEquals("Output item not to be updated!", learnerTestSettings.getOutputItemProvider().get(), update.getItem()); assertEquals("Output item not to be updated!", settings.getOutputItemProvider().get(), update.getItem());
update.apply(); update.apply();
// check if the correct new state was set // check if the correct new state was set
String expected = learnerTestSettings.getExpectedOutput().apply(line); String expected = settings.getExpectedOutput().apply(line);
String actual = learnerTestSettings.getStateOfOutputItem().apply(update.getItem()); String actual = settings.getStateOfOutputItem().apply(update.getItem());
if (expected.equals(actual)) { if (settings.getCheckUpdate().assertEquals(expected, actual)) {
correct++; correct++;
} else { } else {
wrong++; wrong++;
...@@ -111,9 +114,111 @@ public class LearnerTestUtils { ...@@ -111,9 +114,111 @@ public class LearnerTestUtils {
assertThat(accuracy, greaterThan(LearnerTestConstants.MIN_ACCURACY)); assertThat(accuracy, greaterThan(LearnerTestConstants.MIN_ACCURACY));
} }
@FunctionalInterface
public interface CheckUpdate {
boolean assertEquals(String expected, String actual);
}
static String decodeOutput(String[] line) { static String decodeOutput(String[] line) {
int color = Integer.parseInt(line[2]); int color = Integer.parseInt(line[2]);
int brightness = Integer.parseInt(line[3]); int brightness = Integer.parseInt(line[3]);
return TupleHSB.of(color, 100, brightness).toString(); return TupleHSB.of(color, 100, brightness).toString();
} }
private static int hueDistance(int hue1, int hue2) {
int d = Math.abs(hue1 - hue2);
return d > 180 ? 360 - d : d;
}
/**
* Compares two colours given as strings of the form "HUE,SATURATION,BRIGHTNESS"
* @param expected the expected colour
* @param actual the computed, actual colour
* @return <code>true</code>, if both a colours are similar
*/
static boolean colorSimilar(String expected, String actual) {
TupleHSB expectedTuple = TupleHSB.parse(expected);
TupleHSB actualTuple = TupleHSB.parse(actual);
int diffHue = hueDistance(expectedTuple.getHue(), actualTuple.getHue());
int diffSaturation = Math.abs(expectedTuple.getSaturation() - actualTuple.getSaturation());
int diffBrightness = Math.abs(expectedTuple.getBrightness() - actualTuple.getBrightness());
double total = diffHue * COLOR_WEIGHTS[0] +
diffSaturation * COLOR_WEIGHTS[1] +
diffBrightness * COLOR_WEIGHTS[2];
// logger.debug("Diff expected {} and actual {}: H={} + S={} + B={} -> {} < {} ?", expected, actual,
// diffHue, diffSaturation, diffBrightness, total, MAX_COLOR_DIFFERENCE);
return total < MAX_COLOR_DIFFERENCE;
}
@Test
public void testColorSimilar() {
Map<String, TupleHSB> colors = new HashMap<>();
// reddish target colors
colors.put("pink", TupleHSB.of(350, 100, 82));
colors.put("orangeRed", TupleHSB.of(16, 100, 45));
colors.put("lightPink", TupleHSB.of(351, 100, 80));
colors.put("darkSalmon", TupleHSB.of(15, 71, 67));
colors.put("lightCoral", TupleHSB.of(0, 78, 63));
colors.put("darkRed", TupleHSB.of(0, 100, 16));
colors.put("indianRed", TupleHSB.of(0, 53, 49));
colors.put("lavenderBlush", TupleHSB.of(340, 100, 95));
colors.put("lavender", TupleHSB.of(240, 66, 90));
String[] targetColors = new String[]{"pink", "orangeRed", "lightPink", "darkSalmon", "lightCoral",
"darkRed", "indianRed", "lavenderBlush", "lavender"};
// reference colors
colors.put("blue", TupleHSB.of(240, 100, 11));
colors.put("blueViolet", TupleHSB.of(271, 75, 36));
colors.put("magenta", TupleHSB.of(300, 100, 41));
colors.put("purple", TupleHSB.of(300, 100, 20));
colors.put("red", TupleHSB.of(0, 100, 29));
colors.put("tomato", TupleHSB.of(9, 100, 55));
colors.put("orange", TupleHSB.of(39, 100, 67));
colors.put("yellow", TupleHSB.of(60, 100, 88));
colors.put("yellowGreen", TupleHSB.of(80, 60, 67));
colors.put("green", TupleHSB.of(120, 100, 29));
colors.put("springGreen", TupleHSB.of(150, 100, 64));
colors.put("cyan", TupleHSB.of(180, 100, 69));
colors.put("ivory", TupleHSB.of(60, 100, 98));
String[] referenceColors = new String[]{"blue", "blueViolet", "magenta", "purple", "red", "tomato",
"orange", "yellow", "yellowGreen", "green", "springGreen", "cyan", "ivory"};
/* Code to help producing similarity matrix */
// for (String target : targetColors) {
// String tmp = "";
// for (String reference : referenceColors) {
// tmp += assertColorSimilar(colors, target, reference) ? "x" : " ";
// tmp += ",";
// }
// System.out.println( "***" + target + ": " + tmp);
// }
String[] similarityMatrix = new String[]{
"blue, blueViolet, magenta, purple, red, tomato, orange, yellow, yellowGreen, green, springGreen, cyan, ivory", // <- reference colors
" , , x , x , x , x , x , x , , , , , x ", // pink
" , , x , x , x , x , x , x , , , , , x ", // orangeRed
" , , x , x , x , x , x , x , , , , , x ", // lightPink
" , , , , x , x , x , x , x , , , , x ", // darkSalmon
" , , x , x , x , x , x , x , x , , , , x ", // lightCoral
" , , x , x , x , x , x , , , , , , ", // darkRed
" , , x , , x , x , x , , , , , , ", // indianRed
" , , x , x , x , x , x , x , , , , , x ", // lavenderBlush
" x , x , , , , , , , , , , x , "}; // lavender
for (int targetIndex = 0; targetIndex < targetColors.length; targetIndex++) {
String target = targetColors[targetIndex];
String[] expectedValues = similarityMatrix[targetIndex + 1].split(",");
for (int referenceIndex = 0; referenceIndex < referenceColors.length; referenceIndex++) {
String reference = referenceColors[referenceIndex];
boolean expectedToBeSimilar = expectedValues[referenceIndex].contains("x");
String message = String.format("%s iss%s expected to be similar to %s, but %s!",
target, expectedToBeSimilar ? "" : " not", reference, expectedToBeSimilar ? "differs" : "it was");
assertEquals(message, expectedToBeSimilar,
colorSimilar(colors.get(reference).toString(), colors.get(target).toString()));
}
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment