From 9e205334e868180a5572f03403bad396ca27cc11 Mon Sep 17 00:00:00 2001
From: Damon Kohler <damonkohler@google.com>
Date: Tue, 24 Jul 2012 15:42:27 +0200
Subject: [PATCH] Fixes bugs in transform logic and add tests. Removes
 unused/dangerous methods for converting quaternions to axis-angle. Adds new
 method for getting the matrix form of a quaternion.

---
 .../ros/rosjava_geometry/FrameTransform.java  |  54 ++++++++-
 .../rosjava_geometry/FrameTransformTree.java  |  27 +++--
 .../org/ros/rosjava_geometry/Quaternion.java  |  22 +---
 .../org/ros/rosjava_geometry/Transform.java   |  62 +++++++---
 .../org/ros/rosjava_geometry/Vector3.java     |   2 +-
 .../FrameTransformTreeTest.java               | 113 ++++++++++++++++++
 .../ros/rosjava_geometry/QuaternionTest.java  |  52 +-------
 7 files changed, 235 insertions(+), 97 deletions(-)
 create mode 100644 rosjava_geometry/src/test/java/org/ros/rosjava_geometry/FrameTransformTreeTest.java

diff --git a/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/FrameTransform.java b/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/FrameTransform.java
index d258da88..f0a16392 100644
--- a/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/FrameTransform.java
+++ b/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/FrameTransform.java
@@ -16,10 +16,12 @@
 
 package org.ros.rosjava_geometry;
 
+import org.ros.message.Time;
 import org.ros.namespace.GraphName;
 
 /**
- * Describes a {@link Transform} from a source frame to a target frame.
+ * Describes a {@link Transform} from data in the source frame to data in the
+ * target frame.
  * 
  * @author damonkohler@google.com (Damon Kohler)
  */
@@ -32,8 +34,8 @@ public class FrameTransform {
   public static FrameTransform
       fromTransformStamped(geometry_msgs.TransformStamped transformStamped) {
     Transform transform = Transform.newFromTransformMessage(transformStamped.getTransform());
-    String source = transformStamped.getHeader().getFrameId();
-    String target = transformStamped.getChildFrameId();
+    String target = transformStamped.getHeader().getFrameId();
+    String source = transformStamped.getChildFrameId();
     return new FrameTransform(transform, GraphName.of(source), GraphName.of(target));
   }
 
@@ -55,8 +57,54 @@ public class FrameTransform {
     return target;
   }
 
+  public geometry_msgs.TransformStamped toTransformStampedMessage(Time stamp,
+      geometry_msgs.TransformStamped result) {
+    result.getHeader().setFrameId(target.toString());
+    result.getHeader().setStamp(stamp);
+    result.setChildFrameId(source.toString());
+    transform.toTransformMessage(result.getTransform());
+    return result;
+  }
+
   @Override
   public String toString() {
     return String.format("FrameTransform<Source: %s, Target: %s, %s>", source, target, transform);
   }
+
+  @Override
+  public int hashCode() {
+    final int prime = 31;
+    int result = 1;
+    result = prime * result + ((source == null) ? 0 : source.hashCode());
+    result = prime * result + ((target == null) ? 0 : target.hashCode());
+    result = prime * result + ((transform == null) ? 0 : transform.hashCode());
+    return result;
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    if (this == obj)
+      return true;
+    if (obj == null)
+      return false;
+    if (getClass() != obj.getClass())
+      return false;
+    FrameTransform other = (FrameTransform) obj;
+    if (source == null) {
+      if (other.source != null)
+        return false;
+    } else if (!source.equals(other.source))
+      return false;
+    if (target == null) {
+      if (other.target != null)
+        return false;
+    } else if (!target.equals(other.target))
+      return false;
+    if (transform == null) {
+      if (other.transform != null)
+        return false;
+    } else if (!transform.equals(other.transform))
+      return false;
+    return true;
+  }
 }
