#include <ros/ros.h>
#include <stdio.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <boost/make_shared.hpp>
#include <memory>
#include <algorithm> // for std::find
#include <fcntl.h>
#include "tron_adapter.h"

typedef uint8_t byte;

// some helper functions -----------------------------------------------

// logs bytes
inline void byte_info(const byte* buf, int buf_length, bool send){
    std::stringstream strstr;
    strstr << (send ? "sending" : "received") << " bytes:"; 
    for (int i = 0; i < buf_length; i++) strstr << " " << (int)buf[i];
    ROS_INFO(strstr.str().c_str());
}

// gets count bytes from socket file descriptor (with timeout)
const double SECONDS_BEFORE_TIMEOUT = 30;
std::unique_ptr<byte[]> get_bytes_socket(int fd, int count){
    std::unique_ptr<byte[]> arr = std::make_unique<byte[]>(count);
    int already_read = 0;
    ros::Time start_time = ros::Time::now();
    int success = 0;
    while (already_read < count) {
        success = recv(fd, &arr[already_read], count-already_read, MSG_DONTWAIT);
        if (success == 0) throw "connection was closed by TRON";
        if (success == -1){
            if (ros::Time::now().toSec() - start_time.toSec() >= SECONDS_BEFORE_TIMEOUT)
                throw "timeout while reading bytes from socket file descriptor";
            continue;
        }
        already_read += success; // read returns number of bytes read
    }
    byte_info(arr.get(), already_read, false);
    return arr; // no explicit move needed since return value is rvalue
};

inline int32_t network_bytes_to_int_32(byte *buf){
    uint32_t h = ntohl(*reinterpret_cast<uint32_t*>(buf));
    return *reinterpret_cast<int32_t*>(&h);
}

inline uint16_t network_bytes_to_uint_16(byte *buf) {
    return ntohs(*reinterpret_cast<uint16_t*>(buf));
}

// wrapping get_bytes_socket for converting to 32 bit integer
int32_t get_int_socket(int fd) {
    auto ack = get_bytes_socket(fd, 4);
    return network_bytes_to_int_32(ack.get());
}

// converts num to network order and adds it to byte array starting from index
void add_int32_in_network_order(int32_t num, byte *buf, int index){
    uint32_t n = htonl(*reinterpret_cast<uint32_t*>(&num));
    byte* bytes = reinterpret_cast<byte*>(&n);
    buf[index] = bytes[0];
    buf[++index] = bytes[1];
    buf[++index] = bytes[2];
    buf[++index] = bytes[3];
}

// wraps write() for printing and throwing on errors
inline void send_bytes(int fd, const byte *buf, int length){
    byte_info(buf, length);
    int ret = write(fd, (void*) buf, length);
    if (ret < 0) throw "sending failed";
}

// returns false if nothing more to read and true if 4 bytes are read successfully
// used to reduce overhead in testing phase
inline bool read_4_bytes_nonblock(int fd, byte *buf) {
    int bytes_recv = recv(fd, buf, 4, MSG_DONTWAIT);
    if (bytes_recv == -1) return false; // nothing more to read
    if (bytes_recv == 0) throw "connection was closed";
    if (bytes_recv != 4) throw "could not read full 4 bytes";
    byte_info(buf, 4, false);
    return true;
}

