← source index

src/WebSocket.cpp

290 lines  ·  9.8 K  ·  cpp
#include "WebSocket.h"

//================================
//  WebSocket definition
//================================

WebSocket::WebSocket(int capacity)
    : mSize(0)
    , mCapacity(capacity)
{
    mConnCache.reserve(mCapacity);
    mConnLookup.reserve(mCapacity);
}

WebSocket::~WebSocket()
{
    std::vector<int> keys;
    for(const auto& keyPair : mConnCache)
        keys.push_back(keyPair.first);
  
    for(auto key : keys)
        close(key);
}

std::pair<int, std::string> WebSocket::send(int id, const std::string& msg)
{
    const auto& it = mConnCache.find(id);
    if(it == mConnCache.end())
    {
        std::string err = "Cannot send [" + msg + "] to connection id [" + std::to_string(id) + "] because that id does not point to a valid connection.";
        return {0, err};
    }

    const auto& data = it->second;

    try
    {
        data->sslSocket->write(boost::asio::buffer(msg));
    }
    catch (std::exception const& e)
    {
        std::string err = "Exception while attempting to send [" + msg + "] to connection id [" + std::to_string(id) + "]: [" + e.what() + "]";
        return {-1, err};
    }

    return {id, ""};
}

std::pair<int, std::string> WebSocket::reconnect(int id)
{
    const auto& it = mConnDetails.find(id);
    if(it == mConnDetails.end())
    {
        std::string err = "Cannot find connection details for connection id [" + std::to_string(id) + "] because that id does not point to a valid connection. ['reconnect]";
        return {0, err};
    }

    const auto& connectionDetails = it->second;
    auto [closeStatus, closeErr] = close(id);
    auto [connectStatus, connectErr] = connect( connectionDetails.host, connectionDetails.port, 
                                                connectionDetails.target, connectionDetails.instruments, id);
    if(connectStatus <= 0) return {-1, connectErr};
    std::string msg = "Successfully reconnected to [" + connectErr + "] on connection id [" + std::to_string(id) + ']';
    return {id, msg};
}

std::pair<int, std::string> WebSocket::recv(int id)
{
    const auto& it = mConnCache.find(id);
    if(it == mConnCache.end())
    {
        std::string err = "Cannot read from connection id [" + std::to_string(id) + "] because that id does not point to a valid connection. ['recv']";
        return {0, err};
    }

    const auto& data = it->second;

    try
    {
        if(data && data->sslSocket)
        {
            boost::beast::error_code errCode;
            if(data->sslSocket->is_open())
            {
                mBuffer.consume(mBuffer.size());
                data->sslSocket->read(mBuffer, errCode); //This blocks... I really should switch to async
            }

            switch(errCode.value()) 
            {
            case boost::beast::errc::success:
            {
                break;
            }
            case boost::beast::errc::connection_reset:
            case boost::beast::errc::operation_canceled:
            {
                std::string err = "Failure to read WebSocket ssl buffer. Connection timeout... [" + std::to_string(errCode.value()) + ": " + errCode.message() + ']';
                auto [closeId, closeErr] = reconnect(id);
                err += "\n\t" + closeErr;
                return {-2, err};
            }
            default:
            {
                std::string err = "Failure to read WebSocket ssl buffer. Unhandled error state... [" + std::to_string(errCode.value()) + ": " + errCode.message() + ']';
                return {-100, err};
            }
            }
        }
        else
        {
            std::string err = "WebSocket ssl socket was empty when attempting to recv the connection data [" + std::to_string(id) + ']';
            return {-2, err};
        }

        if(mBuffer.size() == 0)
        {
            std::string err = "Connection [" + std::to_string(id) + "] is recovering from an outage...";
            return {-3, err};
        }

        return {id, boost::beast::buffers_to_string(mBuffer.data())};
    }
    catch(const std::exception& e)
    {
        std::string err = "Exception while attempting to read from connection id [" + std::to_string(id) + "]: [" + e.what() + "]";
        return {-1, err};
    }
    
    //return {id, ""};
}

