From bd7f30aec53fc776be678577dbe4f9afec5898a6 Mon Sep 17 00:00:00 2001 From: Edwin Mons Date: Fri, 23 May 2014 11:01:23 +0200 Subject: Sluift component support Change-Id: Ib8af01c04c866e198c04d35236dea4da464c9116 diff --git a/Sluift/ClientHelpers.cpp b/Sluift/ClientHelpers.cpp deleted file mode 100644 index 8e07112..0000000 --- a/Sluift/ClientHelpers.cpp +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2013 Remko Tronçon - * Licensed under the GNU General Public License. - * See the COPYING file for more information. - */ - -#include - -#include - -using namespace Swift; - -std::string Swift::getClientErrorString(const ClientError& error) { - std::string reason = "Disconnected: "; - switch(error.getType()) { - case ClientError::UnknownError: reason += "Unknown Error"; break; - case ClientError::DomainNameResolveError: reason += "Unable to find server"; break; - case ClientError::ConnectionError: reason += "Error connecting to server"; break; - case ClientError::ConnectionReadError: reason += "Error while receiving server data"; break; - case ClientError::ConnectionWriteError: reason += "Error while sending data to the server"; break; - case ClientError::XMLError: reason += "Error parsing server data"; break; - case ClientError::AuthenticationFailedError: reason += "Login/password invalid"; break; - case ClientError::CompressionFailedError: reason += "Error while compressing stream"; break; - case ClientError::ServerVerificationFailedError: reason += "Server verification failed"; break; - case ClientError::NoSupportedAuthMechanismsError: reason += "Authentication mechanisms not supported"; break; - case ClientError::UnexpectedElementError: reason += "Unexpected response"; break; - case ClientError::ResourceBindError: reason += "Error binding resource"; break; - case ClientError::RevokedError: reason += "Certificate got revoked"; break; - case ClientError::RevocationCheckFailedError: reason += "Failed to do revokation check"; break; - case ClientError::SessionStartError: reason += "Error starting session"; break; - case ClientError::StreamError: reason += "Stream error"; break; - case ClientError::TLSError: reason += "Encryption error"; break; - case ClientError::ClientCertificateLoadError: reason += "Error loading certificate (Invalid password?)"; break; - case ClientError::ClientCertificateError: reason += "Certificate not authorized"; break; - case ClientError::UnknownCertificateError: reason += "Unknown certificate"; break; - case ClientError::CertificateCardRemoved: reason += "Certificate card removed"; break; - case ClientError::CertificateExpiredError: reason += "Certificate has expired"; break; - case ClientError::CertificateNotYetValidError: reason += "Certificate is not yet valid"; break; - case ClientError::CertificateSelfSignedError: reason += "Certificate is self-signed"; break; - case ClientError::CertificateRejectedError: reason += "Certificate has been rejected"; break; - case ClientError::CertificateUntrustedError: reason += "Certificate is not trusted"; break; - case ClientError::InvalidCertificatePurposeError: reason += "Certificate cannot be used for encrypting your connection"; break; - case ClientError::CertificatePathLengthExceededError: reason += "Certificate path length constraint exceeded"; break; - case ClientError::InvalidCertificateSignatureError: reason += "Invalid certificate signature"; break; - case ClientError::InvalidCAError: reason += "Invalid Certificate Authority"; break; - case ClientError::InvalidServerIdentityError: reason += "Certificate does not match the host identity"; break; - } - return reason; -} - diff --git a/Sluift/ClientHelpers.h b/Sluift/ClientHelpers.h deleted file mode 100644 index eb78ba6..0000000 --- a/Sluift/ClientHelpers.h +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright (c) 2013 Remko Tronçon - * Licensed under the GNU General Public License. - * See the COPYING file for more information. - */ - -#pragma once - -#include -#include - -#include - -namespace Swift { - class ClientError; - - std::string getClientErrorString(const ClientError& error); -} diff --git a/Sluift/Examples/Component.lua b/Sluift/Examples/Component.lua new file mode 100644 index 0000000..b5d6539 --- /dev/null +++ b/Sluift/Examples/Component.lua @@ -0,0 +1,55 @@ +--[[ + Copyright (c) 2014 Edwin Mons and Remko Tronçon + Licensed under the GNU General Public License v3. + See Documentation/Licenses/GPLv3.txt for more information. +--]] + +--[[ + + Component example. + + This script connects to an XMPP server as a component, and listens to + messages received. + + The following environment variables are used: + * SLUIFT_COMP_DOMAIN: Component domain name + * SLUIFT_COMP_SECRET: Component secret + * SLUIFT_COMP_HOST: XMPP server host name + * SLUIFT_COMP_PORT: XMPP server component port + * SLUIFT_JID: Recipient of presence and initial message + * SLUIFT_DEBUG: Sets whether debugging should be turned on + +--]] + +require "sluift" + +sluift.debug = os.getenv("SLUIFT_DEBUG") or false + +config = { + domain = os.getenv('SLUIFT_COMP_DOMAIN'), + secret = os.getenv('SLUIFT_COMP_SECRET'), + host = os.getenv('SLUIFT_COMP_HOST'), + port = os.getenv('SLUIFT_COMP_PORT'), + jid = os.getenv('SLUIFT_JID') +} + +-- Create the component, and connect +comp = sluift.new_component(config.domain, config.secret); +comp:connect(config) + +-- Send initial presence and message +-- Assumes the remote client already has this component user on his roster +comp:send_presence{from='root@' .. config.domain, to=config.jid} +comp:send_message{from='root@' .. config.domain, to=config.jid, body='Component active'} + +-- Listen for messages, and reply if one is received +for message in comp:messages() do + print("Received a message from " .. message.from) + comp:send_message{to=message.from, from=message.to, body='I received: ' .. message['body']} + + -- Send out a ping to demonstrate we can do more than just send messages + comp:get{to=message.from, query=''} +end + +comp:disconnect() + diff --git a/Sluift/Helpers.cpp b/Sluift/Helpers.cpp new file mode 100644 index 0000000..29e2b04 --- /dev/null +++ b/Sluift/Helpers.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2013-2014 Kevin Smith and Remko Tronçon + * Licensed under the GNU General Public License. + * See the COPYING file for more information. + */ + +#include + +#include +#include + +using namespace Swift; + +template std::string Swift::getCommonErrorString(T& error) { + std::string reason = "Disconnected: "; + switch(error.getType()) { + case T::UnknownError: reason += "Unknown Error"; break; + case T::ConnectionError: reason += "Error connecting to server"; break; + case T::ConnectionReadError: reason += "Error while receiving server data"; break; + case T::ConnectionWriteError: reason += "Error while sending data to the server"; break; + case T::XMLError: reason += "Error parsing server data"; break; + case T::AuthenticationFailedError: reason += "Login/password invalid"; break; + case T::UnexpectedElementError: reason += "Unexpected response"; break; + } + return reason; +} + +std::string Swift::getErrorString(const ClientError& error) { + std::string reason = getCommonErrorString(error); + switch(error.getType()) { + case ClientError::DomainNameResolveError: reason += "Unable to find server"; break; + case ClientError::CompressionFailedError: reason += "Error while compressing stream"; break; + case ClientError::ServerVerificationFailedError: reason += "Server verification failed"; break; + case ClientError::NoSupportedAuthMechanismsError: reason += "Authentication mechanisms not supported"; break; + case ClientError::ResourceBindError: reason += "Error binding resource"; break; + case ClientError::RevokedError: reason += "Certificate got revoked"; break; + case ClientError::RevocationCheckFailedError: reason += "Failed to do revokation check"; break; + case ClientError::SessionStartError: reason += "Error starting session"; break; + case ClientError::StreamError: reason += "Stream error"; break; + case ClientError::TLSError: reason += "Encryption error"; break; + case ClientError::ClientCertificateLoadError: reason += "Error loading certificate (Invalid password?)"; break; + case ClientError::ClientCertificateError: reason += "Certificate not authorized"; break; + case ClientError::UnknownCertificateError: reason += "Unknown certificate"; break; + case ClientError::CertificateCardRemoved: reason += "Certificate card removed"; break; + case ClientError::CertificateExpiredError: reason += "Certificate has expired"; break; + case ClientError::CertificateNotYetValidError: reason += "Certificate is not yet valid"; break; + case ClientError::CertificateSelfSignedError: reason += "Certificate is self-signed"; break; + case ClientError::CertificateRejectedError: reason += "Certificate has been rejected"; break; + case ClientError::CertificateUntrustedError: reason += "Certificate is not trusted"; break; + case ClientError::InvalidCertificatePurposeError: reason += "Certificate cannot be used for encrypting your connection"; break; + case ClientError::CertificatePathLengthExceededError: reason += "Certificate path length constraint exceeded"; break; + case ClientError::InvalidCertificateSignatureError: reason += "Invalid certificate signature"; break; + case ClientError::InvalidCAError: reason += "Invalid Certificate Authority"; break; + case ClientError::InvalidServerIdentityError: reason += "Certificate does not match the host identity"; break; + } + return reason; +} + +std::string Swift::getErrorString(const ComponentError& error) { + return getCommonErrorString(error); +} + diff --git a/Sluift/Helpers.h b/Sluift/Helpers.h new file mode 100644 index 0000000..d04a610 --- /dev/null +++ b/Sluift/Helpers.h @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2013-2014 Kevin Smith and Remko Tronçon + * Licensed under the GNU General Public License. + * See the COPYING file for more information. + */ + +#pragma once + +#include +#include + +#include + +namespace Swift { + class ClientError; + class ComponentError; + + template std::string getCommonErrorString(T& error); + std::string getErrorString(const ClientError& error); + std::string getErrorString(const ComponentError& error); +} diff --git a/Sluift/SConscript b/Sluift/SConscript index 3cc1f29..5e0a030 100644 --- a/Sluift/SConscript +++ b/Sluift/SConscript @@ -33,11 +33,13 @@ elif env["SCONS_STAGE"] == "build" : "ElementConvertors/StatusConvertor.cpp", "ElementConvertors/StatusShowConvertor.cpp", "ElementConvertors/DelayConvertor.cpp", - "ClientHelpers.cpp", + "Helpers.cpp", "SluiftClient.cpp", + "SluiftComponent.cpp", "Watchdog.cpp", "core.c", "client.cpp", + "component.cpp", "sluift.cpp" ] sluift_sources += env.SConscript("ElementConvertors/SConscript") diff --git a/Sluift/SluiftClient.cpp b/Sluift/SluiftClient.cpp index 9ff9d18..69472b8 100644 --- a/Sluift/SluiftClient.cpp +++ b/Sluift/SluiftClient.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013-2014 Remko Tronçon + * Copyright (c) 2013-2014 Kevin Smith and Remko Tronçon * Licensed under the GNU General Public License. * See the COPYING file for more information. */ @@ -16,7 +16,7 @@ #include #include #include -#include +#include #include using namespace Swift; @@ -77,7 +77,7 @@ void SluiftClient::waitConnected(int timeout) { throw Lua::Exception("Timeout while connecting"); } if (disconnectedError) { - throw Lua::Exception(getClientErrorString(*disconnectedError)); + throw Lua::Exception(getErrorString(*disconnectedError)); } } diff --git a/Sluift/SluiftComponent.cpp b/Sluift/SluiftComponent.cpp new file mode 100644 index 0000000..c08a103 --- /dev/null +++ b/Sluift/SluiftComponent.cpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2014 Kevin Smith and Remko Tronçon + * Licensed under the GNU General Public License. + * See the COPYING file for more information. + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace Swift; + +SluiftComponent::SluiftComponent( + const JID& jid, + const std::string& password, + NetworkFactories* networkFactories, + SimpleEventLoop* eventLoop): + networkFactories(networkFactories), + eventLoop(eventLoop), + tracer(NULL) { + component = new Component(jid, password, networkFactories); + component->onError.connect(boost::bind(&SluiftComponent::handleError, this, _1)); + component->onMessageReceived.connect(boost::bind(&SluiftComponent::handleIncomingMessage, this, _1)); + component->onPresenceReceived.connect(boost::bind(&SluiftComponent::handleIncomingPresence, this, _1)); +} + +SluiftComponent::~SluiftComponent() { + delete tracer; + delete component; +} + +void SluiftComponent::connect(const std::string& host, int port) { + disconnectedError = boost::optional(); + component->connect(host, port); +} + +void SluiftComponent::setTraceEnabled(bool b) { + if (b && !tracer) { + tracer = new ComponentXMLTracer(component); + } + else if (!b && tracer) { + delete tracer; + tracer = NULL; + } +} + +void SluiftComponent::waitConnected(int timeout) { + Watchdog watchdog(timeout, networkFactories->getTimerFactory()); + while (!watchdog.getTimedOut() && !disconnectedError && !component->isAvailable()) { + eventLoop->runUntilEvents(); + } + if (watchdog.getTimedOut()) { + component->disconnect(); + throw Lua::Exception("Timeout while connecting"); + } + if (disconnectedError) { + throw Lua::Exception(getErrorString(*disconnectedError)); + } +} + +bool SluiftComponent::isConnected() const { + return component->isAvailable(); +} + +void SluiftComponent::disconnect() { + component->disconnect(); + while (component->isAvailable()) { + eventLoop->runUntilEvents(); + } +} + +void SluiftComponent::setSoftwareVersion(const std::string& name, const std::string& version, const std::string& os) { + component->setSoftwareVersion(name, version); +} + +boost::optional SluiftComponent::getNextEvent( + int timeout, boost::function condition) { + Watchdog watchdog(timeout, networkFactories->getTimerFactory()); + size_t currentIndex = 0; + while (true) { + // Look for pending events in the queue + while (currentIndex < pendingEvents.size()) { + Event event = pendingEvents[currentIndex]; + if (!condition || condition(event)) { + pendingEvents.erase( + pendingEvents.begin() + + boost::numeric_cast(currentIndex)); + return event; + } + ++currentIndex; + } + + // Wait for new events + while (!watchdog.getTimedOut() && currentIndex >= pendingEvents.size() && component->isAvailable()) { + eventLoop->runUntilEvents(); + } + + // Finish if we're disconnected or timed out + if (watchdog.getTimedOut() || !component->isAvailable()) { + return boost::optional(); + } + } +} + +void SluiftComponent::handleIncomingMessage(boost::shared_ptr stanza) { + pendingEvents.push_back(Event(stanza)); +} + +void SluiftComponent::handleIncomingPresence(boost::shared_ptr stanza) { + pendingEvents.push_back(Event(stanza)); +} + +void SluiftComponent::handleRequestResponse(boost::shared_ptr response, boost::shared_ptr error) { + requestResponse = response; + requestError = error; + requestResponseReceived = true; +} + +void SluiftComponent::handleError(const boost::optional& error) { + disconnectedError = error; +} + +Sluift::Response SluiftComponent::doSendRequest(boost::shared_ptr request, int timeout) { + requestResponse.reset(); + requestError.reset(); + requestResponseReceived = false; + request->send(); + + Watchdog watchdog(timeout, networkFactories->getTimerFactory()); + while (!watchdog.getTimedOut() && !requestResponseReceived) { + eventLoop->runUntilEvents(); + } + return Sluift::Response(requestResponse, watchdog.getTimedOut() ? + boost::make_shared(ErrorPayload::RemoteServerTimeout) : requestError); +} diff --git a/Sluift/SluiftComponent.h b/Sluift/SluiftComponent.h new file mode 100644 index 0000000..3d5792b --- /dev/null +++ b/Sluift/SluiftComponent.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2014 Kevin Smith and Remko Tronçon + * Licensed under the GNU General Public License. + * See the COPYING file for more information. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Swift { + struct SluiftGlobals; + class ComponentXMLTracer; + class Component; + class Stanza; + class Payload; + class ErrorPayload; + class JID; + + class SluiftComponent { + public: + struct Event { + enum Type { + MessageType, + PresenceType + }; + + Event(boost::shared_ptr stanza) : type(MessageType), stanza(stanza) {} + Event(boost::shared_ptr stanza) : type(PresenceType), stanza(stanza) {} + + Type type; + + // Message & Presence + boost::shared_ptr stanza; + }; + + SluiftComponent( + const JID& jid, + const std::string& password, + NetworkFactories* networkFactories, + SimpleEventLoop* eventLoop); + ~SluiftComponent(); + + Component* getComponent() { + return component; + } + + void connect(const std::string& host, int port); + void waitConnected(int timeout); + bool isConnected() const; + void setTraceEnabled(bool b); + + template + Sluift::Response sendRequest(REQUEST_TYPE request, int timeout) { + boost::signals::scoped_connection c = request->onResponse.connect( + boost::bind(&SluiftComponent::handleRequestResponse, this, _1, _2)); + return doSendRequest(request, timeout); + } + + template + Sluift::Response sendVoidRequest(REQUEST_TYPE request, int timeout) { + boost::signals::scoped_connection c = request->onResponse.connect( + boost::bind(&SluiftComponent::handleRequestResponse, this, boost::shared_ptr(), _1)); + return doSendRequest(request, timeout); + } + + void disconnect(); + void setSoftwareVersion(const std::string& name, const std::string& version, const std::string& os); + boost::optional getNextEvent(int timeout, + boost::function condition = 0); + + private: + Sluift::Response doSendRequest(boost::shared_ptr request, int timeout); + + void handleIncomingMessage(boost::shared_ptr stanza); + void handleIncomingPresence(boost::shared_ptr stanza); + void handleRequestResponse(boost::shared_ptr response, boost::shared_ptr error); + void handleError(const boost::optional& error); + + private: + NetworkFactories* networkFactories; + SimpleEventLoop* eventLoop; + Component* component; + ComponentXMLTracer* tracer; + bool rosterReceived; + std::deque pendingEvents; + boost::optional disconnectedError; + bool requestResponseReceived; + boost::shared_ptr requestResponse; + boost::shared_ptr requestError; + }; +} diff --git a/Sluift/component.cpp b/Sluift/component.cpp new file mode 100644 index 0000000..0c400b3 --- /dev/null +++ b/Sluift/component.cpp @@ -0,0 +1,467 @@ +/* + * Copyright (c) 2014 Kevin Smith and Remko Tronçon + * Licensed under the GNU General Public License. + * See the COPYING file for more information. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace Swift; +namespace lambda = boost::lambda; + +static inline SluiftComponent* getComponent(lua_State* L) { + return *Lua::checkUserData(L, 1); +} + +static inline int getGlobalTimeout(lua_State* L) { + lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.moduleLibIndex); + lua_getfield(L, -1, "timeout"); + int result = boost::numeric_cast(lua_tointeger(L, -1)); + lua_pop(L, 2); + return result; +} + +static void addPayloadsToTable(lua_State* L, const std::vector >& payloads) { + if (!payloads.empty()) { + lua_createtable(L, boost::numeric_cast(payloads.size()), 0); + for (size_t i = 0; i < payloads.size(); ++i) { + Sluift::globals.elementConvertor.convertToLua(L, payloads[i]); + lua_rawseti(L, -2, boost::numeric_cast(i+1)); + } + Lua::registerGetByTypeIndex(L, -1); + lua_setfield(L, -2, "payloads"); + } +} + +static boost::shared_ptr getPayload(lua_State* L, int index) { + if (lua_type(L, index) == LUA_TTABLE) { + return boost::dynamic_pointer_cast(Sluift::globals.elementConvertor.convertFromLua(L, index)); + } + else if (lua_type(L, index) == LUA_TSTRING) { + return boost::make_shared(Lua::checkString(L, index)); + } + else { + return boost::shared_ptr(); + } +} + +static std::vector< boost::shared_ptr > getPayloadsFromTable(lua_State* L, int index) { + index = Lua::absoluteOffset(L, index); + std::vector< boost::shared_ptr > result; + lua_getfield(L, index, "payloads"); + if (lua_istable(L, -1)) { + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + boost::shared_ptr payload = getPayload(L, -1); + if (payload) { + result.push_back(payload); + } + } + } + lua_pop(L, 1); + return result; +} + +SLUIFT_LUA_FUNCTION(Component, async_connect) { + SluiftComponent* component = getComponent(L); + + std::string host; + int port = 0; + if (lua_istable(L, 2)) { + if (boost::optional hostString = Lua::getStringField(L, 2, "host")) { + host = *hostString; + } + if (boost::optional portInt = Lua::getIntField(L, 2, "port")) { + port = *portInt; + } + } + component->connect(host, port); + return 0; +} + +SLUIFT_LUA_FUNCTION_WITH_HELP( + Component, set_trace_enabled, + "Enable/disable tracing of the data sent/received.\n\n.", + "self\n" + "enable a boolean specifying whether to enable/disable tracing", + "" +) { + getComponent(L)->setTraceEnabled(lua_toboolean(L, 1)); + return 0; +} + +SLUIFT_LUA_FUNCTION_WITH_HELP( + Component, wait_connected, + "Block until the component is connected.\n\nThis is useful after an `async_connect`.", + "self", + "" +) { + getComponent(L)->waitConnected(getGlobalTimeout(L)); + return 0; +} + +SLUIFT_LUA_FUNCTION_WITH_HELP( + Component, is_connected, + "Checks whether this component is still connected.\n\nReturns a boolean.", + "self\n", + "" +) { + lua_pushboolean(L, getComponent(L)->isConnected()); + return 1; +} + +SLUIFT_LUA_FUNCTION_WITH_HELP( + Component, disconnect, + "Disconnect from the server", + "self\n", + "" +) { + Sluift::globals.eventLoop.runOnce(); + getComponent(L)->disconnect(); + return 0; +} + +SLUIFT_LUA_FUNCTION_WITH_HELP( + Component, set_version, + + "Sets the published version of this component.", + + "self", + + "name the name of the component software\n" + "version the version identifier of this component\n" + "os the OS this component is running on\n" +) { + Sluift::globals.eventLoop.runOnce(); + SluiftComponent* component = getComponent(L); + if (boost::shared_ptr version = boost::dynamic_pointer_cast(Sluift::globals.elementConvertor.convertFromLuaUntyped(L, 2, "software_version"))) { + component->setSoftwareVersion(version->getName(), version->getVersion(), version->getOS()); + } + return 0; +} + +SLUIFT_LUA_FUNCTION_WITH_HELP( + Component, send_message, + "Send a message.", + "self\n" + "to the JID to send the message to\n" + "body the body of the message. Can alternatively be specified using the `body` option\n", + + "to the JID to send the message to\n" + "body the body of the message\n" + "subject the subject of the MUC room to set\n" + "type the type of message to send (`normal`, `chat`, `error`, `groupchat`, `headline`)\n" + "payloads payloads to add to the message\n" +) { + Sluift::globals.eventLoop.runOnce(); + JID to; + boost::optional from; + boost::optional body; + boost::optional subject; + std::vector > payloads; + int index = 2; + Message::Type type = Message::Chat; + if (lua_isstring(L, index)) { + to = std::string(lua_tostring(L, index)); + ++index; + if (lua_isstring(L, index)) { + body = lua_tostring(L, index); + ++index; + } + } + if (lua_istable(L, index)) { + if (boost::optional value = Lua::getStringField(L, index, "to")) { + to = *value; + } + + if (boost::optional value = Lua::getStringField(L, index, "from")) { + from = value; + } + + if (boost::optional value = Lua::getStringField(L, index, "body")) { + body = value; + } + + if (boost::optional value = Lua::getStringField(L, index, "type")) { + type = MessageConvertor::convertMessageTypeFromString(*value); + } + + if (boost::optional value = Lua::getStringField(L, index, "subject")) { + subject = value; + } + + payloads = getPayloadsFromTable(L, index); + } + + if (!to.isValid()) { + throw Lua::Exception("Missing 'to'"); + } + if ((!body || body->empty()) && !subject && payloads.empty()) { + throw Lua::Exception("Missing any of 'body', 'subject' or 'payloads'"); + } + Message::ref message = boost::make_shared(); + message->setTo(to); + if (from && !from->empty()) { + message->setFrom(*from); + } + if (body && !body->empty()) { + message->setBody(*body); + } + if (subject) { + message->setSubject(*subject); + } + message->addPayloads(payloads.begin(), payloads.end()); + message->setType(type); + getComponent(L)->getComponent()->sendMessage(message); + return 0; +} + +SLUIFT_LUA_FUNCTION_WITH_HELP( + Component, send_presence, + "Send presence.", + + "self\n" + "body the text of the presence. Can alternatively be specified using the `status` option\n", + + "to the JID to send the message to\n" + "from the JID to send the message from\n" + "status the text of the presence\n" + "priority the priority of the presence\n" + "type the type of message to send (`available`, `error`, `probe`, `subscribe`, `subscribed`, `unavailable`, `unsubscribe`, `unsubscribed`)\n" + "payloads payloads to add to the presence\n" +) { + Sluift::globals.eventLoop.runOnce(); + boost::shared_ptr presence = boost::make_shared(); + + int index = 2; + if (lua_isstring(L, index)) { + presence->setStatus(lua_tostring(L, index)); + ++index; + } + if (lua_istable(L, index)) { + if (boost::optional value = Lua::getStringField(L, index, "to")) { + presence->setTo(*value); + } + if (boost::optional value = Lua::getStringField(L, index, "from")) { + presence->setFrom(*value); + } + if (boost::optional value = Lua::getStringField(L, index, "status")) { + presence->setStatus(*value); + } + if (boost::optional value = Lua::getIntField(L, index, "priority")) { + presence->setPriority(*value); + } + if (boost::optional value = Lua::getStringField(L, index, "type")) { + presence->setType(PresenceConvertor::convertPresenceTypeFromString(*value)); + } + std::vector< boost::shared_ptr > payloads = getPayloadsFromTable(L, index); + presence->addPayloads(payloads.begin(), payloads.end()); + } + + getComponent(L)->getComponent()->sendPresence(presence); + lua_pushvalue(L, 1); + return 0; +} + +static int sendQuery(lua_State* L, IQ::Type type) { + SluiftComponent* component = getComponent(L); + + JID to; + if (boost::optional toString = Lua::getStringField(L, 2, "to")) { + to = JID(*toString); + } + + JID from; + if (boost::optional fromString = Lua::getStringField(L, 2, "from")) { + from = JID(*fromString); + } + + int timeout = getGlobalTimeout(L); + if (boost::optional timeoutInt = Lua::getIntField(L, 2, "timeout")) { + timeout = *timeoutInt; + } + + boost::shared_ptr payload; + lua_getfield(L, 2, "query"); + payload = getPayload(L, -1); + lua_pop(L, 1); + + return component->sendRequest( + boost::make_shared< GenericRequest >(type, from, to, payload, component->getComponent()->getIQRouter()), timeout).convertToLuaResult(L); +} + +#define DISPATCH_PUBSUB_PAYLOAD(payloadType, container, response) \ + else if (boost::shared_ptr p = boost::dynamic_pointer_cast(payload)) { \ + return component->sendPubSubRequest(type, to, p, timeout).convertToLuaResult(L); \ + } + +SLUIFT_LUA_FUNCTION(Component, get) { + return sendQuery(L, IQ::Get); +} + +SLUIFT_LUA_FUNCTION(Component, set) { + return sendQuery(L, IQ::Set); +} + +SLUIFT_LUA_FUNCTION_WITH_HELP( + Component, send, + "Sends a raw string", + + "self\n" + "data the string to send\n", + + "" +) { + Sluift::globals.eventLoop.runOnce(); + + getComponent(L)->getComponent()->sendData(std::string(Lua::checkString(L, 2))); + lua_pushvalue(L, 1); + return 0; +} + +static void pushEvent(lua_State* L, const SluiftComponent::Event& event) { + switch (event.type) { + case SluiftComponent::Event::MessageType: { + Message::ref message = boost::dynamic_pointer_cast(event.stanza); + Lua::Table result = boost::assign::map_list_of + ("type", boost::make_shared(std::string("message"))) + ("from", boost::make_shared(message->getFrom().toString())) + ("to", boost::make_shared(message->getTo().toString())) + ("body", boost::make_shared(message->getBody())) + ("message_type", boost::make_shared(MessageConvertor::convertMessageTypeToString(message->getType()))); + Lua::pushValue(L, result); + addPayloadsToTable(L, message->getPayloads()); + Lua::registerTableToString(L, -1); + break; + } + case SluiftComponent::Event::PresenceType: { + Presence::ref presence = boost::dynamic_pointer_cast(event.stanza); + Lua::Table result = boost::assign::map_list_of + ("type", boost::make_shared(std::string("presence"))) + ("from", boost::make_shared(presence->getFrom().toString())) + ("to", boost::make_shared(presence->getTo().toString())) + ("status", boost::make_shared(presence->getStatus())) + ("presence_type", boost::make_shared(PresenceConvertor::convertPresenceTypeToString(presence->getType()))); + Lua::pushValue(L, result); + addPayloadsToTable(L, presence->getPayloads()); + Lua::registerTableToString(L, -1); + break; + } + } +} + +struct CallUnaryLuaPredicateOnEvent { + CallUnaryLuaPredicateOnEvent(lua_State* L, int index) : L(L), index(index) { + } + + bool operator()(const SluiftComponent::Event& event) { + lua_pushvalue(L, index); + pushEvent(L, event); + if (lua_pcall(L, 1, 1, 0) != 0) { + throw Lua::Exception(lua_tostring(L, -1)); + } + bool result = lua_toboolean(L, -1); + lua_pop(L, 1); + return result; + } + + lua_State* L; + int index; +}; + + +SLUIFT_LUA_FUNCTION(Component, get_next_event) { + Sluift::globals.eventLoop.runOnce(); + SluiftComponent* component = getComponent(L); + + int timeout = getGlobalTimeout(L); + boost::optional type; + int condition = 0; + if (lua_istable(L, 2)) { + if (boost::optional typeString = Lua::getStringField(L, 2, "type")) { + if (*typeString == "message") { + type = SluiftComponent::Event::MessageType; + } + else if (*typeString == "presence") { + type = SluiftComponent::Event::PresenceType; + } + } + if (boost::optional timeoutInt = Lua::getIntField(L, 2, "timeout")) { + timeout = *timeoutInt; + } + lua_getfield(L, 2, "if"); + if (lua_isfunction(L, -1)) { + condition = Lua::absoluteOffset(L, -1); + } + } + + boost::optional event; + if (condition) { + event = component->getNextEvent(timeout, CallUnaryLuaPredicateOnEvent(L, condition)); + } + else if (type) { + event = component->getNextEvent( + timeout, lambda::bind(&SluiftComponent::Event::type, lambda::_1) == *type); + } + else { + event = component->getNextEvent(timeout); + } + + if (event) { + pushEvent(L, *event); + } + else { + lua_pushnil(L); + } + return 1; +} + +SLUIFT_LUA_FUNCTION_WITH_HELP( + Component, jid, + "Returns the JID of this component", + "self\n", + "" +) { + SluiftComponent* component = getComponent(L); + lua_pushstring(L, component->getComponent()->getJID().toString().c_str()); + return 1; +} + +SLUIFT_LUA_FUNCTION(Component, __gc) { + SluiftComponent* component = getComponent(L); + delete component; + return 0; +} diff --git a/Sluift/core.lua b/Sluift/core.lua index 7487de1..ffbb5f9 100644 --- a/Sluift/core.lua +++ b/Sluift/core.lua @@ -144,6 +144,7 @@ end -- Contains help for native methods that we want access to from here local extra_help = {} +local component_extra_help = {} local help_data = {} local help_classes = {} local help_class_metatables = {} @@ -481,6 +482,15 @@ local Client = { Client.__index = Client register_class_table_help(Client, "Client") +_H = { + [[ Component interface ]] +} +local Component = { + _with_prompt = function(component) return component:jid() end +} +Component.__index = Component +register_class_table_help(Component, "Component") + _H = { [[ Interface to communicate with a PubSub service ]] @@ -783,6 +793,196 @@ function Client:pubsub (jid) end register_help(Client.pubsub) + +-------------------------------------------------------------------------------- +-- Component +-------------------------------------------------------------------------------- + +component_extra_help = { + ["Component.get_next_event"] = { + [[ Returns the next event. ]], + parameters = { "self" }, + options = { + type = "The type of event to return (`message`, `presence`). When omitted, all event types are returned.", + timeout = "The amount of time to wait for events.", + ["if"] = "A function to filter events. When this function, called with the event as a parameter, returns true, the event will be returned" + } + }, + ["Component.get"] = { + [[ Sends a `get` query. ]], + parameters = { "self" }, + options = { + to = "The JID of the target to send the query to", + query = "The query to send", + timeout = "The amount of time to wait for the query to finish", + } + }, + ["Component.set"] = { + [[ Sends a `set` query. ]], + parameters = { "self" }, + options = { + to = "The JID of the target to send the query to", + query = "The query to send.", + timeout = "The amount of time to wait for the query to finish.", + } + }, + ["Component.async_connect"] = { + [[ + Connect to the server asynchronously. + + This method immediately returns. + ]], + parameters = { "self" }, + options = { + host = "The host to connect to.", + port = "The port to connect to." + } + } +} + +_H = { + [[ + Connect to the server. + + This method blocks until the connection has been established. + ]], + parameters = { "self" }, + options = component_extra_help["Component.async_connect"].options +} +function Component:connect (...) + local options = parse_options({}, ...) + local f = options.f + self:async_connect(options) + self:wait_connected() + if f then + return call {function() return f(self) end, finally = function() self:disconnect() end} + end + return true +end +register_help(Component.connect) + + +_H = { + [[ + Returns an iterator over all events. + + This function blocks until `timeout` is reached (or blocks forever if it is omitted). + ]], + parameters = { "self" }, + options = component_extra_help["Component.get_next_event"].options +} +function Component:events (options) + local function component_events_iterator(s) + return s['component']:get_next_event(s['options']) + end + return component_events_iterator, {component = self, options = options} +end +register_help(Component.events) + + +_H = { + [[ + Calls `f` for each event. + ]], + parameters = { "self" }, + options = merge_tables(get_help(Component.events).options, { + f = "The functor to call with each event. Required." + }) +} +function Component:for_each_event (...) + local options = parse_options({}, ...) + if not type(options.f) == 'function' then error('Expected function') end + for event in self:events(options) do + local result = options.f(event) + if result then + return result + end + end +end +register_help(Component.for_each_event) + +for method, event_type in pairs({message = 'message', presence = 'presence'}) do + _H = { + "Call `f` for all events of type `" .. event_type .. "`.", + parameters = { "self" }, + options = remove_help_parameters("type", get_help(Component.for_each_event).options) + } + Component['for_each_' .. method] = function (component, ...) + local options = parse_options({}, ...) + options['type'] = event_type + return component:for_each_event (options) + end + register_help(Component['for_each_' .. method]) + + _H = { + "Get the next event of type `" .. event_type .. "`.", + parameters = { "self" }, + options = remove_help_parameters("type", component_extra_help["Component.get_next_event"].options) + } + Component['get_next_' .. method] = function (component, ...) + local options = parse_options({}, ...) + options['type'] = event_type + return component:get_next_event(options) + end + register_help(Component['get_next_' .. method]) +end + +for method, event_type in pairs({messages = 'message'}) do + _H = { + "Returns an iterator over all events of type `" .. event_type .. "`.", + parameters = { "self" }, + options = remove_help_parameters("type", get_help(Component.for_each_event).options) + } + Component[method] = function (component, ...) + local options = parse_options({}, ...) + options['type'] = event_type + return component:events (options) + end + register_help(Component[method]) +end + +_H = { + [[ + Process all pending events + ]], + parameters = { "self" } +} +function Component:process_events () + for event in self:events{timeout=0} do end +end +register_help(Component.process_events) + + +-- +-- Register get_* and set_* convenience methods for some type of queries +-- +-- Example usages: +-- component:get_software_version{to = 'alice@wonderland.lit'} +-- component:set_command{to = 'alice@wonderland.lit', command = { type = 'execute', node = 'uptime' }} +-- +local get_set_shortcuts = { + get = {'software_version', 'disco_items', 'xml', 'dom', 'vcard'}, + set = {'command'} +} +for query_action, query_types in pairs(get_set_shortcuts) do + for _, query_type in ipairs(query_types) do + _H = { + "Sends a `" .. query_action .. "` query of type `" .. query_type .. "`.\n" .. + "Apart from the options below, all top level elements of `" .. query_type .. "` can be passed.", + parameters = { "self" }, + options = remove_help_parameters({"query", "type"}, component_extra_help["Component.get"].options), + } + local method = query_action .. '_' .. query_type + Component[method] = function (component, options) + options = options or {} + if type(options) ~= 'table' then error('Invalid options: ' .. options) end + options['query'] = merge_tables({_type = query_type}, options[query_type] or {}) + return component[query_action](component, options) + end + register_help(Component[method]) + end +end + -------------------------------------------------------------------------------- -- PubSub -------------------------------------------------------------------------------- @@ -1023,6 +1223,7 @@ extra_help['sluift'] = { return { Client = Client, + Component = Component, register_help = register_help, register_class_help = register_class_help, register_table_tostring = register_table_tostring, @@ -1035,6 +1236,7 @@ return { get_help = get_help, help = help, extra_help = extra_help, + component_extra_help = component_extra_help, copy = copy, with = with, create_form = create_form diff --git a/Sluift/sluift.cpp b/Sluift/sluift.cpp index b55649b..2fd1e50 100644 --- a/Sluift/sluift.cpp +++ b/Sluift/sluift.cpp @@ -16,6 +16,7 @@ #include "Watchdog.h" #include #include +#include #include #include #include @@ -88,6 +89,32 @@ SLUIFT_LUA_FUNCTION_WITH_HELP( } SLUIFT_LUA_FUNCTION_WITH_HELP( + Sluift, new_component, + + "Creates a new component.\n\nReturns a @{Component} object.\n", + + "jid The JID to connect as\n" + "passphrase The passphrase to use\n", + + "" +) { + Lua::checkString(L, 1); + JID jid(std::string(Lua::checkString(L, 1))); + std::string password(Lua::checkString(L, 2)); + + SluiftComponent** component = reinterpret_cast(lua_newuserdata(L, sizeof(SluiftComponent*))); + + lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.coreLibIndex); + lua_getfield(L, -1, "Component"); + lua_setmetatable(L, -3); + lua_pop(L, 1); + + *component = new SluiftComponent(jid, password, &Sluift::globals.networkFactories, &Sluift::globals.eventLoop); + (*component)->setTraceEnabled(getGlobalDebug(L)); + return 1; +} + +SLUIFT_LUA_FUNCTION_WITH_HELP( Sluift, sha1, "Compute the SHA-1 hash of given data", "data the data to hash", @@ -408,6 +435,16 @@ SLUIFT_API int luaopen_sluift(lua_State* L) { } lua_pop(L, 1); + // Load component metatable + lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.coreLibIndex); + std::vector comp_tables = boost::assign::list_of("Component"); + foreach(const std::string& table, comp_tables) { + lua_getfield(L, -1, table.c_str()); + Lua::FunctionRegistry::getInstance().addFunctionsToTable(L, table); + lua_pop(L, 1); + } + lua_pop(L, 1); + // Register documentation for all elements foreach (boost::shared_ptr convertor, Sluift::globals.elementConvertor.getConvertors()) { boost::optional documentation = convertor->getDocumentation(); diff --git a/Swiften/Component/CoreComponent.cpp b/Swiften/Component/CoreComponent.cpp index d2cc7aa..cc6be42 100644 --- a/Swiften/Component/CoreComponent.cpp +++ b/Swiften/Component/CoreComponent.cpp @@ -161,4 +161,8 @@ void CoreComponent::sendPresence(boost::shared_ptr presence) { stanzaChannel_->sendPresence(presence); } +void CoreComponent::sendData(const std::string& data) { + sessionStream_->writeData(data); +} + } diff --git a/Swiften/Component/CoreComponent.h b/Swiften/Component/CoreComponent.h index 63b68f6..e9fdd88 100644 --- a/Swiften/Component/CoreComponent.h +++ b/Swiften/Component/CoreComponent.h @@ -51,6 +51,7 @@ namespace Swift { void sendMessage(boost::shared_ptr); void sendPresence(boost::shared_ptr); + void sendData(const std::string& data); IQRouter* getIQRouter() const { return iqRouter_; -- cgit v0.10.2-6-g49f6