// returns file descriptor
int create_connected_socket(std::string IP, uint16_t port){
    int socketfd;
    if ((socketfd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
        throw "failed to create socket";
    }
    ROS_INFO("socket created successfully");

    struct sockaddr_in addr;
    addr.sin_family = AF_INET;
    addr.sin_port = htons(port);
    {
        int x = inet_pton(AF_INET, IP.c_str(), &addr.sin_addr);
        if (x != 1) {
            throw "IP could not be converted";
        }
    }
    if (connect(socketfd, (struct sockaddr*)&addr, sizeof(sockaddr_in)) < 0) {
        throw "failed to connect";
    }
    ROS_INFO("successfully connected");
    return socketfd;
}

// ROS uses little endian for its messages
const bool SYS_IS_BIG_ENDIAN = htonl(47) == 47;


TRON_Adapter::TRON_Adapter(std::string IP, uint16_t PORT) {
    socket_fd = create_connected_socket(IP, PORT);
}

Mapping TRON_Adapter::createMapping(std::string topic, std::string channel, bool channelIsInput){
    Mapping map;
    map.topic = topic;
    map.channel = *send_channel_decl_msg(channelIsInput, channel).get();
    return map;
}

void TRON_Adapter::add_var_to_mapping(Mapping& map, std::string name_tron, int byte_offset, 
                         int32_t (*conv_to_TRON)(byte*, int*),
                         void (*conv_to_topic)(int32_t, byte*, int*)){
    add_var_to_channel(map.channel, map.channel.is_input, name_tron);
    map.byte_offset.push_back(byte_offset);
    map.converters_to_TRON.push_back(conv_to_TRON);
    map.converters_to_topics.push_back(conv_to_topic);
}

void TRON_Adapter::get_error_msg(int32_t errorcode) {
    ROS_WARN("got error, trying to get corresponding message");
    byte get_err_msg_msg[5];
    get_err_msg_msg[0] = GET_ERROR_MSG;
    add_int32_in_network_order(errorcode, get_err_msg_msg, 1);
    send_bytes(socket_fd, get_err_msg_msg, 5);
    byte err_msg_length = get_bytes_socket(socket_fd, 1)[0];
    auto err_msg = get_bytes_socket(socket_fd, err_msg_length);
    std::string msg_str = std::string(reinterpret_cast<char*>(err_msg.get()), (size_t) err_msg_length);
    ROS_FATAL("TRON sent error message: %s", msg_str.c_str());
    throw "got error from TRON";
}

void TRON_Adapter::add_var_to_channel(Channel& chan, bool is_input, std::string var) {
    bool var_already_declared = false;
    for (Mapping& map : mappings)
        if (chan.name == map.channel.name)
            if (std::find(map.channel.vars.begin(), map.channel.vars.end(), var) != map.channel.vars.end()) var_already_declared = true;
    if (var_already_declared) {
        ROS_INFO("variable %s was already declared to channel %s", var.c_str(), chan.name.c_str());
        chan.vars.push_back(var);
        return;
    }
    byte msg[6 + var.length()];
    msg[0] = is_input ? ADD_VAR_TO_INPUT : ADD_VAR_TO_OUTPUT;
    add_int32_in_network_order(chan.identifier, msg, 1);
    msg[5] = (byte) var.length();
    for (int i = 0; i < var.length(); i++) msg[6+i] = var.at(i);
    ROS_INFO("attaching variable %s to channel %s", var.c_str(), chan.name.c_str());
    send_bytes(socket_fd, msg, 6 + var.length());
    int32_t ack = get_int_socket(socket_fd);
    if (ack < 0) get_error_msg(ack);
    ROS_INFO("success: attached variable");
    chan.vars.push_back(var);
}

std::unique_ptr<Channel> TRON_Adapter::send_channel_decl_msg(bool is_input, std::string name) {
    // prepare packet
    size_t msg_length = 2 + name.length();
    byte msg[msg_length];
    msg[0] = is_input ? DECL_CHAN_INPUT : DECL_CHAN_OUTPUT;
    msg[1] = name.length();
    for (int i = 2, c = 0; i < msg_length; i++, c++) msg[i] = name[c];

    // send packet
    ROS_INFO("declaring channel %s as %s", name.c_str(), (is_input ? "input" : "output"));
    send_bytes(socket_fd, msg, msg_length);

    // get answer from TRON
    int32_t channel_identifier = get_int_socket(socket_fd);
    if (channel_identifier < 0) { // error handling
        get_error_msg(channel_identifier);
    }
    if (channel_identifier == 0) throw "did not get channel identifier";
    // assigned channel ID successfully
    ROS_INFO("success: identifier for channel %s is %i", name.c_str(), channel_identifier);
    return std::make_unique<Channel>(name, channel_identifier, is_input);
}

void TRON_Adapter::set_time_unit_and_timeout(uint64_t microseconds, int32_t timeout){
    byte msg[9];
    msg[0] = SET_TIME_UNIT;
    byte *microseconds_bytes = reinterpret_cast<byte*>(&microseconds);

    // htonl does not exist for long int
    if (SYS_IS_BIG_ENDIAN) {
        for (int i = 0; i < 8; i++) msg[i+1] = microseconds_bytes[i];
    } else {
        for (int i = 0; i < 8; i++) msg[i+1] = microseconds_bytes[7-i];
    }
    ROS_INFO("setting time unit: %li microseconds", microseconds);
    send_bytes(socket_fd, msg, 9);
    int32_t ack = get_int_socket(socket_fd);
    if (ack != 0) get_error_msg(ack);
    ROS_INFO("success: set time unit");

    msg[0] = SET_TIMEOUT;
    add_int32_in_network_order(timeout, msg, 1);
    ROS_INFO("setting timeout to %i units", timeout);
    send_bytes(socket_fd, msg, 5);
    ack = get_int_socket(socket_fd);
    if (ack != 0) get_error_msg(ack);
    ROS_INFO("success: set timeout");
}
 
void TRON_Adapter::request_start() {
    /* documentation confuses codes for start and getErrorMessage, actually used:
    64 is start
    127 is gerErrorMessage */
    ROS_INFO("requesting start");
    byte start = REQUEST_START;
    send_bytes(socket_fd, &start, 1);
    byte answer = get_bytes_socket(socket_fd, 1)[0];
    if (answer != ANSWER_START) throw "starting failed";
    ROS_INFO("success: starting test phase");
}

void TRON_Adapter::report_now(Channel& chan, uint16_t var_count, int32_t *vars){
    std::unique_ptr<byte[]> msg = std::make_unique<byte[]>(6 + 4 * var_count);

    add_int32_in_network_order(chan.identifier, msg.get(), 0);

    unsigned short var_count_network_order = htons(var_count);
    byte *var_count_bytes = reinterpret_cast<byte*>(&var_count_network_order);
    msg[4] = var_count_bytes[0];
    msg[5] = var_count_bytes[1];

    for (unsigned short i = 0; i < var_count; i++) 
        add_int32_in_network_order(vars[i], msg.get(), 6 + i * 4);

    ROS_INFO("sending to output channel %s", chan.name.c_str());
    if (var_count == 0) ROS_INFO("no variables attached");
    for (unsigned short i = 0; i < var_count; i++)
        ROS_INFO("attached value %i to variable %s", vars[i], chan.vars[i].c_str());
    send_bytes(socket_fd, msg.get(), 6 + 4 * var_count);
    acks_missing++;
}

void TRON_Adapter::report_now(std::string chan, uint16_t var_count, int32_t *vars){
    for (Mapping& map : mappings)
        if (map.channel.name == chan && !map.channel.is_input) {
            report_now(map.channel, var_count, vars);
            return;
        }
    throw "could not report to channel";
}

void TRON_Adapter::process_TRONs_msgs(){
    /* note: TRONs communication after start is not guaranteed to be synchronous,
    thus incoming messages must be checked for their content */
    while (true){
        // get 4 bytes at a time as an int32
        byte bytes[4];
        if (!read_4_bytes_nonblock(socket_fd, bytes)) 
            break; // no bytes left to process
        int32_t next = network_bytes_to_int_32(bytes);

        // bytes are acknowledgement
        if (next == ACK_SINGLE) {
            if (--acks_missing < 0) throw "too many acknowledgements";
            ROS_INFO("got acknowledgement for output");
            continue;
        }

        // bytes are channel identifier
        // find corresponding channel
        const Channel *chan = nullptr;
        for (Mapping& map : mappings) if (map.channel.identifier == next) chan = &map.channel;
        /* note: this only happens if message number 12 in TRON User Manual is received,
            which should not be the case according to the documentation */
        if (chan == nullptr) 
            throw "channel could not be identified";

        ROS_INFO("got channel identifier (%s) for input", chan->name.c_str());

        // channel identified, assuming all following bytes are correct

        // get number of variables
        recv(socket_fd, bytes, 2, MSG_DONTWAIT);
        byte_info(bytes, 2, false);
        uint16_t var_count = network_bytes_to_uint_16(bytes);
        ROS_INFO("got variable count %i", var_count);

        int32_t vals[var_count];
        // process variables
        for (uint16_t i = 0; i < var_count; i++) {
            recv(socket_fd, bytes, 4, MSG_DONTWAIT);
            next = network_bytes_to_int_32(bytes);
            std::string var = chan->vars[i];
            ROS_INFO("got variable number %i:  value of %s is %i", i+1, var.c_str(), next);
            vals[i] = next;
        }

        for (Mapping& map : mappings)
            if (map.channel.name == chan->name && map.channel.is_input) {
                if (!map.input_callback.empty())
                    map.input_callback(map, vals);  
                else throw "no callback declared";
            }
                          

        // send acknowledgement
        add_int32_in_network_order(ACK_SINGLE, bytes, 0);
        ROS_INFO("sending acknowledgement");
        send_bytes(socket_fd, bytes, 4);
    }
}