diff options
author | Kevin Smith <git@kismith.co.uk> | 2011-11-12 16:56:21 (GMT) |
---|---|---|
committer | Kevin Smith <git@kismith.co.uk> | 2011-12-13 08:17:58 (GMT) |
commit | 81c09a0f6a3e87b078340d7f35d0dea4c03f3a6d (patch) | |
tree | 4371c5808ee26b2b5ed79ace9ccb439ff2988945 /Swiften/Network | |
parent | fd17fe0d239f97cedebe4ceffa234155bd299b68 (diff) | |
download | swift-contrib-81c09a0f6a3e87b078340d7f35d0dea4c03f3a6d.zip swift-contrib-81c09a0f6a3e87b078340d7f35d0dea4c03f3a6d.tar.bz2 |
BOSH Support for Swiften
This adds support for BOSH to Swiften. It does not expose it to Swift.
Release-Notes: Swiften now allows connects over BOSH, if used appropriately.
Diffstat (limited to 'Swiften/Network')
21 files changed, 1618 insertions, 125 deletions
diff --git a/Swiften/Network/BOSHConnection.cpp b/Swiften/Network/BOSHConnection.cpp index 549c652..09548e9 100644 --- a/Swiften/Network/BOSHConnection.cpp +++ b/Swiften/Network/BOSHConnection.cpp @@ -4,129 +4,281 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ -#include "BOSHConnection.h" +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include <Swiften/Network/BOSHConnection.h> + #include <boost/bind.hpp> #include <boost/thread.hpp> +#include <boost/lexical_cast.hpp> #include <string> #include <Swiften/Network/ConnectionFactory.h> #include <Swiften/Base/Log.h> #include <Swiften/Base/String.h> +#include <Swiften/Base/Concat.h> #include <Swiften/Base/ByteArray.h> #include <Swiften/Network/HostAddressPort.h> -#include <Swiften/Parser/BOSHParser.h> +#include <Swiften/Network/TLSConnection.h> +#include <Swiften/Parser/BOSHBodyExtractor.h> namespace Swift { - BOSHConnection::BOSHConnection(ConnectionFactory* connectionFactory) - : connectionFactory_(connectionFactory), server_(HostAddressPort(HostAddress("0.0.0.0"), 0)), sid_() - { - reopenAfterAction = true; - } +BOSHConnection::BOSHConnection(const URL& boshURL, ConnectionFactory* connectionFactory, XMLParserFactory* parserFactory, TLSContextFactory* tlsFactory) + : boshURL_(boshURL), + connectionFactory_(connectionFactory), + parserFactory_(parserFactory), + sid_(), + waitingForStartResponse_(false), + pending_(false), + tlsFactory_(tlsFactory), + connectionReady_(false) +{ +} - BOSHConnection::~BOSHConnection() { - if (newConnection_) { - newConnection_->onDataRead.disconnect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); - newConnection_->onDisconnected.disconnect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); - } - if (currentConnection_) { - currentConnection_->onDataRead.disconnect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); - currentConnection_->onDisconnected.disconnect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); - } +BOSHConnection::~BOSHConnection() { + if (connection_) { + connection_->onConnectFinished.disconnect(boost::bind(&BOSHConnection::handleConnectionConnectFinished, shared_from_this(), _1)); + connection_->onDataRead.disconnect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); + connection_->onDisconnected.disconnect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); } + disconnect(); +} - void BOSHConnection::connect(const HostAddressPort& server) { - server_ = server; - newConnection_ = connectionFactory_->createConnection(); - newConnection_->onConnectFinished.connect(boost::bind(&BOSHConnection::handleConnectionConnectFinished, shared_from_this(), _1)); - newConnection_->onDataRead.connect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); - newConnection_->onDisconnected.connect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); - SWIFT_LOG(debug) << "connect to server " << server.getAddress().toString() << ":" << server.getPort() << std::endl; - newConnection_->connect(HostAddressPort(HostAddress("85.10.192.88"), 5280)); - } +void BOSHConnection::connect(const HostAddressPort& server) { + /* FIXME: Redundant parameter */ + Connection::ref rawConnection = connectionFactory_->createConnection(); + connection_ = (boshURL_.getScheme() == "https") ? boost::make_shared<TLSConnection>(rawConnection, tlsFactory_) : rawConnection; + connection_->onConnectFinished.connect(boost::bind(&BOSHConnection::handleConnectionConnectFinished, shared_from_this(), _1)); + connection_->onDataRead.connect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); + connection_->onDisconnected.connect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); + connection_->connect(HostAddressPort(HostAddress(boshURL_.getHost()), boshURL_.getPort())); +} - void BOSHConnection::listen() { - assert(false); +void BOSHConnection::listen() { + assert(false); +} + +void BOSHConnection::disconnect() { + if(connection_) { + connection_->disconnect(); + sid_ = ""; } +} - void BOSHConnection::disconnect() { - if(newConnection_) - newConnection_->disconnect(); +void BOSHConnection::restartStream() { + write(createSafeByteArray(""), true, false); +} + +void BOSHConnection::terminateStream() { + write(createSafeByteArray(""), false, true); +} - if(currentConnection_) - currentConnection_->disconnect(); - } - void BOSHConnection::write(const SafeByteArray& data) { - SWIFT_LOG(debug) << "write data: " << safeByteArrayToString(data) << std::endl; +void BOSHConnection::write(const SafeByteArray& data) { + write(data, false, false); +} + +std::pair<SafeByteArray, size_t> BOSHConnection::createHTTPRequest(const SafeByteArray& data, bool streamRestart, bool terminate, long rid, const std::string& sid, const URL& boshURL) { + size_t size; + std::stringstream content; + SafeByteArray contentTail = createSafeByteArray("</body>"); + std::stringstream header; + + content << "<body rid='" << rid << "' sid='" << sid << "'"; + if (streamRestart) { + content << " xmpp:restart='true' xmlns:xmpp='urn:xmpp:xbosh'"; } + if (terminate) { + content << " type='terminate'"; + } + content << " xmlns='http://jabber.org/protocol/httpbind'>"; - void BOSHConnection::handleConnectionConnectFinished(bool error) { - newConnection_->onConnectFinished.disconnect(boost::bind(&BOSHConnection::handleConnectionConnectFinished, shared_from_this(), _1)); - if(error) { - onConnectFinished(true); - return; - } + SafeByteArray safeContent = createSafeByteArray(content.str()); + safeContent.insert(safeContent.end(), data.begin(), data.end()); + safeContent.insert(safeContent.end(), contentTail.begin(), contentTail.end()); - if(sid_.size() == 0) { - // Session Creation Request - std::stringstream content; - std::stringstream header; - - content << "<body content='text/xml; charset=utf-8'" - << " from='ephraim@0x10.de'" - << " hold='1'" - << " to='0x10.de'" - << " ver='1.6'" - << " wait='60'" - << " ack='1'" - << " xml:lang='en'" - << " xmlns='http://jabber.org/protocol/httpbind' />\r\n"; - - header << "POST /http-bind HTTP/1.1\r\n" - << "Host: 0x10.de:5280\r\n" - << "Accept-Encoding: deflate\r\n" - << "Content-Type: text/xml; charset=utf-8\r\n" - << "Content-Length: " << content.str().size() << "\r\n\r\n" - << content.str(); - - SWIFT_LOG(debug) << "request: "; - newConnection_->write(createSafeByteArray(header.str())); - } + size = safeContent.size(); + + header << "POST /" << boshURL.getPath() << " HTTP/1.1\r\n" + << "Host: " << boshURL.getHost() << ":" << boshURL.getPort() << "\r\n" + /*<< "Accept-Encoding: deflate\r\n"*/ + << "Content-Type: text/xml; charset=utf-8\r\n" + << "Content-Length: " << size << "\r\n\r\n"; + + SafeByteArray safeHeader = createSafeByteArray(header.str()); + safeHeader.insert(safeHeader.end(), safeContent.begin(), safeContent.end()); + + return std::pair<SafeByteArray, size_t>(safeHeader, size); +} + +void BOSHConnection::write(const SafeByteArray& data, bool streamRestart, bool terminate) { + assert(connectionReady_); + assert(!sid_.empty()); + + SafeByteArray safeHeader = createHTTPRequest(data, streamRestart, terminate, rid_, sid_, boshURL_).first; + + onBOSHDataWritten(safeHeader); + connection_->write(safeHeader); + pending_ = true; + + SWIFT_LOG(debug) << "write data: " << safeByteArrayToString(safeHeader) << std::endl; +} + +void BOSHConnection::handleConnectionConnectFinished(bool error) { + connection_->onConnectFinished.disconnect(boost::bind(&BOSHConnection::handleConnectionConnectFinished, shared_from_this(), _1)); + connectionReady_ = !error; + onConnectFinished(error); +} + +void BOSHConnection::startStream(const std::string& to, unsigned long rid) { + assert(connectionReady_); + // Session Creation Request + std::stringstream content; + std::stringstream header; + + content << "<body content='text/xml; charset=utf-8'" + << " hold='1'" + << " to='" << to << "'" + << " rid='" << rid << "'" + << " ver='1.6'" + << " wait='60'" /* FIXME: we probably want this configurable*/ + /*<< " ack='0'" FIXME: support acks */ + << " xml:lang='en'" + << " xmlns:xmpp='urn:xmpp:bosh'" + << " xmpp:version='1.0'" + << " xmlns='http://jabber.org/protocol/httpbind' />"; + + std::string contentString = content.str(); + + header << "POST /" << boshURL_.getPath() << " HTTP/1.1\r\n" + << "Host: " << boshURL_.getHost() << ":" << boshURL_.getPort() << "\r\n" + /*<< "Accept-Encoding: deflate\r\n"*/ + << "Content-Type: text/xml; charset=utf-8\r\n" + << "Content-Length: " << contentString.size() << "\r\n\r\n" + << contentString; + + waitingForStartResponse_ = true; + SafeByteArray safeHeader = createSafeByteArray(header.str()); + onBOSHDataWritten(safeHeader); + connection_->write(safeHeader); + SWIFT_LOG(debug) << "write stream header: " << safeByteArrayToString(safeHeader) << std::endl; +} + +void BOSHConnection::handleDataRead(boost::shared_ptr<SafeByteArray> data) { + onBOSHDataRead(*data.get()); + buffer_ = concat(buffer_, *data.get()); + std::string response = safeByteArrayToString(buffer_); + if (response.find("\r\n\r\n") == std::string::npos) { + onBOSHDataRead(createSafeByteArray("[[Previous read incomplete, pending]]")); + return; + } + + std::string httpCode = response.substr(response.find(" ") + 1, 3); + if (httpCode != "200") { + onHTTPError(httpCode); + return; } - void BOSHConnection::handleDataRead(const SafeByteArray& data) { - std::string response = safeByteArrayToString(data); - assert(response.find("\r\n\r\n") != std::string::npos); - - SWIFT_LOG(debug) << "response: " << response.substr(response.find("\r\n\r\n") + 4) << std::endl; - - BOSHParser parser; - if(parser.parse(response.substr(response.find("\r\n\r\n") + 4))) { - sid_ = parser.getAttribute("sid"); - onConnectFinished(false); - int bodyStartElementLength = 0; - bool inQuote = false; - for(size_t i= 0; i < response.size(); i++) { - if(response.c_str()[i] == '\'' || response.c_str()[i] == '"') { - inQuote = !inQuote; - } - else if(!inQuote && response.c_str()[i] == '>') { - bodyStartElementLength = i + 1; - break; - } + BOSHBodyExtractor parser(parserFactory_, createByteArray(response.substr(response.find("\r\n\r\n") + 4))); + if (parser.getBody()) { + if ((*parser.getBody()).attributes.getAttribute("type") == "terminate") { + BOSHError::Type errorType = parseTerminationCondition((*parser.getBody()).attributes.getAttribute("condition")); + onSessionTerminated(errorType == BOSHError::NoError ? boost::shared_ptr<BOSHError>() : boost::make_shared<BOSHError>(errorType)); + } + buffer_.clear(); + if (waitingForStartResponse_) { + waitingForStartResponse_ = false; + sid_ = (*parser.getBody()).attributes.getAttribute("sid"); + std::string requestsString = (*parser.getBody()).attributes.getAttribute("requests"); + int requests = 2; + if (!requestsString.empty()) { + requests = boost::lexical_cast<size_t>(requestsString); } - SafeByteArray payload = createSafeByteArray(response.substr(bodyStartElementLength, response.size() - bodyStartElementLength - 7)); - SWIFT_LOG(debug) << "payload: " << safeByteArrayToString(payload) << std::endl; - onDataRead(payload); + onSessionStarted(sid_, requests); } + SafeByteArray payload = createSafeByteArray((*parser.getBody()).content); + /* Say we're good to go again, so don't add anything after here in the method */ + pending_ = false; + onXMPPDataRead(payload); } - void BOSHConnection::handleDisconnected(const boost::optional<Error>& error) { - onDisconnected(error); - } +} - HostAddressPort BOSHConnection::getLocalAddress() const { - return newConnection_->getLocalAddress(); +BOSHError::Type BOSHConnection::parseTerminationCondition(const std::string& text) { + BOSHError::Type condition = BOSHError::UndefinedCondition; + if (text == "bad-request") { + condition = BOSHError::BadRequest; + } + else if (text == "host-gone") { + condition = BOSHError::HostGone; + } + else if (text == "host-unknown") { + condition = BOSHError::HostUnknown; + } + else if (text == "improper-addressing") { + condition = BOSHError::ImproperAddressing; + } + else if (text == "internal-server-error") { + condition = BOSHError::InternalServerError; + } + else if (text == "item-not-found") { + condition = BOSHError::ItemNotFound; + } + else if (text == "other-request") { + condition = BOSHError::OtherRequest; + } + else if (text == "policy-violation") { + condition = BOSHError::PolicyViolation; + } + else if (text == "remote-connection-failed") { + condition = BOSHError::RemoteConnectionFailed; + } + else if (text == "remote-stream-error") { + condition = BOSHError::RemoteStreamError; } + else if (text == "see-other-uri") { + condition = BOSHError::SeeOtherURI; + } + else if (text == "system-shutdown") { + condition = BOSHError::SystemShutdown; + } + else if (text == "") { + condition = BOSHError::NoError; + } + return condition; +} + +const std::string& BOSHConnection::getSID() { + return sid_; +} + +void BOSHConnection::setRID(unsigned long rid) { + rid_ = rid; +} + +void BOSHConnection::setSID(const std::string& sid) { + sid_ = sid; +} + +void BOSHConnection::handleDisconnected(const boost::optional<Error>& error) { + onDisconnected(error); + sid_ = ""; + connectionReady_ = false; +} + +HostAddressPort BOSHConnection::getLocalAddress() const { + return connection_->getLocalAddress(); +} + +bool BOSHConnection::isReadyToSend() { + /* Without pipelining you need to not send more without first receiving the response */ + /* With pipelining you can. Assuming we can't, here */ + return connectionReady_ && !pending_ && !waitingForStartResponse_ && !sid_.empty(); +} + } diff --git a/Swiften/Network/BOSHConnection.h b/Swiften/Network/BOSHConnection.h index 0da92ba..283ea10 100644 --- a/Swiften/Network/BOSHConnection.h +++ b/Swiften/Network/BOSHConnection.h @@ -4,6 +4,13 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + + #pragma once #include <boost/enable_shared_from_this.hpp> @@ -11,6 +18,9 @@ #include <Swiften/Network/Connection.h> #include <Swiften/Network/HostAddressPort.h> #include <Swiften/Base/String.h> +#include <Swiften/Base/URL.h> +#include <Swiften/Base/Error.h> +#include <Swiften/Session/SessionStream.h> namespace boost { class thread; @@ -21,32 +31,73 @@ namespace boost { namespace Swift { class ConnectionFactory; + class XMLParserFactory; + class TLSContextFactory; + + class BOSHError : public SessionStream::Error { + public: + enum Type {BadRequest, HostGone, HostUnknown, ImproperAddressing, + InternalServerError, ItemNotFound, OtherRequest, PolicyViolation, + RemoteConnectionFailed, RemoteStreamError, SeeOtherURI, SystemShutdown, UndefinedCondition, + NoError}; + BOSHError(Type type) : SessionStream::Error(SessionStream::Error::ConnectionReadError), type(type) {} + Type getType() {return type;} + typedef boost::shared_ptr<BOSHError> ref; + private: + Type type; + + }; + class BOSHConnection : public Connection, public boost::enable_shared_from_this<BOSHConnection> { public: typedef boost::shared_ptr<BOSHConnection> ref; - static ref create(ConnectionFactory* connectionFactory) { - return ref(new BOSHConnection(connectionFactory)); + static ref create(const URL& boshURL, ConnectionFactory* connectionFactory, XMLParserFactory* parserFactory, TLSContextFactory* tlsFactory) { + return ref(new BOSHConnection(boshURL, connectionFactory, parserFactory, tlsFactory)); } virtual ~BOSHConnection(); virtual void listen(); virtual void connect(const HostAddressPort& address); virtual void disconnect(); virtual void write(const SafeByteArray& data); + virtual HostAddressPort getLocalAddress() const; + const std::string& getSID(); + void setRID(unsigned long rid); + void setSID(const std::string& sid); + void startStream(const std::string& to, unsigned long rid); + void terminateStream(); + bool isReadyToSend(); + void restartStream(); + static std::pair<SafeByteArray, size_t> createHTTPRequest(const SafeByteArray& data, bool streamRestart, bool terminate, long rid, const std::string& sid, const URL& boshURL); + + boost::signal<void (BOSHError::ref)> onSessionTerminated; + boost::signal<void (const std::string& /*sid*/, size_t /*requests*/)> onSessionStarted; + boost::signal<void (const SafeByteArray&)> onXMPPDataRead; + boost::signal<void (const SafeByteArray&)> onBOSHDataRead; + boost::signal<void (const SafeByteArray&)> onBOSHDataWritten; + boost::signal<void (const std::string&)> onHTTPError; private: - BOSHConnection(ConnectionFactory* connectionFactory); + BOSHConnection(const URL& boshURL, ConnectionFactory* connectionFactory, XMLParserFactory* parserFactory, TLSContextFactory* tlsFactory); + void handleConnectionConnectFinished(bool error); - void handleDataRead(const SafeByteArray& data); + void handleDataRead(boost::shared_ptr<SafeByteArray> data); void handleDisconnected(const boost::optional<Error>& error); + void write(const SafeByteArray& data, bool streamRestart, bool terminate); /* FIXME: refactor */ + BOSHError::Type parseTerminationCondition(const std::string& text); - bool reopenAfterAction; + URL boshURL_; ConnectionFactory* connectionFactory_; - HostAddressPort server_; - boost::shared_ptr<Connection> newConnection_; - boost::shared_ptr<Connection> currentConnection_; + XMLParserFactory* parserFactory_; + boost::shared_ptr<Connection> connection_; std::string sid_; + bool waitingForStartResponse_; + unsigned long rid_; + SafeByteArray buffer_; + bool pending_; + TLSContextFactory* tlsFactory_; + bool connectionReady_; }; } diff --git a/Swiften/Network/BOSHConnectionFactory.cpp b/Swiften/Network/BOSHConnectionFactory.cpp index 4c49cae..7b83034 100644 --- a/Swiften/Network/BOSHConnectionFactory.cpp +++ b/Swiften/Network/BOSHConnectionFactory.cpp @@ -4,18 +4,23 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ -#include "BOSHConnectionFactory.h" +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include <Swiften/Network/BOSHConnectionFactory.h> #include <Swiften/Network/BOSHConnection.h> namespace Swift { -BOSHConnectionFactory::BOSHConnectionFactory(ConnectionFactory* connectionFactory) { - connectionFactory_ = connectionFactory; +BOSHConnectionFactory::BOSHConnectionFactory(const URL& boshURL, ConnectionFactory* connectionFactory, XMLParserFactory* xmlParserFactory, TLSContextFactory* tlsFactory) : boshURL(boshURL), connectionFactory(connectionFactory), xmlParserFactory(xmlParserFactory), tlsFactory(tlsFactory) { } -boost::shared_ptr<Connection> BOSHConnectionFactory::createConnection() { - return BOSHConnection::create(connectionFactory_); +boost::shared_ptr<Connection> BOSHConnectionFactory::createConnection(ConnectionFactory* overrideFactory) { + return BOSHConnection::create(boshURL, overrideFactory != NULL ? overrideFactory : connectionFactory, xmlParserFactory, tlsFactory); } } diff --git a/Swiften/Network/BOSHConnectionFactory.h b/Swiften/Network/BOSHConnectionFactory.h index 7431cf4..3750057 100644 --- a/Swiften/Network/BOSHConnectionFactory.h +++ b/Swiften/Network/BOSHConnectionFactory.h @@ -4,19 +4,40 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + #pragma once +#include <string> + #include <Swiften/Network/ConnectionFactory.h> #include <Swiften/Network/HostAddressPort.h> +#include <Swiften/TLS/TLSContextFactory.h> +#include <Swiften/Base/URL.h> namespace Swift { - class BOSHConnectionFactory : public ConnectionFactory { - public: - BOSHConnectionFactory(ConnectionFactory* connectionFactory); - virtual boost::shared_ptr<Connection> createConnection(); +class XMLParserFactory; + +class BOSHConnectionFactory { + public: + BOSHConnectionFactory(const URL& boshURL, ConnectionFactory* connectionFactory, XMLParserFactory* xmlParserFactory, TLSContextFactory* tlsFactory); + + /** + * @param overrideFactory If non-NULL, creates a connection over the given factory instead. + */ + boost::shared_ptr<Connection> createConnection(ConnectionFactory* overrideFactory); + ConnectionFactory* getRawConnectionFactory() {return connectionFactory;} + TLSContextFactory* getTLSContextFactory() {return tlsFactory;} + private: + URL boshURL; + ConnectionFactory* connectionFactory; + XMLParserFactory* xmlParserFactory; + TLSContextFactory* tlsFactory; +}; - private: - ConnectionFactory* connectionFactory_; - }; } diff --git a/Swiften/Network/BOSHConnectionPool.cpp b/Swiften/Network/BOSHConnectionPool.cpp new file mode 100644 index 0000000..6c3ba7e --- /dev/null +++ b/Swiften/Network/BOSHConnectionPool.cpp @@ -0,0 +1,247 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ +#include <Swiften/Network/BOSHConnectionPool.h> + +#include <climits> + +#include <boost/bind.hpp> +#include <boost/lexical_cast.hpp> + +#include <Swiften/Base/foreach.h> +#include <Swiften/Base/SafeString.h> +#include <Swiften/Network/TLSConnectionFactory.h> +#include <Swiften/Network/HTTPConnectProxiedConnectionFactory.h> + +namespace Swift { +BOSHConnectionPool::BOSHConnectionPool(boost::shared_ptr<BOSHConnectionFactory> connectionFactory, const std::string& to, long initialRID, const URL& boshHTTPConnectProxyURL, const SafeString& boshHTTPConnectProxyAuthID, const SafeString& boshHTTPConnectProxyAuthPassword) + : connectionFactory(connectionFactory), + rid(initialRID), + pendingTerminate(false), + to(to), + requestLimit(2), + restartCount(0), + pendingRestart(false) { + tlsConnectionFactory = NULL; + if (boshHTTPConnectProxyURL.empty()) { + connectProxyFactory = NULL; + } + else { + ConnectionFactory* rawFactory = connectionFactory->getRawConnectionFactory(); + if (boshHTTPConnectProxyURL.getScheme() == "https") { + tlsConnectionFactory = new TLSConnectionFactory(connectionFactory->getTLSContextFactory(), rawFactory); + rawFactory = tlsConnectionFactory; + } + connectProxyFactory = new HTTPConnectProxiedConnectionFactory(rawFactory, HostAddressPort(HostAddress(boshHTTPConnectProxyURL.getHost()), boshHTTPConnectProxyURL.getPort()), boshHTTPConnectProxyAuthID, boshHTTPConnectProxyAuthPassword); + } + createConnection(); +} + +BOSHConnectionPool::~BOSHConnectionPool() { + close(); + delete connectProxyFactory; + delete tlsConnectionFactory; +} + +void BOSHConnectionPool::write(const SafeByteArray& data) { + dataQueue.push_back(data); + tryToSendQueuedData(); +} + +void BOSHConnectionPool::handleDataRead(const SafeByteArray& data) { + onXMPPDataRead(data); + tryToSendQueuedData(); /* Will rebalance the connections */ +} + +void BOSHConnectionPool::restartStream() { + BOSHConnection::ref connection = getSuitableConnection(); + if (connection) { + pendingRestart = false; + rid++; + connection->setRID(rid); + connection->restartStream(); + restartCount++; + } + else { + pendingRestart = true; + } +} + +void BOSHConnectionPool::writeFooter() { + pendingTerminate = true; + tryToSendQueuedData(); +} + +void BOSHConnectionPool::close() { + /* TODO: Send a terminate here. */ + std::vector<BOSHConnection::ref> connectionCopies = connections; + foreach (BOSHConnection::ref connection, connectionCopies) { + if (connection) { + connection->disconnect(); + destroyConnection(connection); + } + } +} + +void BOSHConnectionPool::handleSessionStarted(const std::string& sessionID, size_t requests) { + sid = sessionID; + requestLimit = requests; + onSessionStarted(); +} + +void BOSHConnectionPool::handleConnectFinished(bool error, BOSHConnection::ref connection) { + if (error) { + onSessionTerminated(boost::make_shared<BOSHError>(BOSHError::UndefinedCondition)); + /*TODO: We can probably manage to not terminate the stream here and use the rid/ack retry + * logic to just swallow the error and try again (some number of tries). + */ + } + else { + if (sid.empty()) { + connection->startStream(to, rid); + } + if (pendingRestart) { + restartStream(); + } + tryToSendQueuedData(); + } +} + +BOSHConnection::ref BOSHConnectionPool::getSuitableConnection() { + BOSHConnection::ref suitableConnection; + foreach (BOSHConnection::ref connection, connections) { + if (connection->isReadyToSend()) { + suitableConnection = connection; + break; + } + } + + if (!suitableConnection && connections.size() < requestLimit) { + /* This is not a suitable connection because it won't have yet connected and added TLS if needed. */ + BOSHConnection::ref newConnection = createConnection(); + newConnection->setSID(sid); + } + assert(connections.size() <= requestLimit); + assert((!suitableConnection) || suitableConnection->isReadyToSend()); + return suitableConnection; +} + +void BOSHConnectionPool::tryToSendQueuedData() { + if (sid.empty()) { + /* If we've not got as far as stream start yet, pend */ + return; + } + + BOSHConnection::ref suitableConnection = getSuitableConnection(); + bool sent = false; + bool toSend = !dataQueue.empty(); + if (suitableConnection) { + if (toSend) { + rid++; + suitableConnection->setRID(rid); + SafeByteArray data; + foreach (const SafeByteArray& datum, dataQueue) { + data.insert(data.end(), datum.begin(), datum.end()); + } + suitableConnection->write(data); + sent = true; + dataQueue.clear(); + } + else if (pendingTerminate) { + rid++; + suitableConnection->setRID(rid); + suitableConnection->terminateStream(); + sent = true; + onSessionTerminated(boost::shared_ptr<BOSHError>()); + } + } + if (!pendingTerminate) { + /* Ensure there's always a session waiting to read data for us */ + bool pending = false; + foreach (BOSHConnection::ref connection, connections) { + if (connection && !connection->isReadyToSend()) { + pending = true; + } + } + if (!pending) { + if (restartCount >= 1) { + /* Don't open a second connection until we've restarted the stream twice - i.e. we've authed and resource bound.*/ + if (suitableConnection) { + rid++; + suitableConnection->setRID(rid); + suitableConnection->write(createSafeByteArray("")); + } + else { + /* My thought process I went through when writing this, to aid anyone else confused why this can happen... + * + * What to do here? I think this isn't possible. + If you didn't have two connections, suitable would have made one. + If you have two connections and neither is suitable, pending would be true. + If you have a non-pending connection, it's suitable. + + If I decide to do something here, remove assert above. + + Ah! Yes, because there's a period between creating the connection and it being connected. */ + } + } + } + } +} + +void BOSHConnectionPool::handleHTTPError(const std::string& /*errorCode*/) { + handleSessionTerminated(boost::make_shared<BOSHError>(BOSHError::UndefinedCondition)); +} + +void BOSHConnectionPool::handleConnectionDisconnected(const boost::optional<Connection::Error>& error, BOSHConnection::ref connection) { + destroyConnection(connection); + if (false && error) { + handleSessionTerminated(boost::make_shared<BOSHError>(BOSHError::UndefinedCondition)); + } + else { + /* We might have just freed up a connection slot to send with */ + tryToSendQueuedData(); + } +} + +boost::shared_ptr<BOSHConnection> BOSHConnectionPool::createConnection() { + BOSHConnection::ref connection = boost::dynamic_pointer_cast<BOSHConnection>(connectionFactory->createConnection(connectProxyFactory)); + connection->onXMPPDataRead.connect(boost::bind(&BOSHConnectionPool::handleDataRead, this, _1)); + connection->onSessionStarted.connect(boost::bind(&BOSHConnectionPool::handleSessionStarted, this, _1, _2)); + connection->onBOSHDataRead.connect(boost::bind(&BOSHConnectionPool::handleBOSHDataRead, this, _1)); + connection->onBOSHDataWritten.connect(boost::bind(&BOSHConnectionPool::handleBOSHDataWritten, this, _1)); + connection->onDisconnected.connect(boost::bind(&BOSHConnectionPool::handleConnectionDisconnected, this, _1, connection)); + connection->onConnectFinished.connect(boost::bind(&BOSHConnectionPool::handleConnectFinished, this, _1, connection)); + connection->onSessionTerminated.connect(boost::bind(&BOSHConnectionPool::handleSessionTerminated, this, _1)); + connection->onHTTPError.connect(boost::bind(&BOSHConnectionPool::handleHTTPError, this, _1)); + connection->connect(HostAddressPort(HostAddress("0.0.0.0"), 0)); + connections.push_back(connection); + return connection; +} + +void BOSHConnectionPool::destroyConnection(boost::shared_ptr<BOSHConnection> connection) { + connections.erase(std::remove(connections.begin(), connections.end(), connection), connections.end()); + connection->onXMPPDataRead.disconnect(boost::bind(&BOSHConnectionPool::handleDataRead, this, _1)); + connection->onSessionStarted.disconnect(boost::bind(&BOSHConnectionPool::handleSessionStarted, this, _1, _2)); + connection->onBOSHDataRead.disconnect(boost::bind(&BOSHConnectionPool::handleBOSHDataRead, this, _1)); + connection->onBOSHDataWritten.disconnect(boost::bind(&BOSHConnectionPool::handleBOSHDataWritten, this, _1)); + connection->onDisconnected.disconnect(boost::bind(&BOSHConnectionPool::handleConnectionDisconnected, this, _1, connection)); + connection->onConnectFinished.disconnect(boost::bind(&BOSHConnectionPool::handleConnectFinished, this, _1, connection)); + connection->onSessionTerminated.disconnect(boost::bind(&BOSHConnectionPool::handleSessionTerminated, this, _1)); + connection->onHTTPError.disconnect(boost::bind(&BOSHConnectionPool::handleHTTPError, this, _1)); +} + +void BOSHConnectionPool::handleSessionTerminated(BOSHError::ref error) { + onSessionTerminated(error); +} + +void BOSHConnectionPool::handleBOSHDataRead(const SafeByteArray& data) { + onBOSHDataRead(data); +} + +void BOSHConnectionPool::handleBOSHDataWritten(const SafeByteArray& data) { + onBOSHDataWritten(data); +} + +} diff --git a/Swiften/Network/BOSHConnectionPool.h b/Swiften/Network/BOSHConnectionPool.h new file mode 100644 index 0000000..85e598d --- /dev/null +++ b/Swiften/Network/BOSHConnectionPool.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + + +#pragma once + +#include <vector> + +#include <Swiften/Base/SafeString.h> +#include <Swiften/Network/BOSHConnectionFactory.h> +#include <Swiften/Network/BOSHConnection.h> + +namespace Swift { + class HTTPConnectProxiedConnectionFactory; + class TLSConnectionFactory; + class BOSHConnectionPool : public boost::bsignals::trackable { + public: + BOSHConnectionPool(boost::shared_ptr<BOSHConnectionFactory> factory, const std::string& to, long initialRID, const URL& boshHTTPConnectProxyURL, const SafeString& boshHTTPConnectProxyAuthID, const SafeString& boshHTTPConnectProxyAuthPassword); + ~BOSHConnectionPool(); + void write(const SafeByteArray& data); + void writeFooter(); + void close(); + void restartStream(); + + boost::signal<void (BOSHError::ref)> onSessionTerminated; + boost::signal<void ()> onSessionStarted; + boost::signal<void (const SafeByteArray&)> onXMPPDataRead; + boost::signal<void (const SafeByteArray&)> onBOSHDataRead; + boost::signal<void (const SafeByteArray&)> onBOSHDataWritten; + + private: + void handleDataRead(const SafeByteArray& data); + void handleSessionStarted(const std::string& sid, size_t requests); + void handleBOSHDataRead(const SafeByteArray& data); + void handleBOSHDataWritten(const SafeByteArray& data); + void handleSessionTerminated(BOSHError::ref condition); + void handleConnectFinished(bool, BOSHConnection::ref connection); + void handleConnectionDisconnected(const boost::optional<Connection::Error>& error, BOSHConnection::ref connection); + void handleHTTPError(const std::string& errorCode); + + private: + BOSHConnection::ref createConnection(); + void destroyConnection(BOSHConnection::ref connection); + void tryToSendQueuedData(); + BOSHConnection::ref getSuitableConnection(); + + private: + boost::shared_ptr<BOSHConnectionFactory> connectionFactory; + std::vector<BOSHConnection::ref> connections; + std::string sid; + unsigned long rid; + std::vector<SafeByteArray> dataQueue; + bool pendingTerminate; + std::string to; + size_t requestLimit; + int restartCount; + bool pendingRestart; + HTTPConnectProxiedConnectionFactory* connectProxyFactory; + TLSConnectionFactory* tlsConnectionFactory; + }; +} diff --git a/Swiften/Network/BoostNetworkFactories.cpp b/Swiften/Network/BoostNetworkFactories.cpp index 2b4c04b..488e519 100644 --- a/Swiften/Network/BoostNetworkFactories.cpp +++ b/Swiften/Network/BoostNetworkFactories.cpp @@ -17,7 +17,7 @@ namespace Swift { -BoostNetworkFactories::BoostNetworkFactories(EventLoop* eventLoop) { +BoostNetworkFactories::BoostNetworkFactories(EventLoop* eventLoop) : eventLoop(eventLoop){ timerFactory = new BoostTimerFactory(ioServiceThread.getIOService(), eventLoop); connectionFactory = new BoostConnectionFactory(ioServiceThread.getIOService(), eventLoop); domainNameResolver = new PlatformDomainNameResolver(eventLoop); diff --git a/Swiften/Network/BoostNetworkFactories.h b/Swiften/Network/BoostNetworkFactories.h index 3d268d1..c9b12da 100644 --- a/Swiften/Network/BoostNetworkFactories.h +++ b/Swiften/Network/BoostNetworkFactories.h @@ -53,6 +53,10 @@ namespace Swift { return proxyProvider; } + virtual EventLoop* getEventLoop() const { + return eventLoop; + } + private: BoostIOServiceThread ioServiceThread; TimerFactory* timerFactory; @@ -63,5 +67,6 @@ namespace Swift { XMLParserFactory* xmlParserFactory; PlatformTLSFactories* tlsFactories; ProxyProvider* proxyProvider; + EventLoop* eventLoop; }; } diff --git a/Swiften/Network/HTTPConnectProxiedConnection.cpp b/Swiften/Network/HTTPConnectProxiedConnection.cpp index e05a933..3e6c986 100644 --- a/Swiften/Network/HTTPConnectProxiedConnection.cpp +++ b/Swiften/Network/HTTPConnectProxiedConnection.cpp @@ -4,6 +4,13 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + + #include <Swiften/Network/HTTPConnectProxiedConnection.h> #include <iostream> @@ -11,15 +18,17 @@ #include <boost/thread.hpp> #include <boost/lexical_cast.hpp> +#include <Swiften/Base/Algorithm.h> #include <Swiften/Base/Log.h> #include <Swiften/Base/String.h> #include <Swiften/Base/ByteArray.h> #include <Swiften/Network/HostAddressPort.h> #include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/StringCodecs/Base64.h> using namespace Swift; -HTTPConnectProxiedConnection::HTTPConnectProxiedConnection(ConnectionFactory* connectionFactory, HostAddressPort proxy) : connectionFactory_(connectionFactory), proxy_(proxy), server_(HostAddressPort(HostAddress("0.0.0.0"), 0)) { +HTTPConnectProxiedConnection::HTTPConnectProxiedConnection(ConnectionFactory* connectionFactory, HostAddressPort proxy, const SafeString& authID, const SafeString& authPassword) : connectionFactory_(connectionFactory), proxy_(proxy), server_(HostAddressPort(HostAddress("0.0.0.0"), 0)), authID_(authID), authPassword_(authPassword) { connected_ = false; } @@ -65,8 +74,18 @@ void HTTPConnectProxiedConnection::handleConnectionConnectFinished(bool error) { connection_->onConnectFinished.disconnect(boost::bind(&HTTPConnectProxiedConnection::handleConnectionConnectFinished, shared_from_this(), _1)); if (!error) { std::stringstream connect; - connect << "CONNECT " << server_.getAddress().toString() << ":" << server_.getPort() << " HTTP/1.1\r\n\r\n"; - connection_->write(createSafeByteArray(connect.str())); + connect << "CONNECT " << server_.getAddress().toString() << ":" << server_.getPort() << " HTTP/1.1\r\n"; + SafeByteArray data = createSafeByteArray(connect.str()); + if (!authID_.empty() && !authPassword_.empty()) { + append(data, createSafeByteArray("Proxy-Authorization: Basic ")); + SafeByteArray credentials = authID_; + append(credentials, createSafeByteArray(":")); + append(credentials, authPassword_); + append(data, Base64::encode(credentials)); + append(data, createSafeByteArray("\r\n")); + } + append(data, createSafeByteArray("\r\n")); + connection_->write(data); } else { onConnectFinished(true); diff --git a/Swiften/Network/HTTPConnectProxiedConnection.h b/Swiften/Network/HTTPConnectProxiedConnection.h index d3f5b7a..02d3edd 100644 --- a/Swiften/Network/HTTPConnectProxiedConnection.h +++ b/Swiften/Network/HTTPConnectProxiedConnection.h @@ -4,12 +4,20 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + + #pragma once #include <boost/enable_shared_from_this.hpp> #include <Swiften/Network/Connection.h> #include <Swiften/Network/HostAddressPort.h> +#include <Swiften/Base/SafeString.h> namespace boost { class thread; @@ -27,8 +35,8 @@ namespace Swift { ~HTTPConnectProxiedConnection(); - static ref create(ConnectionFactory* connectionFactory, HostAddressPort proxy) { - return ref(new HTTPConnectProxiedConnection(connectionFactory, proxy)); + static ref create(ConnectionFactory* connectionFactory, HostAddressPort proxy, const SafeString& authID, const SafeString& authPassword) { + return ref(new HTTPConnectProxiedConnection(connectionFactory, proxy, authID, authPassword)); } virtual void listen(); @@ -38,7 +46,7 @@ namespace Swift { virtual HostAddressPort getLocalAddress() const; private: - HTTPConnectProxiedConnection(ConnectionFactory* connectionFactory, HostAddressPort proxy); + HTTPConnectProxiedConnection(ConnectionFactory* connectionFactory, HostAddressPort proxy, const SafeString& authID, const SafeString& authPassword); void handleConnectionConnectFinished(bool error); void handleDataRead(boost::shared_ptr<SafeByteArray> data); @@ -49,6 +57,8 @@ namespace Swift { ConnectionFactory* connectionFactory_; HostAddressPort proxy_; HostAddressPort server_; + SafeByteArray authID_; + SafeByteArray authPassword_; boost::shared_ptr<Connection> connection_; }; } diff --git a/Swiften/Network/HTTPConnectProxiedConnectionFactory.cpp b/Swiften/Network/HTTPConnectProxiedConnectionFactory.cpp index ab7f18e..6ad0228 100644 --- a/Swiften/Network/HTTPConnectProxiedConnectionFactory.cpp +++ b/Swiften/Network/HTTPConnectProxiedConnectionFactory.cpp @@ -10,11 +10,15 @@ namespace Swift { -HTTPConnectProxiedConnectionFactory::HTTPConnectProxiedConnectionFactory(ConnectionFactory* connectionFactory, const HostAddressPort& proxy) : connectionFactory_(connectionFactory), proxy_(proxy) { +HTTPConnectProxiedConnectionFactory::HTTPConnectProxiedConnectionFactory(ConnectionFactory* connectionFactory, const HostAddressPort& proxy) : connectionFactory_(connectionFactory), proxy_(proxy), authID_(""), authPassword_("") { +} + + +HTTPConnectProxiedConnectionFactory::HTTPConnectProxiedConnectionFactory(ConnectionFactory* connectionFactory, const HostAddressPort& proxy, const SafeString& authID, const SafeString& authPassword) : connectionFactory_(connectionFactory), proxy_(proxy), authID_(authID), authPassword_(authPassword) { } boost::shared_ptr<Connection> HTTPConnectProxiedConnectionFactory::createConnection() { - return HTTPConnectProxiedConnection::create(connectionFactory_, proxy_); + return HTTPConnectProxiedConnection::create(connectionFactory_, proxy_, authID_, authPassword_); } } diff --git a/Swiften/Network/HTTPConnectProxiedConnectionFactory.h b/Swiften/Network/HTTPConnectProxiedConnectionFactory.h index b475586..ef3af66 100644 --- a/Swiften/Network/HTTPConnectProxiedConnectionFactory.h +++ b/Swiften/Network/HTTPConnectProxiedConnectionFactory.h @@ -8,16 +8,20 @@ #include <Swiften/Network/ConnectionFactory.h> #include <Swiften/Network/HostAddressPort.h> +#include <Swiften/Base/SafeString.h> namespace Swift { class HTTPConnectProxiedConnectionFactory : public ConnectionFactory { public: HTTPConnectProxiedConnectionFactory(ConnectionFactory* connectionFactory, const HostAddressPort& proxy); + HTTPConnectProxiedConnectionFactory(ConnectionFactory* connectionFactory, const HostAddressPort& proxy, const SafeString& authID, const SafeString& authPassword); virtual boost::shared_ptr<Connection> createConnection(); private: ConnectionFactory* connectionFactory_; HostAddressPort proxy_; + SafeString authID_; + SafeString authPassword_; }; } diff --git a/Swiften/Network/NetworkFactories.h b/Swiften/Network/NetworkFactories.h index 6eba2f3..ebb6d62 100644 --- a/Swiften/Network/NetworkFactories.h +++ b/Swiften/Network/NetworkFactories.h @@ -16,6 +16,7 @@ namespace Swift { class TLSContextFactory; class CertificateFactory; class ProxyProvider; + class EventLoop; /** * An interface collecting network factories. @@ -32,5 +33,6 @@ namespace Swift { virtual XMLParserFactory* getXMLParserFactory() const = 0; virtual TLSContextFactory* getTLSContextFactory() const = 0; virtual ProxyProvider* getProxyProvider() const = 0; + virtual EventLoop* getEventLoop() const {}; }; } diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript index 399cec8..4a5370f 100644 --- a/Swiften/Network/SConscript +++ b/Swiften/Network/SConscript @@ -16,7 +16,8 @@ sourceList = [ "BoostConnectionServerFactory.cpp", "BoostIOServiceThread.cpp", "BOSHConnection.cpp", - "BOSHConnectionFactory.cpp" + "BOSHConnectionPool.cpp", + "BOSHConnectionFactory.cpp", "ConnectionFactory.cpp", "ConnectionServer.cpp", "ConnectionServerFactory.cpp", @@ -41,6 +42,8 @@ sourceList = [ "BoostNetworkFactories.cpp", "NetworkEnvironment.cpp", "Timer.cpp", + "TLSConnection.cpp", + "TLSConnectionFactory.cpp", "BoostTimer.cpp", "ProxyProvider.cpp", "NullProxyProvider.cpp", diff --git a/Swiften/Network/TLSConnection.cpp b/Swiften/Network/TLSConnection.cpp new file mode 100644 index 0000000..543ee1e --- /dev/null +++ b/Swiften/Network/TLSConnection.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include <Swiften/Network/TLSConnection.h> + +#include <boost/bind.hpp> + +#include <Swiften/Network/HostAddressPort.h> +#include <Swiften/TLS/TLSContext.h> +#include <Swiften/TLS/TLSContextFactory.h> + +namespace Swift { + +TLSConnection::TLSConnection(Connection::ref connection, TLSContextFactory* tlsFactory) : connection(connection) { + context = tlsFactory->createTLSContext(); + context->onDataForNetwork.connect(boost::bind(&TLSConnection::handleTLSDataForNetwork, this, _1)); + context->onDataForApplication.connect(boost::bind(&TLSConnection::handleTLSDataForApplication, this, _1)); + context->onConnected.connect(boost::bind(&TLSConnection::handleTLSConnectFinished, this, false)); + context->onError.connect(boost::bind(&TLSConnection::handleTLSConnectFinished, this, true)); + + connection->onConnectFinished.connect(boost::bind(&TLSConnection::handleRawConnectFinished, this, _1)); + connection->onDataRead.connect(boost::bind(&TLSConnection::handleRawDataRead, this, _1)); + connection->onDataWritten.connect(boost::bind(&TLSConnection::handleRawDataWritten, this)); + connection->onDisconnected.connect(boost::bind(&TLSConnection::handleRawDisconnected, this, _1)); +} + +TLSConnection::~TLSConnection() { + connection->onConnectFinished.disconnect(boost::bind(&TLSConnection::handleRawConnectFinished, this, _1)); + connection->onDataRead.disconnect(boost::bind(&TLSConnection::handleRawDataRead, this, _1)); + connection->onDataWritten.disconnect(boost::bind(&TLSConnection::handleRawDataWritten, this)); + connection->onDisconnected.disconnect(boost::bind(&TLSConnection::handleRawDisconnected, this, _1)); + delete context; +} + +void TLSConnection::handleTLSConnectFinished(bool error) { + onConnectFinished(error); + if (error) { + disconnect(); + } +} + +void TLSConnection::handleTLSDataForNetwork(const SafeByteArray& data) { + connection->write(data); +} + +void TLSConnection::handleTLSDataForApplication(const SafeByteArray& data) { + onDataRead(boost::make_shared<SafeByteArray>(data)); +} + +void TLSConnection::connect(const HostAddressPort& address) { + connection->connect(address); +} + +void TLSConnection::disconnect() { + connection->disconnect(); +} + +void TLSConnection::write(const SafeByteArray& data) { + context->handleDataFromApplication(data); +} + +HostAddressPort TLSConnection::getLocalAddress() const { + return connection->getLocalAddress(); +} + +void TLSConnection::handleRawConnectFinished(bool error) { + connection->onConnectFinished.disconnect(boost::bind(&TLSConnection::handleRawConnectFinished, this, _1)); + if (error) { + onConnectFinished(true); + } + else { + context->connect(); + } +} + +void TLSConnection::handleRawDisconnected(const boost::optional<Error>& error) { + onDisconnected(error); +} + +void TLSConnection::handleRawDataRead(boost::shared_ptr<SafeByteArray> data) { + context->handleDataFromNetwork(*data); +} + +void TLSConnection::handleRawDataWritten() { + onDataWritten(); +} + +} diff --git a/Swiften/Network/TLSConnection.h b/Swiften/Network/TLSConnection.h new file mode 100644 index 0000000..a798393 --- /dev/null +++ b/Swiften/Network/TLSConnection.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include <boost/shared_ptr.hpp> +#include <boost/enable_shared_from_this.hpp> +#include <Swiften/Base/boost_bsignals.h> + +#include <Swiften/Base/SafeByteArray.h> +#include <Swiften/Network/Connection.h> + +namespace Swift { + class HostAddressPort; + class TLSContextFactory; + class TLSContext; + + class TLSConnection : public Connection { + public: + + TLSConnection(Connection::ref connection, TLSContextFactory* tlsFactory); + virtual ~TLSConnection(); + + virtual void listen() {assert(false);}; + virtual void connect(const HostAddressPort& address); + virtual void disconnect(); + virtual void write(const SafeByteArray& data); + + virtual HostAddressPort getLocalAddress() const; + + private: + void handleRawConnectFinished(bool error); + void handleRawDisconnected(const boost::optional<Error>& error); + void handleRawDataRead(boost::shared_ptr<SafeByteArray> data); + void handleRawDataWritten(); + void handleTLSConnectFinished(bool error); + void handleTLSDataForNetwork(const SafeByteArray& data); + void handleTLSDataForApplication(const SafeByteArray& data); + private: + TLSContext* context; + Connection::ref connection; + }; +} diff --git a/Swiften/Network/TLSConnectionFactory.cpp b/Swiften/Network/TLSConnectionFactory.cpp new file mode 100644 index 0000000..0c21650 --- /dev/null +++ b/Swiften/Network/TLSConnectionFactory.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include <Swiften/Network/TLSConnectionFactory.h> + +#include <boost/shared_ptr.hpp> + +#include <Swiften/Network/TLSConnection.h> + +namespace Swift { + +TLSConnectionFactory::TLSConnectionFactory(TLSContextFactory* contextFactory, ConnectionFactory* connectionFactory) : contextFactory(contextFactory), connectionFactory(connectionFactory){ + +} + +TLSConnectionFactory::~TLSConnectionFactory() { + +} + + +boost::shared_ptr<Connection> TLSConnectionFactory::createConnection() { + return boost::make_shared<TLSConnection>(connectionFactory->createConnection(), contextFactory); +} + +} diff --git a/Swiften/Network/TLSConnectionFactory.h b/Swiften/Network/TLSConnectionFactory.h new file mode 100644 index 0000000..32757a1 --- /dev/null +++ b/Swiften/Network/TLSConnectionFactory.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include <boost/shared_ptr.hpp> + +#include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/TLS/TLSContextFactory.h> + +namespace Swift { + class Connection; + + class TLSConnectionFactory : public ConnectionFactory { + public: + TLSConnectionFactory(TLSContextFactory* contextFactory, ConnectionFactory* connectionFactory); + virtual ~TLSConnectionFactory(); + + virtual boost::shared_ptr<Connection> createConnection(); + private: + TLSContextFactory* contextFactory; + ConnectionFactory* connectionFactory; + }; +} diff --git a/Swiften/Network/UnitTest/BOSHConnectionPoolTest.cpp b/Swiften/Network/UnitTest/BOSHConnectionPoolTest.cpp new file mode 100644 index 0000000..978bf3b --- /dev/null +++ b/Swiften/Network/UnitTest/BOSHConnectionPoolTest.cpp @@ -0,0 +1,423 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include <QA/Checker/IO.h> + +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> + +#include <boost/optional.hpp> +#include <boost/bind.hpp> +#include <boost/smart_ptr/make_shared.hpp> +#include <boost/shared_ptr.hpp> +#include <boost/lexical_cast.hpp> + +#include <Swiften/Base/Algorithm.h> +#include <Swiften/Network/Connection.h> +#include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/Network/BOSHConnection.h> +#include <Swiften/Network/BOSHConnectionFactory.h> +#include <Swiften/Network/BOSHConnectionPool.h> +#include <Swiften/Network/HostAddressPort.h> +#include <Swiften/EventLoop/DummyEventLoop.h> +#include <Swiften/Parser/PlatformXMLParserFactory.h> + +using namespace Swift; + +typedef boost::shared_ptr<BOSHConnectionPool> PoolRef; + +class BOSHConnectionPoolTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(BOSHConnectionPoolTest); + CPPUNIT_TEST(testConnectionCount_OneWrite); + CPPUNIT_TEST(testConnectionCount_TwoWrites); + CPPUNIT_TEST(testConnectionCount_ThreeWrites); + CPPUNIT_TEST(testConnectionCount_ThreeWrites_ManualConnect); + CPPUNIT_TEST(testConnectionCount_ThreeWritesTwoReads); + CPPUNIT_TEST(testSession); + CPPUNIT_TEST(testWrite_Empty); + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + to = "wonderland.lit"; + path = "http-bind"; + port = "5280"; + sid = "MyShinySID"; + initial = "<body wait='60' " + "inactivity='30' " + "polling='5' " + "requests='2' " + "hold='1' " + "maxpause='120' " + "sid='" + sid + "' " + "ver='1.6' " + "from='wonderland.lit' " + "xmlns='http://jabber.org/protocol/httpbind'/>"; + eventLoop = new DummyEventLoop(); + connectionFactory = new MockConnectionFactory(eventLoop); + factory = boost::make_shared<BOSHConnectionFactory>(URL("http", to, 5280, path), connectionFactory, &parserFactory, static_cast<TLSContextFactory*>(NULL)); + sessionTerminated = 0; + sessionStarted = 0; + initialRID = 2349876; + xmppDataRead.clear(); + boshDataRead.clear(); + boshDataWritten.clear(); + } + + void tearDown() { + eventLoop->processEvents(); + delete connectionFactory; + delete eventLoop; + } + + void testConnectionCount_OneWrite() { + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(0, sessionStarted); + readResponse(initial, connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(1, sessionStarted); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + testling->write(createSafeByteArray("<blah/>")); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + CPPUNIT_ASSERT_EQUAL(1, sessionStarted); + } + + void testConnectionCount_TwoWrites() { + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + readResponse(initial, connectionFactory->connections[0]); + testling->write(createSafeByteArray("<blah/>")); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + testling->write(createSafeByteArray("<bleh/>")); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + } + + void testConnectionCount_ThreeWrites() { + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + readResponse(initial, connectionFactory->connections[0]); + testling->restartStream(); + readResponse("<body/>", connectionFactory->connections[0]); + testling->restartStream(); + readResponse("<body/>", connectionFactory->connections[0]); + testling->write(createSafeByteArray("<blah/>")); + testling->write(createSafeByteArray("<bleh/>")); + testling->write(createSafeByteArray("<bluh/>")); + eventLoop->processEvents(); + CPPUNIT_ASSERT(st(2) >= connectionFactory->connections.size()); + } + + void testConnectionCount_ThreeWrites_ManualConnect() { + connectionFactory->autoFinishConnect = false; + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + CPPUNIT_ASSERT_EQUAL(st(0), boshDataWritten.size()); /* Connection not connected yet, can't send data */ + + connectionFactory->connections[0]->onConnectFinished(false); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* Connection finished, stream header sent */ + + readResponse(initial, connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* Don't respond to initial data with a holding call */ + + testling->restartStream(); + readResponse("<body/>", connectionFactory->connections[0]); + testling->restartStream(); + + + testling->write(createSafeByteArray("<blah/>")); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); /* New connection isn't up yet. */ + + connectionFactory->connections[1]->onConnectFinished(false); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* New connection ready. */ + + testling->write(createSafeByteArray("<bleh/>")); + testling->write(createSafeByteArray("<bluh/>")); + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* New data can't be sent, no free connections. */ + eventLoop->processEvents(); + CPPUNIT_ASSERT(st(2) >= connectionFactory->connections.size()); + } + + void testConnectionCount_ThreeWritesTwoReads() { + boost::shared_ptr<MockConnection> c0; + boost::shared_ptr<MockConnection> c1; + long rid = initialRID; + + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + c0 = connectionFactory->connections[0]; + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* header*/ + + rid++; + readResponse(initial, c0); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + CPPUNIT_ASSERT(!c0->pending); + + rid++; + testling->restartStream(); + readResponse("<body/>", connectionFactory->connections[0]); + + rid++; + testling->write(createSafeByteArray("<blah/>")); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); /* 0 was waiting for response, open and send on 1 */ + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* data */ + c1 = connectionFactory->connections[1]; + std::string fullBody = "<body rid='" + boost::lexical_cast<std::string>(rid) + "' sid='" + sid + "' xmlns='http://jabber.org/protocol/httpbind'><blah/></body>"; /* check empty write */ + CPPUNIT_ASSERT_EQUAL(fullBody, lastBody()); + CPPUNIT_ASSERT(c0->pending); + CPPUNIT_ASSERT(c1->pending); + + + rid++; + readResponse("<body xmlns='http://jabber.org/protocol/httpbind'><message><splatploing/></message></body>", c0); /* Doesn't include necessary attributes - as the support is improved this'll start to fail */ + eventLoop->processEvents(); + CPPUNIT_ASSERT(!c0->pending); + CPPUNIT_ASSERT(c1->pending); + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* don't send empty in [0], still have [1] waiting */ + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + + rid++; + readResponse("<body xmlns='http://jabber.org/protocol/httpbind'><message><splatploing><blittlebarg/></splatploing></message></body>", c1); + eventLoop->processEvents(); + CPPUNIT_ASSERT(!c1->pending); + CPPUNIT_ASSERT(c0->pending); + CPPUNIT_ASSERT_EQUAL(st(5), boshDataWritten.size()); /* empty to make room */ + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + + rid++; + testling->write(createSafeByteArray("<bleh/>")); + CPPUNIT_ASSERT(c0->pending); + CPPUNIT_ASSERT(c1->pending); + CPPUNIT_ASSERT_EQUAL(st(6), boshDataWritten.size()); /* data */ + + rid++; + testling->write(createSafeByteArray("<bluh/>")); + CPPUNIT_ASSERT(c0->pending); + CPPUNIT_ASSERT(c1->pending); + CPPUNIT_ASSERT_EQUAL(st(6), boshDataWritten.size()); /* Don't send data, no room */ + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + } + + void testSession() { + to = "prosody.doomsong.co.uk"; + path = "http-bind/"; + factory = boost::make_shared<BOSHConnectionFactory>(URL("http", to, 5280, path), connectionFactory, &parserFactory, static_cast<TLSContextFactory*>(NULL)); + + + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* header*/ + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + std::string response = "<body authid='743da605-4c2e-4de1-afac-ac040dd4a940' xmpp:version='1.0' xmlns:stream='http://etherx.jabber.org/streams' xmlns:xmpp='urn:xmpp:xbosh' inactivity='60' wait='60' polling='5' secure='true' hold='1' from='prosody.doomsong.co.uk' ver='1.6' sid='743da605-4c2e-4de1-afac-ac040dd4a940' requests='2' xmlns='http://jabber.org/protocol/httpbind'><stream:features><auth xmlns='http://jabber.org/features/iq-auth'/><mechanisms xmlns='urn:ietf:params:xml:ns:xmpp-sasl'><mechanism>SCRAM-SHA-1</mechanism><mechanism>DIGEST-MD5</mechanism></mechanisms></stream:features></body>"; + readResponse(response, connectionFactory->connections[0]); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + std::string send = "<auth xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\" mechanism=\"SCRAM-SHA-1\">biwsbj1hZG1pbixyPWZhOWE5ZDhiLWZmMDctNGE4Yy04N2E3LTg4YWRiNDQxZGUwYg==</auth>"; + testling->write(createSafeByteArray(send)); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + response = "<body xmlns='http://jabber.org/protocol/httpbind' sid='743da605-4c2e-4de1-afac-ac040dd4a940' xmlns:stream = 'http://etherx.jabber.org/streams'><challenge xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>cj1mYTlhOWQ4Yi1mZjA3LTRhOGMtODdhNy04OGFkYjQ0MWRlMGJhZmZlMWNhMy1mMDJkLTQ5NzEtYjkyNS0yM2NlNWQ2MDQyMjYscz1OVGd5WkdWaFptTXRaVE15WXkwMFpXUmhMV0ZqTURRdFpqYzRNbUppWmpGa1pqWXgsaT00MDk2</challenge></body>"; + readResponse(response, connectionFactory->connections[0]); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + send = "<response xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\">Yz1iaXdzLHI9ZmE5YTlkOGItZmYwNy00YThjLTg3YTctODhhZGI0NDFkZTBiYWZmZTFjYTMtZjAyZC00OTcxLWI5MjUtMjNjZTVkNjA0MjI2LHA9aU11NWt3dDN2VWplU2RqL01Jb3VIRldkZjBnPQ==</response>"; + testling->write(createSafeByteArray(send)); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + response = "<body xmlns='http://jabber.org/protocol/httpbind' sid='743da605-4c2e-4de1-afac-ac040dd4a940' xmlns:stream = 'http://etherx.jabber.org/streams'><success xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>dj1YNmNBY3BBOWxHNjNOOXF2bVQ5S0FacERrVm89</success></body>"; + readResponse(response, connectionFactory->connections[0]); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + } + + void testWrite_Empty() { + boost::shared_ptr<MockConnection> c0; + + PoolRef testling = createTestling(); + c0 = connectionFactory->connections[0]; + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + + readResponse(initial, c0); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* Shouldn't have sent anything extra */ + testling->restartStream(); + CPPUNIT_ASSERT_EQUAL(st(2), boshDataWritten.size()); + readResponse("<body></body>", c0); + CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); + std::string fullBody = "<body rid='" + boost::lexical_cast<std::string>(initialRID + 2) + "' sid='" + sid + "' xmlns='http://jabber.org/protocol/httpbind'></body>"; + std::string response = boshDataWritten[2]; + size_t bodyPosition = response.find("\r\n\r\n"); + CPPUNIT_ASSERT_EQUAL(fullBody, response.substr(bodyPosition+4)); + + + } + + private: + + PoolRef createTestling() { + PoolRef pool = boost::make_shared<BOSHConnectionPool>(factory, to, initialRID, URL(), "", ""); + pool->onXMPPDataRead.connect(boost::bind(&BOSHConnectionPoolTest::handleXMPPDataRead, this, _1)); + pool->onBOSHDataRead.connect(boost::bind(&BOSHConnectionPoolTest::handleBOSHDataRead, this, _1)); + pool->onBOSHDataWritten.connect(boost::bind(&BOSHConnectionPoolTest::handleBOSHDataWritten, this, _1)); + pool->onSessionStarted.connect(boost::bind(&BOSHConnectionPoolTest::handleSessionStarted, this)); + pool->onSessionTerminated.connect(boost::bind(&BOSHConnectionPoolTest::handleSessionTerminated, this)); + return pool; + } + + std::string lastBody() { + std::string response = boshDataWritten[boshDataWritten.size() - 1]; + size_t bodyPosition = response.find("\r\n\r\n"); + return response.substr(bodyPosition+4); + } + + size_t st(int val) { + return static_cast<size_t>(val); + } + + void handleXMPPDataRead(const SafeByteArray& d) { + xmppDataRead.push_back(safeByteArrayToString(d)); + } + + void handleBOSHDataRead(const SafeByteArray& d) { + boshDataRead.push_back(safeByteArrayToString(d)); + } + + void handleBOSHDataWritten(const SafeByteArray& d) { + boshDataWritten.push_back(safeByteArrayToString(d)); + } + + + void handleSessionStarted() { + sessionStarted++; + } + + void handleSessionTerminated() { + sessionTerminated++; + } + + struct MockConnection : public Connection { + public: + MockConnection(const std::vector<HostAddressPort>& failingPorts, EventLoop* eventLoop, bool autoFinishConnect) : eventLoop(eventLoop), failingPorts(failingPorts), disconnected(false), pending(false), autoFinishConnect(autoFinishConnect) { + } + + void listen() { assert(false); } + + void connect(const HostAddressPort& address) { + hostAddressPort = address; + bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); + if (autoFinishConnect) { + eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail)); + } + } + + HostAddressPort getLocalAddress() const { return HostAddressPort(); } + + void disconnect() { + disconnected = true; + onDisconnected(boost::optional<Connection::Error>()); + } + + void write(const SafeByteArray& d) { + append(dataWritten, d); + pending = true; + } + + EventLoop* eventLoop; + boost::optional<HostAddressPort> hostAddressPort; + std::vector<HostAddressPort> failingPorts; + ByteArray dataWritten; + bool disconnected; + bool pending; + bool autoFinishConnect; + }; + + struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory(EventLoop* eventLoop, bool autoFinishConnect = true) : eventLoop(eventLoop), autoFinishConnect(autoFinishConnect) { + } + + boost::shared_ptr<Connection> createConnection() { + boost::shared_ptr<MockConnection> connection = boost::make_shared<MockConnection>(failingPorts, eventLoop, autoFinishConnect); + connections.push_back(connection); + return connection; + } + + EventLoop* eventLoop; + std::vector< boost::shared_ptr<MockConnection> > connections; + std::vector<HostAddressPort> failingPorts; + bool autoFinishConnect; + }; + + void readResponse(const std::string& response, boost::shared_ptr<MockConnection> connection) { + connection->pending = false; + boost::shared_ptr<SafeByteArray> data1 = boost::make_shared<SafeByteArray>(createSafeByteArray( + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/xml; charset=utf-8\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Access-Control-Allow-Headers: Content-Type\r\n" + "Content-Length: ")); + connection->onDataRead(data1); + boost::shared_ptr<SafeByteArray> data2 = boost::make_shared<SafeByteArray>(createSafeByteArray(boost::lexical_cast<std::string>(response.size()))); + connection->onDataRead(data2); + boost::shared_ptr<SafeByteArray> data3 = boost::make_shared<SafeByteArray>(createSafeByteArray("\r\n\r\n")); + connection->onDataRead(data3); + boost::shared_ptr<SafeByteArray> data4 = boost::make_shared<SafeByteArray>(createSafeByteArray(response)); + connection->onDataRead(data4); + } + + std::string fullRequestFor(const std::string& data) { + std::string body = data; + std::string result = "POST /" + path + " HTTP/1.1\r\n" + + "Host: " + to + ":" + port + "\r\n" + + "Content-Type: text/xml; charset=utf-8\r\n" + + "Content-Length: " + boost::lexical_cast<std::string>(body.size()) + "\r\n\r\n" + + body; + return result; + } + + private: + DummyEventLoop* eventLoop; + MockConnectionFactory* connectionFactory; + std::vector<std::string> xmppDataRead; + std::vector<std::string> boshDataRead; + std::vector<std::string> boshDataWritten; + PlatformXMLParserFactory parserFactory; + std::string to; + std::string path; + std::string port; + std::string sid; + std::string initial; + long initialRID; + boost::shared_ptr<BOSHConnectionFactory> factory; + int sessionStarted; + int sessionTerminated; + +}; + + +CPPUNIT_TEST_SUITE_REGISTRATION(BOSHConnectionPoolTest); diff --git a/Swiften/Network/UnitTest/BOSHConnectionTest.cpp b/Swiften/Network/UnitTest/BOSHConnectionTest.cpp new file mode 100644 index 0000000..9215725 --- /dev/null +++ b/Swiften/Network/UnitTest/BOSHConnectionTest.cpp @@ -0,0 +1,291 @@ +/* + * Copyright (c) 2011 Kevin Smith + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include <QA/Checker/IO.h> + +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> + +#include <boost/optional.hpp> +#include <boost/bind.hpp> +#include <boost/smart_ptr/make_shared.hpp> +#include <boost/shared_ptr.hpp> +#include <boost/lexical_cast.hpp> + +#include <Swiften/Base/Algorithm.h> +#include <Swiften/Network/Connection.h> +#include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/Network/BOSHConnection.h> +#include <Swiften/Network/HostAddressPort.h> +#include <Swiften/EventLoop/DummyEventLoop.h> +#include <Swiften/Parser/PlatformXMLParserFactory.h> + +using namespace Swift; + +class BOSHConnectionTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(BOSHConnectionTest); + CPPUNIT_TEST(testHeader); + CPPUNIT_TEST(testReadiness_ok); + CPPUNIT_TEST(testReadiness_pending); + CPPUNIT_TEST(testReadiness_disconnect); + CPPUNIT_TEST(testReadiness_noSID); + CPPUNIT_TEST(testWrite_Receive); + CPPUNIT_TEST(testWrite_ReceiveTwice); + CPPUNIT_TEST(testRead_Fragment); + CPPUNIT_TEST(testHTTPRequest); + CPPUNIT_TEST(testHTTPRequest_Empty); + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + eventLoop = new DummyEventLoop(); + connectionFactory = new MockConnectionFactory(eventLoop); + connectFinished = false; + disconnected = false; + dataRead.clear(); + } + + void tearDown() { + eventLoop->processEvents(); + delete connectionFactory; + delete eventLoop; + } + + void testHeader() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + testling->startStream("wonderland.lit", 1); + std::string initial("<body wait='60' " + "inactivity='30' " + "polling='5' " + "requests='2' " + "hold='1' " + "maxpause='120' " + "sid='MyShinySID' " + "ver='1.6' " + "from='wonderland.lit' " + "xmlns='http://jabber.org/protocol/httpbind'/>"); + readResponse(initial, connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(std::string("MyShinySID"), sid); + CPPUNIT_ASSERT(testling->isReadyToSend()); + } + + void testReadiness_ok() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + testling->setSID("blahhhhh"); + CPPUNIT_ASSERT(testling->isReadyToSend()); + } + + void testReadiness_pending() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + testling->setSID("mySID"); + CPPUNIT_ASSERT(testling->isReadyToSend()); + testling->write(createSafeByteArray("<mypayload/>")); + CPPUNIT_ASSERT(!testling->isReadyToSend()); + readResponse("<body><blah/></body>", connectionFactory->connections[0]); + CPPUNIT_ASSERT(testling->isReadyToSend()); + } + + void testReadiness_disconnect() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + testling->setSID("mySID"); + CPPUNIT_ASSERT(testling->isReadyToSend()); + connectionFactory->connections[0]->onDisconnected(false); + CPPUNIT_ASSERT(!testling->isReadyToSend()); + } + + + void testReadiness_noSID() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + CPPUNIT_ASSERT(!testling->isReadyToSend()); + } + + void testWrite_Receive() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + testling->setSID("mySID"); + testling->write(createSafeByteArray("<mypayload/>")); + readResponse("<body><blah/></body>", connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(std::string("<blah/>"), byteArrayToString(dataRead)); + + } + + void testWrite_ReceiveTwice() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + testling->setSID("mySID"); + testling->write(createSafeByteArray("<mypayload/>")); + readResponse("<body><blah/></body>", connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(std::string("<blah/>"), byteArrayToString(dataRead)); + dataRead.clear(); + testling->write(createSafeByteArray("<mypayload2/>")); + readResponse("<body><bleh/></body>", connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(std::string("<bleh/>"), byteArrayToString(dataRead)); + } + + void testRead_Fragment() { + BOSHConnection::ref testling = createTestling(); + testling->connect(HostAddressPort(HostAddress("127.0.0.1"), 5280)); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(1), connectionFactory->connections.size()); + boost::shared_ptr<MockConnection> connection = connectionFactory->connections[0]; + boost::shared_ptr<SafeByteArray> data1 = boost::make_shared<SafeByteArray>(createSafeByteArray( + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/xml; charset=utf-8\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Access-Control-Allow-Headers: Content-Type\r\n" + "Content-Length: 64\r\n")); + boost::shared_ptr<SafeByteArray> data2 = boost::make_shared<SafeByteArray>(createSafeByteArray( + "\r\n<body xmlns='http://jabber.org/protocol/httpbind'>" + "<bl")); + boost::shared_ptr<SafeByteArray> data3 = boost::make_shared<SafeByteArray>(createSafeByteArray( + "ah/>" + "</body>")); + connection->onDataRead(data1); + connection->onDataRead(data2); + CPPUNIT_ASSERT(dataRead.empty()); + connection->onDataRead(data3); + CPPUNIT_ASSERT_EQUAL(std::string("<blah/>"), byteArrayToString(dataRead)); + } + + void testHTTPRequest() { + std::string data = "<blah/>"; + std::string sid = "wigglebloom"; + std::string fullBody = "<body xmlns='http://jabber.org/protocol/httpbind' sid='" + sid + "' rid='20'>" + data + "</body>"; + std::pair<SafeByteArray, size_t> http = BOSHConnection::createHTTPRequest(createSafeByteArray(data), false, false, 20, sid, URL()); + CPPUNIT_ASSERT_EQUAL(fullBody.size(), http.second); + } + + void testHTTPRequest_Empty() { + std::string data = ""; + std::string sid = "wigglebloomsickle"; + std::string fullBody = "<body rid='42' sid='" + sid + "' xmlns='http://jabber.org/protocol/httpbind'>" + data + "</body>"; + std::pair<SafeByteArray, size_t> http = BOSHConnection::createHTTPRequest(createSafeByteArray(data), false, false, 42, sid, URL()); + CPPUNIT_ASSERT_EQUAL(fullBody.size(), http.second); + std::string response = safeByteArrayToString(http.first); + size_t bodyPosition = response.find("\r\n\r\n"); + CPPUNIT_ASSERT_EQUAL(fullBody, response.substr(bodyPosition+4)); + } + + private: + + BOSHConnection::ref createTestling() { + BOSHConnection::ref c = BOSHConnection::create(URL("http", "wonderland.lit", 5280, "http-bind"), connectionFactory, &parserFactory, static_cast<TLSContextFactory*>(NULL)); + c->onConnectFinished.connect(boost::bind(&BOSHConnectionTest::handleConnectFinished, this, _1)); + c->onDisconnected.connect(boost::bind(&BOSHConnectionTest::handleDisconnected, this, _1)); + c->onXMPPDataRead.connect(boost::bind(&BOSHConnectionTest::handleDataRead, this, _1)); + c->onSessionStarted.connect(boost::bind(&BOSHConnectionTest::handleSID, this, _1)); + c->setRID(42); + return c; + } + + void handleConnectFinished(bool error) { + connectFinished = true; + connectFinishedWithError = error; + } + + void handleDisconnected(const boost::optional<Connection::Error>& e) { + disconnected = true; + disconnectedError = e; + } + + void handleDataRead(const SafeByteArray& d) { + append(dataRead, d); + } + + void handleSID(const std::string& s) { + sid = s; + } + + struct MockConnection : public Connection { + public: + MockConnection(const std::vector<HostAddressPort>& failingPorts, EventLoop* eventLoop) : eventLoop(eventLoop), failingPorts(failingPorts), disconnected(false) { + } + + void listen() { assert(false); } + + void connect(const HostAddressPort& address) { + hostAddressPort = address; + bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); + eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail)); + } + + HostAddressPort getLocalAddress() const { return HostAddressPort(); } + + void disconnect() { + disconnected = true; + onDisconnected(boost::optional<Connection::Error>()); + } + + void write(const SafeByteArray& d) { + append(dataWritten, d); + } + + EventLoop* eventLoop; + boost::optional<HostAddressPort> hostAddressPort; + std::vector<HostAddressPort> failingPorts; + ByteArray dataWritten; + bool disconnected; + }; + + struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory(EventLoop* eventLoop) : eventLoop(eventLoop) { + } + + boost::shared_ptr<Connection> createConnection() { + boost::shared_ptr<MockConnection> connection = boost::make_shared<MockConnection>(failingPorts, eventLoop); + connections.push_back(connection); + return connection; + } + + EventLoop* eventLoop; + std::vector< boost::shared_ptr<MockConnection> > connections; + std::vector<HostAddressPort> failingPorts; + }; + + void readResponse(const std::string& response, boost::shared_ptr<MockConnection> connection) { + boost::shared_ptr<SafeByteArray> data1 = boost::make_shared<SafeByteArray>(createSafeByteArray( + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/xml; charset=utf-8\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Access-Control-Allow-Headers: Content-Type\r\n" + "Content-Length: ")); + connection->onDataRead(data1); + boost::shared_ptr<SafeByteArray> data2 = boost::make_shared<SafeByteArray>(createSafeByteArray(boost::lexical_cast<std::string>(response.size()))); + connection->onDataRead(data2); + boost::shared_ptr<SafeByteArray> data3 = boost::make_shared<SafeByteArray>(createSafeByteArray("\r\n\r\n")); + connection->onDataRead(data3); + boost::shared_ptr<SafeByteArray> data4 = boost::make_shared<SafeByteArray>(createSafeByteArray(response)); + connection->onDataRead(data4); + } + + + private: + DummyEventLoop* eventLoop; + MockConnectionFactory* connectionFactory; + bool connectFinished; + bool connectFinishedWithError; + bool disconnected; + boost::optional<Connection::Error> disconnectedError; + ByteArray dataRead; + PlatformXMLParserFactory parserFactory; + std::string sid; + +}; + + +CPPUNIT_TEST_SUITE_REGISTRATION(BOSHConnectionTest); diff --git a/Swiften/Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp b/Swiften/Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp index 133773f..c0252d4 100644 --- a/Swiften/Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp +++ b/Swiften/Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp @@ -163,7 +163,7 @@ class HTTPConnectProxiedConnectionTest : public CppUnit::TestFixture { private: HTTPConnectProxiedConnection::ref createTestling() { - boost::shared_ptr<HTTPConnectProxiedConnection> c = HTTPConnectProxiedConnection::create(connectionFactory, proxyHost); + boost::shared_ptr<HTTPConnectProxiedConnection> c = HTTPConnectProxiedConnection::create(connectionFactory, proxyHost, "", ""); c->onConnectFinished.connect(boost::bind(&HTTPConnectProxiedConnectionTest::handleConnectFinished, this, _1)); c->onDisconnected.connect(boost::bind(&HTTPConnectProxiedConnectionTest::handleDisconnected, this, _1)); c->onDataRead.connect(boost::bind(&HTTPConnectProxiedConnectionTest::handleDataRead, this, _1)); |