Select Git revision
RestHandler.jadd

René Schöne authored
RestHandler.jadd 5.09 KiB
import java.util.concurrent.TimeUnit;aspect RestHandler {
public class RestServerHandler {
private static final int DEFAULT_PORT = 4567;
private final java.util.Map<Integer, RestHandler> handlers = new java.util.HashMap<>();
private final java.util.Map<ConnectToken, Object> tokensForRemoval = new java.util.HashMap<>();
private String name;
public RestServerHandler() {
this("RagConnect");
}
public RestServerHandler(String name) {
this.name = name;
}
private RestHandler resolveHandler(java.net.URI uri) {
int port = uri.getPort() != -1 ? uri.getPort() : DEFAULT_PORT;
RestHandler handler = handlers.get(port);
if (handler == null) {
// first connect to that server
handler = new RestHandler();
handler.setPort(port);
handlers.put(port, handler);
}
return handler;
}
public ConnectToken newPUTConnection(java.net.URI uri, java.util.function.Consumer<String> callback) {
ConnectToken connectToken = new ConnectToken(uri);
resolveHandler(uri).newPUTConnection(uri.getPath(), callback);
tokensForRemoval.put(connectToken, callback);
return connectToken;
}
public ConnectToken newGETConnection(java.net.URI uri, SupplierWithException<String> supplier) {
ConnectToken connectToken = new ConnectToken(uri);
resolveHandler(uri).newGETConnection(uri.getPath(), supplier);
tokensForRemoval.put(connectToken, supplier);
return connectToken;
}
public boolean disconnect(ConnectToken connectToken) {
RestHandler handler = resolveHandler(connectToken.uri);
return handler != null ? handler.disconnect(connectToken.uri.getPath(), tokensForRemoval.get(connectToken)) : false;
}
public void close() {
for (RestHandler handler : handlers.values()) {
handler.close();
}
}
}
/**
* Helper class to receive updates and publishes information via REST.
* @author rschoene - Initial contribution
*/
public class RestHandler {
private static final int DEFAULT_PORT = 4567;
private final org.apache.logging.log4j.Logger logger;
private final String name;
private int port;
private final java.util.concurrent.CountDownLatch exitCondition;
/** Dispatch knowledge */
private final java.util.Map<String, java.util.List<java.util.function.Consumer<String>>> callbacks;
private final java.util.Map<String, SupplierWithException<String>> suppliers;
public RestHandler() {
this("RagConnect");
}
public RestHandler(String name) {
this.logger = org.apache.logging.log4j.LogManager.getLogger(RestHandler.class);
this.name = name;
this.port = DEFAULT_PORT;
this.exitCondition = new java.util.concurrent.CountDownLatch(1);
this.callbacks = new java.util.HashMap<>();
this.suppliers = new java.util.HashMap<>();
}
public RestHandler setPort(int port) {
this.port = port;
start();
return this;
}
public void newPUTConnection(String path, java.util.function.Consumer<String> callback) {
if (callbacks.containsKey(path)) {
callbacks.get(path).add(callback);
} else {
// setup path
java.util.List<java.util.function.Consumer<String>> callbackList = new java.util.ArrayList<>();
callbackList.add(callback);
callbacks.put(path, callbackList);
spark.Spark.put(path, (request, response) -> {
String content = request.body();
java.util.Set<String> errors = new java.util.HashSet<>();
for (java.util.function.Consumer<String> f : callbackList) {
try {
f.accept(content);
} catch (Exception e) {
errors.add(e.getMessage());
}
}
if (errors.isEmpty()) {
return "OK";
} else {
return makeError(response, 500, errors.stream().collect(java.util.stream.Collectors.joining("\n", "The following error(s) happened: [", "]")));
}
});
}
}
public void newGETConnection(String path, SupplierWithException<String> supplier) {
if (suppliers.get(path) != null) {
logger.warn("Overriding existing supplier for '{}'", path);
}
suppliers.put(path, supplier);
spark.Spark.get(path, (request, response) -> {
try {
// we could check for null here in case supplier has been disconnected
return supplier.get();
} catch (Exception e) {
return makeError(response, 500, e.getMessage());
}
});
}
public boolean disconnect(String path, Object callbackOrSupplier) {
// only one will succeed (or false will be returned)
return callbacks.getOrDefault(path, java.util.Collections.emptyList()).remove(callbackOrSupplier) ||
suppliers.remove(path, callbackOrSupplier);
}
private String makeError(spark.Response response, int statusCode, String message) {
response.status(statusCode);
return message;
}
private void start() {
logger.info("Starting REST server at {}", this.port);
spark.Spark.port(this.port);
spark.Spark.init();
spark.Spark.awaitInitialization();
}
public void close() {
spark.Spark.stop();
spark.Spark.awaitStop();
}
}
@FunctionalInterface
public interface SupplierWithException<T> {
public T get() throws Exception;
}
}