diff --git a/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/FrameTransformTree.java b/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/FrameTransformTree.java
index fdbe0c81..7af9a5e8 100644
--- a/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/FrameTransformTree.java
+++ b/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/FrameTransformTree.java
@@ -27,20 +27,19 @@ import java.util.Map;
 
 /**
  * A tree of {@link FrameTransform}s.
- * 
  * <p>
  * {@link FrameTransformTree} does not currently support time travel. Lookups
  * always use the newest {@link TransformStamped}.
  * 
+ * @author damonkohler@google.com (Damon Kohler)
  * @author moesenle@google.com (Lorenz Moesenlechner)
- * 
  */
 public class FrameTransformTree {
 
   private final NameResolver nameResolver;
 
   /**
-   * A {@link Map} of the most recent {@link LazyFrameTransform} by source
+   * A {@link Map} of the most recent {@link LazyFrameTransform} by target
    * frame.
    */
   private final Map<GraphName, LazyFrameTransform> transforms;
@@ -51,14 +50,19 @@ public class FrameTransformTree {
   }
 
   /**
-   * Updates the transform tree with the provided transform.
+   * Updates the tree with the provided {@link geometry_msgs.TransformStamped}
+   * message.
+   * <p>
+   * Note that the tree is updated lazily. Modifications to the provided
+   * {@link geometry_msgs.TransformStamped} message may cause unpredictable
+   * results.
    * 
    * @param transformStamped
-   *          the transform to update
+   *          the {@link geometry_msgs.TransformStamped} message to update with
    */
   public void updateTransform(geometry_msgs.TransformStamped transformStamped) {
-    GraphName source = nameResolver.resolve(transformStamped.getChildFrameId());
-    transforms.put(source, new LazyFrameTransform(transformStamped));
+    GraphName target = nameResolver.resolve(transformStamped.getChildFrameId());
+    transforms.put(target, new LazyFrameTransform(transformStamped));
   }
 
   private FrameTransform getLatestTransform(GraphName frame) {
@@ -111,12 +115,13 @@ public class FrameTransformTree {
     FrameTransform result =
         new FrameTransform(Transform.newIdentityTransform(), sourceFrame, sourceFrame);
     while (true) {
-      FrameTransform parent = getLatestTransform(result.getTargetFrame());
-      if (parent == null) {
+      FrameTransform resultToParent = getLatestTransform(result.getTargetFrame());
+      if (resultToParent == null) {
         return result;
       }
-      Transform transform = result.getTransform().multiply(parent.getTransform());
-      GraphName targetFrame = nameResolver.resolve(parent.getSourceFrame());
+      // Now resultToParent.getSourceFrame() == result.getTargetFrame()
+      Transform transform = resultToParent.getTransform().multiply(result.getTransform());
+      GraphName targetFrame = nameResolver.resolve(resultToParent.getTargetFrame());
       result = new FrameTransform(transform, sourceFrame, targetFrame);
     }
   }
diff --git a/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/Quaternion.java b/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/Quaternion.java
index 676ac49e..68f3e2f9 100644
--- a/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/Quaternion.java
+++ b/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/Quaternion.java
@@ -68,18 +68,6 @@ public class Quaternion {
     this.w = w;
   }
 
-  public double getAngle() {
-    return 2 * Math.acos(w);
-  }
-
-  public Vector3 getAxis() {
-    double length = Math.sqrt(1 - w * w);
-    if (length > 1e-9) {
-      return new Vector3(x / length, y / length, z / length);
-    }
-    return new Vector3(0, 0, 0);
-  }
-
   public Quaternion invert() {
     return new Quaternion(-x, -y, -z, w);
   }
@@ -136,6 +124,11 @@ public class Quaternion {
     this.w = w;
   }
 
+  @Override
+  public String toString() {
+    return String.format("Quaternion<x: %.4f, y: %.4f, z: %.4f, w: %.4f>", x, y, z, w);
+  }
+
   @Override
   public int hashCode() {
     final int prime = 31;
@@ -171,9 +164,4 @@ public class Quaternion {
       return false;
     return true;
   }
-
-  @Override
-  public String toString() {
-    return String.format("Quaternion<x: %.4f, y: %.4f, z: %.4f, w: %.4f>", x, y, z, w);
-  }
 }
diff --git a/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/Transform.java b/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/Transform.java
index 1587f78b..a24e4fa4 100644
--- a/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/Transform.java
+++ b/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/Transform.java
@@ -35,8 +35,7 @@ public class Transform {
   }
 
   public Transform multiply(Transform other) {
-    return new Transform(transformVector(other.getTranslation()),
-        transformQuaternion(other.getRotation()));
+    return new Transform(translate(other.getTranslation()), rotate(other.getRotation()));
   }
 
   public Transform invert() {
@@ -44,29 +43,33 @@ public class Transform {
     return new Transform(inverseRotation.rotateVector(translation.invert()), inverseRotation);
   }
 
-  public Vector3 transformVector(Vector3 vector) {
+  public Vector3 translate(Vector3 vector) {
     return translation.add(rotation.rotateVector(vector));
   }
 
-  public Quaternion transformQuaternion(Quaternion quaternion) {
+  public Quaternion rotate(Quaternion quaternion) {
     return rotation.multiply(quaternion);
   }
 
+  public double[] toMatrix() {
+    double x = getRotation().getX();
+    double y = getRotation().getY();
+    double z = getRotation().getZ();
+    double w = getRotation().getW();
+    return new double[] {
+        1 - 2 * y * y - 2 * z * z, 2 * x * y + 2 * z * w, 2 * x * z - 2 * y * w, 0,
+        2 * x * y - 2 * z * w, 1 - 2 * x * x - 2 * z * z, 2 * y * z + 2 * x * w, 0,
+        2 * x * z + 2 * y * w, 2 * y * z - 2 * x * w, 1 - 2 * x * x - 2 * y * y, 0,
+        getTranslation().getX(), getTranslation().getY(), getTranslation().getZ(), 1
+        };
+  }
+
   public geometry_msgs.Transform toTransformMessage(geometry_msgs.Transform result) {
     result.setTranslation(translation.toVector3Message(result.getTranslation()));
     result.setRotation(rotation.toQuaternionMessage(result.getRotation()));
     return result;
   }
 
-  public geometry_msgs.TransformStamped toTransformStampedMessage(GraphName frame,
-      GraphName childFrame, Time stamp, geometry_msgs.TransformStamped result) {
-    result.getHeader().setFrameId(frame.toString());
-    result.getHeader().setStamp(stamp);
-    result.setChildFrameId(childFrame.toString());
-    result.setTransform(toTransformMessage(result.getTransform()));
-    return result;
-  }
-
   public geometry_msgs.Pose toPoseMessage(geometry_msgs.Pose result) {
     result.setPosition(translation.toPointMessage(result.getPosition()));
     result.setOrientation(rotation.toQuaternionMessage(result.getOrientation()));
@@ -108,11 +111,42 @@ public class Transform {
   }
 
   public static Transform newIdentityTransform() {
-    return new Transform(Vector3.newIdentityVector3(), Quaternion.newIdentityQuaternion());
+    return new Transform(Vector3.newZeroVector(), Quaternion.newIdentityQuaternion());
   }
 
   @Override
   public String toString() {
     return String.format("Transform<%s, %s>", translation, rotation);
   }
+
+  @Override
+  public int hashCode() {
+    final int prime = 31;
+    int result = 1;
+    result = prime * result + ((rotation == null) ? 0 : rotation.hashCode());
+    result = prime * result + ((translation == null) ? 0 : translation.hashCode());
+    return result;
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    if (this == obj)
+      return true;
+    if (obj == null)
+      return false;
+    if (getClass() != obj.getClass())
+      return false;
+    Transform other = (Transform) obj;
+    if (rotation == null) {
+      if (other.rotation != null)
+        return false;
+    } else if (!rotation.equals(other.rotation))
+      return false;
+    if (translation == null) {
+      if (other.translation != null)
+        return false;
+    } else if (!translation.equals(other.translation))
+      return false;
+    return true;
+  }
 }
diff --git a/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/Vector3.java b/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/Vector3.java
index 4302dca2..39456e90 100644
--- a/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/Vector3.java
+++ b/rosjava_geometry/src/main/java/org/ros/rosjava_geometry/Vector3.java
@@ -103,7 +103,7 @@ public class Vector3 {
     return new Vector3(message.getX(), message.getY(), message.getZ());
   }
 
-  public static Vector3 newIdentityVector3() {
+  public static Vector3 newZeroVector() {
     return new Vector3(0, 0, 0);
   }
 
diff --git a/rosjava_geometry/src/test/java/org/ros/rosjava_geometry/FrameTransformTreeTest.java b/rosjava_geometry/src/test/java/org/ros/rosjava_geometry/FrameTransformTreeTest.java
new file mode 100644
index 00000000..bc5f78cf
--- /dev/null
+++ b/rosjava_geometry/src/test/java/org/ros/rosjava_geometry/FrameTransformTreeTest.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright (C) 2011 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package org.ros.rosjava_geometry;
+
+import static org.junit.Assert.assertEquals;
+
+import org.junit.Test;
+import org.ros.internal.message.DefaultMessageFactory;
+import org.ros.internal.message.definition.MessageDefinitionReflectionProvider;
+import org.ros.message.MessageDefinitionProvider;
+import org.ros.message.MessageFactory;
+import org.ros.message.Time;
+import org.ros.namespace.GraphName;
+import org.ros.namespace.NameResolver;
+
+/**
+ * @author damonkohler@google.com (Damon Kohler)
+ */
+public class FrameTransformTreeTest {
+
+  @Test
+  public void testIdentityTransforms() {
+    MessageDefinitionProvider messageDefinitionProvider = new MessageDefinitionReflectionProvider();
+    MessageFactory messageFactory = new DefaultMessageFactory(messageDefinitionProvider);
+    NameResolver nameResolver = NameResolver.newRoot();
+    FrameTransformTree frameTransformTree = new FrameTransformTree(nameResolver);
+
+    {
+      geometry_msgs.TransformStamped message =
+          messageFactory.newFromType(geometry_msgs.TransformStamped._TYPE);
+      Transform transform = Transform.newIdentityTransform();
+      FrameTransform frameTransform =
+          new FrameTransform(transform, GraphName.of("baz"), GraphName.of("bar"));
+      frameTransform.toTransformStampedMessage(new Time(), message);
+      frameTransformTree.updateTransform(message);
+    }
+
+    {
+      geometry_msgs.TransformStamped message =
+          messageFactory.newFromType(geometry_msgs.TransformStamped._TYPE);
+      Transform transform = Transform.newIdentityTransform();
+      FrameTransform frameTransform =
+          new FrameTransform(transform, GraphName.of("bar"), GraphName.of("foo"));
+      frameTransform.toTransformStampedMessage(new Time(), message);
+      frameTransformTree.updateTransform(message);
+    }
+
+    FrameTransform frameTransform =
+        frameTransformTree.newFrameTransform(GraphName.of("baz"), GraphName.of("foo"));
+    assertEquals(nameResolver.resolve("baz"), frameTransform.getSourceFrame());
+    assertEquals(nameResolver.resolve("foo"), frameTransform.getTargetFrame());
+    assertEquals(Transform.newIdentityTransform(), frameTransform.getTransform());
+  }
+
+  @Test
+  public void testTransformToRoot() {
+    MessageDefinitionProvider messageDefinitionProvider = new MessageDefinitionReflectionProvider();
+    MessageFactory messageFactory = new DefaultMessageFactory(messageDefinitionProvider);
+    NameResolver nameResolver = NameResolver.newRoot();
+    FrameTransformTree frameTransformTree = new FrameTransformTree(nameResolver);
+
+    {
+      geometry_msgs.TransformStamped message =
+          messageFactory.newFromType(geometry_msgs.TransformStamped._TYPE);
+      Vector3 vector = Vector3.newZeroVector();
+      Quaternion quaternion = new Quaternion(Math.sqrt(0.5), 0, 0, Math.sqrt(0.5));
+      Transform transform = new Transform(vector, quaternion);
+      GraphName source = GraphName.of("baz");
+      GraphName target = GraphName.of("bar");
+      FrameTransform frameTransform = new FrameTransform(transform, source, target);
+      frameTransform.toTransformStampedMessage(new Time(), message);
+      frameTransformTree.updateTransform(message);
+    }
+
+    {
+      geometry_msgs.TransformStamped message =
+          messageFactory.newFromType(geometry_msgs.TransformStamped._TYPE);
+      Vector3 vector = new Vector3(0, 1, 0);
+      Quaternion quaternion = Quaternion.newIdentityQuaternion();
+      Transform transform = new Transform(vector, quaternion);
+      GraphName source = GraphName.of("bar");
+      GraphName target = GraphName.of("foo");
+      FrameTransform frameTransform = new FrameTransform(transform, source, target);
+      frameTransform.toTransformStampedMessage(new Time(), message);
+      frameTransformTree.updateTransform(message);
+    }
+
+    FrameTransform frameTransform =
+        frameTransformTree.newFrameTransform(GraphName.of("baz"), GraphName.of("foo"));
+    // If we were to reverse the order of the transforms in our implementation,
+    // we would expect the translation vector to be <0, 0, 1> instead.
+    Vector3 vector = new Vector3(0, 1, 0);
+    Quaternion quaternion = new Quaternion(Math.sqrt(0.5), 0, 0, Math.sqrt(0.5));
+    Transform transform = new Transform(vector, quaternion);
+    assertEquals(nameResolver.resolve("baz"), frameTransform.getSourceFrame());
+    assertEquals(nameResolver.resolve("foo"), frameTransform.getTargetFrame());
+    assertEquals(transform, frameTransform.getTransform());
+  }
+}
diff --git a/rosjava_geometry/src/test/java/org/ros/rosjava_geometry/QuaternionTest.java b/rosjava_geometry/src/test/java/org/ros/rosjava_geometry/QuaternionTest.java
index ee33c1f9..1a9694e6 100644
--- a/rosjava_geometry/src/test/java/org/ros/rosjava_geometry/QuaternionTest.java
+++ b/rosjava_geometry/src/test/java/org/ros/rosjava_geometry/QuaternionTest.java
@@ -25,56 +25,6 @@ import org.junit.Test;
  */
 public class QuaternionTest {
 
-  @Test
-  public void testCalculateRotationAngleAxis() {
-    Quaternion quaternion;
-    Vector3 axis;
-
-    quaternion = new Quaternion(0, 0, 0, 1);
-    assertEquals(0.0, quaternion.getAngle(), 1e-9);
-    axis = quaternion.getAxis();
-    assertEquals(0, axis.getX(), 1e-9);
-    assertEquals(0, axis.getY(), 1e-9);
-    assertEquals(0, axis.getZ(), 1e-9);
-
-    quaternion = new Quaternion(0, 0, 1, 0);
-    assertEquals(Math.PI, quaternion.getAngle(), 1e-9);
-    axis = quaternion.getAxis();
-    assertEquals(0, axis.getX(), 1e-9);
-    assertEquals(0, axis.getY(), 1e-9);
-    assertEquals(1, axis.getZ(), 1e-9);
-
-    quaternion = new Quaternion(0, 0, -0.7071067811865475, 0.7071067811865475);
-    // The actual angle is -Math.PI / 2 but this is represented by a flipped
-    // rotation axis in the quaternion.
-    assertEquals(Math.PI / 2, quaternion.getAngle(), 1e-9);
-    axis = quaternion.getAxis();
-    assertEquals(0, axis.getX(), 1e-9);
-    assertEquals(0, axis.getY(), 1e-9);
-    assertEquals(-1, axis.getZ(), 1e-9);
-
-    quaternion = new Quaternion(0, 0, 0.9238795325112867, 0.38268343236508984);
-    assertEquals(0.75 * Math.PI, quaternion.getAngle(), 1e-9);
-    axis = quaternion.getAxis();
-    assertEquals(0, axis.getX(), 1e-9);
-    assertEquals(0, axis.getY(), 1e-9);
-    assertEquals(1, axis.getZ(), 1e-9);
-
-    quaternion = new Quaternion(0, 0, -0.9238795325112867, 0.38268343236508984);
-    assertEquals(0.75 * Math.PI, quaternion.getAngle(), 1e-9);
-    axis = quaternion.getAxis();
-    assertEquals(0, axis.getX(), 1e-9);
-    assertEquals(0, axis.getY(), 1e-9);
-    assertEquals(-1, axis.getZ(), 1e-9);
-
-    quaternion = new Quaternion(0, 0, 0.7071067811865475, -0.7071067811865475);
-    assertEquals(1.5 * Math.PI, quaternion.getAngle(), 1e-9);
-    axis = quaternion.getAxis();
-    assertEquals(0, axis.getX(), 1e-9);
-    assertEquals(0, axis.getY(), 1e-9);
-    assertEquals(1, axis.getZ(), 1e-9);
-  }
-
   @Test
   public void testAxisAngleToQuaternion() {
     Quaternion quaternion;
@@ -136,7 +86,7 @@ public class QuaternionTest {
     Quaternion quaternion = Quaternion.newFromAxisAngle(new Vector3(0, 0, 1), Math.PI / 2);
     Quaternion inverse = quaternion.invert();
     Quaternion rotated = quaternion.multiply(inverse);
-    assertEquals(0, rotated.getAngle(), 1e-9);
+    assertEquals(1, rotated.getW(), 1e-9);
   }
 
   @Test
-- 
GitLab