Skip to content
Snippets Groups Projects
Select Git revision
  • 1c945a6e63972c077103187b844b60175200dd2c
  • dev default protected
  • main protected
  • feature/ros-java-integration
4 results

MqttHandler.jadd

Blame
  • MqttHandler.jadd 17.36 KiB
    import java.util.List;
    import java.util.concurrent.CountDownLatch;
    import java.util.concurrent.TimeUnit;
    import java.util.function.BiConsumer;aspect MqttHandler {
    public class MqttServerHandler {
      private final java.util.Map<String, MqttHandler> handlers = new java.util.HashMap<>();
      private final java.util.Map<ConnectToken, java.util.function.BiConsumer<String, byte[]>> tokensForRemoval = new java.util.HashMap<>();
      private long time;
      private java.util.concurrent.TimeUnit unit;
      private String name;
    
      public MqttServerHandler() {
        this("RagConnect");
      }
    
      public MqttServerHandler(String name) {
        this.name = name;
        setupWaitUntilReady(1, java.util.concurrent.TimeUnit.SECONDS);
      }
    
      public void setupWaitUntilReady(long time, java.util.concurrent.TimeUnit unit) {
        this.time = time;
        this.unit = unit;
      }
    
      public MqttHandler resolveHandler(java.net.URI uri) throws java.io.IOException {
        MqttHandler handler = handlers.get(uri.getHost());
        if (handler == null) {
          // first connect to that server
          handler = new MqttHandler();
          if (uri.getPort() == -1) {
            handler.setHost(uri.getHost());
          } else {
            handler.setHost(uri.getHost(), uri.getPort());
          }
          handlers.put(uri.getHost(), handler);
        }
        handler.waitUntilReady(this.time, this.unit);
        return handler;
      }
    
      public ConnectToken newConnection(java.net.URI uri, java.util.function.BiConsumer<String, byte[]> callback) throws java.io.IOException {
        ConnectToken connectToken = new ConnectToken(uri);
        resolveHandler(uri).newConnection(extractTopic(uri), callback);
        tokensForRemoval.put(connectToken, callback);
        return connectToken;
      }
    
      public boolean disconnect(ConnectToken connectToken) throws java.io.IOException {
        MqttHandler handler = resolveHandler(connectToken.uri);
        return handler != null ? handler.disconnect(extractTopic(connectToken.uri), tokensForRemoval.get(connectToken)) : false;
      }
    
      public void publish(java.net.URI uri, byte[] bytes) throws java.io.IOException {
        resolveHandler(uri).publish(extractTopic(uri), bytes);
      }
    
      public void publish(java.net.URI uri, byte[] bytes, boolean retain) throws java.io.IOException {
        resolveHandler(uri).publish(extractTopic(uri), bytes, retain);
      }
    
      public void publish(java.net.URI uri, byte[] bytes,
                          org.fusesource.mqtt.client.QoS qos, boolean retain) throws java.io.IOException {
        resolveHandler(uri).publish(extractTopic(uri), bytes, qos, retain);
      }
    
      public static String extractTopic(java.net.URI uri) {
        String path = uri.getPath();
        if (uri.getFragment() != null) {
          // do not also append fragment, as it is illegal, that anything follows "#" in a mqtt topic anyway
          path += "#";
        }
        if (path.charAt(0) == '/') {
          path = path.substring(1);
        }
        return path;
      }
    
      public void close() {
        for (MqttHandler handler : handlers.values()) {
          handler.close();
        }
      }
    
    }
    /**
     * Helper class to receive updates via MQTT and use callbacks to handle those messages.
     *
     * @author rschoene - Initial contribution
     */
    public class MqttHandler {
      private static final int DEFAULT_PORT = 1883;
    
      private final org.apache.logging.log4j.Logger logger;
      private final String name;
    
      /** The host running the MQTT broker. */
      private java.net.URI host;
      /** The connection to the MQTT broker. */
      private org.fusesource.mqtt.client.CallbackConnection connection;
      /** Whether we are connected yet */
      private final java.util.concurrent.CountDownLatch readyLatch;
      private boolean sendWelcomeMessage = true;
      private org.fusesource.mqtt.client.QoS qos;
      /** Dispatch knowledge */
      private final java.util.Map<String, java.util.List<java.util.function.BiConsumer<String, byte[]>>> normalCallbacks;
      private final java.util.Map<java.util.regex.Pattern, java.util.List<java.util.function.BiConsumer<String, byte[]>>> wildcardCallbacks;
    
      public MqttHandler() {
        this("RagConnect");
      }
    
      public MqttHandler(String name) {
        this.name = java.util.Objects.requireNonNull(name, "Name must be set");
        this.logger = org.apache.logging.log4j.LogManager.getLogger(MqttHandler.class);
        this.normalCallbacks = new java.util.HashMap<>();
        this.wildcardCallbacks = new java.util.HashMap<>();
        this.readyLatch = new java.util.concurrent.CountDownLatch(1);
        this.qos = org.fusesource.mqtt.client.QoS.AT_LEAST_ONCE;
      }
    
      public MqttHandler dontSendWelcomeMessage() {
        this.sendWelcomeMessage = false;
        return this;
      }
    
      /**
       * Sets the host to receive messages from, and connects to it.
       * @param host name of the host to connect to, format is either <code>"$name"</code> or <code>"$name:$port"</code>
       * @throws java.io.IOException if could not connect, or could not subscribe to a topic
       * @return self
       */
      public MqttHandler setHost(String host) throws java.io.IOException {
        if (host.contains(":")) {
          int colon_index = host.indexOf(":");
          return setHost(host.substring(0, colon_index),
              Integer.parseInt(host.substring(colon_index + 1)));
        }
        return setHost(host, DEFAULT_PORT);
      }
    
      /**
       * Sets the host to receive messages from, and connects to it.
       * @throws java.io.IOException if could not connect, or could not subscribe to a topic
       * @return self
       */
      public MqttHandler setHost(String host, int port) throws java.io.IOException {
        java.util.Objects.requireNonNull(host, "Host need to be set!");
    
        this.host = java.net.URI.create("tcp://" + host + ":" + port);
        logger.debug("Host for {} is {}", this.name, this.host);
    
        org.fusesource.mqtt.client.MQTT mqtt = new org.fusesource.mqtt.client.MQTT();
        mqtt.setHost(this.host);
        connection = mqtt.callbackConnection();
        java.util.concurrent.atomic.AtomicReference<Throwable> error = new java.util.concurrent.atomic.AtomicReference<>();
    
        // add the listener to dispatch messages later
        connection.listener(new org.fusesource.mqtt.client.ExtendedListener() {
          public void onConnected() {
            logger.debug("Connected");
          }
    
          @Override
          public void onDisconnected() {
            logger.debug("Disconnected");
          }
    
          @Override
          public void onPublish(org.fusesource.hawtbuf.UTF8Buffer topic,
                                org.fusesource.hawtbuf.Buffer body,
                                org.fusesource.mqtt.client.Callback<org.fusesource.mqtt.client.Callback<Void>> ack) {
            // this method is called, whenever a MQTT message is received
            String topicString = topic.toString();
            java.util.List<java.util.function.BiConsumer<String, byte[]>> callbackList = callbacksFor(topicString);
            if (callbackList.isEmpty()) {
              logger.debug("Got a message at {}, but no callback to call. Forgot to subscribe?", topic);
            } else {
              byte[] message = body.toByteArray();
              for (java.util.function.BiConsumer<String, byte[]> callback : callbackList) {
                try {
                  callback.accept(topicString, message);
                } catch (Exception e) {
                  logger.catching(e);
                }
              }
            }
            ack.onSuccess(null);  // always acknowledge message
          }
    
          @Override
          public void onPublish(org.fusesource.hawtbuf.UTF8Buffer topicBuffer,
                                org.fusesource.hawtbuf.Buffer body,
                                Runnable ack) {
            // not used by this type of connection
            logger.warn("onPublish should not be called");
          }
    
          @Override
          public void onFailure(Throwable cause) {
            error.set(cause);
          }
        });
        throwIf(error);
    
        // actually establish the connection
        connection.connect(new org.fusesource.mqtt.client.Callback<>() {
          @Override
          public void onSuccess(Void value) {
            if (MqttHandler.this.sendWelcomeMessage) {
              connection.publish("components",
                  (name + " is connected").getBytes(),
                  org.fusesource.mqtt.client.QoS.AT_LEAST_ONCE,
                  false,
                  new org.fusesource.mqtt.client.Callback<>() {
                    @Override
                    public void onSuccess(Void value) {
                      logger.debug("success sending welcome message");
                      setReady();
                    }
    
                @Override
                public void onFailure(Throwable value) {
                  logger.debug("failure sending welcome message", value);
                }
              });
            } else {
              setReady();
            }
          }
    
          @Override
          public void onFailure(Throwable cause) {
            error.set(cause);
          }
        });
        throwIf(error);
        return this;
      }
    
      private java.util.List<java.util.function.BiConsumer<String, byte[]>> callbacksFor(String topicString) {
        java.util.List<java.util.function.BiConsumer<String, byte[]>> result = new java.util.ArrayList<>();
        List<BiConsumer<String, byte[]>> normalCallbackList = normalCallbacks.get(topicString);
        if (normalCallbackList != null) {
          result.addAll(normalCallbackList);
        }
        wildcardCallbacks.forEach((topicPattern, callback) -> {
          if (topicPattern.matcher(topicString).matches()) {
            result.addAll(callback);
          }
        });
        return result;
      }
    
      public java.net.URI getHost() {
        return host;
      }
    
      private void setReady() {
        readyLatch.countDown();
      }
    
      private void throwIf(java.util.concurrent.atomic.AtomicReference<Throwable> error) throws java.io.IOException {
        if (error.get() != null) {
          throw new java.io.IOException(error.get());
        }
      }
    
      public void setQoSForSubscription(org.fusesource.mqtt.client.QoS qos) {
        this.qos = qos;
      }
    
      /**
       * Establish a new connection for some topic.
       * @param topic    the topic to create a connection for, may contain the wildcards "*" and "#"
       * @param callback the callback to run if a new message arrives for this topic
       * @return true if successful stored this connection, false otherwise (e.g., on failed subscribe)
       */
      public boolean newConnection(String topic, java.util.function.Consumer<byte[]> callback) {
        return newConnection(topic, (ignoredTopicString, bytes) -> callback.accept(bytes));
      }
    
      /**
       * Establish a new connection for some topic.
       * @param topic    the topic to create a connection for, may contain the wildcards "*" and "#"
       * @param callback the callback to run if a new message arrives for this topic
       * @return true if successful stored this connection, false otherwise (e.g., on failed subscribe)
       */
      public boolean newConnection(String topic, java.util.function.BiConsumer<String, byte[]> callback) {
        if (readyLatch.getCount() > 0) {
          System.err.println("Handler not ready");
          return false;
        }
        // register callback
        logger.debug("new connection for {}", topic);
        final boolean needSubscribe;
        if (isWildcardTopic(topic)) {
          String regexForTopic = topic.replace("*", "[^/]*").replace("#", ".*");
          java.util.regex.Pattern pattern = java.util.regex.Pattern.compile(regexForTopic);
          wildcardCallbacks.computeIfAbsent(pattern, p -> new java.util.ArrayList<>())
                           .add(callback);
          needSubscribe = true;
        } else { // normal topic
          java.util.List<java.util.function.BiConsumer<String, byte[]>> callbacksForTopic = normalCallbacks.get(topic);
          if (callbacksForTopic == null || callbacksForTopic.isEmpty()) {
            callbacksForTopic = new java.util.ArrayList<>();
            normalCallbacks.put(topic, callbacksForTopic);
            needSubscribe = true;
          } else {
            needSubscribe = false;
          }
          callbacksForTopic.add(callback);
        }
        if (needSubscribe) {
          // subscribe at broker
          CountDownLatch operationFinished = new CountDownLatch(1);
          java.util.concurrent.atomic.AtomicReference<Boolean> success = new java.util.concurrent.atomic.AtomicReference<>(true);
          org.fusesource.mqtt.client.Topic[] topicArray = { new org.fusesource.mqtt.client.Topic(topic, this.qos) };
          connection.getDispatchQueue().execute(() -> {
            connection.subscribe(topicArray, new org.fusesource.mqtt.client.Callback<>() {
              @Override
              public void onSuccess(byte[] qoses) {
                logger.debug("Subscribed to {}, qoses: {}", topic, qoses);
                operationFinished.countDown();
              }
    
              @Override
              public void onFailure(Throwable cause) {
                logger.error("Could not subscribe to {}", topic, cause);
                success.set(false);
                operationFinished.countDown();
              }
            });
          });
          try {
            operationFinished.await(2, TimeUnit.SECONDS);
            return success.get();
          } catch (InterruptedException e) {
            return false;
          }
        } else {
          return true;
        }
      }
    
      private boolean isWildcardTopic(String topic) {
        return topic.contains("*") || topic.contains("#");
      }
    
      public boolean disconnect(String topic, java.util.function.BiConsumer<String, byte[]> callback) {
        boolean needUnsubscribe = false;
        java.util.concurrent.atomic.AtomicReference<Boolean> success = new java.util.concurrent.atomic.AtomicReference<>(true);
    
        boolean foundTopicInCallbacks = false;
    
        // check if wildcard is to be removed
        if (isWildcardTopic(topic)) {
          java.util.regex.Pattern wildcardPatternToRemove = null;
          for (java.util.Map.Entry<java.util.regex.Pattern, java.util.List<java.util.function.BiConsumer<String, byte[]>>> entry : wildcardCallbacks.entrySet()) {
            if (entry.getKey().pattern().equals(topic)) {
              foundTopicInCallbacks = true;
              // if still successful, update with whether callback could be removed
              success.compareAndSet(true, (entry.getValue().remove(callback)));
              if (entry.getValue().isEmpty()) {
                wildcardPatternToRemove = entry.getKey();
                needUnsubscribe = true;
              }
              break;
            }
          }
          ;
          if (wildcardPatternToRemove != null) {
            wildcardCallbacks.remove(wildcardPatternToRemove);
          }
        } else if (normalCallbacks.containsKey(topic)) {
          foundTopicInCallbacks = true;
          // if still successful, update with whether callback could be removed
          var normalCallbackList = normalCallbacks.get(topic);
          success.compareAndSet(true, normalCallbackList.remove(callback));
          needUnsubscribe |= normalCallbackList.isEmpty();
        }
    
        if (!foundTopicInCallbacks) {
          logger.warn("Disconnect for not connected topic '{}'", topic);
          return false;
        }
    
        if (needUnsubscribe) {
          java.util.concurrent.CountDownLatch operationFinished = new java.util.concurrent.CountDownLatch(1);
          // no callbacks anymore for this topic, unsubscribe from mqtt
          connection.getDispatchQueue().execute(() -> {
            org.fusesource.hawtbuf.UTF8Buffer topicBuffer = org.fusesource.hawtbuf.Buffer.utf8(topic);
            org.fusesource.hawtbuf.UTF8Buffer[] topicArray = new org.fusesource.hawtbuf.UTF8Buffer[]{topicBuffer};
            connection.unsubscribe(topicArray, new org.fusesource.mqtt.client.Callback<>() {
              @Override
              public void onSuccess(Void value) {
                operationFinished.countDown();
              }
    
              @Override
              public void onFailure(Throwable cause) {
                success.set(false);
                logger.warn("Could not disconnect from {}", topic, cause);
                operationFinished.countDown();
              }
            });
          });
          try {
            operationFinished.await(2, java.util.concurrent.TimeUnit.SECONDS);
          } catch (InterruptedException e) {
            logger.catching(e);
            success.set(false);
          }
        }
        return success.get();
      }
    
      /**
       * Waits until this updater is ready to receive MQTT messages.
       * If it already is ready, return immediately with the value <code>true</code>.
       * Otherwise waits for the given amount of time, and either return <code>true</code> within the timespan,
       * if it got ready, or <code>false</code> upon a timeout.
       * @param time the maximum time to wait
       * @param unit the time unit of the time argument
       * @return whether this updater is ready
       */
      public boolean waitUntilReady(long time, java.util.concurrent.TimeUnit unit) {
        try {
          return readyLatch.await(time, unit);
        } catch (InterruptedException e) {
          e.printStackTrace();
        }
        return false;
      }
    
      public void close() {
        if (connection == null) {
          logger.warn("Stopping without connection. Was setHost() called?");
          return;
        }
        connection.getDispatchQueue().execute(() -> {
          connection.disconnect(new org.fusesource.mqtt.client.Callback<>() {
            @Override
            public void onSuccess(Void value) {
              logger.info("Disconnected {} from {}", name, host);
            }
    
            @Override
            public void onFailure(Throwable ignored) {
              // Disconnects never fail. And we do not care either.
            }
          });
        });
      }
    
      public void publish(String topic, byte[] bytes) {
        publish(topic, bytes, false);
      }
    
      public void publish(String topic, byte[] bytes, boolean retain) {
        publish(topic, bytes, this.qos, retain);
      }
    
      public void publish(String topic, byte[] bytes, org.fusesource.mqtt.client.QoS qos, boolean retain) {
        connection.getDispatchQueue().execute(() -> {
          connection.publish(topic, bytes, qos, retain, new org.fusesource.mqtt.client.Callback<>() {
            @Override
            public void onSuccess(Void value) {
              logger.debug("Published some bytes to {}", topic);
            }
    
            @Override
            public void onFailure(Throwable value) {
              logger.warn("Could not publish on topic '{}'", topic, value);
            }
          });
        });
      }
    }
    }