From 81c09a0f6a3e87b078340d7f35d0dea4c03f3a6d Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Sat, 12 Nov 2011 16:56:21 +0000 Subject: BOSH Support for Swiften This adds support for BOSH to Swiften. It does not expose it to Swift. Release-Notes: Swiften now allows connects over BOSH, if used appropriately. diff --git a/Swift/Controllers/MainController.cpp b/Swift/Controllers/MainController.cpp index 2f9c42e..7ed53a2 100644 --- a/Swift/Controllers/MainController.cpp +++ b/Swift/Controllers/MainController.cpp @@ -504,12 +504,6 @@ void MainController::performLoginFromCachedCredentials() { ClientOptions clientOptions; clientOptions.forgetPassword = eagleMode_; clientOptions.useTLS = eagleMode_ ? ClientOptions::RequireTLS : ClientOptions::UseTLSWhenAvailable; - if (clientJID.getDomain() == "wonderland.lit") { - clientOptions.boshURL = URL("http", "192.168.1.185", 5280, "http-bind/"); - } - else if (clientJID.getDomain() == "prosody.doomsong.co.uk") { - clientOptions.boshURL = URL("http", "192.168.1.130", 5280, "http-bind/"); - } client_->connect(clientOptions); } diff --git a/Swiften/Base/URL.h b/Swiften/Base/URL.h new file mode 100644 index 0000000..7a5aa59 --- /dev/null +++ b/Swiften/Base/URL.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include + +namespace Swift { + +class URL { + public: + + URL() : scheme(""), user(""), password(""), host(""), port(-1), path(""), isEmpty(true) { + } + + URL(const std::string& scheme, const std::string& host, int port, const std::string& path) : scheme(scheme), user(), password(), host(host), port(port), path(path), isEmpty(false) { + + } + + /** + * Whether the URL is empty. + */ + bool empty() const { + return isEmpty; + } + + /** + * Scheme used for the URL (http, https etc.) + */ + const std::string& getScheme() const { + return scheme; + } + + /** + * Hostname + */ + const std::string& getHost() const { + return host; + } + + /** + * Port number + */ + int getPort() const { + return port; + } + + /** + * Path + */ + const std::string& getPath() const { + return path; + } + + + + private: + std::string scheme; + std::string user; + std::string password; + std::string host; + int port; + std::string path; + bool isEmpty; + }; +} diff --git a/Swiften/Client/ClientOptions.h b/Swiften/Client/ClientOptions.h index 3b51a87..06bf947 100644 --- a/Swiften/Client/ClientOptions.h +++ b/Swiften/Client/ClientOptions.h @@ -6,6 +6,9 @@ #pragma once +#include +#include + namespace Swift { struct ClientOptions { enum UseTLS { @@ -14,7 +17,7 @@ namespace Swift { RequireTLS }; - ClientOptions() : useStreamCompression(true), useTLS(UseTLSWhenAvailable), allowPLAINWithoutTLS(false), useStreamResumption(false), forgetPassword(false), useAcks(true) { + ClientOptions() : useStreamCompression(true), useTLS(UseTLSWhenAvailable), allowPLAINWithoutTLS(false), useStreamResumption(false), forgetPassword(false), useAcks(true), boshHTTPConnectProxyAuthID(""), boshHTTPConnectProxyAuthPassword("") { } /** @@ -61,5 +64,28 @@ namespace Swift { * Default: true */ bool useAcks; + + /** + * If non-empty, use BOSH instead of direct TCP, with the given URL. + * The host currently needs to be specified by IP, rather than hostname. + * Default: empty (no BOSH) + */ + URL boshURL; + + /** + * If non-empty, BOSH connections will try to connect over this HTTP CONNECT + * proxy instead of directly. + * Must be specified by IP, rather than hostname. + * Default: empty (no proxy) + */ + URL boshHTTPConnectProxyURL; + + /** + * If this and matching Password are non-empty, BOSH connections over + * HTTP CONNECT proxies will use these credentials for proxy access. + * Default: empty (no authentication needed by the proxy) + */ + SafeString boshHTTPConnectProxyAuthID; + SafeString boshHTTPConnectProxyAuthPassword; }; } diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index 55e0bc2..bfc9313 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -181,7 +181,7 @@ void ClientSession::handleElement(boost::shared_ptr element) { else if (useTLS == RequireTLS && !stream->isTLSEncrypted()) { finishSession(Error::NoSupportedAuthMechanismsError); } - else if (useStreamCompression && streamFeatures->hasCompressionMethod("zlib")) { + else if (useStreamCompression && stream->supportsZLibCompression() && streamFeatures->hasCompressionMethod("zlib")) { state = Compressing; stream->writeElement(boost::make_shared("zlib")); } diff --git a/Swiften/Client/ClientXMLTracer.cpp b/Swiften/Client/ClientXMLTracer.cpp index c1093eb..405e3d1 100644 --- a/Swiften/Client/ClientXMLTracer.cpp +++ b/Swiften/Client/ClientXMLTracer.cpp @@ -11,7 +11,7 @@ namespace Swift { -ClientXMLTracer::ClientXMLTracer(CoreClient* client) { +ClientXMLTracer::ClientXMLTracer(CoreClient* client, bool bosh) : bosh(bosh) { beautifier = new XMLBeautifier(true, true); client->onDataRead.connect(boost::bind(&ClientXMLTracer::printData, this, '<', _1)); client->onDataWritten.connect(boost::bind(&ClientXMLTracer::printData, this, '>', _1)); @@ -23,7 +23,20 @@ ClientXMLTracer::~ClientXMLTracer() { void ClientXMLTracer::printData(char direction, const SafeByteArray& data) { printLine(direction); - std::cerr << beautifier->beautify(byteArrayToString(ByteArray(data.begin(), data.end()))) << std::endl; + if (bosh) { + std::string line = byteArrayToString(ByteArray(data.begin(), data.end())); + size_t endOfHTTP = line.find("\r\n\r\n"); + if (false && endOfHTTP != std::string::npos) { + /* Disabled because it swallows bits of XML (namespaces, if I recall) */ + std::cerr << line.substr(0, endOfHTTP) << std::endl << beautifier->beautify(line.substr(endOfHTTP)) << std::endl; + } + else { + std::cerr << line << std::endl; + } + } + else { + std::cerr << beautifier->beautify(byteArrayToString(ByteArray(data.begin(), data.end()))) << std::endl; + } } void ClientXMLTracer::printLine(char c) { diff --git a/Swiften/Client/ClientXMLTracer.h b/Swiften/Client/ClientXMLTracer.h index 0752faa..67040c4 100644 --- a/Swiften/Client/ClientXMLTracer.h +++ b/Swiften/Client/ClientXMLTracer.h @@ -13,7 +13,7 @@ namespace Swift { class ClientXMLTracer { public: - ClientXMLTracer(CoreClient* client); + ClientXMLTracer(CoreClient* client, bool bosh = false); ~ClientXMLTracer(); private: void printData(char direction, const SafeByteArray& data); @@ -21,5 +21,6 @@ namespace Swift { private: XMLBeautifier *beautifier; + bool bosh; }; } diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index 08f31a0..cef2b24 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2010 Remko Tronçon + * Copyright (c) 2010-2011 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -70,15 +71,52 @@ void CoreClient::connect(const std::string& host) { proxyConnectionFactories.push_back(new HTTPConnectProxiedConnectionFactory(networkFactories->getConnectionFactory(), networkFactories->getProxyProvider()->getHTTPConnectProxy())); } std::vector connectionFactories(proxyConnectionFactories); - // connectionFactories.push_back(networkFactories->getConnectionFactory()); - connectionFactories.push_back(new BOSHConnectionFactory(networkFactories->getConnectionFactory())); + if (options.boshURL.empty()) { + connectionFactories.push_back(networkFactories->getConnectionFactory()); + connector_ = boost::make_shared(host, networkFactories->getDomainNameResolver(), connectionFactories, networkFactories->getTimerFactory()); + connector_->onConnectFinished.connect(boost::bind(&CoreClient::handleConnectorFinished, this, _1)); + connector_->setTimeoutMilliseconds(60*1000); + connector_->start(); + } + else { + /* Autodiscovery of which proxy works is largely ok with a TCP session, because this is a one-off. With BOSH + * it would be quite painful given that potentially every stanza could be sent on a new connection. + */ + //sessionStream_ = boost::make_shared(boost::make_shared(options.boshURL, networkFactories->getConnectionFactory(), networkFactories->getXMLParserFactory(), networkFactories->getTLSContextFactory()), getPayloadParserFactories(), getPayloadSerializers(), networkFactories->getTLSContextFactory(), networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory(), networkFactories->getEventLoop(), host, options.boshHTTPConnectProxyURL, options.boshHTTPConnectProxyAuthID, options.boshHTTPConnectProxyAuthPassword); + sessionStream_ = boost::shared_ptr(new BOSHSessionStream(boost::make_shared(options.boshURL, networkFactories->getConnectionFactory(), networkFactories->getXMLParserFactory(), networkFactories->getTLSContextFactory()), getPayloadParserFactories(), getPayloadSerializers(), networkFactories->getTLSContextFactory(), networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory(), networkFactories->getEventLoop(), host, options.boshHTTPConnectProxyURL, options.boshHTTPConnectProxyAuthID, options.boshHTTPConnectProxyAuthPassword)); + sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1)); + sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1)); + bindSessionToStream(); + } + +} - connector_ = boost::make_shared(host, networkFactories->getDomainNameResolver(), connectionFactories, networkFactories->getTimerFactory()); - connector_->onConnectFinished.connect(boost::bind(&CoreClient::handleConnectorFinished, this, _1)); - connector_->setTimeoutMilliseconds(60*1000); - connector_->start(); +void CoreClient::bindSessionToStream() { + session_ = ClientSession::create(jid_, sessionStream_); + session_->setCertificateTrustChecker(certificateTrustChecker); + session_->setUseStreamCompression(options.useStreamCompression); + session_->setAllowPLAINOverNonTLS(options.allowPLAINWithoutTLS); + switch(options.useTLS) { + case ClientOptions::UseTLSWhenAvailable: + session_->setUseTLS(ClientSession::UseTLSWhenAvailable); + break; + case ClientOptions::NeverUseTLS: + session_->setUseTLS(ClientSession::NeverUseTLS); + break; + case ClientOptions::RequireTLS: + session_->setUseTLS(ClientSession::RequireTLS); + break; + } + session_->setUseAcks(options.useAcks); + stanzaChannel_->setSession(session_); + session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1)); + session_->onNeedCredentials.connect(boost::bind(&CoreClient::handleNeedCredentials, this)); + session_->start(); } +/** + * Only called for TCP sessions. BOSH is handled inside the BOSHSessionStream. + */ void CoreClient::handleConnectorFinished(boost::shared_ptr connection) { resetConnector(); if (!connection) { @@ -99,26 +137,7 @@ void CoreClient::handleConnectorFinished(boost::shared_ptr connectio sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1)); sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1)); - session_ = ClientSession::create(jid_, sessionStream_); - session_->setCertificateTrustChecker(certificateTrustChecker); - session_->setUseStreamCompression(options.useStreamCompression); - session_->setAllowPLAINOverNonTLS(options.allowPLAINWithoutTLS); - switch(options.useTLS) { - case ClientOptions::UseTLSWhenAvailable: - session_->setUseTLS(ClientSession::UseTLSWhenAvailable); - break; - case ClientOptions::NeverUseTLS: - session_->setUseTLS(ClientSession::NeverUseTLS); - break; - case ClientOptions::RequireTLS: - session_->setUseTLS(ClientSession::RequireTLS); - break; - } - session_->setUseAcks(options.useAcks); - stanzaChannel_->setSession(session_); - session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1)); - session_->onNeedCredentials.connect(boost::bind(&CoreClient::handleNeedCredentials, this)); - session_->start(); + bindSessionToStream(); } } @@ -339,9 +358,14 @@ void CoreClient::resetSession() { sessionStream_->onDataRead.disconnect(boost::bind(&CoreClient::handleDataRead, this, _1)); sessionStream_->onDataWritten.disconnect(boost::bind(&CoreClient::handleDataWritten, this, _1)); - sessionStream_.reset(); - connection_->disconnect(); + if (connection_) { + connection_->disconnect(); + } + else if (boost::dynamic_pointer_cast(sessionStream_)) { + sessionStream_->close(); + } + sessionStream_.reset(); connection_.reset(); } diff --git a/Swiften/Client/CoreClient.h b/Swiften/Client/CoreClient.h index 3c089c1..c231fdc 100644 --- a/Swiften/Client/CoreClient.h +++ b/Swiften/Client/CoreClient.h @@ -29,7 +29,7 @@ namespace Swift { class ClientSession; class StanzaChannel; class Stanza; - class BasicSessionStream; + class SessionStream; class CertificateTrustChecker; class NetworkFactories; class ClientSessionStanzaChannel; @@ -207,6 +207,7 @@ namespace Swift { void handleMessageReceived(boost::shared_ptr); void handleStanzaAcked(boost::shared_ptr); void purgePassword(); + void bindSessionToStream(); void resetConnector(); void resetSession(); @@ -222,7 +223,7 @@ namespace Swift { boost::shared_ptr connector_; std::vector proxyConnectionFactories; boost::shared_ptr connection_; - boost::shared_ptr sessionStream_; + boost::shared_ptr sessionStream_; boost::shared_ptr session_; std::string certificate_; bool disconnectRequested_; diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index e9d1b21..22db8fc 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -403,6 +403,10 @@ class ClientSessionTest : public CppUnit::TestFixture { return boost::shared_ptr(); } + virtual bool supportsZLibCompression() { + return true; + } + virtual void addZLibCompression() { compressed = true; } diff --git a/Swiften/Component/UnitTest/ComponentSessionTest.cpp b/Swiften/Component/UnitTest/ComponentSessionTest.cpp index c27ade5..1541cce 100644 --- a/Swiften/Component/UnitTest/ComponentSessionTest.cpp +++ b/Swiften/Component/UnitTest/ComponentSessionTest.cpp @@ -142,6 +142,10 @@ class ComponentSessionTest : public CppUnit::TestFixture { return boost::shared_ptr(); } + virtual bool supportsZLibCompression() { + return true; + } + virtual void addZLibCompression() { assert(false); } diff --git a/Swiften/Network/BOSHConnection.cpp b/Swiften/Network/BOSHConnection.cpp index 549c652..09548e9 100644 --- a/Swiften/Network/BOSHConnection.cpp +++ b/Swiften/Network/BOSHConnection.cpp @@ -4,129 +4,281 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ -#include "BOSHConnection.h" +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include + #include #include +#include #include #include #include #include +#include #include #include -#include +#include +#include namespace Swift { - BOSHConnection::BOSHConnection(ConnectionFactory* connectionFactory) - : connectionFactory_(connectionFactory), server_(HostAddressPort(HostAddress("0.0.0.0"), 0)), sid_() - { - reopenAfterAction = true; - } +BOSHConnection::BOSHConnection(const URL& boshURL, ConnectionFactory* connectionFactory, XMLParserFactory* parserFactory, TLSContextFactory* tlsFactory) + : boshURL_(boshURL), + connectionFactory_(connectionFactory), + parserFactory_(parserFactory), + sid_(), + waitingForStartResponse_(false), + pending_(false), + tlsFactory_(tlsFactory), + connectionReady_(false) +{ +} - BOSHConnection::~BOSHConnection() { - if (newConnection_) { - newConnection_->onDataRead.disconnect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); - newConnection_->onDisconnected.disconnect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); - } - if (currentConnection_) { - currentConnection_->onDataRead.disconnect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); - currentConnection_->onDisconnected.disconnect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); - } +BOSHConnection::~BOSHConnection() { + if (connection_) { + connection_->onConnectFinished.disconnect(boost::bind(&BOSHConnection::handleConnectionConnectFinished, shared_from_this(), _1)); + connection_->onDataRead.disconnect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); + connection_->onDisconnected.disconnect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); } + disconnect(); +} - void BOSHConnection::connect(const HostAddressPort& server) { - server_ = server; - newConnection_ = connectionFactory_->createConnection(); - newConnection_->onConnectFinished.connect(boost::bind(&BOSHConnection::handleConnectionConnectFinished, shared_from_this(), _1)); - newConnection_->onDataRead.connect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); - newConnection_->onDisconnected.connect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); - SWIFT_LOG(debug) << "connect to server " << server.getAddress().toString() << ":" << server.getPort() << std::endl; - newConnection_->connect(HostAddressPort(HostAddress("85.10.192.88"), 5280)); - } +void BOSHConnection::connect(const HostAddressPort& server) { + /* FIXME: Redundant parameter */ + Connection::ref rawConnection = connectionFactory_->createConnection(); + connection_ = (boshURL_.getScheme() == "https") ? boost::make_shared(rawConnection, tlsFactory_) : rawConnection; + connection_->onConnectFinished.connect(boost::bind(&BOSHConnection::handleConnectionConnectFinished, shared_from_this(), _1)); + connection_->onDataRead.connect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); + connection_->onDisconnected.connect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); + connection_->connect(HostAddressPort(HostAddress(boshURL_.getHost()), boshURL_.getPort())); +} - void BOSHConnection::listen() { - assert(false); +void BOSHConnection::listen() { + assert(false); +} + +void BOSHConnection::disconnect() { + if(connection_) { + connection_->disconnect(); + sid_ = ""; } +} - void BOSHConnection::disconnect() { - if(newConnection_) - newConnection_->disconnect(); +void BOSHConnection::restartStream() { + write(createSafeByteArray(""), true, false); +} + +void BOSHConnection::terminateStream() { + write(createSafeByteArray(""), false, true); +} - if(currentConnection_) - currentConnection_->disconnect(); - } - void BOSHConnection::write(const SafeByteArray& data) { - SWIFT_LOG(debug) << "write data: " << safeByteArrayToString(data) << std::endl; +void BOSHConnection::write(const SafeByteArray& data) { + write(data, false, false); +} + +std::pair BOSHConnection::createHTTPRequest(const SafeByteArray& data, bool streamRestart, bool terminate, long rid, const std::string& sid, const URL& boshURL) { + size_t size; + std::stringstream content; + SafeByteArray contentTail = createSafeByteArray(""); + std::stringstream header; + + content << ""; - void BOSHConnection::handleConnectionConnectFinished(bool error) { - newConnection_->onConnectFinished.disconnect(boost::bind(&BOSHConnection::handleConnectionConnectFinished, shared_from_this(), _1)); - if(error) { - onConnectFinished(true); - return; - } + SafeByteArray safeContent = createSafeByteArray(content.str()); + safeContent.insert(safeContent.end(), data.begin(), data.end()); + safeContent.insert(safeContent.end(), contentTail.begin(), contentTail.end()); - if(sid_.size() == 0) { - // Session Creation Request - std::stringstream content; - std::stringstream header; - - content << "\r\n"; - - header << "POST /http-bind HTTP/1.1\r\n" - << "Host: 0x10.de:5280\r\n" - << "Accept-Encoding: deflate\r\n" - << "Content-Type: text/xml; charset=utf-8\r\n" - << "Content-Length: " << content.str().size() << "\r\n\r\n" - << content.str(); - - SWIFT_LOG(debug) << "request: "; - newConnection_->write(createSafeByteArray(header.str())); - } + size = safeContent.size(); + + header << "POST /" << boshURL.getPath() << " HTTP/1.1\r\n" + << "Host: " << boshURL.getHost() << ":" << boshURL.getPort() << "\r\n" + /*<< "Accept-Encoding: deflate\r\n"*/ + << "Content-Type: text/xml; charset=utf-8\r\n" + << "Content-Length: " << size << "\r\n\r\n"; + + SafeByteArray safeHeader = createSafeByteArray(header.str()); + safeHeader.insert(safeHeader.end(), safeContent.begin(), safeContent.end()); + + return std::pair(safeHeader, size); +} + +void BOSHConnection::write(const SafeByteArray& data, bool streamRestart, bool terminate) { + assert(connectionReady_); + assert(!sid_.empty()); + + SafeByteArray safeHeader = createHTTPRequest(data, streamRestart, terminate, rid_, sid_, boshURL_).first; + + onBOSHDataWritten(safeHeader); + connection_->write(safeHeader); + pending_ = true; + + SWIFT_LOG(debug) << "write data: " << safeByteArrayToString(safeHeader) << std::endl; +} + +void BOSHConnection::handleConnectionConnectFinished(bool error) { + connection_->onConnectFinished.disconnect(boost::bind(&BOSHConnection::handleConnectionConnectFinished, shared_from_this(), _1)); + connectionReady_ = !error; + onConnectFinished(error); +} + +void BOSHConnection::startStream(const std::string& to, unsigned long rid) { + assert(connectionReady_); + // Session Creation Request + std::stringstream content; + std::stringstream header; + + content << ""; + + std::string contentString = content.str(); + + header << "POST /" << boshURL_.getPath() << " HTTP/1.1\r\n" + << "Host: " << boshURL_.getHost() << ":" << boshURL_.getPort() << "\r\n" + /*<< "Accept-Encoding: deflate\r\n"*/ + << "Content-Type: text/xml; charset=utf-8\r\n" + << "Content-Length: " << contentString.size() << "\r\n\r\n" + << contentString; + + waitingForStartResponse_ = true; + SafeByteArray safeHeader = createSafeByteArray(header.str()); + onBOSHDataWritten(safeHeader); + connection_->write(safeHeader); + SWIFT_LOG(debug) << "write stream header: " << safeByteArrayToString(safeHeader) << std::endl; +} + +void BOSHConnection::handleDataRead(boost::shared_ptr data) { + onBOSHDataRead(*data.get()); + buffer_ = concat(buffer_, *data.get()); + std::string response = safeByteArrayToString(buffer_); + if (response.find("\r\n\r\n") == std::string::npos) { + onBOSHDataRead(createSafeByteArray("[[Previous read incomplete, pending]]")); + return; + } + + std::string httpCode = response.substr(response.find(" ") + 1, 3); + if (httpCode != "200") { + onHTTPError(httpCode); + return; } - void BOSHConnection::handleDataRead(const SafeByteArray& data) { - std::string response = safeByteArrayToString(data); - assert(response.find("\r\n\r\n") != std::string::npos); - - SWIFT_LOG(debug) << "response: " << response.substr(response.find("\r\n\r\n") + 4) << std::endl; - - BOSHParser parser; - if(parser.parse(response.substr(response.find("\r\n\r\n") + 4))) { - sid_ = parser.getAttribute("sid"); - onConnectFinished(false); - int bodyStartElementLength = 0; - bool inQuote = false; - for(size_t i= 0; i < response.size(); i++) { - if(response.c_str()[i] == '\'' || response.c_str()[i] == '"') { - inQuote = !inQuote; - } - else if(!inQuote && response.c_str()[i] == '>') { - bodyStartElementLength = i + 1; - break; - } + BOSHBodyExtractor parser(parserFactory_, createByteArray(response.substr(response.find("\r\n\r\n") + 4))); + if (parser.getBody()) { + if ((*parser.getBody()).attributes.getAttribute("type") == "terminate") { + BOSHError::Type errorType = parseTerminationCondition((*parser.getBody()).attributes.getAttribute("condition")); + onSessionTerminated(errorType == BOSHError::NoError ? boost::shared_ptr() : boost::make_shared(errorType)); + } + buffer_.clear(); + if (waitingForStartResponse_) { + waitingForStartResponse_ = false; + sid_ = (*parser.getBody()).attributes.getAttribute("sid"); + std::string requestsString = (*parser.getBody()).attributes.getAttribute("requests"); + int requests = 2; + if (!requestsString.empty()) { + requests = boost::lexical_cast(requestsString); } - SafeByteArray payload = createSafeByteArray(response.substr(bodyStartElementLength, response.size() - bodyStartElementLength - 7)); - SWIFT_LOG(debug) << "payload: " << safeByteArrayToString(payload) << std::endl; - onDataRead(payload); + onSessionStarted(sid_, requests); } + SafeByteArray payload = createSafeByteArray((*parser.getBody()).content); + /* Say we're good to go again, so don't add anything after here in the method */ + pending_ = false; + onXMPPDataRead(payload); } - void BOSHConnection::handleDisconnected(const boost::optional& error) { - onDisconnected(error); - } +} - HostAddressPort BOSHConnection::getLocalAddress() const { - return newConnection_->getLocalAddress(); +BOSHError::Type BOSHConnection::parseTerminationCondition(const std::string& text) { + BOSHError::Type condition = BOSHError::UndefinedCondition; + if (text == "bad-request") { + condition = BOSHError::BadRequest; + } + else if (text == "host-gone") { + condition = BOSHError::HostGone; + } + else if (text == "host-unknown") { + condition = BOSHError::HostUnknown; + } + else if (text == "improper-addressing") { + condition = BOSHError::ImproperAddressing; + } + else if (text == "internal-server-error") { + condition = BOSHError::InternalServerError; + } + else if (text == "item-not-found") { + condition = BOSHError::ItemNotFound; + } + else if (text == "other-request") { + condition = BOSHError::OtherRequest; + } + else if (text == "policy-violation") { + condition = BOSHError::PolicyViolation; + } + else if (text == "remote-connection-failed") { + condition = BOSHError::RemoteConnectionFailed; + } + else if (text == "remote-stream-error") { + condition = BOSHError::RemoteStreamError; } + else if (text == "see-other-uri") { + condition = BOSHError::SeeOtherURI; + } + else if (text == "system-shutdown") { + condition = BOSHError::SystemShutdown; + } + else if (text == "") { + condition = BOSHError::NoError; + } + return condition; +} + +const std::string& BOSHConnection::getSID() { + return sid_; +} + +void BOSHConnection::setRID(unsigned long rid) { + rid_ = rid; +} + +void BOSHConnection::setSID(const std::string& sid) { + sid_ = sid; +} + +void BOSHConnection::handleDisconnected(const boost::optional& error) { + onDisconnected(error); + sid_ = ""; + connectionReady_ = false; +} + +HostAddressPort BOSHConnection::getLocalAddress() const { + return connection_->getLocalAddress(); +} + +bool BOSHConnection::isReadyToSend() { + /* Without pipelining you need to not send more without first receiving the response */ + /* With pipelining you can. Assuming we can't, here */ + return connectionReady_ && !pending_ && !waitingForStartResponse_ && !sid_.empty(); +} + } diff --git a/Swiften/Network/BOSHConnection.h b/Swiften/Network/BOSHConnection.h index 0da92ba..283ea10 100644 --- a/Swiften/Network/BOSHConnection.h +++ b/Swiften/Network/BOSHConnection.h @@ -4,6 +4,13 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + + #pragma once #include @@ -11,6 +18,9 @@ #include #include #include +#include +#include +#include namespace boost { class thread; @@ -21,32 +31,73 @@ namespace boost { namespace Swift { class ConnectionFactory; + class XMLParserFactory; + class TLSContextFactory; + + class BOSHError : public SessionStream::Error { + public: + enum Type {BadRequest, HostGone, HostUnknown, ImproperAddressing, + InternalServerError, ItemNotFound, OtherRequest, PolicyViolation, + RemoteConnectionFailed, RemoteStreamError, SeeOtherURI, SystemShutdown, UndefinedCondition, + NoError}; + BOSHError(Type type) : SessionStream::Error(SessionStream::Error::ConnectionReadError), type(type) {} + Type getType() {return type;} + typedef boost::shared_ptr ref; + private: + Type type; + + }; + class BOSHConnection : public Connection, public boost::enable_shared_from_this { public: typedef boost::shared_ptr ref; - static ref create(ConnectionFactory* connectionFactory) { - return ref(new BOSHConnection(connectionFactory)); + static ref create(const URL& boshURL, ConnectionFactory* connectionFactory, XMLParserFactory* parserFactory, TLSContextFactory* tlsFactory) { + return ref(new BOSHConnection(boshURL, connectionFactory, parserFactory, tlsFactory)); } virtual ~BOSHConnection(); virtual void listen(); virtual void connect(const HostAddressPort& address); virtual void disconnect(); virtual void write(const SafeByteArray& data); + virtual HostAddressPort getLocalAddress() const; + const std::string& getSID(); + void setRID(unsigned long rid); + void setSID(const std::string& sid); + void startStream(const std::string& to, unsigned long rid); + void terminateStream(); + bool isReadyToSend(); + void restartStream(); + static std::pair createHTTPRequest(const SafeByteArray& data, bool streamRestart, bool terminate, long rid, const std::string& sid, const URL& boshURL); + + boost::signal onSessionTerminated; + boost::signal onSessionStarted; + boost::signal onXMPPDataRead; + boost::signal onBOSHDataRead; + boost::signal onBOSHDataWritten; + boost::signal onHTTPError; private: - BOSHConnection(ConnectionFactory* connectionFactory); + BOSHConnection(const URL& boshURL, ConnectionFactory* connectionFactory, XMLParserFactory* parserFactory, TLSContextFactory* tlsFactory); + void handleConnectionConnectFinished(bool error); - void handleDataRead(const SafeByteArray& data); + void handleDataRead(boost::shared_ptr data); void handleDisconnected(const boost::optional& error); + void write(const SafeByteArray& data, bool streamRestart, bool terminate); /* FIXME: refactor */ + BOSHError::Type parseTerminationCondition(const std::string& text); - bool reopenAfterAction; + URL boshURL_; ConnectionFactory* connectionFactory_; - HostAddressPort server_; - boost::shared_ptr newConnection_; - boost::shared_ptr currentConnection_; + XMLParserFactory* parserFactory_; + boost::shared_ptr connection_; std::string sid_; + bool waitingForStartResponse_; + unsigned long rid_; + SafeByteArray buffer_; + bool pending_; + TLSContextFactory* tlsFactory_; + bool connectionReady_; }; } diff --git a/Swiften/Network/BOSHConnectionFactory.cpp b/Swiften/Network/BOSHConnectionFactory.cpp index 4c49cae..7b83034 100644 --- a/Swiften/Network/BOSHConnectionFactory.cpp +++ b/Swiften/Network/BOSHConnectionFactory.cpp @@ -4,18 +4,23 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ -#include "BOSHConnectionFactory.h" +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include #include namespace Swift { -BOSHConnectionFactory::BOSHConnectionFactory(ConnectionFactory* connectionFactory) { - connectionFactory_ = connectionFactory; +BOSHConnectionFactory::BOSHConnectionFactory(const URL& boshURL, ConnectionFactory* connectionFactory, XMLParserFactory* xmlParserFactory, TLSContextFactory* tlsFactory) : boshURL(boshURL), connectionFactory(connectionFactory), xmlParserFactory(xmlParserFactory), tlsFactory(tlsFactory) { } -boost::shared_ptr BOSHConnectionFactory::createConnection() { - return BOSHConnection::create(connectionFactory_); +boost::shared_ptr BOSHConnectionFactory::createConnection(ConnectionFactory* overrideFactory) { + return BOSHConnection::create(boshURL, overrideFactory != NULL ? overrideFactory : connectionFactory, xmlParserFactory, tlsFactory); } } diff --git a/Swiften/Network/BOSHConnectionFactory.h b/Swiften/Network/BOSHConnectionFactory.h index 7431cf4..3750057 100644 --- a/Swiften/Network/BOSHConnectionFactory.h +++ b/Swiften/Network/BOSHConnectionFactory.h @@ -4,19 +4,40 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + #pragma once +#include + #include #include +#include +#include namespace Swift { - class BOSHConnectionFactory : public ConnectionFactory { - public: - BOSHConnectionFactory(ConnectionFactory* connectionFactory); - virtual boost::shared_ptr createConnection(); +class XMLParserFactory; + +class BOSHConnectionFactory { + public: + BOSHConnectionFactory(const URL& boshURL, ConnectionFactory* connectionFactory, XMLParserFactory* xmlParserFactory, TLSContextFactory* tlsFactory); + + /** + * @param overrideFactory If non-NULL, creates a connection over the given factory instead. + */ + boost::shared_ptr createConnection(ConnectionFactory* overrideFactory); + ConnectionFactory* getRawConnectionFactory() {return connectionFactory;} + TLSContextFactory* getTLSContextFactory() {return tlsFactory;} + private: + URL boshURL; + ConnectionFactory* connectionFactory; + XMLParserFactory* xmlParserFactory; + TLSContextFactory* tlsFactory; +}; - private: - ConnectionFactory* connectionFactory_; - }; } diff --git a/Swiften/Network/BOSHConnectionPool.cpp b/Swiften/Network/BOSHConnectionPool.cpp new file mode 100644 index 0000000..6c3ba7e --- /dev/null +++ b/Swiften/Network/BOSHConnectionPool.cpp @@ -0,0 +1,247 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ +#include + +#include + +#include +#include + +#include +#include +#include +#include + +namespace Swift { +BOSHConnectionPool::BOSHConnectionPool(boost::shared_ptr connectionFactory, const std::string& to, long initialRID, const URL& boshHTTPConnectProxyURL, const SafeString& boshHTTPConnectProxyAuthID, const SafeString& boshHTTPConnectProxyAuthPassword) + : connectionFactory(connectionFactory), + rid(initialRID), + pendingTerminate(false), + to(to), + requestLimit(2), + restartCount(0), + pendingRestart(false) { + tlsConnectionFactory = NULL; + if (boshHTTPConnectProxyURL.empty()) { + connectProxyFactory = NULL; + } + else { + ConnectionFactory* rawFactory = connectionFactory->getRawConnectionFactory(); + if (boshHTTPConnectProxyURL.getScheme() == "https") { + tlsConnectionFactory = new TLSConnectionFactory(connectionFactory->getTLSContextFactory(), rawFactory); + rawFactory = tlsConnectionFactory; + } + connectProxyFactory = new HTTPConnectProxiedConnectionFactory(rawFactory, HostAddressPort(HostAddress(boshHTTPConnectProxyURL.getHost()), boshHTTPConnectProxyURL.getPort()), boshHTTPConnectProxyAuthID, boshHTTPConnectProxyAuthPassword); + } + createConnection(); +} + +BOSHConnectionPool::~BOSHConnectionPool() { + close(); + delete connectProxyFactory; + delete tlsConnectionFactory; +} + +void BOSHConnectionPool::write(const SafeByteArray& data) { + dataQueue.push_back(data); + tryToSendQueuedData(); +} + +void BOSHConnectionPool::handleDataRead(const SafeByteArray& data) { + onXMPPDataRead(data); + tryToSendQueuedData(); /* Will rebalance the connections */ +} + +void BOSHConnectionPool::restartStream() { + BOSHConnection::ref connection = getSuitableConnection(); + if (connection) { + pendingRestart = false; + rid++; + connection->setRID(rid); + connection->restartStream(); + restartCount++; + } + else { + pendingRestart = true; + } +} + +void BOSHConnectionPool::writeFooter() { + pendingTerminate = true; + tryToSendQueuedData(); +} + +void BOSHConnectionPool::close() { + /* TODO: Send a terminate here. */ + std::vector connectionCopies = connections; + foreach (BOSHConnection::ref connection, connectionCopies) { + if (connection) { + connection->disconnect(); + destroyConnection(connection); + } + } +} + +void BOSHConnectionPool::handleSessionStarted(const std::string& sessionID, size_t requests) { + sid = sessionID; + requestLimit = requests; + onSessionStarted(); +} + +void BOSHConnectionPool::handleConnectFinished(bool error, BOSHConnection::ref connection) { + if (error) { + onSessionTerminated(boost::make_shared(BOSHError::UndefinedCondition)); + /*TODO: We can probably manage to not terminate the stream here and use the rid/ack retry + * logic to just swallow the error and try again (some number of tries). + */ + } + else { + if (sid.empty()) { + connection->startStream(to, rid); + } + if (pendingRestart) { + restartStream(); + } + tryToSendQueuedData(); + } +} + +BOSHConnection::ref BOSHConnectionPool::getSuitableConnection() { + BOSHConnection::ref suitableConnection; + foreach (BOSHConnection::ref connection, connections) { + if (connection->isReadyToSend()) { + suitableConnection = connection; + break; + } + } + + if (!suitableConnection && connections.size() < requestLimit) { + /* This is not a suitable connection because it won't have yet connected and added TLS if needed. */ + BOSHConnection::ref newConnection = createConnection(); + newConnection->setSID(sid); + } + assert(connections.size() <= requestLimit); + assert((!suitableConnection) || suitableConnection->isReadyToSend()); + return suitableConnection; +} + +void BOSHConnectionPool::tryToSendQueuedData() { + if (sid.empty()) { + /* If we've not got as far as stream start yet, pend */ + return; + } + + BOSHConnection::ref suitableConnection = getSuitableConnection(); + bool sent = false; + bool toSend = !dataQueue.empty(); + if (suitableConnection) { + if (toSend) { + rid++; + suitableConnection->setRID(rid); + SafeByteArray data; + foreach (const SafeByteArray& datum, dataQueue) { + data.insert(data.end(), datum.begin(), datum.end()); + } + suitableConnection->write(data); + sent = true; + dataQueue.clear(); + } + else if (pendingTerminate) { + rid++; + suitableConnection->setRID(rid); + suitableConnection->terminateStream(); + sent = true; + onSessionTerminated(boost::shared_ptr()); + } + } + if (!pendingTerminate) { + /* Ensure there's always a session waiting to read data for us */ + bool pending = false; + foreach (BOSHConnection::ref connection, connections) { + if (connection && !connection->isReadyToSend()) { + pending = true; + } + } + if (!pending) { + if (restartCount >= 1) { + /* Don't open a second connection until we've restarted the stream twice - i.e. we've authed and resource bound.*/ + if (suitableConnection) { + rid++; + suitableConnection->setRID(rid); + suitableConnection->write(createSafeByteArray("")); + } + else { + /* My thought process I went through when writing this, to aid anyone else confused why this can happen... + * + * What to do here? I think this isn't possible. + If you didn't have two connections, suitable would have made one. + If you have two connections and neither is suitable, pending would be true. + If you have a non-pending connection, it's suitable. + + If I decide to do something here, remove assert above. + + Ah! Yes, because there's a period between creating the connection and it being connected. */ + } + } + } + } +} + +void BOSHConnectionPool::handleHTTPError(const std::string& /*errorCode*/) { + handleSessionTerminated(boost::make_shared(BOSHError::UndefinedCondition)); +} + +void BOSHConnectionPool::handleConnectionDisconnected(const boost::optional& error, BOSHConnection::ref connection) { + destroyConnection(connection); + if (false && error) { + handleSessionTerminated(boost::make_shared(BOSHError::UndefinedCondition)); + } + else { + /* We might have just freed up a connection slot to send with */ + tryToSendQueuedData(); + } +} + +boost::shared_ptr BOSHConnectionPool::createConnection() { + BOSHConnection::ref connection = boost::dynamic_pointer_cast(connectionFactory->createConnection(connectProxyFactory)); + connection->onXMPPDataRead.connect(boost::bind(&BOSHConnectionPool::handleDataRead, this, _1)); + connection->onSessionStarted.connect(boost::bind(&BOSHConnectionPool::handleSessionStarted, this, _1, _2)); + connection->onBOSHDataRead.connect(boost::bind(&BOSHConnectionPool::handleBOSHDataRead, this, _1)); + connection->onBOSHDataWritten.connect(boost::bind(&BOSHConnectionPool::handleBOSHDataWritten, this, _1)); + connection->onDisconnected.connect(boost::bind(&BOSHConnectionPool::handleConnectionDisconnected, this, _1, connection)); + connection->onConnectFinished.connect(boost::bind(&BOSHConnectionPool::handleConnectFinished, this, _1, connection)); + connection->onSessionTerminated.connect(boost::bind(&BOSHConnectionPool::handleSessionTerminated, this, _1)); + connection->onHTTPError.connect(boost::bind(&BOSHConnectionPool::handleHTTPError, this, _1)); + connection->connect(HostAddressPort(HostAddress("0.0.0.0"), 0)); + connections.push_back(connection); + return connection; +} + +void BOSHConnectionPool::destroyConnection(boost::shared_ptr connection) { + connections.erase(std::remove(connections.begin(), connections.end(), connection), connections.end()); + connection->onXMPPDataRead.disconnect(boost::bind(&BOSHConnectionPool::handleDataRead, this, _1)); + connection->onSessionStarted.disconnect(boost::bind(&BOSHConnectionPool::handleSessionStarted, this, _1, _2)); + connection->onBOSHDataRead.disconnect(boost::bind(&BOSHConnectionPool::handleBOSHDataRead, this, _1)); + connection->onBOSHDataWritten.disconnect(boost::bind(&BOSHConnectionPool::handleBOSHDataWritten, this, _1)); + connection->onDisconnected.disconnect(boost::bind(&BOSHConnectionPool::handleConnectionDisconnected, this, _1, connection)); + connection->onConnectFinished.disconnect(boost::bind(&BOSHConnectionPool::handleConnectFinished, this, _1, connection)); + connection->onSessionTerminated.disconnect(boost::bind(&BOSHConnectionPool::handleSessionTerminated, this, _1)); + connection->onHTTPError.disconnect(boost::bind(&BOSHConnectionPool::handleHTTPError, this, _1)); +} + +void BOSHConnectionPool::handleSessionTerminated(BOSHError::ref error) { + onSessionTerminated(error); +} + +void BOSHConnectionPool::handleBOSHDataRead(const SafeByteArray& data) { + onBOSHDataRead(data); +} + +void BOSHConnectionPool::handleBOSHDataWritten(const SafeByteArray& data) { + onBOSHDataWritten(data); +} + +} diff --git a/Swiften/Network/BOSHConnectionPool.h b/Swiften/Network/BOSHConnectionPool.h new file mode 100644 index 0000000..85e598d --- /dev/null +++ b/Swiften/Network/BOSHConnectionPool.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + + +#pragma once + +#include + +#include +#include +#include + +namespace Swift { + class HTTPConnectProxiedConnectionFactory; + class TLSConnectionFactory; + class BOSHConnectionPool : public boost::bsignals::trackable { + public: + BOSHConnectionPool(boost::shared_ptr factory, const std::string& to, long initialRID, const URL& boshHTTPConnectProxyURL, const SafeString& boshHTTPConnectProxyAuthID, const SafeString& boshHTTPConnectProxyAuthPassword); + ~BOSHConnectionPool(); + void write(const SafeByteArray& data); + void writeFooter(); + void close(); + void restartStream(); + + boost::signal onSessionTerminated; + boost::signal onSessionStarted; + boost::signal onXMPPDataRead; + boost::signal onBOSHDataRead; + boost::signal onBOSHDataWritten; + + private: + void handleDataRead(const SafeByteArray& data); + void handleSessionStarted(const std::string& sid, size_t requests); + void handleBOSHDataRead(const SafeByteArray& data); + void handleBOSHDataWritten(const SafeByteArray& data); + void handleSessionTerminated(BOSHError::ref condition); + void handleConnectFinished(bool, BOSHConnection::ref connection); + void handleConnectionDisconnected(const boost::optional& error, BOSHConnection::ref connection); + void handleHTTPError(const std::string& errorCode); + + private: + BOSHConnection::ref createConnection(); + void destroyConnection(BOSHConnection::ref connection); + void tryToSendQueuedData(); + BOSHConnection::ref getSuitableConnection(); + + private: + boost::shared_ptr connectionFactory; + std::vector connections; + std::string sid; + unsigned long rid; + std::vector dataQueue; + bool pendingTerminate; + std::string to; + size_t requestLimit; + int restartCount; + bool pendingRestart; + HTTPConnectProxiedConnectionFactory* connectProxyFactory; + TLSConnectionFactory* tlsConnectionFactory; + }; +} diff --git a/Swiften/Network/BoostNetworkFactories.cpp b/Swiften/Network/BoostNetworkFactories.cpp index 2b4c04b..488e519 100644 --- a/Swiften/Network/BoostNetworkFactories.cpp +++ b/Swiften/Network/BoostNetworkFactories.cpp @@ -17,7 +17,7 @@ namespace Swift { -BoostNetworkFactories::BoostNetworkFactories(EventLoop* eventLoop) { +BoostNetworkFactories::BoostNetworkFactories(EventLoop* eventLoop) : eventLoop(eventLoop){ timerFactory = new BoostTimerFactory(ioServiceThread.getIOService(), eventLoop); connectionFactory = new BoostConnectionFactory(ioServiceThread.getIOService(), eventLoop); domainNameResolver = new PlatformDomainNameResolver(eventLoop); diff --git a/Swiften/Network/BoostNetworkFactories.h b/Swiften/Network/BoostNetworkFactories.h index 3d268d1..c9b12da 100644 --- a/Swiften/Network/BoostNetworkFactories.h +++ b/Swiften/Network/BoostNetworkFactories.h @@ -53,6 +53,10 @@ namespace Swift { return proxyProvider; } + virtual EventLoop* getEventLoop() const { + return eventLoop; + } + private: BoostIOServiceThread ioServiceThread; TimerFactory* timerFactory; @@ -63,5 +67,6 @@ namespace Swift { XMLParserFactory* xmlParserFactory; PlatformTLSFactories* tlsFactories; ProxyProvider* proxyProvider; + EventLoop* eventLoop; }; } diff --git a/Swiften/Network/HTTPConnectProxiedConnection.cpp b/Swiften/Network/HTTPConnectProxiedConnection.cpp index e05a933..3e6c986 100644 --- a/Swiften/Network/HTTPConnectProxiedConnection.cpp +++ b/Swiften/Network/HTTPConnectProxiedConnection.cpp @@ -4,6 +4,13 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + + #include #include @@ -11,15 +18,17 @@ #include #include +#include #include #include #include #include #include +#include using namespace Swift; -HTTPConnectProxiedConnection::HTTPConnectProxiedConnection(ConnectionFactory* connectionFactory, HostAddressPort proxy) : connectionFactory_(connectionFactory), proxy_(proxy), server_(HostAddressPort(HostAddress("0.0.0.0"), 0)) { +HTTPConnectProxiedConnection::HTTPConnectProxiedConnection(ConnectionFactory* connectionFactory, HostAddressPort proxy, const SafeString& authID, const SafeString& authPassword) : connectionFactory_(connectionFactory), proxy_(proxy), server_(HostAddressPort(HostAddress("0.0.0.0"), 0)), authID_(authID), authPassword_(authPassword) { connected_ = false; } @@ -65,8 +74,18 @@ void HTTPConnectProxiedConnection::handleConnectionConnectFinished(bool error) { connection_->onConnectFinished.disconnect(boost::bind(&HTTPConnectProxiedConnection::handleConnectionConnectFinished, shared_from_this(), _1)); if (!error) { std::stringstream connect; - connect << "CONNECT " << server_.getAddress().toString() << ":" << server_.getPort() << " HTTP/1.1\r\n\r\n"; - connection_->write(createSafeByteArray(connect.str())); + connect << "CONNECT " << server_.getAddress().toString() << ":" << server_.getPort() << " HTTP/1.1\r\n"; + SafeByteArray data = createSafeByteArray(connect.str()); + if (!authID_.empty() && !authPassword_.empty()) { + append(data, createSafeByteArray("Proxy-Authorization: Basic ")); + SafeByteArray credentials = authID_; + append(credentials, createSafeByteArray(":")); + append(credentials, authPassword_); + append(data, Base64::encode(credentials)); + append(data, createSafeByteArray("\r\n")); + } + append(data, createSafeByteArray("\r\n")); + connection_->write(data); } else { onConnectFinished(true); diff --git a/Swiften/Network/HTTPConnectProxiedConnection.h b/Swiften/Network/HTTPConnectProxiedConnection.h index d3f5b7a..02d3edd 100644 --- a/Swiften/Network/HTTPConnectProxiedConnection.h +++ b/Swiften/Network/HTTPConnectProxiedConnection.h @@ -4,12 +4,20 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + + #pragma once #include #include #include +#include namespace boost { class thread; @@ -27,8 +35,8 @@ namespace Swift { ~HTTPConnectProxiedConnection(); - static ref create(ConnectionFactory* connectionFactory, HostAddressPort proxy) { - return ref(new HTTPConnectProxiedConnection(connectionFactory, proxy)); + static ref create(ConnectionFactory* connectionFactory, HostAddressPort proxy, const SafeString& authID, const SafeString& authPassword) { + return ref(new HTTPConnectProxiedConnection(connectionFactory, proxy, authID, authPassword)); } virtual void listen(); @@ -38,7 +46,7 @@ namespace Swift { virtual HostAddressPort getLocalAddress() const; private: - HTTPConnectProxiedConnection(ConnectionFactory* connectionFactory, HostAddressPort proxy); + HTTPConnectProxiedConnection(ConnectionFactory* connectionFactory, HostAddressPort proxy, const SafeString& authID, const SafeString& authPassword); void handleConnectionConnectFinished(bool error); void handleDataRead(boost::shared_ptr data); @@ -49,6 +57,8 @@ namespace Swift { ConnectionFactory* connectionFactory_; HostAddressPort proxy_; HostAddressPort server_; + SafeByteArray authID_; + SafeByteArray authPassword_; boost::shared_ptr connection_; }; } diff --git a/Swiften/Network/HTTPConnectProxiedConnectionFactory.cpp b/Swiften/Network/HTTPConnectProxiedConnectionFactory.cpp index ab7f18e..6ad0228 100644 --- a/Swiften/Network/HTTPConnectProxiedConnectionFactory.cpp +++ b/Swiften/Network/HTTPConnectProxiedConnectionFactory.cpp @@ -10,11 +10,15 @@ namespace Swift { -HTTPConnectProxiedConnectionFactory::HTTPConnectProxiedConnectionFactory(ConnectionFactory* connectionFactory, const HostAddressPort& proxy) : connectionFactory_(connectionFactory), proxy_(proxy) { +HTTPConnectProxiedConnectionFactory::HTTPConnectProxiedConnectionFactory(ConnectionFactory* connectionFactory, const HostAddressPort& proxy) : connectionFactory_(connectionFactory), proxy_(proxy), authID_(""), authPassword_("") { +} + + +HTTPConnectProxiedConnectionFactory::HTTPConnectProxiedConnectionFactory(ConnectionFactory* connectionFactory, const HostAddressPort& proxy, const SafeString& authID, const SafeString& authPassword) : connectionFactory_(connectionFactory), proxy_(proxy), authID_(authID), authPassword_(authPassword) { } boost::shared_ptr HTTPConnectProxiedConnectionFactory::createConnection() { - return HTTPConnectProxiedConnection::create(connectionFactory_, proxy_); + return HTTPConnectProxiedConnection::create(connectionFactory_, proxy_, authID_, authPassword_); } } diff --git a/Swiften/Network/HTTPConnectProxiedConnectionFactory.h b/Swiften/Network/HTTPConnectProxiedConnectionFactory.h index b475586..ef3af66 100644 --- a/Swiften/Network/HTTPConnectProxiedConnectionFactory.h +++ b/Swiften/Network/HTTPConnectProxiedConnectionFactory.h @@ -8,16 +8,20 @@ #include #include +#include namespace Swift { class HTTPConnectProxiedConnectionFactory : public ConnectionFactory { public: HTTPConnectProxiedConnectionFactory(ConnectionFactory* connectionFactory, const HostAddressPort& proxy); + HTTPConnectProxiedConnectionFactory(ConnectionFactory* connectionFactory, const HostAddressPort& proxy, const SafeString& authID, const SafeString& authPassword); virtual boost::shared_ptr createConnection(); private: ConnectionFactory* connectionFactory_; HostAddressPort proxy_; + SafeString authID_; + SafeString authPassword_; }; } diff --git a/Swiften/Network/NetworkFactories.h b/Swiften/Network/NetworkFactories.h index 6eba2f3..ebb6d62 100644 --- a/Swiften/Network/NetworkFactories.h +++ b/Swiften/Network/NetworkFactories.h @@ -16,6 +16,7 @@ namespace Swift { class TLSContextFactory; class CertificateFactory; class ProxyProvider; + class EventLoop; /** * An interface collecting network factories. @@ -32,5 +33,6 @@ namespace Swift { virtual XMLParserFactory* getXMLParserFactory() const = 0; virtual TLSContextFactory* getTLSContextFactory() const = 0; virtual ProxyProvider* getProxyProvider() const = 0; + virtual EventLoop* getEventLoop() const {}; }; } diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript index 399cec8..4a5370f 100644 --- a/Swiften/Network/SConscript +++ b/Swiften/Network/SConscript @@ -16,7 +16,8 @@ sourceList = [ "BoostConnectionServerFactory.cpp", "BoostIOServiceThread.cpp", "BOSHConnection.cpp", - "BOSHConnectionFactory.cpp" + "BOSHConnectionPool.cpp", + "BOSHConnectionFactory.cpp", "ConnectionFactory.cpp", "ConnectionServer.cpp", "ConnectionServerFactory.cpp", @@ -41,6 +42,8 @@ sourceList = [ "BoostNetworkFactories.cpp", "NetworkEnvironment.cpp", "Timer.cpp", + "TLSConnection.cpp", + "TLSConnectionFactory.cpp", "BoostTimer.cpp", "ProxyProvider.cpp", "NullProxyProvider.cpp", diff --git a/Swiften/Network/TLSConnection.cpp b/Swiften/Network/TLSConnection.cpp new file mode 100644 index 0000000..543ee1e --- /dev/null +++ b/Swiften/Network/TLSConnection.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include + +#include + +#include +#include +#include + +namespace Swift { + +TLSConnection::TLSConnection(Connection::ref connection, TLSContextFactory* tlsFactory) : connection(connection) { + context = tlsFactory->createTLSContext(); + context->onDataForNetwork.connect(boost::bind(&TLSConnection::handleTLSDataForNetwork, this, _1)); + context->onDataForApplication.connect(boost::bind(&TLSConnection::handleTLSDataForApplication, this, _1)); + context->onConnected.connect(boost::bind(&TLSConnection::handleTLSConnectFinished, this, false)); + context->onError.connect(boost::bind(&TLSConnection::handleTLSConnectFinished, this, true)); + + connection->onConnectFinished.connect(boost::bind(&TLSConnection::handleRawConnectFinished, this, _1)); + connection->onDataRead.connect(boost::bind(&TLSConnection::handleRawDataRead, this, _1)); + connection->onDataWritten.connect(boost::bind(&TLSConnection::handleRawDataWritten, this)); + connection->onDisconnected.connect(boost::bind(&TLSConnection::handleRawDisconnected, this, _1)); +} + +TLSConnection::~TLSConnection() { + connection->onConnectFinished.disconnect(boost::bind(&TLSConnection::handleRawConnectFinished, this, _1)); + connection->onDataRead.disconnect(boost::bind(&TLSConnection::handleRawDataRead, this, _1)); + connection->onDataWritten.disconnect(boost::bind(&TLSConnection::handleRawDataWritten, this)); + connection->onDisconnected.disconnect(boost::bind(&TLSConnection::handleRawDisconnected, this, _1)); + delete context; +} + +void TLSConnection::handleTLSConnectFinished(bool error) { + onConnectFinished(error); + if (error) { + disconnect(); + } +} + +void TLSConnection::handleTLSDataForNetwork(const SafeByteArray& data) { + connection->write(data); +} + +void TLSConnection::handleTLSDataForApplication(const SafeByteArray& data) { + onDataRead(boost::make_shared(data)); +} + +void TLSConnection::connect(const HostAddressPort& address) { + connection->connect(address); +} + +void TLSConnection::disconnect() { + connection->disconnect(); +} + +void TLSConnection::write(const SafeByteArray& data) { + context->handleDataFromApplication(data); +} + +HostAddressPort TLSConnection::getLocalAddress() const { + return connection->getLocalAddress(); +} + +void TLSConnection::handleRawConnectFinished(bool error) { + connection->onConnectFinished.disconnect(boost::bind(&TLSConnection::handleRawConnectFinished, this, _1)); + if (error) { + onConnectFinished(true); + } + else { + context->connect(); + } +} + +void TLSConnection::handleRawDisconnected(const boost::optional& error) { + onDisconnected(error); +} + +void TLSConnection::handleRawDataRead(boost::shared_ptr data) { + context->handleDataFromNetwork(*data); +} + +void TLSConnection::handleRawDataWritten() { + onDataWritten(); +} + +} diff --git a/Swiften/Network/TLSConnection.h b/Swiften/Network/TLSConnection.h new file mode 100644 index 0000000..a798393 --- /dev/null +++ b/Swiften/Network/TLSConnection.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace Swift { + class HostAddressPort; + class TLSContextFactory; + class TLSContext; + + class TLSConnection : public Connection { + public: + + TLSConnection(Connection::ref connection, TLSContextFactory* tlsFactory); + virtual ~TLSConnection(); + + virtual void listen() {assert(false);}; + virtual void connect(const HostAddressPort& address); + virtual void disconnect(); + virtual void write(const SafeByteArray& data); + + virtual HostAddressPort getLocalAddress() const; + + private: + void handleRawConnectFinished(bool error); + void handleRawDisconnected(const boost::optional& error); + void handleRawDataRead(boost::shared_ptr data); + void handleRawDataWritten(); + void handleTLSConnectFinished(bool error); + void handleTLSDataForNetwork(const SafeByteArray& data); + void handleTLSDataForApplication(const SafeByteArray& data); + private: + TLSContext* context; + Connection::ref connection; + }; +} diff --git a/Swiften/Network/TLSConnectionFactory.cpp b/Swiften/Network/TLSConnectionFactory.cpp new file mode 100644 index 0000000..0c21650 --- /dev/null +++ b/Swiften/Network/TLSConnectionFactory.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include + +#include + +#include + +namespace Swift { + +TLSConnectionFactory::TLSConnectionFactory(TLSContextFactory* contextFactory, ConnectionFactory* connectionFactory) : contextFactory(contextFactory), connectionFactory(connectionFactory){ + +} + +TLSConnectionFactory::~TLSConnectionFactory() { + +} + + +boost::shared_ptr TLSConnectionFactory::createConnection() { + return boost::make_shared(connectionFactory->createConnection(), contextFactory); +} + +} diff --git a/Swiften/Network/TLSConnectionFactory.h b/Swiften/Network/TLSConnectionFactory.h new file mode 100644 index 0000000..32757a1 --- /dev/null +++ b/Swiften/Network/TLSConnectionFactory.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include + +#include +#include + +namespace Swift { + class Connection; + + class TLSConnectionFactory : public ConnectionFactory { + public: + TLSConnectionFactory(TLSContextFactory* contextFactory, ConnectionFactory* connectionFactory); + virtual ~TLSConnectionFactory(); + + virtual boost::shared_ptr createConnection(); + private: + TLSContextFactory* contextFactory; + ConnectionFactory* connectionFactory; + }; +} diff --git a/Swiften/Network/UnitTest/BOSHConnectionPoolTest.cpp b/Swiften/Network/UnitTest/BOSHConnectionPoolTest.cpp new file mode 100644 index 0000000..978bf3b --- /dev/null +++ b/Swiften/Network/UnitTest/BOSHConnectionPoolTest.cpp @@ -0,0 +1,423 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace Swift; + +typedef boost::shared_ptr PoolRef; + +class BOSHConnectionPoolTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(BOSHConnectionPoolTest); + CPPUNIT_TEST(testConnectionCount_OneWrite); + CPPUNIT_TEST(testConnectionCount_TwoWrites); + CPPUNIT_TEST(testConnectionCount_ThreeWrites); + CPPUNIT_TEST(testConnectionCount_ThreeWrites_ManualConnect); + CPPUNIT_TEST(testConnectionCount_ThreeWritesTwoReads); + CPPUNIT_TEST(testSession); + CPPUNIT_TEST(testWrite_Empty); + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + to = "wonderland.lit"; + path = "http-bind"; + port = "5280"; + sid = "MyShinySID"; + initial = ""; + eventLoop = new DummyEventLoop(); + connectionFactory = new MockConnectionFactory(eventLoop); + factory = boost::make_shared(URL("http", to, 5280, path), connectionFactory, &parserFactory, static_cast(NULL)); + sessionTerminated = 0; + sessionStarted = 0; + initialRID = 2349876; + xmppDataRead.clear(); + boshDataRead.clear(); + boshDataWritten.clear(); + } + + void tearDown() { + eventLoop->processEvents(); + delete connectionFactory; + delete eventLoop; + } + + void testConnectionCount_OneWrite() { + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(0, sessionStarted); + readResponse(initial, connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(1, sessionStarted); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + testling->write(createSafeByteArray("")); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + CPPUNIT_ASSERT_EQUAL(1, sessionStarted); + } + + void testConnectionCount_TwoWrites() { + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + readResponse(initial, connectionFactory->connections[0]); + testling->write(createSafeByteArray("")); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + testling->write(createSafeByteArray("")); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + } + + void testConnectionCount_ThreeWrites() { + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + readResponse(initial, connectionFactory->connections[0]); + testling->restartStream(); + readResponse("", connectionFactory->connections[0]); + testling->restartStream(); + readResponse("", connectionFactory->connections[0]); + testling->write(createSafeByteArray("")); + testling->write(createSafeByteArray("")); + testling->write(createSafeByteArray("")); + eventLoop->processEvents(); + CPPUNIT_ASSERT(st(2) >= connectionFactory->connections.size()); + } + + void testConnectionCount_ThreeWrites_ManualConnect() { + connectionFactory->autoFinishConnect = false; + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + CPPUNIT_ASSERT_EQUAL(st(0), boshDataWritten.size()); /* Connection not connected yet, can't send data */ + + connectionFactory->connections[0]->onConnectFinished(false); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* Connection finished, stream header sent */ + + readResponse(initial, connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* Don't respond to initial data with a holding call */ + + testling->restartStream(); + readResponse("", connectionFactory->connections[0]); + testling->restartStream(); + + + testling->write(createSafeByteArray("")); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); /* New connection isn't up yet. */ + + connectionFactory->connections[1]->onConnectFinished(false); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* New connection ready. */ + + testling->write(createSafeByteArray("")); + testling->write(createSafeByteArray("")); + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* New data can't be sent, no free connections. */ + eventLoop->processEvents(); + CPPUNIT_ASSERT(st(2) >= connectionFactory->connections.size()); + } + + void testConnectionCount_ThreeWritesTwoReads() { + boost::shared_ptr c0; + boost::shared_ptr c1; + long rid = initialRID; + + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + c0 = connectionFactory->connections[0]; + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* header*/ + + rid++; + readResponse(initial, c0); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + CPPUNIT_ASSERT(!c0->pending); + + rid++; + testling->restartStream(); + readResponse("", connectionFactory->connections[0]); + + rid++; + testling->write(createSafeByteArray("")); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); /* 0 was waiting for response, open and send on 1 */ + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* data */ + c1 = connectionFactory->connections[1]; + std::string fullBody = ""; /* check empty write */ + CPPUNIT_ASSERT_EQUAL(fullBody, lastBody()); + CPPUNIT_ASSERT(c0->pending); + CPPUNIT_ASSERT(c1->pending); + + + rid++; + readResponse("", c0); /* Doesn't include necessary attributes - as the support is improved this'll start to fail */ + eventLoop->processEvents(); + CPPUNIT_ASSERT(!c0->pending); + CPPUNIT_ASSERT(c1->pending); + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* don't send empty in [0], still have [1] waiting */ + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + + rid++; + readResponse("", c1); + eventLoop->processEvents(); + CPPUNIT_ASSERT(!c1->pending); + CPPUNIT_ASSERT(c0->pending); + CPPUNIT_ASSERT_EQUAL(st(5), boshDataWritten.size()); /* empty to make room */ + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + + rid++; + testling->write(createSafeByteArray("")); + CPPUNIT_ASSERT(c0->pending); + CPPUNIT_ASSERT(c1->pending); + CPPUNIT_ASSERT_EQUAL(st(6), boshDataWritten.size()); /* data */ + + rid++; + testling->write(createSafeByteArray("")); + CPPUNIT_ASSERT(c0->pending); + CPPUNIT_ASSERT(c1->pending); + CPPUNIT_ASSERT_EQUAL(st(6), boshDataWritten.size()); /* Don't send data, no room */ + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + } + + void testSession() { + to = "prosody.doomsong.co.uk"; + path = "http-bind/"; + factory = boost::make_shared(URL("http", to, 5280, path), connectionFactory, &parserFactory, static_cast(NULL)); + + + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* header*/ + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + std::string response = "SCRAM-SHA-1DIGEST-MD5"; + readResponse(response, connectionFactory->connections[0]); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + std::string send = "biwsbj1hZG1pbixyPWZhOWE5ZDhiLWZmMDctNGE4Yy04N2E3LTg4YWRiNDQxZGUwYg=="; + testling->write(createSafeByteArray(send)); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + response = "cj1mYTlhOWQ4Yi1mZjA3LTRhOGMtODdhNy04OGFkYjQ0MWRlMGJhZmZlMWNhMy1mMDJkLTQ5NzEtYjkyNS0yM2NlNWQ2MDQyMjYscz1OVGd5WkdWaFptTXRaVE15WXkwMFpXUmhMV0ZqTURRdFpqYzRNbUppWmpGa1pqWXgsaT00MDk2"; + readResponse(response, connectionFactory->connections[0]); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + send = "Yz1iaXdzLHI9ZmE5YTlkOGItZmYwNy00YThjLTg3YTctODhhZGI0NDFkZTBiYWZmZTFjYTMtZjAyZC00OTcxLWI5MjUtMjNjZTVkNjA0MjI2LHA9aU11NWt3dDN2VWplU2RqL01Jb3VIRldkZjBnPQ=="; + testling->write(createSafeByteArray(send)); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + response = "dj1YNmNBY3BBOWxHNjNOOXF2bVQ5S0FacERrVm89"; + readResponse(response, connectionFactory->connections[0]); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + } + + void testWrite_Empty() { + boost::shared_ptr c0; + + PoolRef testling = createTestling(); + c0 = connectionFactory->connections[0]; + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + + readResponse(initial, c0); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* Shouldn't have sent anything extra */ + testling->restartStream(); + CPPUNIT_ASSERT_EQUAL(st(2), boshDataWritten.size()); + readResponse("", c0); + CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); + std::string fullBody = ""; + std::string response = boshDataWritten[2]; + size_t bodyPosition = response.find("\r\n\r\n"); + CPPUNIT_ASSERT_EQUAL(fullBody, response.substr(bodyPosition+4)); + + + } + + private: + + PoolRef createTestling() { + PoolRef pool = boost::make_shared(factory, to, initialRID, URL(), "", ""); + pool->onXMPPDataRead.connect(boost::bind(&BOSHConnectionPoolTest::handleXMPPDataRead, this, _1)); + pool->onBOSHDataRead.connect(boost::bind(&BOSHConnectionPoolTest::handleBOSHDataRead, this, _1)); + pool->onBOSHDataWritten.connect(boost::bind(&BOSHConnectionPoolTest::handleBOSHDataWritten, this, _1)); + pool->onSessionStarted.connect(boost::bind(&BOSHConnectionPoolTest::handleSessionStarted, this)); + pool->onSessionTerminated.connect(boost::bind(&BOSHConnectionPoolTest::handleSessionTerminated, this)); + return pool; + } + + std::string lastBody() { + std::string response = boshDataWritten[boshDataWritten.size() - 1]; + size_t bodyPosition = response.find("\r\n\r\n"); + return response.substr(bodyPosition+4); + } + + size_t st(int val) { + return static_cast(val); + } + + void handleXMPPDataRead(const SafeByteArray& d) { + xmppDataRead.push_back(safeByteArrayToString(d)); + } + + void handleBOSHDataRead(const SafeByteArray& d) { + boshDataRead.push_back(safeByteArrayToString(d)); + } + + void handleBOSHDataWritten(const SafeByteArray& d) { + boshDataWritten.push_back(safeByteArrayToString(d)); + } + + + void handleSessionStarted() { + sessionStarted++; + } + + void handleSessionTerminated() { + sessionTerminated++; + } + + struct MockConnection : public Connection { + public: + MockConnection(const std::vector& failingPorts, EventLoop* eventLoop, bool autoFinishConnect) : eventLoop(eventLoop), failingPorts(failingPorts), disconnected(false), pending(false), autoFinishConnect(autoFinishConnect) { + } + + void listen() { assert(false); } + + void connect(const HostAddressPort& address) { + hostAddressPort = address; + bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); + if (autoFinishConnect) { + eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail)); + } + } + + HostAddressPort getLocalAddress() const { return HostAddressPort(); } + + void disconnect() { + disconnected = true; + onDisconnected(boost::optional()); + } + + void write(const SafeByteArray& d) { + append(dataWritten, d); + pending = true; + } + + EventLoop* eventLoop; + boost::optional hostAddressPort; + std::vector failingPorts; + ByteArray dataWritten; + bool disconnected; + bool pending; + bool autoFinishConnect; + }; + + struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory(EventLoop* eventLoop, bool autoFinishConnect = true) : eventLoop(eventLoop), autoFinishConnect(autoFinishConnect) { + } + + boost::shared_ptr createConnection() { + boost::shared_ptr connection = boost::make_shared(failingPorts, eventLoop, autoFinishConnect); + connections.push_back(connection); + return connection; + } + + EventLoop* eventLoop; + std::vector< boost::shared_ptr > connections; + std::vector failingPorts; + bool autoFinishConnect; + }; + + void readResponse(const std::string& response, boost::shared_ptr connection) { + connection->pending = false; + boost::shared_ptr data1 = boost::make_shared(createSafeByteArray( + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/xml; charset=utf-8\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Access-Control-Allow-Headers: Content-Type\r\n" + "Content-Length: ")); + connection->onDataRead(data1); + boost::shared_ptr data2 = boost::make_shared(createSafeByteArray(boost::lexical_cast(response.size()))); + connection->onDataRead(data2); + boost::shared_ptr data3 = boost::make_shared(createSafeByteArray("\r\n\r\n")); + connection->onDataRead(data3); + boost::shared_ptr data4 = boost::make_shared(createSafeByteArray(response)); + connection->onDataRead(data4); + } + + std::string fullRequestFor(const std::string& data) { + std::string body = data; + std::string result = "POST /" + path + " HTTP/1.1\r\n" + + "Host: " + to + ":" + port + "\r\n" + + "Content-Type: text/xml; charset=utf-8\r\n" + + "Content-Length: " + boost::lexical_cast(body.size()) + "\r\n\r\n" + + body; + return result; + } + + private: + DummyEventLoop* eventLoop; + MockConnectionFactory* connectionFactory; + std::vector xmppDataRead; + std::vector boshDataRead; + std::vector boshDataWritten; + PlatformXMLParserFactory parserFactory; + std::string to; + std::string path; + std::string port; + std::string sid; + std::string initial; + long initialRID; + boost::shared_ptr factory; + int sessionStarted; + int sessionTerminated; + +}; + + +CPPUNIT_TEST_SUITE_REGISTRATION(BOSHConnectionPoolTest); diff --git a/Swiften/Network/UnitTest/BOSHConnectionTest.cpp b/Swiften/Network/UnitTest/BOSHConnectionTest.cpp new file mode 100644 index 0000000..9215725 --- /dev/null +++ b/Swiften/Network/UnitTest/BOSHConnectionTest.cpp @@ -0,0 +1,291 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace Swift; + +class BOSHConnectionTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(BOSHConnectionTest); + CPPUNIT_TEST(testHeader); + CPPUNIT_TEST(testReadiness_ok); + CPPUNIT_TEST(testReadiness_pending); + CPPUNIT_TEST(testReadiness_disconnect); + CPPUNIT_TEST(testReadiness_noSID); + CPPUNIT_TEST(testWrite_Receive); + CPPUNIT_TEST(testWrite_ReceiveTwice); + CPPUNIT_TEST(testRead_Fragment); + CPPUNIT_TEST(testHTTPRequest); + CPPUNIT_TEST(testHTTPRequest_Empty); + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + eventLoop = new DummyEventLoop(); + connectionFactory = new MockConnectionFactory(eventLoop); + connectFinished = false; + disconnected = false; + dataRead.clear(); + } + + void tearDown() { + eventLoop->processEvents(); + delete connectionFactory; + delete eventLoop; + } + + void testHeader() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + testling->startStream("wonderland.lit", 1); + std::string initial(""); + readResponse(initial, connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(std::string("MyShinySID"), sid); + CPPUNIT_ASSERT(testling->isReadyToSend()); + } + + void testReadiness_ok() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + testling->setSID("blahhhhh"); + CPPUNIT_ASSERT(testling->isReadyToSend()); + } + + void testReadiness_pending() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + testling->setSID("mySID"); + CPPUNIT_ASSERT(testling->isReadyToSend()); + testling->write(createSafeByteArray("")); + CPPUNIT_ASSERT(!testling->isReadyToSend()); + readResponse("", connectionFactory->connections[0]); + CPPUNIT_ASSERT(testling->isReadyToSend()); + } + + void testReadiness_disconnect() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + testling->setSID("mySID"); + CPPUNIT_ASSERT(testling->isReadyToSend()); + connectionFactory->connections[0]->onDisconnected(false); + CPPUNIT_ASSERT(!testling->isReadyToSend()); + } + + + void testReadiness_noSID() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + CPPUNIT_ASSERT(!testling->isReadyToSend()); + } + + void testWrite_Receive() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + testling->setSID("mySID"); + testling->write(createSafeByteArray("")); + readResponse("", connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(std::string(""), byteArrayToString(dataRead)); + + } + + void testWrite_ReceiveTwice() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + testling->setSID("mySID"); + testling->write(createSafeByteArray("")); + readResponse("", connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(std::string(""), byteArrayToString(dataRead)); + dataRead.clear(); + testling->write(createSafeByteArray("")); + readResponse("", connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(std::string(""), byteArrayToString(dataRead)); + } + + void testRead_Fragment() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(static_cast(1), connectionFactory->connections.size()); + boost::shared_ptr connection = connectionFactory->connections[0]; + boost::shared_ptr data1 = boost::make_shared(createSafeByteArray( + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/xml; charset=utf-8\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Access-Control-Allow-Headers: Content-Type\r\n" + "Content-Length: 64\r\n")); + boost::shared_ptr data2 = boost::make_shared(createSafeByteArray( + "\r\n" + " data3 = boost::make_shared(createSafeByteArray( + "ah/>" + "")); + connection->onDataRead(data1); + connection->onDataRead(data2); + CPPUNIT_ASSERT(dataRead.empty()); + connection->onDataRead(data3); + CPPUNIT_ASSERT_EQUAL(std::string(""), byteArrayToString(dataRead)); + } + + void testHTTPRequest() { + std::string data = ""; + std::string sid = "wigglebloom"; + std::string fullBody = "" + data + ""; + std::pair http = BOSHConnection::createHTTPRequest(createSafeByteArray(data), false, false, 20, sid, URL()); + CPPUNIT_ASSERT_EQUAL(fullBody.size(), http.second); + } + + void testHTTPRequest_Empty() { + std::string data = ""; + std::string sid = "wigglebloomsickle"; + std::string fullBody = "" + data + ""; + std::pair http = BOSHConnection::createHTTPRequest(createSafeByteArray(data), false, false, 42, sid, URL()); + CPPUNIT_ASSERT_EQUAL(fullBody.size(), http.second); + std::string response = safeByteArrayToString(http.first); + size_t bodyPosition = response.find("\r\n\r\n"); + CPPUNIT_ASSERT_EQUAL(fullBody, response.substr(bodyPosition+4)); + } + + private: + + BOSHConnection::ref createTestling() { + BOSHConnection::ref c = BOSHConnection::create(URL("http", "wonderland.lit", 5280, "http-bind"), connectionFactory, &parserFactory, static_cast(NULL)); + c->onConnectFinished.connect(boost::bind(&BOSHConnectionTest::handleConnectFinished, this, _1)); + c->onDisconnected.connect(boost::bind(&BOSHConnectionTest::handleDisconnected, this, _1)); + c->onXMPPDataRead.connect(boost::bind(&BOSHConnectionTest::handleDataRead, this, _1)); + c->onSessionStarted.connect(boost::bind(&BOSHConnectionTest::handleSID, this, _1)); + c->setRID(42); + return c; + } + + void handleConnectFinished(bool error) { + connectFinished = true; + connectFinishedWithError = error; + } + + void handleDisconnected(const boost::optional& e) { + disconnected = true; + disconnectedError = e; + } + + void handleDataRead(const SafeByteArray& d) { + append(dataRead, d); + } + + void handleSID(const std::string& s) { + sid = s; + } + + struct MockConnection : public Connection { + public: + MockConnection(const std::vector& failingPorts, EventLoop* eventLoop) : eventLoop(eventLoop), failingPorts(failingPorts), disconnected(false) { + } + + void listen() { assert(false); } + + void connect(const HostAddressPort& address) { + hostAddressPort = address; + bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); + eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail)); + } + + HostAddressPort getLocalAddress() const { return HostAddressPort(); } + + void disconnect() { + disconnected = true; + onDisconnected(boost::optional()); + } + + void write(const SafeByteArray& d) { + append(dataWritten, d); + } + + EventLoop* eventLoop; + boost::optional hostAddressPort; + std::vector failingPorts; + ByteArray dataWritten; + bool disconnected; + }; + + struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory(EventLoop* eventLoop) : eventLoop(eventLoop) { + } + + boost::shared_ptr createConnection() { + boost::shared_ptr connection = boost::make_shared(failingPorts, eventLoop); + connections.push_back(connection); + return connection; + } + + EventLoop* eventLoop; + std::vector< boost::shared_ptr > connections; + std::vector failingPorts; + }; + + void readResponse(const std::string& response, boost::shared_ptr connection) { + boost::shared_ptr data1 = boost::make_shared(createSafeByteArray( + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/xml; charset=utf-8\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Access-Control-Allow-Headers: Content-Type\r\n" + "Content-Length: ")); + connection->onDataRead(data1); + boost::shared_ptr data2 = boost::make_shared(createSafeByteArray(boost::lexical_cast(response.size()))); + connection->onDataRead(data2); + boost::shared_ptr data3 = boost::make_shared(createSafeByteArray("\r\n\r\n")); + connection->onDataRead(data3); + boost::shared_ptr data4 = boost::make_shared(createSafeByteArray(response)); + connection->onDataRead(data4); + } + + + private: + DummyEventLoop* eventLoop; + MockConnectionFactory* connectionFactory; + bool connectFinished; + bool connectFinishedWithError; + bool disconnected; + boost::optional disconnectedError; + ByteArray dataRead; + PlatformXMLParserFactory parserFactory; + std::string sid; + +}; + + +CPPUNIT_TEST_SUITE_REGISTRATION(BOSHConnectionTest); diff --git a/Swiften/Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp b/Swiften/Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp index 133773f..c0252d4 100644 --- a/Swiften/Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp +++ b/Swiften/Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp @@ -163,7 +163,7 @@ class HTTPConnectProxiedConnectionTest : public CppUnit::TestFixture { private: HTTPConnectProxiedConnection::ref createTestling() { - boost::shared_ptr c = HTTPConnectProxiedConnection::create(connectionFactory, proxyHost); + boost::shared_ptr c = HTTPConnectProxiedConnection::create(connectionFactory, proxyHost, "", ""); c->onConnectFinished.connect(boost::bind(&HTTPConnectProxiedConnectionTest::handleConnectFinished, this, _1)); c->onDisconnected.connect(boost::bind(&HTTPConnectProxiedConnectionTest::handleDisconnected, this, _1)); c->onDataRead.connect(boost::bind(&HTTPConnectProxiedConnectionTest::handleDataRead, this, _1)); diff --git a/Swiften/Parser/BOSHBodyExtractor.cpp b/Swiften/Parser/BOSHBodyExtractor.cpp index d8759a3..eeebe8a 100644 --- a/Swiften/Parser/BOSHBodyExtractor.cpp +++ b/Swiften/Parser/BOSHBodyExtractor.cpp @@ -124,6 +124,8 @@ BOSHBodyExtractor::BOSHBodyExtractor(XMLParserFactory* parserFactory, const Byte BOSHBodyParserClient parserClient(this); boost::shared_ptr parser(parserFactory->createXMLParser(&parserClient)); if (!parser->parse(std::string(reinterpret_cast(vecptr(data)), std::distance(data.begin(), i)))) { + /* TODO: This needs to be only validating the BOSH element, so that XMPP parsing errors are caught at + the correct higher layer */ body = boost::optional(); return; } diff --git a/Swiften/Parser/BOSHParser.cpp b/Swiften/Parser/BOSHParser.cpp deleted file mode 100644 index 9fb218a..0000000 --- a/Swiften/Parser/BOSHParser.cpp +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) 2011 Thilo Cestonaro - * Licensed under the simplified BSD license. - * See Documentation/Licenses/BSD-simplified.txt for more information. - */ - -#include - -#include -#include -#include - -namespace Swift { - -BOSHParser::BOSHParser() : - xmlParser_(0), - level_(-1), - parseErrorOccurred_(false) -{ - xmlParser_ = PlatformXMLParserFactory().createXMLParser(this); -} - -BOSHParser::~BOSHParser() { - delete xmlParser_; -} - -bool BOSHParser::parse(const std::string& data) { - bool xmlParseResult = xmlParser_->parse(data); - return xmlParseResult && !parseErrorOccurred_; -} - -void BOSHParser::handleStartElement(const std::string& /*element*/, const std::string& /*ns*/, const AttributeMap& attributes) { - if (!parseErrorOccurred_) { - if (level_ == BoshTopLevel) { - boshBodyAttributes_ = attributes; - } - } - ++level_; -} - -void BOSHParser::handleEndElement(const std::string& /*element*/, const std::string& /*ns*/) { - assert(level_ > BoshTopLevel); - --level_; - if (!parseErrorOccurred_) { - - } -} - -void BOSHParser::handleCharacterData(const std::string& /*data*/) { - if (!parseErrorOccurred_) { - - } -} - -} diff --git a/Swiften/Parser/BOSHParser.h b/Swiften/Parser/BOSHParser.h deleted file mode 100644 index 69b3d13..0000000 --- a/Swiften/Parser/BOSHParser.h +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2011 Thilo Cestonaro - * Licensed under the simplified BSD license. - * See Documentation/Licenses/BSD-simplified.txt for more information. - */ - -#pragma once - -#include - -#include -#include - -namespace Swift { - class XMLParser; - - class BOSHParser : public XMLParserClient, boost::noncopyable { - public: - BOSHParser(); - ~BOSHParser(); - - bool parse(const std::string&); - - std::string getAttribute(const std::string& attribute, const std::string& ns = "") const { - return boshBodyAttributes_.getAttribute(attribute, ns); - } - private: - virtual void handleStartElement( - const std::string& element, - const std::string& ns, - const AttributeMap& attributes); - virtual void handleEndElement(const std::string& element, const std::string& ns); - virtual void handleCharacterData(const std::string& data); - - private: - AttributeMap boshBodyAttributes_; - XMLParser* xmlParser_; - enum Level { - BoshTopLevel = -1, - TopLevel = 0, - StreamLevel = 1, - ElementLevel = 2 - }; - int level_; - bool parseErrorOccurred_; - }; -} diff --git a/Swiften/Parser/SConscript b/Swiften/Parser/SConscript index dd19238..e4c2778 100644 --- a/Swiften/Parser/SConscript +++ b/Swiften/Parser/SConscript @@ -11,7 +11,6 @@ sources = [ "AuthChallengeParser.cpp", "AuthSuccessParser.cpp", "AuthResponseParser.cpp", - "BOSHParser.cpp", "CompressParser.cpp", "ElementParser.cpp", "IQParser.cpp", diff --git a/Swiften/QA/ClientTest/ClientTest.cpp b/Swiften/QA/ClientTest/ClientTest.cpp index 4515893..3b8734e 100644 --- a/Swiften/QA/ClientTest/ClientTest.cpp +++ b/Swiften/QA/ClientTest/ClientTest.cpp @@ -28,12 +28,13 @@ enum TestStage { Reconnect }; TestStage stage; +ClientOptions options; void handleDisconnected(boost::optional e) { std::cout << "Disconnected: " << e << std::endl; if (stage == FirstConnect) { stage = Reconnect; - client->connect(); + client->connect(options); } else { eventLoop.stop(); @@ -66,13 +67,22 @@ int main(int, char**) { return -1; } + char* boshHost = getenv("SWIFT_CLIENTTEST_BOSH_HOST"); + char* boshPort = getenv("SWIFT_CLIENTTEST_BOSH_PORT"); + char* boshPath = getenv("SWIFT_CLIENTTEST_BOSH_PATH"); + + if (boshHost && boshPort && boshPath) { + std::cout << "Using BOSH with URL: http://" << boshHost << ":" << boshPort << "/" << boshPath << std::endl; + options.boshURL = URL("http", boshHost, atoi(boshPort), boshPath); + } + client = new Swift::Client(JID(jid), std::string(pass), &networkFactories); - ClientXMLTracer* tracer = new ClientXMLTracer(client); + ClientXMLTracer* tracer = new ClientXMLTracer(client, !options.boshURL.empty()); client->onConnected.connect(&handleConnected); client->onDisconnected.connect(boost::bind(&handleDisconnected, _1)); client->setAlwaysTrustCertificates(); stage = FirstConnect; - client->connect(); + client->connect(options); { Timer::ref timer = networkFactories.getTimerFactory()->createTimer(60000); diff --git a/Swiften/SConscript b/Swiften/SConscript index 8c3ad42..9e61fc6 100644 --- a/Swiften/SConscript +++ b/Swiften/SConscript @@ -191,6 +191,7 @@ if env["SCONS_STAGE"] == "build" : "Session/SessionTracer.cpp", "Session/SessionStream.cpp", "Session/BasicSessionStream.cpp", + "Session/BOSHSessionStream.cpp", "StringCodecs/Base64.cpp", "StringCodecs/SHA1.cpp", "StringCodecs/SHA256.cpp", @@ -285,6 +286,8 @@ if env["SCONS_STAGE"] == "build" : File("Network/UnitTest/ConnectorTest.cpp"), File("Network/UnitTest/ChainedConnectorTest.cpp"), File("Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp"), + File("Network/UnitTest/BOSHConnectionTest.cpp"), + File("Network/UnitTest/BOSHConnectionPoolTest.cpp"), File("Parser/PayloadParsers/UnitTest/BodyParserTest.cpp"), File("Parser/PayloadParsers/UnitTest/DiscoInfoParserTest.cpp"), File("Parser/PayloadParsers/UnitTest/ErrorParserTest.cpp"), diff --git a/Swiften/Session/BOSHSessionStream.cpp b/Swiften/Session/BOSHSessionStream.cpp new file mode 100644 index 0000000..95390f4 --- /dev/null +++ b/Swiften/Session/BOSHSessionStream.cpp @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Swift { + +BOSHSessionStream::BOSHSessionStream( + boost::shared_ptr connectionFactory, /*FIXME: probably rip out*/ + PayloadParserFactoryCollection* payloadParserFactories, + PayloadSerializerCollection* payloadSerializers, + TLSContextFactory* tlsContextFactory, + TimerFactory* timerFactory, + XMLParserFactory* xmlParserFactory, + EventLoop* eventLoop, + const std::string& to, + const URL& boshHTTPConnectProxyURL, + const SafeString& boshHTTPConnectProxyAuthID, + const SafeString& boshHTTPConnectProxyAuthPassword) : + available(false), + payloadParserFactories(payloadParserFactories), + payloadSerializers(payloadSerializers), + tlsContextFactory(tlsContextFactory), + timerFactory(timerFactory), + xmlParserFactory(xmlParserFactory), + eventLoop(eventLoop), + firstHeader(true) { + + boost::mt19937 random; + boost::uniform_int<> dist(0, LONG_MAX); + random.seed(time(NULL)); + boost::variate_generator > randomRID(random, dist); + long initialRID = randomRID(); + + connectionPool = new BOSHConnectionPool(connectionFactory, to, initialRID, boshHTTPConnectProxyURL, boshHTTPConnectProxyAuthID, boshHTTPConnectProxyAuthPassword); + connectionPool->onSessionTerminated.connect(boost::bind(&BOSHSessionStream::handlePoolSessionTerminated, this, _1)); + connectionPool->onSessionStarted.connect(boost::bind(&BOSHSessionStream::handlePoolSessionStarted, this)); + connectionPool->onXMPPDataRead.connect(boost::bind(&BOSHSessionStream::handlePoolXMPPDataRead, this, _1)); + connectionPool->onBOSHDataRead.connect(boost::bind(&BOSHSessionStream::handlePoolBOSHDataRead, this, _1)); + connectionPool->onBOSHDataWritten.connect(boost::bind(&BOSHSessionStream::handlePoolBOSHDataWritten, this, _1)); + + xmppLayer = new XMPPLayer(payloadParserFactories, payloadSerializers, xmlParserFactory, ClientStreamType); + xmppLayer->onStreamStart.connect(boost::bind(&BOSHSessionStream::handleStreamStartReceived, this, _1)); + xmppLayer->onElement.connect(boost::bind(&BOSHSessionStream::handleElementReceived, this, _1)); + xmppLayer->onError.connect(boost::bind(&BOSHSessionStream::handleXMPPError, this)); + xmppLayer->onWriteData.connect(boost::bind(&BOSHSessionStream::handleXMPPLayerDataWritten, this, _1)); + + available = true; +} + +BOSHSessionStream::~BOSHSessionStream() { + close(); + connectionPool->onSessionTerminated.disconnect(boost::bind(&BOSHSessionStream::handlePoolSessionTerminated, this, _1)); + connectionPool->onSessionStarted.disconnect(boost::bind(&BOSHSessionStream::handlePoolSessionStarted, this)); + connectionPool->onXMPPDataRead.disconnect(boost::bind(&BOSHSessionStream::handlePoolXMPPDataRead, this, _1)); + connectionPool->onBOSHDataRead.disconnect(boost::bind(&BOSHSessionStream::handlePoolBOSHDataRead, this, _1)); + connectionPool->onBOSHDataWritten.disconnect(boost::bind(&BOSHSessionStream::handlePoolBOSHDataWritten, this, _1)); + delete connectionPool; + connectionPool = NULL; + xmppLayer->onStreamStart.disconnect(boost::bind(&BOSHSessionStream::handleStreamStartReceived, this, _1)); + xmppLayer->onElement.disconnect(boost::bind(&BOSHSessionStream::handleElementReceived, this, _1)); + xmppLayer->onError.disconnect(boost::bind(&BOSHSessionStream::handleXMPPError, this)); + xmppLayer->onWriteData.disconnect(boost::bind(&BOSHSessionStream::handleXMPPLayerDataWritten, this, _1)); + delete xmppLayer; + xmppLayer = NULL; +} + +void BOSHSessionStream::handlePoolXMPPDataRead(const SafeByteArray& data) { + xmppLayer->handleDataRead(data); +} + +void BOSHSessionStream::writeElement(boost::shared_ptr element) { + assert(available); + xmppLayer->writeElement(element); +} + +void BOSHSessionStream::writeFooter() { + connectionPool->writeFooter(); +} + +void BOSHSessionStream::writeData(const std::string& data) { + assert(available); + xmppLayer->writeData(data); +} + +void BOSHSessionStream::close() { + connectionPool->close(); +} + +bool BOSHSessionStream::isOpen() { + return available; +} + +bool BOSHSessionStream::supportsTLSEncryption() { + return false; +} + +void BOSHSessionStream::addTLSEncryption() { + assert(available); +} + +bool BOSHSessionStream::isTLSEncrypted() { + return false; +} + +Certificate::ref BOSHSessionStream::getPeerCertificate() const { + return Certificate::ref(); +} + +boost::shared_ptr BOSHSessionStream::getPeerCertificateVerificationError() const { + return boost::shared_ptr(); +} + +ByteArray BOSHSessionStream::getTLSFinishMessage() const { + return ByteArray(); +} + +bool BOSHSessionStream::supportsZLibCompression() { + return false; +} + +void BOSHSessionStream::addZLibCompression() { + +} + +void BOSHSessionStream::setWhitespacePingEnabled(bool /*enabled*/) { + return; +} + +void BOSHSessionStream::resetXMPPParser() { + xmppLayer->resetParser(); +} + +void BOSHSessionStream::handleStreamStartReceived(const ProtocolHeader& header) { + onStreamStartReceived(header); +} + +void BOSHSessionStream::handleElementReceived(boost::shared_ptr element) { + onElementReceived(element); +} + +void BOSHSessionStream::handleXMPPError() { + available = false; + onClosed(boost::shared_ptr(new Error(Error::ParseError))); +} + +void BOSHSessionStream::handlePoolSessionStarted() { + fakeStreamHeaderReceipt(); +} + +void BOSHSessionStream::handlePoolSessionTerminated(BOSHError::ref error) { + eventLoop->postEvent(boost::bind(&BOSHSessionStream::fakeStreamFooterReceipt, this, error), shared_from_this()); +} + +void BOSHSessionStream::writeHeader(const ProtocolHeader& header) { + streamHeader = header; + /*First time we're told to do this, don't (the sending of the initial header is handled on connect) + On subsequent requests we should restart the stream the BOSH way. + */ + if (!firstHeader) { + eventLoop->postEvent(boost::bind(&BOSHSessionStream::fakeStreamHeaderReceipt, this), shared_from_this()); + eventLoop->postEvent(boost::bind(&BOSHConnectionPool::restartStream, connectionPool), shared_from_this()); + } + firstHeader = false; +} + + +void BOSHSessionStream::fakeStreamHeaderReceipt() { + std::stringstream header; + header << ""; + + xmppLayer->handleDataRead(createSafeByteArray(header.str())); +} + +void BOSHSessionStream::fakeStreamFooterReceipt(BOSHError::ref error) { + std::string footer(""); + xmppLayer->handleDataRead(createSafeByteArray(footer)); + onClosed(error); +} + +void BOSHSessionStream::handleXMPPLayerDataWritten(const SafeByteArray& data) { + eventLoop->postEvent(boost::bind(&BOSHConnectionPool::write, connectionPool, data), shared_from_this()); +} + +void BOSHSessionStream::handlePoolBOSHDataRead(const SafeByteArray& data) { + onDataRead(data); +} + +void BOSHSessionStream::handlePoolBOSHDataWritten(const SafeByteArray& data) { + onDataWritten(data); +} + +}; diff --git a/Swiften/Session/BOSHSessionStream.h b/Swiften/Session/BOSHSessionStream.h new file mode 100644 index 0000000..75c1f2a --- /dev/null +++ b/Swiften/Session/BOSHSessionStream.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace Swift { + class TimerFactory; + class PayloadParserFactoryCollection; + class PayloadSerializerCollection; + class StreamStack; + class XMPPLayer; + class ConnectionLayer; + class CompressionLayer; + class XMLParserFactory; + class TLSContextFactory; + class EventLoop; + + class BOSHSessionStream : public SessionStream, public EventOwner, public boost::enable_shared_from_this { + public: + BOSHSessionStream( + boost::shared_ptr connectionFactory, + PayloadParserFactoryCollection* payloadParserFactories, + PayloadSerializerCollection* payloadSerializers, + TLSContextFactory* tlsContextFactory, + TimerFactory* whitespacePingLayerFactory, + XMLParserFactory* xmlParserFactory, + EventLoop* eventLoop, + const std::string& to, + const URL& boshHTTPConnectProxyURL, + const SafeString& boshHTTPConnectProxyAuthID, + const SafeString& boshHTTPConnectProxyAuthPassword + ); + ~BOSHSessionStream(); + + virtual void close(); + virtual bool isOpen(); + + virtual void writeHeader(const ProtocolHeader& header); + virtual void writeElement(boost::shared_ptr); + virtual void writeFooter(); + virtual void writeData(const std::string& data); + + virtual bool supportsZLibCompression(); + virtual void addZLibCompression(); + + virtual bool supportsTLSEncryption(); + virtual void addTLSEncryption(); + virtual bool isTLSEncrypted(); + virtual Certificate::ref getPeerCertificate() const; + virtual boost::shared_ptr getPeerCertificateVerificationError() const; + virtual ByteArray getTLSFinishMessage() const; + + virtual void setWhitespacePingEnabled(bool); + + virtual void resetXMPPParser(); + + private: + void handleXMPPError(); + void handleStreamStartReceived(const ProtocolHeader&); + void handleElementReceived(boost::shared_ptr); + void handlePoolXMPPDataRead(const SafeByteArray& data); + void handleXMPPLayerDataWritten(const SafeByteArray& data); + void handlePoolSessionStarted(); + void handlePoolBOSHDataRead(const SafeByteArray& data); + void handlePoolBOSHDataWritten(const SafeByteArray& data); + void handlePoolSessionTerminated(BOSHError::ref condition); + + private: + void fakeStreamHeaderReceipt(); + void fakeStreamFooterReceipt(BOSHError::ref error); + + private: + BOSHConnectionPool* connectionPool; + bool available; + PayloadParserFactoryCollection* payloadParserFactories; + PayloadSerializerCollection* payloadSerializers; + TLSContextFactory* tlsContextFactory; + TimerFactory* timerFactory; + XMLParserFactory* xmlParserFactory; + XMPPLayer* xmppLayer; + ProtocolHeader streamHeader; + EventLoop* eventLoop; + bool firstHeader; + }; + +} diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp index 07a04b8..70bbeea 100644 --- a/Swiften/Session/BasicSessionStream.cpp +++ b/Swiften/Session/BasicSessionStream.cpp @@ -136,6 +136,10 @@ ByteArray BasicSessionStream::getTLSFinishMessage() const { return tlsLayer->getContext()->getFinishMessage(); } +bool BasicSessionStream::supportsZLibCompression() { + return true; +} + void BasicSessionStream::addZLibCompression() { compressionLayer = new CompressionLayer(); streamStack->addLayer(compressionLayer); diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h index 2ed5ac6..b0c4331 100644 --- a/Swiften/Session/BasicSessionStream.h +++ b/Swiften/Session/BasicSessionStream.h @@ -47,6 +47,7 @@ namespace Swift { virtual void writeFooter(); virtual void writeData(const std::string& data); + virtual bool supportsZLibCompression(); virtual void addZLibCompression(); virtual bool supportsTLSEncryption(); diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h index e6b9469..096f185 100644 --- a/Swiften/Session/SessionStream.h +++ b/Swiften/Session/SessionStream.h @@ -46,6 +46,7 @@ namespace Swift { virtual void writeElement(boost::shared_ptr) = 0; virtual void writeData(const std::string& data) = 0; + virtual bool supportsZLibCompression() = 0; virtual void addZLibCompression() = 0; virtual bool supportsTLSEncryption() = 0; diff --git a/Swiften/StreamStack/XMPPLayer.h b/Swiften/StreamStack/XMPPLayer.h index 9be00b2..81f0457 100644 --- a/Swiften/StreamStack/XMPPLayer.h +++ b/Swiften/StreamStack/XMPPLayer.h @@ -23,8 +23,10 @@ namespace Swift { class XMPPSerializer; class PayloadSerializerCollection; class XMLParserFactory; + class BOSHSessionStream; class XMPPLayer : public XMPPParserClient, public HighLayer, boost::noncopyable { + friend class BOSHSessionStream; public: XMPPLayer( PayloadParserFactoryCollection* payloadParserFactories, -- cgit v0.10.2-6-g49f6