std::pair<int, std::string> WebSocket::connect(std::string const& host, const std::string& port, std::string const& target, const std::string& instruments, int id)
{
    if(mSize >= mCapacity) return {0, "Max connections reached. Discarding..."};

    std::string connKey = host + ':' + port + target + '|' + instruments;

    const auto& it = mConnLookup.find(connKey);
    if(it != mConnLookup.end() && id == 0)
    {
        std::string err = "Connection already established [" + connKey + "] on connection id [" + std::to_string(it->second) + ']';
        return {it->second, err};
    }

    try
    {
        std::shared_ptr<WebSocketData> data = std::make_shared<WebSocketData>();

        boost::asio::ssl::context sslCtx{boost::asio::ssl::context::tls_client};
        sslCtx.set_default_verify_paths();
        boost::asio::ip::tcp::resolver res{mCtx};
        boost::beast::websocket::stream<boost::asio::ip::tcp::socket> socket{mCtx};

        sslCtx.set_verify_mode(boost::asio::ssl::verify_peer);
        boost::beast::websocket::stream<boost::asio::ssl::stream<boost::asio::ip::tcp::socket>> sslSocket{mCtx, sslCtx};

        data->symKey = connKey;
        data->sslCtx = std::make_unique<boost::asio::ssl::context>(std::move(sslCtx));
        data->res = std::make_unique<boost::asio::ip::tcp::resolver>(std::move(res));
        data->socket = std::make_unique<boost::beast::websocket::stream<boost::asio::ip::tcp::socket>>(std::move(socket));
        data->sslSocket = std::make_unique<boost::beast::websocket::stream<boost::asio::ssl::stream<boost::asio::ip::tcp::socket>>>(std::move(sslSocket));
        //.emplace(std::move(sslSocket), sslCtx);
        
        auto const results = res.resolve(host, port);
        auto endpoint = boost::asio::connect(boost::beast::get_lowest_layer(*data->sslSocket), results);
 
        // Set SNI Hostname (many hosts need this to handshake successfully)
        if(!SSL_set_tlsext_host_name(data->sslSocket->next_layer().native_handle(), host.c_str()))
        {
            throw boost::beast::system_error(
                static_cast<int>(::ERR_get_error()),
                boost::asio::error::get_ssl_category());
        }

        data->sslSocket->next_layer().set_verify_callback(boost::asio::ssl::host_name_verification(host));
        data->sslSocket->next_layer().handshake(boost::asio::ssl::stream_base::client);

        // Set a decorator to change the User-Agent of the handshake
        /*data->sslSocket->set_option(boost::beast::websocket::stream_base::decorator(
            [](boost::beast::websocket::request_type& req)
            {
                req.set(boost::beast::http::field::user_agent, std::string(BOOST_BEAST_VERSION_STRING) + " websocket-client-coro");
            }));
        */

        data->sslSocket->handshake(host, target);
        

        if(id > 0)
        {
            mConnCache.emplace(id, std::move(data));
            mConnLookup.emplace(connKey, id);
            return {id, connKey};
        }
        else
        {
            mSize++;

            ConnectionDetails connectionDetails;
            connectionDetails.host = host;
            connectionDetails.port = port;
            connectionDetails.target = target;
            connectionDetails.instruments = instruments;

            mConnCache.emplace(mSize, std::move(data));
            mConnLookup.emplace(connKey, mSize);
            mConnDetails.emplace(mSize, connectionDetails);
            return {mSize, connKey};
        }
    }
    catch (std::exception const& e)
    {
        std::string err = "Exception while attempting to connect to [" + host + ":" + port + "] [" + target + "]: [" + e.what() + "]";
        return {-1, err};
    }
}

std::pair<int, std::string> WebSocket::close(int id)
{
    const auto& it = mConnCache.find(id);
    if(it == mConnCache.end())
    {
        std::string err = "Cannot close connection id [" + std::to_string(id) + "] because that id does not point to a valid connection.";
        return {0, err};
    }

    const auto& data = it->second;

    try
    {
        if(data && data->sslSocket)
        {
            boost::beast::error_code errCode;

            if(data->sslSocket->is_open())
            {
                data->sslSocket->close(boost::beast::websocket::close_code::normal, errCode);
            }
            else
            {
                data->sslSocket->next_layer().shutdown(errCode);
                data->sslSocket->next_layer().lowest_layer().close(errCode);
            }

            switch(errCode.value()) 
            {
            case boost::beast::errc::success:
            {
                break;
            }
            default:
            {
                std::string err = "Failure to close WebSocket ssl connection. [" + errCode.message() + ']';
                return {-3, err};
            }
            }
        }
        else
        {
            std::string err = "WebSocket ssl socket was empty when attempting to close the connection [" + std::to_string(id) + ']';
            return {-2, err};
        }
    }
    catch (std::exception const& e)
    {
        std::string err = "Exception while attempting to close connection id [" + std::to_string(id) + "]: [" + e.what() + "]";
        return {-1, err};
    }

    mConnLookup.erase(data->symKey);
    mConnCache.erase(it);
    return {id, ""};
}

std::pair<int, std::string> WebSocket::getSymKey(int id)
{
    const auto& it = mConnCache.find(id);
    if(it == mConnCache.end())
    {
        std::string err = "Cannot get symbology key [" + std::to_string(id) + "] because that id does not point to a valid connection.";
        return {0, err};
    }

    return {id, it->second->symKey};
}

//================================
//================================
//================================