From ca3f25d09a703ff7c27267a5591ce5379886e1c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Remko=20Tron=C3=A7on?= Date: Tue, 1 Mar 2011 21:17:23 +0100 Subject: Some more Sluift enhancements. diff --git a/BuildTools/Eclipse/Swift (Mac OS X).launch b/BuildTools/Eclipse/Swift (Mac OS X).launch index 1dc569f..ba5e5be 100644 --- a/BuildTools/Eclipse/Swift (Mac OS X).launch +++ b/BuildTools/Eclipse/Swift (Mac OS X).launch @@ -17,7 +17,7 @@ - + diff --git a/QA/UnitTest/Unit Tests.launch b/QA/UnitTest/Unit Tests.launch new file mode 100644 index 0000000..f58a6b0 --- /dev/null +++ b/QA/UnitTest/Unit Tests.launch @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/Sluift/Examples/Login.lua b/Sluift/Examples/Login.lua index 52c1521..d93e990 100644 --- a/Sluift/Examples/Login.lua +++ b/Sluift/Examples/Login.lua @@ -5,7 +5,8 @@ -- -- This script logs into an XMPP server, and sends initial presence --- Useful as initialization script for an interactive session ('-i') +-- Useful as initialization script for an interactive session ('-i'), +-- or as a starting point for scripts. -- -- The following environment variables are used: -- * SLUIFT_JID, SWIFT_PASS: JID and password to log in with @@ -17,7 +18,7 @@ sluift.debug = os.getenv("SLUIFT_DEBUG") or false print("Connecting " .. os.getenv("SLUIFT_JID") .. " ...") c = sluift.new_client(os.getenv("SLUIFT_JID"), os.getenv("SLUIFT_PASS")) -c:set_options(os.getenv("SLUIFT_OPTIONS") or {}) +c:set_options({compress = false, tls = false}) c:connect() c:send_presence("") diff --git a/Sluift/ResponseSink.h b/Sluift/ResponseSink.h new file mode 100644 index 0000000..042d6e0 --- /dev/null +++ b/Sluift/ResponseSink.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2011 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include + +#include + +namespace Swift { + template + class ResponseSink { + public: + ResponseSink() : responseReceived(false) { + } + + bool hasResponse() const { + return responseReceived; + } + + boost::shared_ptr getResponsePayload() const { + return payload; + } + + ErrorPayload::ref getResponseError() const { + return error; + } + + void operator()(boost::shared_ptr payload, ErrorPayload::ref error) { + this->payload = payload; + this->error = error; + this->responseReceived = true; + } + + private: + bool responseReceived; + boost::shared_ptr payload; + ErrorPayload::ref error; + }; +} diff --git a/Sluift/SConscript b/Sluift/SConscript index 44fabdf..816c234 100644 --- a/Sluift/SConscript +++ b/Sluift/SConscript @@ -6,12 +6,15 @@ if env["SCONS_STAGE"] == "build" : myenv.UseFlags(env["SWIFTEN_FLAGS"]) myenv.UseFlags(env["SWIFTEN_DEP_FLAGS"]) myenv["SHLIBPREFIX"] = "" - if myenv["PLATFORM"] == "win32" : myenv.Append(CPPDEFINES = ["SLUIFT_BUILD_DLL"]) elif myenv["PLATFORM"] == "darwin" : myenv["SHLIBSUFFIX"] = ".so" + sluift_lib = myenv.StaticLibrary("SluiftCore", [ + "sluift.cpp" + ]); + def patchLua(env, target, source) : f = open(source[0].abspath, "r") contents = f.read() @@ -23,20 +26,16 @@ if env["SCONS_STAGE"] == "build" : f.close() sluift_bin_env = myenv.Clone() + sluift_bin_env.Append(LIBS = sluift_lib) sluift_bin_env.Command("lua.c", ["#/3rdParty/Lua/src/lua.c"], env.Action(patchLua, cmdstr = "$GENCOMSTR")) if sluift_bin_env.get("HAVE_READLINE", False) : sluift_bin_env.Append(CPPDEFINES = ["LUA_USE_READLINE"]) sluift_bin_env.MergeFlags(sluift_bin_env["READLINE_FLAGS"]) env["SLUIFT"] = sluift_bin_env.Program("sluift", [ - "sluift.cpp", "lua.c", "linit.c", ]) - # Create a copy of sluift.cpp to avoid conflicting targets - # Ideally, we would use variants for this - myenv.InstallAs("sluift_dll.cpp", "sluift.cpp") - myenv.SharedLibrary("sluift", [ - "sluift_dll.cpp", - ]) - + sluift_dll_env = myenv.Clone() + sluift_dll_env.Append(LIBS = sluift_lib) + sluift_dll_env.SharedLibrary("sluift", []); diff --git a/Sluift/SluiftException.h b/Sluift/SluiftException.h new file mode 100644 index 0000000..92326b6 --- /dev/null +++ b/Sluift/SluiftException.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2011 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include + +#include + +namespace Swift { + class SluiftException { + public: + SluiftException(const std::string& reason) : reason(reason) { + } + + SluiftException(const ClientError& error) { + std::string reason("Disconnected: "); + switch(error.getType()) { + case ClientError::UnknownError: reason += "Unknown Error"; break; + case ClientError::DomainNameResolveError: reason += "Unable to find server"; break; + case ClientError::ConnectionError: reason += "Error connecting to server"; break; + case ClientError::ConnectionReadError: reason += "Error while receiving server data"; break; + case ClientError::ConnectionWriteError: reason += "Error while sending data to the server"; break; + case ClientError::XMLError: reason += "Error parsing server data"; break; + case ClientError::AuthenticationFailedError: reason += "Login/password invalid"; break; + case ClientError::CompressionFailedError: reason += "Error while compressing stream"; break; + case ClientError::ServerVerificationFailedError: reason += "Server verification failed"; break; + case ClientError::NoSupportedAuthMechanismsError: reason += "Authentication mechanisms not supported"; break; + case ClientError::UnexpectedElementError: reason += "Unexpected response"; break; + case ClientError::ResourceBindError: reason += "Error binding resource"; break; + case ClientError::SessionStartError: reason += "Error starting session"; break; + case ClientError::StreamError: reason += "Stream error"; break; + case ClientError::TLSError: reason += "Encryption error"; break; + case ClientError::ClientCertificateLoadError: reason += "Error loading certificate (Invalid password?)"; break; + case ClientError::ClientCertificateError: reason += "Certificate not authorized"; break; + case ClientError::UnknownCertificateError: reason += "Unknown certificate"; break; + case ClientError::CertificateExpiredError: reason += "Certificate has expired"; break; + case ClientError::CertificateNotYetValidError: reason += "Certificate is not yet valid"; break; + case ClientError::CertificateSelfSignedError: reason += "Certificate is self-signed"; break; + case ClientError::CertificateRejectedError: reason += "Certificate has been rejected"; break; + case ClientError::CertificateUntrustedError: reason += "Certificate is not trusted"; break; + case ClientError::InvalidCertificatePurposeError: reason += "Certificate cannot be used for encrypting your connection"; break; + case ClientError::CertificatePathLengthExceededError: reason += "Certificate path length constraint exceeded"; break; + case ClientError::InvalidCertificateSignatureError: reason += "Invalid certificate signature"; break; + case ClientError::InvalidCAError: reason += "Invalid Certificate Authority"; break; + case ClientError::InvalidServerIdentityError: reason += "Certificate does not match the host identity"; break; + } + } + + const std::string& getReason() const { + return reason; + } + + private: + std::string reason; + }; +} diff --git a/Sluift/Watchdog.h b/Sluift/Watchdog.h new file mode 100644 index 0000000..95b6971 --- /dev/null +++ b/Sluift/Watchdog.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2011 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include + +namespace Swift { + class Watchdog { + public: + Watchdog(int timeout, TimerFactory* timerFactory) : timedOut(false) { + if (timeout > 0) { + timer = timerFactory->createTimer(timeout); + timer->start(); + timer->onTick.connect(boost::bind(&Watchdog::handleTimerTick, this)); + } + else if (timeout == 0) { + timedOut = true; + } + } + + ~Watchdog() { + if (timer) { + timer->stop(); + } + } + + bool getTimedOut() const { + return timedOut; + } + + private: + void handleTimerTick() { + timedOut = true; + } + + private: + Timer::ref timer; + bool timedOut; + }; +} diff --git a/Sluift/sluift.cpp b/Sluift/sluift.cpp index 1ce9642..da2c93b 100644 --- a/Sluift/sluift.cpp +++ b/Sluift/sluift.cpp @@ -20,10 +20,14 @@ extern "C" { #include #include #include -#include #include #include +#include #include +#include +#include "Watchdog.h" +#include "SluiftException.h" +#include "ResponseSink.h" using namespace Swift; @@ -42,51 +46,6 @@ bool debug = false; SimpleEventLoop eventLoop; BoostNetworkFactories networkFactories(&eventLoop); -class Watchdog { - public: - Watchdog(int timeout) : timedOut(false) { - if (timeout > 0) { - timer = networkFactories.getTimerFactory()->createTimer(timeout); - timer->start(); - timer->onTick.connect(boost::bind(&Watchdog::handleTimerTick, this)); - } - else if (timeout == 0) { - timedOut = true; - } - } - - ~Watchdog() { - if (timer) { - timer->stop(); - } - } - - bool getTimedOut() const { - return timedOut; - } - - private: - void handleTimerTick() { - timedOut = true; - } - - private: - Timer::ref timer; - bool timedOut; -}; - -class SluiftException { - public: - SluiftException(const std::string& reason) : reason(reason) { - } - - const std::string& getReason() const { - return reason; - } - - private: - std::string reason; -}; class SluiftClient { public: @@ -137,6 +96,17 @@ class SluiftClient { client->sendPresence(boost::make_shared(status)); } + std::string sendQuery(const JID& jid, IQ::Type type, const std::string& data) { + rawRequestResponse.reset(); + RawRequest::ref request = RawRequest::create(type, jid, data, client->getIQRouter()); + request->onResponse.connect(boost::bind(&SluiftClient::handleRawRequestResponse, this, _1)); + request->send(); + while (!rawRequestResponse) { + eventLoop.runUntilEvents(); + } + return *rawRequestResponse; + } + void disconnect() { client->disconnect(); while (client->isActive()) { @@ -149,25 +119,23 @@ class SluiftClient { } boost::optional getSoftwareVersion(const JID& jid) { + ResponseSink sink; GetSoftwareVersionRequest::ref request = GetSoftwareVersionRequest::create(jid, client->getIQRouter()); - request->onResponse.connect(boost::bind(&SluiftClient::handleSoftwareVersionResponse, this, _1, _2)); - softwareVersion.reset(); - error.reset(); + request->onResponse.connect(boost::ref(sink)); request->send(); - while (!softwareVersion && !error) { + while (!sink.hasResponse()) { eventLoop.runUntilEvents(); } - return softwareVersion; + return sink.getResponsePayload() ? *sink.getResponsePayload() : boost::optional(); } Stanza::ref getNextEvent(int timeout) { - eventLoop.runOnce(); if (!pendingEvents.empty()) { Stanza::ref event = pendingEvents.front(); pendingEvents.pop_front(); return event; } - Watchdog watchdog(timeout); + Watchdog watchdog(timeout, networkFactories.getTimerFactory()); while (!watchdog.getTimedOut() && pendingEvents.empty()) { eventLoop.runUntilEvents(); } @@ -198,52 +166,13 @@ class SluiftClient { rosterReceived = true; } - void handleSoftwareVersionResponse(boost::shared_ptr version, ErrorPayload::ref error) { - if (error) { - this->error = error; - } - else if (version) { - this->softwareVersion = *version; - } - else { - this->softwareVersion = SoftwareVersion("", "", ""); - } + void handleRawRequestResponse(const std::string& response) { + rawRequestResponse = response; } void handleDisconnected(const boost::optional& error) { if (error) { - std::string reason("Disconnected: "); - switch(error->getType()) { - case ClientError::UnknownError: reason += "Unknown Error"; break; - case ClientError::DomainNameResolveError: reason += "Unable to find server"; break; - case ClientError::ConnectionError: reason += "Error connecting to server"; break; - case ClientError::ConnectionReadError: reason += "Error while receiving server data"; break; - case ClientError::ConnectionWriteError: reason += "Error while sending data to the server"; break; - case ClientError::XMLError: reason += "Error parsing server data"; break; - case ClientError::AuthenticationFailedError: reason += "Login/password invalid"; break; - case ClientError::CompressionFailedError: reason += "Error while compressing stream"; break; - case ClientError::ServerVerificationFailedError: reason += "Server verification failed"; break; - case ClientError::NoSupportedAuthMechanismsError: reason += "Authentication mechanisms not supported"; break; - case ClientError::UnexpectedElementError: reason += "Unexpected response"; break; - case ClientError::ResourceBindError: reason += "Error binding resource"; break; - case ClientError::SessionStartError: reason += "Error starting session"; break; - case ClientError::StreamError: reason += "Stream error"; break; - case ClientError::TLSError: reason += "Encryption error"; break; - case ClientError::ClientCertificateLoadError: reason += "Error loading certificate (Invalid password?)"; break; - case ClientError::ClientCertificateError: reason += "Certificate not authorized"; break; - case ClientError::UnknownCertificateError: reason += "Unknown certificate"; break; - case ClientError::CertificateExpiredError: reason += "Certificate has expired"; break; - case ClientError::CertificateNotYetValidError: reason += "Certificate is not yet valid"; break; - case ClientError::CertificateSelfSignedError: reason += "Certificate is self-signed"; break; - case ClientError::CertificateRejectedError: reason += "Certificate has been rejected"; break; - case ClientError::CertificateUntrustedError: reason += "Certificate is not trusted"; break; - case ClientError::InvalidCertificatePurposeError: reason += "Certificate cannot be used for encrypting your connection"; break; - case ClientError::CertificatePathLengthExceededError: reason += "Certificate path length constraint exceeded"; break; - case ClientError::InvalidCertificateSignatureError: reason += "Invalid certificate signature"; break; - case ClientError::InvalidCAError: reason += "Invalid Certificate Authority"; break; - case ClientError::InvalidServerIdentityError: reason += "Certificate does not match the host identity"; break; - } - throw SluiftException(reason); + throw SluiftException(*error); } } @@ -251,9 +180,8 @@ class SluiftClient { Client* client; ClientXMLTracer* tracer; bool rosterReceived; - boost::optional softwareVersion; - ErrorPayload::ref error; std::deque pendingEvents; + boost::optional rawRequestResponse; }; /******************************************************************************* @@ -400,6 +328,43 @@ static int sluift_client_send_presence(lua_State *L) { return 0; } +static int sluift_client_get(lua_State *L) { + SluiftClient* client = getClient(L); + JID jid; + std::string data; + if (lua_type(L, 3) != LUA_TNONE) { + jid = JID(std::string(luaL_checkstring(L, 2))); + data = std::string(luaL_checkstring(L, 3)); + } + else { + data = std::string(luaL_checkstring(L, 2)); + } + std::string result = client->sendQuery(jid, IQ::Get, data); + lua_pushstring(L, result.c_str()); + return 1; +} + +static int sluift_client_set(lua_State *L) { + SluiftClient* client = getClient(L); + JID jid; + std::string data; + if (lua_type(L, 3) != LUA_TNONE) { + jid = JID(std::string(luaL_checkstring(L, 2))); + data = std::string(luaL_checkstring(L, 3)); + } + else { + data = std::string(luaL_checkstring(L, 2)); + } + std::string result = client->sendQuery(jid, IQ::Set, data); + lua_pushstring(L, result.c_str()); + return 1; +} + +static int sluift_client_send(lua_State *L) { + getClient(L)->getClient()->sendData(std::string(luaL_checkstring(L, 2))); + return 0; +} + static int sluift_client_set_options(lua_State* L) { SluiftClient* client = getClient(L); luaL_checktype(L, 2, LUA_TTABLE); @@ -407,7 +372,11 @@ static int sluift_client_set_options(lua_State* L) { if (!lua_isnil(L, -1)) { client->getClient()->setUseStreamCompression(lua_toboolean(L, -1)); } - lua_pop(L, -1); + lua_getfield(L, 2, "tls"); + if (!lua_isnil(L, -1)) { + bool useTLS = lua_toboolean(L, -1); + client->getClient()->setUseTLS(useTLS ? Client::UseTLSWhenAvailable : Client::NeverUseTLS); + } return 0; } @@ -437,6 +406,8 @@ static void pushEvent(lua_State* L, Stanza::ref event) { static int sluift_client_for_event(lua_State *L) { try { + eventLoop.runOnce(); + SluiftClient* client = getClient(L); luaL_checktype(L, 2, LUA_TFUNCTION); int timeout = -1; @@ -472,6 +443,8 @@ static int sluift_client_for_event(lua_State *L) { static int sluift_client_get_next_event(lua_State *L) { try { + eventLoop.runOnce(); + SluiftClient* client = getClient(L); int timeout = -1; if (lua_type(L, 2) != LUA_TNONE) { @@ -500,6 +473,9 @@ static const luaL_reg sluift_client_functions[] = { {"disconnect", sluift_client_disconnect}, {"send_message", sluift_client_send_message}, {"send_presence", sluift_client_send_presence}, + {"get", sluift_client_get}, + {"set", sluift_client_set}, + {"send", sluift_client_send}, {"set_version", sluift_client_set_version}, {"get_roster", sluift_client_get_roster}, {"get_version", sluift_client_get_version}, @@ -515,15 +491,37 @@ static const luaL_reg sluift_client_functions[] = { ******************************************************************************/ static int sluift_new_client(lua_State *L) { - JID jid(std::string(luaL_checkstring(L, 1))); - std::string password(luaL_checkstring(L, 2)); + try { + JID jid(std::string(luaL_checkstring(L, 1))); + std::string password(luaL_checkstring(L, 2)); - SluiftClient** client = reinterpret_cast(lua_newuserdata(L, sizeof(SluiftClient*))); - luaL_getmetatable(L, SLUIFT_CLIENT); - lua_setmetatable(L, -2); + SluiftClient** client = reinterpret_cast(lua_newuserdata(L, sizeof(SluiftClient*))); + luaL_getmetatable(L, SLUIFT_CLIENT); + lua_setmetatable(L, -2); - *client = new SluiftClient(jid, password); - return 1; + *client = new SluiftClient(jid, password); + return 1; + } + catch (const SluiftException& e) { + return luaL_error(L, e.getReason().c_str()); + } +} + +static int sluift_sleep(lua_State *L) { + try { + eventLoop.runOnce(); + + int timeout = luaL_checknumber(L, 1); + Watchdog watchdog(timeout, networkFactories.getTimerFactory()); + while (!watchdog.getTimedOut()) { + Swift::sleep(std::min(100, timeout)); + eventLoop.runOnce(); + } + return 0; + } + catch (const SluiftException& e) { + return luaL_error(L, e.getReason().c_str()); + } } static int sluift_index(lua_State *L) { @@ -550,6 +548,7 @@ static int sluift_newindex(lua_State *L) { static const luaL_reg sluift_functions[] = { {"new_client", sluift_new_client}, + {"sleep", sluift_sleep}, {NULL, NULL} }; diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index 8d9e678..e1c1d8e 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -50,6 +50,7 @@ ClientSession::ClientSession( stream(stream), allowPLAINOverNonTLS(false), useStreamCompression(true), + useTLS(UseTLSWhenAvailable), needSessionStart(false), needResourceBind(false), needAcking(false), @@ -170,7 +171,7 @@ void ClientSession::handleElement(boost::shared_ptr element) { return; } - if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption()) { + if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption() && useTLS != NeverUseTLS) { state = WaitingForEncrypt; stream->writeElement(boost::shared_ptr(new StartTLSRequest())); } diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index ee3992d..25ee694 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -57,6 +57,11 @@ namespace Swift { Error(Type type) : type(type) {} }; + enum UseTLS { + NeverUseTLS, + UseTLSWhenAvailable + }; + ~ClientSession(); static boost::shared_ptr create(const JID& jid, boost::shared_ptr stream) { @@ -75,6 +80,11 @@ namespace Swift { useStreamCompression = b; } + void setUseTLS(UseTLS b) { + useTLS = b; + } + + bool getStreamManagementEnabled() const { return stanzaAckRequester_; } @@ -139,6 +149,7 @@ namespace Swift { boost::shared_ptr stream; bool allowPLAINOverNonTLS; bool useStreamCompression; + UseTLS useTLS; bool needSessionStart; bool needResourceBind; bool needAcking; diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index a199b16..f0c5333 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -22,7 +22,7 @@ namespace Swift { -CoreClient::CoreClient(const JID& jid, const std::string& password, NetworkFactories* networkFactories) : jid_(jid), password_(password), networkFactories(networkFactories), useStreamCompression(true), disconnectRequested_(false), certificateTrustChecker(NULL) { +CoreClient::CoreClient(const JID& jid, const std::string& password, NetworkFactories* networkFactories) : jid_(jid), password_(password), networkFactories(networkFactories), useStreamCompression(true), useTLS(UseTLSWhenAvailable), disconnectRequested_(false), certificateTrustChecker(NULL) { stanzaChannel_ = new ClientSessionStanzaChannel(); stanzaChannel_->onMessageReceived.connect(boost::bind(&CoreClient::handleMessageReceived, this, _1)); stanzaChannel_->onPresenceReceived.connect(boost::bind(&CoreClient::handlePresenceReceived, this, _1)); @@ -83,6 +83,14 @@ void CoreClient::handleConnectorFinished(boost::shared_ptr connectio session_ = ClientSession::create(jid_, sessionStream_); session_->setCertificateTrustChecker(certificateTrustChecker); session_->setUseStreamCompression(useStreamCompression); + switch(useTLS) { + case UseTLSWhenAvailable: + session_->setUseTLS(ClientSession::UseTLSWhenAvailable); + break; + case NeverUseTLS: + session_->setUseTLS(ClientSession::NeverUseTLS); + break; + } stanzaChannel_->setSession(session_); session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1)); session_->onNeedCredentials.connect(boost::bind(&CoreClient::handleNeedCredentials, this)); @@ -242,6 +250,10 @@ void CoreClient::sendPresence(boost::shared_ptr presence) { stanzaChannel_->sendPresence(presence); } +void CoreClient::sendData(const std::string& data) { + sessionStream_->writeData(data); +} + bool CoreClient::isActive() const { return (session_ && !session_->isFinished()) || connector_; } @@ -267,5 +279,9 @@ void CoreClient::setUseStreamCompression(bool b) { useStreamCompression = b; } +void CoreClient::setUseTLS(UseTLS b) { + useTLS = b; +} + } diff --git a/Swiften/Client/CoreClient.h b/Swiften/Client/CoreClient.h index ee73396..eb9c42c 100644 --- a/Swiften/Client/CoreClient.h +++ b/Swiften/Client/CoreClient.h @@ -48,6 +48,11 @@ namespace Swift { */ class CoreClient : public Entity { public: + enum UseTLS { + NeverUseTLS, + UseTLSWhenAvailable + }; + /** * Constructs a client for the given JID with the given password. * The given eventLoop will be used to post events to. @@ -83,6 +88,11 @@ namespace Swift { void sendPresence(Presence::ref); /** + * Sends raw, unchecked data. + */ + void sendData(const std::string& data); + + /** * Returns the IQ router for this client. */ IQRouter* getIQRouter() const { @@ -148,6 +158,11 @@ namespace Swift { */ void setUseStreamCompression(bool b); + /** + * Sets whether TLS encryption should be used. + */ + void setUseTLS(UseTLS useTLS); + public: /** * Emitted when the client was disconnected from the network. @@ -213,6 +228,7 @@ namespace Swift { std::string password_; NetworkFactories* networkFactories; bool useStreamCompression; + UseTLS useTLS; ClientSessionStanzaChannel* stanzaChannel_; IQRouter* iqRouter_; Connector::ref connector_; diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index 21c0ffb..756287c 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -329,6 +329,9 @@ class ClientSessionTest : public CppUnit::TestFixture { receivedEvents.push_back(Event(element)); } + virtual void writeData(const std::string&) { + } + virtual bool supportsTLSEncryption() { return canTLSEncrypt; } diff --git a/Swiften/Component/UnitTest/ComponentSessionTest.cpp b/Swiften/Component/UnitTest/ComponentSessionTest.cpp index af8962a..953973c 100644 --- a/Swiften/Component/UnitTest/ComponentSessionTest.cpp +++ b/Swiften/Component/UnitTest/ComponentSessionTest.cpp @@ -115,6 +115,9 @@ class ComponentSessionTest : public CppUnit::TestFixture { receivedEvents.push_back(Event(element)); } + virtual void writeData(const std::string&) { + } + virtual bool supportsTLSEncryption() { return false; } diff --git a/Swiften/Elements/RawXMLPayload.h b/Swiften/Elements/RawXMLPayload.h index b042b95..e583c12 100644 --- a/Swiften/Elements/RawXMLPayload.h +++ b/Swiften/Elements/RawXMLPayload.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2010 Remko Tronçon + * Copyright (c) 2011 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ @@ -7,12 +7,13 @@ #pragma once #include -#include "Swiften/Elements/Payload.h" + +#include namespace Swift { class RawXMLPayload : public Payload { public: - RawXMLPayload() {} + RawXMLPayload(const std::string& data = "") : rawXML_(data) {} void setRawXML(const std::string& data) { rawXML_ = data; diff --git a/Swiften/Queries/RawRequest.h b/Swiften/Queries/RawRequest.h new file mode 100644 index 0000000..477952f --- /dev/null +++ b/Swiften/Queries/RawRequest.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2011 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace Swift { + class RawRequest : public Request { + public: + typedef boost::shared_ptr ref; + + static ref create(IQ::Type type, const JID& recipient, const std::string& data, IQRouter* router) { + return ref(new RawRequest(type, recipient, data, router)); + } + + boost::signal onResponse; + + private: + RawRequest(IQ::Type type, const JID& receiver, const std::string& data, IQRouter* router) : Request(type, receiver, boost::make_shared(data), router) { + } + + virtual void handleResponse(Payload::ref payload, ErrorPayload::ref error) { + if (error) { + onResponse(ErrorSerializer().serializePayload(error)); + } + else { + assert(payload); + PayloadSerializer* serializer = serializers.getPayloadSerializer(payload); + assert(serializer); + onResponse(serializer->serialize(payload)); + } + } + + private: + FullPayloadSerializerCollection serializers; + }; +} diff --git a/Swiften/Queries/Request.cpp b/Swiften/Queries/Request.cpp index 35475c1..6c47725 100644 --- a/Swiften/Queries/Request.cpp +++ b/Swiften/Queries/Request.cpp @@ -6,6 +6,7 @@ #include "Swiften/Queries/Request.h" #include "Swiften/Queries/IQRouter.h" +#include namespace Swift { @@ -40,7 +41,11 @@ bool Request::handleIQ(boost::shared_ptr iq) { bool handled = false; if (sent_ && iq->getID() == id_) { if (iq->getType() == IQ::Result) { - handleResponse(iq->getPayloadOfSameType(payload_), ErrorPayload::ref()); + boost::shared_ptr payload = iq->getPayloadOfSameType(payload_); + if (!payload && boost::dynamic_pointer_cast(payload_) && !iq->getPayloads().empty()) { + payload = iq->getPayloads().front(); + } + handleResponse(payload, ErrorPayload::ref()); } else { ErrorPayload::ref errorPayload = iq->getPayload(); diff --git a/Swiften/Queries/UnitTest/RequestTest.cpp b/Swiften/Queries/UnitTest/RequestTest.cpp index e99149e..34d07c9 100644 --- a/Swiften/Queries/UnitTest/RequestTest.cpp +++ b/Swiften/Queries/UnitTest/RequestTest.cpp @@ -8,11 +8,13 @@ #include #include #include +#include #include "Swiften/Queries/GenericRequest.h" #include "Swiften/Queries/IQRouter.h" #include "Swiften/Queries/DummyIQChannel.h" #include "Swiften/Elements/Payload.h" +#include using namespace Swift; @@ -25,6 +27,8 @@ class RequestTest : public CppUnit::TestFixture { CPPUNIT_TEST(testHandleIQ_Error); CPPUNIT_TEST(testHandleIQ_ErrorWithoutPayload); CPPUNIT_TEST(testHandleIQ_BeforeSend); + CPPUNIT_TEST(testHandleIQ_DifferentPayload); + CPPUNIT_TEST(testHandleIQ_RawXMLPayload); CPPUNIT_TEST_SUITE_END(); public: @@ -34,7 +38,26 @@ class RequestTest : public CppUnit::TestFixture { std::string text_; }; - typedef GenericRequest MyRequest; + struct MyOtherPayload : public Payload { + }; + + class MyRequest : public Request { + public: + MyRequest( + IQ::Type type, + const JID& receiver, + boost::shared_ptr payload, + IQRouter* router) : + Request(type, receiver, payload, router) { + } + + virtual void handleResponse(boost::shared_ptr payload, ErrorPayload::ref error) { + onResponse(payload, error); + } + + public: + boost::signal, ErrorPayload::ref)> onResponse; + }; public: void setUp() { @@ -132,6 +155,33 @@ class RequestTest : public CppUnit::TestFixture { CPPUNIT_ASSERT_EQUAL(0, static_cast(channel_->iqs_.size())); } + void testHandleIQ_DifferentPayload() { + MyRequest testling(IQ::Get, JID("foo@bar.com/baz"), payload_, router_); + testling.onResponse.connect(boost::bind(&RequestTest::handleDifferentResponse, this, _1, _2)); + testling.send(); + + responsePayload_ = boost::make_shared(); + channel_->onIQReceived(createResponse("test-id")); + + CPPUNIT_ASSERT_EQUAL(1, responsesReceived_); + CPPUNIT_ASSERT_EQUAL(0, static_cast(receivedErrors.size())); + CPPUNIT_ASSERT_EQUAL(1, static_cast(channel_->iqs_.size())); + } + + void testHandleIQ_RawXMLPayload() { + payload_ = boost::make_shared(""); + MyRequest testling(IQ::Get, JID("foo@bar.com/baz"), payload_, router_); + testling.onResponse.connect(boost::bind(&RequestTest::handleRawXMLResponse, this, _1, _2)); + testling.send(); + + responsePayload_ = boost::make_shared(); + channel_->onIQReceived(createResponse("test-id")); + + CPPUNIT_ASSERT_EQUAL(1, responsesReceived_); + CPPUNIT_ASSERT_EQUAL(0, static_cast(receivedErrors.size())); + CPPUNIT_ASSERT_EQUAL(1, static_cast(channel_->iqs_.size())); + } + private: void handleResponse(boost::shared_ptr p, ErrorPayload::ref e) { if (e) { @@ -145,6 +195,19 @@ class RequestTest : public CppUnit::TestFixture { } } + void handleDifferentResponse(boost::shared_ptr p, ErrorPayload::ref e) { + CPPUNIT_ASSERT(!e); + CPPUNIT_ASSERT(!p); + ++responsesReceived_; + } + + void handleRawXMLResponse(boost::shared_ptr p, ErrorPayload::ref e) { + CPPUNIT_ASSERT(!e); + CPPUNIT_ASSERT(p); + CPPUNIT_ASSERT(boost::dynamic_pointer_cast(p)); + ++responsesReceived_; + } + boost::shared_ptr createResponse(const std::string& id) { boost::shared_ptr iq(new IQ(IQ::Result)); iq->addPayload(responsePayload_); diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp index 1736f80..ddb833e 100644 --- a/Swiften/Session/BasicSessionStream.cpp +++ b/Swiften/Session/BasicSessionStream.cpp @@ -88,6 +88,11 @@ void BasicSessionStream::writeFooter() { xmppLayer->writeFooter(); } +void BasicSessionStream::writeData(const std::string& data) { + assert(available); + xmppLayer->writeData(data); +} + void BasicSessionStream::close() { connection->disconnect(); } diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h index 747177a..b3f7421 100644 --- a/Swiften/Session/BasicSessionStream.h +++ b/Swiften/Session/BasicSessionStream.h @@ -42,6 +42,7 @@ namespace Swift { virtual void writeHeader(const ProtocolHeader& header); virtual void writeElement(boost::shared_ptr); virtual void writeFooter(); + virtual void writeData(const std::string& data); virtual void addZLibCompression(); diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h index 017d3d4..2753878 100644 --- a/Swiften/Session/SessionStream.h +++ b/Swiften/Session/SessionStream.h @@ -43,6 +43,7 @@ namespace Swift { virtual void writeHeader(const ProtocolHeader& header) = 0; virtual void writeFooter() = 0; virtual void writeElement(boost::shared_ptr) = 0; + virtual void writeData(const std::string& data) = 0; virtual void addZLibCompression() = 0; -- cgit v0.10.2-6-g49f6