summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Sluift/ClientHelpers.h18
-rw-r--r--Sluift/Examples/Component.lua55
-rw-r--r--Sluift/Helpers.cpp (renamed from Sluift/ClientHelpers.cpp)32
-rw-r--r--Sluift/Helpers.h21
-rw-r--r--Sluift/SConscript4
-rw-r--r--Sluift/SluiftClient.cpp6
-rw-r--r--Sluift/SluiftComponent.cpp145
-rw-r--r--Sluift/SluiftComponent.h108
-rw-r--r--Sluift/component.cpp467
-rw-r--r--Sluift/core.lua202
-rw-r--r--Sluift/sluift.cpp37
-rw-r--r--Swiften/Component/CoreComponent.cpp4
-rw-r--r--Swiften/Component/CoreComponent.h1
13 files changed, 1068 insertions, 32 deletions
diff --git a/Sluift/ClientHelpers.h b/Sluift/ClientHelpers.h
deleted file mode 100644
index eb78ba6..0000000
--- a/Sluift/ClientHelpers.h
+++ /dev/null
@@ -1,18 +0,0 @@
-/*
- * Copyright (c) 2013 Remko Tronçon
- * Licensed under the GNU General Public License.
- * See the COPYING file for more information.
- */
-
-#pragma once
-
-#include <Swiften/Base/Override.h>
-#include <Swiften/Base/API.h>
-
-#include <string>
-
-namespace Swift {
- class ClientError;
-
- std::string getClientErrorString(const ClientError& error);
-}
diff --git a/Sluift/Examples/Component.lua b/Sluift/Examples/Component.lua
new file mode 100644
index 0000000..b5d6539
--- /dev/null
+++ b/Sluift/Examples/Component.lua
@@ -0,0 +1,55 @@
+--[[
+ Copyright (c) 2014 Edwin Mons and Remko Tronçon
+ Licensed under the GNU General Public License v3.
+ See Documentation/Licenses/GPLv3.txt for more information.
+--]]
+
+--[[
+
+ Component example.
+
+ This script connects to an XMPP server as a component, and listens to
+ messages received.
+
+ The following environment variables are used:
+ * SLUIFT_COMP_DOMAIN: Component domain name
+ * SLUIFT_COMP_SECRET: Component secret
+ * SLUIFT_COMP_HOST: XMPP server host name
+ * SLUIFT_COMP_PORT: XMPP server component port
+ * SLUIFT_JID: Recipient of presence and initial message
+ * SLUIFT_DEBUG: Sets whether debugging should be turned on
+
+--]]
+
+require "sluift"
+
+sluift.debug = os.getenv("SLUIFT_DEBUG") or false
+
+config = {
+ domain = os.getenv('SLUIFT_COMP_DOMAIN'),
+ secret = os.getenv('SLUIFT_COMP_SECRET'),
+ host = os.getenv('SLUIFT_COMP_HOST'),
+ port = os.getenv('SLUIFT_COMP_PORT'),
+ jid = os.getenv('SLUIFT_JID')
+}
+
+-- Create the component, and connect
+comp = sluift.new_component(config.domain, config.secret);
+comp:connect(config)
+
+-- Send initial presence and message
+-- Assumes the remote client already has this component user on his roster
+comp:send_presence{from='root@' .. config.domain, to=config.jid}
+comp:send_message{from='root@' .. config.domain, to=config.jid, body='Component active'}
+
+-- Listen for messages, and reply if one is received
+for message in comp:messages() do
+ print("Received a message from " .. message.from)
+ comp:send_message{to=message.from, from=message.to, body='I received: ' .. message['body']}
+
+ -- Send out a ping to demonstrate we can do more than just send messages
+ comp:get{to=message.from, query='<ping xmlns="urn:xmpp:ping"/>'}
+end
+
+comp:disconnect()
+
diff --git a/Sluift/ClientHelpers.cpp b/Sluift/Helpers.cpp
index 8e07112..29e2b04 100644
--- a/Sluift/ClientHelpers.cpp
+++ b/Sluift/Helpers.cpp
@@ -1,50 +1,62 @@
/*
- * Copyright (c) 2013 Remko Tronçon
+ * Copyright (c) 2013-2014 Kevin Smith and Remko Tronçon
* Licensed under the GNU General Public License.
* See the COPYING file for more information.
*/
-#include <Sluift/ClientHelpers.h>
+#include <Sluift/Helpers.h>
#include <Swiften/Client/ClientError.h>
+#include <Swiften/Component/ComponentError.h>
using namespace Swift;
-std::string Swift::getClientErrorString(const ClientError& error) {
+template<class T> std::string Swift::getCommonErrorString(T& error) {
std::string reason = "Disconnected: ";
switch(error.getType()) {
- case ClientError::UnknownError: reason += "Unknown Error"; break;
+ case T::UnknownError: reason += "Unknown Error"; break;
+ case T::ConnectionError: reason += "Error connecting to server"; break;
+ case T::ConnectionReadError: reason += "Error while receiving server data"; break;
+ case T::ConnectionWriteError: reason += "Error while sending data to the server"; break;
+ case T::XMLError: reason += "Error parsing server data"; break;
+ case T::AuthenticationFailedError: reason += "Login/password invalid"; break;
+ case T::UnexpectedElementError: reason += "Unexpected response"; break;
+ }
+ return reason;
+}
+
+std::string Swift::getErrorString(const ClientError& error) {
+ std::string reason = getCommonErrorString(error);
+ switch(error.getType()) {
case ClientError::DomainNameResolveError: reason += "Unable to find server"; break;
- case ClientError::ConnectionError: reason += "Error connecting to server"; break;
- case ClientError::ConnectionReadError: reason += "Error while receiving server data"; break;
- case ClientError::ConnectionWriteError: reason += "Error while sending data to the server"; break;
- case ClientError::XMLError: reason += "Error parsing server data"; break;
- case ClientError::AuthenticationFailedError: reason += "Login/password invalid"; break;
case ClientError::CompressionFailedError: reason += "Error while compressing stream"; break;
case ClientError::ServerVerificationFailedError: reason += "Server verification failed"; break;
case ClientError::NoSupportedAuthMechanismsError: reason += "Authentication mechanisms not supported"; break;
- case ClientError::UnexpectedElementError: reason += "Unexpected response"; break;
case ClientError::ResourceBindError: reason += "Error binding resource"; break;
case ClientError::RevokedError: reason += "Certificate got revoked"; break;
case ClientError::RevocationCheckFailedError: reason += "Failed to do revokation check"; break;
case ClientError::SessionStartError: reason += "Error starting session"; break;
case ClientError::StreamError: reason += "Stream error"; break;
case ClientError::TLSError: reason += "Encryption error"; break;
case ClientError::ClientCertificateLoadError: reason += "Error loading certificate (Invalid password?)"; break;
case ClientError::ClientCertificateError: reason += "Certificate not authorized"; break;
case ClientError::UnknownCertificateError: reason += "Unknown certificate"; break;
case ClientError::CertificateCardRemoved: reason += "Certificate card removed"; break;
case ClientError::CertificateExpiredError: reason += "Certificate has expired"; break;
case ClientError::CertificateNotYetValidError: reason += "Certificate is not yet valid"; break;
case ClientError::CertificateSelfSignedError: reason += "Certificate is self-signed"; break;
case ClientError::CertificateRejectedError: reason += "Certificate has been rejected"; break;
case ClientError::CertificateUntrustedError: reason += "Certificate is not trusted"; break;
case ClientError::InvalidCertificatePurposeError: reason += "Certificate cannot be used for encrypting your connection"; break;
case ClientError::CertificatePathLengthExceededError: reason += "Certificate path length constraint exceeded"; break;
case ClientError::InvalidCertificateSignatureError: reason += "Invalid certificate signature"; break;
case ClientError::InvalidCAError: reason += "Invalid Certificate Authority"; break;
case ClientError::InvalidServerIdentityError: reason += "Certificate does not match the host identity"; break;
}
return reason;
}
+std::string Swift::getErrorString(const ComponentError& error) {
+ return getCommonErrorString(error);
+}
+
diff --git a/Sluift/Helpers.h b/Sluift/Helpers.h
new file mode 100644
index 0000000..d04a610
--- /dev/null
+++ b/Sluift/Helpers.h
@@ -0,0 +1,21 @@
+/*
+ * Copyright (c) 2013-2014 Kevin Smith and Remko Tronçon
+ * Licensed under the GNU General Public License.
+ * See the COPYING file for more information.
+ */
+
+#pragma once
+
+#include <Swiften/Base/Override.h>
+#include <Swiften/Base/API.h>
+
+#include <string>
+
+namespace Swift {
+ class ClientError;
+ class ComponentError;
+
+ template<typename T> std::string getCommonErrorString(T& error);
+ std::string getErrorString(const ClientError& error);
+ std::string getErrorString(const ComponentError& error);
+}
diff --git a/Sluift/SConscript b/Sluift/SConscript
index 3cc1f29..5e0a030 100644
--- a/Sluift/SConscript
+++ b/Sluift/SConscript
@@ -1,87 +1,89 @@
import Version, os.path
Import(["env"])
if env["SCONS_STAGE"] == "build" and not GetOption("help") and not env.get("HAVE_LUA", 0) :
print "Warning: Lua was not found. Sluift will not be built."
if "Sluift" in env["PROJECTS"] :
env["PROJECTS"].remove("Sluift")
elif env["SCONS_STAGE"] == "build" :
sluift_sources = [
"Lua/Value.cpp",
"Lua/Exception.cpp",
"Lua/Check.cpp",
"Lua/FunctionRegistration.cpp",
"Lua/FunctionRegistry.cpp",
"Lua/LuaUtils.cpp",
"LuaElementConvertors.cpp",
"LuaElementConvertor.cpp",
"Response.cpp",
"ElementConvertors/BodyConvertor.cpp",
"ElementConvertors/VCardUpdateConvertor.cpp",
"ElementConvertors/PubSubEventConvertor.cpp",
"ElementConvertors/RawXMLElementConvertor.cpp",
"ElementConvertors/DOMElementConvertor.cpp",
"ElementConvertors/DefaultElementConvertor.cpp",
"ElementConvertors/DiscoInfoConvertor.cpp",
"ElementConvertors/DiscoItemsConvertor.cpp",
"ElementConvertors/FormConvertor.cpp",
"ElementConvertors/SoftwareVersionConvertor.cpp",
"ElementConvertors/VCardConvertor.cpp",
"ElementConvertors/CommandConvertor.cpp",
"ElementConvertors/StatusConvertor.cpp",
"ElementConvertors/StatusShowConvertor.cpp",
"ElementConvertors/DelayConvertor.cpp",
- "ClientHelpers.cpp",
+ "Helpers.cpp",
"SluiftClient.cpp",
+ "SluiftComponent.cpp",
"Watchdog.cpp",
"core.c",
"client.cpp",
+ "component.cpp",
"sluift.cpp"
]
sluift_sources += env.SConscript("ElementConvertors/SConscript")
sluift_env = env.Clone()
sluift_env.UseFlags(env.get("LUA_FLAGS", {}))
sluift_env.UseFlags(env["SWIFTEN_FLAGS"])
sluift_env.UseFlags(env["SWIFTEN_DEP_FLAGS"])
# Support compilation on both Lua 5.1 and Lua 5.2
sluift_env.Append(CPPDEFINES = ["LUA_COMPAT_ALL"])
if sluift_env["PLATFORM"] == "win32" :
sluift_env.Append(CPPDEFINES = ["SLUIFT_BUILD_DLL"])
if sluift_env["PLATFORM"] == "darwin" and os.path.isdir("/Applications/iTunes.app") :
sluift_env.Append(FRAMEWORKS = ["ScriptingBridge"])
sluift_env.Command("iTunes.h", "/Applications/iTunes.app",
"sdef ${SOURCE} | sdp -fh --basename iTunes -o ${TARGET.dir}")
sluift_env.Append(CPPDEFINES = ["HAVE_ITUNES"])
sluift_sources += ["ITunesInterface.mm"]
# Generate Version.h
version_header = "#pragma once\n\n"
version_header += "#define SLUIFT_VERSION_STRING \"" + Version.getBuildVersion(env.Dir("#").abspath, "sluift") + "\"\n"
sluift_env.WriteVal("Version.h", sluift_env.Value(version_header))
# Generate core.c
def generate_embedded_lua(env, target, source) :
f = open(source[0].abspath, "r")
data = f.read()
f.close()
data_bytes = bytearray(data)
f = open(target[0].abspath, "w")
f.write('#include <stddef.h>\n')
f.write('const size_t ' + source[0].name.replace(".", "_") + "_size = " + str(len(data_bytes)) + ";\n")
f.write('const char ' + source[0].name.replace(".", "_") + "[] = {" + ", ".join([str(b) for b in data_bytes]) + "};\n")
f.close()
sluift_env.Command("core.c", ["core.lua"], env.Action(generate_embedded_lua, cmdstr="$GENCOMSTR"))
sluift_env.WriteVal("dll.c", sluift_env.Value(""))
sluift_sources = [env.File(x) for x in sluift_sources]
for sluift_variant in ['dll', 'exe'] :
SConscript(["SConscript.variant"], variant_dir = sluift_variant,
duplicate = 0,
exports = ['sluift_sources', 'sluift_variant', 'sluift_env'])
diff --git a/Sluift/SluiftClient.cpp b/Sluift/SluiftClient.cpp
index 9ff9d18..69472b8 100644
--- a/Sluift/SluiftClient.cpp
+++ b/Sluift/SluiftClient.cpp
@@ -1,181 +1,181 @@
/*
- * Copyright (c) 2013-2014 Remko Tronçon
+ * Copyright (c) 2013-2014 Kevin Smith and Remko Tronçon
* Licensed under the GNU General Public License.
* See the COPYING file for more information.
*/
#include <Sluift/SluiftClient.h>
#include <boost/numeric/conversion/cast.hpp>
#include <Swiften/Client/ClientXMLTracer.h>
#include <Swiften/Client/Client.h>
#include <Swiften/Roster/XMPPRoster.h>
#include <Sluift/SluiftGlobals.h>
#include <Sluift/Lua/Exception.h>
#include <Swiften/Elements/Message.h>
#include <Swiften/Elements/PubSubEvent.h>
#include <Swiften/Queries/RawRequest.h>
-#include <Sluift/ClientHelpers.h>
+#include <Sluift/Helpers.h>
#include <Swiften/Elements/Presence.h>
using namespace Swift;
SluiftClient::SluiftClient(
const JID& jid,
const std::string& password,
NetworkFactories* networkFactories,
SimpleEventLoop* eventLoop) :
networkFactories(networkFactories),
eventLoop(eventLoop),
tracer(NULL) {
client = new Client(jid, password, networkFactories);
client->setAlwaysTrustCertificates();
client->onDisconnected.connect(boost::bind(&SluiftClient::handleDisconnected, this, _1));
client->onMessageReceived.connect(boost::bind(&SluiftClient::handleIncomingMessage, this, _1));
client->onPresenceReceived.connect(boost::bind(&SluiftClient::handleIncomingPresence, this, _1));
client->getPubSubManager()->onEvent.connect(boost::bind(&SluiftClient::handleIncomingPubSubEvent, this, _1, _2));
client->getRoster()->onInitialRosterPopulated.connect(boost::bind(&SluiftClient::handleInitialRosterPopulated, this));
}
SluiftClient::~SluiftClient() {
delete tracer;
delete client;
}
void SluiftClient::connect() {
rosterReceived = false;
disconnectedError = boost::optional<ClientError>();
client->connect(options);
}
void SluiftClient::connect(const std::string& host, int port) {
rosterReceived = false;
options.manualHostname = host;
options.manualPort = port;
disconnectedError = boost::optional<ClientError>();
client->connect(options);
}
void SluiftClient::setTraceEnabled(bool b) {
if (b && !tracer) {
tracer = new ClientXMLTracer(client, options.boshURL.isEmpty()? false: true);
}
else if (!b && tracer) {
delete tracer;
tracer = NULL;
}
}
void SluiftClient::waitConnected(int timeout) {
Watchdog watchdog(timeout, networkFactories->getTimerFactory());
while (!watchdog.getTimedOut() && client->isActive() && !client->isAvailable()) {
eventLoop->runUntilEvents();
}
if (watchdog.getTimedOut()) {
client->disconnect();
throw Lua::Exception("Timeout while connecting");
}
if (disconnectedError) {
- throw Lua::Exception(getClientErrorString(*disconnectedError));
+ throw Lua::Exception(getErrorString(*disconnectedError));
}
}
bool SluiftClient::isConnected() const {
return client->isAvailable();
}
void SluiftClient::disconnect() {
client->disconnect();
while (client->isActive()) {
eventLoop->runUntilEvents();
}
}
void SluiftClient::setSoftwareVersion(const std::string& name, const std::string& version, const std::string& os) {
client->setSoftwareVersion(name, version, os);
}
boost::optional<SluiftClient::Event> SluiftClient::getNextEvent(
int timeout, boost::function<bool (const Event&)> condition) {
Watchdog watchdog(timeout, networkFactories->getTimerFactory());
size_t currentIndex = 0;
while (true) {
// Look for pending events in the queue
while (currentIndex < pendingEvents.size()) {
Event event = pendingEvents[currentIndex];
if (!condition || condition(event)) {
pendingEvents.erase(
pendingEvents.begin()
+ boost::numeric_cast<int>(currentIndex));
return event;
}
++currentIndex;
}
// Wait for new events
while (!watchdog.getTimedOut() && currentIndex >= pendingEvents.size() && client->isActive()) {
eventLoop->runUntilEvents();
}
// Finish if we're disconnected or timed out
if (watchdog.getTimedOut() || !client->isActive()) {
return boost::optional<Event>();
}
}
}
std::vector<XMPPRosterItem> SluiftClient::getRoster() {
if (!rosterReceived) {
// If we haven't requested it yet, request it for the first time
client->requestRoster();
}
while (!rosterReceived) {
eventLoop->runUntilEvents();
}
return client->getRoster()->getItems();
}
void SluiftClient::handleIncomingMessage(boost::shared_ptr<Message> stanza) {
if (stanza->getPayload<PubSubEvent>()) {
// Already handled by pubsub manager
return;
}
pendingEvents.push_back(Event(stanza));
}
void SluiftClient::handleIncomingPresence(boost::shared_ptr<Presence> stanza) {
pendingEvents.push_back(Event(stanza));
}
void SluiftClient::handleIncomingPubSubEvent(const JID& from, boost::shared_ptr<PubSubEventPayload> event) {
pendingEvents.push_back(Event(from, event));
}
void SluiftClient::handleInitialRosterPopulated() {
rosterReceived = true;
}
void SluiftClient::handleRequestResponse(boost::shared_ptr<Payload> response, boost::shared_ptr<ErrorPayload> error) {
requestResponse = response;
requestError = error;
requestResponseReceived = true;
}
void SluiftClient::handleDisconnected(const boost::optional<ClientError>& error) {
disconnectedError = error;
}
Sluift::Response SluiftClient::doSendRequest(boost::shared_ptr<Request> request, int timeout) {
requestResponse.reset();
requestError.reset();
requestResponseReceived = false;
request->send();
Watchdog watchdog(timeout, networkFactories->getTimerFactory());
while (!watchdog.getTimedOut() && !requestResponseReceived) {
eventLoop->runUntilEvents();
}
return Sluift::Response(requestResponse, watchdog.getTimedOut() ?
boost::make_shared<ErrorPayload>(ErrorPayload::RemoteServerTimeout) : requestError);
}
diff --git a/Sluift/SluiftComponent.cpp b/Sluift/SluiftComponent.cpp
new file mode 100644
index 0000000..c08a103
--- /dev/null
+++ b/Sluift/SluiftComponent.cpp
@@ -0,0 +1,145 @@
+/*
+ * Copyright (c) 2014 Kevin Smith and Remko Tronçon
+ * Licensed under the GNU General Public License.
+ * See the COPYING file for more information.
+ */
+
+#include <Sluift/SluiftComponent.h>
+
+#include <boost/numeric/conversion/cast.hpp>
+
+#include <Swiften/Component/ComponentXMLTracer.h>
+#include <Swiften/Component/Component.h>
+#include <Swiften/Roster/XMPPRoster.h>
+#include <Sluift/SluiftGlobals.h>
+#include <Sluift/Lua/Exception.h>
+#include <Swiften/Elements/Message.h>
+#include <Swiften/Queries/RawRequest.h>
+#include <Sluift/Helpers.h>
+#include <Swiften/Elements/Presence.h>
+
+using namespace Swift;
+
+SluiftComponent::SluiftComponent(
+ const JID& jid,
+ const std::string& password,
+ NetworkFactories* networkFactories,
+ SimpleEventLoop* eventLoop):
+ networkFactories(networkFactories),
+ eventLoop(eventLoop),
+ tracer(NULL) {
+ component = new Component(jid, password, networkFactories);
+ component->onError.connect(boost::bind(&SluiftComponent::handleError, this, _1));
+ component->onMessageReceived.connect(boost::bind(&SluiftComponent::handleIncomingMessage, this, _1));
+ component->onPresenceReceived.connect(boost::bind(&SluiftComponent::handleIncomingPresence, this, _1));
+}
+
+SluiftComponent::~SluiftComponent() {
+ delete tracer;
+ delete component;
+}
+
+void SluiftComponent::connect(const std::string& host, int port) {
+ disconnectedError = boost::optional<ComponentError>();
+ component->connect(host, port);
+}
+
+void SluiftComponent::setTraceEnabled(bool b) {
+ if (b && !tracer) {
+ tracer = new ComponentXMLTracer(component);
+ }
+ else if (!b && tracer) {
+ delete tracer;
+ tracer = NULL;
+ }
+}
+
+void SluiftComponent::waitConnected(int timeout) {
+ Watchdog watchdog(timeout, networkFactories->getTimerFactory());
+ while (!watchdog.getTimedOut() && !disconnectedError && !component->isAvailable()) {
+ eventLoop->runUntilEvents();
+ }
+ if (watchdog.getTimedOut()) {
+ component->disconnect();
+ throw Lua::Exception("Timeout while connecting");
+ }
+ if (disconnectedError) {
+ throw Lua::Exception(getErrorString(*disconnectedError));
+ }
+}
+
+bool SluiftComponent::isConnected() const {
+ return component->isAvailable();
+}
+
+void SluiftComponent::disconnect() {
+ component->disconnect();
+ while (component->isAvailable()) {
+ eventLoop->runUntilEvents();
+ }
+}
+
+void SluiftComponent::setSoftwareVersion(const std::string& name, const std::string& version, const std::string& os) {
+ component->setSoftwareVersion(name, version);
+}
+
+boost::optional<SluiftComponent::Event> SluiftComponent::getNextEvent(
+ int timeout, boost::function<bool (const Event&)> condition) {
+ Watchdog watchdog(timeout, networkFactories->getTimerFactory());
+ size_t currentIndex = 0;
+ while (true) {
+ // Look for pending events in the queue
+ while (currentIndex < pendingEvents.size()) {
+ Event event = pendingEvents[currentIndex];
+ if (!condition || condition(event)) {
+ pendingEvents.erase(
+ pendingEvents.begin()
+ + boost::numeric_cast<int>(currentIndex));
+ return event;
+ }
+ ++currentIndex;
+ }
+
+ // Wait for new events
+ while (!watchdog.getTimedOut() && currentIndex >= pendingEvents.size() && component->isAvailable()) {
+ eventLoop->runUntilEvents();
+ }
+
+ // Finish if we're disconnected or timed out
+ if (watchdog.getTimedOut() || !component->isAvailable()) {
+ return boost::optional<Event>();
+ }
+ }
+}
+
+void SluiftComponent::handleIncomingMessage(boost::shared_ptr<Message> stanza) {
+ pendingEvents.push_back(Event(stanza));
+}
+
+void SluiftComponent::handleIncomingPresence(boost::shared_ptr<Presence> stanza) {
+ pendingEvents.push_back(Event(stanza));
+}
+
+void SluiftComponent::handleRequestResponse(boost::shared_ptr<Payload> response, boost::shared_ptr<ErrorPayload> error) {
+ requestResponse = response;
+ requestError = error;
+ requestResponseReceived = true;
+}
+
+void SluiftComponent::handleError(const boost::optional<ComponentError>& error) {
+ disconnectedError = error;
+}
+
+Sluift::Response SluiftComponent::doSendRequest(boost::shared_ptr<Request> request, int timeout) {
+ requestResponse.reset();
+ requestError.reset();
+ requestResponseReceived = false;
+ request->send();
+
+ Watchdog watchdog(timeout, networkFactories->getTimerFactory());
+ while (!watchdog.getTimedOut() && !requestResponseReceived) {
+ eventLoop->runUntilEvents();
+ }
+ return Sluift::Response(requestResponse, watchdog.getTimedOut() ?
+ boost::make_shared<ErrorPayload>(ErrorPayload::RemoteServerTimeout) : requestError);
+}
diff --git a/Sluift/SluiftComponent.h b/Sluift/SluiftComponent.h
new file mode 100644
index 0000000..3d5792b
--- /dev/null
+++ b/Sluift/SluiftComponent.h
@@ -0,0 +1,108 @@
+/*
+ * Copyright (c) 2014 Kevin Smith and Remko Tronçon
+ * Licensed under the GNU General Public License.
+ * See the COPYING file for more information.
+ */
+
+#pragma once
+
+#include <deque>
+#include <boost/optional.hpp>
+#include <boost/bind.hpp>
+#include <boost/function.hpp>
+
+#include <Swiften/Client/ClientOptions.h>
+#include <Swiften/Elements/IQ.h>
+#include <Swiften/Elements/Message.h>
+#include <Swiften/Elements/Presence.h>
+#include <Swiften/Queries/GenericRequest.h>
+#include <Swiften/Roster/XMPPRosterItem.h>
+#include <Swiften/Component/ComponentError.h>
+#include <Swiften/Network/NetworkFactories.h>
+#include <Swiften/Component/Component.h>
+#include <Swiften/EventLoop/SimpleEventLoop.h>
+#include <Sluift/Watchdog.h>
+#include <Sluift/Response.h>
+
+namespace Swift {
+ struct SluiftGlobals;
+ class ComponentXMLTracer;
+ class Component;
+ class Stanza;
+ class Payload;
+ class ErrorPayload;
+ class JID;
+
+ class SluiftComponent {
+ public:
+ struct Event {
+ enum Type {
+ MessageType,
+ PresenceType
+ };
+
+ Event(boost::shared_ptr<Message> stanza) : type(MessageType), stanza(stanza) {}
+ Event(boost::shared_ptr<Presence> stanza) : type(PresenceType), stanza(stanza) {}
+
+ Type type;
+
+ // Message & Presence
+ boost::shared_ptr<Stanza> stanza;
+ };
+
+ SluiftComponent(
+ const JID& jid,
+ const std::string& password,
+ NetworkFactories* networkFactories,
+ SimpleEventLoop* eventLoop);
+ ~SluiftComponent();
+
+ Component* getComponent() {
+ return component;
+ }
+
+ void connect(const std::string& host, int port);
+ void waitConnected(int timeout);
+ bool isConnected() const;
+ void setTraceEnabled(bool b);
+
+ template<typename REQUEST_TYPE>
+ Sluift::Response sendRequest(REQUEST_TYPE request, int timeout) {
+ boost::signals::scoped_connection c = request->onResponse.connect(
+ boost::bind(&SluiftComponent::handleRequestResponse, this, _1, _2));
+ return doSendRequest(request, timeout);
+ }
+
+ template<typename REQUEST_TYPE>
+ Sluift::Response sendVoidRequest(REQUEST_TYPE request, int timeout) {
+ boost::signals::scoped_connection c = request->onResponse.connect(
+ boost::bind(&SluiftComponent::handleRequestResponse, this, boost::shared_ptr<Payload>(), _1));
+ return doSendRequest(request, timeout);
+ }
+
+ void disconnect();
+ void setSoftwareVersion(const std::string& name, const std::string& version, const std::string& os);
+ boost::optional<SluiftComponent::Event> getNextEvent(int timeout,
+ boost::function<bool (const Event&)> condition = 0);
+
+ private:
+ Sluift::Response doSendRequest(boost::shared_ptr<Request> request, int timeout);
+
+ void handleIncomingMessage(boost::shared_ptr<Message> stanza);
+ void handleIncomingPresence(boost::shared_ptr<Presence> stanza);
+ void handleRequestResponse(boost::shared_ptr<Payload> response, boost::shared_ptr<ErrorPayload> error);
+ void handleError(const boost::optional<ComponentError>& error);
+
+ private:
+ NetworkFactories* networkFactories;
+ SimpleEventLoop* eventLoop;
+ Component* component;
+ ComponentXMLTracer* tracer;
+ bool rosterReceived;
+ std::deque<Event> pendingEvents;
+ boost::optional<ComponentError> disconnectedError;
+ bool requestResponseReceived;
+ boost::shared_ptr<Payload> requestResponse;
+ boost::shared_ptr<ErrorPayload> requestError;
+ };
+}
diff --git a/Sluift/component.cpp b/Sluift/component.cpp
new file mode 100644
index 0000000..0c400b3
--- /dev/null
+++ b/Sluift/component.cpp
@@ -0,0 +1,467 @@
+/*
+ * Copyright (c) 2014 Kevin Smith and Remko Tronçon
+ * Licensed under the GNU General Public License.
+ * See the COPYING file for more information.
+ */
+
+#include <boost/lambda/lambda.hpp>
+#include <boost/lambda/bind.hpp>
+#include <boost/assign/list_of.hpp>
+#include <iostream>
+
+#include <Sluift/SluiftComponent.h>
+#include <Swiften/JID/JID.h>
+#include <Swiften/Elements/SoftwareVersion.h>
+#include <Swiften/Elements/Message.h>
+#include <Swiften/Elements/Presence.h>
+#include <Swiften/Elements/RawXMLPayload.h>
+#include <Swiften/Elements/RosterItemPayload.h>
+#include <Swiften/Elements/RosterPayload.h>
+#include <Swiften/Elements/DiscoInfo.h>
+#include <Swiften/Elements/MAMQuery.h>
+#include <Swiften/Queries/GenericRequest.h>
+#include <Swiften/Presence/PresenceSender.h>
+#include <Swiften/Roster/XMPPRoster.h>
+#include <Swiften/Roster/SetRosterRequest.h>
+#include <Swiften/Presence/SubscriptionManager.h>
+#include <Swiften/Roster/XMPPRosterItem.h>
+#include <Swiften/Queries/IQRouter.h>
+#include <Swiften/Queries/Requests/GetSoftwareVersionRequest.h>
+#include <Sluift/Lua/FunctionRegistration.h>
+#include <Swiften/Base/foreach.h>
+#include <Swiften/Base/IDGenerator.h>
+#include <Sluift/Lua/Check.h>
+#include <Sluift/Lua/Value.h>
+#include <Sluift/Lua/Exception.h>
+#include <Sluift/Lua/LuaUtils.h>
+#include <Sluift/globals.h>
+#include <Sluift/ElementConvertors/StanzaConvertor.h>
+#include <Sluift/ElementConvertors/IQConvertor.h>
+#include <Sluift/ElementConvertors/PresenceConvertor.h>
+#include <Sluift/ElementConvertors/MessageConvertor.h>
+
+using namespace Swift;
+namespace lambda = boost::lambda;
+
+static inline SluiftComponent* getComponent(lua_State* L) {
+ return *Lua::checkUserData<SluiftComponent>(L, 1);
+}
+
+static inline int getGlobalTimeout(lua_State* L) {
+ lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.moduleLibIndex);
+ lua_getfield(L, -1, "timeout");
+ int result = boost::numeric_cast<int>(lua_tointeger(L, -1));
+ lua_pop(L, 2);
+ return result;
+}
+
+static void addPayloadsToTable(lua_State* L, const std::vector<boost::shared_ptr<Payload> >& payloads) {
+ if (!payloads.empty()) {
+ lua_createtable(L, boost::numeric_cast<int>(payloads.size()), 0);
+ for (size_t i = 0; i < payloads.size(); ++i) {
+ Sluift::globals.elementConvertor.convertToLua(L, payloads[i]);
+ lua_rawseti(L, -2, boost::numeric_cast<int>(i+1));
+ }
+ Lua::registerGetByTypeIndex(L, -1);
+ lua_setfield(L, -2, "payloads");
+ }
+}
+
+static boost::shared_ptr<Payload> getPayload(lua_State* L, int index) {
+ if (lua_type(L, index) == LUA_TTABLE) {
+ return boost::dynamic_pointer_cast<Payload>(Sluift::globals.elementConvertor.convertFromLua(L, index));
+ }
+ else if (lua_type(L, index) == LUA_TSTRING) {
+ return boost::make_shared<RawXMLPayload>(Lua::checkString(L, index));
+ }
+ else {
+ return boost::shared_ptr<Payload>();
+ }
+}
+
+static std::vector< boost::shared_ptr<Payload> > getPayloadsFromTable(lua_State* L, int index) {
+ index = Lua::absoluteOffset(L, index);
+ std::vector< boost::shared_ptr<Payload> > result;
+ lua_getfield(L, index, "payloads");
+ if (lua_istable(L, -1)) {
+ for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) {
+ boost::shared_ptr<Payload> payload = getPayload(L, -1);
+ if (payload) {
+ result.push_back(payload);
+ }
+ }
+ }
+ lua_pop(L, 1);
+ return result;
+}
+
+SLUIFT_LUA_FUNCTION(Component, async_connect) {
+ SluiftComponent* component = getComponent(L);
+
+ std::string host;
+ int port = 0;
+ if (lua_istable(L, 2)) {
+ if (boost::optional<std::string> hostString = Lua::getStringField(L, 2, "host")) {
+ host = *hostString;
+ }
+ if (boost::optional<int> portInt = Lua::getIntField(L, 2, "port")) {
+ port = *portInt;
+ }
+ }
+ component->connect(host, port);
+ return 0;
+}
+
+SLUIFT_LUA_FUNCTION_WITH_HELP(
+ Component, set_trace_enabled,
+ "Enable/disable tracing of the data sent/received.\n\n.",
+ "self\n"
+ "enable a boolean specifying whether to enable/disable tracing",
+ ""
+) {
+ getComponent(L)->setTraceEnabled(lua_toboolean(L, 1));
+ return 0;
+}
+
+SLUIFT_LUA_FUNCTION_WITH_HELP(
+ Component, wait_connected,
+ "Block until the component is connected.\n\nThis is useful after an `async_connect`.",
+ "self",
+ ""
+) {
+ getComponent(L)->waitConnected(getGlobalTimeout(L));
+ return 0;
+}
+
+SLUIFT_LUA_FUNCTION_WITH_HELP(
+ Component, is_connected,
+ "Checks whether this component is still connected.\n\nReturns a boolean.",
+ "self\n",
+ ""
+) {
+ lua_pushboolean(L, getComponent(L)->isConnected());
+ return 1;
+}
+
+SLUIFT_LUA_FUNCTION_WITH_HELP(
+ Component, disconnect,
+ "Disconnect from the server",
+ "self\n",
+ ""
+) {
+ Sluift::globals.eventLoop.runOnce();
+ getComponent(L)->disconnect();
+ return 0;
+}
+
+SLUIFT_LUA_FUNCTION_WITH_HELP(
+ Component, set_version,
+
+ "Sets the published version of this component.",
+
+ "self",
+
+ "name the name of the component software\n"
+ "version the version identifier of this component\n"
+ "os the OS this component is running on\n"
+) {
+ Sluift::globals.eventLoop.runOnce();
+ SluiftComponent* component = getComponent(L);
+ if (boost::shared_ptr<SoftwareVersion> version = boost::dynamic_pointer_cast<SoftwareVersion>(Sluift::globals.elementConvertor.convertFromLuaUntyped(L, 2, "software_version"))) {
+ component->setSoftwareVersion(version->getName(), version->getVersion(), version->getOS());
+ }
+ return 0;
+}
+
+SLUIFT_LUA_FUNCTION_WITH_HELP(
+ Component, send_message,
+ "Send a message.",
+ "self\n"
+ "to the JID to send the message to\n"
+ "body the body of the message. Can alternatively be specified using the `body` option\n",
+
+ "to the JID to send the message to\n"
+ "body the body of the message\n"
+ "subject the subject of the MUC room to set\n"
+ "type the type of message to send (`normal`, `chat`, `error`, `groupchat`, `headline`)\n"
+ "payloads payloads to add to the message\n"
+) {
+ Sluift::globals.eventLoop.runOnce();
+ JID to;
+ boost::optional<std::string> from;
+ boost::optional<std::string> body;
+ boost::optional<std::string> subject;
+ std::vector<boost::shared_ptr<Payload> > payloads;
+ int index = 2;
+ Message::Type type = Message::Chat;
+ if (lua_isstring(L, index)) {
+ to = std::string(lua_tostring(L, index));
+ ++index;
+ if (lua_isstring(L, index)) {
+ body = lua_tostring(L, index);
+ ++index;
+ }
+ }
+ if (lua_istable(L, index)) {
+ if (boost::optional<std::string> value = Lua::getStringField(L, index, "to")) {
+ to = *value;
+ }
+
+ if (boost::optional<std::string> value = Lua::getStringField(L, index, "from")) {
+ from = value;
+ }
+
+ if (boost::optional<std::string> value = Lua::getStringField(L, index, "body")) {
+ body = value;
+ }
+
+ if (boost::optional<std::string> value = Lua::getStringField(L, index, "type")) {
+ type = MessageConvertor::convertMessageTypeFromString(*value);
+ }
+
+ if (boost::optional<std::string> value = Lua::getStringField(L, index, "subject")) {
+ subject = value;
+ }
+
+ payloads = getPayloadsFromTable(L, index);
+ }
+
+ if (!to.isValid()) {
+ throw Lua::Exception("Missing 'to'");
+ }
+ if ((!body || body->empty()) && !subject && payloads.empty()) {
+ throw Lua::Exception("Missing any of 'body', 'subject' or 'payloads'");
+ }
+ Message::ref message = boost::make_shared<Message>();
+ message->setTo(to);
+ if (from && !from->empty()) {
+ message->setFrom(*from);
+ }
+ if (body && !body->empty()) {
+ message->setBody(*body);
+ }
+ if (subject) {
+ message->setSubject(*subject);
+ }
+ message->addPayloads(payloads.begin(), payloads.end());
+ message->setType(type);
+ getComponent(L)->getComponent()->sendMessage(message);
+ return 0;
+}
+
+SLUIFT_LUA_FUNCTION_WITH_HELP(
+ Component, send_presence,
+ "Send presence.",
+
+ "self\n"
+ "body the text of the presence. Can alternatively be specified using the `status` option\n",
+
+ "to the JID to send the message to\n"
+ "from the JID to send the message from\n"
+ "status the text of the presence\n"
+ "priority the priority of the presence\n"
+ "type the type of message to send (`available`, `error`, `probe`, `subscribe`, `subscribed`, `unavailable`, `unsubscribe`, `unsubscribed`)\n"
+ "payloads payloads to add to the presence\n"
+) {
+ Sluift::globals.eventLoop.runOnce();
+ boost::shared_ptr<Presence> presence = boost::make_shared<Presence>();
+
+ int index = 2;
+ if (lua_isstring(L, index)) {
+ presence->setStatus(lua_tostring(L, index));
+ ++index;
+ }
+ if (lua_istable(L, index)) {
+ if (boost::optional<std::string> value = Lua::getStringField(L, index, "to")) {
+ presence->setTo(*value);
+ }
+ if (boost::optional<std::string> value = Lua::getStringField(L, index, "from")) {
+ presence->setFrom(*value);
+ }
+ if (boost::optional<std::string> value = Lua::getStringField(L, index, "status")) {
+ presence->setStatus(*value);
+ }
+ if (boost::optional<int> value = Lua::getIntField(L, index, "priority")) {
+ presence->setPriority(*value);
+ }
+ if (boost::optional<std::string> value = Lua::getStringField(L, index, "type")) {
+ presence->setType(PresenceConvertor::convertPresenceTypeFromString(*value));
+ }
+ std::vector< boost::shared_ptr<Payload> > payloads = getPayloadsFromTable(L, index);
+ presence->addPayloads(payloads.begin(), payloads.end());
+ }
+
+ getComponent(L)->getComponent()->sendPresence(presence);
+ lua_pushvalue(L, 1);
+ return 0;
+}
+
+static int sendQuery(lua_State* L, IQ::Type type) {
+ SluiftComponent* component = getComponent(L);
+
+ JID to;
+ if (boost::optional<std::string> toString = Lua::getStringField(L, 2, "to")) {
+ to = JID(*toString);
+ }
+
+ JID from;
+ if (boost::optional<std::string> fromString = Lua::getStringField(L, 2, "from")) {
+ from = JID(*fromString);
+ }
+
+ int timeout = getGlobalTimeout(L);
+ if (boost::optional<int> timeoutInt = Lua::getIntField(L, 2, "timeout")) {
+ timeout = *timeoutInt;
+ }
+
+ boost::shared_ptr<Payload> payload;
+ lua_getfield(L, 2, "query");
+ payload = getPayload(L, -1);
+ lua_pop(L, 1);
+
+ return component->sendRequest(
+ boost::make_shared< GenericRequest<Payload> >(type, from, to, payload, component->getComponent()->getIQRouter()), timeout).convertToLuaResult(L);
+}
+
+#define DISPATCH_PUBSUB_PAYLOAD(payloadType, container, response) \
+ else if (boost::shared_ptr<payloadType> p = boost::dynamic_pointer_cast<payloadType>(payload)) { \
+ return component->sendPubSubRequest(type, to, p, timeout).convertToLuaResult(L); \
+ }
+
+SLUIFT_LUA_FUNCTION(Component, get) {
+ return sendQuery(L, IQ::Get);
+}
+
+SLUIFT_LUA_FUNCTION(Component, set) {
+ return sendQuery(L, IQ::Set);
+}
+
+SLUIFT_LUA_FUNCTION_WITH_HELP(
+ Component, send,
+ "Sends a raw string",
+
+ "self\n"
+ "data the string to send\n",
+
+ ""
+) {
+ Sluift::globals.eventLoop.runOnce();
+
+ getComponent(L)->getComponent()->sendData(std::string(Lua::checkString(L, 2)));
+ lua_pushvalue(L, 1);
+ return 0;
+}
+
+static void pushEvent(lua_State* L, const SluiftComponent::Event& event) {
+ switch (event.type) {
+ case SluiftComponent::Event::MessageType: {
+ Message::ref message = boost::dynamic_pointer_cast<Message>(event.stanza);
+ Lua::Table result = boost::assign::map_list_of
+ ("type", boost::make_shared<Lua::Value>(std::string("message")))
+ ("from", boost::make_shared<Lua::Value>(message->getFrom().toString()))
+ ("to", boost::make_shared<Lua::Value>(message->getTo().toString()))
+ ("body", boost::make_shared<Lua::Value>(message->getBody()))
+ ("message_type", boost::make_shared<Lua::Value>(MessageConvertor::convertMessageTypeToString(message->getType())));
+ Lua::pushValue(L, result);
+ addPayloadsToTable(L, message->getPayloads());
+ Lua::registerTableToString(L, -1);
+ break;
+ }
+ case SluiftComponent::Event::PresenceType: {
+ Presence::ref presence = boost::dynamic_pointer_cast<Presence>(event.stanza);
+ Lua::Table result = boost::assign::map_list_of
+ ("type", boost::make_shared<Lua::Value>(std::string("presence")))
+ ("from", boost::make_shared<Lua::Value>(presence->getFrom().toString()))
+ ("to", boost::make_shared<Lua::Value>(presence->getTo().toString()))
+ ("status", boost::make_shared<Lua::Value>(presence->getStatus()))
+ ("presence_type", boost::make_shared<Lua::Value>(PresenceConvertor::convertPresenceTypeToString(presence->getType())));
+ Lua::pushValue(L, result);
+ addPayloadsToTable(L, presence->getPayloads());
+ Lua::registerTableToString(L, -1);
+ break;
+ }
+ }
+}
+
+struct CallUnaryLuaPredicateOnEvent {
+ CallUnaryLuaPredicateOnEvent(lua_State* L, int index) : L(L), index(index) {
+ }
+
+ bool operator()(const SluiftComponent::Event& event) {
+ lua_pushvalue(L, index);
+ pushEvent(L, event);
+ if (lua_pcall(L, 1, 1, 0) != 0) {
+ throw Lua::Exception(lua_tostring(L, -1));
+ }
+ bool result = lua_toboolean(L, -1);
+ lua_pop(L, 1);
+ return result;
+ }
+
+ lua_State* L;
+ int index;
+};
+
+
+SLUIFT_LUA_FUNCTION(Component, get_next_event) {
+ Sluift::globals.eventLoop.runOnce();
+ SluiftComponent* component = getComponent(L);
+
+ int timeout = getGlobalTimeout(L);
+ boost::optional<SluiftComponent::Event::Type> type;
+ int condition = 0;
+ if (lua_istable(L, 2)) {
+ if (boost::optional<std::string> typeString = Lua::getStringField(L, 2, "type")) {
+ if (*typeString == "message") {
+ type = SluiftComponent::Event::MessageType;
+ }
+ else if (*typeString == "presence") {
+ type = SluiftComponent::Event::PresenceType;
+ }
+ }
+ if (boost::optional<int> timeoutInt = Lua::getIntField(L, 2, "timeout")) {
+ timeout = *timeoutInt;
+ }
+ lua_getfield(L, 2, "if");
+ if (lua_isfunction(L, -1)) {
+ condition = Lua::absoluteOffset(L, -1);
+ }
+ }
+
+ boost::optional<SluiftComponent::Event> event;
+ if (condition) {
+ event = component->getNextEvent(timeout, CallUnaryLuaPredicateOnEvent(L, condition));
+ }
+ else if (type) {
+ event = component->getNextEvent(
+ timeout, lambda::bind(&SluiftComponent::Event::type, lambda::_1) == *type);
+ }
+ else {
+ event = component->getNextEvent(timeout);
+ }
+
+ if (event) {
+ pushEvent(L, *event);
+ }
+ else {
+ lua_pushnil(L);
+ }
+ return 1;
+}
+
+SLUIFT_LUA_FUNCTION_WITH_HELP(
+ Component, jid,
+ "Returns the JID of this component",
+ "self\n",
+ ""
+) {
+ SluiftComponent* component = getComponent(L);
+ lua_pushstring(L, component->getComponent()->getJID().toString().c_str());
+ return 1;
+}
+
+SLUIFT_LUA_FUNCTION(Component, __gc) {
+ SluiftComponent* component = getComponent(L);
+ delete component;
+ return 0;
+}
diff --git a/Sluift/core.lua b/Sluift/core.lua
index 7487de1..ffbb5f9 100644
--- a/Sluift/core.lua
+++ b/Sluift/core.lua
@@ -1,1041 +1,1243 @@
--[[
Copyright (c) 2013-2014 Remko Tronçon
Licensed under the GNU General Public License.
See the COPYING file for more information.
--]]
local sluift = select(1, ...)
local _G = _G
local pairs, ipairs, print, tostring, type, error, assert, next, rawset, xpcall, unpack, io = pairs, ipairs, print, tostring, type, error, assert, next, rawset, xpcall, unpack, io
local setmetatable, getmetatable = setmetatable, getmetatable
local string = require "string"
local table = require "table"
local debug = require "debug"
_ENV = nil
--------------------------------------------------------------------------------
-- Table utility methods
--------------------------------------------------------------------------------
local function table_value_tostring(value)
local result = tostring(value)
if type(value) == 'number' then return result
elseif type(value) == 'boolean' then return result
elseif type(value) == 'string' then return "'" .. result .. "'"
else return '<' .. result .. '>'
end
end
local function table_tostring(table, print_functions, indent, accumulator, history)
local INDENT = ' '
local accumulator = accumulator or ''
local history = history or {}
local indent = indent or ''
accumulator = accumulator .. '{'
history[table] = true
local is_first = true
for key, value in pairs(table) do
if print_functions or type(value) ~= 'function' then
if not is_first then
accumulator = accumulator .. ','
end
is_first = false
accumulator = accumulator .. '\n' .. indent .. INDENT .. '[' .. table_value_tostring(key) .. '] = '
if type(value) == 'table' then
if history[value] then
accumulator = accumulator .. "..."
else
accumulator = table_tostring(value, print_functions, indent .. INDENT, accumulator, history)
end
else
accumulator = accumulator .. table_value_tostring(value)
end
end
end
history[table] = false
if not is_first then
accumulator = accumulator .. '\n' .. indent
end
accumulator = accumulator .. '}'
return accumulator
end
local function register_table_tostring(table, print_functions)
if type(table) == 'table' then
local metatable = getmetatable(table)
if not metatable then
metatable = {}
setmetatable(table, metatable)
end
if print_functions then
metatable.__tostring = function(table) return table_tostring(table, true) end
else
metatable.__tostring = table_tostring
end
end
end
-- FIXME: Not really a good or efficiant equals, but does the trick for now
local function table_equals(t1, t2)
return tostring(t1) == tostring(t2)
end
local function register_table_equals(table)
if type(table) == 'table' then
local metatable = getmetatable(table)
if not metatable then
metatable = {}
setmetatable(table, metatable)
end
metatable.__eq = table_equals
end
end
local function merge_tables(...)
local result = {}
for _, table in ipairs({...}) do
for k, v in pairs(table) do
result[k] = v
end
end
return result
end
local function copy(object)
if type(object) == 'table' then
local copy = {}
for key, value in pairs(object) do
copy[key] = value
end
return copy
else
return object
end
end
local function clear(table)
setmetatable(table, nil)
for key, value in pairs(table) do
rawset(table, key, nil)
end
end
local function trim(string)
return string:gsub("^%s*(.-)%s*$", "%1")
end
local function keys(table)
local result = {}
for key in pairs(table) do
result[#result+1] = key
end
return result
end
local function insert_all(table, values)
for _, value in pairs(values) do
table[#table+1] = value
end
end
--------------------------------------------------------------------------------
-- Help
--------------------------------------------------------------------------------
-- Contains help for native methods that we want access to from here
local extra_help = {}
+local component_extra_help = {}
local help_data = {}
local help_classes = {}
local help_class_metatables = {}
local _H
local function get_synopsis(description)
return description:gsub("[\n\r].*", "")
end
local function format_description(text)
local result = {}
local trim_whitespace
for line in (text .. "\n"):gmatch"(.-)\n" do
if not trim_whitespace and line:find('[^%s]') then
trim_whitespace = line:match("^(%s*)")
end
if trim_whitespace then
line = line:gsub("^" .. trim_whitespace, "")
end
table.insert(result, line)
end
return trim(table.concat(result, "\n"))
end
local function strip_links(text)
return text:gsub("(@{(%w*)})", "`%2`")
end
local function register_help(target, help)
assert(target)
if not help then
help = _H
end
assert(help)
-- Transform description into canonical representation
local parameters = {}
for _, parameter in pairs(help.parameters or {}) do
local parameter_description = parameter[2]
if parameter_description and #parameter_description == 0 then
parameter_description = nil
end
if type(parameter) == "table" then
parameters[#parameters+1] = { name = parameter[1], description = parameter_description }
else
parameters[#parameters+1] = { name = parameter }
end
end
local options = {}
for option_name, option_description in pairs(help.options or {}) do
if type(option_description) == "table" then
options[#options+1] = { name = option_description.name, description = option_description.description }
else
options[#options+1] = { name = option_name, description = option_description }
end
end
local description = format_description(help[1] or help.description or "")
local synopsis = get_synopsis(description)
if #description == 0 then
synopsis = nil
description = nil
end
local data = {
description = description,
synopsis = synopsis,
parameters = parameters,
options = options,
classes = help.classes
}
register_table_tostring(data, true)
help_data[target] = data
end
local function register_class_help(class, help)
help_classes[#help_classes+1] = class
register_help(class, help)
end
local function register_class_table_help(target, class, help)
register_help(target, help)
help_class_metatables[class] = target
register_class_help(class, help)
end
_H = {
[[
Retrieves the help information from `target`.
Returns a table with the following fields:
- `description`: the description of `target`
- `parameters`: an array of parameters of `target` represented as tables with `name` and `description` fields.
- `options`: an array of options (named parameters) of `target` represented as tables with `name` and
`description` fields.
- `methods`: an array of methods
- `fields`: an array of fields
]],
parameters = { {"target", "The target to retrieve help of"} }
}
local function get_help(target)
if not target then error("Nil argument or argument missing") end
local help = help_data[target] or help_data[getmetatable(target)] or {}
-- Collect child methods and fields
local children = {}
if type(target) == "table" then children = target end
local mt
if type(target) == "string" then
mt = help_class_metatables[target]
else
mt = getmetatable(target)
end
if mt and type(mt.__index) == "table" then
children = merge_tables(children, mt.__index)
end
local methods = {}
local fields = {}
for name, value in pairs(children) do
if name:sub(1, 1) ~= "_" then
if type(value) == "function" then
methods[#methods+1] = { name = name, ref = value }
else
fields[#fields+1] = { name = name, description = nil }
end
end
end
if next(methods) ~= nil then
help.methods = methods
end
if next(fields) ~= nil then
help.fields = fields
end
if next(help) then
return help
else
return nil
end
end
register_help(get_help)
_H = {
[[
Prints the help of `target`.
`target` can be any object. When `target` is a string, prints the help of the class with
the given name.
]],
parameters = { {"target", "The target to retrieve help of"} }
}
local function help(target)
print()
if not target then
print("Call `help(target)` to get the help of a specific `target`.")
print("`target` can be any object. When `target` is a string, prints")
print("the help of the class with the given name.")
print()
print("For general information about sluift, type:")
print(" help(sluift)")
print()
return
end
local data = get_help(target)
if not data then
print("No help available\n")
return
end
-- Collect help of children
local methods = {}
for _, method in pairs(data.methods or {}) do
local description
local method_help = get_help(method.ref)
if method_help and method_help.description then
description = method_help.synopsis
end
methods[#methods+1] = { name = method.name, description = description }
end
local fields = copy(data.fields or {})
table.sort(methods, function (a, b) return (a.name or "") < (b.name or "") end)
table.sort(fields, function (a, b) return (a.name or "") < (b.name or "") end)
local classes = {}
for _, class in pairs(data.classes or {}) do
classes[#classes+1] = { name = class, description = get_help(class).synopsis }
end
print(strip_links(data.description) or "(No description available)")
for _, p in ipairs({
{"Parameters", data.parameters}, {"Options", data.options}, {"Methods", methods}, {"Fields", fields}, {"Classes", classes}}) do
if p[2] and next(p[2]) ~= nil then
print()
print(p[1] .. ":")
for _, parameter in ipairs(p[2]) do
if parameter.description then
print(" " .. parameter.name .. ": " .. strip_links(parameter.description))
else
print(" " .. parameter.name)
end
end
end
end
print()
end
register_help(help)
--------------------------------------------------------------------------------
-- Utility methods
--------------------------------------------------------------------------------
_H = {
[[ Perform a shallow copy of `object`. ]],
parameters = {{"object", "the object to copy"}}
}
register_help(copy)
_H = {
[[ Pretty-print a table ]],
parameters = {{"table", "the table to print"}}
}
local function tprint(table)
print(table_tostring(table, true))
end
register_help(tprint)
local function remove_help_parameters(elements, table)
if type(elements) ~= "table" then
elements = {elements}
end
local result = copy(table)
for k, v in ipairs(table) do
for _, element in ipairs(elements) do
if v.name == element then
result[k] = nil
end
end
end
return result
end
local function parse_options(unnamed_parameters, arg1, arg2)
local options = {}
local f
if type(arg1) == 'table' then
options = arg1
f = arg2
elseif type(arg1) == 'function' then
f = arg1
end
options.f = f or options.f
return copy(options)
end
local function get_by_type(table, typ)
for _, v in ipairs(table) do
if v['_type'] == typ then
return v
end
end
end
local function register_get_by_type_index(table)
if type(table) == 'table' then
local metatable = getmetatable(table)
if not metatable then
metatable = {}
setmetatable(table, metatable)
end
metatable.__index = get_by_type
end
return table
end
local function call(options)
local f = options[1]
local result = { xpcall(f, debug.traceback) }
if options.finally then
options.finally()
end
if result[1] then
table.remove(result, 1)
return unpack(result)
else
error(result[2])
end
end
local function read_file(file)
local f = io.open(file, 'rb')
local result = f:read('*all')
f:close()
return result
end
_H = {
[[ Generate a form table, suitable for PubSubConfiguration and MAMQuery ]],
parameters = { {"fields", "The fields that will be converted into a form table"},
{"form_type", "If specified, add a form_type field with this value"},
{"type", "Form type, e.g. 'submit'"} }
}
local function create_form(...)
local options = parse_options({}, ...)
local result = { fields = {} }
-- FIXME: make nicer when parse_options binds positional arguments to names
if options.fields then
for var, value in pairs(options.fields) do
result.fields[#result.fields+1] = { name = var, value = value }
end
elseif options[1] then
for var, value in pairs(options[1]) do
result.fields[#result.fields+1] = { name = var, value = value }
end
end
if options.form_type then
result.fields[#result.fields+1] = { name = 'FORM_TYPE', value = options.form_type }
end
result['type'] = options.type
return result
end
--------------------------------------------------------------------------------
-- Metatables
--------------------------------------------------------------------------------
_H = {
[[ Client interface ]]
}
local Client = {
_with_prompt = function(client) return client:jid() end
}
Client.__index = Client
register_class_table_help(Client, "Client")
+_H = {
+ [[ Component interface ]]
+}
+local Component = {
+ _with_prompt = function(component) return component:jid() end
+}
+Component.__index = Component
+register_class_table_help(Component, "Component")
+
_H = {
[[ Interface to communicate with a PubSub service ]]
}
local PubSub = {}
PubSub.__index = PubSub
register_class_table_help(PubSub, "PubSub")
_H = {
[[ Interface to communicate with a PubSub node on a service ]]
}
local PubSubNode = {}
PubSubNode.__index = PubSubNode
register_class_table_help(PubSubNode, "PubSubNode")
--------------------------------------------------------------------------------
-- with
--------------------------------------------------------------------------------
local original_G
local function with (target, f)
-- Dynamic scope
if f then
with(target)
return call{f, finally = function() with() end}
end
-- No scope
if target then
if not original_G then
original_G = copy(_G)
setmetatable(original_G, getmetatable(_G))
clear(_G)
end
setmetatable(_G, {
__index = function(_, key)
local value = target[key]
if value then
if type(value) == 'function' then
-- Add 'self' argument to all functions
return function(...) return value(target, ...) end
else
return value
end
else
return original_G[key]
end
end,
__newindex = original_G,
_completions = function ()
local result = {}
if type(target) == "table" then
insert_all(result, keys(target))
end
local mt = getmetatable(target)
if mt and type(mt.__index) == 'table' then
insert_all(result, keys(mt.__index))
end
insert_all(result, keys(original_G))
return result
end
})
-- Set prompt
local prompt = nil
-- Try '_with_prompt' in metatable
local target_metatable = getmetatable(target)
if target_metatable then
if type(target_metatable._with_prompt) == "function" then
prompt = target_metatable._with_prompt(target)
else
prompt = target_metatable._with_prompt
end
end
if not prompt then
-- Use tostring()
local target_string = tostring(target)
if string.len(target_string) > 25 then
prompt = string.sub(target_string, 0, 22) .. "..."
else
prompt = target_string
end
end
rawset(_G, "_PROMPT", prompt .. '> ')
else
-- Reset _G
clear(_G)
for key, value in pairs(original_G) do
_G[key] = value
end
setmetatable(_G, original_G)
end
end
--------------------------------------------------------------------------------
-- Client
--------------------------------------------------------------------------------
extra_help = {
["Client.get_next_event"] = {
[[ Returns the next event. ]],
parameters = { "self" },
options = {
type = "The type of event to return (`message`, `presence`, `pubsub`). When omitted, all event types are returned.",
timeout = "The amount of time to wait for events.",
["if"] = "A function to filter events. When this function, called with the event as a parameter, returns true, the event will be returned"
}
},
["Client.get"] = {
[[ Sends a `get` query. ]],
parameters = { "self" },
options = {
to = "The JID of the target to send the query to",
query = "The query to send",
timeout = "The amount of time to wait for the query to finish",
}
},
["Client.set"] = {
[[ Sends a `set` query. ]],
parameters = { "self" },
options = {
to = "The JID of the target to send the query to",
query = "The query to send.",
timeout = "The amount of time to wait for the query to finish.",
}
},
["Client.async_connect"] = {
[[
Connect to the server asynchronously.
This method immediately returns.
]],
parameters = { "self" },
options = {
host = "The host to connect to. When omitted, is determined by resolving the client JID.",
port = "The port to connect to. When omitted, is determined by resolving the client JID."
}
}
}
_H = {
[[
Connect to the server.
This method blocks until the connection has been established.
]],
parameters = { "self" },
options = extra_help["Client.async_connect"].options
}
function Client:connect (...)
local options = parse_options({}, ...)
local f = options.f
self:async_connect(options)
self:wait_connected()
if f then
return call {function() return f(self) end, finally = function() self:disconnect() end}
end
return true
end
register_help(Client.connect)
_H = {
[[
Returns an iterator over all events.
This function blocks until `timeout` is reached (or blocks forever if it is omitted).
]],
parameters = { "self" },
options = extra_help["Client.get_next_event"].options
}
function Client:events (options)
local function client_events_iterator(s)
return s['client']:get_next_event(s['options'])
end
return client_events_iterator, {client = self, options = options}
end
register_help(Client.events)
_H = {
[[
Calls `f` for each event.
]],
parameters = { "self" },
options = merge_tables(get_help(Client.events).options, {
f = "The functor to call with each event. Required."
})
}
function Client:for_each_event (...)
local options = parse_options({}, ...)
if not type(options.f) == 'function' then error('Expected function') end
for event in self:events(options) do
local result = options.f(event)
if result then
return result
end
end
end
register_help(Client.for_each_event)
for method, event_type in pairs({message = 'message', presence = 'presence', pubsub_event = 'pubsub'}) do
_H = {
"Call `f` for all events of type `" .. event_type .. "`.",
parameters = { "self" },
options = remove_help_parameters("type", get_help(Client.for_each_event).options)
}
Client['for_each_' .. method] = function (client, ...)
local options = parse_options({}, ...)
options['type'] = event_type
return client:for_each_event (options)
end
register_help(Client['for_each_' .. method])
_H = {
"Get the next event of type `" .. event_type .. "`.",
parameters = { "self" },
options = remove_help_parameters("type", extra_help["Client.get_next_event"].options)
}
Client['get_next_' .. method] = function (client, ...)
local options = parse_options({}, ...)
options['type'] = event_type
return client:get_next_event(options)
end
register_help(Client['get_next_' .. method])
end
for method, event_type in pairs({messages = 'message', pubsub_events = 'pubsub'}) do
_H = {
"Returns an iterator over all events of type `" .. event_type .. "`.",
parameters = { "self" },
options = remove_help_parameters("type", get_help(Client.for_each_event).options)
}
Client[method] = function (client, ...)
local options = parse_options({}, ...)
options['type'] = event_type
return client:events (options)
end
register_help(Client[method])
end
_H = {
[[
Process all pending events
]],
parameters = { "self" }
}
function Client:process_events ()
for event in self:events{timeout=0} do end
end
register_help(Client.process_events)
--
-- Register get_* and set_* convenience methods for some type of queries
--
-- Example usages:
-- client:get_software_version{to = 'alice@wonderland.lit'}
-- client:set_command{to = 'alice@wonderland.lit', command = { type = 'execute', node = 'uptime' }}
--
local get_set_shortcuts = {
get = {'software_version', 'disco_items', 'xml', 'dom', 'vcard', 'mam'},
set = {'command', 'mam'}
}
for query_action, query_types in pairs(get_set_shortcuts) do
for _, query_type in ipairs(query_types) do
_H = {
"Sends a `" .. query_action .. "` query of type `" .. query_type .. "`.\n" ..
"Apart from the options below, all top level elements of `" .. query_type .. "` can be passed.",
parameters = { "self" },
options = remove_help_parameters({"query", "type"}, extra_help["Client.get"].options),
}
local method = query_action .. '_' .. query_type
Client[method] = function (client, options)
options = options or {}
if type(options) ~= 'table' then error('Invalid options: ' .. options) end
options['query'] = merge_tables({_type = query_type}, options[query_type] or {})
return client[query_action](client, options)
end
register_help(Client[method])
end
end
_H = {
[[ Returns a @{PubSub} object for communicating with the PubSub service at `jid`. ]],
parameters = {
"self",
{"jid", "The JID of the PubSub service"}
}
}
function Client:pubsub (jid)
local result = { client = self, jid = jid }
setmetatable(result, PubSub)
return result
end
register_help(Client.pubsub)
+
+--------------------------------------------------------------------------------
+-- Component
+--------------------------------------------------------------------------------
+
+component_extra_help = {
+ ["Component.get_next_event"] = {
+ [[ Returns the next event. ]],
+ parameters = { "self" },
+ options = {
+ type = "The type of event to return (`message`, `presence`). When omitted, all event types are returned.",
+ timeout = "The amount of time to wait for events.",
+ ["if"] = "A function to filter events. When this function, called with the event as a parameter, returns true, the event will be returned"
+ }
+ },
+ ["Component.get"] = {
+ [[ Sends a `get` query. ]],
+ parameters = { "self" },
+ options = {
+ to = "The JID of the target to send the query to",
+ query = "The query to send",
+ timeout = "The amount of time to wait for the query to finish",
+ }
+ },
+ ["Component.set"] = {
+ [[ Sends a `set` query. ]],
+ parameters = { "self" },
+ options = {
+ to = "The JID of the target to send the query to",
+ query = "The query to send.",
+ timeout = "The amount of time to wait for the query to finish.",
+ }
+ },
+ ["Component.async_connect"] = {
+ [[
+ Connect to the server asynchronously.
+
+ This method immediately returns.
+ ]],
+ parameters = { "self" },
+ options = {
+ host = "The host to connect to.",
+ port = "The port to connect to."
+ }
+ }
+}
+
+_H = {
+ [[
+ Connect to the server.
+
+ This method blocks until the connection has been established.
+ ]],
+ parameters = { "self" },
+ options = component_extra_help["Component.async_connect"].options
+}
+function Component:connect (...)
+ local options = parse_options({}, ...)
+ local f = options.f
+ self:async_connect(options)
+ self:wait_connected()
+ if f then
+ return call {function() return f(self) end, finally = function() self:disconnect() end}
+ end
+ return true
+end
+register_help(Component.connect)
+
+
+_H = {
+ [[
+ Returns an iterator over all events.
+
+ This function blocks until `timeout` is reached (or blocks forever if it is omitted).
+ ]],
+ parameters = { "self" },
+ options = component_extra_help["Component.get_next_event"].options
+}
+function Component:events (options)
+ local function component_events_iterator(s)
+ return s['component']:get_next_event(s['options'])
+ end
+ return component_events_iterator, {component = self, options = options}
+end
+register_help(Component.events)
+
+
+_H = {
+ [[
+ Calls `f` for each event.
+ ]],
+ parameters = { "self" },
+ options = merge_tables(get_help(Component.events).options, {
+ f = "The functor to call with each event. Required."
+ })
+}
+function Component:for_each_event (...)
+ local options = parse_options({}, ...)
+ if not type(options.f) == 'function' then error('Expected function') end
+ for event in self:events(options) do
+ local result = options.f(event)
+ if result then
+ return result
+ end
+ end
+end
+register_help(Component.for_each_event)
+
+for method, event_type in pairs({message = 'message', presence = 'presence'}) do
+ _H = {
+ "Call `f` for all events of type `" .. event_type .. "`.",
+ parameters = { "self" },
+ options = remove_help_parameters("type", get_help(Component.for_each_event).options)
+ }
+ Component['for_each_' .. method] = function (component, ...)
+ local options = parse_options({}, ...)
+ options['type'] = event_type
+ return component:for_each_event (options)
+ end
+ register_help(Component['for_each_' .. method])
+
+ _H = {
+ "Get the next event of type `" .. event_type .. "`.",
+ parameters = { "self" },
+ options = remove_help_parameters("type", component_extra_help["Component.get_next_event"].options)
+ }
+ Component['get_next_' .. method] = function (component, ...)
+ local options = parse_options({}, ...)
+ options['type'] = event_type
+ return component:get_next_event(options)
+ end
+ register_help(Component['get_next_' .. method])
+end
+
+for method, event_type in pairs({messages = 'message'}) do
+ _H = {
+ "Returns an iterator over all events of type `" .. event_type .. "`.",
+ parameters = { "self" },
+ options = remove_help_parameters("type", get_help(Component.for_each_event).options)
+ }
+ Component[method] = function (component, ...)
+ local options = parse_options({}, ...)
+ options['type'] = event_type
+ return component:events (options)
+ end
+ register_help(Component[method])
+end
+
+_H = {
+ [[
+ Process all pending events
+ ]],
+ parameters = { "self" }
+}
+function Component:process_events ()
+ for event in self:events{timeout=0} do end
+end
+register_help(Component.process_events)
+
+
+--
+-- Register get_* and set_* convenience methods for some type of queries
+--
+-- Example usages:
+-- component:get_software_version{to = 'alice@wonderland.lit'}
+-- component:set_command{to = 'alice@wonderland.lit', command = { type = 'execute', node = 'uptime' }}
+--
+local get_set_shortcuts = {
+ get = {'software_version', 'disco_items', 'xml', 'dom', 'vcard'},
+ set = {'command'}
+}
+for query_action, query_types in pairs(get_set_shortcuts) do
+ for _, query_type in ipairs(query_types) do
+ _H = {
+ "Sends a `" .. query_action .. "` query of type `" .. query_type .. "`.\n" ..
+ "Apart from the options below, all top level elements of `" .. query_type .. "` can be passed.",
+ parameters = { "self" },
+ options = remove_help_parameters({"query", "type"}, component_extra_help["Component.get"].options),
+ }
+ local method = query_action .. '_' .. query_type
+ Component[method] = function (component, options)
+ options = options or {}
+ if type(options) ~= 'table' then error('Invalid options: ' .. options) end
+ options['query'] = merge_tables({_type = query_type}, options[query_type] or {})
+ return component[query_action](component, options)
+ end
+ register_help(Component[method])
+ end
+end
+
--------------------------------------------------------------------------------
-- PubSub
--------------------------------------------------------------------------------
local function process_pubsub_event (event)
if event._type == 'pubsub_event_items' then
-- Add 'item' shortcut to payload of first item
event.item = event.items and event.items[1] and
event.items[1].data and event.items[1].data[1]
end
end
function PubSub:list_nodes (options)
return self.client:get_disco_items(merge_tables({to = self.jid}, options))
end
function PubSub:node (node)
local result = { client = self.client, jid = self.jid, node = node }
setmetatable(result, PubSubNode)
return result
end
local simple_pubsub_queries = {
get_default_configuration = 'pubsub_owner_default',
get_subscriptions = 'pubsub_subscriptions',
get_affiliations = 'pubsub_affiliations',
get_default_subscription_options = 'pubsub_default',
}
for method, query_type in pairs(simple_pubsub_queries) do
PubSub[method] = function (service, options)
options = options or {}
return service.client:query_pubsub(merge_tables(
{ type = 'get', to = service.jid, query = { _type = query_type } },
options))
end
end
for _, method in ipairs({'events', 'get_next_event', 'for_each_event'}) do
PubSub[method] = function (node, ...)
local options = parse_options({}, ...)
options['if'] = function (event)
return event.type == 'pubsub' and event.from == node.jid and event.node == node
end
return node.client[method](node.client, options)
end
end
--------------------------------------------------------------------------------
-- PubSubNode
--------------------------------------------------------------------------------
local function pubsub_node_configuration_to_form(configuration)
return create_form{configuration, form_type="http://jabber.org/protocol/pubsub#node_config", type="submit"}
end
function PubSubNode:list_items (options)
return self.client:get_disco_items(merge_tables({to = self.jid, disco_items = { node = self.node }}, options))
end
local simple_pubsub_node_queries = {
get_configuration = 'pubsub_owner_configure',
get_subscriptions = 'pubsub_subscriptions',
get_affiliations = 'pubsub_affiliations',
get_owner_subscriptions = 'pubsub_owner_subscriptions',
get_owner_affiliations = 'pubsub_owner_affiliations',
get_default_subscription_options = 'pubsub_default',
}
for method, query_type in pairs(simple_pubsub_node_queries) do
PubSubNode[method] = function (node, options)
return node.client:query_pubsub(merge_tables({
type = 'get', to = node.jid, query = {
_type = query_type, node = node.node
}}, options))
end
end
function PubSubNode:get_items (...)
local options = parse_options({}, ...)
local items = options.items or {}
if options.maximum_items then
items = merge_tables({maximum_items = options.maximum_items}, items)
end
items = merge_tables({_type = 'pubsub_items', node = self.node}, items)
return self.client:query_pubsub(merge_tables({
type = 'get', to = self.jid, query = items}, options))
end
function PubSubNode:get_item (...)
local options = parse_options({}, ...)
if not type(options.id) == 'string' then error('Expected ID') end
return self:get_items{items = {{id = options.id}}}
end
function PubSubNode:create (options)
options = options or {}
local configure
if options['configuration'] then
configure = { data = pubsub_node_configuration_to_form(options['configuration']) }
end
return self.client:query_pubsub(merge_tables(
{ type = 'set', to = self.jid, query = {
_type = 'pubsub_create', node = self.node, configure = configure }
}, options))
end
function PubSubNode:delete (options)
options = options or {}
local redirect
if options['redirect'] then
redirect = {uri = options['redirect']}
end
return self.client:query_pubsub(merge_tables({ type = 'set', to = self.jid, query = {
_type = 'pubsub_owner_delete', node = self.node, redirect = redirect
}}, options))
end
function PubSubNode:set_configuration(options)
options = options or {}
local configuration = pubsub_node_configuration_to_form(options['configuration'])
return self.client:query_pubsub(merge_tables(
{ type = 'set', to = self.jid, query = {
_type = 'pubsub_owner_configure', node = self.node, data = configuration }
}, options))
end
function PubSubNode:set_owner_affiliations(...)
local options = parse_options({}, ...)
return self.client:query_pubsub(merge_tables({
type = 'set', to = self.jid, query = merge_tables({
_type = 'pubsub_owner_affiliations', node = self.node,
}, options.affiliations)}, options))
end
function PubSubNode:subscribe(...)
local options = parse_options({}, ...)
local jid = options.jid or sluift.jid.to_bare(self.client:jid())
return self.client:query_pubsub(merge_tables(
{ type = 'set', to = self.jid, query = {
_type = 'pubsub_subscribe', node = self.node, jid = jid }
}, options))
end
function PubSubNode:unsubscribe(options)
options = options or {}
return self.client:query_pubsub(merge_tables(
{ type = 'set', to = self.jid, query = {
_type = 'pubsub_unsubscribe', node = self.node, jid = options['jid'],
subscription_id = 'subscription_id'}
}, options))
end
function PubSubNode:get_subscription_options (options)
return self.client:query_pubsub(merge_tables(
{ type = 'get', to = self.jid, query = {
_type = 'pubsub_options', node = self.node, jid = options['jid'] }
}, options))
end
function PubSubNode:publish(...)
local options = parse_options({}, ...)
local items = options.items or {}
if options.item then
if type(options.item) == 'string' or options.item._type then
items = {{id = options.id, data = { options.item } }}
options.id = nil
else
items = { options.item }
end
options.item = nil
end
return self.client:query_pubsub(merge_tables(
{ type = 'set', to = self.jid, query = {
_type = 'pubsub_publish', node = self.node, items = items }
}, options))
end
function PubSubNode:retract(...)
local options = parse_options({}, ...)
local items = options.items
if options.id then
items = {{id = options.id}}
end
return self.client:query_pubsub(merge_tables(
{ type = 'set', to = self.jid, query = {
_type = 'pubsub_retract', node = self.node, items = items, notify = options['notify']
}}, options))
end
function PubSubNode:purge(...)
local options = parse_options({}, ...)
return self.client:query_pubsub(merge_tables(
{ type = 'set', to = self.jid, query = {
_type = 'pubsub_owner_purge', node = self.node
}}, options))
end
-- Iterators over events
for _, method in ipairs({'events', 'get_next_event', 'for_each_event'}) do
PubSubNode[method] = function (node, ...)
local options = parse_options({}, ...)
options['if'] = function (event)
return event.type == 'pubsub' and event.from == node.jid and event.node == node.node
end
return node.client[method](node.client, options)
end
end
--------------------------------------------------------------------------------
-- Service discovery
--------------------------------------------------------------------------------
local disco = {
features = {
DISCO_INFO = 'http://jabber.org/protocol/disco#info',
COMMANDS = 'http://jabber.org/protocol/commands',
USER_LOCATION = 'http://jabber.org/protocol/geoloc',
USER_TUNE = 'http://jabber.org/protocol/tune',
USER_AVATAR_METADATA = 'urn:xmpp:avatar:metadata',
USER_ACTIVITY = 'http://jabber.org/protocol/activity',
USER_PROFILE = 'urn:xmpp:tmp:profile'
}
}
--------------------------------------------------------------------------------
_H = nil
extra_help['sluift'] = {
[[
This module provides methods for XMPP communication.
The main entry point of this module is the `new_client` method, which creates a
new client for communicating with an XMPP server.
]],
classes = help_classes
}
return {
Client = Client,
+ Component = Component,
register_help = register_help,
register_class_help = register_class_help,
register_table_tostring = register_table_tostring,
register_table_equals = register_table_equals,
register_get_by_type_index = register_get_by_type_index,
process_pubsub_event = process_pubsub_event,
tprint = tprint,
read_file = read_file,
disco = disco,
get_help = get_help,
help = help,
extra_help = extra_help,
+ component_extra_help = component_extra_help,
copy = copy,
with = with,
create_form = create_form
}
diff --git a/Sluift/sluift.cpp b/Sluift/sluift.cpp
index b55649b..2fd1e50 100644
--- a/Sluift/sluift.cpp
+++ b/Sluift/sluift.cpp
@@ -1,423 +1,460 @@
/*
* Copyright (c) 2011-2014 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#include <Sluift/sluift.h>
#include <lua.hpp>
#include <string>
#include <boost/bind.hpp>
#include <boost/numeric/conversion/cast.hpp>
#include <boost/assign/list_of.hpp>
#include "Watchdog.h"
#include <Sluift/Lua/Check.h>
#include <Sluift/SluiftClient.h>
+#include <Sluift/SluiftComponent.h>
#include <Sluift/globals.h>
#include <Sluift/Lua/Exception.h>
#include <Sluift/Lua/LuaUtils.h>
#include <Sluift/Lua/FunctionRegistration.h>
#include <Swiften/Base/sleep.h>
#include <Swiften/Base/foreach.h>
#include <Swiften/Base/IDGenerator.h>
#include <Swiften/Parser/PayloadParsers/UnitTest/PayloadsParserTester.h>
#include <Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h>
#include <Swiften/Serializer/PayloadSerializer.h>
#include <Swiften/TLS/Certificate.h>
#include <Swiften/TLS/CertificateFactory.h>
#include <Sluift/LuaElementConvertor.h>
#include <Sluift/Lua/Debug.h>
#include <Swiften/StringCodecs/Base64.h>
#include <Swiften/StringCodecs/Hexify.h>
#include <Swiften/IDN/IDNConverter.h>
#include <Swiften/Crypto/CryptoProvider.h>
#include <Swiften/Crypto/PlatformCryptoProvider.h>
#include <Sluift/ITunesInterface.h>
using namespace Swift;
namespace Swift {
namespace Sluift {
SluiftGlobals globals;
}
}
extern "C" const char core_lua[];
extern "C" size_t core_lua_size;
static inline bool getGlobalDebug(lua_State* L) {
lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.moduleLibIndex);
lua_getfield(L, -1, "debug");
int result = lua_toboolean(L, -1);
lua_pop(L, 2);
return result;
}
/*******************************************************************************
* Module functions
******************************************************************************/
SLUIFT_LUA_FUNCTION_WITH_HELP(
Sluift, new_client,
"Creates a new client.\n\nReturns a @{Client} object.\n",
"jid The JID to connect as\n"
"passphrase The passphrase to use\n",
""
) {
Lua::checkString(L, 1);
JID jid(std::string(Lua::checkString(L, 1)));
std::string password(Lua::checkString(L, 2));
SluiftClient** client = reinterpret_cast<SluiftClient**>(lua_newuserdata(L, sizeof(SluiftClient*)));
lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.coreLibIndex);
lua_getfield(L, -1, "Client");
lua_setmetatable(L, -3);
lua_pop(L, 1);
*client = new SluiftClient(jid, password, &Sluift::globals.networkFactories, &Sluift::globals.eventLoop);
(*client)->setTraceEnabled(getGlobalDebug(L));
return 1;
}
SLUIFT_LUA_FUNCTION_WITH_HELP(
+ Sluift, new_component,
+
+ "Creates a new component.\n\nReturns a @{Component} object.\n",
+
+ "jid The JID to connect as\n"
+ "passphrase The passphrase to use\n",
+
+ ""
+) {
+ Lua::checkString(L, 1);
+ JID jid(std::string(Lua::checkString(L, 1)));
+ std::string password(Lua::checkString(L, 2));
+
+ SluiftComponent** component = reinterpret_cast<SluiftComponent**>(lua_newuserdata(L, sizeof(SluiftComponent*)));
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.coreLibIndex);
+ lua_getfield(L, -1, "Component");
+ lua_setmetatable(L, -3);
+ lua_pop(L, 1);
+
+ *component = new SluiftComponent(jid, password, &Sluift::globals.networkFactories, &Sluift::globals.eventLoop);
+ (*component)->setTraceEnabled(getGlobalDebug(L));
+ return 1;
+}
+
+SLUIFT_LUA_FUNCTION_WITH_HELP(
Sluift, sha1,
"Compute the SHA-1 hash of given data",
"data the data to hash",
""
) {
static boost::shared_ptr<CryptoProvider> crypto(PlatformCryptoProvider::create());
if (!lua_isstring(L, 1)) {
throw Lua::Exception("Expected string");
}
size_t len;
const char* data = lua_tolstring(L, 1, &len);
ByteArray result = crypto->getSHA1Hash(createByteArray(data, len));
lua_pushlstring(L, reinterpret_cast<char*>(vecptr(result)), result.size());
return 1;
}
SLUIFT_LUA_FUNCTION_WITH_HELP(
Sluift, sleep,
"Sleeps for the given time.",
"milliseconds the amount of milliseconds to sleep",
""
) {
Sluift::globals.eventLoop.runOnce();
int timeout = Lua::checkIntNumber(L, 1);
Watchdog watchdog(timeout, Sluift::globals.networkFactories.getTimerFactory());
while (!watchdog.getTimedOut()) {
Swift::sleep(boost::numeric_cast<unsigned int>(std::min(100, timeout)));
Sluift::globals.eventLoop.runOnce();
}
return 0;
}
SLUIFT_LUA_FUNCTION_WITH_HELP(
Sluift, new_uuid,
"Generates a new UUID", "", ""
) {
lua_pushstring(L, IDGenerator().generateID().c_str());
return 1;
}
SLUIFT_LUA_FUNCTION_WITH_HELP(
Sluift, from_xml,
"Convert a raw XML string into a structured representation.",
"string the string to convert",
""
) {
PayloadsParserTester parser;
if (!parser.parse(Lua::checkString(L, 1))) {
throw Lua::Exception("Error in XML");
}
return Sluift::globals.elementConvertor.convertToLua(L, parser.getPayload());
}
SLUIFT_LUA_FUNCTION_WITH_HELP(
Sluift, to_xml,
"Convert a structured element into XML.",
"element the element to convert",
""
) {
static FullPayloadSerializerCollection serializers;
boost::shared_ptr<Payload> payload = boost::dynamic_pointer_cast<Payload>(Sluift::globals.elementConvertor.convertFromLua(L, 1));
if (!payload) {
throw Lua::Exception("Unrecognized XML");
}
PayloadSerializer* serializer = serializers.getPayloadSerializer(payload);
if (!payload) {
throw Lua::Exception("Unrecognized XML");
}
lua_pushstring(L, serializer->serialize(payload).c_str());
return 1;
}
SLUIFT_LUA_FUNCTION_WITH_HELP(
Sluift, hexify,
"Convert binary data into hexadecimal format.",
"data the data to convert",
""
) {
if (!lua_isstring(L, 1)) {
throw Lua::Exception("Expected string");
}
size_t len;
const char* data = lua_tolstring(L, 1, &len);
lua_pushstring(L, Hexify::hexify(createByteArray(data, len)).c_str());
return 1;
}
SLUIFT_LUA_FUNCTION_WITH_HELP(
Sluift, unhexify,
"Convert hexadecimal data into binary data.",
"data the data in hexadecimal format",
""
) {
if (!lua_isstring(L, 1)) {
throw Lua::Exception("Expected string");
}
ByteArray result = Hexify::unhexify(lua_tostring(L, 1));
lua_pushlstring(L, reinterpret_cast<char*>(vecptr(result)), result.size());
return 1;
}
/*******************************************************************************
* Crypto functions
******************************************************************************/
SLUIFT_LUA_FUNCTION_WITH_HELP(
Crypto, new_certificate,
"Creates a new X.509 certificate from DER data.\n",
"der the DER-encoded certificate data",
"") {
ByteArray certData(Lua::checkByteArray(L, 1));
Certificate::ref cert(Sluift::globals.tlsFactories.getCertificateFactory()->createCertificateFromDER(certData));
lua_createtable(L, 0, 0);
lua_pushstring(L, cert->getSubjectName().c_str());
lua_setfield(L, -2, "subject_name");
lua_pushstring(L, Certificate::getSHA1Fingerprint(cert, Sluift::globals.networkFactories.getCryptoProvider()).c_str());
lua_setfield(L, -2, "sha1_fingerprint");
Lua::pushStringArray(L, cert->getCommonNames());
lua_setfield(L, -2, "common_names");
Lua::pushStringArray(L, cert->getSRVNames());
lua_setfield(L, -2, "srv_names");
Lua::pushStringArray(L, cert->getDNSNames());
lua_setfield(L, -2, "dns_names");
Lua::pushStringArray(L, cert->getXMPPAddresses());
lua_setfield(L, -2, "xmpp_addresses");
Lua::registerTableToString(L, -1);
return 1;
}
/*******************************************************************************
* JID Functions
******************************************************************************/
SLUIFT_LUA_FUNCTION(JID, to_bare) {
JID jid(std::string(Lua::checkString(L, 1)));
lua_pushstring(L, jid.toBare().toString().c_str());
return 1;
}
SLUIFT_LUA_FUNCTION(JID, node) {
JID jid(std::string(Lua::checkString(L, 1)));
lua_pushstring(L, jid.getNode().c_str());
return 1;
}
SLUIFT_LUA_FUNCTION(JID, domain) {
JID jid(std::string(Lua::checkString(L, 1)));
lua_pushstring(L, jid.getDomain().c_str());
return 1;
}
SLUIFT_LUA_FUNCTION(JID, resource) {
JID jid(std::string(Lua::checkString(L, 1)));
lua_pushstring(L, jid.getResource().c_str());
return 1;
}
SLUIFT_LUA_FUNCTION(JID, escape_node) {
lua_pushstring(L, JID::getEscapedNode(Lua::checkString(L, 1)).c_str());
return 1;
}
/*******************************************************************************
* Base64 Functions
******************************************************************************/
SLUIFT_LUA_FUNCTION(Base64, encode) {
if (!lua_isstring(L, 1)) {
throw Lua::Exception("Expected string");
}
size_t len;
const char* data = lua_tolstring(L, 1, &len);
lua_pushstring(L, Base64::encode(createByteArray(data, len)).c_str());
return 1;
}
SLUIFT_LUA_FUNCTION(Base64, decode) {
if (!lua_isstring(L, 1)) {
throw Lua::Exception("Expected string");
}
ByteArray result = Base64::decode(lua_tostring(L, 1));
lua_pushlstring(L, reinterpret_cast<char*>(vecptr(result)), result.size());
return 1;
}
/*******************************************************************************
* IDN Functions
******************************************************************************/
SLUIFT_LUA_FUNCTION(IDN, encode) {
IDNConverter* converter = Sluift::globals.networkFactories.getIDNConverter();
lua_pushstring(L, converter->getIDNAEncoded(Lua::checkString(L, 1)).c_str());
return 1;
}
SLUIFT_LUA_FUNCTION(IDN, stringprep) {
IDNConverter* converter = Sluift::globals.networkFactories.getIDNConverter();
IDNConverter::StringPrepProfile profile;
std::string profileString = Lua::checkString(L, 2);
if (profileString == "nameprep") {
profile = IDNConverter::NamePrep;
}
else if (profileString == "xmpp_nodeprep") {
profile = IDNConverter::XMPPNodePrep;
}
else if (profileString == "xmpp_resourceprep") {
profile = IDNConverter::XMPPResourcePrep;
}
else if (profileString == "saslprep") {
profile = IDNConverter::SASLPrep;
}
else {
throw Lua::Exception("Invalid profile");
}
try {
lua_pushstring(L, converter->getStringPrepared(Lua::checkString(L, 1), profile).c_str());
}
catch (const std::exception&) {
throw Lua::Exception("Error");
}
return 1;
}
/*******************************************************************************
* iTunes Functions
******************************************************************************/
#ifdef HAVE_ITUNES
SLUIFT_LUA_FUNCTION(iTunes, get_current_track) {
boost::optional<ITunesInterface::Track> track = Sluift::globals.iTunes.getCurrentTrack();
if (!track) {
return 0;
}
lua_createtable(L, 0, 0);
lua_pushstring(L, track->artist.c_str());
lua_setfield(L, -2, "artist");
lua_pushstring(L, track->name.c_str());
lua_setfield(L, -2, "name");
lua_pushstring(L, track->album.c_str());
lua_setfield(L, -2, "album");
lua_pushinteger(L, track->trackNumber);
lua_setfield(L, -2, "track_number");
lua_pushnumber(L, track->duration);
lua_setfield(L, -2, "duration");
lua_pushinteger(L, track->rating);
lua_setfield(L, -2, "rating");
Lua::registerTableToString(L, -1);
Lua::registerTableEquals(L, -1);
return 1;
}
#endif
/*******************************************************************************
* Module registration
******************************************************************************/
static const luaL_Reg sluift_functions[] = { {NULL, NULL} };
SLUIFT_API int luaopen_sluift(lua_State* L) {
// Initialize & store the module table
luaL_register(L, lua_tostring(L, 1), sluift_functions);
lua_pushinteger(L, -1);
lua_setfield(L, -2, "timeout");
lua_pushboolean(L, 0);
lua_setfield(L, -2, "debug");
lua_pushvalue(L, -1);
Sluift::globals.moduleLibIndex = luaL_ref(L, LUA_REGISTRYINDEX);
// Load core lib code
if (luaL_loadbuffer(L, core_lua, core_lua_size, "core.lua") != 0) {
lua_error(L);
}
lua_pushvalue(L, -2);
lua_call(L, 1, 1);
Sluift::globals.coreLibIndex = luaL_ref(L, LUA_REGISTRYINDEX);
// Register functions
Lua::FunctionRegistry::getInstance().addFunctionsToTable(L, "Sluift");
Lua::FunctionRegistry::getInstance().createFunctionTable(L, "JID");
lua_setfield(L, -2, "jid");
Lua::FunctionRegistry::getInstance().createFunctionTable(L, "Base64");
lua_setfield(L, -2, "base64");
Lua::FunctionRegistry::getInstance().createFunctionTable(L, "IDN");
lua_setfield(L, -2, "idn");
Lua::FunctionRegistry::getInstance().createFunctionTable(L, "Crypto");
lua_setfield(L, -2, "crypto");
#ifdef HAVE_ITUNES
Lua::FunctionRegistry::getInstance().createFunctionTable(L, "iTunes");
lua_setfield(L, -2, "itunes");
#endif
// Register convenience functions
lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.coreLibIndex);
std::vector<std::string> coreLibExports = boost::assign::list_of
("tprint")("disco")("help")("get_help")("copy")("with")("read_file")("create_form");
foreach (const std::string& coreLibExport, coreLibExports) {
lua_getfield(L, -1, coreLibExport.c_str());
lua_setfield(L, -3, coreLibExport.c_str());
}
lua_pop(L, 1);
// Load client metatable
lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.coreLibIndex);
std::vector<std::string> tables = boost::assign::list_of("Client");
foreach(const std::string& table, tables) {
lua_getfield(L, -1, table.c_str());
Lua::FunctionRegistry::getInstance().addFunctionsToTable(L, table);
lua_pop(L, 1);
}
lua_pop(L, 1);
+ // Load component metatable
+ lua_rawgeti(L, LUA_REGISTRYINDEX, Sluift::globals.coreLibIndex);
+ std::vector<std::string> comp_tables = boost::assign::list_of("Component");
+ foreach(const std::string& table, comp_tables) {
+ lua_getfield(L, -1, table.c_str());
+ Lua::FunctionRegistry::getInstance().addFunctionsToTable(L, table);
+ lua_pop(L, 1);
+ }
+ lua_pop(L, 1);
+
// Register documentation for all elements
foreach (boost::shared_ptr<LuaElementConvertor> convertor, Sluift::globals.elementConvertor.getConvertors()) {
boost::optional<LuaElementConvertor::Documentation> documentation = convertor->getDocumentation();
if (documentation) {
Lua::registerClassHelp(L, documentation->className, documentation->description);
}
}
// Register global documentation
Lua::registerExtraHelp(L, -1, "sluift");
return 1;
}
diff --git a/Swiften/Component/CoreComponent.cpp b/Swiften/Component/CoreComponent.cpp
index d2cc7aa..cc6be42 100644
--- a/Swiften/Component/CoreComponent.cpp
+++ b/Swiften/Component/CoreComponent.cpp
@@ -1,164 +1,168 @@
/*
* Copyright (c) 2010-2013 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#include <Swiften/Component/CoreComponent.h>
#include <boost/bind.hpp>
#include <iostream>
#include <Swiften/Component/ComponentSession.h>
#include <Swiften/Network/Connector.h>
#include <Swiften/Network/NetworkFactories.h>
#include <Swiften/TLS/PKCS12Certificate.h>
#include <Swiften/Session/BasicSessionStream.h>
#include <Swiften/Queries/IQRouter.h>
#include <Swiften/Base/IDGenerator.h>
#include <Swiften/Component/ComponentSessionStanzaChannel.h>
namespace Swift {
CoreComponent::CoreComponent(const JID& jid, const std::string& secret, NetworkFactories* networkFactories) : networkFactories(networkFactories), jid_(jid), secret_(secret), disconnectRequested_(false) {
stanzaChannel_ = new ComponentSessionStanzaChannel();
stanzaChannel_->onMessageReceived.connect(boost::ref(onMessageReceived));
stanzaChannel_->onPresenceReceived.connect(boost::ref(onPresenceReceived));
stanzaChannel_->onAvailableChanged.connect(boost::bind(&CoreComponent::handleStanzaChannelAvailableChanged, this, _1));
iqRouter_ = new IQRouter(stanzaChannel_);
iqRouter_->setFrom(jid);
}
CoreComponent::~CoreComponent() {
if (session_ || connection_) {
std::cerr << "Warning: Component not disconnected properly" << std::endl;
}
delete iqRouter_;
stanzaChannel_->onAvailableChanged.disconnect(boost::bind(&CoreComponent::handleStanzaChannelAvailableChanged, this, _1));
stanzaChannel_->onMessageReceived.disconnect(boost::ref(onMessageReceived));
stanzaChannel_->onPresenceReceived.disconnect(boost::ref(onPresenceReceived));
delete stanzaChannel_;
}
void CoreComponent::connect(const std::string& host, int port) {
assert(!connector_);
connector_ = ComponentConnector::create(host, port, networkFactories->getDomainNameResolver(), networkFactories->getConnectionFactory(), networkFactories->getTimerFactory());
connector_->onConnectFinished.connect(boost::bind(&CoreComponent::handleConnectorFinished, this, _1));
connector_->setTimeoutMilliseconds(60*1000);
connector_->start();
}
void CoreComponent::handleConnectorFinished(boost::shared_ptr<Connection> connection) {
connector_->onConnectFinished.disconnect(boost::bind(&CoreComponent::handleConnectorFinished, this, _1));
connector_.reset();
if (!connection) {
if (!disconnectRequested_) {
onError(ComponentError::ConnectionError);
}
}
else {
assert(!connection_);
connection_ = connection;
assert(!sessionStream_);
sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(ComponentStreamType, connection_, getPayloadParserFactories(), getPayloadSerializers(), NULL, networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory()));
sessionStream_->onDataRead.connect(boost::bind(&CoreComponent::handleDataRead, this, _1));
sessionStream_->onDataWritten.connect(boost::bind(&CoreComponent::handleDataWritten, this, _1));
session_ = ComponentSession::create(jid_, secret_, sessionStream_, networkFactories->getCryptoProvider());
stanzaChannel_->setSession(session_);
session_->onFinished.connect(boost::bind(&CoreComponent::handleSessionFinished, this, _1));
session_->start();
}
}
void CoreComponent::disconnect() {
// FIXME: We should be able to do without this boolean. We just have to make sure we can tell the difference between
// connector finishing without a connection due to an error or because of a disconnect.
disconnectRequested_ = true;
if (session_) {
session_->finish();
}
else if (connector_) {
connector_->stop();
assert(!session_);
}
//assert(!session_); /* commenting out until we have time to refactor to be like CoreClient */
//assert(!sessionStream_);
//assert(!connector_);
disconnectRequested_ = false;
}
void CoreComponent::handleSessionFinished(boost::shared_ptr<Error> error) {
session_->onFinished.disconnect(boost::bind(&CoreComponent::handleSessionFinished, this, _1));
session_.reset();
sessionStream_->onDataRead.disconnect(boost::bind(&CoreComponent::handleDataRead, this, _1));
sessionStream_->onDataWritten.disconnect(boost::bind(&CoreComponent::handleDataWritten, this, _1));
sessionStream_.reset();
connection_->disconnect();
connection_.reset();
if (error) {
ComponentError componentError;
if (boost::shared_ptr<ComponentSession::Error> actualError = boost::dynamic_pointer_cast<ComponentSession::Error>(error)) {
switch(actualError->type) {
case ComponentSession::Error::AuthenticationFailedError:
componentError = ComponentError(ComponentError::AuthenticationFailedError);
break;
case ComponentSession::Error::UnexpectedElementError:
componentError = ComponentError(ComponentError::UnexpectedElementError);
break;
}
}
else if (boost::shared_ptr<SessionStream::SessionStreamError> actualError = boost::dynamic_pointer_cast<SessionStream::SessionStreamError>(error)) {
switch(actualError->type) {
case SessionStream::SessionStreamError::ParseError:
componentError = ComponentError(ComponentError::XMLError);
break;
case SessionStream::SessionStreamError::TLSError:
assert(false);
componentError = ComponentError(ComponentError::UnknownError);
break;
case SessionStream::SessionStreamError::InvalidTLSCertificateError:
assert(false);
componentError = ComponentError(ComponentError::UnknownError);
break;
case SessionStream::SessionStreamError::ConnectionReadError:
componentError = ComponentError(ComponentError::ConnectionReadError);
break;
case SessionStream::SessionStreamError::ConnectionWriteError:
componentError = ComponentError(ComponentError::ConnectionWriteError);
break;
}
}
onError(componentError);
}
}
void CoreComponent::handleDataRead(const SafeByteArray& data) {
onDataRead(data);
}
void CoreComponent::handleDataWritten(const SafeByteArray& data) {
onDataWritten(data);
}
void CoreComponent::handleStanzaChannelAvailableChanged(bool available) {
if (available) {
onConnected();
}
}
void CoreComponent::sendMessage(boost::shared_ptr<Message> message) {
stanzaChannel_->sendMessage(message);
}
void CoreComponent::sendPresence(boost::shared_ptr<Presence> presence) {
stanzaChannel_->sendPresence(presence);
}
+void CoreComponent::sendData(const std::string& data) {
+ sessionStream_->writeData(data);
+}
+
}
diff --git a/Swiften/Component/CoreComponent.h b/Swiften/Component/CoreComponent.h
index 63b68f6..e9fdd88 100644
--- a/Swiften/Component/CoreComponent.h
+++ b/Swiften/Component/CoreComponent.h
@@ -1,102 +1,103 @@
/*
* Copyright (c) 2010-2013 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#pragma once
#include <boost/shared_ptr.hpp>
#include <Swiften/Base/API.h>
#include <Swiften/Base/boost_bsignals.h>
#include <Swiften/Base/Error.h>
#include <Swiften/Component/ComponentConnector.h>
#include <Swiften/Component/ComponentSession.h>
#include <Swiften/Component/ComponentError.h>
#include <Swiften/Elements/Presence.h>
#include <Swiften/Elements/Message.h>
#include <Swiften/JID/JID.h>
#include <string>
#include <Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h>
#include <Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h>
#include <Swiften/Component/ComponentSessionStanzaChannel.h>
#include <Swiften/Entity/Entity.h>
#include <Swiften/Base/SafeByteArray.h>
namespace Swift {
class EventLoop;
class IQRouter;
class NetworkFactories;
class ComponentSession;
class BasicSessionStream;
/**
* The central class for communicating with an XMPP server as a component.
*
* This class is responsible for setting up the connection with the XMPP
* server and authenticating the component.
*
* This class can be used directly in your application, although the Component
* subclass provides more functionality and interfaces, and is better suited
* for most needs.
*/
class SWIFTEN_API CoreComponent : public Entity {
public:
CoreComponent(const JID& jid, const std::string& secret, NetworkFactories* networkFactories);
~CoreComponent();
void connect(const std::string& host, int port);
void disconnect();
void sendMessage(boost::shared_ptr<Message>);
void sendPresence(boost::shared_ptr<Presence>);
+ void sendData(const std::string& data);
IQRouter* getIQRouter() const {
return iqRouter_;
}
StanzaChannel* getStanzaChannel() const {
return stanzaChannel_;
}
bool isAvailable() const {
return stanzaChannel_->isAvailable();
}
/**
* Returns the JID of the component
*/
const JID& getJID() const {
return jid_;
}
public:
boost::signal<void (const ComponentError&)> onError;
boost::signal<void ()> onConnected;
boost::signal<void (const SafeByteArray&)> onDataRead;
boost::signal<void (const SafeByteArray&)> onDataWritten;
boost::signal<void (boost::shared_ptr<Message>)> onMessageReceived;
boost::signal<void (boost::shared_ptr<Presence>) > onPresenceReceived;
private:
void handleConnectorFinished(boost::shared_ptr<Connection>);
void handleStanzaChannelAvailableChanged(bool available);
void handleSessionFinished(boost::shared_ptr<Error>);
void handleDataRead(const SafeByteArray&);
void handleDataWritten(const SafeByteArray&);
private:
NetworkFactories* networkFactories;
JID jid_;
std::string secret_;
ComponentSessionStanzaChannel* stanzaChannel_;
IQRouter* iqRouter_;
ComponentConnector::ref connector_;
boost::shared_ptr<Connection> connection_;
boost::shared_ptr<BasicSessionStream> sessionStream_;
boost::shared_ptr<ComponentSession> session_;
bool disconnectRequested_;
};
}