diff options
Diffstat (limited to 'Swiften/Network')
138 files changed, 7131 insertions, 5610 deletions
diff --git a/Swiften/Network/BOSHConnection.cpp b/Swiften/Network/BOSHConnection.cpp index 23772eb..1312a3e 100644 --- a/Swiften/Network/BOSHConnection.cpp +++ b/Swiften/Network/BOSHConnection.cpp @@ -5,295 +5,392 @@ */ /* - * Copyright (c) 2011 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/BOSHConnection.h> +#include <string> +#include <thread> + #include <boost/bind.hpp> -#include <boost/thread.hpp> #include <boost/lexical_cast.hpp> -#include <string> +#include <Swiften/Base/ByteArray.h> +#include <Swiften/Base/Concat.h> #include <Swiften/Base/Log.h> #include <Swiften/Base/String.h> -#include <Swiften/Base/Concat.h> -#include <Swiften/Base/ByteArray.h> #include <Swiften/Network/HostAddressPort.h> #include <Swiften/Parser/BOSHBodyExtractor.h> +#include <Swiften/StreamStack/DummyStreamLayer.h> +#include <Swiften/StreamStack/TLSLayer.h> +#include <Swiften/TLS/TLSContext.h> +#include <Swiften/TLS/TLSContextFactory.h> +#include <Swiften/TLS/TLSOptions.h> namespace Swift { -BOSHConnection::BOSHConnection(const URL& boshURL, Connector::ref connector, XMLParserFactory* parserFactory) - : boshURL_(boshURL), - connector_(connector), - parserFactory_(parserFactory), - sid_(), - waitingForStartResponse_(false), - pending_(false), - connectionReady_(false) +BOSHConnection::BOSHConnection(const URL& boshURL, Connector::ref connector, XMLParserFactory* parserFactory, TLSContextFactory* tlsContextFactory, const TLSOptions& tlsOptions) + : boshURL_(boshURL), + connector_(connector), + parserFactory_(parserFactory), + sid_(), + waitingForStartResponse_(false), + rid_(~0ULL), + pending_(false), + connectionReady_(false) { + if (boshURL_.getScheme() == "https") { + auto tlsContext = tlsContextFactory->createTLSContext(tlsOptions); + tlsLayer_ = std::make_shared<TLSLayer>(std::move(tlsContext)); + // The following dummyLayer_ is needed as the TLSLayer will pass the decrypted data to its parent layer. + // The dummyLayer_ will serve as the parent layer. + dummyLayer_ = std::make_shared<DummyStreamLayer>(tlsLayer_.get()); + } } BOSHConnection::~BOSHConnection() { - cancelConnector(); - if (connection_) { - connection_->onDataRead.disconnect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); - connection_->onDisconnected.disconnect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); - } - disconnect(); + cancelConnector(); + if (connection_) { + connection_->onDataRead.disconnect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); + connection_->onDisconnected.disconnect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); + } + BOSHConnection::disconnect(); } void BOSHConnection::connect() { - connector_->onConnectFinished.connect(boost::bind(&BOSHConnection::handleConnectFinished, shared_from_this(), _1)); - connector_->start(); + connector_->onConnectFinished.connect(boost::bind(&BOSHConnection::handleConnectFinished, shared_from_this(), _1)); + connector_->start(); } void BOSHConnection::cancelConnector() { - if (connector_) { - connector_->onConnectFinished.disconnect(boost::bind(&BOSHConnection::handleConnectFinished, shared_from_this(), _1)); - connector_->stop(); - connector_.reset(); - } + if (connector_) { + connector_->onConnectFinished.disconnect(boost::bind(&BOSHConnection::handleConnectFinished, shared_from_this(), _1)); + connector_->stop(); + connector_.reset(); + } +} + +void BOSHConnection::handleTLSConnected() { + SWIFT_LOG(debug); + onConnectFinished(false); +} + +void BOSHConnection::handleTLSApplicationDataRead(const SafeByteArray& data) { + SWIFT_LOG(debug); + handleDataRead(std::make_shared<SafeByteArray>(data)); +} + +void BOSHConnection::handleTLSNetowrkDataWriteRequest(const SafeByteArray& data) { + SWIFT_LOG(debug); + connection_->write(data); +} + +void BOSHConnection::handleRawDataRead(std::shared_ptr<SafeByteArray> data) { + SWIFT_LOG(debug); + tlsLayer_->handleDataRead(*data.get()); +} + +void BOSHConnection::handleTLSError(std::shared_ptr<TLSError> error) { + SWIFT_LOG(debug) << (error ? error->getMessage() : "Unknown TLS error"); +} + +void BOSHConnection::writeData(const SafeByteArray& data) { + if (tlsLayer_) { + tlsLayer_->writeData(data); + } + else { + connection_->write(data); + } } void BOSHConnection::disconnect() { - if (connection_) { - connection_->disconnect(); - sid_ = ""; - } - else { - /* handleDisconnected takes care of the connector_ as well */ - handleDisconnected(boost::optional<Connection::Error>()); - } + if (connection_) { + connection_->disconnect(); + sid_ = ""; + } + else { + /* handleDisconnected takes care of the connector_ as well */ + handleDisconnected(boost::optional<Connection::Error>()); + } } void BOSHConnection::restartStream() { - write(createSafeByteArray(""), true, false); + write(createSafeByteArray(""), true, false); +} + +bool BOSHConnection::setClientCertificate(CertificateWithKey::ref cert) { + if (tlsLayer_) { + SWIFT_LOG(debug) << "set client certificate"; + return tlsLayer_->setClientCertificate(cert); + } + else { + return false; + } +} + +Certificate::ref BOSHConnection::getPeerCertificate() const { + Certificate::ref peerCertificate; + if (tlsLayer_) { + peerCertificate = tlsLayer_->getPeerCertificate(); + } + return peerCertificate; +} + +std::vector<Certificate::ref> BOSHConnection::getPeerCertificateChain() const { + std::vector<Certificate::ref> peerCertificateChain; + if (tlsLayer_) { + peerCertificateChain = tlsLayer_->getPeerCertificateChain(); + } + return peerCertificateChain; +} + +CertificateVerificationError::ref BOSHConnection::getPeerCertificateVerificationError() const { + CertificateVerificationError::ref verificationError; + if (tlsLayer_) { + verificationError = tlsLayer_->getPeerCertificateVerificationError(); + } + return verificationError; } void BOSHConnection::terminateStream() { - write(createSafeByteArray(""), false, true); + write(createSafeByteArray(""), false, true); } void BOSHConnection::write(const SafeByteArray& data) { - write(data, false, false); + write(data, false, false); } std::pair<SafeByteArray, size_t> BOSHConnection::createHTTPRequest(const SafeByteArray& data, bool streamRestart, bool terminate, unsigned long long rid, const std::string& sid, const URL& boshURL) { - size_t size; - std::stringstream content; - SafeByteArray contentTail = createSafeByteArray("</body>"); - std::stringstream header; - - content << "<body rid='" << rid << "' sid='" << sid << "'"; - if (streamRestart) { - content << " xmpp:restart='true' xmlns:xmpp='urn:xmpp:xbosh'"; - } - if (terminate) { - content << " type='terminate'"; - } - content << " xmlns='http://jabber.org/protocol/httpbind'>"; - - SafeByteArray safeContent = createSafeByteArray(content.str()); - safeContent.insert(safeContent.end(), data.begin(), data.end()); - safeContent.insert(safeContent.end(), contentTail.begin(), contentTail.end()); - - size = safeContent.size(); - - header << "POST " << boshURL.getPath() << " HTTP/1.1\r\n" - << "Host: " << boshURL.getHost(); - if (boshURL.getPort()) { - header << ":" << *boshURL.getPort(); - } - header << "\r\n" - // << "Accept-Encoding: deflate\r\n" - << "Content-Type: text/xml; charset=utf-8\r\n" - << "Content-Length: " << size << "\r\n\r\n"; - - SafeByteArray safeHeader = createSafeByteArray(header.str()); - safeHeader.insert(safeHeader.end(), safeContent.begin(), safeContent.end()); - - return std::pair<SafeByteArray, size_t>(safeHeader, size); + size_t size; + std::stringstream content; + SafeByteArray contentTail = createSafeByteArray("</body>"); + std::stringstream header; + + content << "<body rid='" << rid << "' sid='" << sid << "'"; + if (streamRestart) { + content << " xmpp:restart='true' xmlns:xmpp='urn:xmpp:xbosh'"; + } + if (terminate) { + content << " type='terminate'"; + } + content << " xmlns='http://jabber.org/protocol/httpbind'>"; + + SafeByteArray safeContent = createSafeByteArray(content.str()); + safeContent.insert(safeContent.end(), data.begin(), data.end()); + safeContent.insert(safeContent.end(), contentTail.begin(), contentTail.end()); + + size = safeContent.size(); + + header << "POST " << boshURL.getPath() << " HTTP/1.1\r\n" + << "Host: " << boshURL.getHost(); + if (boshURL.getPort()) { + header << ":" << *boshURL.getPort(); + } + header << "\r\n" + // << "Accept-Encoding: deflate\r\n" + << "Content-Type: text/xml; charset=utf-8\r\n" + << "Content-Length: " << size << "\r\n\r\n"; + + SafeByteArray safeHeader = createSafeByteArray(header.str()); + safeHeader.insert(safeHeader.end(), safeContent.begin(), safeContent.end()); + + return std::pair<SafeByteArray, size_t>(safeHeader, size); } void BOSHConnection::write(const SafeByteArray& data, bool streamRestart, bool terminate) { - assert(connectionReady_); - assert(!sid_.empty()); + assert(connectionReady_); + assert(!sid_.empty()); - SafeByteArray safeHeader = createHTTPRequest(data, streamRestart, terminate, rid_, sid_, boshURL_).first; + SafeByteArray safeHeader = createHTTPRequest(data, streamRestart, terminate, rid_, sid_, boshURL_).first; - onBOSHDataWritten(safeHeader); - connection_->write(safeHeader); - pending_ = true; + onBOSHDataWritten(safeHeader); + writeData(safeHeader); + pending_ = true; - SWIFT_LOG(debug) << "write data: " << safeByteArrayToString(safeHeader) << std::endl; + SWIFT_LOG(debug) << "write data: " << safeByteArrayToString(safeHeader); } void BOSHConnection::handleConnectFinished(Connection::ref connection) { - cancelConnector(); - connectionReady_ = connection; - if (connectionReady_) { - connection_ = connection; - connection_->onDataRead.connect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); - connection_->onDisconnected.connect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); - } - onConnectFinished(!connectionReady_); + cancelConnector(); + connectionReady_ = !!connection; + if (connectionReady_) { + connection_ = connection; + if (tlsLayer_) { + connection_->onDataRead.connect(boost::bind(&BOSHConnection::handleRawDataRead, shared_from_this(), _1)); + connection_->onDisconnected.connect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); + + tlsLayer_->getContext()->onDataForNetwork.connect(boost::bind(&BOSHConnection::handleTLSNetowrkDataWriteRequest, shared_from_this(), _1)); + tlsLayer_->getContext()->onDataForApplication.connect(boost::bind(&BOSHConnection::handleTLSApplicationDataRead, shared_from_this(), _1)); + tlsLayer_->onConnected.connect(boost::bind(&BOSHConnection::handleTLSConnected, shared_from_this())); + tlsLayer_->onError.connect(boost::bind(&BOSHConnection::handleTLSError, shared_from_this(), _1)); + tlsLayer_->connect(); + } + else { + connection_->onDataRead.connect(boost::bind(&BOSHConnection::handleDataRead, shared_from_this(), _1)); + connection_->onDisconnected.connect(boost::bind(&BOSHConnection::handleDisconnected, shared_from_this(), _1)); + } + } + + if (!connectionReady_ || !tlsLayer_) { + onConnectFinished(!connectionReady_); + } } void BOSHConnection::startStream(const std::string& to, unsigned long long rid) { - assert(connectionReady_); - // Session Creation Request - std::stringstream content; - std::stringstream header; - - content << "<body content='text/xml; charset=utf-8'" - << " hold='1'" - << " to='" << to << "'" - << " rid='" << rid << "'" - << " ver='1.6'" - << " wait='60'" /* FIXME: we probably want this configurable*/ - // << " ack='0'" FIXME: support acks - << " xml:lang='en'" - << " xmlns:xmpp='urn:xmpp:bosh'" - << " xmpp:version='1.0'" - << " xmlns='http://jabber.org/protocol/httpbind' />"; - - std::string contentString = content.str(); - - header << "POST " << boshURL_.getPath() << " HTTP/1.1\r\n" - << "Host: " << boshURL_.getHost(); - if (boshURL_.getPort()) { - header << ":" << *boshURL_.getPort(); - } - header << "\r\n" - // << "Accept-Encoding: deflate\r\n" - << "Content-Type: text/xml; charset=utf-8\r\n" - << "Content-Length: " << contentString.size() << "\r\n\r\n" - << contentString; - - waitingForStartResponse_ = true; - SafeByteArray safeHeader = createSafeByteArray(header.str()); - onBOSHDataWritten(safeHeader); - connection_->write(safeHeader); - SWIFT_LOG(debug) << "write stream header: " << safeByteArrayToString(safeHeader) << std::endl; + assert(connectionReady_); + // Session Creation Request + std::stringstream content; + std::stringstream header; + + content << "<body content='text/xml; charset=utf-8'" + << " hold='1'" + << " to='" << to << "'" + << " rid='" << rid << "'" + << " ver='1.6'" + << " wait='60'" /* FIXME: we probably want this configurable*/ + // << " ack='0'" FIXME: support acks + << " xml:lang='en'" + << " xmlns:xmpp='urn:xmpp:bosh'" + << " xmpp:version='1.0'" + << " xmlns='http://jabber.org/protocol/httpbind' />"; + + std::string contentString = content.str(); + + header << "POST " << boshURL_.getPath() << " HTTP/1.1\r\n" + << "Host: " << boshURL_.getHost(); + if (boshURL_.getPort()) { + header << ":" << *boshURL_.getPort(); + } + header << "\r\n" + // << "Accept-Encoding: deflate\r\n" + << "Content-Type: text/xml; charset=utf-8\r\n" + << "Content-Length: " << contentString.size() << "\r\n\r\n" + << contentString; + + waitingForStartResponse_ = true; + SafeByteArray safeHeader = createSafeByteArray(header.str()); + onBOSHDataWritten(safeHeader); + writeData(safeHeader); + SWIFT_LOG(debug) << "write stream header: " << safeByteArrayToString(safeHeader); } -void BOSHConnection::handleDataRead(boost::shared_ptr<SafeByteArray> data) { - onBOSHDataRead(*data); - buffer_ = concat(buffer_, *data); - std::string response = safeByteArrayToString(buffer_); - if (response.find("\r\n\r\n") == std::string::npos) { - onBOSHDataRead(createSafeByteArray("[[Previous read incomplete, pending]]")); - return; - } - - std::string httpCode = response.substr(response.find(" ") + 1, 3); - if (httpCode != "200") { - onHTTPError(httpCode); - return; - } - - BOSHBodyExtractor parser(parserFactory_, createByteArray(response.substr(response.find("\r\n\r\n") + 4))); - if (parser.getBody()) { - if (parser.getBody()->attributes.getAttribute("type") == "terminate") { - BOSHError::Type errorType = parseTerminationCondition(parser.getBody()->attributes.getAttribute("condition")); - onSessionTerminated(errorType == BOSHError::NoError ? boost::shared_ptr<BOSHError>() : boost::make_shared<BOSHError>(errorType)); - } - buffer_.clear(); - if (waitingForStartResponse_) { - waitingForStartResponse_ = false; - sid_ = parser.getBody()->attributes.getAttribute("sid"); - std::string requestsString = parser.getBody()->attributes.getAttribute("requests"); - size_t requests = 2; - if (!requestsString.empty()) { - try { - requests = boost::lexical_cast<size_t>(requestsString); - } - catch (const boost::bad_lexical_cast&) { - } - } - onSessionStarted(sid_, requests); - } - SafeByteArray payload = createSafeByteArray(parser.getBody()->content); - /* Say we're good to go again, so don't add anything after here in the method */ - pending_ = false; - onXMPPDataRead(payload); - } +void BOSHConnection::handleDataRead(std::shared_ptr<SafeByteArray> data) { + onBOSHDataRead(*data); + buffer_ = concat(buffer_, *data); + std::string response = safeByteArrayToString(buffer_); + if (response.find("\r\n\r\n") == std::string::npos) { + onBOSHDataRead(createSafeByteArray("[[Previous read incomplete, pending]]")); + return; + } + + std::string httpCode = response.substr(response.find(" ") + 1, 3); + if (httpCode != "200") { + onHTTPError(httpCode); + return; + } + + BOSHBodyExtractor parser(parserFactory_, createByteArray(response.substr(response.find("\r\n\r\n") + 4))); + if (parser.getBody()) { + if (parser.getBody()->attributes.getAttribute("type") == "terminate") { + BOSHError::Type errorType = parseTerminationCondition(parser.getBody()->attributes.getAttribute("condition")); + onSessionTerminated(errorType == BOSHError::NoError ? std::shared_ptr<BOSHError>() : std::make_shared<BOSHError>(errorType)); + return; + } + buffer_.clear(); + if (waitingForStartResponse_) { + waitingForStartResponse_ = false; + sid_ = parser.getBody()->attributes.getAttribute("sid"); + std::string requestsString = parser.getBody()->attributes.getAttribute("requests"); + size_t requests = 2; + if (!requestsString.empty()) { + try { + requests = boost::lexical_cast<size_t>(requestsString); + } + catch (const boost::bad_lexical_cast&) { + } + } + onSessionStarted(sid_, requests); + } + SafeByteArray payload = createSafeByteArray(parser.getBody()->content); + /* Say we're good to go again, so don't add anything after here in the method */ + pending_ = false; + onXMPPDataRead(payload); + } } BOSHError::Type BOSHConnection::parseTerminationCondition(const std::string& text) { - BOSHError::Type condition = BOSHError::UndefinedCondition; - if (text == "bad-request") { - condition = BOSHError::BadRequest; - } - else if (text == "host-gone") { - condition = BOSHError::HostGone; - } - else if (text == "host-unknown") { - condition = BOSHError::HostUnknown; - } - else if (text == "improper-addressing") { - condition = BOSHError::ImproperAddressing; - } - else if (text == "internal-server-error") { - condition = BOSHError::InternalServerError; - } - else if (text == "item-not-found") { - condition = BOSHError::ItemNotFound; - } - else if (text == "other-request") { - condition = BOSHError::OtherRequest; - } - else if (text == "policy-violation") { - condition = BOSHError::PolicyViolation; - } - else if (text == "remote-connection-failed") { - condition = BOSHError::RemoteConnectionFailed; - } - else if (text == "remote-stream-error") { - condition = BOSHError::RemoteStreamError; - } - else if (text == "see-other-uri") { - condition = BOSHError::SeeOtherURI; - } - else if (text == "system-shutdown") { - condition = BOSHError::SystemShutdown; - } - else if (text == "") { - condition = BOSHError::NoError; - } - return condition; + BOSHError::Type condition = BOSHError::UndefinedCondition; + if (text == "bad-request") { + condition = BOSHError::BadRequest; + } + else if (text == "host-gone") { + condition = BOSHError::HostGone; + } + else if (text == "host-unknown") { + condition = BOSHError::HostUnknown; + } + else if (text == "improper-addressing") { + condition = BOSHError::ImproperAddressing; + } + else if (text == "internal-server-error") { + condition = BOSHError::InternalServerError; + } + else if (text == "item-not-found") { + condition = BOSHError::ItemNotFound; + } + else if (text == "other-request") { + condition = BOSHError::OtherRequest; + } + else if (text == "policy-violation") { + condition = BOSHError::PolicyViolation; + } + else if (text == "remote-connection-failed") { + condition = BOSHError::RemoteConnectionFailed; + } + else if (text == "remote-stream-error") { + condition = BOSHError::RemoteStreamError; + } + else if (text == "see-other-uri") { + condition = BOSHError::SeeOtherURI; + } + else if (text == "system-shutdown") { + condition = BOSHError::SystemShutdown; + } + else if (text == "") { + condition = BOSHError::NoError; + } + return condition; } const std::string& BOSHConnection::getSID() { - return sid_; + return sid_; } void BOSHConnection::setRID(unsigned long long rid) { - rid_ = rid; + rid_ = rid; } void BOSHConnection::setSID(const std::string& sid) { - sid_ = sid; + sid_ = sid; } void BOSHConnection::handleDisconnected(const boost::optional<Connection::Error>& error) { - cancelConnector(); - onDisconnected(error); - sid_ = ""; - connectionReady_ = false; + cancelConnector(); + onDisconnected(error ? true : false); + sid_ = ""; + connectionReady_ = false; } bool BOSHConnection::isReadyToSend() { - /* Without pipelining you need to not send more without first receiving the response */ - /* With pipelining you can. Assuming we can't, here */ - return connectionReady_ && !pending_ && !waitingForStartResponse_ && !sid_.empty(); + /* Without pipelining you need to not send more without first receiving the response */ + /* With pipelining you can. Assuming we can't, here */ + return connectionReady_ && !pending_ && !waitingForStartResponse_ && !sid_.empty(); } } diff --git a/Swiften/Network/BOSHConnection.h b/Swiften/Network/BOSHConnection.h index 01341cc..f0a946a 100644 --- a/Swiften/Network/BOSHConnection.h +++ b/Swiften/Network/BOSHConnection.h @@ -5,104 +5,115 @@ */ /* - * Copyright (c) 2011 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2017 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/enable_shared_from_this.hpp> +#include <memory> #include <Swiften/Base/API.h> +#include <Swiften/Base/Error.h> +#include <Swiften/Base/String.h> +#include <Swiften/Base/URL.h> #include <Swiften/Network/Connection.h> #include <Swiften/Network/Connector.h> #include <Swiften/Network/HostAddressPort.h> -#include <Swiften/Base/String.h> -#include <Swiften/Base/URL.h> -#include <Swiften/Base/Error.h> #include <Swiften/Session/SessionStream.h> - -namespace boost { - class thread; - namespace system { - class error_code; - } -} +#include <Swiften/TLS/TLSError.h> class BOSHConnectionTest; namespace Swift { - class XMLParserFactory; - class TLSContextFactory; - - class SWIFTEN_API BOSHError : public SessionStream::SessionStreamError { - public: - enum Type {BadRequest, HostGone, HostUnknown, ImproperAddressing, - InternalServerError, ItemNotFound, OtherRequest, PolicyViolation, - RemoteConnectionFailed, RemoteStreamError, SeeOtherURI, SystemShutdown, UndefinedCondition, - NoError}; - BOSHError(Type type) : SessionStream::SessionStreamError(SessionStream::SessionStreamError::ConnectionReadError), type(type) {} - Type getType() {return type;} - typedef boost::shared_ptr<BOSHError> ref; - private: - Type type; - - }; - - - class SWIFTEN_API BOSHConnection : public boost::enable_shared_from_this<BOSHConnection> { - public: - typedef boost::shared_ptr<BOSHConnection> ref; - static ref create(const URL& boshURL, Connector::ref connector, XMLParserFactory* parserFactory) { - return ref(new BOSHConnection(boshURL, connector, parserFactory)); - } - virtual ~BOSHConnection(); - virtual void connect(); - virtual void disconnect(); - virtual void write(const SafeByteArray& data); - - const std::string& getSID(); - void setRID(unsigned long long rid); - void setSID(const std::string& sid); - void startStream(const std::string& to, unsigned long long rid); - void terminateStream(); - bool isReadyToSend(); - void restartStream(); - - - boost::signal<void (bool /* error */)> onConnectFinished; - boost::signal<void (bool /* error */)> onDisconnected; - boost::signal<void (BOSHError::ref)> onSessionTerminated; - boost::signal<void (const std::string& /*sid*/, size_t /*requests*/)> onSessionStarted; - boost::signal<void (const SafeByteArray&)> onXMPPDataRead; - boost::signal<void (const SafeByteArray&)> onBOSHDataRead; - boost::signal<void (const SafeByteArray&)> onBOSHDataWritten; - boost::signal<void (const std::string&)> onHTTPError; - - private: - friend class ::BOSHConnectionTest; - - BOSHConnection(const URL& boshURL, Connector::ref connector, XMLParserFactory* parserFactory); - - static std::pair<SafeByteArray, size_t> createHTTPRequest(const SafeByteArray& data, bool streamRestart, bool terminate, unsigned long long rid, const std::string& sid, const URL& boshURL); - void handleConnectFinished(Connection::ref); - void handleDataRead(boost::shared_ptr<SafeByteArray> data); - void handleDisconnected(const boost::optional<Connection::Error>& error); - void write(const SafeByteArray& data, bool streamRestart, bool terminate); /* FIXME: refactor */ - BOSHError::Type parseTerminationCondition(const std::string& text); - void cancelConnector(); - - URL boshURL_; - Connector::ref connector_; - XMLParserFactory* parserFactory_; - boost::shared_ptr<Connection> connection_; - std::string sid_; - bool waitingForStartResponse_; - unsigned long long rid_; - SafeByteArray buffer_; - bool pending_; - bool connectionReady_; - }; + class XMLParserFactory; + class TLSContextFactory; + class TLSLayer; + class TLSOptions; + class HighLayer; + + class SWIFTEN_API BOSHError : public SessionStream::SessionStreamError { + public: + enum Type { + BadRequest, HostGone, HostUnknown, ImproperAddressing, + InternalServerError, ItemNotFound, OtherRequest, PolicyViolation, + RemoteConnectionFailed, RemoteStreamError, SeeOtherURI, SystemShutdown, UndefinedCondition, + NoError}; + + BOSHError(Type type) : SessionStream::SessionStreamError(SessionStream::SessionStreamError::ConnectionReadError), type(type) {} + Type getType() const {return type;} + typedef std::shared_ptr<BOSHError> ref; + + private: + Type type; + }; + + class SWIFTEN_API BOSHConnection : public std::enable_shared_from_this<BOSHConnection> { + public: + typedef std::shared_ptr<BOSHConnection> ref; + static ref create(const URL& boshURL, Connector::ref connector, XMLParserFactory* parserFactory, TLSContextFactory* tlsContextFactory, const TLSOptions& tlsOptions) { + return ref(new BOSHConnection(boshURL, connector, parserFactory, tlsContextFactory, tlsOptions)); + } + virtual ~BOSHConnection(); + virtual void connect(); + virtual void disconnect(); + virtual void write(const SafeByteArray& data); + + const std::string& getSID(); + void setRID(unsigned long long rid); + void setSID(const std::string& sid); + void startStream(const std::string& to, unsigned long long rid); + void terminateStream(); + bool isReadyToSend(); + void restartStream(); + + bool setClientCertificate(CertificateWithKey::ref cert); + Certificate::ref getPeerCertificate() const; + std::vector<Certificate::ref> getPeerCertificateChain() const; + CertificateVerificationError::ref getPeerCertificateVerificationError() const; + + boost::signals2::signal<void (bool /* error */)> onConnectFinished; + boost::signals2::signal<void (bool /* error */)> onDisconnected; + boost::signals2::signal<void (BOSHError::ref)> onSessionTerminated; + boost::signals2::signal<void (const std::string& /*sid*/, size_t /*requests*/)> onSessionStarted; + boost::signals2::signal<void (const SafeByteArray&)> onXMPPDataRead; + boost::signals2::signal<void (const SafeByteArray&)> onBOSHDataRead; + boost::signals2::signal<void (const SafeByteArray&)> onBOSHDataWritten; + boost::signals2::signal<void (const std::string&)> onHTTPError; + + private: + friend class ::BOSHConnectionTest; + + BOSHConnection(const URL& boshURL, Connector::ref connector, XMLParserFactory* parserFactory, TLSContextFactory* tlsContextFactory, const TLSOptions& tlsOptions); + + static std::pair<SafeByteArray, size_t> createHTTPRequest(const SafeByteArray& data, bool streamRestart, bool terminate, unsigned long long rid, const std::string& sid, const URL& boshURL); + void handleConnectFinished(Connection::ref); + void handleDataRead(std::shared_ptr<SafeByteArray> data); + void handleDisconnected(const boost::optional<Connection::Error>& error); + void write(const SafeByteArray& data, bool streamRestart, bool terminate); /* FIXME: refactor */ + BOSHError::Type parseTerminationCondition(const std::string& text); + void cancelConnector(); + + void handleTLSConnected(); + void handleTLSApplicationDataRead(const SafeByteArray& data); + void handleTLSNetowrkDataWriteRequest(const SafeByteArray& data); + void handleRawDataRead(std::shared_ptr<SafeByteArray> data); + void handleTLSError(std::shared_ptr<TLSError> error); + void writeData(const SafeByteArray& data); + + URL boshURL_; + Connector::ref connector_; + XMLParserFactory* parserFactory_; + std::shared_ptr<Connection> connection_; + std::shared_ptr<HighLayer> dummyLayer_; + std::shared_ptr<TLSLayer> tlsLayer_; + std::string sid_; + bool waitingForStartResponse_; + unsigned long long rid_; + SafeByteArray buffer_; + bool pending_; + bool connectionReady_; + }; } diff --git a/Swiften/Network/BOSHConnectionPool.cpp b/Swiften/Network/BOSHConnectionPool.cpp index 4517ffb..3a79a16 100644 --- a/Swiften/Network/BOSHConnectionPool.cpp +++ b/Swiften/Network/BOSHConnectionPool.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2011 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/BOSHConnectionPool.h> @@ -10,258 +10,292 @@ #include <boost/bind.hpp> #include <boost/lexical_cast.hpp> -#include <Swiften/Base/foreach.h> +#include <Swiften/Base/Log.h> #include <Swiften/Base/SafeString.h> -#include <Swiften/Network/TLSConnectionFactory.h> -#include <Swiften/Network/HTTPConnectProxiedConnectionFactory.h> #include <Swiften/Network/CachingDomainNameResolver.h> +#include <Swiften/Network/HTTPConnectProxiedConnectionFactory.h> namespace Swift { -BOSHConnectionPool::BOSHConnectionPool(const URL& boshURL, DomainNameResolver* realResolver, ConnectionFactory* connectionFactoryParameter, XMLParserFactory* parserFactory, TLSContextFactory* tlsFactory, TimerFactory* timerFactory, EventLoop* eventLoop, const std::string& to, unsigned long long initialRID, const URL& boshHTTPConnectProxyURL, const SafeString& boshHTTPConnectProxyAuthID, const SafeString& boshHTTPConnectProxyAuthPassword) : - boshURL(boshURL), - connectionFactory(connectionFactoryParameter), - xmlParserFactory(parserFactory), - timerFactory(timerFactory), - rid(initialRID), - pendingTerminate(false), - to(to), - requestLimit(2), - restartCount(0), - pendingRestart(false) { - - if (!boshHTTPConnectProxyURL.isEmpty()) { - if (boshHTTPConnectProxyURL.getScheme() == "https") { - connectionFactory = new TLSConnectionFactory(tlsFactory, connectionFactory); - myConnectionFactories.push_back(connectionFactory); - } - connectionFactory = new HTTPConnectProxiedConnectionFactory(realResolver, connectionFactory, timerFactory, boshHTTPConnectProxyURL.getHost(), URL::getPortOrDefaultPort(boshHTTPConnectProxyURL), boshHTTPConnectProxyAuthID, boshHTTPConnectProxyAuthPassword); - } - if (boshURL.getScheme() == "https") { - connectionFactory = new TLSConnectionFactory(tlsFactory, connectionFactory); - myConnectionFactories.push_back(connectionFactory); - } - resolver = new CachingDomainNameResolver(realResolver, eventLoop); - createConnection(); +BOSHConnectionPool::BOSHConnectionPool(const URL& boshURL, DomainNameResolver* realResolver, ConnectionFactory* connectionFactoryParameter, XMLParserFactory* parserFactory, TLSContextFactory* tlsFactory, TimerFactory* timerFactory, EventLoop* eventLoop, const std::string& to, unsigned long long initialRID, const URL& boshHTTPConnectProxyURL, const SafeString& boshHTTPConnectProxyAuthID, const SafeString& boshHTTPConnectProxyAuthPassword, const TLSOptions& tlsOptions, std::shared_ptr<HTTPTrafficFilter> trafficFilter) : + boshURL(boshURL), + connectionFactory(connectionFactoryParameter), + xmlParserFactory(parserFactory), + timerFactory(timerFactory), + rid(initialRID), + pendingTerminate(false), + to(to), + requestLimit(2), + restartCount(0), + pendingRestart(false), + tlsContextFactory_(tlsFactory), + tlsOptions_(tlsOptions) { + + if (!boshHTTPConnectProxyURL.isEmpty()) { + connectionFactory = new HTTPConnectProxiedConnectionFactory(realResolver, connectionFactory, timerFactory, boshHTTPConnectProxyURL.getHost(), URL::getPortOrDefaultPort(boshHTTPConnectProxyURL), boshHTTPConnectProxyAuthID, boshHTTPConnectProxyAuthPassword, trafficFilter); + } + resolver = new CachingDomainNameResolver(realResolver, eventLoop); } BOSHConnectionPool::~BOSHConnectionPool() { - /* Don't do a normal close here. Instead kill things forcibly, as close() or writeFooter() will already have been called */ - std::vector<BOSHConnection::ref> connectionCopies = connections; - foreach (BOSHConnection::ref connection, connectionCopies) { - if (connection) { - destroyConnection(connection); - connection->disconnect(); - } - } - foreach (ConnectionFactory* factory, myConnectionFactories) { - delete factory; - } - delete resolver; + /* Don't do a normal close here. Instead kill things forcibly, as close() or writeFooter() will already have been called */ + std::vector<BOSHConnection::ref> connectionCopies = connections; + for (auto&& connection : connectionCopies) { + if (connection) { + destroyConnection(connection); + connection->disconnect(); + } + } + for (auto factory : myConnectionFactories) { + delete factory; + } + delete resolver; } void BOSHConnectionPool::write(const SafeByteArray& data) { - dataQueue.push_back(data); - tryToSendQueuedData(); + dataQueue.push_back(data); + tryToSendQueuedData(); } void BOSHConnectionPool::handleDataRead(const SafeByteArray& data) { - onXMPPDataRead(data); - tryToSendQueuedData(); /* Will rebalance the connections */ + onXMPPDataRead(data); + tryToSendQueuedData(); /* Will rebalance the connections */ } void BOSHConnectionPool::restartStream() { - BOSHConnection::ref connection = getSuitableConnection(); - if (connection) { - pendingRestart = false; - rid++; - connection->setRID(rid); - connection->restartStream(); - restartCount++; - } - else { - pendingRestart = true; - } + BOSHConnection::ref connection = getSuitableConnection(); + if (connection) { + pendingRestart = false; + rid++; + connection->setRID(rid); + connection->restartStream(); + restartCount++; + } + else { + pendingRestart = true; + } +} + +void BOSHConnectionPool::setTLSCertificate(CertificateWithKey::ref certWithKey) { + clientCertificate = certWithKey; +} + +bool BOSHConnectionPool::isTLSEncrypted() const { + return !pinnedCertificateChain_.empty(); +} + +Certificate::ref BOSHConnectionPool::getPeerCertificate() const { + Certificate::ref peerCertificate; + if (!pinnedCertificateChain_.empty()) { + peerCertificate = pinnedCertificateChain_[0]; + } + return peerCertificate; +} + +std::vector<Certificate::ref> BOSHConnectionPool::getPeerCertificateChain() const { + return pinnedCertificateChain_; +} + +std::shared_ptr<CertificateVerificationError> BOSHConnectionPool::getPeerCertificateVerificationError() const { + return lastVerificationError_; } void BOSHConnectionPool::writeFooter() { - pendingTerminate = true; - tryToSendQueuedData(); + pendingTerminate = true; + tryToSendQueuedData(); +} + +void BOSHConnectionPool::open() { + createConnection(); } void BOSHConnectionPool::close() { - if (!sid.empty()) { - writeFooter(); - } - else { - pendingTerminate = true; - std::vector<BOSHConnection::ref> connectionCopies = connections; - foreach (BOSHConnection::ref connection, connectionCopies) { - if (connection) { - connection->disconnect(); - } - } - } + if (!sid.empty()) { + writeFooter(); + } + else { + pendingTerminate = true; + std::vector<BOSHConnection::ref> connectionCopies = connections; + for (auto&& connection : connectionCopies) { + if (connection) { + connection->disconnect(); + } + } + } } void BOSHConnectionPool::handleSessionStarted(const std::string& sessionID, size_t requests) { - sid = sessionID; - requestLimit = requests; - onSessionStarted(); + sid = sessionID; + requestLimit = requests; + onSessionStarted(); } void BOSHConnectionPool::handleConnectFinished(bool error, BOSHConnection::ref connection) { - if (error) { - onSessionTerminated(boost::make_shared<BOSHError>(BOSHError::UndefinedCondition)); - /*TODO: We can probably manage to not terminate the stream here and use the rid/ack retry - * logic to just swallow the error and try again (some number of tries). - */ - } - else { - if (sid.empty()) { - connection->startStream(to, rid); - } - if (pendingRestart) { - restartStream(); - } - tryToSendQueuedData(); - } + if (error) { + onSessionTerminated(std::make_shared<BOSHError>(BOSHError::UndefinedCondition)); + /*TODO: We can probably manage to not terminate the stream here and use the rid/ack retry + * logic to just swallow the error and try again (some number of tries). + */ + } + else { + if (connection->getPeerCertificate() && pinnedCertificateChain_.empty()) { + pinnedCertificateChain_ = connection->getPeerCertificateChain(); + } + if (!pinnedCertificateChain_.empty()) { + lastVerificationError_ = connection->getPeerCertificateVerificationError(); + onTLSConnectionEstablished(); + } + + if (sid.empty()) { + connection->startStream(to, rid); + } + if (pendingRestart) { + restartStream(); + } + tryToSendQueuedData(); + } } BOSHConnection::ref BOSHConnectionPool::getSuitableConnection() { - BOSHConnection::ref suitableConnection; - foreach (BOSHConnection::ref connection, connections) { - if (connection->isReadyToSend()) { - suitableConnection = connection; - break; - } - } - - if (!suitableConnection && connections.size() < requestLimit) { - /* This is not a suitable connection because it won't have yet connected and added TLS if needed. */ - BOSHConnection::ref newConnection = createConnection(); - newConnection->setSID(sid); - } - assert(connections.size() <= requestLimit); - assert((!suitableConnection) || suitableConnection->isReadyToSend()); - return suitableConnection; + BOSHConnection::ref suitableConnection; + for (auto&& connection : connections) { + if (connection->isReadyToSend()) { + suitableConnection = connection; + break; + } + } + + if (!suitableConnection && connections.size() < requestLimit) { + /* This is not a suitable connection because it won't have yet connected and added TLS if needed. */ + BOSHConnection::ref newConnection = createConnection(); + newConnection->setSID(sid); + } + assert(connections.size() <= requestLimit); + assert((!suitableConnection) || suitableConnection->isReadyToSend()); + return suitableConnection; } void BOSHConnectionPool::tryToSendQueuedData() { - if (sid.empty()) { - /* If we've not got as far as stream start yet, pend */ - return; - } - - BOSHConnection::ref suitableConnection = getSuitableConnection(); - bool toSend = !dataQueue.empty(); - if (suitableConnection) { - if (toSend) { - rid++; - suitableConnection->setRID(rid); - SafeByteArray data; - foreach (const SafeByteArray& datum, dataQueue) { - data.insert(data.end(), datum.begin(), datum.end()); - } - suitableConnection->write(data); - dataQueue.clear(); - } - else if (pendingTerminate) { - rid++; - suitableConnection->setRID(rid); - suitableConnection->terminateStream(); - sid = ""; - close(); - } - } - if (!pendingTerminate) { - /* Ensure there's always a session waiting to read data for us */ - bool pending = false; - foreach (BOSHConnection::ref connection, connections) { - if (connection && !connection->isReadyToSend()) { - pending = true; - } - } - if (!pending) { - if (restartCount >= 1) { - /* Don't open a second connection until we've restarted the stream twice - i.e. we've authed and resource bound.*/ - if (suitableConnection) { - rid++; - suitableConnection->setRID(rid); - suitableConnection->write(createSafeByteArray("")); - } - else { - /* My thought process I went through when writing this, to aid anyone else confused why this can happen... - * - * What to do here? I think this isn't possible. - If you didn't have two connections, suitable would have made one. - If you have two connections and neither is suitable, pending would be true. - If you have a non-pending connection, it's suitable. - - If I decide to do something here, remove assert above. - - Ah! Yes, because there's a period between creating the connection and it being connected. */ - } - } - } - } + if (sid.empty()) { + /* If we've not got as far as stream start yet, pend */ + return; + } + + BOSHConnection::ref suitableConnection = getSuitableConnection(); + bool toSend = !dataQueue.empty(); + if (suitableConnection) { + if (toSend) { + rid++; + suitableConnection->setRID(rid); + SafeByteArray data; + for (const auto& datum : dataQueue) { + data.insert(data.end(), datum.begin(), datum.end()); + } + suitableConnection->write(data); + dataQueue.clear(); + } + else if (pendingTerminate) { + rid++; + suitableConnection->setRID(rid); + suitableConnection->terminateStream(); + sid = ""; + close(); + } + } + if (!pendingTerminate) { + /* Ensure there's always a session waiting to read data for us */ + bool pending = false; + for (auto&& connection : connections) { + if (connection && !connection->isReadyToSend()) { + pending = true; + } + } + if (!pending) { + if (restartCount >= 1) { + /* Don't open a second connection until we've restarted the stream twice - i.e. we've authed and resource bound.*/ + if (suitableConnection) { + rid++; + suitableConnection->setRID(rid); + suitableConnection->write(createSafeByteArray("")); + } + else { + /* My thought process I went through when writing this, to aid anyone else confused why this can happen... + * + * What to do here? I think this isn't possible. + If you didn't have two connections, suitable would have made one. + If you have two connections and neither is suitable, pending would be true. + If you have a non-pending connection, it's suitable. + + If I decide to do something here, remove assert above. + + Ah! Yes, because there's a period between creating the connection and it being connected. */ + } + } + } + } } void BOSHConnectionPool::handleHTTPError(const std::string& /*errorCode*/) { - handleSessionTerminated(boost::make_shared<BOSHError>(BOSHError::UndefinedCondition)); + handleSessionTerminated(std::make_shared<BOSHError>(BOSHError::UndefinedCondition)); } void BOSHConnectionPool::handleConnectionDisconnected(bool/* error*/, BOSHConnection::ref connection) { - destroyConnection(connection); - if (pendingTerminate && sid.empty() && connections.empty()) { - handleSessionTerminated(BOSHError::ref()); - } - //else if (error) { - // handleSessionTerminated(boost::make_shared<BOSHError>(BOSHError::UndefinedCondition)); - //} - else { - /* We might have just freed up a connection slot to send with */ - tryToSendQueuedData(); - } + destroyConnection(connection); + if (pendingTerminate && sid.empty() && connections.empty()) { + handleSessionTerminated(BOSHError::ref()); + } + //else if (error) { + // handleSessionTerminated(std::make_shared<BOSHError>(BOSHError::UndefinedCondition)); + //} + else { + /* We might have just freed up a connection slot to send with */ + tryToSendQueuedData(); + } } -boost::shared_ptr<BOSHConnection> BOSHConnectionPool::createConnection() { - Connector::ref connector = Connector::create(boshURL.getHost(), URL::getPortOrDefaultPort(boshURL), false, resolver, connectionFactory, timerFactory); - BOSHConnection::ref connection = BOSHConnection::create(boshURL, connector, xmlParserFactory); - connection->onXMPPDataRead.connect(boost::bind(&BOSHConnectionPool::handleDataRead, this, _1)); - connection->onSessionStarted.connect(boost::bind(&BOSHConnectionPool::handleSessionStarted, this, _1, _2)); - connection->onBOSHDataRead.connect(boost::bind(&BOSHConnectionPool::handleBOSHDataRead, this, _1)); - connection->onBOSHDataWritten.connect(boost::bind(&BOSHConnectionPool::handleBOSHDataWritten, this, _1)); - connection->onDisconnected.connect(boost::bind(&BOSHConnectionPool::handleConnectionDisconnected, this, _1, connection)); - connection->onConnectFinished.connect(boost::bind(&BOSHConnectionPool::handleConnectFinished, this, _1, connection)); - connection->onSessionTerminated.connect(boost::bind(&BOSHConnectionPool::handleSessionTerminated, this, _1)); - connection->onHTTPError.connect(boost::bind(&BOSHConnectionPool::handleHTTPError, this, _1)); - connection->connect(); - connections.push_back(connection); - return connection; +std::shared_ptr<BOSHConnection> BOSHConnectionPool::createConnection() { + Connector::ref connector = Connector::create(boshURL.getHost(), URL::getPortOrDefaultPort(boshURL), boost::optional<std::string>(), resolver, connectionFactory, timerFactory); + BOSHConnection::ref connection = BOSHConnection::create(boshURL, connector, xmlParserFactory, tlsContextFactory_, tlsOptions_); + connection->onXMPPDataRead.connect(boost::bind(&BOSHConnectionPool::handleDataRead, this, _1)); + connection->onSessionStarted.connect(boost::bind(&BOSHConnectionPool::handleSessionStarted, this, _1, _2)); + connection->onBOSHDataRead.connect(boost::bind(&BOSHConnectionPool::handleBOSHDataRead, this, _1)); + connection->onBOSHDataWritten.connect(boost::bind(&BOSHConnectionPool::handleBOSHDataWritten, this, _1)); + connection->onDisconnected.connect(boost::bind(&BOSHConnectionPool::handleConnectionDisconnected, this, _1, connection)); + connection->onConnectFinished.connect(boost::bind(&BOSHConnectionPool::handleConnectFinished, this, _1, connection)); + connection->onSessionTerminated.connect(boost::bind(&BOSHConnectionPool::handleSessionTerminated, this, _1)); + connection->onHTTPError.connect(boost::bind(&BOSHConnectionPool::handleHTTPError, this, _1)); + + if (boshURL.getScheme() == "https") { + bool success = connection->setClientCertificate(clientCertificate); + SWIFT_LOG(debug) << "setClientCertificate, success: " << success; + } + + connection->connect(); + connections.push_back(connection); + return connection; } -void BOSHConnectionPool::destroyConnection(boost::shared_ptr<BOSHConnection> connection) { - connections.erase(std::remove(connections.begin(), connections.end(), connection), connections.end()); - connection->onXMPPDataRead.disconnect(boost::bind(&BOSHConnectionPool::handleDataRead, this, _1)); - connection->onSessionStarted.disconnect(boost::bind(&BOSHConnectionPool::handleSessionStarted, this, _1, _2)); - connection->onBOSHDataRead.disconnect(boost::bind(&BOSHConnectionPool::handleBOSHDataRead, this, _1)); - connection->onBOSHDataWritten.disconnect(boost::bind(&BOSHConnectionPool::handleBOSHDataWritten, this, _1)); - connection->onDisconnected.disconnect(boost::bind(&BOSHConnectionPool::handleConnectionDisconnected, this, _1, connection)); - connection->onConnectFinished.disconnect(boost::bind(&BOSHConnectionPool::handleConnectFinished, this, _1, connection)); - connection->onSessionTerminated.disconnect(boost::bind(&BOSHConnectionPool::handleSessionTerminated, this, _1)); - connection->onHTTPError.disconnect(boost::bind(&BOSHConnectionPool::handleHTTPError, this, _1)); +void BOSHConnectionPool::destroyConnection(std::shared_ptr<BOSHConnection> connection) { + connections.erase(std::remove(connections.begin(), connections.end(), connection), connections.end()); + connection->onXMPPDataRead.disconnect(boost::bind(&BOSHConnectionPool::handleDataRead, this, _1)); + connection->onSessionStarted.disconnect(boost::bind(&BOSHConnectionPool::handleSessionStarted, this, _1, _2)); + connection->onBOSHDataRead.disconnect(boost::bind(&BOSHConnectionPool::handleBOSHDataRead, this, _1)); + connection->onBOSHDataWritten.disconnect(boost::bind(&BOSHConnectionPool::handleBOSHDataWritten, this, _1)); + connection->onDisconnected.disconnect(boost::bind(&BOSHConnectionPool::handleConnectionDisconnected, this, _1, connection)); + connection->onConnectFinished.disconnect(boost::bind(&BOSHConnectionPool::handleConnectFinished, this, _1, connection)); + connection->onSessionTerminated.disconnect(boost::bind(&BOSHConnectionPool::handleSessionTerminated, this, _1)); + connection->onHTTPError.disconnect(boost::bind(&BOSHConnectionPool::handleHTTPError, this, _1)); } void BOSHConnectionPool::handleSessionTerminated(BOSHError::ref error) { - onSessionTerminated(error); + onSessionTerminated(error); } void BOSHConnectionPool::handleBOSHDataRead(const SafeByteArray& data) { - onBOSHDataRead(data); + onBOSHDataRead(data); } void BOSHConnectionPool::handleBOSHDataWritten(const SafeByteArray& data) { - onBOSHDataWritten(data); + onBOSHDataWritten(data); } } diff --git a/Swiften/Network/BOSHConnectionPool.h b/Swiften/Network/BOSHConnectionPool.h index de707e8..a6956fa 100644 --- a/Swiften/Network/BOSHConnectionPool.h +++ b/Swiften/Network/BOSHConnectionPool.h @@ -1,7 +1,7 @@ /* - * Copyright (c) 2011 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2017 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ @@ -12,59 +12,77 @@ #include <Swiften/Base/API.h> #include <Swiften/Base/SafeString.h> #include <Swiften/Network/BOSHConnection.h> +#include <Swiften/TLS/CertificateWithKey.h> +#include <Swiften/TLS/TLSOptions.h> namespace Swift { - class HTTPConnectProxiedConnectionFactory; - class TLSConnectionFactory; - class CachingDomainNameResolver; - class EventLoop; + class CachingDomainNameResolver; + class EventLoop; + class HTTPTrafficFilter; + class TLSContextFactory; + class CachingDomainNameResolver; + class EventLoop; - class SWIFTEN_API BOSHConnectionPool : public boost::bsignals::trackable { - public: - BOSHConnectionPool(const URL& boshURL, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, XMLParserFactory* parserFactory, TLSContextFactory* tlsFactory, TimerFactory* timerFactory, EventLoop* eventLoop, const std::string& to, unsigned long long initialRID, const URL& boshHTTPConnectProxyURL, const SafeString& boshHTTPConnectProxyAuthID, const SafeString& boshHTTPConnectProxyAuthPassword); - ~BOSHConnectionPool(); - void write(const SafeByteArray& data); - void writeFooter(); - void close(); - void restartStream(); + class SWIFTEN_API BOSHConnectionPool : public boost::signals2::trackable { + public: + BOSHConnectionPool(const URL& boshURL, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, XMLParserFactory* parserFactory, TLSContextFactory* tlsFactory, TimerFactory* timerFactory, EventLoop* eventLoop, const std::string& to, unsigned long long initialRID, const URL& boshHTTPConnectProxyURL, const SafeString& boshHTTPConnectProxyAuthID, const SafeString& boshHTTPConnectProxyAuthPassword, const TLSOptions& tlsOptions, std::shared_ptr<HTTPTrafficFilter> trafficFilter = std::shared_ptr<HTTPTrafficFilter>()); + ~BOSHConnectionPool(); - boost::signal<void (BOSHError::ref)> onSessionTerminated; - boost::signal<void ()> onSessionStarted; - boost::signal<void (const SafeByteArray&)> onXMPPDataRead; - boost::signal<void (const SafeByteArray&)> onBOSHDataRead; - boost::signal<void (const SafeByteArray&)> onBOSHDataWritten; + void open(); + void write(const SafeByteArray& data); + void writeFooter(); + void close(); + void restartStream(); - private: - void handleDataRead(const SafeByteArray& data); - void handleSessionStarted(const std::string& sid, size_t requests); - void handleBOSHDataRead(const SafeByteArray& data); - void handleBOSHDataWritten(const SafeByteArray& data); - void handleSessionTerminated(BOSHError::ref condition); - void handleConnectFinished(bool, BOSHConnection::ref connection); - void handleConnectionDisconnected(bool error, BOSHConnection::ref connection); - void handleHTTPError(const std::string& errorCode); + void setTLSCertificate(CertificateWithKey::ref certWithKey); + bool isTLSEncrypted() const; + Certificate::ref getPeerCertificate() const; + std::vector<Certificate::ref> getPeerCertificateChain() const; + std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const; - private: - BOSHConnection::ref createConnection(); - void destroyConnection(BOSHConnection::ref connection); - void tryToSendQueuedData(); - BOSHConnection::ref getSuitableConnection(); + boost::signals2::signal<void ()> onTLSConnectionEstablished; + boost::signals2::signal<void (BOSHError::ref)> onSessionTerminated; + boost::signals2::signal<void ()> onSessionStarted; + boost::signals2::signal<void (const SafeByteArray&)> onXMPPDataRead; + boost::signals2::signal<void (const SafeByteArray&)> onBOSHDataRead; + boost::signals2::signal<void (const SafeByteArray&)> onBOSHDataWritten; - private: - URL boshURL; - ConnectionFactory* connectionFactory; - XMLParserFactory* xmlParserFactory; - TimerFactory* timerFactory; - std::vector<BOSHConnection::ref> connections; - std::string sid; - unsigned long long rid; - std::vector<SafeByteArray> dataQueue; - bool pendingTerminate; - std::string to; - size_t requestLimit; - int restartCount; - bool pendingRestart; - std::vector<ConnectionFactory*> myConnectionFactories; - CachingDomainNameResolver* resolver; - }; + private: + void handleDataRead(const SafeByteArray& data); + void handleSessionStarted(const std::string& sid, size_t requests); + void handleBOSHDataRead(const SafeByteArray& data); + void handleBOSHDataWritten(const SafeByteArray& data); + void handleSessionTerminated(BOSHError::ref condition); + void handleConnectFinished(bool, BOSHConnection::ref connection); + void handleConnectionDisconnected(bool error, BOSHConnection::ref connection); + void handleHTTPError(const std::string& errorCode); + + private: + BOSHConnection::ref createConnection(); + void destroyConnection(BOSHConnection::ref connection); + void tryToSendQueuedData(); + BOSHConnection::ref getSuitableConnection(); + + private: + URL boshURL; + ConnectionFactory* connectionFactory; + XMLParserFactory* xmlParserFactory; + TimerFactory* timerFactory; + std::vector<BOSHConnection::ref> connections; + std::string sid; + unsigned long long rid; + std::vector<SafeByteArray> dataQueue; + bool pendingTerminate; + std::string to; + size_t requestLimit; + int restartCount; + bool pendingRestart; + std::vector<ConnectionFactory*> myConnectionFactories; + CachingDomainNameResolver* resolver; + CertificateWithKey::ref clientCertificate; + TLSContextFactory* tlsContextFactory_; + TLSOptions tlsOptions_; + std::vector<std::shared_ptr<Certificate> > pinnedCertificateChain_; + std::shared_ptr<CertificateVerificationError> lastVerificationError_; + }; } diff --git a/Swiften/Network/BoostConnection.cpp b/Swiften/Network/BoostConnection.cpp index 5137c3c..6ae6bf6 100644 --- a/Swiften/Network/BoostConnection.cpp +++ b/Swiften/Network/BoostConnection.cpp @@ -1,28 +1,27 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/BoostConnection.h> -#include <iostream> -#include <string> #include <algorithm> -#include <boost/bind.hpp> -#include <boost/thread.hpp> +#include <memory> +#include <mutex> +#include <string> + #include <boost/asio/placeholders.hpp> #include <boost/asio/write.hpp> -#include <boost/smart_ptr/make_shared.hpp> -#include <boost/numeric/conversion/cast.hpp> +#include <boost/bind.hpp> -#include <Swiften/Base/Log.h> #include <Swiften/Base/Algorithm.h> -#include <Swiften/EventLoop/EventLoop.h> #include <Swiften/Base/ByteArray.h> -#include <Swiften/Network/HostAddressPort.h> -#include <Swiften/Base/sleep.h> +#include <Swiften/Base/Log.h> #include <Swiften/Base/SafeAllocator.h> +#include <Swiften/Base/sleep.h> +#include <Swiften/EventLoop/EventLoop.h> +#include <Swiften/Network/HostAddressPort.h> namespace Swift { @@ -32,141 +31,147 @@ static const size_t BUFFER_SIZE = 4096; // A reference-counted non-modifiable buffer class. class SharedBuffer { - public: - SharedBuffer(const SafeByteArray& data) : - data_(new std::vector<char, SafeAllocator<char> >(data.begin(), data.end())), - buffer_(boost::asio::buffer(*data_)) { - } - - // ConstBufferSequence requirements. - typedef boost::asio::const_buffer value_type; - typedef const boost::asio::const_buffer* const_iterator; - const boost::asio::const_buffer* begin() const { return &buffer_; } - const boost::asio::const_buffer* end() const { return &buffer_ + 1; } - - private: - boost::shared_ptr< std::vector<char, SafeAllocator<char> > > data_; - boost::asio::const_buffer buffer_; + public: + SharedBuffer(const SafeByteArray& data) : + data_(new std::vector<char, SafeAllocator<char> >(data.begin(), data.end())), + buffer_(boost::asio::buffer(*data_)) { + } + + // ConstBufferSequence requirements. + typedef boost::asio::const_buffer value_type; + typedef const boost::asio::const_buffer* const_iterator; + const boost::asio::const_buffer* begin() const { return &buffer_; } + const boost::asio::const_buffer* end() const { return &buffer_ + 1; } + + private: + std::shared_ptr< std::vector<char, SafeAllocator<char> > > data_; + boost::asio::const_buffer buffer_; }; // ----------------------------------------------------------------------------- -BoostConnection::BoostConnection(boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : - eventLoop(eventLoop), ioService(ioService), socket_(*ioService), writing_(false), closeSocketAfterNextWrite_(false) { +BoostConnection::BoostConnection(std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : + eventLoop(eventLoop), ioService(ioService), socket_(*ioService), writing_(false), closeSocketAfterNextWrite_(false) { } BoostConnection::~BoostConnection() { } void BoostConnection::listen() { - doRead(); + doRead(); } void BoostConnection::connect(const HostAddressPort& addressPort) { - boost::asio::ip::tcp::endpoint endpoint( - boost::asio::ip::address::from_string(addressPort.getAddress().toString()), boost::numeric_cast<unsigned short>(addressPort.getPort())); - socket_.async_connect( - endpoint, - boost::bind(&BoostConnection::handleConnectFinished, shared_from_this(), boost::asio::placeholders::error)); + boost::asio::ip::tcp::endpoint endpoint( + boost::asio::ip::address::from_string(addressPort.getAddress().toString()), addressPort.getPort()); + socket_.async_connect( + endpoint, + boost::bind(&BoostConnection::handleConnectFinished, shared_from_this(), boost::asio::placeholders::error)); } void BoostConnection::disconnect() { - //MainEventLoop::removeEventsFromOwner(shared_from_this()); - - // Mac OS X apparently exhibits a problem where closing a socket during a write could potentially go into uninterruptable sleep. - // See e.g. http://bugs.python.org/issue7401 - // We therefore wait until any pending write finishes, which hopefully should fix our hang on exit during close(). - boost::lock_guard<boost::mutex> lock(writeMutex_); - if (writing_) { - closeSocketAfterNextWrite_ = true; - } else { - closeSocket(); - } + //MainEventLoop::removeEventsFromOwner(shared_from_this()); + + // Mac OS X apparently exhibits a problem where closing a socket during a write could potentially go into uninterruptable sleep. + // See e.g. http://bugs.python.org/issue7401 + // We therefore wait until any pending write finishes, which hopefully should fix our hang on exit during close(). + std::lock_guard<std::mutex> lock(writeMutex_); + if (writing_) { + closeSocketAfterNextWrite_ = true; + } else { + closeSocket(); + } } void BoostConnection::closeSocket() { - boost::system::error_code errorCode; - socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_both, errorCode); - socket_.close(); + std::lock_guard<std::mutex> lock(readCloseMutex_); + boost::system::error_code errorCode; + socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_both, errorCode); + socket_.close(); } void BoostConnection::write(const SafeByteArray& data) { - boost::lock_guard<boost::mutex> lock(writeMutex_); - if (!writing_) { - writing_ = true; - doWrite(data); - } - else { - append(writeQueue_, data); - } + std::lock_guard<std::mutex> lock(writeMutex_); + if (!writing_) { + writing_ = true; + doWrite(data); + } + else { + append(writeQueue_, data); + } } void BoostConnection::doWrite(const SafeByteArray& data) { - boost::asio::async_write(socket_, SharedBuffer(data), - boost::bind(&BoostConnection::handleDataWritten, shared_from_this(), boost::asio::placeholders::error)); + boost::asio::async_write(socket_, SharedBuffer(data), + boost::bind(&BoostConnection::handleDataWritten, shared_from_this(), boost::asio::placeholders::error)); } void BoostConnection::handleConnectFinished(const boost::system::error_code& error) { - SWIFT_LOG(debug) << "Connect finished: " << error << std::endl; - if (!error) { - eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), false), shared_from_this()); - doRead(); - } - else if (error != boost::asio::error::operation_aborted) { - eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), true), shared_from_this()); - } + SWIFT_LOG(debug) << "Connect finished: " << error; + if (!error) { + eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), false), shared_from_this()); + doRead(); + } + else if (error != boost::asio::error::operation_aborted) { + eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), true), shared_from_this()); + } } void BoostConnection::doRead() { - readBuffer_ = boost::make_shared<SafeByteArray>(BUFFER_SIZE); - socket_.async_read_some( - boost::asio::buffer(*readBuffer_), - boost::bind(&BoostConnection::handleSocketRead, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); + readBuffer_ = std::make_shared<SafeByteArray>(BUFFER_SIZE); + std::lock_guard<std::mutex> lock(readCloseMutex_); + socket_.async_read_some( + boost::asio::buffer(*readBuffer_), + boost::bind(&BoostConnection::handleSocketRead, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); } void BoostConnection::handleSocketRead(const boost::system::error_code& error, size_t bytesTransferred) { - SWIFT_LOG(debug) << "Socket read " << error << std::endl; - if (!error) { - readBuffer_->resize(bytesTransferred); - eventLoop->postEvent(boost::bind(boost::ref(onDataRead), readBuffer_), shared_from_this()); - doRead(); - } - else if (/*error == boost::asio::error::eof ||*/ error == boost::asio::error::operation_aborted) { - eventLoop->postEvent(boost::bind(boost::ref(onDisconnected), boost::optional<Error>()), shared_from_this()); - } - else { - eventLoop->postEvent(boost::bind(boost::ref(onDisconnected), ReadError), shared_from_this()); - } + SWIFT_LOG(debug) << "Socket read " << error; + if (!error) { + readBuffer_->resize(bytesTransferred); + eventLoop->postEvent(boost::bind(boost::ref(onDataRead), readBuffer_), shared_from_this()); + doRead(); + } + else if (/*error == boost::asio::error::eof ||*/ error == boost::asio::error::operation_aborted) { + eventLoop->postEvent(boost::bind(boost::ref(onDisconnected), boost::optional<Error>()), shared_from_this()); + } + else { + eventLoop->postEvent(boost::bind(boost::ref(onDisconnected), ReadError), shared_from_this()); + } } void BoostConnection::handleDataWritten(const boost::system::error_code& error) { - SWIFT_LOG(debug) << "Data written " << error << std::endl; - if (!error) { - eventLoop->postEvent(boost::ref(onDataWritten), shared_from_this()); - } - else if (/*error == boost::asio::error::eof || */error == boost::asio::error::operation_aborted) { - eventLoop->postEvent(boost::bind(boost::ref(onDisconnected), boost::optional<Error>()), shared_from_this()); - } - else { - eventLoop->postEvent(boost::bind(boost::ref(onDisconnected), WriteError), shared_from_this()); - } - { - boost::lock_guard<boost::mutex> lock(writeMutex_); - if (writeQueue_.empty()) { - writing_ = false; - if (closeSocketAfterNextWrite_) { - closeSocket(); - } - } - else { - doWrite(writeQueue_); - writeQueue_.clear(); - } - } + SWIFT_LOG(debug) << "Data written " << error; + if (!error) { + eventLoop->postEvent(boost::ref(onDataWritten), shared_from_this()); + } + else if (/*error == boost::asio::error::eof || */error == boost::asio::error::operation_aborted) { + eventLoop->postEvent(boost::bind(boost::ref(onDisconnected), boost::optional<Error>()), shared_from_this()); + } + else { + eventLoop->postEvent(boost::bind(boost::ref(onDisconnected), WriteError), shared_from_this()); + } + { + std::lock_guard<std::mutex> lock(writeMutex_); + if (writeQueue_.empty()) { + writing_ = false; + if (closeSocketAfterNextWrite_) { + closeSocket(); + } + } + else { + doWrite(writeQueue_); + writeQueue_.clear(); + } + } } HostAddressPort BoostConnection::getLocalAddress() const { - return HostAddressPort(socket_.local_endpoint()); + return HostAddressPort(socket_.local_endpoint()); +} + +HostAddressPort BoostConnection::getRemoteAddress() const { + return HostAddressPort(socket_.remote_endpoint()); } diff --git a/Swiften/Network/BoostConnection.h b/Swiften/Network/BoostConnection.h index 636853a..c77b933 100644 --- a/Swiften/Network/BoostConnection.h +++ b/Swiften/Network/BoostConnection.h @@ -1,70 +1,75 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2017 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once +#include <memory> +#include <mutex> + #include <boost/asio/io_service.hpp> #include <boost/asio/ip/tcp.hpp> -#include <boost/enable_shared_from_this.hpp> -#include <boost/thread/mutex.hpp> #include <Swiften/Base/API.h> -#include <Swiften/Network/Connection.h> -#include <Swiften/EventLoop/EventOwner.h> #include <Swiften/Base/SafeByteArray.h> - -namespace boost { - class thread; - namespace system { - class error_code; - } -} +#include <Swiften/EventLoop/EventOwner.h> +#include <Swiften/Network/Connection.h> +#include <Swiften/TLS/Certificate.h> +#include <Swiften/TLS/CertificateVerificationError.h> +#include <Swiften/TLS/CertificateWithKey.h> namespace Swift { - class EventLoop; - - class SWIFTEN_API BoostConnection : public Connection, public EventOwner, public boost::enable_shared_from_this<BoostConnection> { - public: - typedef boost::shared_ptr<BoostConnection> ref; - - ~BoostConnection(); - - static ref create(boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) { - return ref(new BoostConnection(ioService, eventLoop)); - } - - virtual void listen(); - virtual void connect(const HostAddressPort& address); - virtual void disconnect(); - virtual void write(const SafeByteArray& data); - - boost::asio::ip::tcp::socket& getSocket() { - return socket_; - } - - HostAddressPort getLocalAddress() const; - - private: - BoostConnection(boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop); - - void handleConnectFinished(const boost::system::error_code& error); - void handleSocketRead(const boost::system::error_code& error, size_t bytesTransferred); - void handleDataWritten(const boost::system::error_code& error); - void doRead(); - void doWrite(const SafeByteArray& data); - void closeSocket(); - - private: - EventLoop* eventLoop; - boost::shared_ptr<boost::asio::io_service> ioService; - boost::asio::ip::tcp::socket socket_; - boost::shared_ptr<SafeByteArray> readBuffer_; - boost::mutex writeMutex_; - bool writing_; - SafeByteArray writeQueue_; - bool closeSocketAfterNextWrite_; - }; + class EventLoop; + + class SWIFTEN_API BoostConnection : public Connection, public EventOwner, public std::enable_shared_from_this<BoostConnection> { + public: + typedef std::shared_ptr<BoostConnection> ref; + + virtual ~BoostConnection(); + + static ref create(std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) { + return ref(new BoostConnection(ioService, eventLoop)); + } + + virtual void listen(); + virtual void connect(const HostAddressPort& address); + virtual void disconnect(); + virtual void write(const SafeByteArray& data); + + boost::asio::ip::tcp::socket& getSocket() { + return socket_; + } + + virtual HostAddressPort getLocalAddress() const; + virtual HostAddressPort getRemoteAddress() const; + + bool setClientCertificate(CertificateWithKey::ref cert); + + Certificate::ref getPeerCertificate() const; + std::vector<Certificate::ref> getPeerCertificateChain() const; + std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const; + + private: + BoostConnection(std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop); + + void handleConnectFinished(const boost::system::error_code& error); + void handleSocketRead(const boost::system::error_code& error, size_t bytesTransferred); + void handleDataWritten(const boost::system::error_code& error); + void doRead(); + void doWrite(const SafeByteArray& data); + void closeSocket(); + + private: + EventLoop* eventLoop; + std::shared_ptr<boost::asio::io_service> ioService; + boost::asio::ip::tcp::socket socket_; + std::shared_ptr<SafeByteArray> readBuffer_; + std::mutex writeMutex_; + bool writing_; + SafeByteArray writeQueue_; + bool closeSocketAfterNextWrite_; + std::mutex readCloseMutex_; + }; } diff --git a/Swiften/Network/BoostConnectionFactory.cpp b/Swiften/Network/BoostConnectionFactory.cpp index d5f9fad..9ec30f5 100644 --- a/Swiften/Network/BoostConnectionFactory.cpp +++ b/Swiften/Network/BoostConnectionFactory.cpp @@ -1,19 +1,20 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/BoostConnectionFactory.h> + #include <Swiften/Network/BoostConnection.h> namespace Swift { -BoostConnectionFactory::BoostConnectionFactory(boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : ioService(ioService), eventLoop(eventLoop) { +BoostConnectionFactory::BoostConnectionFactory(std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : ioService(ioService), eventLoop(eventLoop) { } -boost::shared_ptr<Connection> BoostConnectionFactory::createConnection() { - return BoostConnection::create(ioService, eventLoop); +std::shared_ptr<Connection> BoostConnectionFactory::createConnection() { + return BoostConnection::create(ioService, eventLoop); } } diff --git a/Swiften/Network/BoostConnectionFactory.h b/Swiften/Network/BoostConnectionFactory.h index c0a105b..eef0b45 100644 --- a/Swiften/Network/BoostConnectionFactory.h +++ b/Swiften/Network/BoostConnectionFactory.h @@ -1,27 +1,26 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2017 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once #include <boost/asio/io_service.hpp> -#include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/Base/API.h> #include <Swiften/Network/BoostConnection.h> +#include <Swiften/Network/ConnectionFactory.h> namespace Swift { - class BoostConnection; - - class BoostConnectionFactory : public ConnectionFactory { - public: - BoostConnectionFactory(boost::shared_ptr<boost::asio::io_service>, EventLoop* eventLoop); + class SWIFTEN_API BoostConnectionFactory : public ConnectionFactory { + public: + BoostConnectionFactory(std::shared_ptr<boost::asio::io_service>, EventLoop* eventLoop); - virtual boost::shared_ptr<Connection> createConnection(); + virtual std::shared_ptr<Connection> createConnection(); - private: - boost::shared_ptr<boost::asio::io_service> ioService; - EventLoop* eventLoop; - }; + private: + std::shared_ptr<boost::asio::io_service> ioService; + EventLoop* eventLoop; + }; } diff --git a/Swiften/Network/BoostConnectionServer.cpp b/Swiften/Network/BoostConnectionServer.cpp index c90b554..dc05172 100644 --- a/Swiften/Network/BoostConnectionServer.cpp +++ b/Swiften/Network/BoostConnectionServer.cpp @@ -1,103 +1,112 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/BoostConnectionServer.h> -#include <boost/bind.hpp> -#include <boost/system/system_error.hpp> +#include <boost/asio/ip/v6_only.hpp> #include <boost/asio/placeholders.hpp> +#include <boost/bind.hpp> #include <boost/numeric/conversion/cast.hpp> #include <boost/optional.hpp> +#include <boost/system/error_code.hpp> +#include <boost/system/system_error.hpp> +#include <Swiften/Base/Log.h> #include <Swiften/EventLoop/EventLoop.h> namespace Swift { -BoostConnectionServer::BoostConnectionServer(int port, boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : port_(port), ioService_(ioService), eventLoop(eventLoop), acceptor_(NULL) { +BoostConnectionServer::BoostConnectionServer(unsigned short port, std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : port_(port), ioService_(ioService), eventLoop(eventLoop), acceptor_(nullptr) { } -BoostConnectionServer::BoostConnectionServer(const HostAddress &address, int port, boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : address_(address), port_(port), ioService_(ioService), eventLoop(eventLoop), acceptor_(NULL) { +BoostConnectionServer::BoostConnectionServer(const HostAddress &address, unsigned short port, std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : address_(address), port_(port), ioService_(ioService), eventLoop(eventLoop), acceptor_(nullptr) { } void BoostConnectionServer::start() { - boost::optional<Error> error = tryStart(); - if (error) { - eventLoop->postEvent(boost::bind(boost::ref(onStopped), *error), shared_from_this()); - } + boost::optional<Error> error = tryStart(); + if (error) { + eventLoop->postEvent(boost::bind(boost::ref(onStopped), *error), shared_from_this()); + } } boost::optional<BoostConnectionServer::Error> BoostConnectionServer::tryStart() { - try { - assert(!acceptor_); - if (address_.isValid()) { - acceptor_ = new boost::asio::ip::tcp::acceptor( - *ioService_, - boost::asio::ip::tcp::endpoint(address_.getRawAddress(), boost::numeric_cast<unsigned short>(port_))); - } - else { - acceptor_ = new boost::asio::ip::tcp::acceptor( - *ioService_, - boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), boost::numeric_cast<unsigned short>(port_))); - } - acceptNextConnection(); - } - catch (const boost::system::system_error& e) { - if (e.code() == boost::asio::error::address_in_use) { - return Conflict; - } - else { - return UnknownError; - } - } - return boost::optional<Error>(); + try { + assert(!acceptor_); + boost::asio::ip::tcp::endpoint endpoint; + if (address_.isValid()) { + endpoint = boost::asio::ip::tcp::endpoint(address_.getRawAddress(), port_); + } + else { + endpoint = boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v6(), port_); + } + acceptor_ = new boost::asio::ip::tcp::acceptor(*ioService_, endpoint); + if (endpoint.protocol() == boost::asio::ip::tcp::v6()) { + boost::system::error_code ec; + acceptor_->set_option(boost::asio::ip::v6_only(false), ec); + SWIFT_LOG_ASSERT(ec, warning) << "IPv4/IPv6 dual-stack support is not supported on this platform."; + } + acceptNextConnection(); + } + catch (const boost::system::system_error& e) { + if (e.code() == boost::asio::error::address_in_use) { + return Conflict; + } + else { + return UnknownError; + } + } + catch (const boost::numeric::bad_numeric_cast&) { + return UnknownError; + } + return boost::optional<Error>(); } void BoostConnectionServer::stop() { - stop(boost::optional<Error>()); + stop(boost::optional<Error>()); } void BoostConnectionServer::stop(boost::optional<Error> e) { - if (acceptor_) { - acceptor_->close(); - delete acceptor_; - acceptor_ = NULL; - } - eventLoop->postEvent(boost::bind(boost::ref(onStopped), e), shared_from_this()); + if (acceptor_) { + acceptor_->close(); + delete acceptor_; + acceptor_ = nullptr; + } + eventLoop->postEvent(boost::bind(boost::ref(onStopped), e), shared_from_this()); } void BoostConnectionServer::acceptNextConnection() { - BoostConnection::ref newConnection(BoostConnection::create(ioService_, eventLoop)); - acceptor_->async_accept(newConnection->getSocket(), - boost::bind(&BoostConnectionServer::handleAccept, shared_from_this(), newConnection, boost::asio::placeholders::error)); + BoostConnection::ref newConnection(BoostConnection::create(ioService_, eventLoop)); + acceptor_->async_accept(newConnection->getSocket(), + boost::bind(&BoostConnectionServer::handleAccept, shared_from_this(), newConnection, boost::asio::placeholders::error)); } -void BoostConnectionServer::handleAccept(boost::shared_ptr<BoostConnection> newConnection, const boost::system::error_code& error) { - if (error) { - eventLoop->postEvent( - boost::bind( - &BoostConnectionServer::stop, shared_from_this(), UnknownError), - shared_from_this()); - } - else { - eventLoop->postEvent( - boost::bind(boost::ref(onNewConnection), newConnection), - shared_from_this()); - newConnection->listen(); - acceptNextConnection(); - } +void BoostConnectionServer::handleAccept(std::shared_ptr<BoostConnection> newConnection, const boost::system::error_code& error) { + if (error) { + eventLoop->postEvent( + boost::bind( + &BoostConnectionServer::stop, shared_from_this(), UnknownError), + shared_from_this()); + } + else { + eventLoop->postEvent( + boost::bind(boost::ref(onNewConnection), newConnection), + shared_from_this()); + newConnection->listen(); + acceptNextConnection(); + } } HostAddressPort BoostConnectionServer::getAddressPort() const { - if (acceptor_) { - return HostAddressPort(acceptor_->local_endpoint()); - } - else { - return HostAddressPort(); - } + if (acceptor_) { + return HostAddressPort(acceptor_->local_endpoint()); + } + else { + return HostAddressPort(); + } } } diff --git a/Swiften/Network/BoostConnectionServer.h b/Swiften/Network/BoostConnectionServer.h index 3ad0450..917d638 100644 --- a/Swiften/Network/BoostConnectionServer.h +++ b/Swiften/Network/BoostConnectionServer.h @@ -1,57 +1,57 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> + #include <boost/asio/io_service.hpp> #include <boost/asio/ip/tcp.hpp> -#include <boost/enable_shared_from_this.hpp> +#include <boost/optional/optional.hpp> +#include <boost/signals2.hpp> #include <Swiften/Base/API.h> -#include <Swiften/Base/boost_bsignals.h> +#include <Swiften/EventLoop/EventOwner.h> #include <Swiften/Network/BoostConnection.h> #include <Swiften/Network/ConnectionServer.h> -#include <Swiften/EventLoop/EventOwner.h> -#include <boost/optional/optional_fwd.hpp> namespace Swift { - class SWIFTEN_API BoostConnectionServer : public ConnectionServer, public EventOwner, public boost::enable_shared_from_this<BoostConnectionServer> { - public: - typedef boost::shared_ptr<BoostConnectionServer> ref; + class SWIFTEN_API BoostConnectionServer : public ConnectionServer, public EventOwner, public std::enable_shared_from_this<BoostConnectionServer> { + public: + typedef std::shared_ptr<BoostConnectionServer> ref; - static ref create(int port, boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) { - return ref(new BoostConnectionServer(port, ioService, eventLoop)); - } + static ref create(unsigned short port, std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) { + return ref(new BoostConnectionServer(port, ioService, eventLoop)); + } - static ref create(const HostAddress &address, int port, boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) { - return ref(new BoostConnectionServer(address, port, ioService, eventLoop)); - } + static ref create(const HostAddress &address, unsigned short port, std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) { + return ref(new BoostConnectionServer(address, port, ioService, eventLoop)); + } - virtual boost::optional<Error> tryStart(); // FIXME: This should become the new start - virtual void start(); - virtual void stop(); + virtual boost::optional<Error> tryStart(); // FIXME: This should become the new start + virtual void start(); + virtual void stop(); - virtual HostAddressPort getAddressPort() const; + virtual HostAddressPort getAddressPort() const; - boost::signal<void (boost::optional<Error>)> onStopped; + boost::signals2::signal<void (boost::optional<Error>)> onStopped; - private: - BoostConnectionServer(int port, boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop); - BoostConnectionServer(const HostAddress &address, int port, boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop); + private: + BoostConnectionServer(unsigned short port, std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop); + BoostConnectionServer(const HostAddress &address, unsigned short port, std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop); - void stop(boost::optional<Error> e); - void acceptNextConnection(); - void handleAccept(boost::shared_ptr<BoostConnection> newConnection, const boost::system::error_code& error); + void stop(boost::optional<Error> e); + void acceptNextConnection(); + void handleAccept(std::shared_ptr<BoostConnection> newConnection, const boost::system::error_code& error); - private: - HostAddress address_; - int port_; - boost::shared_ptr<boost::asio::io_service> ioService_; - EventLoop* eventLoop; - boost::asio::ip::tcp::acceptor* acceptor_; - }; + private: + HostAddress address_; + unsigned short port_; + std::shared_ptr<boost::asio::io_service> ioService_; + EventLoop* eventLoop; + boost::asio::ip::tcp::acceptor* acceptor_; + }; } diff --git a/Swiften/Network/BoostConnectionServerFactory.cpp b/Swiften/Network/BoostConnectionServerFactory.cpp index 04c614e..6936453 100644 --- a/Swiften/Network/BoostConnectionServerFactory.cpp +++ b/Swiften/Network/BoostConnectionServerFactory.cpp @@ -4,20 +4,27 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #include <Swiften/Network/BoostConnectionServerFactory.h> + #include <Swiften/Network/BoostConnectionServer.h> namespace Swift { -BoostConnectionServerFactory::BoostConnectionServerFactory(boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : ioService(ioService), eventLoop(eventLoop) { +BoostConnectionServerFactory::BoostConnectionServerFactory(std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : ioService(ioService), eventLoop(eventLoop) { } -boost::shared_ptr<ConnectionServer> BoostConnectionServerFactory::createConnectionServer(int port) { - return BoostConnectionServer::create(port, ioService, eventLoop); +std::shared_ptr<ConnectionServer> BoostConnectionServerFactory::createConnectionServer(unsigned short port) { + return BoostConnectionServer::create(port, ioService, eventLoop); } -boost::shared_ptr<ConnectionServer> BoostConnectionServerFactory::createConnectionServer(const Swift::HostAddress &hostAddress, int port) { - return BoostConnectionServer::create(hostAddress, port, ioService, eventLoop); +std::shared_ptr<ConnectionServer> BoostConnectionServerFactory::createConnectionServer(const Swift::HostAddress &hostAddress, unsigned short port) { + return BoostConnectionServer::create(hostAddress, port, ioService, eventLoop); } } diff --git a/Swiften/Network/BoostConnectionServerFactory.h b/Swiften/Network/BoostConnectionServerFactory.h index 9132b5c..956132b 100644 --- a/Swiften/Network/BoostConnectionServerFactory.h +++ b/Swiften/Network/BoostConnectionServerFactory.h @@ -4,26 +4,33 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2015-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once #include <boost/asio/io_service.hpp> -#include <Swiften/Network/ConnectionServerFactory.h> +#include <Swiften/Base/API.h> #include <Swiften/Network/BoostConnectionServer.h> +#include <Swiften/Network/ConnectionServerFactory.h> namespace Swift { - class ConnectionServer; + class ConnectionServer; - class BoostConnectionServerFactory : public ConnectionServerFactory { - public: - BoostConnectionServerFactory(boost::shared_ptr<boost::asio::io_service>, EventLoop* eventLoop); + class SWIFTEN_API BoostConnectionServerFactory : public ConnectionServerFactory { + public: + BoostConnectionServerFactory(std::shared_ptr<boost::asio::io_service>, EventLoop* eventLoop); - virtual boost::shared_ptr<ConnectionServer> createConnectionServer(int port); + virtual std::shared_ptr<ConnectionServer> createConnectionServer(unsigned short port); - virtual boost::shared_ptr<ConnectionServer> createConnectionServer(const Swift::HostAddress &hostAddress, int port); + virtual std::shared_ptr<ConnectionServer> createConnectionServer(const Swift::HostAddress &hostAddress, unsigned short port); - private: - boost::shared_ptr<boost::asio::io_service> ioService; - EventLoop* eventLoop; - }; + private: + std::shared_ptr<boost::asio::io_service> ioService; + EventLoop* eventLoop; + }; } diff --git a/Swiften/Network/BoostIOServiceThread.cpp b/Swiften/Network/BoostIOServiceThread.cpp index c98a653..756e660 100644 --- a/Swiften/Network/BoostIOServiceThread.cpp +++ b/Swiften/Network/BoostIOServiceThread.cpp @@ -1,29 +1,39 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/BoostIOServiceThread.h> -#include <boost/smart_ptr/make_shared.hpp> +#include <memory> + +#include <boost/bind.hpp> namespace Swift { -BoostIOServiceThread::BoostIOServiceThread() { - ioService_ = boost::make_shared<boost::asio::io_service>(); - thread_ = new boost::thread(boost::bind(&BoostIOServiceThread::doRun, this)); +BoostIOServiceThread::BoostIOServiceThread(std::shared_ptr<boost::asio::io_service> ioService) { + if (!!ioService) { + ioService_ = ioService; + thread_ = nullptr; + } + else { + ioService_ = std::make_shared<boost::asio::io_service>(); + thread_ = new std::thread(boost::bind(&BoostIOServiceThread::doRun, this)); + } } BoostIOServiceThread::~BoostIOServiceThread() { - ioService_->stop(); - thread_->join(); - delete thread_; + if (thread_) { + ioService_->stop(); + thread_->join(); + delete thread_; + } } void BoostIOServiceThread::doRun() { - boost::asio::io_service::work work(*ioService_); - ioService_->run(); + boost::asio::io_service::work work(*ioService_); + ioService_->run(); } } diff --git a/Swiften/Network/BoostIOServiceThread.h b/Swiften/Network/BoostIOServiceThread.h index d1a5f37..b9183fd 100644 --- a/Swiften/Network/BoostIOServiceThread.h +++ b/Swiften/Network/BoostIOServiceThread.h @@ -1,32 +1,41 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once +#include <memory> +#include <thread> + #include <boost/asio/io_service.hpp> -#include <boost/thread/thread.hpp> -#include <boost/shared_ptr.hpp> #include <Swiften/Base/API.h> namespace Swift { - class SWIFTEN_API BoostIOServiceThread { - public: - BoostIOServiceThread(); - ~BoostIOServiceThread(); - - boost::shared_ptr<boost::asio::io_service> getIOService() const { - return ioService_; - } - - private: - void doRun(); - - private: - boost::shared_ptr<boost::asio::io_service> ioService_; - boost::thread* thread_; - }; + class SWIFTEN_API BoostIOServiceThread { + public: + /** + * Construct the object. + * @param ioService If this optional parameter is provided, the behaviour + * of this class changes completely - it no longer spawns its own thread + * and instead acts as a simple wrapper of the io_service. Use this if + * you are re-using an io_service from elsewhere (particularly if you + * are using the BoostASIOEventLoop). + */ + BoostIOServiceThread(std::shared_ptr<boost::asio::io_service> ioService = std::shared_ptr<boost::asio::io_service>()); + ~BoostIOServiceThread(); + + std::shared_ptr<boost::asio::io_service> getIOService() const { + return ioService_; + } + + private: + void doRun(); + + private: + std::shared_ptr<boost::asio::io_service> ioService_; + std::thread* thread_; + }; } diff --git a/Swiften/Network/BoostNetworkFactories.cpp b/Swiften/Network/BoostNetworkFactories.cpp index 72e826a..13a7960 100644 --- a/Swiften/Network/BoostNetworkFactories.cpp +++ b/Swiften/Network/BoostNetworkFactories.cpp @@ -1,24 +1,24 @@ /* - * Copyright (c) 2010-2013 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/BoostNetworkFactories.h> -#include <Swiften/Network/BoostTimerFactory.h> -#include <Swiften/Network/BoostConnectionFactory.h> +#include <Swiften/Crypto/CryptoProvider.h> +#include <Swiften/Crypto/PlatformCryptoProvider.h> +#include <Swiften/IDN/IDNConverter.h> +#include <Swiften/IDN/PlatformIDNConverter.h> +#include <Swiften/Network/BoostConnectionFactory.h> #include <Swiften/Network/BoostConnectionServerFactory.h> -#include <Swiften/Network/PlatformNATTraversalWorker.h> -#include <Swiften/Parser/PlatformXMLParserFactory.h> +#include <Swiften/Network/BoostTimerFactory.h> #include <Swiften/Network/NullNATTraverser.h> +#include <Swiften/Network/PlatformNATTraversalWorker.h> #include <Swiften/Network/PlatformNetworkEnvironment.h> -#include <Swiften/TLS/PlatformTLSFactories.h> #include <Swiften/Network/PlatformProxyProvider.h> -#include <Swiften/IDN/PlatformIDNConverter.h> -#include <Swiften/IDN/IDNConverter.h> -#include <Swiften/Crypto/PlatformCryptoProvider.h> -#include <Swiften/Crypto/CryptoProvider.h> +#include <Swiften/Parser/PlatformXMLParserFactory.h> +#include <Swiften/TLS/PlatformTLSFactories.h> #ifdef USE_UNBOUND #include <Swiften/Network/UnboundDomainNameResolver.h> @@ -28,45 +28,44 @@ namespace Swift { -BoostNetworkFactories::BoostNetworkFactories(EventLoop* eventLoop) : eventLoop(eventLoop){ - timerFactory = new BoostTimerFactory(ioServiceThread.getIOService(), eventLoop); - connectionFactory = new BoostConnectionFactory(ioServiceThread.getIOService(), eventLoop); - connectionServerFactory = new BoostConnectionServerFactory(ioServiceThread.getIOService(), eventLoop); +BoostNetworkFactories::BoostNetworkFactories(EventLoop* eventLoop, std::shared_ptr<boost::asio::io_service> ioService) : ioServiceThread(ioService), eventLoop(eventLoop) { + timerFactory = new BoostTimerFactory(ioServiceThread.getIOService(), eventLoop); + connectionFactory = new BoostConnectionFactory(ioServiceThread.getIOService(), eventLoop); + connectionServerFactory = new BoostConnectionServerFactory(ioServiceThread.getIOService(), eventLoop); #ifdef SWIFT_EXPERIMENTAL_FT - natTraverser = new PlatformNATTraversalWorker(eventLoop); + natTraverser = new PlatformNATTraversalWorker(eventLoop); #else - natTraverser = new NullNATTraverser(eventLoop); + natTraverser = new NullNATTraverser(eventLoop); #endif - networkEnvironment = new PlatformNetworkEnvironment(); - xmlParserFactory = new PlatformXMLParserFactory(); - tlsFactories = new PlatformTLSFactories(); - proxyProvider = new PlatformProxyProvider(); - idnConverter = PlatformIDNConverter::create(); + networkEnvironment = new PlatformNetworkEnvironment(); + xmlParserFactory = new PlatformXMLParserFactory(); + tlsFactories = new PlatformTLSFactories(); + proxyProvider = new PlatformProxyProvider(); + idnConverter = PlatformIDNConverter::create(); #ifdef USE_UNBOUND - // TODO: What to do about idnConverter. - domainNameResolver = new UnboundDomainNameResolver(ioServiceThread.getIOService(), eventLoop); + // TODO: What to do about idnConverter. + domainNameResolver = new UnboundDomainNameResolver(idnConverter.get(), ioServiceThread.getIOService(), eventLoop); #else - domainNameResolver = new PlatformDomainNameResolver(idnConverter, eventLoop); + domainNameResolver = new PlatformDomainNameResolver(idnConverter.get(), eventLoop); #endif - cryptoProvider = PlatformCryptoProvider::create(); + cryptoProvider = PlatformCryptoProvider::create(); } BoostNetworkFactories::~BoostNetworkFactories() { - delete cryptoProvider; - delete domainNameResolver; - delete idnConverter; - delete proxyProvider; - delete tlsFactories; - delete xmlParserFactory; - delete networkEnvironment; - delete natTraverser; - delete connectionServerFactory; - delete connectionFactory; - delete timerFactory; + delete cryptoProvider; + delete domainNameResolver; + delete proxyProvider; + delete tlsFactories; + delete xmlParserFactory; + delete networkEnvironment; + delete natTraverser; + delete connectionServerFactory; + delete connectionFactory; + delete timerFactory; } TLSContextFactory* BoostNetworkFactories::getTLSContextFactory() const { - return tlsFactories->getTLSContextFactory(); + return tlsFactories->getTLSContextFactory(); } } diff --git a/Swiften/Network/BoostNetworkFactories.h b/Swiften/Network/BoostNetworkFactories.h index 9c3bab1..33a3584 100644 --- a/Swiften/Network/BoostNetworkFactories.h +++ b/Swiften/Network/BoostNetworkFactories.h @@ -1,89 +1,95 @@ /* - * Copyright (c) 2010-2013 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once +#include <memory> + #include <Swiften/Base/API.h> -#include <Swiften/Base/Override.h> -#include <Swiften/Network/NetworkFactories.h> #include <Swiften/Network/BoostIOServiceThread.h> +#include <Swiften/Network/NetworkFactories.h> namespace Swift { - class EventLoop; - class NATTraverser; - class PlatformTLSFactories; - - class SWIFTEN_API BoostNetworkFactories : public NetworkFactories { - public: - BoostNetworkFactories(EventLoop* eventLoop); - ~BoostNetworkFactories(); - - virtual TimerFactory* getTimerFactory() const SWIFTEN_OVERRIDE { - return timerFactory; - } - - virtual ConnectionFactory* getConnectionFactory() const SWIFTEN_OVERRIDE { - return connectionFactory; - } - - BoostIOServiceThread* getIOServiceThread() { - return &ioServiceThread; - } - - DomainNameResolver* getDomainNameResolver() const SWIFTEN_OVERRIDE { - return domainNameResolver; - } - - ConnectionServerFactory* getConnectionServerFactory() const SWIFTEN_OVERRIDE { - return connectionServerFactory; - } - - NetworkEnvironment* getNetworkEnvironment() const SWIFTEN_OVERRIDE { - return networkEnvironment; - } - - NATTraverser* getNATTraverser() const SWIFTEN_OVERRIDE { - return natTraverser; - } - - virtual XMLParserFactory* getXMLParserFactory() const SWIFTEN_OVERRIDE { - return xmlParserFactory; - } - - virtual TLSContextFactory* getTLSContextFactory() const SWIFTEN_OVERRIDE; - - virtual ProxyProvider* getProxyProvider() const SWIFTEN_OVERRIDE { - return proxyProvider; - } - - virtual EventLoop* getEventLoop() const SWIFTEN_OVERRIDE { - return eventLoop; - } - - virtual IDNConverter* getIDNConverter() const SWIFTEN_OVERRIDE { - return idnConverter; - } - - virtual CryptoProvider* getCryptoProvider() const SWIFTEN_OVERRIDE { - return cryptoProvider; - } - - private: - BoostIOServiceThread ioServiceThread; - TimerFactory* timerFactory; - ConnectionFactory* connectionFactory; - DomainNameResolver* domainNameResolver; - ConnectionServerFactory* connectionServerFactory; - NATTraverser* natTraverser; - NetworkEnvironment* networkEnvironment; - XMLParserFactory* xmlParserFactory; - PlatformTLSFactories* tlsFactories; - ProxyProvider* proxyProvider; - EventLoop* eventLoop; - IDNConverter* idnConverter; - CryptoProvider* cryptoProvider; - }; + class EventLoop; + class NATTraverser; + class PlatformTLSFactories; + + class SWIFTEN_API BoostNetworkFactories : public NetworkFactories { + public: + /** + * Construct the network factories, using the provided EventLoop. + * @param ioService If this optional parameter is provided, it will be + * used for the construction of the BoostIOServiceThread. + */ + BoostNetworkFactories(EventLoop* eventLoop, std::shared_ptr<boost::asio::io_service> ioService = std::shared_ptr<boost::asio::io_service>()); + virtual ~BoostNetworkFactories() override; + + virtual TimerFactory* getTimerFactory() const override { + return timerFactory; + } + + virtual ConnectionFactory* getConnectionFactory() const override { + return connectionFactory; + } + + BoostIOServiceThread* getIOServiceThread() { + return &ioServiceThread; + } + + DomainNameResolver* getDomainNameResolver() const override { + return domainNameResolver; + } + + ConnectionServerFactory* getConnectionServerFactory() const override { + return connectionServerFactory; + } + + NetworkEnvironment* getNetworkEnvironment() const override { + return networkEnvironment; + } + + NATTraverser* getNATTraverser() const override { + return natTraverser; + } + + virtual XMLParserFactory* getXMLParserFactory() const override { + return xmlParserFactory; + } + + virtual TLSContextFactory* getTLSContextFactory() const override; + + virtual ProxyProvider* getProxyProvider() const override { + return proxyProvider; + } + + virtual EventLoop* getEventLoop() const override { + return eventLoop; + } + + virtual IDNConverter* getIDNConverter() const override { + return idnConverter.get(); + } + + virtual CryptoProvider* getCryptoProvider() const override { + return cryptoProvider; + } + + private: + BoostIOServiceThread ioServiceThread; + TimerFactory* timerFactory; + ConnectionFactory* connectionFactory; + DomainNameResolver* domainNameResolver; + ConnectionServerFactory* connectionServerFactory; + NATTraverser* natTraverser; + NetworkEnvironment* networkEnvironment; + XMLParserFactory* xmlParserFactory; + PlatformTLSFactories* tlsFactories; + ProxyProvider* proxyProvider; + EventLoop* eventLoop; + std::unique_ptr<IDNConverter> idnConverter; + CryptoProvider* cryptoProvider; + }; } diff --git a/Swiften/Network/BoostTimer.cpp b/Swiften/Network/BoostTimer.cpp index bf042d6..a177504 100644 --- a/Swiften/Network/BoostTimer.cpp +++ b/Swiften/Network/BoostTimer.cpp @@ -1,40 +1,62 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/BoostTimer.h> -#include <boost/date_time/posix_time/posix_time.hpp> #include <boost/asio.hpp> #include <boost/bind.hpp> +#include <boost/date_time/posix_time/posix_time.hpp> #include <Swiften/EventLoop/EventLoop.h> namespace Swift { -BoostTimer::BoostTimer(int milliseconds, boost::shared_ptr<boost::asio::io_service> service, EventLoop* eventLoop) : - timeout(milliseconds), ioService(service), timer(*service), eventLoop(eventLoop) { +BoostTimer::BoostTimer(int milliseconds, std::shared_ptr<boost::asio::io_service> service, EventLoop* eventLoop) : + timeout(milliseconds), ioService(service), eventLoop(eventLoop), shuttingDown(false) { + timer.reset(new boost::asio::deadline_timer(*service)); +} + +BoostTimer::~BoostTimer() { + { + std::unique_lock<std::mutex> lockTimer(timerMutex); + timer.reset(); + } } void BoostTimer::start() { - timer.expires_from_now(boost::posix_time::milliseconds(timeout)); - timer.async_wait(boost::bind(&BoostTimer::handleTimerTick, shared_from_this(), boost::asio::placeholders::error)); + { + std::unique_lock<std::mutex> lockTimer(timerMutex); + shuttingDown = false; + timer->expires_from_now(boost::posix_time::milliseconds(timeout)); + timer->async_wait(boost::bind(&BoostTimer::handleTimerTick, shared_from_this(), boost::asio::placeholders::error)); + } } void BoostTimer::stop() { - timer.cancel(); - eventLoop->removeEventsFromOwner(shared_from_this()); + { + std::unique_lock<std::mutex> lockTimer(timerMutex); + shuttingDown = true; + timer->cancel(); + eventLoop->removeEventsFromOwner(shared_from_this()); + } } void BoostTimer::handleTimerTick(const boost::system::error_code& error) { - if (error) { - assert(error == boost::asio::error::operation_aborted); - } - else { - eventLoop->postEvent(boost::bind(boost::ref(onTick)), shared_from_this()); - } + if (error) { + assert(error == boost::asio::error::operation_aborted); + } + else { + { + std::unique_lock<std::mutex> lockTimer(timerMutex); + if (shuttingDown) { + return; + } + eventLoop->postEvent(boost::bind(boost::ref(onTick)), shared_from_this()); + } + } } } diff --git a/Swiften/Network/BoostTimer.h b/Swiften/Network/BoostTimer.h index bfe631b..68ae28c 100644 --- a/Swiften/Network/BoostTimer.h +++ b/Swiften/Network/BoostTimer.h @@ -1,41 +1,49 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/asio/io_service.hpp> +#include <memory> +#include <mutex> + #include <boost/asio/deadline_timer.hpp> -#include <boost/enable_shared_from_this.hpp> +#include <boost/asio/io_service.hpp> +#include <boost/scoped_ptr.hpp> +#include <Swiften/Base/API.h> #include <Swiften/EventLoop/EventOwner.h> #include <Swiften/Network/Timer.h> namespace Swift { - class EventLoop; + class EventLoop; + + class SWIFTEN_API BoostTimer : public Timer, public EventOwner, public std::enable_shared_from_this<BoostTimer> { + public: + typedef std::shared_ptr<BoostTimer> ref; - class BoostTimer : public Timer, public EventOwner, public boost::enable_shared_from_this<BoostTimer> { - public: - typedef boost::shared_ptr<BoostTimer> ref; + virtual ~BoostTimer(); - static ref create(int milliseconds, boost::shared_ptr<boost::asio::io_service> service, EventLoop* eventLoop) { - return ref(new BoostTimer(milliseconds, service, eventLoop)); - } + static ref create(int milliseconds, std::shared_ptr<boost::asio::io_service> service, EventLoop* eventLoop) { + return ref(new BoostTimer(milliseconds, service, eventLoop)); + } - virtual void start(); - virtual void stop(); + virtual void start(); + virtual void stop(); - private: - BoostTimer(int milliseconds, boost::shared_ptr<boost::asio::io_service> service, EventLoop* eventLoop); + private: + BoostTimer(int milliseconds, std::shared_ptr<boost::asio::io_service> service, EventLoop* eventLoop); - void handleTimerTick(const boost::system::error_code& error); + void handleTimerTick(const boost::system::error_code& error); - private: - int timeout; - boost::shared_ptr<boost::asio::io_service> ioService; - boost::asio::deadline_timer timer; - EventLoop* eventLoop; - }; + private: + int timeout; + std::shared_ptr<boost::asio::io_service> ioService; + boost::scoped_ptr<boost::asio::deadline_timer> timer; + std::mutex timerMutex; + EventLoop* eventLoop; + bool shuttingDown; + }; } diff --git a/Swiften/Network/BoostTimerFactory.cpp b/Swiften/Network/BoostTimerFactory.cpp index c0bdb56..ffa9b30 100644 --- a/Swiften/Network/BoostTimerFactory.cpp +++ b/Swiften/Network/BoostTimerFactory.cpp @@ -1,19 +1,20 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/BoostTimerFactory.h> + #include <Swiften/Network/BoostTimer.h> namespace Swift { -BoostTimerFactory::BoostTimerFactory(boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : ioService(ioService), eventLoop(eventLoop) { +BoostTimerFactory::BoostTimerFactory(std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : ioService(ioService), eventLoop(eventLoop) { } -boost::shared_ptr<Timer> BoostTimerFactory::createTimer(int milliseconds) { - return BoostTimer::create(milliseconds, ioService, eventLoop); +std::shared_ptr<Timer> BoostTimerFactory::createTimer(int milliseconds) { + return BoostTimer::create(milliseconds, ioService, eventLoop); } } diff --git a/Swiften/Network/BoostTimerFactory.h b/Swiften/Network/BoostTimerFactory.h index 6093db0..1e2139b 100644 --- a/Swiften/Network/BoostTimerFactory.h +++ b/Swiften/Network/BoostTimerFactory.h @@ -1,28 +1,28 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2017 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once #include <boost/asio/io_service.hpp> -#include <Swiften/Network/TimerFactory.h> +#include <Swiften/Base/API.h> #include <Swiften/Network/BoostTimer.h> +#include <Swiften/Network/TimerFactory.h> namespace Swift { - class BoostTimer; - class EventLoop; + class EventLoop; - class BoostTimerFactory : public TimerFactory { - public: - BoostTimerFactory(boost::shared_ptr<boost::asio::io_service>, EventLoop* eventLoop); + class SWIFTEN_API BoostTimerFactory : public TimerFactory { + public: + BoostTimerFactory(std::shared_ptr<boost::asio::io_service>, EventLoop* eventLoop); - virtual boost::shared_ptr<Timer> createTimer(int milliseconds); + virtual std::shared_ptr<Timer> createTimer(int milliseconds); - private: - boost::shared_ptr<boost::asio::io_service> ioService; - EventLoop* eventLoop; - }; + private: + std::shared_ptr<boost::asio::io_service> ioService; + EventLoop* eventLoop; + }; } diff --git a/Swiften/Network/CachingDomainNameResolver.cpp b/Swiften/Network/CachingDomainNameResolver.cpp index 4cf8286..8846e09 100644 --- a/Swiften/Network/CachingDomainNameResolver.cpp +++ b/Swiften/Network/CachingDomainNameResolver.cpp @@ -1,12 +1,12 @@ /* - * Copyright (c) 2012 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2012-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/CachingDomainNameResolver.h> -#include <boost/smart_ptr/make_shared.hpp> +#include <memory> namespace Swift { @@ -17,14 +17,14 @@ CachingDomainNameResolver::~CachingDomainNameResolver() { } -DomainNameServiceQuery::ref CachingDomainNameResolver::createServiceQuery(const std::string& name) { - //TODO: Cache - return realResolver->createServiceQuery(name); +DomainNameServiceQuery::ref CachingDomainNameResolver::createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain) { + //TODO: Cache + return realResolver->createServiceQuery(serviceLookupPrefix, domain); } DomainNameAddressQuery::ref CachingDomainNameResolver::createAddressQuery(const std::string& name) { - //TODO: Cache - return realResolver->createAddressQuery(name); + //TODO: Cache + return realResolver->createAddressQuery(name); } } diff --git a/Swiften/Network/CachingDomainNameResolver.h b/Swiften/Network/CachingDomainNameResolver.h index 66b4d68..9339a77 100644 --- a/Swiften/Network/CachingDomainNameResolver.h +++ b/Swiften/Network/CachingDomainNameResolver.h @@ -1,13 +1,14 @@ /* - * Copyright (c) 2012 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2012-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> +#include <Swiften/Base/API.h> #include <Swiften/Network/DomainNameResolver.h> #include <Swiften/Network/StaticDomainNameResolver.h> @@ -15,17 +16,17 @@ * FIXME: Does not do any caching yet. */ namespace Swift { - class EventLoop; + class EventLoop; - class CachingDomainNameResolver : public DomainNameResolver { - public: - CachingDomainNameResolver(DomainNameResolver* realResolver, EventLoop* eventLoop); - ~CachingDomainNameResolver(); + class SWIFTEN_API CachingDomainNameResolver : public DomainNameResolver { + public: + CachingDomainNameResolver(DomainNameResolver* realResolver, EventLoop* eventLoop); + ~CachingDomainNameResolver(); - virtual DomainNameServiceQuery::ref createServiceQuery(const std::string& name); - virtual DomainNameAddressQuery::ref createAddressQuery(const std::string& name); + virtual DomainNameServiceQuery::ref createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain); + virtual DomainNameAddressQuery::ref createAddressQuery(const std::string& name); - private: - DomainNameResolver* realResolver; - }; + private: + DomainNameResolver* realResolver; + }; } diff --git a/Swiften/Network/ChainedConnector.cpp b/Swiften/Network/ChainedConnector.cpp index 8c7c04b..a9210ba 100644 --- a/Swiften/Network/ChainedConnector.cpp +++ b/Swiften/Network/ChainedConnector.cpp @@ -1,87 +1,95 @@ /* - * Copyright (c) 2011 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/ChainedConnector.h> -#include <boost/bind.hpp> #include <typeinfo> +#include <boost/bind.hpp> + #include <Swiften/Base/Log.h> -#include <Swiften/Base/foreach.h> -#include <Swiften/Network/Connector.h> #include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/Network/Connector.h> using namespace Swift; ChainedConnector::ChainedConnector( - const std::string& hostname, - int port, - bool doServiceLookups, - DomainNameResolver* resolver, - const std::vector<ConnectionFactory*>& connectionFactories, - TimerFactory* timerFactory) : - hostname(hostname), - port(port), - doServiceLookups(doServiceLookups), - resolver(resolver), - connectionFactories(connectionFactories), - timerFactory(timerFactory), - timeoutMilliseconds(0) { + const std::string& hostname, + unsigned short port, + const boost::optional<std::string>& serviceLookupPrefix, + DomainNameResolver* resolver, + const std::vector<ConnectionFactory*>& connectionFactories, + TimerFactory* timerFactory) : + hostname(hostname), + port(port), + serviceLookupPrefix(serviceLookupPrefix), + resolver(resolver), + connectionFactories(connectionFactories), + timerFactory(timerFactory), + timeoutMilliseconds(0) { +} + +ChainedConnector::~ChainedConnector() { + if (currentConnector) { + currentConnector->onConnectFinished.disconnect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1, _2)); + currentConnector->stop(); + currentConnector.reset(); + } } void ChainedConnector::setTimeoutMilliseconds(int milliseconds) { - timeoutMilliseconds = milliseconds; + timeoutMilliseconds = milliseconds; } void ChainedConnector::start() { - SWIFT_LOG(debug) << "Starting queued connector for " << hostname << std::endl; + SWIFT_LOG(debug) << "Starting queued connector for " << hostname; - connectionFactoryQueue = std::deque<ConnectionFactory*>(connectionFactories.begin(), connectionFactories.end()); - tryNextConnectionFactory(); + connectionFactoryQueue = std::deque<ConnectionFactory*>(connectionFactories.begin(), connectionFactories.end()); + tryNextConnectionFactory(); } void ChainedConnector::stop() { - if (currentConnector) { - currentConnector->onConnectFinished.disconnect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1, _2)); - currentConnector->stop(); - currentConnector.reset(); - } - finish(boost::shared_ptr<Connection>(), boost::shared_ptr<Error>()); + if (currentConnector) { + currentConnector->onConnectFinished.disconnect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1, _2)); + currentConnector->stop(); + currentConnector.reset(); + } + finish(std::shared_ptr<Connection>(), std::shared_ptr<Error>()); } void ChainedConnector::tryNextConnectionFactory() { - assert(!currentConnector); - if (connectionFactoryQueue.empty()) { - SWIFT_LOG(debug) << "No more connection factories" << std::endl; - finish(boost::shared_ptr<Connection>(), lastError); - } - else { - ConnectionFactory* connectionFactory = connectionFactoryQueue.front(); - SWIFT_LOG(debug) << "Trying next connection factory: " << typeid(*connectionFactory).name() << std::endl; - connectionFactoryQueue.pop_front(); - currentConnector = Connector::create(hostname, port, doServiceLookups, resolver, connectionFactory, timerFactory); - currentConnector->setTimeoutMilliseconds(timeoutMilliseconds); - currentConnector->onConnectFinished.connect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1, _2)); - currentConnector->start(); - } + assert(!currentConnector); + if (connectionFactoryQueue.empty()) { + SWIFT_LOG(debug) << "No more connection factories"; + finish(std::shared_ptr<Connection>(), lastError); + } + else { + ConnectionFactory* connectionFactory = connectionFactoryQueue.front(); + SWIFT_LOG(debug) << "Trying next connection factory: " << typeid(*connectionFactory).name(); + connectionFactoryQueue.pop_front(); + currentConnector = Connector::create(hostname, port, serviceLookupPrefix, resolver, connectionFactory, timerFactory); + currentConnector->setTimeoutMilliseconds(timeoutMilliseconds); + currentConnector->onConnectFinished.connect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1, _2)); + currentConnector->start(); + } } -void ChainedConnector::handleConnectorFinished(boost::shared_ptr<Connection> connection, boost::shared_ptr<Error> error) { - SWIFT_LOG(debug) << "Connector finished" << std::endl; - currentConnector->onConnectFinished.disconnect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1, _2)); - lastError = error; - currentConnector.reset(); - if (connection) { - finish(connection, error); - } - else { - tryNextConnectionFactory(); - } +void ChainedConnector::handleConnectorFinished(std::shared_ptr<Connection> connection, std::shared_ptr<Error> error) { + SWIFT_LOG(debug) << "Connector finished"; + currentConnector->onConnectFinished.disconnect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1, _2)); + lastError = error; + currentConnector.reset(); + if (connection) { + finish(connection, error); + } + else { + tryNextConnectionFactory(); + } } -void ChainedConnector::finish(boost::shared_ptr<Connection> connection, boost::shared_ptr<Error> error) { - onConnectFinished(connection, error); +void ChainedConnector::finish(std::shared_ptr<Connection> connection, std::shared_ptr<Error> error) { + onConnectFinished(connection, error); } diff --git a/Swiften/Network/ChainedConnector.h b/Swiften/Network/ChainedConnector.h index 03462bc..9620293 100644 --- a/Swiften/Network/ChainedConnector.h +++ b/Swiften/Network/ChainedConnector.h @@ -1,52 +1,55 @@ /* - * Copyright (c) 2011 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once +#include <deque> +#include <memory> #include <string> #include <vector> -#include <deque> -#include <boost/shared_ptr.hpp> + +#include <boost/optional.hpp> +#include <boost/signals2.hpp> #include <Swiften/Base/API.h> -#include <Swiften/Base/boost_bsignals.h> #include <Swiften/Base/Error.h> namespace Swift { - class Connection; - class Connector; - class ConnectionFactory; - class TimerFactory; - class DomainNameResolver; - - class SWIFTEN_API ChainedConnector { - public: - ChainedConnector(const std::string& hostname, int port, bool doServiceLookups, DomainNameResolver*, const std::vector<ConnectionFactory*>&, TimerFactory*); - - void setTimeoutMilliseconds(int milliseconds); - void start(); - void stop(); - - boost::signal<void (boost::shared_ptr<Connection>, boost::shared_ptr<Error>)> onConnectFinished; - - private: - void finish(boost::shared_ptr<Connection> connection, boost::shared_ptr<Error>); - void tryNextConnectionFactory(); - void handleConnectorFinished(boost::shared_ptr<Connection>, boost::shared_ptr<Error>); - - private: - std::string hostname; - int port; - bool doServiceLookups; - DomainNameResolver* resolver; - std::vector<ConnectionFactory*> connectionFactories; - TimerFactory* timerFactory; - int timeoutMilliseconds; - std::deque<ConnectionFactory*> connectionFactoryQueue; - boost::shared_ptr<Connector> currentConnector; - boost::shared_ptr<Error> lastError; - }; + class Connection; + class Connector; + class ConnectionFactory; + class TimerFactory; + class DomainNameResolver; + + class SWIFTEN_API ChainedConnector { + public: + ChainedConnector(const std::string& hostname, unsigned short port, const boost::optional<std::string>& serviceLookupPrefix, DomainNameResolver*, const std::vector<ConnectionFactory*>&, TimerFactory*); + ~ChainedConnector(); + + void setTimeoutMilliseconds(int milliseconds); + void start(); + void stop(); + + boost::signals2::signal<void (std::shared_ptr<Connection>, std::shared_ptr<Error>)> onConnectFinished; + + private: + void finish(std::shared_ptr<Connection> connection, std::shared_ptr<Error>); + void tryNextConnectionFactory(); + void handleConnectorFinished(std::shared_ptr<Connection>, std::shared_ptr<Error>); + + private: + std::string hostname; + unsigned short port; + boost::optional<std::string> serviceLookupPrefix; + DomainNameResolver* resolver; + std::vector<ConnectionFactory*> connectionFactories; + TimerFactory* timerFactory; + int timeoutMilliseconds; + std::deque<ConnectionFactory*> connectionFactoryQueue; + std::shared_ptr<Connector> currentConnector; + std::shared_ptr<Error> lastError; + }; } diff --git a/Swiften/Network/Connection.cpp b/Swiften/Network/Connection.cpp index 9bb29e1..adf3c6c 100644 --- a/Swiften/Network/Connection.cpp +++ b/Swiften/Network/Connection.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/Connection.h> diff --git a/Swiften/Network/Connection.h b/Swiften/Network/Connection.h index 97c287d..85f33a8 100644 --- a/Swiften/Network/Connection.h +++ b/Swiften/Network/Connection.h @@ -1,43 +1,45 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> -#include <Swiften/Base/boost_bsignals.h> +#include <memory> + +#include <boost/signals2.hpp> #include <Swiften/Base/API.h> #include <Swiften/Base/SafeByteArray.h> namespace Swift { - class HostAddressPort; - - class SWIFTEN_API Connection { - public: - typedef boost::shared_ptr<Connection> ref; - - enum Error { - ReadError, - WriteError - }; - - Connection(); - virtual ~Connection(); - - virtual void listen() = 0; - virtual void connect(const HostAddressPort& address) = 0; - virtual void disconnect() = 0; - virtual void write(const SafeByteArray& data) = 0; - - virtual HostAddressPort getLocalAddress() const = 0; - - public: - boost::signal<void (bool /* error */)> onConnectFinished; - boost::signal<void (const boost::optional<Error>&)> onDisconnected; - boost::signal<void (boost::shared_ptr<SafeByteArray>)> onDataRead; - boost::signal<void ()> onDataWritten; - }; + class HostAddressPort; + + class SWIFTEN_API Connection { + public: + typedef std::shared_ptr<Connection> ref; + + enum Error { + ReadError, + WriteError + }; + + Connection(); + virtual ~Connection(); + + virtual void listen() = 0; + virtual void connect(const HostAddressPort& address) = 0; + virtual void disconnect() = 0; + virtual void write(const SafeByteArray& data) = 0; + + virtual HostAddressPort getLocalAddress() const = 0; + virtual HostAddressPort getRemoteAddress() const = 0; + + public: + boost::signals2::signal<void (bool /* error */)> onConnectFinished; + boost::signals2::signal<void (const boost::optional<Error>&)> onDisconnected; + boost::signals2::signal<void (std::shared_ptr<SafeByteArray>)> onDataRead; + boost::signals2::signal<void ()> onDataWritten; + }; } diff --git a/Swiften/Network/ConnectionFactory.cpp b/Swiften/Network/ConnectionFactory.cpp index 2e38b21..372e029 100644 --- a/Swiften/Network/ConnectionFactory.cpp +++ b/Swiften/Network/ConnectionFactory.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/ConnectionFactory.h> diff --git a/Swiften/Network/ConnectionFactory.h b/Swiften/Network/ConnectionFactory.h index c8be2fc..e749fa3 100644 --- a/Swiften/Network/ConnectionFactory.h +++ b/Swiften/Network/ConnectionFactory.h @@ -1,22 +1,22 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> #include <Swiften/Base/API.h> namespace Swift { - class Connection; + class Connection; - class SWIFTEN_API ConnectionFactory { - public: - virtual ~ConnectionFactory(); + class SWIFTEN_API ConnectionFactory { + public: + virtual ~ConnectionFactory(); - virtual boost::shared_ptr<Connection> createConnection() = 0; - }; + virtual std::shared_ptr<Connection> createConnection() = 0; + }; } diff --git a/Swiften/Network/ConnectionServer.cpp b/Swiften/Network/ConnectionServer.cpp index 78312e7..c2ac241 100644 --- a/Swiften/Network/ConnectionServer.cpp +++ b/Swiften/Network/ConnectionServer.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/ConnectionServer.h> diff --git a/Swiften/Network/ConnectionServer.h b/Swiften/Network/ConnectionServer.h index 2e09348..769ab9f 100644 --- a/Swiften/Network/ConnectionServer.h +++ b/Swiften/Network/ConnectionServer.h @@ -1,37 +1,38 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> + +#include <boost/optional/optional.hpp> +#include <boost/signals2.hpp> #include <Swiften/Base/API.h> -#include <Swiften/Base/boost_bsignals.h> #include <Swiften/Network/Connection.h> #include <Swiften/Network/HostAddressPort.h> -#include <boost/optional/optional_fwd.hpp> namespace Swift { - class SWIFTEN_API ConnectionServer { - public: - enum Error { - Conflict, - UnknownError - }; + class SWIFTEN_API ConnectionServer { + public: + enum Error { + Conflict, + UnknownError + }; - virtual ~ConnectionServer(); + virtual ~ConnectionServer(); - virtual HostAddressPort getAddressPort() const = 0; + virtual HostAddressPort getAddressPort() const = 0; - virtual boost::optional<Error> tryStart() = 0; // FIXME: This should become the new start + virtual boost::optional<Error> tryStart() = 0; // FIXME: This should become the new start - virtual void start() = 0; + virtual void start() = 0; - virtual void stop() = 0; + virtual void stop() = 0; - boost::signal<void (boost::shared_ptr<Connection>)> onNewConnection; - }; + boost::signals2::signal<void (std::shared_ptr<Connection>)> onNewConnection; + }; } diff --git a/Swiften/Network/ConnectionServerFactory.h b/Swiften/Network/ConnectionServerFactory.h index df5f912..2ebccc1 100644 --- a/Swiften/Network/ConnectionServerFactory.h +++ b/Swiften/Network/ConnectionServerFactory.h @@ -4,20 +4,28 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2015-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> + +#include <Swiften/Base/API.h> namespace Swift { - class ConnectionServer; - class HostAddress; + class ConnectionServer; + class HostAddress; - class ConnectionServerFactory { - public: - virtual ~ConnectionServerFactory(); + class SWIFTEN_API ConnectionServerFactory { + public: + virtual ~ConnectionServerFactory(); - virtual boost::shared_ptr<ConnectionServer> createConnectionServer(int port) = 0; + virtual std::shared_ptr<ConnectionServer> createConnectionServer(unsigned short port) = 0; - virtual boost::shared_ptr<ConnectionServer> createConnectionServer(const Swift::HostAddress& hostAddress, int port) = 0; - }; + virtual std::shared_ptr<ConnectionServer> createConnectionServer(const Swift::HostAddress& hostAddress, unsigned short port) = 0; + }; } diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp index a0155cf..e27b95d 100644 --- a/Swiften/Network/Connector.cpp +++ b/Swiften/Network/Connector.cpp @@ -1,186 +1,203 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/Connector.h> #include <boost/bind.hpp> -#include <iostream> +#include <Swiften/Base/Log.h> #include <Swiften/Network/ConnectionFactory.h> -#include <Swiften/Network/DomainNameResolver.h> #include <Swiften/Network/DomainNameAddressQuery.h> +#include <Swiften/Network/DomainNameResolver.h> +#include <Swiften/Network/HostAddress.h> #include <Swiften/Network/TimerFactory.h> -#include <Swiften/Base/Log.h> namespace Swift { -Connector::Connector(const std::string& hostname, int port, bool doServiceLookups, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) : hostname(hostname), port(port), doServiceLookups(doServiceLookups), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), timeoutMilliseconds(0), queriedAllServices(true), foundSomeDNS(false) { +Connector::Connector(const std::string& hostname, unsigned short port, const boost::optional<std::string>& serviceLookupPrefix, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) : hostname(hostname), port(port), serviceLookupPrefix(serviceLookupPrefix), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), timeoutMilliseconds(0), queriedAllServices(true), foundSomeDNS(false) { } void Connector::setTimeoutMilliseconds(int milliseconds) { - timeoutMilliseconds = milliseconds; + timeoutMilliseconds = milliseconds; } void Connector::start() { - SWIFT_LOG(debug) << "Starting connector for " << hostname << std::endl; - //std::cout << "Connector::start()" << std::endl; - assert(!currentConnection); - assert(!serviceQuery); - assert(!timer); - queriedAllServices = false; - if (doServiceLookups) { - serviceQuery = resolver->createServiceQuery("_xmpp-client._tcp." + hostname); - serviceQuery->onResult.connect(boost::bind(&Connector::handleServiceQueryResult, shared_from_this(), _1)); - if (timeoutMilliseconds > 0) { - timer = timerFactory->createTimer(timeoutMilliseconds); - timer->onTick.connect(boost::bind(&Connector::handleTimeout, shared_from_this())); - } - serviceQuery->run(); - } - else { - queryAddress(hostname); - } + SWIFT_LOG(debug) << "Starting connector for " << hostname; + assert(!currentConnection); + assert(!serviceQuery); + assert(!timer); + auto hostAddress = HostAddress::fromString(hostname); + if (timeoutMilliseconds > 0) { + timer = timerFactory->createTimer(timeoutMilliseconds); + timer->onTick.connect(boost::bind(&Connector::handleTimeout, shared_from_this())); + } + if (serviceLookupPrefix) { + queriedAllServices = false; + serviceQuery = resolver->createServiceQuery(*serviceLookupPrefix, hostname); + serviceQuery->onResult.connect(boost::bind(&Connector::handleServiceQueryResult, shared_from_this(), _1)); + serviceQuery->run(); + } + else if (hostAddress) { + // hostname is already a valid address; skip name lookup. + foundSomeDNS = true; + addressQueryResults.push_back(hostAddress.get()); + tryNextAddress(); + } else { + queryAddress(hostname); + } } void Connector::stop() { - finish(boost::shared_ptr<Connection>()); + if (currentConnection) { + currentConnection->onConnectFinished.disconnect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); + currentConnection->disconnect(); + } + finish(std::shared_ptr<Connection>()); } void Connector::queryAddress(const std::string& hostname) { - assert(!addressQuery); - addressQuery = resolver->createAddressQuery(hostname); - addressQuery->onResult.connect(boost::bind(&Connector::handleAddressQueryResult, shared_from_this(), _1, _2)); - addressQuery->run(); + assert(!addressQuery); + addressQuery = resolver->createAddressQuery(hostname); + addressQuery->onResult.connect(boost::bind(&Connector::handleAddressQueryResult, shared_from_this(), _1, _2)); + addressQuery->run(); } void Connector::handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result) { - SWIFT_LOG(debug) << result.size() << " SRV result(s)" << std::endl; - serviceQueryResults = std::deque<DomainNameServiceQuery::Result>(result.begin(), result.end()); - serviceQuery.reset(); - if (!serviceQueryResults.empty()) { - foundSomeDNS = true; - } - tryNextServiceOrFallback(); + SWIFT_LOG(debug) << result.size() << " SRV result(s)"; + serviceQueryResults = std::deque<DomainNameServiceQuery::Result>(result.begin(), result.end()); + serviceQuery.reset(); + if (!serviceQueryResults.empty()) { + foundSomeDNS = true; + } + tryNextServiceOrFallback(); } void Connector::tryNextServiceOrFallback() { - if (queriedAllServices) { - SWIFT_LOG(debug) << "Queried all services" << std::endl; - finish(boost::shared_ptr<Connection>()); - } - else if (serviceQueryResults.empty()) { - SWIFT_LOG(debug) << "Falling back on A resolution" << std::endl; - // Fall back on simple address resolving - queriedAllServices = true; - queryAddress(hostname); - } - else { - SWIFT_LOG(debug) << "Querying next address" << std::endl; - queryAddress(serviceQueryResults.front().hostname); - } + if (queriedAllServices) { + SWIFT_LOG(debug) << "Queried all services"; + finish(std::shared_ptr<Connection>()); + } + else if (serviceQueryResults.empty()) { + SWIFT_LOG(debug) << "Falling back on A resolution"; + // Fall back on simple address resolving + queriedAllServices = true; + queryAddress(hostname); + } + else { + SWIFT_LOG(debug) << "Querying next address"; + queryAddress(serviceQueryResults.front().hostname); + } } void Connector::handleAddressQueryResult(const std::vector<HostAddress>& addresses, boost::optional<DomainNameResolveError> error) { - SWIFT_LOG(debug) << addresses.size() << " addresses" << std::endl; - addressQuery.reset(); - if (error || addresses.empty()) { - if (!serviceQueryResults.empty()) { - serviceQueryResults.pop_front(); - } - tryNextServiceOrFallback(); - } - else { - foundSomeDNS = true; - addressQueryResults = std::deque<HostAddress>(addresses.begin(), addresses.end()); - tryNextAddress(); - } + SWIFT_LOG(debug) << addresses.size() << " addresses"; + addressQuery.reset(); + if (error || addresses.empty()) { + if (!serviceQueryResults.empty()) { + serviceQueryResults.pop_front(); + } + tryNextServiceOrFallback(); + } + else { + foundSomeDNS = true; + addressQueryResults = std::deque<HostAddress>(addresses.begin(), addresses.end()); + tryNextAddress(); + } } void Connector::tryNextAddress() { - if (addressQueryResults.empty()) { - SWIFT_LOG(debug) << "Done trying addresses. Moving on." << std::endl; - // Done trying all addresses. Move on to the next host. - if (!serviceQueryResults.empty()) { - serviceQueryResults.pop_front(); - } - tryNextServiceOrFallback(); - } - else { - SWIFT_LOG(debug) << "Trying next address" << std::endl; - HostAddress address = addressQueryResults.front(); - addressQueryResults.pop_front(); - - int connectPort = (port == -1 ? 5222 : port); - if (!serviceQueryResults.empty()) { - connectPort = serviceQueryResults.front().port; - } - - tryConnect(HostAddressPort(address, connectPort)); - } + if (addressQueryResults.empty()) { + SWIFT_LOG(debug) << "Done trying addresses. Moving on."; + // Done trying all addresses. Move on to the next host. + if (!serviceQueryResults.empty()) { + serviceQueryResults.pop_front(); + } + tryNextServiceOrFallback(); + } + else { + SWIFT_LOG(debug) << "Trying next address"; + HostAddress address = addressQueryResults.front(); + addressQueryResults.pop_front(); + + unsigned short connectPort = (port == 0 ? 5222 : port); + if (!serviceQueryResults.empty()) { + connectPort = serviceQueryResults.front().port; + } + + tryConnect(HostAddressPort(address, connectPort)); + } } void Connector::tryConnect(const HostAddressPort& target) { - assert(!currentConnection); - SWIFT_LOG(debug) << "Trying to connect to " << target.getAddress().toString() << ":" << target.getPort() << std::endl; - currentConnection = connectionFactory->createConnection(); - currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); - currentConnection->connect(target); - if (timer) { - timer->start(); - } + assert(!currentConnection); + SWIFT_LOG(debug) << "Trying to connect to " << target.getAddress().toString() << ":" << target.getPort(); + currentConnection = connectionFactory->createConnection(); + currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); + currentConnection->connect(target); + if (timer) { + timer->start(); + } } void Connector::handleConnectionConnectFinished(bool error) { - SWIFT_LOG(debug) << "ConnectFinished: " << (error ? "error" : "success") << std::endl; - if (timer) { - timer->stop(); - timer.reset(); - } - currentConnection->onConnectFinished.disconnect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); - if (error) { - currentConnection.reset(); - if (!addressQueryResults.empty()) { - tryNextAddress(); - } - else { - if (!serviceQueryResults.empty()) { - serviceQueryResults.pop_front(); - } - tryNextServiceOrFallback(); - } - } - else { - finish(currentConnection); - } + SWIFT_LOG(debug) << "ConnectFinished: " << (error ? "error" : "success"); + if (timer) { + timer->stop(); + timer.reset(); + } + if (!currentConnection) { + // We've hit a race condition where multiple finisheds were on the eventloop queue at once. + // This is particularly likely on macOS where the hourly momentary wakeup while asleep + // can cause both a timeout and an onConnectFinished to be queued sequentially (SWIFT-232). + // Let the first one process as normal, but ignore the second. + return; + } + currentConnection->onConnectFinished.disconnect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); + if (error) { + currentConnection.reset(); + if (!addressQueryResults.empty()) { + tryNextAddress(); + } + else { + if (!serviceQueryResults.empty()) { + serviceQueryResults.pop_front(); + } + tryNextServiceOrFallback(); + } + } + else { + finish(currentConnection); + } } -void Connector::finish(boost::shared_ptr<Connection> connection) { - if (timer) { - timer->stop(); - timer->onTick.disconnect(boost::bind(&Connector::handleTimeout, shared_from_this())); - timer.reset(); - } - if (serviceQuery) { - serviceQuery->onResult.disconnect(boost::bind(&Connector::handleServiceQueryResult, shared_from_this(), _1)); - serviceQuery.reset(); - } - if (addressQuery) { - addressQuery->onResult.disconnect(boost::bind(&Connector::handleAddressQueryResult, shared_from_this(), _1, _2)); - addressQuery.reset(); - } - if (currentConnection) { - currentConnection->onConnectFinished.disconnect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); - currentConnection.reset(); - } - onConnectFinished(connection, (connection || foundSomeDNS) ? boost::shared_ptr<Error>() : boost::make_shared<DomainNameResolveError>()); +void Connector::finish(std::shared_ptr<Connection> connection) { + if (timer) { + timer->stop(); + timer->onTick.disconnect(boost::bind(&Connector::handleTimeout, shared_from_this())); + timer.reset(); + } + if (serviceQuery) { + serviceQuery->onResult.disconnect(boost::bind(&Connector::handleServiceQueryResult, shared_from_this(), _1)); + serviceQuery.reset(); + } + if (addressQuery) { + addressQuery->onResult.disconnect(boost::bind(&Connector::handleAddressQueryResult, shared_from_this(), _1, _2)); + addressQuery.reset(); + } + if (currentConnection) { + currentConnection->onConnectFinished.disconnect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); + currentConnection.reset(); + } + onConnectFinished(connection, (connection || foundSomeDNS) ? std::shared_ptr<Error>() : std::make_shared<DomainNameResolveError>()); } void Connector::handleTimeout() { - SWIFT_LOG(debug) << "Timeout" << std::endl; - handleConnectionConnectFinished(true); + SWIFT_LOG(debug) << "Timeout"; + SWIFT_LOG_ASSERT(currentConnection, error) << "Connection not valid but triggered a timeout"; + handleConnectionConnectFinished(true); } } diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h index 49ac271..c76a4af 100644 --- a/Swiften/Network/Connector.h +++ b/Swiften/Network/Connector.h @@ -1,74 +1,81 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once #include <deque> -#include <Swiften/Base/boost_bsignals.h> -#include <boost/shared_ptr.hpp> +#include <memory> +#include <string> + +#include <boost/optional.hpp> +#include <boost/signals2.hpp> #include <Swiften/Base/API.h> -#include <Swiften/Network/DomainNameServiceQuery.h> #include <Swiften/Network/Connection.h> -#include <Swiften/Network/Timer.h> -#include <Swiften/Network/HostAddressPort.h> -#include <string> #include <Swiften/Network/DomainNameResolveError.h> +#include <Swiften/Network/DomainNameServiceQuery.h> +#include <Swiften/Network/HostAddressPort.h> +#include <Swiften/Network/Timer.h> namespace Swift { - class DomainNameAddressQuery; - class DomainNameResolver; - class ConnectionFactory; - class TimerFactory; + class DomainNameAddressQuery; + class DomainNameResolver; + class ConnectionFactory; + class TimerFactory; - class SWIFTEN_API Connector : public boost::bsignals::trackable, public boost::enable_shared_from_this<Connector> { - public: - typedef boost::shared_ptr<Connector> ref; + class SWIFTEN_API Connector : public boost::signals2::trackable, public std::enable_shared_from_this<Connector> { + public: + typedef std::shared_ptr<Connector> ref; - static Connector::ref create(const std::string& hostname, int port, bool doServiceLookups, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) { - return ref(new Connector(hostname, port, doServiceLookups, resolver, connectionFactory, timerFactory)); - } + static Connector::ref create(const std::string& hostname, unsigned short port, const boost::optional<std::string>& serviceLookupPrefix, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) { + return ref(new Connector(hostname, port, serviceLookupPrefix, resolver, connectionFactory, timerFactory)); + } - void setTimeoutMilliseconds(int milliseconds); - void start(); - void stop(); + void setTimeoutMilliseconds(int milliseconds); + /** + * Start the connection attempt. + * Note that after calling this method, the caller is responsible for calling #stop() + * if it wants to cancel it. Not doing so can leak references. + */ + void start(); + void stop(); - boost::signal<void (boost::shared_ptr<Connection>, boost::shared_ptr<Error>)> onConnectFinished; + boost::signals2::signal<void (std::shared_ptr<Connection>, std::shared_ptr<Error>)> onConnectFinished; - private: - Connector(const std::string& hostname, int port, bool doServiceLookups, DomainNameResolver*, ConnectionFactory*, TimerFactory*); + private: + Connector(const std::string& hostname, unsigned short port, const boost::optional<std::string>& serviceLookupPrefix, DomainNameResolver*, ConnectionFactory*, TimerFactory*); - void handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result); - void handleAddressQueryResult(const std::vector<HostAddress>& address, boost::optional<DomainNameResolveError> error); - void queryAddress(const std::string& hostname); + void handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result); + void handleAddressQueryResult(const std::vector<HostAddress>& address, boost::optional<DomainNameResolveError> error); + void queryAddress(const std::string& hostname); - void tryNextServiceOrFallback(); - void tryNextAddress(); - void tryConnect(const HostAddressPort& target); + void tryNextServiceOrFallback(); + void tryNextAddress(); + void tryConnect(const HostAddressPort& target); - void handleConnectionConnectFinished(bool error); - void finish(boost::shared_ptr<Connection>); - void handleTimeout(); + void handleConnectionConnectFinished(bool error); + void finish(std::shared_ptr<Connection>); + void handleTimeout(); - private: - std::string hostname; - int port; - bool doServiceLookups; - DomainNameResolver* resolver; - ConnectionFactory* connectionFactory; - TimerFactory* timerFactory; - int timeoutMilliseconds; - boost::shared_ptr<Timer> timer; - boost::shared_ptr<DomainNameServiceQuery> serviceQuery; - std::deque<DomainNameServiceQuery::Result> serviceQueryResults; - boost::shared_ptr<DomainNameAddressQuery> addressQuery; - std::deque<HostAddress> addressQueryResults; - bool queriedAllServices; - boost::shared_ptr<Connection> currentConnection; - bool foundSomeDNS; - }; + private: + std::string hostname; + unsigned short port; + boost::optional<std::string> serviceLookupPrefix; + DomainNameResolver* resolver; + ConnectionFactory* connectionFactory; + TimerFactory* timerFactory; + int timeoutMilliseconds; + std::shared_ptr<Timer> timer; + std::shared_ptr<DomainNameServiceQuery> serviceQuery; + std::deque<DomainNameServiceQuery::Result> serviceQueryResults; + std::shared_ptr<DomainNameAddressQuery> addressQuery; + std::deque<HostAddress> addressQueryResults; + bool queriedAllServices; + std::shared_ptr<Connection> currentConnection; + bool foundSomeDNS; + }; } diff --git a/Swiften/Network/DomainNameAddressQuery.cpp b/Swiften/Network/DomainNameAddressQuery.cpp index 856f301..33d9e99 100644 --- a/Swiften/Network/DomainNameAddressQuery.cpp +++ b/Swiften/Network/DomainNameAddressQuery.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/DomainNameAddressQuery.h> diff --git a/Swiften/Network/DomainNameAddressQuery.h b/Swiften/Network/DomainNameAddressQuery.h index c8ed981..7f89546 100644 --- a/Swiften/Network/DomainNameAddressQuery.h +++ b/Swiften/Network/DomainNameAddressQuery.h @@ -1,27 +1,28 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <Swiften/Base/boost_bsignals.h> +#include <memory> + #include <boost/optional.hpp> -#include <boost/shared_ptr.hpp> +#include <boost/signals2.hpp> #include <Swiften/Network/DomainNameResolveError.h> #include <Swiften/Network/HostAddress.h> namespace Swift { - class DomainNameAddressQuery { - public: - typedef boost::shared_ptr<DomainNameAddressQuery> ref; + class DomainNameAddressQuery { + public: + typedef std::shared_ptr<DomainNameAddressQuery> ref; - virtual ~DomainNameAddressQuery(); + virtual ~DomainNameAddressQuery(); - virtual void run() = 0; + virtual void run() = 0; - boost::signal<void (const std::vector<HostAddress>&, boost::optional<DomainNameResolveError>)> onResult; - }; + boost::signals2::signal<void (const std::vector<HostAddress>&, boost::optional<DomainNameResolveError>)> onResult; + }; } diff --git a/Swiften/Network/DomainNameResolveError.h b/Swiften/Network/DomainNameResolveError.h index aa4441d..7a5eed1 100644 --- a/Swiften/Network/DomainNameResolveError.h +++ b/Swiften/Network/DomainNameResolveError.h @@ -1,16 +1,19 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once +#include <Swiften/Base/API.h> #include <Swiften/Base/Error.h> namespace Swift { - class DomainNameResolveError : public Error { - public: - DomainNameResolveError() {} - }; + class SWIFTEN_API DomainNameResolveError : public Error { + public: + DomainNameResolveError() {} + SWIFTEN_DEFAULT_COPY_CONSTRUCTOR(DomainNameResolveError) + SWIFTEN_DEFAULT_COPY_ASSIGMNENT_OPERATOR(DomainNameResolveError) + }; } diff --git a/Swiften/Network/DomainNameResolver.cpp b/Swiften/Network/DomainNameResolver.cpp index 56a9d72..95c08db 100644 --- a/Swiften/Network/DomainNameResolver.cpp +++ b/Swiften/Network/DomainNameResolver.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/DomainNameResolver.h> diff --git a/Swiften/Network/DomainNameResolver.h b/Swiften/Network/DomainNameResolver.h index 491586a..5fe30d5 100644 --- a/Swiften/Network/DomainNameResolver.h +++ b/Swiften/Network/DomainNameResolver.h @@ -1,26 +1,26 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> #include <string> #include <Swiften/Base/API.h> namespace Swift { - class DomainNameServiceQuery; - class DomainNameAddressQuery; - + class DomainNameServiceQuery; + class DomainNameAddressQuery; - class SWIFTEN_API DomainNameResolver { - public: - virtual ~DomainNameResolver(); - virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const std::string& name) = 0; - virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const std::string& name) = 0; - }; + class SWIFTEN_API DomainNameResolver { + public: + virtual ~DomainNameResolver(); + + virtual std::shared_ptr<DomainNameServiceQuery> createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain) = 0; + virtual std::shared_ptr<DomainNameAddressQuery> createAddressQuery(const std::string& name) = 0; + }; } diff --git a/Swiften/Network/DomainNameServiceQuery.cpp b/Swiften/Network/DomainNameServiceQuery.cpp index 6ce1d97..5784dd7 100644 --- a/Swiften/Network/DomainNameServiceQuery.cpp +++ b/Swiften/Network/DomainNameServiceQuery.cpp @@ -1,31 +1,28 @@ /* - * Copyright (c) 2010-2013 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/DomainNameServiceQuery.h> -#include <numeric> #include <cassert> #include <functional> #include <iterator> +#include <numeric> -#include <Swiften/Base/RandomGenerator.h> #include <boost/numeric/conversion/cast.hpp> -#include <boost/lambda/lambda.hpp> -#include <boost/lambda/bind.hpp> -#include <boost/typeof/typeof.hpp> + +#include <Swiften/Base/RandomGenerator.h> using namespace Swift; -namespace lambda = boost::lambda; namespace { - struct ResultPriorityComparator { - bool operator()(const DomainNameServiceQuery::Result& a, const DomainNameServiceQuery::Result& b) const { - return a.priority < b.priority; - } - }; + struct ResultPriorityComparator { + bool operator()(const DomainNameServiceQuery::Result& a, const DomainNameServiceQuery::Result& b) const { + return a.priority < b.priority; + } + }; } namespace Swift { @@ -34,31 +31,37 @@ DomainNameServiceQuery::~DomainNameServiceQuery() { } void DomainNameServiceQuery::sortResults(std::vector<DomainNameServiceQuery::Result>& queries, RandomGenerator& generator) { - ResultPriorityComparator comparator; - std::sort(queries.begin(), queries.end(), comparator); + ResultPriorityComparator comparator; + std::stable_sort(queries.begin(), queries.end(), comparator); - std::vector<DomainNameServiceQuery::Result>::iterator i = queries.begin(); - while (i != queries.end()) { - std::vector<DomainNameServiceQuery::Result>::iterator next = std::upper_bound(i, queries.end(), *i, comparator); - if (std::distance(i, next) > 1) { - std::vector<int> weights; - std::transform(i, next, std::back_inserter(weights), - /* easy hack to account for '0' weights getting at least some weight */ - lambda::bind(&Result::weight, lambda::_1) + 1); - for (int j = 0; j < boost::numeric_cast<int>(weights.size() - 1); ++j) { - std::vector<int> cumulativeWeights; - std::partial_sum( - weights.begin() + j, - weights.end(), - std::back_inserter(cumulativeWeights)); - int randomNumber = generator.generateRandomInteger(cumulativeWeights.back()); - BOOST_AUTO(selectedIndex, std::lower_bound(cumulativeWeights.begin(), cumulativeWeights.end(), randomNumber) - cumulativeWeights.begin()); - std::swap(i[j], i[j + selectedIndex]); - std::swap(weights.begin()[j], weights.begin()[j + selectedIndex]); - } - } - i = next; - } + std::vector<DomainNameServiceQuery::Result>::iterator i = queries.begin(); + while (i != queries.end()) { + std::vector<DomainNameServiceQuery::Result>::iterator next = std::upper_bound(i, queries.end(), *i, comparator); + if (std::distance(i, next) > 1) { + std::vector<int> weights; + std::transform(i, next, std::back_inserter(weights), [](const DomainNameServiceQuery::Result& result) { + /* easy hack to account for '0' weights getting at least some weight */ + return result.weight + 1; + }); + try { + for (int j = 0; j < boost::numeric_cast<int>(weights.size()) - 1; ++j) { + std::vector<int> cumulativeWeights; + std::partial_sum( + weights.begin() + j, + weights.end(), + std::back_inserter(cumulativeWeights)); + int randomNumber = generator.generateRandomInteger(cumulativeWeights.back()); + auto selectedIndex = std::lower_bound(cumulativeWeights.begin(), cumulativeWeights.end(), randomNumber) - cumulativeWeights.begin(); + std::swap(i[j], i[j + selectedIndex]); + std::swap(weights.begin()[j], weights.begin()[j + selectedIndex]); + } + } + catch (const boost::numeric::bad_numeric_cast&) { + // In the unlikely event of weights.size() being too large, use the list as-is. + } + } + i = next; + } } diff --git a/Swiften/Network/DomainNameServiceQuery.h b/Swiften/Network/DomainNameServiceQuery.h index fdf5b5d..1631b99 100644 --- a/Swiften/Network/DomainNameServiceQuery.h +++ b/Swiften/Network/DomainNameServiceQuery.h @@ -1,40 +1,41 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <Swiften/Base/boost_bsignals.h> -#include <boost/optional.hpp> +#include <memory> +#include <string> #include <vector> -#include <boost/shared_ptr.hpp> -#include <string> +#include <boost/optional.hpp> +#include <boost/signals2.hpp> + #include <Swiften/Base/API.h> #include <Swiften/Network/DomainNameResolveError.h> namespace Swift { - class RandomGenerator; + class RandomGenerator; - class SWIFTEN_API DomainNameServiceQuery { - public: - typedef boost::shared_ptr<DomainNameServiceQuery> ref; + class SWIFTEN_API DomainNameServiceQuery { + public: + typedef std::shared_ptr<DomainNameServiceQuery> ref; - struct Result { - Result(const std::string& hostname = "", int port = -1, int priority = -1, int weight = -1) : hostname(hostname), port(port), priority(priority), weight(weight) {} - std::string hostname; - int port; - int priority; - int weight; - }; + struct Result { + Result(const std::string& hostname = "", unsigned short port = 0, int priority = -1, int weight = -1) : hostname(hostname), port(port), priority(priority), weight(weight) {} + std::string hostname; + unsigned short port; + int priority; + int weight; + }; - virtual ~DomainNameServiceQuery(); + virtual ~DomainNameServiceQuery(); - virtual void run() = 0; - static void sortResults(std::vector<DomainNameServiceQuery::Result>& queries, RandomGenerator& generator); + virtual void run() = 0; + static void sortResults(std::vector<DomainNameServiceQuery::Result>& queries, RandomGenerator& generator); - boost::signal<void (const std::vector<Result>&)> onResult; - }; + boost::signals2::signal<void (const std::vector<Result>&)> onResult; + }; } diff --git a/Swiften/Network/DummyConnection.cpp b/Swiften/Network/DummyConnection.cpp index 09bd06d..3024b21 100644 --- a/Swiften/Network/DummyConnection.cpp +++ b/Swiften/Network/DummyConnection.cpp @@ -1,14 +1,15 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/DummyConnection.h> -#include <boost/bind.hpp> -#include <boost/smart_ptr/make_shared.hpp> #include <cassert> +#include <memory> + +#include <boost/bind.hpp> namespace Swift { @@ -16,15 +17,15 @@ DummyConnection::DummyConnection(EventLoop* eventLoop) : eventLoop(eventLoop) { } void DummyConnection::receive(const SafeByteArray& data) { - eventLoop->postEvent(boost::bind(boost::ref(onDataRead), boost::make_shared<SafeByteArray>(data)), shared_from_this()); + eventLoop->postEvent(boost::bind(boost::ref(onDataRead), std::make_shared<SafeByteArray>(data)), shared_from_this()); } void DummyConnection::listen() { - assert(false); + assert(false); } void DummyConnection::connect(const HostAddressPort&) { - assert(false); + assert(false); } diff --git a/Swiften/Network/DummyConnection.h b/Swiften/Network/DummyConnection.h index 36bf897..e58edf6 100644 --- a/Swiften/Network/DummyConnection.h +++ b/Swiften/Network/DummyConnection.h @@ -1,45 +1,50 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/enable_shared_from_this.hpp> +#include <memory> #include <Swiften/Base/API.h> -#include <Swiften/Network/Connection.h> -#include <Swiften/Network/HostAddressPort.h> #include <Swiften/EventLoop/EventLoop.h> #include <Swiften/EventLoop/EventOwner.h> +#include <Swiften/Network/Connection.h> +#include <Swiften/Network/HostAddressPort.h> namespace Swift { - class SWIFTEN_API DummyConnection : public Connection, public EventOwner, public boost::enable_shared_from_this<DummyConnection> { - public: - DummyConnection(EventLoop* eventLoop); + class SWIFTEN_API DummyConnection : public Connection, public EventOwner, public std::enable_shared_from_this<DummyConnection> { + public: + DummyConnection(EventLoop* eventLoop); + + void listen(); + void connect(const HostAddressPort&); - void listen(); - void connect(const HostAddressPort&); + void disconnect() { + //assert(false); + } - void disconnect() { - //assert(false); - } + void write(const SafeByteArray& data) { + eventLoop->postEvent(boost::ref(onDataWritten), shared_from_this()); + onDataSent(data); + } - void write(const SafeByteArray& data) { - eventLoop->postEvent(boost::ref(onDataWritten), shared_from_this()); - onDataSent(data); - } + void receive(const SafeByteArray& data); - void receive(const SafeByteArray& data); + HostAddressPort getLocalAddress() const { + return localAddress; + } - HostAddressPort getLocalAddress() const { - return localAddress; - } + HostAddressPort getRemoteAddress() const { + return remoteAddress; + } - boost::signal<void (const SafeByteArray&)> onDataSent; + boost::signals2::signal<void (const SafeByteArray&)> onDataSent; - EventLoop* eventLoop; - HostAddressPort localAddress; - }; + EventLoop* eventLoop; + HostAddressPort localAddress; + HostAddressPort remoteAddress; + }; } diff --git a/Swiften/Network/DummyConnectionFactory.h b/Swiften/Network/DummyConnectionFactory.h index e8a294e..d723283 100644 --- a/Swiften/Network/DummyConnectionFactory.h +++ b/Swiften/Network/DummyConnectionFactory.h @@ -4,9 +4,15 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once -#include <boost/smart_ptr/make_shared.hpp> +#include <memory> #include <Swiften/Network/ConnectionFactory.h> #include <Swiften/Network/DummyConnection.h> @@ -17,13 +23,13 @@ class EventLoop; class DummyConnectionFactory : public ConnectionFactory { public: - DummyConnectionFactory(EventLoop *eventLoop) : eventLoop(eventLoop) {} - virtual ~DummyConnectionFactory() {} - virtual boost::shared_ptr<Connection> createConnection() { - return boost::make_shared<DummyConnection>(eventLoop); - } + DummyConnectionFactory(EventLoop *eventLoop) : eventLoop(eventLoop) {} + virtual ~DummyConnectionFactory() {} + virtual std::shared_ptr<Connection> createConnection() { + return std::make_shared<DummyConnection>(eventLoop); + } private: - EventLoop* eventLoop; + EventLoop* eventLoop; }; } diff --git a/Swiften/Network/DummyConnectionServer.h b/Swiften/Network/DummyConnectionServer.h new file mode 100644 index 0000000..a4fd07f --- /dev/null +++ b/Swiften/Network/DummyConnectionServer.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2010-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + +#pragma once + +#include <memory> + +#include <Swiften/Base/API.h> +#include <Swiften/EventLoop/EventLoop.h> +#include <Swiften/EventLoop/EventOwner.h> +#include <Swiften/Network/ConnectionServer.h> +#include <Swiften/Network/HostAddressPort.h> + +namespace Swift { + class SWIFTEN_API DummyConnectionServer : public ConnectionServer, public EventOwner, public std::enable_shared_from_this<DummyConnectionServer> { + public: + DummyConnectionServer(EventLoop* /*eventLoop*/, unsigned short port) : localAddressPort(HostAddress(), port) {} + DummyConnectionServer(EventLoop* /*eventLoop*/, const Swift::HostAddress& hostAddress, unsigned short port) : localAddressPort(hostAddress, port) {} + virtual ~DummyConnectionServer() {} + + virtual HostAddressPort getAddressPort() const { + return localAddressPort; + } + + virtual boost::optional<Error> tryStart() { + return boost::optional<Error>(); + } + + virtual void start() { + + } + + virtual void stop() { + + } + + private: + HostAddressPort localAddressPort; + }; +} diff --git a/Swiften/Network/DummyConnectionServerFactory.h b/Swiften/Network/DummyConnectionServerFactory.h new file mode 100644 index 0000000..4b25118 --- /dev/null +++ b/Swiften/Network/DummyConnectionServerFactory.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2014-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + +#pragma once + +#include <memory> + +#include <Swiften/Network/ConnectionServerFactory.h> +#include <Swiften/Network/DummyConnectionServer.h> + +namespace Swift { + +class EventLoop; + +class DummyConnectionServerFactory : public ConnectionServerFactory { +public: + DummyConnectionServerFactory(EventLoop* eventLoop) : eventLoop(eventLoop) {} + virtual ~DummyConnectionServerFactory() {} + + virtual std::shared_ptr<ConnectionServer> createConnectionServer(unsigned short port) { + return std::make_shared<DummyConnectionServer>(eventLoop, port); + } + + virtual std::shared_ptr<ConnectionServer> createConnectionServer(const Swift::HostAddress& hostAddress, unsigned short port) { + return std::make_shared<DummyConnectionServer>(eventLoop, hostAddress, port); + } + +private: + EventLoop* eventLoop; +}; + +} diff --git a/Swiften/Network/DummyTimerFactory.cpp b/Swiften/Network/DummyTimerFactory.cpp index 16428b7..0bad7be 100644 --- a/Swiften/Network/DummyTimerFactory.cpp +++ b/Swiften/Network/DummyTimerFactory.cpp @@ -1,60 +1,59 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/DummyTimerFactory.h> #include <algorithm> -#include <Swiften/Base/foreach.h> #include <Swiften/Network/Timer.h> namespace Swift { class DummyTimerFactory::DummyTimer : public Timer { - public: - DummyTimer(int timeout, DummyTimerFactory* factory) : timeout(timeout), factory(factory), isRunning(false), startTime(~0) { - } - - virtual void start() { - isRunning = true; - startTime = factory->currentTime; - } - - virtual void stop() { - isRunning = false; - } - - int getAlarmTime() const { - return startTime + timeout; - } - - int timeout; - DummyTimerFactory* factory; - bool isRunning; - int startTime; + public: + DummyTimer(int timeout, DummyTimerFactory* factory) : timeout(timeout), factory(factory), isRunning(false), startTime(~0) { + } + + virtual void start() { + isRunning = true; + startTime = factory->currentTime; + } + + virtual void stop() { + isRunning = false; + } + + int getAlarmTime() const { + return startTime + timeout; + } + + int timeout; + DummyTimerFactory* factory; + bool isRunning; + int startTime; }; DummyTimerFactory::DummyTimerFactory() : currentTime(0) { } -boost::shared_ptr<Timer> DummyTimerFactory::createTimer(int milliseconds) { - boost::shared_ptr<DummyTimer> timer(new DummyTimer(milliseconds, this)); - timers.push_back(timer); - return timer; +std::shared_ptr<Timer> DummyTimerFactory::createTimer(int milliseconds) { + std::shared_ptr<DummyTimer> timer(new DummyTimer(milliseconds, this)); + timers.push_back(timer); + return timer; } void DummyTimerFactory::setTime(int time) { - assert(time > currentTime); - foreach(boost::shared_ptr<DummyTimer> timer, timers) { - if (timer->getAlarmTime() > currentTime && timer->getAlarmTime() <= time && timer->isRunning) { - timer->onTick(); - } - } - currentTime = time; + assert(time > currentTime); + for (auto&& timer : timers) { + if (timer->getAlarmTime() > currentTime && timer->getAlarmTime() <= time && timer->isRunning) { + timer->onTick(); + } + } + currentTime = time; } } diff --git a/Swiften/Network/DummyTimerFactory.h b/Swiften/Network/DummyTimerFactory.h index 1e9413b..5ccbf93 100644 --- a/Swiften/Network/DummyTimerFactory.h +++ b/Swiften/Network/DummyTimerFactory.h @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once @@ -12,18 +12,18 @@ #include <Swiften/Network/TimerFactory.h> namespace Swift { - class SWIFTEN_API DummyTimerFactory : public TimerFactory { - public: - class DummyTimer; + class SWIFTEN_API DummyTimerFactory : public TimerFactory { + public: + class DummyTimer; - DummyTimerFactory(); + DummyTimerFactory(); - virtual boost::shared_ptr<Timer> createTimer(int milliseconds); - void setTime(int time); + virtual std::shared_ptr<Timer> createTimer(int milliseconds); + void setTime(int time); - private: - friend class DummyTimer; - int currentTime; - std::list<boost::shared_ptr<DummyTimer> > timers; - }; + private: + friend class DummyTimer; + int currentTime; + std::list<std::shared_ptr<DummyTimer> > timers; + }; } diff --git a/Swiften/Network/EnvironmentProxyProvider.cpp b/Swiften/Network/EnvironmentProxyProvider.cpp index 7701da1..6fbf373 100644 --- a/Swiften/Network/EnvironmentProxyProvider.cpp +++ b/Swiften/Network/EnvironmentProxyProvider.cpp @@ -4,45 +4,59 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + +#include <Swiften/Network/EnvironmentProxyProvider.h> + #include <stdio.h> #include <stdlib.h> + #include <iostream> +#include <boost/numeric/conversion/cast.hpp> + #include <Swiften/Base/Log.h> -#include <Swiften/Network/EnvironmentProxyProvider.h> namespace Swift { EnvironmentProxyProvider::EnvironmentProxyProvider() { - socksProxy = getFromEnv("all_proxy", "socks"); - httpProxy = getFromEnv("http_proxy", "http"); - SWIFT_LOG(debug) << "Environment: SOCKS5 => " << socksProxy.toString() << "; HTTP Connect => " << httpProxy.toString() << std::endl; + socksProxy = getFromEnv("all_proxy", "socks"); + httpProxy = getFromEnv("http_proxy", "http"); + SWIFT_LOG(debug) << "Environment: SOCKS5 => " << socksProxy.toString() << "; HTTP Connect => " << httpProxy.toString(); } HostAddressPort EnvironmentProxyProvider::getHTTPConnectProxy() const { - return httpProxy; + return httpProxy; } HostAddressPort EnvironmentProxyProvider::getSOCKS5Proxy() const { - return socksProxy; + return socksProxy; } HostAddressPort EnvironmentProxyProvider::getFromEnv(const char* envVarName, std::string proxyProtocol) { - char* envVar = NULL; - std::string address; - int port = 0; - - envVar = getenv(envVarName); - - proxyProtocol += "://"; - address = envVar != NULL ? envVar : "0.0.0.0"; - if(envVar != NULL && address.compare(0, proxyProtocol.length(), proxyProtocol) == 0) { - address = address.substr(proxyProtocol.length(), address.length()); - port = atoi(address.substr(address.find(':') + 1, address.length()).c_str()); - address = address.substr(0, address.find(':')); - } - - return HostAddressPort(HostAddress(address), port); + char* envVar = nullptr; + std::string address; + unsigned short port = 0; + + envVar = getenv(envVarName); + + proxyProtocol += "://"; + address = envVar != nullptr ? envVar : "0.0.0.0"; + if(envVar != nullptr && address.compare(0, proxyProtocol.length(), proxyProtocol) == 0) { + address = address.substr(proxyProtocol.length(), address.length()); + try { + port = boost::numeric_cast<unsigned short>(atoi(address.substr(address.find(':') + 1, address.length()).c_str())); + } + catch (boost::numeric::bad_numeric_cast&) { + } + address = address.substr(0, address.find(':')); + } + + return HostAddressPort(HostAddress::fromString(address).get_value_or(HostAddress()), port); } } diff --git a/Swiften/Network/EnvironmentProxyProvider.h b/Swiften/Network/EnvironmentProxyProvider.h index 224d301..59463ff 100644 --- a/Swiften/Network/EnvironmentProxyProvider.h +++ b/Swiften/Network/EnvironmentProxyProvider.h @@ -4,21 +4,28 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2015 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once +#include <Swiften/Base/API.h> #include <Swiften/Network/ProxyProvider.h> namespace Swift { - class EnvironmentProxyProvider : public ProxyProvider { - public: - EnvironmentProxyProvider(); - virtual HostAddressPort getHTTPConnectProxy() const; - virtual HostAddressPort getSOCKS5Proxy() const; - private: - HostAddressPort getFromEnv(const char* envVarName, std::string proxyProtocol); - HostAddressPort socksProxy; - HostAddressPort httpProxy; - }; + class SWIFTEN_API EnvironmentProxyProvider : public ProxyProvider { + public: + EnvironmentProxyProvider(); + virtual HostAddressPort getHTTPConnectProxy() const; + virtual HostAddressPort getSOCKS5Proxy() const; + private: + HostAddressPort getFromEnv(const char* envVarName, std::string proxyProtocol); + HostAddressPort socksProxy; + HostAddressPort httpProxy; + }; } diff --git a/Swiften/Network/FakeConnection.cpp b/Swiften/Network/FakeConnection.cpp index be5555c..82b792a 100644 --- a/Swiften/Network/FakeConnection.cpp +++ b/Swiften/Network/FakeConnection.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2014 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/FakeConnection.h> @@ -17,48 +17,48 @@ FakeConnection::~FakeConnection() { } void FakeConnection::listen() { - assert(false); + assert(false); } void FakeConnection::setError(const Error& e) { - error = boost::optional<Error>(e); - state = DisconnectedWithError; - if (connectedTo) { - eventLoop->postEvent( - boost::bind(boost::ref(onDisconnected), error), - shared_from_this()); - } + error = boost::optional<Error>(e); + state = DisconnectedWithError; + if (connectedTo) { + eventLoop->postEvent( + boost::bind(boost::ref(onDisconnected), error), + shared_from_this()); + } } void FakeConnection::connect(const HostAddressPort& address) { - if (delayConnect) { - state = Connecting; - } - else { - if (!error) { - connectedTo = address; - state = Connected; - } - else { - state = DisconnectedWithError; - } - eventLoop->postEvent( - boost::bind(boost::ref(onConnectFinished), error), - shared_from_this()); - } + if (delayConnect) { + state = Connecting; + } + else { + if (!error) { + connectedTo = address; + state = Connected; + } + else { + state = DisconnectedWithError; + } + eventLoop->postEvent( + boost::bind(boost::ref(onConnectFinished), error ? true : false), + shared_from_this()); + } } void FakeConnection::disconnect() { - if (!error) { - state = Disconnected; - } - else { - state = DisconnectedWithError; - } - connectedTo.reset(); - eventLoop->postEvent( - boost::bind(boost::ref(onDisconnected), error), - shared_from_this()); + if (!error) { + state = Disconnected; + } + else { + state = DisconnectedWithError; + } + connectedTo.reset(); + eventLoop->postEvent( + boost::bind(boost::ref(onDisconnected), error), + shared_from_this()); } } diff --git a/Swiften/Network/FakeConnection.h b/Swiften/Network/FakeConnection.h index eca45da..08c1d75 100644 --- a/Swiften/Network/FakeConnection.h +++ b/Swiften/Network/FakeConnection.h @@ -1,60 +1,64 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/optional.hpp> -#include <boost/enable_shared_from_this.hpp> +#include <memory> #include <vector> +#include <boost/optional.hpp> + #include <Swiften/Base/API.h> +#include <Swiften/EventLoop/EventLoop.h> +#include <Swiften/EventLoop/EventOwner.h> #include <Swiften/Network/Connection.h> #include <Swiften/Network/HostAddressPort.h> -#include <Swiften/EventLoop/EventOwner.h> -#include <Swiften/EventLoop/EventLoop.h> namespace Swift { - class SWIFTEN_API FakeConnection : - public Connection, - public EventOwner, - public boost::enable_shared_from_this<FakeConnection> { - public: - enum State { - Initial, - Connecting, - Connected, - Disconnected, - DisconnectedWithError - }; - - FakeConnection(EventLoop* eventLoop); - ~FakeConnection(); - - virtual void listen(); - virtual HostAddressPort getLocalAddress() const { - return HostAddressPort(); - } - - void setError(const Error& e); - virtual void connect(const HostAddressPort& address); - virtual void disconnect(); - - virtual void write(const SafeByteArray& data) { - dataWritten.push_back(data); - } - - void setDelayConnect() { - delayConnect = true; - } - - EventLoop* eventLoop; - boost::optional<HostAddressPort> connectedTo; - std::vector<SafeByteArray> dataWritten; - boost::optional<Error> error; - State state; - bool delayConnect; - }; + class SWIFTEN_API FakeConnection : + public Connection, + public EventOwner, + public std::enable_shared_from_this<FakeConnection> { + public: + enum State { + Initial, + Connecting, + Connected, + Disconnected, + DisconnectedWithError + }; + + FakeConnection(EventLoop* eventLoop); + virtual ~FakeConnection(); + + virtual void listen(); + virtual HostAddressPort getLocalAddress() const { + return HostAddressPort(); + } + virtual HostAddressPort getRemoteAddress() const { + return HostAddressPort(); + } + + void setError(const Error& e); + virtual void connect(const HostAddressPort& address); + virtual void disconnect(); + + virtual void write(const SafeByteArray& data) { + dataWritten.push_back(data); + } + + void setDelayConnect() { + delayConnect = true; + } + + EventLoop* eventLoop; + boost::optional<HostAddressPort> connectedTo; + std::vector<SafeByteArray> dataWritten; + boost::optional<Error> error; + State state; + bool delayConnect; + }; } diff --git a/Swiften/Network/GConfProxyProvider.cpp b/Swiften/Network/GConfProxyProvider.cpp index 8d97c68..a2f8adc 100644 --- a/Swiften/Network/GConfProxyProvider.cpp +++ b/Swiften/Network/GConfProxyProvider.cpp @@ -4,55 +4,73 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ -#include <stdio.h> -#include <stdlib.h> +/* + * Copyright (c) 2016-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + +#include <Swiften/Network/GConfProxyProvider.h> + +#include <cstdio> +#include <cstdlib> + #include <iostream> +extern "C" { #include <gconf/gconf-client.h> +} + +#include <boost/numeric/conversion/cast.hpp> #include <Swiften/Base/Log.h> -#include <Swiften/Network/GConfProxyProvider.h> namespace Swift { GConfProxyProvider::GConfProxyProvider() { - // Ensure static GLib initialization methods are called - static bool glibInitialized = false; - if (!glibInitialized) { - g_type_init(); - glibInitialized = true; - } - - socksProxy = getFromGConf("/system/proxy/socks_host", "/system/proxy/socks_port"); - httpProxy = getFromGConf("/system/http_proxy/host", "/system/http_proxy/port"); - SWIFT_LOG(debug) << "GConf: SOCKS5 => " << socksProxy.toString() << "; HTTP Connect => " << httpProxy.toString() << std::endl; +#if !GLIB_CHECK_VERSION(2,35,0) + // Ensure static GLib initialization methods are called + static bool glibInitialized = false; + if (!glibInitialized) { + g_type_init(); + glibInitialized = true; + } +#endif + + socksProxy = getFromGConf("/system/proxy/socks_host", "/system/proxy/socks_port"); + httpProxy = getFromGConf("/system/http_proxy/host", "/system/http_proxy/port"); + SWIFT_LOG(debug) << "GConf: SOCKS5 => " << socksProxy.toString() << "; HTTP Connect => " << httpProxy.toString(); } HostAddressPort GConfProxyProvider::getHTTPConnectProxy() const { - return httpProxy; + return httpProxy; } HostAddressPort GConfProxyProvider::getSOCKS5Proxy() const { - return socksProxy; + return socksProxy; } HostAddressPort GConfProxyProvider::getFromGConf(const char* gcHost, const char* gcPort) { - std::string address; - int port = 0; - gchar* str; + std::string address; + unsigned short port = 0; + gchar* str; - GConfClient* client = gconf_client_get_default(); + GConfClient* client = gconf_client_get_default(); - str = gconf_client_get_string(client, gcHost, NULL); - port = static_cast<int> (gconf_client_get_int(client, gcPort, NULL)); + str = gconf_client_get_string(client, gcHost, NULL); + try { + port = boost::numeric_cast<unsigned short>(gconf_client_get_int(client, gcPort, NULL)); + } + catch (const boost::numeric::bad_numeric_cast&) { + } - if(str) { - address = static_cast<char*> (str); - g_free(str); - } + if(str) { + address = static_cast<char*> (str); + g_free(str); + } - g_object_unref(client); - return HostAddressPort(HostAddress(address), port); + g_object_unref(client); + return HostAddressPort(HostAddress::fromString(address).get_value_or(HostAddress()), port); } } diff --git a/Swiften/Network/GConfProxyProvider.h b/Swiften/Network/GConfProxyProvider.h index 31f38a9..826a67b 100644 --- a/Swiften/Network/GConfProxyProvider.h +++ b/Swiften/Network/GConfProxyProvider.h @@ -9,16 +9,16 @@ #include <Swiften/Network/ProxyProvider.h> namespace Swift { - class GConfProxyProvider : public ProxyProvider { - public: - GConfProxyProvider(); - virtual HostAddressPort getHTTPConnectProxy() const; - virtual HostAddressPort getSOCKS5Proxy() const; - private: - HostAddressPort getFromGConf(const char* gcHost, const char* gcPort); - HostAddressPort socksProxy; - HostAddressPort httpProxy; - }; + class GConfProxyProvider : public ProxyProvider { + public: + GConfProxyProvider(); + virtual HostAddressPort getHTTPConnectProxy() const; + virtual HostAddressPort getSOCKS5Proxy() const; + private: + HostAddressPort getFromGConf(const char* gcHost, const char* gcPort); + HostAddressPort socksProxy; + HostAddressPort httpProxy; + }; } diff --git a/Swiften/Network/HTTPConnectProxiedConnection.cpp b/Swiften/Network/HTTPConnectProxiedConnection.cpp index a88ded1..e63b8e2 100644 --- a/Swiften/Network/HTTPConnectProxiedConnection.cpp +++ b/Swiften/Network/HTTPConnectProxiedConnection.cpp @@ -5,80 +5,158 @@ */ /* - * Copyright (c) 2011-2012 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/HTTPConnectProxiedConnection.h> #include <iostream> +#include <utility> + +#include <boost/algorithm/string.hpp> #include <boost/bind.hpp> #include <boost/lexical_cast.hpp> #include <Swiften/Base/Algorithm.h> +#include <Swiften/Base/ByteArray.h> #include <Swiften/Base/Log.h> #include <Swiften/Base/String.h> -#include <Swiften/Base/ByteArray.h> -#include <Swiften/Network/HostAddressPort.h> #include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/Network/HTTPTrafficFilter.h> +#include <Swiften/Network/HostAddressPort.h> #include <Swiften/StringCodecs/Base64.h> using namespace Swift; HTTPConnectProxiedConnection::HTTPConnectProxiedConnection( - DomainNameResolver* resolver, - ConnectionFactory* connectionFactory, - TimerFactory* timerFactory, - const std::string& proxyHost, - int proxyPort, - const SafeString& authID, - const SafeString& authPassword) : - ProxiedConnection(resolver, connectionFactory, timerFactory, proxyHost, proxyPort), - authID_(authID), - authPassword_(authPassword) { + DomainNameResolver* resolver, + ConnectionFactory* connectionFactory, + TimerFactory* timerFactory, + const std::string& proxyHost, + unsigned short proxyPort, + const SafeString& authID, + const SafeString& authPassword) : + ProxiedConnection(resolver, connectionFactory, timerFactory, proxyHost, proxyPort), + authID_(authID), + authPassword_(authPassword) { +} + +HTTPConnectProxiedConnection::~HTTPConnectProxiedConnection() { + } +void HTTPConnectProxiedConnection::setHTTPTrafficFilter(std::shared_ptr<HTTPTrafficFilter> trafficFilter) { + trafficFilter_ = trafficFilter; +} void HTTPConnectProxiedConnection::initializeProxy() { - std::stringstream connect; - connect << "CONNECT " << getServer().getAddress().toString() << ":" << getServer().getPort() << " HTTP/1.1\r\n"; - SafeByteArray data = createSafeByteArray(connect.str()); - if (!authID_.empty() && !authPassword_.empty()) { - append(data, createSafeByteArray("Proxy-Authorization: Basic ")); - SafeByteArray credentials = authID_; - append(credentials, createSafeByteArray(":")); - append(credentials, authPassword_); - append(data, Base64::encode(credentials)); - append(data, createSafeByteArray("\r\n")); - } - append(data, createSafeByteArray("\r\n")); - SWIFT_LOG(debug) << "HTTP Proxy send headers: " << byteArrayToString(ByteArray(data.begin(), data.end())) << std::endl; - write(data); + httpResponseBuffer_.clear(); + + std::stringstream connect; + connect << "CONNECT " << getServer().getAddress().toString() << ":" << getServer().getPort() << " HTTP/1.1\r\n"; + SafeByteArray data = createSafeByteArray(connect.str()); + if (!authID_.empty() && !authPassword_.empty()) { + append(data, createSafeByteArray("Proxy-Authorization: Basic ")); + SafeByteArray credentials = authID_; + append(credentials, createSafeByteArray(":")); + append(credentials, authPassword_); + append(data, Base64::encode(credentials)); + append(data, createSafeByteArray("\r\n")); + } + else if (!nextHTTPRequestHeaders_.empty()) { + for (const auto& headerField : nextHTTPRequestHeaders_) { + append(data, createSafeByteArray(headerField.first)); + append(data, createSafeByteArray(": ")); + append(data, createSafeByteArray(headerField.second)); + append(data, createSafeByteArray("\r\n")); + } + + nextHTTPRequestHeaders_.clear(); + } + append(data, createSafeByteArray("\r\n")); + SWIFT_LOG(debug) << "HTTP Proxy send headers: " << byteArrayToString(ByteArray(data.begin(), data.end())); + write(data); } -void HTTPConnectProxiedConnection::handleProxyInitializeData(boost::shared_ptr<SafeByteArray> data) { - SWIFT_LOG(debug) << byteArrayToString(ByteArray(data->begin(), data->end())) << std::endl; - std::vector<std::string> tmp = String::split(byteArrayToString(ByteArray(data->begin(), data->end())), ' '); - if (tmp.size() > 1) { - try { - int status = boost::lexical_cast<int>(tmp[1]); - SWIFT_LOG(debug) << "Proxy Status: " << status << std::endl; - if (status / 100 == 2) { // all 2XX states are OK - setProxyInitializeFinished(true); - } - else { - SWIFT_LOG(debug) << "HTTP Proxy returned an error: " << byteArrayToString(ByteArray(data->begin(), data->end())) << std::endl; - setProxyInitializeFinished(false); - } - } - catch (boost::bad_lexical_cast&) { - SWIFT_LOG(warning) << "Unexpected response: " << tmp[1] << std::endl; - setProxyInitializeFinished(false); - } - } - else { - setProxyInitializeFinished(false); - } +void HTTPConnectProxiedConnection::parseHTTPHeader(const std::string& data, std::string& statusLine, std::vector<std::pair<std::string, std::string> >& headerFields) { + std::istringstream dataStream(data); + + // parse status line + std::getline(dataStream, statusLine); + + // parse fields + std::string headerLine; + std::string::size_type splitIndex; + while (std::getline(dataStream, headerLine) && headerLine != "\r") { + splitIndex = headerLine.find(':', 0); + if (splitIndex != std::string::npos) { + headerFields.push_back(std::pair<std::string, std::string>(headerLine.substr(0, splitIndex), headerLine.substr(splitIndex + 1))); + } + } +} + +void HTTPConnectProxiedConnection::sendHTTPRequest(const std::string& statusLine, const std::vector<std::pair<std::string, std::string> >& headerFields) { + std::stringstream request; + + request << statusLine << "\r\n"; + for (const auto& field : headerFields) { + request << field.first << ":" << field.second << "\r\n"; + } + request << "\r\n"; + write(createSafeByteArray(request.str())); +} + +void HTTPConnectProxiedConnection::handleProxyInitializeData(std::shared_ptr<SafeByteArray> data) { + std::string dataString = byteArrayToString(ByteArray(data->begin(), data->end())); + SWIFT_LOG(debug) << data; + httpResponseBuffer_.append(dataString); + + std::string statusLine; + std::vector<std::pair<std::string, std::string> > headerFields; + + std::string::size_type headerEnd = httpResponseBuffer_.find("\r\n\r\n", 0); + if (headerEnd == std::string::npos) { + if ((httpResponseBuffer_.size() > 4) && (httpResponseBuffer_.substr(0, 4) != "HTTP")) { + setProxyInitializeFinished(false); + } + return; + } + + parseHTTPHeader(httpResponseBuffer_.substr(0, headerEnd), statusLine, headerFields); + + if (trafficFilter_) { + std::vector<std::pair<std::string, std::string> > newHeaderFields = trafficFilter_->filterHTTPResponseHeader(statusLine, headerFields); + if (!newHeaderFields.empty()) { + std::stringstream statusLine; + reconnect(); + nextHTTPRequestHeaders_ = newHeaderFields; + return; + } + } + + std::vector<std::string> tmp = String::split(statusLine, ' '); + if (tmp.size() > 1) { + try { + int status = boost::lexical_cast<int>(tmp[1]); + SWIFT_LOG(debug) << "Proxy Status: " << status; + if (status / 100 == 2) { // all 2XX states are OK + setProxyInitializeFinished(true); + } + else { + SWIFT_LOG(debug) << "HTTP Proxy returned an error: " << httpResponseBuffer_; + setProxyInitializeFinished(false); + } + } + catch (boost::bad_lexical_cast&) { + SWIFT_LOG(warning) << "Unexpected response: " << tmp[1]; + setProxyInitializeFinished(false); + } + } + else { + setProxyInitializeFinished(false); + } + httpResponseBuffer_.clear(); } diff --git a/Swiften/Network/HTTPConnectProxiedConnection.h b/Swiften/Network/HTTPConnectProxiedConnection.h index c209dc1..a83d47c 100644 --- a/Swiften/Network/HTTPConnectProxiedConnection.h +++ b/Swiften/Network/HTTPConnectProxiedConnection.h @@ -5,39 +5,51 @@ */ /* - * Copyright (c) 2011-2012 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once +#include <memory> + #include <Swiften/Base/API.h> #include <Swiften/Network/ProxiedConnection.h> namespace Swift { - class DomainNameResolver; - class ConnectionFactory; - class EventLoop; - class TimerFactory; + class ConnectionFactory; + class DomainNameResolver; + class HTTPTrafficFilter; + class TimerFactory; + + class SWIFTEN_API HTTPConnectProxiedConnection : public ProxiedConnection { + public: + typedef std::shared_ptr<HTTPConnectProxiedConnection> ref; + + virtual ~HTTPConnectProxiedConnection(); + + static ref create(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, unsigned short proxyPort, const SafeString& authID, const SafeString& authPassword) { + return ref(new HTTPConnectProxiedConnection(resolver, connectionFactory, timerFactory, proxyHost, proxyPort, authID, authPassword)); + } - class SWIFTEN_API HTTPConnectProxiedConnection : public ProxiedConnection { - public: - typedef boost::shared_ptr<HTTPConnectProxiedConnection> ref; + void setHTTPTrafficFilter(std::shared_ptr<HTTPTrafficFilter> trafficFilter); - static ref create(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, int proxyPort, const SafeString& authID, const SafeString& authPassword) { - return ref(new HTTPConnectProxiedConnection(resolver, connectionFactory, timerFactory, proxyHost, proxyPort, authID, authPassword)); - } + private: + HTTPConnectProxiedConnection(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, unsigned short proxyPort, const SafeString& authID, const SafeString& authPassword); - private: - HTTPConnectProxiedConnection(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, int proxyPort, const SafeString& authID, const SafeString& authPassword); + virtual void initializeProxy(); + virtual void handleProxyInitializeData(std::shared_ptr<SafeByteArray> data); - virtual void initializeProxy(); - virtual void handleProxyInitializeData(boost::shared_ptr<SafeByteArray> data); + void sendHTTPRequest(const std::string& statusLine, const std::vector<std::pair<std::string, std::string> >& headerFields); + void parseHTTPHeader(const std::string& data, std::string& statusLine, std::vector<std::pair<std::string, std::string> >& headerFields); - private: - SafeByteArray authID_; - SafeByteArray authPassword_; - }; + private: + SafeByteArray authID_; + SafeByteArray authPassword_; + std::shared_ptr<HTTPTrafficFilter> trafficFilter_; + std::string httpResponseBuffer_; + std::vector<std::pair<std::string, std::string> > nextHTTPRequestHeaders_; + }; } diff --git a/Swiften/Network/HTTPConnectProxiedConnectionFactory.cpp b/Swiften/Network/HTTPConnectProxiedConnectionFactory.cpp index cf4cef5..54b998a 100644 --- a/Swiften/Network/HTTPConnectProxiedConnectionFactory.cpp +++ b/Swiften/Network/HTTPConnectProxiedConnectionFactory.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2012 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2012-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ /* @@ -16,15 +16,17 @@ namespace Swift { -HTTPConnectProxiedConnectionFactory::HTTPConnectProxiedConnectionFactory(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, int proxyPort) : resolver_(resolver), connectionFactory_(connectionFactory), timerFactory_(timerFactory), proxyHost_(proxyHost), proxyPort_(proxyPort), authID_(""), authPassword_("") { +HTTPConnectProxiedConnectionFactory::HTTPConnectProxiedConnectionFactory(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, unsigned short proxyPort, std::shared_ptr<HTTPTrafficFilter> httpTrafficFilter) : resolver_(resolver), connectionFactory_(connectionFactory), timerFactory_(timerFactory), proxyHost_(proxyHost), proxyPort_(proxyPort), authID_(""), authPassword_(""), httpTrafficFilter_(httpTrafficFilter) { } -HTTPConnectProxiedConnectionFactory::HTTPConnectProxiedConnectionFactory(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, int proxyPort, const SafeString& authID, const SafeString& authPassword) : resolver_(resolver), connectionFactory_(connectionFactory), timerFactory_(timerFactory), proxyHost_(proxyHost), proxyPort_(proxyPort), authID_(authID), authPassword_(authPassword) { +HTTPConnectProxiedConnectionFactory::HTTPConnectProxiedConnectionFactory(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, unsigned short proxyPort, const SafeString& authID, const SafeString& authPassword, std::shared_ptr<HTTPTrafficFilter> httpTrafficFilter) : resolver_(resolver), connectionFactory_(connectionFactory), timerFactory_(timerFactory), proxyHost_(proxyHost), proxyPort_(proxyPort), authID_(authID), authPassword_(authPassword), httpTrafficFilter_(httpTrafficFilter) { } -boost::shared_ptr<Connection> HTTPConnectProxiedConnectionFactory::createConnection() { - return HTTPConnectProxiedConnection::create(resolver_, connectionFactory_, timerFactory_, proxyHost_, proxyPort_, authID_, authPassword_); +std::shared_ptr<Connection> HTTPConnectProxiedConnectionFactory::createConnection() { + HTTPConnectProxiedConnection::ref proxyConnection = HTTPConnectProxiedConnection::create(resolver_, connectionFactory_, timerFactory_, proxyHost_, proxyPort_, authID_, authPassword_); + proxyConnection->setHTTPTrafficFilter(httpTrafficFilter_); + return proxyConnection; } } diff --git a/Swiften/Network/HTTPConnectProxiedConnectionFactory.h b/Swiften/Network/HTTPConnectProxiedConnectionFactory.h index 3efcecd..7a5f527 100644 --- a/Swiften/Network/HTTPConnectProxiedConnectionFactory.h +++ b/Swiften/Network/HTTPConnectProxiedConnectionFactory.h @@ -1,7 +1,7 @@ /* - * Copyright (c) 2012 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2012-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ /* @@ -12,28 +12,31 @@ #pragma once +#include <Swiften/Base/API.h> +#include <Swiften/Base/SafeString.h> #include <Swiften/Network/ConnectionFactory.h> #include <Swiften/Network/HostAddressPort.h> -#include <Swiften/Base/SafeString.h> namespace Swift { - class DomainNameResolver; - class TimerFactory; - class EventLoop; - class HTTPConnectProxiedConnectionFactory : public ConnectionFactory { - public: - HTTPConnectProxiedConnectionFactory(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, int proxyPort); - HTTPConnectProxiedConnectionFactory(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, int proxyPort, const SafeString& authID, const SafeString& authPassword); + class DomainNameResolver; + class HTTPTrafficFilter; + class TimerFactory; + + class SWIFTEN_API HTTPConnectProxiedConnectionFactory : public ConnectionFactory { + public: + HTTPConnectProxiedConnectionFactory(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, unsigned short proxyPort, std::shared_ptr<HTTPTrafficFilter> httpTrafficFilter = std::shared_ptr<HTTPTrafficFilter>()); + HTTPConnectProxiedConnectionFactory(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, unsigned short proxyPort, const SafeString& authID, const SafeString& authPassword, std::shared_ptr<HTTPTrafficFilter> httpTrafficFilter = std::shared_ptr<HTTPTrafficFilter>()); - virtual boost::shared_ptr<Connection> createConnection(); + virtual std::shared_ptr<Connection> createConnection(); - private: - DomainNameResolver* resolver_; - ConnectionFactory* connectionFactory_; - TimerFactory* timerFactory_; - std::string proxyHost_; - int proxyPort_; - SafeString authID_; - SafeString authPassword_; - }; + private: + DomainNameResolver* resolver_; + ConnectionFactory* connectionFactory_; + TimerFactory* timerFactory_; + std::string proxyHost_; + unsigned short proxyPort_; + SafeString authID_; + SafeString authPassword_; + std::shared_ptr<HTTPTrafficFilter> httpTrafficFilter_; + }; } diff --git a/Swiften/Network/HTTPTrafficFilter.cpp b/Swiften/Network/HTTPTrafficFilter.cpp new file mode 100644 index 0000000..d40fbdf --- /dev/null +++ b/Swiften/Network/HTTPTrafficFilter.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2015 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + +#include <Swiften/Network/HTTPTrafficFilter.h> + +namespace Swift { + +HTTPTrafficFilter::~HTTPTrafficFilter() { + +} + +} diff --git a/Swiften/Network/HTTPTrafficFilter.h b/Swiften/Network/HTTPTrafficFilter.h new file mode 100644 index 0000000..5c29bd6 --- /dev/null +++ b/Swiften/Network/HTTPTrafficFilter.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2015 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + +#pragma once + +#include <string> +#include <utility> +#include <vector> + +#include <boost/optional.hpp> + +#include <Swiften/Base/API.h> + +namespace Swift { + +class SWIFTEN_API HTTPTrafficFilter { + public: + virtual ~HTTPTrafficFilter(); + /** + * @brief This method is called by the HTTPConnectPRoxiedConnection on every incoming HTTP response. + * It can be used to insert additional HTTP requests into the HTTP CONNECT proxy initalization process. + * @return A vector of HTTP header fields to use in a new request. If an empty vector is returned, + * no new request will be send and the normal proxy logic continues. + */ + virtual std::vector<std::pair<std::string, std::string> > filterHTTPResponseHeader(const std::string& statusLine, const std::vector<std::pair<std::string, std::string> >& /* responseHeader */) = 0; +}; + +} diff --git a/Swiften/Network/HostAddress.cpp b/Swiften/Network/HostAddress.cpp index ff5c1c4..e82f433 100644 --- a/Swiften/Network/HostAddress.cpp +++ b/Swiften/Network/HostAddress.cpp @@ -1,20 +1,17 @@ /* - * Copyright (c) 2010-2013 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/HostAddress.h> -#include <boost/numeric/conversion/cast.hpp> -#include <boost/lexical_cast.hpp> #include <cassert> -#include <stdexcept> -#include <boost/array.hpp> - -#include <Swiften/Base/foreach.h> +#include <cstring> #include <string> +#include <Swiften/Base/Log.h> + static boost::asio::ip::address localhost4 = boost::asio::ip::address(boost::asio::ip::address_v4::loopback()); static boost::asio::ip::address localhost6 = boost::asio::ip::address(boost::asio::ip::address_v6::loopback()); @@ -23,49 +20,55 @@ namespace Swift { HostAddress::HostAddress() { } -HostAddress::HostAddress(const std::string& address) { - try { - address_ = boost::asio::ip::address::from_string(address); - } - catch (const std::exception&) { - } -} - HostAddress::HostAddress(const unsigned char* address, size_t length) { - assert(length == 4 || length == 16); - if (length == 4) { - boost::asio::ip::address_v4::bytes_type data; - for (size_t i = 0; i < length; ++i) { - data[i] = address[i]; - } - address_ = boost::asio::ip::address(boost::asio::ip::address_v4(data)); - } - else { - boost::asio::ip::address_v6::bytes_type data; - for (size_t i = 0; i < length; ++i) { - data[i] = address[i]; - } - address_ = boost::asio::ip::address(boost::asio::ip::address_v6(data)); - } + assert(length == 4 || length == 16); + if (length == 4) { + boost::asio::ip::address_v4::bytes_type data; + std::memcpy(data.data(), address, length); + address_ = boost::asio::ip::address(boost::asio::ip::address_v4(data)); + } + else { + boost::asio::ip::address_v6::bytes_type data; + std::memcpy(data.data(), address, length); + address_ = boost::asio::ip::address(boost::asio::ip::address_v6(data)); + } } HostAddress::HostAddress(const boost::asio::ip::address& address) : address_(address) { } std::string HostAddress::toString() const { - return address_.to_string(); + std::string addressString; + boost::system::error_code errorCode; + + addressString = address_.to_string(errorCode); + if (errorCode) { + SWIFT_LOG(debug) << "error: " << errorCode.message(); + } + + return addressString; } bool HostAddress::isValid() const { - return !(address_.is_v4() && address_.to_v4().to_ulong() == 0); + return !(address_.is_v4() && address_.to_v4().to_ulong() == 0); } boost::asio::ip::address HostAddress::getRawAddress() const { - return address_; + return address_; } bool HostAddress::isLocalhost() const { - return address_ == localhost4 || address_ == localhost6; + return address_ == localhost4 || address_ == localhost6; +} + +boost::optional<HostAddress> HostAddress::fromString(const std::string& addressString) { + boost::optional<HostAddress> hostAddress; + boost::system::error_code errorCode; + boost::asio::ip::address address = boost::asio::ip::address::from_string(addressString, errorCode); + if (!errorCode) { + hostAddress = HostAddress(address); + } + return hostAddress; } } diff --git a/Swiften/Network/HostAddress.h b/Swiften/Network/HostAddress.h index c62239b..7a22cf4 100644 --- a/Swiften/Network/HostAddress.h +++ b/Swiften/Network/HostAddress.h @@ -1,34 +1,42 @@ /* - * Copyright (c) 2010-2013 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ + #pragma once #include <string> + #include <boost/asio/ip/address.hpp> +#include <boost/optional.hpp> #include <Swiften/Base/API.h> namespace Swift { - class SWIFTEN_API HostAddress { - public: - HostAddress(); - HostAddress(const std::string&); - HostAddress(const unsigned char* address, size_t length); - HostAddress(const boost::asio::ip::address& address); - - std::string toString() const; - boost::asio::ip::address getRawAddress() const; - - bool operator==(const HostAddress& o) const { - return address_ == o.address_; - } - - bool isValid() const; - bool isLocalhost() const; - - private: - boost::asio::ip::address address_; - }; + class SWIFTEN_API HostAddress { + public: + HostAddress(); + HostAddress(const unsigned char* address, size_t length); + HostAddress(const boost::asio::ip::address& address); + + std::string toString() const; + boost::asio::ip::address getRawAddress() const; + + bool operator==(const HostAddress& o) const { + return address_ == o.address_; + } + + bool operator<(const HostAddress& o) const { + return address_ < o.address_; + } + + bool isValid() const; + bool isLocalhost() const; + + static boost::optional<HostAddress> fromString(const std::string& addressString); + + private: + boost::asio::ip::address address_; + }; } diff --git a/Swiften/Network/HostAddressPort.cpp b/Swiften/Network/HostAddressPort.cpp index e2e6012..248be2d 100644 --- a/Swiften/Network/HostAddressPort.cpp +++ b/Swiften/Network/HostAddressPort.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/HostAddressPort.h> @@ -10,14 +10,19 @@ using namespace Swift; -HostAddressPort::HostAddressPort(const HostAddress& address, int port) : address_(address), port_(port) { +HostAddressPort::HostAddressPort(const HostAddress& address, unsigned short port) : address_(address), port_(port) { } HostAddressPort::HostAddressPort(const boost::asio::ip::tcp::endpoint& endpoint) { - address_ = HostAddress(endpoint.address()); - port_ = endpoint.port(); + address_ = HostAddress(endpoint.address()); + port_ = endpoint.port(); } std::string HostAddressPort::toString() const { - return getAddress().toString() + ":" + boost::lexical_cast<std::string>(getPort()); + std::string portAsString; + try { + portAsString = std::to_string(getPort()); + } catch (boost::bad_lexical_cast&) { + } + return getAddress().toString() + ":" + portAsString; } diff --git a/Swiften/Network/HostAddressPort.h b/Swiften/Network/HostAddressPort.h index 68f3a1c..759af01 100644 --- a/Swiften/Network/HostAddressPort.h +++ b/Swiften/Network/HostAddressPort.h @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once @@ -12,31 +12,38 @@ #include <Swiften/Network/HostAddress.h> namespace Swift { - class SWIFTEN_API HostAddressPort { - public: - HostAddressPort(const HostAddress& address = HostAddress(), int port = -1); - HostAddressPort(const boost::asio::ip::tcp::endpoint& endpoint); - - const HostAddress& getAddress() const { - return address_; - } - - int getPort() const { - return port_; - } - - bool operator==(const HostAddressPort& o) const { - return address_ == o.address_ && port_ == o.port_; - } - - bool isValid() const { - return address_.isValid() && port_ > 0; - } - - std::string toString() const; - - private: - HostAddress address_; - int port_; - }; + class SWIFTEN_API HostAddressPort { + public: + HostAddressPort(const HostAddress& address = HostAddress(), unsigned short port = 0); + HostAddressPort(const boost::asio::ip::tcp::endpoint& endpoint); + + const HostAddress& getAddress() const { + return address_; + } + + unsigned short getPort() const { + return port_; + } + + bool operator==(const HostAddressPort& o) const { + return address_ == o.address_ && port_ == o.port_; + } + + bool operator<(const HostAddressPort& o) const { + if (address_ < o.address_) { + return true; + } + return address_ == o.address_ && port_ < o.port_; + } + + bool isValid() const { + return address_.isValid() && port_ > 0; + } + + std::string toString() const; + + private: + HostAddress address_; + unsigned short port_; + }; } diff --git a/Swiften/Network/HostNameOrAddress.cpp b/Swiften/Network/HostNameOrAddress.cpp index bc2737d..5c5e5e0 100644 --- a/Swiften/Network/HostNameOrAddress.cpp +++ b/Swiften/Network/HostNameOrAddress.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2012 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2012 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/HostNameOrAddress.h> @@ -9,21 +9,21 @@ using namespace Swift; namespace { - struct ToStringVisitor : public boost::static_visitor<std::string> { - std::string operator()(const HostAddress& address) const { - return address.toString(); - } - - std::string operator()(const std::string & str) const { - return str; - } - }; + struct ToStringVisitor : public boost::static_visitor<std::string> { + std::string operator()(const HostAddress& address) const { + return address.toString(); + } + + std::string operator()(const std::string & str) const { + return str; + } + }; } namespace Swift { std::string toString(const HostNameOrAddress& address) { - return boost::apply_visitor(ToStringVisitor(), address); + return boost::apply_visitor(ToStringVisitor(), address); } } diff --git a/Swiften/Network/HostNameOrAddress.h b/Swiften/Network/HostNameOrAddress.h index f804d15..81c0995 100644 --- a/Swiften/Network/HostNameOrAddress.h +++ b/Swiften/Network/HostNameOrAddress.h @@ -1,16 +1,18 @@ /* - * Copyright (c) 2012 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2012-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once #include <string> -#include <Swiften/Network/HostAddress.h> + #include <boost/variant.hpp> +#include <Swiften/Network/HostAddress.h> + namespace Swift { - typedef boost::variant<std::string, HostAddress> HostNameOrAddress; + typedef boost::variant<std::string, HostAddress> HostNameOrAddress; - std::string toString(const HostNameOrAddress& address); + std::string toString(const HostNameOrAddress& address); } diff --git a/Swiften/Network/MacOSXProxyProvider.cpp b/Swiften/Network/MacOSXProxyProvider.cpp index 3456c73..d3b10dd 100644 --- a/Swiften/Network/MacOSXProxyProvider.cpp +++ b/Swiften/Network/MacOSXProxyProvider.cpp @@ -5,9 +5,9 @@ */ /* - * Copyright (c) 2013 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2013-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Base/Platform.h> @@ -16,9 +16,10 @@ #include <stdio.h> #include <stdlib.h> #include <iostream> -#include <boost/numeric/conversion/cast.hpp> #include <utility> +#include <boost/numeric/conversion/cast.hpp> + #ifndef SWIFTEN_PLATFORM_IPHONE #include <SystemConfiguration/SystemConfiguration.h> #endif @@ -29,49 +30,51 @@ using namespace Swift; #ifndef SWIFTEN_PLATFORM_IPHONE static HostAddressPort getFromDictionary(CFDictionaryRef dict, CFStringRef enabledKey, CFStringRef hostKey, CFStringRef portKey) { - CFNumberRef numberValue = NULL; - HostAddressPort ret = HostAddressPort(HostAddress(), 0); - - if(CFDictionaryGetValueIfPresent(dict, reinterpret_cast<const void*> (enabledKey), reinterpret_cast<const void**> (&numberValue)) == true) { - const int i = 0; - CFNumberRef zero = CFNumberCreate(kCFAllocatorDefault, kCFNumberIntType, &i); - CFComparisonResult result = CFNumberCompare(numberValue, zero, NULL); - CFRelease(zero); - - if(result != kCFCompareEqualTo) { - int port = 0; - std::string host = ""; - - try { - CFNumberRef numberValue = reinterpret_cast<CFNumberRef> (CFDictionaryGetValue(dict, portKey)); - if(numberValue != NULL) { - CFNumberGetValue(numberValue, kCFNumberIntType, &port); - } - - CFStringRef stringValue = reinterpret_cast<CFStringRef> (CFDictionaryGetValue(dict, hostKey)); - if(stringValue != NULL) { - std::vector<char> buffer; - // length must be +1 for the ending zero; and the Docu of CFStringGetCString tells it like - // if the string is toby the length must be at least 5. - CFIndex length = CFStringGetLength(stringValue) + 1; - buffer.resize(boost::numeric_cast<size_t>(length)); - if(CFStringGetCString(stringValue, &buffer[0], length, kCFStringEncodingMacRoman)) { - for(std::vector<char>::iterator iter = buffer.begin(); iter != buffer.end(); ++iter) { - host += *iter; - } - } - } - } - catch(...) { - std::cerr << "Exception caught ... " << std::endl; - } - - if(host != "" && port != 0) { - ret = HostAddressPort(HostAddress(host), port); - } - } - } - return ret; + CFNumberRef numberValue = nullptr; + HostAddressPort ret = HostAddressPort(HostAddress(), 0); + + if(CFDictionaryGetValueIfPresent(dict, reinterpret_cast<const void*> (enabledKey), reinterpret_cast<const void**> (&numberValue)) == true) { + const int i = 0; + CFNumberRef zero = CFNumberCreate(kCFAllocatorDefault, kCFNumberIntType, &i); + CFComparisonResult result = CFNumberCompare(numberValue, zero, nullptr); + CFRelease(zero); + + if(result != kCFCompareEqualTo) { + unsigned short port = 0; + std::string host = ""; + + try { + CFNumberRef numberValue = reinterpret_cast<CFNumberRef> (CFDictionaryGetValue(dict, portKey)); + if(numberValue != nullptr) { + int intPort = 0; + CFNumberGetValue(numberValue, kCFNumberIntType, &intPort); + port = boost::numeric_cast<unsigned short>(intPort); + } + + CFStringRef stringValue = reinterpret_cast<CFStringRef> (CFDictionaryGetValue(dict, hostKey)); + if(stringValue != nullptr) { + std::vector<char> buffer; + // length must be +1 for the ending zero; and the Docu of CFStringGetCString tells it like + // if the string is toby the length must be at least 5. + CFIndex length = CFStringGetLength(stringValue) + 1; + buffer.resize(boost::numeric_cast<size_t>(length)); + if(CFStringGetCString(stringValue, &buffer[0], length, kCFStringEncodingMacRoman)) { + for(char& iter : buffer) { + host += iter; + } + } + } + } + catch(...) { + std::cerr << "Exception caught ... " << std::endl; + } + + if(host != "" && port != 0) { + ret = HostAddressPort(HostAddress::fromString(host).get(), port); + } + } + } + return ret; } #endif namespace Swift { @@ -80,27 +83,27 @@ MacOSXProxyProvider::MacOSXProxyProvider() { } HostAddressPort MacOSXProxyProvider::getHTTPConnectProxy() const { - HostAddressPort result; + HostAddressPort result; #ifndef SWIFTEN_PLATFORM_IPHONE - CFDictionaryRef proxies = SCDynamicStoreCopyProxies(NULL); - if(proxies != NULL) { - result = getFromDictionary(proxies, kSCPropNetProxiesHTTPEnable, kSCPropNetProxiesHTTPProxy, kSCPropNetProxiesHTTPPort); - CFRelease(proxies); - } + CFDictionaryRef proxies = SCDynamicStoreCopyProxies(nullptr); + if(proxies != nullptr) { + result = getFromDictionary(proxies, kSCPropNetProxiesHTTPEnable, kSCPropNetProxiesHTTPProxy, kSCPropNetProxiesHTTPPort); + CFRelease(proxies); + } #endif - return result; + return result; } HostAddressPort MacOSXProxyProvider::getSOCKS5Proxy() const { - HostAddressPort result; + HostAddressPort result; #ifndef SWIFTEN_PLATFORM_IPHONE - CFDictionaryRef proxies = SCDynamicStoreCopyProxies(NULL); - if(proxies != NULL) { - result = getFromDictionary(proxies, kSCPropNetProxiesSOCKSEnable, kSCPropNetProxiesSOCKSProxy, kSCPropNetProxiesSOCKSPort); - CFRelease(proxies); - } + CFDictionaryRef proxies = SCDynamicStoreCopyProxies(nullptr); + if(proxies != nullptr) { + result = getFromDictionary(proxies, kSCPropNetProxiesSOCKSEnable, kSCPropNetProxiesSOCKSProxy, kSCPropNetProxiesSOCKSPort); + CFRelease(proxies); + } #endif - return result; + return result; } } diff --git a/Swiften/Network/MacOSXProxyProvider.h b/Swiften/Network/MacOSXProxyProvider.h index 6666d30..56ffd9f 100644 --- a/Swiften/Network/MacOSXProxyProvider.h +++ b/Swiften/Network/MacOSXProxyProvider.h @@ -4,15 +4,23 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once -#include <Swiften/Network/ProxyProvider.h> + #include <CoreFoundation/CoreFoundation.h> +#include <Swiften/Network/ProxyProvider.h> + namespace Swift { - class MacOSXProxyProvider : public ProxyProvider { - public: - MacOSXProxyProvider(); - virtual HostAddressPort getHTTPConnectProxy() const; - virtual HostAddressPort getSOCKS5Proxy() const; - }; + class MacOSXProxyProvider : public ProxyProvider { + public: + MacOSXProxyProvider(); + virtual HostAddressPort getHTTPConnectProxy() const; + virtual HostAddressPort getSOCKS5Proxy() const; + }; } diff --git a/Swiften/Network/MiniUPnPInterface.cpp b/Swiften/Network/MiniUPnPInterface.cpp index bfa989f..8425c77 100644 --- a/Swiften/Network/MiniUPnPInterface.cpp +++ b/Swiften/Network/MiniUPnPInterface.cpp @@ -4,107 +4,117 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2015-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #include <Swiften/Network/MiniUPnPInterface.h> +#include <memory> + #include <miniupnpc.h> #include <upnpcommands.h> #include <upnperrors.h> -#include <boost/smart_ptr/make_shared.hpp> -#include <boost/lexical_cast.hpp> #include <Swiften/Base/Log.h> namespace Swift { struct MiniUPnPInterface::Private { - bool isValid; - std::string localAddress; - UPNPDev* deviceList; - UPNPUrls urls; - IGDdatas data; + bool isValid; + std::string localAddress; + UPNPDev* deviceList; + UPNPUrls urls; + IGDdatas data; }; -MiniUPnPInterface::MiniUPnPInterface() : p(boost::make_shared<Private>()) { - p->isValid = false; - int error = 0; - p->deviceList = upnpDiscover(1500 /* timeout in ms */, 0, 0, 0, 0 /* do IPv6? */, &error); - if (!p->deviceList) { - return; - } - - char lanAddress[64]; - if (!UPNP_GetValidIGD(p->deviceList, &p->urls, &p->data, lanAddress, sizeof(lanAddress))) { - return; - } - p->localAddress = std::string(lanAddress); - p->isValid = true; +MiniUPnPInterface::MiniUPnPInterface() : p(new Private()) { + p->isValid = false; + int error = 0; +#if MINIUPNPC_API_VERSION > 13 + p->deviceList = upnpDiscover(1500 /* timeout in ms */, nullptr, nullptr, 0, 0 /* do IPv6? */, 2 /* default TTL */, &error); +#else + p->deviceList = upnpDiscover(1500 /* timeout in ms */, nullptr, nullptr, 0, 0 /* do IPv6? */, &error); +#endif + if (!p->deviceList) { + return; + } + + char lanAddress[64]; + if (!UPNP_GetValidIGD(p->deviceList, &p->urls, &p->data, lanAddress, sizeof(lanAddress))) { + return; + } + p->localAddress = std::string(lanAddress); + p->isValid = true; } MiniUPnPInterface::~MiniUPnPInterface() { - if (p->isValid) { - FreeUPNPUrls(&p->urls); - } - freeUPNPDevlist(p->deviceList); + if (p->isValid) { + FreeUPNPUrls(&p->urls); + } + freeUPNPDevlist(p->deviceList); } boost::optional<HostAddress> MiniUPnPInterface::getPublicIP() { - if (!p->isValid) { - return boost::optional<HostAddress>(); - } - char externalIPAddress[40]; - int ret = UPNP_GetExternalIPAddress(p->urls.controlURL, p->data.first.servicetype, externalIPAddress); - if (ret != UPNPCOMMAND_SUCCESS) { - return boost::optional<HostAddress>(); - } - else { - return HostAddress(std::string(externalIPAddress)); - } + if (!p->isValid) { + return boost::optional<HostAddress>(); + } + char externalIPAddress[40]; + int ret = UPNP_GetExternalIPAddress(p->urls.controlURL, p->data.first.servicetype, externalIPAddress); + if (ret != UPNPCOMMAND_SUCCESS) { + return boost::optional<HostAddress>(); + } + else { + return HostAddress::fromString(std::string(externalIPAddress)); + } } -boost::optional<NATPortMapping> MiniUPnPInterface::addPortForward(int actualLocalPort, int actualPublicPort) { - if (!p->isValid) { - return boost::optional<NATPortMapping>(); - } - - NATPortMapping mapping(actualLocalPort, actualPublicPort, NATPortMapping::TCP); - - std::string publicPort = boost::lexical_cast<std::string>(mapping.getPublicPort()); - std::string localPort = boost::lexical_cast<std::string>(mapping.getLocalPort()); - std::string leaseSeconds = boost::lexical_cast<std::string>(mapping.getLeaseInSeconds()); - - int ret = UPNP_AddPortMapping( - p->urls.controlURL, - p->data.first.servicetype, - publicPort.c_str(), - localPort.c_str(), - p->localAddress.c_str(), - 0, - mapping.getProtocol() == NATPortMapping::TCP ? "TCP" : "UDP", - 0, - leaseSeconds.c_str()); - if (ret == UPNPCOMMAND_SUCCESS) { - return mapping; - } - else { - return boost::optional<NATPortMapping>(); - } +boost::optional<NATPortMapping> MiniUPnPInterface::addPortForward(unsigned short actualLocalPort, unsigned short actualPublicPort) { + if (!p->isValid) { + return boost::optional<NATPortMapping>(); + } + + NATPortMapping mapping(actualLocalPort, actualPublicPort, NATPortMapping::TCP); + + std::string publicPort = std::to_string(mapping.getPublicPort()); + std::string localPort = std::to_string(mapping.getLocalPort()); + std::string leaseSeconds = std::to_string(mapping.getLeaseInSeconds()); + + int ret = UPNP_AddPortMapping( + p->urls.controlURL, + p->data.first.servicetype, + publicPort.c_str(), + localPort.c_str(), + p->localAddress.c_str(), + "Swift", + mapping.getProtocol() == NATPortMapping::TCP ? "TCP" : "UDP", + nullptr, + leaseSeconds.c_str()); + if (ret == UPNPCOMMAND_SUCCESS) { + return mapping; + } + else { + return boost::optional<NATPortMapping>(); + } } bool MiniUPnPInterface::removePortForward(const NATPortMapping& mapping) { - if (!p->isValid) { - return false; - } + if (!p->isValid) { + return false; + } - std::string publicPort = boost::lexical_cast<std::string>(mapping.getPublicPort()); - std::string localPort = boost::lexical_cast<std::string>(mapping.getLocalPort()); - std::string leaseSeconds = boost::lexical_cast<std::string>(mapping.getLeaseInSeconds()); + std::string publicPort = std::to_string(mapping.getPublicPort()); + std::string localPort = std::to_string(mapping.getLocalPort()); + std::string leaseSeconds = std::to_string(mapping.getLeaseInSeconds()); - int ret = UPNP_DeletePortMapping(p->urls.controlURL, p->data.first.servicetype, publicPort.c_str(), mapping.getProtocol() == NATPortMapping::TCP ? "TCP" : "UDP", 0); - return ret == UPNPCOMMAND_SUCCESS; + int ret = UPNP_DeletePortMapping(p->urls.controlURL, p->data.first.servicetype, publicPort.c_str(), mapping.getProtocol() == NATPortMapping::TCP ? "TCP" : "UDP", nullptr); + return ret == UPNPCOMMAND_SUCCESS; } bool MiniUPnPInterface::isAvailable() { - return p->isValid; + return p->isValid; } } diff --git a/Swiften/Network/MiniUPnPInterface.h b/Swiften/Network/MiniUPnPInterface.h index 61d12ca..8c68268 100644 --- a/Swiften/Network/MiniUPnPInterface.h +++ b/Swiften/Network/MiniUPnPInterface.h @@ -1,32 +1,33 @@ /* - * Copyright (c) 2011 Remko Tronçon + * Copyright (c) 2011-2018 Isode Limited. * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ #pragma once -#include <boost/optional.hpp> +#include <memory> + #include <boost/noncopyable.hpp> -#include <boost/shared_ptr.hpp> +#include <boost/optional.hpp> #include <Swiften/Network/NATPortMapping.h> #include <Swiften/Network/NATTraversalInterface.h> namespace Swift { - class MiniUPnPInterface : public NATTraversalInterface, boost::noncopyable { - public: - MiniUPnPInterface(); - ~MiniUPnPInterface(); + class MiniUPnPInterface : public NATTraversalInterface, boost::noncopyable { + public: + MiniUPnPInterface(); + virtual ~MiniUPnPInterface(); - virtual bool isAvailable(); + virtual bool isAvailable(); - boost::optional<HostAddress> getPublicIP(); - boost::optional<NATPortMapping> addPortForward(int localPort, int publicPort); - bool removePortForward(const NATPortMapping&); + boost::optional<HostAddress> getPublicIP(); + boost::optional<NATPortMapping> addPortForward(unsigned short localPort, unsigned short publicPort); + bool removePortForward(const NATPortMapping&); - private: - struct Private; - boost::shared_ptr<Private> p; - }; + private: + struct Private; + const std::unique_ptr<Private> p; + }; } diff --git a/Swiften/Network/NATPMPInterface.cpp b/Swiften/Network/NATPMPInterface.cpp index c7a41ff..e20fecd 100644 --- a/Swiften/Network/NATPMPInterface.cpp +++ b/Swiften/Network/NATPMPInterface.cpp @@ -4,10 +4,15 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2014-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #include <Swiften/Network/NATPMPInterface.h> -#include <boost/smart_ptr/make_shared.hpp> -#include <boost/numeric/conversion/cast.hpp> +#include <memory> #include <Swiften/Base/Log.h> @@ -20,107 +25,117 @@ namespace Swift { struct NATPMPInterface::Private { - natpmp_t natpmp; + natpmp_t natpmp; }; -NATPMPInterface::NATPMPInterface() : p(boost::make_shared<Private>()) { - initnatpmp(&p->natpmp, 0, 0); +NATPMPInterface::NATPMPInterface() : p(new Private()) { + initnatpmp(&p->natpmp, 0, 0); } NATPMPInterface::~NATPMPInterface() { - closenatpmp(&p->natpmp); + closenatpmp(&p->natpmp); } bool NATPMPInterface::isAvailable() { - return getPublicIP(); + return getPublicIP() ? true : false; } boost::optional<HostAddress> NATPMPInterface::getPublicIP() { - if (sendpublicaddressrequest(&p->natpmp) < 0) { - SWIFT_LOG(debug) << "Failed to send NAT-PMP public address request!" << std::endl; - return boost::optional<HostAddress>(); - } - - int r = 0; - natpmpresp_t response; - do { - fd_set fds; - struct timeval timeout; - FD_ZERO(&fds); - FD_SET(p->natpmp.s, &fds); - getnatpmprequesttimeout(&p->natpmp, &timeout); - select(FD_SETSIZE, &fds, NULL, NULL, &timeout); - r = readnatpmpresponseorretry(&p->natpmp, &response); - } while (r == NATPMP_TRYAGAIN); - - if (r == 0) { - return boost::optional<HostAddress>(HostAddress(reinterpret_cast<const unsigned char*>(&(response.pnu.publicaddress.addr)), 4)); - } - else { - SWIFT_LOG(debug) << "Inavlid NAT-PMP response." << std::endl; - return boost::optional<HostAddress>(); - } + if (sendpublicaddressrequest(&p->natpmp) < 0) { + SWIFT_LOG(debug) << "Failed to send NAT-PMP public address request!"; + return boost::optional<HostAddress>(); + } + + int r = 0; + natpmpresp_t response; + do { + fd_set fds; + struct timeval timeout; + FD_ZERO(&fds); + FD_SET(p->natpmp.s, &fds); + getnatpmprequesttimeout(&p->natpmp, &timeout); + + // Limit NAT-PMP timeout to ten seconds. + timeout.tv_sec = 10; + timeout.tv_usec = 0; + + select(FD_SETSIZE, &fds, nullptr, nullptr, &timeout); + r = readnatpmpresponseorretry(&p->natpmp, &response); + } while (false /*r == NATPMP_TRYAGAIN*/); + + if (r == 0) { + return boost::optional<HostAddress>(HostAddress(reinterpret_cast<const unsigned char*>(&(response.pnu.publicaddress.addr)), 4)); + } + else { + SWIFT_LOG(debug) << "Inavlid NAT-PMP response."; + return boost::optional<HostAddress>(); + } } -boost::optional<NATPortMapping> NATPMPInterface::addPortForward(int localPort, int publicPort) { - NATPortMapping mapping(localPort, publicPort, NATPortMapping::TCP); - if (sendnewportmappingrequest( - &p->natpmp, - mapping.getProtocol() == NATPortMapping::TCP ? NATPMP_PROTOCOL_TCP : NATPMP_PROTOCOL_UDP, - boost::numeric_cast<uint16_t>(mapping.getLocalPort()), - boost::numeric_cast<uint16_t>(mapping.getPublicPort()), - boost::numeric_cast<uint32_t>(mapping.getLeaseInSeconds())) < 0) { - SWIFT_LOG(debug) << "Failed to send NAT-PMP port forwarding request!" << std::endl; - return boost::optional<NATPortMapping>(); - } - - int r = 0; - natpmpresp_t response; - do { - fd_set fds; - struct timeval timeout; - FD_ZERO(&fds); - FD_SET(p->natpmp.s, &fds); - getnatpmprequesttimeout(&p->natpmp, &timeout); - select(FD_SETSIZE, &fds, NULL, NULL, &timeout); - r = readnatpmpresponseorretry(&p->natpmp, &response); - } while(r == NATPMP_TRYAGAIN); - - if (r == 0) { - NATPortMapping result(response.pnu.newportmapping.privateport, response.pnu.newportmapping.mappedpublicport, NATPortMapping::TCP, boost::numeric_cast<int>(response.pnu.newportmapping.lifetime)); - return result; - } - else { - SWIFT_LOG(debug) << "Invalid NAT-PMP response." << std::endl; - return boost::optional<NATPortMapping>(); - } +boost::optional<NATPortMapping> NATPMPInterface::addPortForward(unsigned short localPort, unsigned short publicPort) { + NATPortMapping mapping(localPort, publicPort, NATPortMapping::TCP); + if (sendnewportmappingrequest( + &p->natpmp, + mapping.getProtocol() == NATPortMapping::TCP ? NATPMP_PROTOCOL_TCP : NATPMP_PROTOCOL_UDP, + mapping.getLocalPort(), + mapping.getPublicPort(), + mapping.getLeaseInSeconds()) < 0) { + SWIFT_LOG(debug) << "Failed to send NAT-PMP port forwarding request!"; + return boost::optional<NATPortMapping>(); + } + + int r = 0; + natpmpresp_t response; + do { + fd_set fds; + struct timeval timeout; + FD_ZERO(&fds); + FD_SET(p->natpmp.s, &fds); + getnatpmprequesttimeout(&p->natpmp, &timeout); + + // Limit NAT-PMP timeout to ten seconds. + timeout.tv_sec = 10; + timeout.tv_usec = 0; + + select(FD_SETSIZE, &fds, nullptr, nullptr, &timeout); + r = readnatpmpresponseorretry(&p->natpmp, &response); + } while(false /*r == NATPMP_TRYAGAIN*/); + + if (r == 0) { + NATPortMapping result(response.pnu.newportmapping.privateport, response.pnu.newportmapping.mappedpublicport, NATPortMapping::TCP, response.pnu.newportmapping.lifetime); + return result; + } + else { + SWIFT_LOG(debug) << "Invalid NAT-PMP response."; + return boost::optional<NATPortMapping>(); + } } bool NATPMPInterface::removePortForward(const NATPortMapping& mapping) { - if (sendnewportmappingrequest(&p->natpmp, mapping.getProtocol() == NATPortMapping::TCP ? NATPMP_PROTOCOL_TCP : NATPMP_PROTOCOL_UDP, 0, 0, boost::numeric_cast<uint32_t>(mapping.getLocalPort())) < 0) { - SWIFT_LOG(debug) << "Failed to send NAT-PMP remove forwarding request!" << std::endl; - return false; - } - - int r = 0; - natpmpresp_t response; - do { - fd_set fds; - struct timeval timeout; - FD_ZERO(&fds); - FD_SET(p->natpmp.s, &fds); - getnatpmprequesttimeout(&p->natpmp, &timeout); - select(FD_SETSIZE, &fds, NULL, NULL, &timeout); - r = readnatpmpresponseorretry(&p->natpmp, &response); - } while(r == NATPMP_TRYAGAIN); - - if (r == 0) { - return true; - } - else { - SWIFT_LOG(debug) << "Invalid NAT-PMP response." << std::endl; - return false; - } + if (sendnewportmappingrequest(&p->natpmp, mapping.getProtocol() == NATPortMapping::TCP ? NATPMP_PROTOCOL_TCP : NATPMP_PROTOCOL_UDP, mapping.getLocalPort(), 0, 0) < 0) { + SWIFT_LOG(debug) << "Failed to send NAT-PMP remove forwarding request!"; + return false; + } + + int r = 0; + natpmpresp_t response; + do { + fd_set fds; + struct timeval timeout; + FD_ZERO(&fds); + FD_SET(p->natpmp.s, &fds); + getnatpmprequesttimeout(&p->natpmp, &timeout); + select(FD_SETSIZE, &fds, nullptr, nullptr, &timeout); + r = readnatpmpresponseorretry(&p->natpmp, &response); + } while(r == NATPMP_TRYAGAIN); + + if (r == 0) { + return true; + } + else { + SWIFT_LOG(debug) << "Invalid NAT-PMP response."; + return false; + } } diff --git a/Swiften/Network/NATPMPInterface.h b/Swiften/Network/NATPMPInterface.h index e079a59..58d62b6 100644 --- a/Swiften/Network/NATPMPInterface.h +++ b/Swiften/Network/NATPMPInterface.h @@ -1,31 +1,33 @@ /* - * Copyright (c) 2011 Remko Tronçon + * Copyright (c) 2011-2018 Isode Limited. * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ #pragma once -#include <boost/optional.hpp> -#include <boost/shared_ptr.hpp> +#include <memory> + #include <boost/noncopyable.hpp> +#include <boost/optional.hpp> + #include <Swiften/Network/NATPortMapping.h> #include <Swiften/Network/NATTraversalInterface.h> namespace Swift { - class NATPMPInterface : public NATTraversalInterface, boost::noncopyable { - public: - NATPMPInterface(); - ~NATPMPInterface(); + class NATPMPInterface : public NATTraversalInterface, boost::noncopyable { + public: + NATPMPInterface(); + virtual ~NATPMPInterface(); - virtual bool isAvailable(); + virtual bool isAvailable(); - virtual boost::optional<HostAddress> getPublicIP(); - virtual boost::optional<NATPortMapping> addPortForward(int localPort, int publicPort); - virtual bool removePortForward(const NATPortMapping&); + virtual boost::optional<HostAddress> getPublicIP(); + virtual boost::optional<NATPortMapping> addPortForward(unsigned short localPort, unsigned short publicPort); + virtual bool removePortForward(const NATPortMapping&); - private: - struct Private; - boost::shared_ptr<Private> p; - }; + private: + struct Private; + const std::unique_ptr<Private> p; + }; } diff --git a/Swiften/Network/NATPortMapping.h b/Swiften/Network/NATPortMapping.h index 0f6bd95..bf0fb1c 100644 --- a/Swiften/Network/NATPortMapping.h +++ b/Swiften/Network/NATPortMapping.h @@ -4,43 +4,50 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2015-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once +#include <Swiften/Base/API.h> #include <Swiften/Network/HostAddress.h> namespace Swift { - class NATPortMapping { - public: - enum Protocol { - TCP, - UDP - }; - - NATPortMapping(int localPort, int publicPort, Protocol protocol = TCP, int leaseInSeconds = 60 * 60 * 24) : - publicPort(publicPort), localPort(localPort), protocol(protocol), leaseInSeconds(leaseInSeconds) { - - } - - int getPublicPort() const { - return publicPort; - } - - int getLocalPort() const { - return localPort; - } - - Protocol getProtocol() const { - return protocol; - } - - int getLeaseInSeconds() const { - return leaseInSeconds; - } - - private: - int publicPort; - int localPort; - Protocol protocol; - int leaseInSeconds; - }; + class SWIFTEN_API NATPortMapping { + public: + enum Protocol { + TCP, + UDP + }; + + NATPortMapping(unsigned short localPort, unsigned short publicPort, Protocol protocol = TCP, uint32_t leaseInSeconds = 60 * 60 * 24) : + publicPort(publicPort), localPort(localPort), protocol(protocol), leaseInSeconds(leaseInSeconds) { + + } + + unsigned short getPublicPort() const { + return publicPort; + } + + unsigned short getLocalPort() const { + return localPort; + } + + Protocol getProtocol() const { + return protocol; + } + + uint32_t getLeaseInSeconds() const { + return leaseInSeconds; + } + + private: + unsigned short publicPort; + unsigned short localPort; + Protocol protocol; + uint32_t leaseInSeconds; + }; } diff --git a/Swiften/Network/NATTraversalForwardPortRequest.h b/Swiften/Network/NATTraversalForwardPortRequest.h index 48f85ea..0f9c62c 100644 --- a/Swiften/Network/NATTraversalForwardPortRequest.h +++ b/Swiften/Network/NATTraversalForwardPortRequest.h @@ -4,21 +4,27 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once -#include <Swiften/Base/boost_bsignals.h> +#include <boost/signals2.hpp> #include <Swiften/Base/API.h> #include <Swiften/Network/NATPortMapping.h> namespace Swift { - class SWIFTEN_API NATTraversalForwardPortRequest { - public: - virtual ~NATTraversalForwardPortRequest(); + class SWIFTEN_API NATTraversalForwardPortRequest { + public: + virtual ~NATTraversalForwardPortRequest(); - virtual void start() = 0; - virtual void stop() = 0; + virtual void start() = 0; + virtual void stop() = 0; - boost::signal<void (boost::optional<NATPortMapping>)> onResult; - }; + boost::signals2::signal<void (boost::optional<NATPortMapping>)> onResult; + }; } diff --git a/Swiften/Network/NATTraversalGetPublicIPRequest.h b/Swiften/Network/NATTraversalGetPublicIPRequest.h index 1270db3..8b34e0f 100644 --- a/Swiften/Network/NATTraversalGetPublicIPRequest.h +++ b/Swiften/Network/NATTraversalGetPublicIPRequest.h @@ -4,19 +4,27 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2015-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once -#include <Swiften/Base/boost_bsignals.h> +#include <boost/signals2.hpp> + +#include <Swiften/Base/API.h> #include <Swiften/Network/HostAddress.h> namespace Swift { - class NATTraversalGetPublicIPRequest { - public: - virtual ~NATTraversalGetPublicIPRequest(); + class SWIFTEN_API NATTraversalGetPublicIPRequest { + public: + virtual ~NATTraversalGetPublicIPRequest(); - virtual void start() = 0; - virtual void stop() = 0; + virtual void start() = 0; + virtual void stop() = 0; - boost::signal<void (boost::optional<HostAddress>)> onResult; - }; + boost::signals2::signal<void (boost::optional<HostAddress>)> onResult; + }; } diff --git a/Swiften/Network/NATTraversalInterface.cpp b/Swiften/Network/NATTraversalInterface.cpp index f8a0cc2..18ee843 100644 --- a/Swiften/Network/NATTraversalInterface.cpp +++ b/Swiften/Network/NATTraversalInterface.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2011 Remko Tronçon + * Copyright (c) 2011-2016 Isode Limited. * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ @@ -8,7 +8,6 @@ #include <Swiften/Base/Log.h> - namespace Swift { NATTraversalInterface::~NATTraversalInterface() { diff --git a/Swiften/Network/NATTraversalInterface.h b/Swiften/Network/NATTraversalInterface.h index c84deba..1655eb6 100644 --- a/Swiften/Network/NATTraversalInterface.h +++ b/Swiften/Network/NATTraversalInterface.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2011 Remko Tronçon + * Copyright (c) 2011-2018 Isode Limited. * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ @@ -8,17 +8,18 @@ #include <boost/optional.hpp> +#include <Swiften/Base/API.h> #include <Swiften/Network/NATPortMapping.h> namespace Swift { - class NATTraversalInterface { - public: - virtual ~NATTraversalInterface(); + class SWIFTEN_API NATTraversalInterface { + public: + virtual ~NATTraversalInterface(); - virtual bool isAvailable() = 0; + virtual bool isAvailable() = 0; - virtual boost::optional<HostAddress> getPublicIP() = 0; - virtual boost::optional<NATPortMapping> addPortForward(int localPort, int publicPort) = 0; - virtual bool removePortForward(const NATPortMapping&) = 0; - }; + virtual boost::optional<HostAddress> getPublicIP() = 0; + virtual boost::optional<NATPortMapping> addPortForward(unsigned short localPort, unsigned short publicPort) = 0; + virtual bool removePortForward(const NATPortMapping&) = 0; + }; } diff --git a/Swiften/Network/NATTraversalRemovePortForwardingRequest.h b/Swiften/Network/NATTraversalRemovePortForwardingRequest.h index 210cbcb..83235f9 100644 --- a/Swiften/Network/NATTraversalRemovePortForwardingRequest.h +++ b/Swiften/Network/NATTraversalRemovePortForwardingRequest.h @@ -4,32 +4,40 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2015-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once -#include <Swiften/Base/boost_bsignals.h> +#include <boost/signals2.hpp> + +#include <Swiften/Base/API.h> #include <Swiften/Network/HostAddress.h> namespace Swift { - class NATTraversalRemovePortForwardingRequest { - public: - struct PortMapping { - enum Protocol { - TCP, - UDP - }; - - unsigned int publicPort; - unsigned int localPort; - Protocol protocol; - unsigned long leaseInSeconds; - }; - - public: - virtual ~NATTraversalRemovePortForwardingRequest(); - - virtual void start() = 0; - virtual void stop() = 0; - - boost::signal<void (boost::optional<bool> /* failure */)> onResult; - }; + class SWIFTEN_API NATTraversalRemovePortForwardingRequest { + public: + struct PortMapping { + enum Protocol { + TCP, + UDP + }; + + unsigned short publicPort; + unsigned short localPort; + Protocol protocol; + unsigned long leaseInSeconds; + }; + + public: + virtual ~NATTraversalRemovePortForwardingRequest(); + + virtual void start() = 0; + virtual void stop() = 0; + + boost::signals2::signal<void (boost::optional<bool> /* failure */)> onResult; + }; } diff --git a/Swiften/Network/NATTraverser.cpp b/Swiften/Network/NATTraverser.cpp index 8c628ee..824ae73 100644 --- a/Swiften/Network/NATTraverser.cpp +++ b/Swiften/Network/NATTraverser.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2011 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/NATTraverser.h> diff --git a/Swiften/Network/NATTraverser.h b/Swiften/Network/NATTraverser.h index e48ce26..7f03c03 100644 --- a/Swiften/Network/NATTraverser.h +++ b/Swiften/Network/NATTraverser.h @@ -1,24 +1,26 @@ /* - * Copyright (c) 2011 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> + +#include <Swiften/Base/API.h> namespace Swift { - class NATTraversalGetPublicIPRequest; - class NATTraversalForwardPortRequest; - class NATTraversalRemovePortForwardingRequest; + class NATTraversalGetPublicIPRequest; + class NATTraversalForwardPortRequest; + class NATTraversalRemovePortForwardingRequest; - class NATTraverser { - public: - virtual ~NATTraverser(); + class SWIFTEN_API NATTraverser { + public: + virtual ~NATTraverser(); - virtual boost::shared_ptr<NATTraversalGetPublicIPRequest> createGetPublicIPRequest() = 0; - virtual boost::shared_ptr<NATTraversalForwardPortRequest> createForwardPortRequest(int localPort, int publicPort) = 0; - virtual boost::shared_ptr<NATTraversalRemovePortForwardingRequest> createRemovePortForwardingRequest(int localPort, int publicPort) = 0; - }; + virtual std::shared_ptr<NATTraversalGetPublicIPRequest> createGetPublicIPRequest() = 0; + virtual std::shared_ptr<NATTraversalForwardPortRequest> createForwardPortRequest(unsigned short localPort, unsigned short publicPort) = 0; + virtual std::shared_ptr<NATTraversalRemovePortForwardingRequest> createRemovePortForwardingRequest(unsigned short localPort, unsigned short publicPort) = 0; + }; } diff --git a/Swiften/Network/NetworkEnvironment.cpp b/Swiften/Network/NetworkEnvironment.cpp index 52ceb01..87883c1 100644 --- a/Swiften/Network/NetworkEnvironment.cpp +++ b/Swiften/Network/NetworkEnvironment.cpp @@ -1,14 +1,13 @@ /* - * Copyright (c) 2011 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/NetworkEnvironment.h> -#include <Swiften/Network/NetworkInterface.h> #include <Swiften/Network/HostAddress.h> -#include <Swiften/Base/foreach.h> +#include <Swiften/Network/NetworkInterface.h> namespace Swift { @@ -16,17 +15,17 @@ NetworkEnvironment::~NetworkEnvironment() { } HostAddress NetworkEnvironment::getLocalAddress() const { - std::vector<NetworkInterface> networkInterfaces = getNetworkInterfaces(); - foreach (const NetworkInterface& iface, networkInterfaces) { - if (!iface.isLoopback()) { - foreach (const HostAddress& address, iface.getAddresses()) { - if (address.getRawAddress().is_v4()) { - return address; - } - } - } - } - return HostAddress(); + std::vector<NetworkInterface> networkInterfaces = getNetworkInterfaces(); + for (const auto& iface : networkInterfaces) { + if (!iface.isLoopback()) { + for (const auto& address : iface.getAddresses()) { + if (address.getRawAddress().is_v4()) { + return address; + } + } + } + } + return HostAddress(); } } diff --git a/Swiften/Network/NetworkEnvironment.h b/Swiften/Network/NetworkEnvironment.h index 36a2bde..0f68c29 100644 --- a/Swiften/Network/NetworkEnvironment.h +++ b/Swiften/Network/NetworkEnvironment.h @@ -4,21 +4,28 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once #include <vector> +#include <boost/signals2.hpp> + #include <Swiften/Base/API.h> -#include <Swiften/Base/boost_bsignals.h> #include <Swiften/Network/NetworkInterface.h> namespace Swift { - class SWIFTEN_API NetworkEnvironment { - public: - virtual ~NetworkEnvironment(); + class SWIFTEN_API NetworkEnvironment { + public: + virtual ~NetworkEnvironment(); - virtual std::vector<NetworkInterface> getNetworkInterfaces() const = 0; + virtual std::vector<NetworkInterface> getNetworkInterfaces() const = 0; - HostAddress getLocalAddress() const; - }; + HostAddress getLocalAddress() const; + }; } diff --git a/Swiften/Network/NetworkFactories.cpp b/Swiften/Network/NetworkFactories.cpp index 7046fd3..0380c90 100644 --- a/Swiften/Network/NetworkFactories.cpp +++ b/Swiften/Network/NetworkFactories.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/NetworkFactories.h> diff --git a/Swiften/Network/NetworkFactories.h b/Swiften/Network/NetworkFactories.h index dd8e216..f31c448 100644 --- a/Swiften/Network/NetworkFactories.h +++ b/Swiften/Network/NetworkFactories.h @@ -1,44 +1,45 @@ /* - * Copyright (c) 2010-2013 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2017 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once +#include <Swiften/Base/API.h> + namespace Swift { - class TimerFactory; - class ConnectionFactory; - class DomainNameResolver; - class ConnectionServerFactory; - class NATTraverser; - class XMLParserFactory; - class TLSContextFactory; - class CertificateFactory; - class ProxyProvider; - class EventLoop; - class IDNConverter; - class NetworkEnvironment; - class CryptoProvider; + class ConnectionFactory; + class ConnectionServerFactory; + class CryptoProvider; + class DomainNameResolver; + class EventLoop; + class IDNConverter; + class NATTraverser; + class NetworkEnvironment; + class ProxyProvider; + class TLSContextFactory; + class TimerFactory; + class XMLParserFactory; - /** - * An interface collecting network factories. - */ - class NetworkFactories { - public: - virtual ~NetworkFactories(); + /** + * An interface collecting network factories. + */ + class SWIFTEN_API NetworkFactories { + public: + virtual ~NetworkFactories(); - virtual TimerFactory* getTimerFactory() const = 0; - virtual ConnectionFactory* getConnectionFactory() const = 0; - virtual DomainNameResolver* getDomainNameResolver() const = 0; - virtual ConnectionServerFactory* getConnectionServerFactory() const = 0; - virtual NATTraverser* getNATTraverser() const = 0; - virtual NetworkEnvironment* getNetworkEnvironment() const = 0; - virtual XMLParserFactory* getXMLParserFactory() const = 0; - virtual TLSContextFactory* getTLSContextFactory() const = 0; - virtual ProxyProvider* getProxyProvider() const = 0; - virtual EventLoop* getEventLoop() const = 0; - virtual IDNConverter* getIDNConverter() const = 0; - virtual CryptoProvider* getCryptoProvider() const = 0; - }; + virtual TimerFactory* getTimerFactory() const = 0; + virtual ConnectionFactory* getConnectionFactory() const = 0; + virtual DomainNameResolver* getDomainNameResolver() const = 0; + virtual ConnectionServerFactory* getConnectionServerFactory() const = 0; + virtual NATTraverser* getNATTraverser() const = 0; + virtual NetworkEnvironment* getNetworkEnvironment() const = 0; + virtual XMLParserFactory* getXMLParserFactory() const = 0; + virtual TLSContextFactory* getTLSContextFactory() const = 0; + virtual ProxyProvider* getProxyProvider() const = 0; + virtual EventLoop* getEventLoop() const = 0; + virtual IDNConverter* getIDNConverter() const = 0; + virtual CryptoProvider* getCryptoProvider() const = 0; + }; } diff --git a/Swiften/Network/NetworkInterface.h b/Swiften/Network/NetworkInterface.h index 1d302cb..91aefc4 100644 --- a/Swiften/Network/NetworkInterface.h +++ b/Swiften/Network/NetworkInterface.h @@ -4,37 +4,44 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2015 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once #include <vector> +#include <Swiften/Base/API.h> #include <Swiften/Network/HostAddress.h> namespace Swift { - class NetworkInterface { - public: - NetworkInterface(const std::string& name, bool loopback) : name(name), loopback(loopback) { - } - - void addAddress(const HostAddress& address) { - addresses.push_back(address); - } - - const std::vector<HostAddress>& getAddresses() const { - return addresses; - } - - const std::string& getName() const { - return name; - } - - bool isLoopback() const { - return loopback; - } - - private: - std::string name; - bool loopback; - std::vector<HostAddress> addresses; - }; + class SWIFTEN_API NetworkInterface { + public: + NetworkInterface(const std::string& name, bool loopback) : name(name), loopback(loopback) { + } + + void addAddress(const HostAddress& address) { + addresses.push_back(address); + } + + const std::vector<HostAddress>& getAddresses() const { + return addresses; + } + + const std::string& getName() const { + return name; + } + + bool isLoopback() const { + return loopback; + } + + private: + std::string name; + bool loopback; + std::vector<HostAddress> addresses; + }; } diff --git a/Swiften/Network/NullNATTraversalInterface.h b/Swiften/Network/NullNATTraversalInterface.h index 72a4a08..eabc197 100644 --- a/Swiften/Network/NullNATTraversalInterface.h +++ b/Swiften/Network/NullNATTraversalInterface.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2011 Remko Tronçon + * Copyright (c) 2011 Isode Limited. * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ @@ -11,22 +11,22 @@ #include <Swiften/Network/NATTraversalInterface.h> namespace Swift { - class NullNATTraversalInterface : public NATTraversalInterface { - public: - virtual bool isAvailable() { - return true; - } + class NullNATTraversalInterface : public NATTraversalInterface { + public: + virtual bool isAvailable() { + return true; + } - virtual boost::optional<HostAddress> getPublicIP() { - return boost::optional<HostAddress>(); - } + virtual boost::optional<HostAddress> getPublicIP() { + return boost::optional<HostAddress>(); + } - virtual boost::optional<NATPortMapping> addPortForward(int, int) { - return boost::optional<NATPortMapping>(); - } + virtual boost::optional<NATPortMapping> addPortForward(unsigned short, unsigned short) { + return boost::optional<NATPortMapping>(); + } - virtual bool removePortForward(const NATPortMapping&) { - return false; - } - }; + virtual bool removePortForward(const NATPortMapping&) { + return false; + } + }; } diff --git a/Swiften/Network/NullNATTraverser.cpp b/Swiften/Network/NullNATTraverser.cpp index 43fcd08..0b9464e 100644 --- a/Swiften/Network/NullNATTraverser.cpp +++ b/Swiften/Network/NullNATTraverser.cpp @@ -1,82 +1,83 @@ /* - * Copyright (c) 2011 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/NullNATTraverser.h> -#include <boost/smart_ptr/make_shared.hpp> +#include <memory> + #include <boost/bind.hpp> -#include <Swiften/Network/NATTraversalGetPublicIPRequest.h> +#include <Swiften/EventLoop/EventLoop.h> #include <Swiften/Network/NATTraversalForwardPortRequest.h> +#include <Swiften/Network/NATTraversalGetPublicIPRequest.h> #include <Swiften/Network/NATTraversalRemovePortForwardingRequest.h> -#include <Swiften/EventLoop/EventLoop.h> namespace Swift { class NullNATTraversalGetPublicIPRequest : public NATTraversalGetPublicIPRequest { - public: - NullNATTraversalGetPublicIPRequest(EventLoop* eventLoop) : eventLoop(eventLoop) { - } + public: + NullNATTraversalGetPublicIPRequest(EventLoop* eventLoop) : eventLoop(eventLoop) { + } - virtual void start() { - eventLoop->postEvent(boost::bind(boost::ref(onResult), boost::optional<HostAddress>())); - } + virtual void start() { + eventLoop->postEvent(boost::bind(boost::ref(onResult), boost::optional<HostAddress>())); + } - virtual void stop() { - } + virtual void stop() { + } - private: - EventLoop* eventLoop; + private: + EventLoop* eventLoop; }; class NullNATTraversalForwardPortRequest : public NATTraversalForwardPortRequest { - public: - NullNATTraversalForwardPortRequest(EventLoop* eventLoop) : eventLoop(eventLoop) { - } + public: + NullNATTraversalForwardPortRequest(EventLoop* eventLoop) : eventLoop(eventLoop) { + } - virtual void start() { - eventLoop->postEvent(boost::bind(boost::ref(onResult), boost::optional<NATPortMapping>())); - } + virtual void start() { + eventLoop->postEvent(boost::bind(boost::ref(onResult), boost::optional<NATPortMapping>())); + } - virtual void stop() { - } + virtual void stop() { + } - private: - EventLoop* eventLoop; + private: + EventLoop* eventLoop; }; class NullNATTraversalRemovePortForwardingRequest : public NATTraversalRemovePortForwardingRequest { - public: - NullNATTraversalRemovePortForwardingRequest(EventLoop* eventLoop) : eventLoop(eventLoop) { - } + public: + NullNATTraversalRemovePortForwardingRequest(EventLoop* eventLoop) : eventLoop(eventLoop) { + } - virtual void start() { - eventLoop->postEvent(boost::bind(boost::ref(onResult), boost::optional<bool>(true))); - } + virtual void start() { + eventLoop->postEvent(boost::bind(boost::ref(onResult), boost::optional<bool>(true))); + } - virtual void stop() { - } + virtual void stop() { + } - private: - EventLoop* eventLoop; + private: + EventLoop* eventLoop; }; NullNATTraverser::NullNATTraverser(EventLoop* eventLoop) : eventLoop(eventLoop) { } -boost::shared_ptr<NATTraversalGetPublicIPRequest> NullNATTraverser::createGetPublicIPRequest() { - return boost::make_shared<NullNATTraversalGetPublicIPRequest>(eventLoop); +std::shared_ptr<NATTraversalGetPublicIPRequest> NullNATTraverser::createGetPublicIPRequest() { + return std::make_shared<NullNATTraversalGetPublicIPRequest>(eventLoop); } -boost::shared_ptr<NATTraversalForwardPortRequest> NullNATTraverser::createForwardPortRequest(int, int) { - return boost::make_shared<NullNATTraversalForwardPortRequest>(eventLoop); +std::shared_ptr<NATTraversalForwardPortRequest> NullNATTraverser::createForwardPortRequest(unsigned short, unsigned short) { + return std::make_shared<NullNATTraversalForwardPortRequest>(eventLoop); } -boost::shared_ptr<NATTraversalRemovePortForwardingRequest> NullNATTraverser::createRemovePortForwardingRequest(int, int) { - return boost::make_shared<NullNATTraversalRemovePortForwardingRequest>(eventLoop); +std::shared_ptr<NATTraversalRemovePortForwardingRequest> NullNATTraverser::createRemovePortForwardingRequest(unsigned short, unsigned short) { + return std::make_shared<NullNATTraversalRemovePortForwardingRequest>(eventLoop); } } diff --git a/Swiften/Network/NullNATTraverser.h b/Swiften/Network/NullNATTraverser.h index 5775a9b..2f975bf 100644 --- a/Swiften/Network/NullNATTraverser.h +++ b/Swiften/Network/NullNATTraverser.h @@ -1,7 +1,7 @@ /* - * Copyright (c) 2011 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once @@ -9,17 +9,17 @@ #include <Swiften/Network/NATTraverser.h> namespace Swift { - class EventLoop; + class EventLoop; - class NullNATTraverser : public NATTraverser { - public: - NullNATTraverser(EventLoop* eventLoop); + class NullNATTraverser : public NATTraverser { + public: + NullNATTraverser(EventLoop* eventLoop); - boost::shared_ptr<NATTraversalGetPublicIPRequest> createGetPublicIPRequest(); - boost::shared_ptr<NATTraversalForwardPortRequest> createForwardPortRequest(int localPort, int publicPort); - boost::shared_ptr<NATTraversalRemovePortForwardingRequest> createRemovePortForwardingRequest(int localPort, int publicPort); + std::shared_ptr<NATTraversalGetPublicIPRequest> createGetPublicIPRequest(); + std::shared_ptr<NATTraversalForwardPortRequest> createForwardPortRequest(unsigned short localPort, unsigned short publicPort); + std::shared_ptr<NATTraversalRemovePortForwardingRequest> createRemovePortForwardingRequest(unsigned short localPort, unsigned short publicPort); - private: - EventLoop* eventLoop; - }; + private: + EventLoop* eventLoop; + }; } diff --git a/Swiften/Network/NullProxyProvider.cpp b/Swiften/Network/NullProxyProvider.cpp index 3b9d94d..32a1b9d 100644 --- a/Swiften/Network/NullProxyProvider.cpp +++ b/Swiften/Network/NullProxyProvider.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2011 Remko Tronçon + * Copyright (c) 2011 Isode Limited. * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ @@ -12,9 +12,9 @@ NullProxyProvider::NullProxyProvider() { } HostAddressPort NullProxyProvider::getHTTPConnectProxy() const { - return HostAddressPort(); + return HostAddressPort(); } HostAddressPort NullProxyProvider::getSOCKS5Proxy() const { - return HostAddressPort(); + return HostAddressPort(); } diff --git a/Swiften/Network/NullProxyProvider.h b/Swiften/Network/NullProxyProvider.h index 544bea2..ae7aaab 100644 --- a/Swiften/Network/NullProxyProvider.h +++ b/Swiften/Network/NullProxyProvider.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2011 Remko Tronçon + * Copyright (c) 2011 Isode Limited. * Licensed under the simplified BSD license. * See Documentation/Licenses/BSD-simplified.txt for more information. */ @@ -9,11 +9,11 @@ #include <Swiften/Network/ProxyProvider.h> namespace Swift { - class NullProxyProvider : public ProxyProvider { - public: - NullProxyProvider(); + class NullProxyProvider : public ProxyProvider { + public: + NullProxyProvider(); - virtual HostAddressPort getHTTPConnectProxy() const; - virtual HostAddressPort getSOCKS5Proxy() const; - }; + virtual HostAddressPort getHTTPConnectProxy() const; + virtual HostAddressPort getSOCKS5Proxy() const; + }; } diff --git a/Swiften/Network/PlatformDomainNameAddressQuery.cpp b/Swiften/Network/PlatformDomainNameAddressQuery.cpp index ec7e663..2d72146 100644 --- a/Swiften/Network/PlatformDomainNameAddressQuery.cpp +++ b/Swiften/Network/PlatformDomainNameAddressQuery.cpp @@ -1,58 +1,70 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2015 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/PlatformDomainNameAddressQuery.h> #include <boost/asio/ip/tcp.hpp> -#include <Swiften/Network/PlatformDomainNameResolver.h> #include <Swiften/EventLoop/EventLoop.h> +#include <Swiften/Network/PlatformDomainNameResolver.h> namespace Swift { -PlatformDomainNameAddressQuery::PlatformDomainNameAddressQuery(const std::string& host, EventLoop* eventLoop, PlatformDomainNameResolver* resolver) : PlatformDomainNameQuery(resolver), hostname(host), eventLoop(eventLoop) { +PlatformDomainNameAddressQuery::PlatformDomainNameAddressQuery(const boost::optional<std::string>& host, EventLoop* eventLoop, PlatformDomainNameResolver* resolver) : PlatformDomainNameQuery(resolver), hostnameValid(false), eventLoop(eventLoop) { + if (!!host) { + hostname = *host; + hostnameValid = true; + } +} + +PlatformDomainNameAddressQuery::~PlatformDomainNameAddressQuery() { + } void PlatformDomainNameAddressQuery::run() { - getResolver()->addQueryToQueue(shared_from_this()); + getResolver()->addQueryToQueue(shared_from_this()); } void PlatformDomainNameAddressQuery::runBlocking() { - //std::cout << "PlatformDomainNameResolver::doRun()" << std::endl; - boost::asio::ip::tcp::resolver resolver(ioService); - boost::asio::ip::tcp::resolver::query query(hostname, "5222"); - try { - //std::cout << "PlatformDomainNameResolver::doRun(): Resolving" << std::endl; - boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query); - //std::cout << "PlatformDomainNameResolver::doRun(): Resolved" << std::endl; - if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) { - //std::cout << "PlatformDomainNameResolver::doRun(): Error 1" << std::endl; - emitError(); - } - else { - std::vector<HostAddress> results; - for ( ; endpointIterator != boost::asio::ip::tcp::resolver::iterator(); ++endpointIterator) { - boost::asio::ip::address address = (*endpointIterator).endpoint().address(); - results.push_back(address.is_v4() ? HostAddress(&address.to_v4().to_bytes()[0], 4) : HostAddress(&address.to_v6().to_bytes()[0], 16)); - } - - //std::cout << "PlatformDomainNameResolver::doRun(): Success" << std::endl; - eventLoop->postEvent( - boost::bind(boost::ref(onResult), results, boost::optional<DomainNameResolveError>()), - shared_from_this()); - } - } - catch (...) { - //std::cout << "PlatformDomainNameResolver::doRun(): Error 2" << std::endl; - emitError(); - } + if (!hostnameValid) { + emitError(); + return; + } + //std::cout << "PlatformDomainNameResolver::doRun()" << std::endl; + boost::asio::ip::tcp::resolver resolver(ioService); + boost::asio::ip::tcp::resolver::query query(hostname, "5222", boost::asio::ip::resolver_query_base::passive); + try { + //std::cout << "PlatformDomainNameResolver::doRun(): Resolving" << std::endl; + boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query); + //std::cout << "PlatformDomainNameResolver::doRun(): Resolved" << std::endl; + if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) { + //std::cout << "PlatformDomainNameResolver::doRun(): Error 1" << std::endl; + emitError(); + } + else { + std::vector<HostAddress> results; + for ( ; endpointIterator != boost::asio::ip::tcp::resolver::iterator(); ++endpointIterator) { + boost::asio::ip::address address = (*endpointIterator).endpoint().address(); + results.push_back(address.is_v4() ? HostAddress(&address.to_v4().to_bytes()[0], 4) : HostAddress(&address.to_v6().to_bytes()[0], 16)); + } + + //std::cout << "PlatformDomainNameResolver::doRun(): Success" << std::endl; + eventLoop->postEvent( + boost::bind(boost::ref(onResult), results, boost::optional<DomainNameResolveError>()), + shared_from_this()); + } + } + catch (...) { + //std::cout << "PlatformDomainNameResolver::doRun(): Error 2" << std::endl; + emitError(); + } } void PlatformDomainNameAddressQuery::emitError() { - eventLoop->postEvent(boost::bind(boost::ref(onResult), std::vector<HostAddress>(), boost::optional<DomainNameResolveError>(DomainNameResolveError())), shared_from_this()); + eventLoop->postEvent(boost::bind(boost::ref(onResult), std::vector<HostAddress>(), boost::optional<DomainNameResolveError>(DomainNameResolveError())), shared_from_this()); } } diff --git a/Swiften/Network/PlatformDomainNameAddressQuery.h b/Swiften/Network/PlatformDomainNameAddressQuery.h index e1dc05f..6cb3e0a 100644 --- a/Swiften/Network/PlatformDomainNameAddressQuery.h +++ b/Swiften/Network/PlatformDomainNameAddressQuery.h @@ -1,38 +1,41 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once +#include <memory> +#include <string> + #include <boost/asio/io_service.hpp> -#include <boost/enable_shared_from_this.hpp> +#include <Swiften/EventLoop/EventOwner.h> #include <Swiften/Network/DomainNameAddressQuery.h> #include <Swiften/Network/PlatformDomainNameQuery.h> -#include <Swiften/EventLoop/EventOwner.h> -#include <string> namespace Swift { - class PlatformDomainNameResolver; - class EventLoop; - - class PlatformDomainNameAddressQuery : public DomainNameAddressQuery, public PlatformDomainNameQuery, public boost::enable_shared_from_this<PlatformDomainNameAddressQuery>, public EventOwner { - public: - PlatformDomainNameAddressQuery(const std::string& host, EventLoop* eventLoop, PlatformDomainNameResolver*); - - void run(); - - private: - void runBlocking(); - void emitError(); - - private: - boost::asio::io_service ioService; - std::string hostname; - EventLoop* eventLoop; - }; + class PlatformDomainNameResolver; + class EventLoop; + + class PlatformDomainNameAddressQuery : public DomainNameAddressQuery, public PlatformDomainNameQuery, public std::enable_shared_from_this<PlatformDomainNameAddressQuery>, public EventOwner { + public: + PlatformDomainNameAddressQuery(const boost::optional<std::string>& host, EventLoop* eventLoop, PlatformDomainNameResolver*); + virtual ~PlatformDomainNameAddressQuery(); + + void run(); + + private: + void runBlocking(); + void emitError(); + + private: + boost::asio::io_service ioService; + std::string hostname; + bool hostnameValid; + EventLoop* eventLoop; + }; } diff --git a/Swiften/Network/PlatformDomainNameQuery.h b/Swiften/Network/PlatformDomainNameQuery.h index bbfb1d1..a279f20 100644 --- a/Swiften/Network/PlatformDomainNameQuery.h +++ b/Swiften/Network/PlatformDomainNameQuery.h @@ -1,31 +1,31 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> namespace Swift { - class PlatformDomainNameResolver; + class PlatformDomainNameResolver; - class PlatformDomainNameQuery { - public: - typedef boost::shared_ptr<PlatformDomainNameQuery> ref; + class PlatformDomainNameQuery { + public: + typedef std::shared_ptr<PlatformDomainNameQuery> ref; - PlatformDomainNameQuery(PlatformDomainNameResolver* resolver) : resolver(resolver) {} - virtual ~PlatformDomainNameQuery() {} + PlatformDomainNameQuery(PlatformDomainNameResolver* resolver) : resolver(resolver) {} + virtual ~PlatformDomainNameQuery() {} - virtual void runBlocking() = 0; + virtual void runBlocking() = 0; - protected: - PlatformDomainNameResolver* getResolver() { - return resolver; - } + protected: + PlatformDomainNameResolver* getResolver() { + return resolver; + } - private: - PlatformDomainNameResolver* resolver; - }; + private: + PlatformDomainNameResolver* resolver; + }; } diff --git a/Swiften/Network/PlatformDomainNameResolver.cpp b/Swiften/Network/PlatformDomainNameResolver.cpp index 677f1d5..40e385d 100644 --- a/Swiften/Network/PlatformDomainNameResolver.cpp +++ b/Swiften/Network/PlatformDomainNameResolver.cpp @@ -1,76 +1,80 @@ /* - * Copyright (c) 2010-2013 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/PlatformDomainNameResolver.h> -// Putting this early on, because some system types conflict with thread -#include <Swiften/Network/PlatformDomainNameServiceQuery.h> - +#include <algorithm> +#include <mutex> #include <string> +#include <thread> #include <vector> + #include <boost/bind.hpp> -#include <boost/thread.hpp> -#include <algorithm> -#include <string> +#include <Swiften/EventLoop/EventLoop.h> #include <Swiften/IDN/IDNConverter.h> +#include <Swiften/Network/DomainNameAddressQuery.h> #include <Swiften/Network/HostAddress.h> -#include <Swiften/EventLoop/EventLoop.h> #include <Swiften/Network/HostAddressPort.h> -#include <Swiften/Network/DomainNameAddressQuery.h> #include <Swiften/Network/PlatformDomainNameAddressQuery.h> +#include <Swiften/Network/PlatformDomainNameServiceQuery.h> using namespace Swift; namespace Swift { PlatformDomainNameResolver::PlatformDomainNameResolver(IDNConverter* idnConverter, EventLoop* eventLoop) : idnConverter(idnConverter), eventLoop(eventLoop), stopRequested(false) { - thread = new boost::thread(boost::bind(&PlatformDomainNameResolver::run, this)); + thread = new std::thread(boost::bind(&PlatformDomainNameResolver::run, this)); } PlatformDomainNameResolver::~PlatformDomainNameResolver() { - stopRequested = true; - addQueryToQueue(boost::shared_ptr<PlatformDomainNameQuery>()); - thread->join(); - delete thread; + stopRequested = true; + addQueryToQueue(std::shared_ptr<PlatformDomainNameQuery>()); + thread->join(); + delete thread; } -boost::shared_ptr<DomainNameServiceQuery> PlatformDomainNameResolver::createServiceQuery(const std::string& name) { - return boost::shared_ptr<DomainNameServiceQuery>(new PlatformDomainNameServiceQuery(idnConverter->getIDNAEncoded(name), eventLoop, this)); +std::shared_ptr<DomainNameServiceQuery> PlatformDomainNameResolver::createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain) { + boost::optional<std::string> encodedDomain = idnConverter->getIDNAEncoded(domain); + std::string result; + if (encodedDomain) { + result = serviceLookupPrefix + *encodedDomain; + } + return std::make_shared<PlatformDomainNameServiceQuery>(result, eventLoop, this); } -boost::shared_ptr<DomainNameAddressQuery> PlatformDomainNameResolver::createAddressQuery(const std::string& name) { - return boost::shared_ptr<DomainNameAddressQuery>(new PlatformDomainNameAddressQuery(idnConverter->getIDNAEncoded(name), eventLoop, this)); +std::shared_ptr<DomainNameAddressQuery> PlatformDomainNameResolver::createAddressQuery(const std::string& name) { + return std::make_shared<PlatformDomainNameAddressQuery>(idnConverter->getIDNAEncoded(name), eventLoop, this); } void PlatformDomainNameResolver::run() { - while (!stopRequested) { - PlatformDomainNameQuery::ref query; - { - boost::unique_lock<boost::mutex> lock(queueMutex); - while (queue.empty()) { - queueNonEmpty.wait(lock); - } - query = queue.front(); - queue.pop_front(); - } - // Check whether we don't have a non-null query (used to stop the - // resolver) - if (query) { - query->runBlocking(); - } - } + while (!stopRequested) { + PlatformDomainNameQuery::ref query; + { + std::unique_lock<std::mutex> lock(queueMutex); + while (queue.empty()) { + queueNonEmpty.wait(lock); + } + query = queue.front(); + queue.pop_front(); + } + // Check whether we don't have a non-null query (used to stop the + // resolver) + if (query) { + query->runBlocking(); + } + } } void PlatformDomainNameResolver::addQueryToQueue(PlatformDomainNameQuery::ref query) { - { - boost::lock_guard<boost::mutex> lock(queueMutex); - queue.push_back(query); - } - queueNonEmpty.notify_one(); + { + std::lock_guard<std::mutex> lock(queueMutex); + queue.push_back(query); + } + queueNonEmpty.notify_one(); } } diff --git a/Swiften/Network/PlatformDomainNameResolver.h b/Swiften/Network/PlatformDomainNameResolver.h index 25d87cf..4ddb999 100644 --- a/Swiften/Network/PlatformDomainNameResolver.h +++ b/Swiften/Network/PlatformDomainNameResolver.h @@ -1,47 +1,48 @@ /* - * Copyright (c) 2010-2013 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once +#include <condition_variable> #include <deque> -#include <boost/thread/thread.hpp> -#include <boost/thread/mutex.hpp> -#include <boost/thread/condition_variable.hpp> +#include <mutex> +#include <thread> #include <Swiften/Base/API.h> +#include <Swiften/Base/Atomic.h> +#include <Swiften/Network/DomainNameAddressQuery.h> #include <Swiften/Network/DomainNameResolver.h> -#include <Swiften/Network/PlatformDomainNameQuery.h> #include <Swiften/Network/DomainNameServiceQuery.h> -#include <Swiften/Network/DomainNameAddressQuery.h> +#include <Swiften/Network/PlatformDomainNameQuery.h> namespace Swift { - class IDNConverter; - class EventLoop; - - class SWIFTEN_API PlatformDomainNameResolver : public DomainNameResolver { - public: - PlatformDomainNameResolver(IDNConverter* idnConverter, EventLoop* eventLoop); - ~PlatformDomainNameResolver(); - - virtual DomainNameServiceQuery::ref createServiceQuery(const std::string& name); - virtual DomainNameAddressQuery::ref createAddressQuery(const std::string& name); - - private: - void run(); - void addQueryToQueue(PlatformDomainNameQuery::ref); - - private: - friend class PlatformDomainNameServiceQuery; - friend class PlatformDomainNameAddressQuery; - IDNConverter* idnConverter; - EventLoop* eventLoop; - bool stopRequested; - boost::thread* thread; - std::deque<PlatformDomainNameQuery::ref> queue; - boost::mutex queueMutex; - boost::condition_variable queueNonEmpty; - }; + class IDNConverter; + class EventLoop; + + class SWIFTEN_API PlatformDomainNameResolver : public DomainNameResolver { + public: + PlatformDomainNameResolver(IDNConverter* idnConverter, EventLoop* eventLoop); + virtual ~PlatformDomainNameResolver(); + + virtual DomainNameServiceQuery::ref createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain); + virtual DomainNameAddressQuery::ref createAddressQuery(const std::string& name); + + private: + void run(); + void addQueryToQueue(PlatformDomainNameQuery::ref); + + private: + friend class PlatformDomainNameServiceQuery; + friend class PlatformDomainNameAddressQuery; + IDNConverter* idnConverter; + EventLoop* eventLoop; + Atomic<bool> stopRequested; + std::thread* thread; + std::deque<PlatformDomainNameQuery::ref> queue; + std::mutex queueMutex; + std::condition_variable queueNonEmpty; + }; } diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.cpp b/Swiften/Network/PlatformDomainNameServiceQuery.cpp index 5788d2f..2ff14e1 100644 --- a/Swiften/Network/PlatformDomainNameServiceQuery.cpp +++ b/Swiften/Network/PlatformDomainNameServiceQuery.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <boost/asio.hpp> @@ -29,145 +29,164 @@ #include <Swiften/Base/ByteArray.h> #include <Swiften/EventLoop/EventLoop.h> -#include <Swiften/Base/foreach.h> -#include <Swiften/Base/BoostRandomGenerator.h> #include <Swiften/Base/Log.h> +#include <Swiften/Base/StdRandomGenerator.h> #include <Swiften/Network/PlatformDomainNameResolver.h> using namespace Swift; namespace Swift { -PlatformDomainNameServiceQuery::PlatformDomainNameServiceQuery(const std::string& service, EventLoop* eventLoop, PlatformDomainNameResolver* resolver) : PlatformDomainNameQuery(resolver), eventLoop(eventLoop), service(service) { +PlatformDomainNameServiceQuery::PlatformDomainNameServiceQuery(const boost::optional<std::string>& serviceName, EventLoop* eventLoop, PlatformDomainNameResolver* resolver) : PlatformDomainNameQuery(resolver), eventLoop(eventLoop), serviceValid(false) { + if (!!serviceName) { + service = *serviceName; + serviceValid = true; + } +} + +PlatformDomainNameServiceQuery::~PlatformDomainNameServiceQuery() { + } void PlatformDomainNameServiceQuery::run() { - getResolver()->addQueryToQueue(shared_from_this()); + getResolver()->addQueryToQueue(shared_from_this()); } void PlatformDomainNameServiceQuery::runBlocking() { - SWIFT_LOG(debug) << "Querying " << service << std::endl; + if (!serviceValid) { + emitError(); + return; + } + + SWIFT_LOG(debug) << "Querying " << service; - std::vector<DomainNameServiceQuery::Result> records; + std::vector<DomainNameServiceQuery::Result> records; #if defined(SWIFTEN_PLATFORM_WINDOWS) - DNS_RECORD* responses; - // FIXME: This conversion doesn't work if unicode is deffed above - if (DnsQuery(service.c_str(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) { - emitError(); - return; - } - - DNS_RECORD* currentEntry = responses; - while (currentEntry) { - if (currentEntry->wType == DNS_TYPE_SRV) { - DomainNameServiceQuery::Result record; - record.priority = currentEntry->Data.SRV.wPriority; - record.weight = currentEntry->Data.SRV.wWeight; - record.port = currentEntry->Data.SRV.wPort; - - // The pNameTarget is actually a PCWSTR, so I would have expected this - // conversion to not work at all, but it does. - // Actually, it doesn't. Fix this and remove explicit cast - // Remove unicode undef above as well - record.hostname = std::string((const char*) currentEntry->Data.SRV.pNameTarget); - records.push_back(record); - } - currentEntry = currentEntry->pNext; - } - DnsRecordListFree(responses, DnsFreeRecordList); + DNS_RECORD* responses; + // FIXME: This conversion doesn't work if unicode is deffed above + if (DnsQuery(service.c_str(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) { + emitError(); + return; + } + + DNS_RECORD* currentEntry = responses; + while (currentEntry) { + if (currentEntry->wType == DNS_TYPE_SRV) { + DomainNameServiceQuery::Result record; + record.priority = currentEntry->Data.SRV.wPriority; + record.weight = currentEntry->Data.SRV.wWeight; + record.port = currentEntry->Data.SRV.wPort; + + // The pNameTarget is actually a PCWSTR, so I would have expected this + // conversion to not work at all, but it does. + // Actually, it doesn't. Fix this and remove explicit cast + // Remove unicode undef above as well + record.hostname = std::string((const char*) currentEntry->Data.SRV.pNameTarget); + records.push_back(record); + } + currentEntry = currentEntry->pNext; + } + DnsRecordListFree(responses, DnsFreeRecordList); #else - // Make sure we reinitialize the domain list every time - res_init(); - - ByteArray response; - response.resize(NS_PACKETSZ); - int responseLength = res_query(const_cast<char*>(service.c_str()), ns_c_in, ns_t_srv, reinterpret_cast<u_char*>(vecptr(response)), response.size()); - if (responseLength == -1) { - SWIFT_LOG(debug) << "Error" << std::endl; - emitError(); - return; - } - - // Parse header - HEADER* header = reinterpret_cast<HEADER*>(vecptr(response)); - unsigned char* messageStart = vecptr(response); - unsigned char* messageEnd = messageStart + responseLength; - unsigned char* currentEntry = messageStart + NS_HFIXEDSZ; - - // Skip over the queries - int queriesCount = ntohs(header->qdcount); - while (queriesCount > 0) { - int entryLength = dn_skipname(currentEntry, messageEnd); - if (entryLength < 0) { - emitError(); - return; - } - currentEntry += entryLength + NS_QFIXEDSZ; - queriesCount--; - } - - // Process the SRV answers - int answersCount = ntohs(header->ancount); - while (answersCount > 0) { - DomainNameServiceQuery::Result record; - - int entryLength = dn_skipname(currentEntry, messageEnd); - currentEntry += entryLength; - currentEntry += NS_RRFIXEDSZ; - - // Priority - if (currentEntry + 2 >= messageEnd) { - emitError(); - return; - } - record.priority = boost::numeric_cast<int>(ns_get16(currentEntry)); - currentEntry += 2; - - // Weight - if (currentEntry + 2 >= messageEnd) { - emitError(); - return; - } - record.weight = boost::numeric_cast<int>(ns_get16(currentEntry)); - currentEntry += 2; - - // Port - if (currentEntry + 2 >= messageEnd) { - emitError(); - return; - } - record.port = boost::numeric_cast<int>(ns_get16(currentEntry)); - currentEntry += 2; - - // Hostname - if (currentEntry >= messageEnd) { - emitError(); - return; - } - ByteArray entry; - entry.resize(NS_MAXDNAME); - entryLength = dn_expand(messageStart, messageEnd, currentEntry, reinterpret_cast<char*>(vecptr(entry)), entry.size()); - if (entryLength < 0) { - emitError(); - return; - } - record.hostname = std::string(reinterpret_cast<const char*>(vecptr(entry))); - records.push_back(record); - currentEntry += entryLength; - answersCount--; - } + // Make sure we reinitialize the domain list every time + res_init(); + + ByteArray response; + response.resize(NS_PACKETSZ); + int responseLength = res_query(const_cast<char*>(service.c_str()), ns_c_in, ns_t_srv, reinterpret_cast<u_char*>(vecptr(response)), response.size()); + if (responseLength == -1) { + SWIFT_LOG(debug) << "Error"; + emitError(); + return; + } + + // Parse header + HEADER* header = reinterpret_cast<HEADER*>(vecptr(response)); + unsigned char* messageStart = vecptr(response); + unsigned char* messageEnd = messageStart + responseLength; + unsigned char* currentEntry = messageStart + NS_HFIXEDSZ; + + // Skip over the queries + int queriesCount = ntohs(header->qdcount); + while (queriesCount > 0) { + int entryLength = dn_skipname(currentEntry, messageEnd); + if (entryLength < 0) { + emitError(); + return; + } + currentEntry += entryLength + NS_QFIXEDSZ; + queriesCount--; + } + + // Process the SRV answers + int answersCount = ntohs(header->ancount); + while (answersCount > 0) { + DomainNameServiceQuery::Result record; + + int entryLength = dn_skipname(currentEntry, messageEnd); + currentEntry += entryLength; + currentEntry += NS_RRFIXEDSZ; + + try { + // Priority + if (currentEntry + 2 >= messageEnd) { + emitError(); + return; + } + record.priority = boost::numeric_cast<int>(ns_get16(currentEntry)); + currentEntry += 2; + + // Weight + if (currentEntry + 2 >= messageEnd) { + emitError(); + return; + } + record.weight = boost::numeric_cast<int>(ns_get16(currentEntry)); + currentEntry += 2; + + // Port + if (currentEntry + 2 >= messageEnd) { + emitError(); + return; + } + record.port = boost::numeric_cast<unsigned short>(ns_get16(currentEntry)); + currentEntry += 2; + + // Hostname + if (currentEntry >= messageEnd) { + emitError(); + return; + } + } + catch (const boost::numeric::bad_numeric_cast&) { + emitError(); + return; + } + + ByteArray entry; + entry.resize(NS_MAXDNAME); + entryLength = dn_expand(messageStart, messageEnd, currentEntry, reinterpret_cast<char*>(vecptr(entry)), entry.size()); + if (entryLength < 0) { + emitError(); + return; + } + record.hostname = std::string(reinterpret_cast<const char*>(vecptr(entry))); + records.push_back(record); + currentEntry += entryLength; + answersCount--; + } #endif - BoostRandomGenerator generator; - DomainNameServiceQuery::sortResults(records, generator); - //std::cout << "Sending out " << records.size() << " SRV results " << std::endl; - eventLoop->postEvent(boost::bind(boost::ref(onResult), records), shared_from_this()); + StdRandomGenerator generator; + DomainNameServiceQuery::sortResults(records, generator); + //std::cout << "Sending out " << records.size() << " SRV results " << std::endl; + eventLoop->postEvent(boost::bind(boost::ref(onResult), records), shared_from_this()); } void PlatformDomainNameServiceQuery::emitError() { - eventLoop->postEvent(boost::bind(boost::ref(onResult), std::vector<DomainNameServiceQuery::Result>()), shared_from_this()); + eventLoop->postEvent(boost::bind(boost::ref(onResult), std::vector<DomainNameServiceQuery::Result>()), shared_from_this()); } } diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.h b/Swiften/Network/PlatformDomainNameServiceQuery.h index 310e639..0d690f3 100644 --- a/Swiften/Network/PlatformDomainNameServiceQuery.h +++ b/Swiften/Network/PlatformDomainNameServiceQuery.h @@ -1,33 +1,35 @@ /* - * Copyright (c) 2010-2013 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/enable_shared_from_this.hpp> +#include <memory> +#include <string> -#include <Swiften/Network/DomainNameServiceQuery.h> #include <Swiften/EventLoop/EventOwner.h> -#include <string> +#include <Swiften/Network/DomainNameServiceQuery.h> #include <Swiften/Network/PlatformDomainNameQuery.h> namespace Swift { - class EventLoop; + class EventLoop; - class PlatformDomainNameServiceQuery : public DomainNameServiceQuery, public PlatformDomainNameQuery, public boost::enable_shared_from_this<PlatformDomainNameServiceQuery>, public EventOwner { - public: - PlatformDomainNameServiceQuery(const std::string& service, EventLoop* eventLoop, PlatformDomainNameResolver* resolver); + class PlatformDomainNameServiceQuery : public DomainNameServiceQuery, public PlatformDomainNameQuery, public std::enable_shared_from_this<PlatformDomainNameServiceQuery>, public EventOwner { + public: + PlatformDomainNameServiceQuery(const boost::optional<std::string>& serviceName, EventLoop* eventLoop, PlatformDomainNameResolver* resolver); + virtual ~PlatformDomainNameServiceQuery(); - virtual void run(); + virtual void run(); - private: - void runBlocking(); - void emitError(); + private: + void runBlocking(); + void emitError(); - private: - EventLoop* eventLoop; - std::string service; - }; + private: + EventLoop* eventLoop; + std::string service; + bool serviceValid; + }; } diff --git a/Swiften/Network/PlatformNATTraversalWorker.cpp b/Swiften/Network/PlatformNATTraversalWorker.cpp index 133b006..5431379 100644 --- a/Swiften/Network/PlatformNATTraversalWorker.cpp +++ b/Swiften/Network/PlatformNATTraversalWorker.cpp @@ -4,13 +4,21 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ -#include "PlatformNATTraversalWorker.h" +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + +#include <Swiften/Network/PlatformNATTraversalWorker.h> + +#include <memory> -#include <boost/smart_ptr/make_shared.hpp> -#include <boost/enable_shared_from_this.hpp> #include <boost/numeric/conversion/cast.hpp> #include <Swiften/Base/Log.h> +#include <Swiften/EventLoop/EventLoop.h> +#include <Swiften/EventLoop/EventOwner.h> #include <Swiften/Network/NATTraversalGetPublicIPRequest.h> #include <Swiften/Network/NATTraversalForwardPortRequest.h> #include <Swiften/Network/NATTraversalRemovePortForwardingRequest.h> @@ -23,179 +31,194 @@ namespace Swift { -class PlatformNATTraversalRequest : public boost::enable_shared_from_this<PlatformNATTraversalRequest> { - public: - typedef boost::shared_ptr<PlatformNATTraversalRequest> ref; +class PlatformNATTraversalRequest : public std::enable_shared_from_this<PlatformNATTraversalRequest>, public EventOwner { + public: + typedef std::shared_ptr<PlatformNATTraversalRequest> ref; - public: - PlatformNATTraversalRequest(PlatformNATTraversalWorker* worker) : worker(worker) { - } + public: + PlatformNATTraversalRequest(PlatformNATTraversalWorker* worker) : worker(worker) { + } - virtual ~PlatformNATTraversalRequest() { - } + virtual ~PlatformNATTraversalRequest() { + } - virtual void doRun() { - worker->addRequestToQueue(shared_from_this()); - } + virtual void doRun() { + worker->addRequestToQueue(shared_from_this()); + } - NATTraversalInterface* getNATTraversalInterface() const { - return worker->getNATTraversalInterface(); - } + NATTraversalInterface* getNATTraversalInterface() const { + return worker->getNATTraversalInterface(); + } + EventLoop* getEventLoop() const { + return worker->getEventLoop(); + } - virtual void runBlocking() = 0; - private: - PlatformNATTraversalWorker* worker; + virtual void runBlocking() = 0; + + private: + PlatformNATTraversalWorker* worker; }; class PlatformNATTraversalGetPublicIPRequest : public NATTraversalGetPublicIPRequest, public PlatformNATTraversalRequest { - public: - PlatformNATTraversalGetPublicIPRequest(PlatformNATTraversalWorker* worker) : PlatformNATTraversalRequest(worker) { - } + public: + PlatformNATTraversalGetPublicIPRequest(PlatformNATTraversalWorker* worker) : PlatformNATTraversalRequest(worker) { + } + + virtual ~PlatformNATTraversalGetPublicIPRequest() { + } - virtual void start() { - doRun(); - } + virtual void start() { + doRun(); + } - virtual void stop() { - // TODO - } + virtual void stop() { + onResult.disconnect_all_slots(); + } - virtual void runBlocking() { - onResult(getNATTraversalInterface()->getPublicIP()); - } + virtual void runBlocking() { + getEventLoop()->postEvent(boost::bind(boost::ref(onResult), getNATTraversalInterface()->getPublicIP()), shared_from_this()); + } }; class PlatformNATTraversalForwardPortRequest : public NATTraversalForwardPortRequest, public PlatformNATTraversalRequest { - public: - PlatformNATTraversalForwardPortRequest(PlatformNATTraversalWorker* worker, unsigned int localIP, unsigned int publicIP) : PlatformNATTraversalRequest(worker), localIP(localIP), publicIP(publicIP) { - } + public: + PlatformNATTraversalForwardPortRequest(PlatformNATTraversalWorker* worker, unsigned short localPort, unsigned short publicPort) : PlatformNATTraversalRequest(worker), localPort(localPort), publicPort(publicPort) { + } - virtual void start() { - doRun(); - } + virtual ~PlatformNATTraversalForwardPortRequest() { + } - virtual void stop() { - // TODO - } + virtual void start() { + doRun(); + } - virtual void runBlocking() { - onResult(getNATTraversalInterface()->addPortForward(boost::numeric_cast<int>(localIP), boost::numeric_cast<int>(publicIP))); - } + virtual void stop() { + onResult.disconnect_all_slots(); + } - private: - unsigned int localIP; - unsigned int publicIP; + virtual void runBlocking() { + getEventLoop()->postEvent(boost::bind(boost::ref(onResult), getNATTraversalInterface()->addPortForward(localPort, publicPort)), shared_from_this()); + } + + private: + unsigned short localPort; + unsigned short publicPort; }; class PlatformNATTraversalRemovePortForwardingRequest : public NATTraversalRemovePortForwardingRequest, public PlatformNATTraversalRequest { - public: - PlatformNATTraversalRemovePortForwardingRequest(PlatformNATTraversalWorker* worker, const NATPortMapping& mapping) : PlatformNATTraversalRequest(worker), mapping(mapping) { - } + public: + PlatformNATTraversalRemovePortForwardingRequest(PlatformNATTraversalWorker* worker, const NATPortMapping& mapping) : PlatformNATTraversalRequest(worker), mapping(mapping) { + } + + virtual ~PlatformNATTraversalRemovePortForwardingRequest() { + } - virtual void start() { - doRun(); - } + virtual void start() { + doRun(); + } - virtual void stop() { - // TODO - } + virtual void stop() { + onResult.disconnect_all_slots(); + } - virtual void runBlocking() { - onResult(getNATTraversalInterface()->removePortForward(mapping)); - } + virtual void runBlocking() { + getEventLoop()->postEvent(boost::bind(boost::ref(onResult), getNATTraversalInterface()->removePortForward(mapping)), shared_from_this()); + } - private: - NATPortMapping mapping; + private: + NATPortMapping mapping; }; -PlatformNATTraversalWorker::PlatformNATTraversalWorker(EventLoop* eventLoop) : eventLoop(eventLoop), stopRequested(false), natPMPSupported(boost::logic::indeterminate), natPMPInterface(NULL), miniUPnPSupported(boost::logic::indeterminate), miniUPnPInterface(NULL) { - nullNATTraversalInterface = new NullNATTraversalInterface(); - // FIXME: This should be done from start(), and the current start() should be an internal method - thread = new boost::thread(boost::bind(&PlatformNATTraversalWorker::start, this)); +PlatformNATTraversalWorker::PlatformNATTraversalWorker(EventLoop* eventLoop) : eventLoop(eventLoop), stopRequested(false), natPMPSupported(boost::logic::indeterminate), natPMPInterface(nullptr), miniUPnPSupported(boost::logic::indeterminate), miniUPnPInterface(nullptr) { + nullNATTraversalInterface = new NullNATTraversalInterface(); + // FIXME: This should be done from start(), and the current start() should be an internal method + thread = new std::thread(boost::bind(&PlatformNATTraversalWorker::start, this)); } PlatformNATTraversalWorker::~PlatformNATTraversalWorker() { - stopRequested = true; - addRequestToQueue(boost::shared_ptr<PlatformNATTraversalRequest>()); - thread->join(); - delete thread; + stopRequested = true; + addRequestToQueue(std::shared_ptr<PlatformNATTraversalRequest>()); + thread->join(); + delete thread; #ifdef HAVE_LIBNATPMP - delete natPMPInterface; + delete natPMPInterface; #endif #ifdef HAVE_LIBMINIUPNPC - delete miniUPnPInterface; + delete miniUPnPInterface; #endif - delete nullNATTraversalInterface; + delete nullNATTraversalInterface; } NATTraversalInterface* PlatformNATTraversalWorker::getNATTraversalInterface() const { #ifdef HAVE_LIBMINIUPNPC - if (boost::logic::indeterminate(miniUPnPSupported)) { - miniUPnPInterface = new MiniUPnPInterface(); - miniUPnPSupported = miniUPnPInterface->isAvailable(); - } - if (miniUPnPSupported) { - return miniUPnPInterface; - } + if (boost::logic::indeterminate(miniUPnPSupported)) { + miniUPnPInterface = new MiniUPnPInterface(); + miniUPnPSupported = miniUPnPInterface->isAvailable(); + } + SWIFT_LOG(debug) << "UPnP NAT traversal supported: " << static_cast<bool>(miniUPnPSupported); + if (miniUPnPSupported) { + return miniUPnPInterface; + } #endif #ifdef HAVE_LIBNATPMP - if (boost::logic::indeterminate(natPMPSupported)) { - natPMPInterface = new NATPMPInterface(); - natPMPSupported = natPMPInterface->isAvailable(); - } - if (natPMPSupported) { - return natPMPInterface; - } + if (boost::logic::indeterminate(natPMPSupported)) { + natPMPInterface = new NATPMPInterface(); + natPMPSupported = natPMPInterface->isAvailable(); + } + SWIFT_LOG(debug) << "NAT-PMP NAT traversal supported: " << static_cast<bool>(natPMPSupported); + if (natPMPSupported) { + return natPMPInterface; + } #endif - return nullNATTraversalInterface; + return nullNATTraversalInterface; } -boost::shared_ptr<NATTraversalGetPublicIPRequest> PlatformNATTraversalWorker::createGetPublicIPRequest() { - return boost::make_shared<PlatformNATTraversalGetPublicIPRequest>(this); +std::shared_ptr<NATTraversalGetPublicIPRequest> PlatformNATTraversalWorker::createGetPublicIPRequest() { + return std::make_shared<PlatformNATTraversalGetPublicIPRequest>(this); } -boost::shared_ptr<NATTraversalForwardPortRequest> PlatformNATTraversalWorker::createForwardPortRequest(int localPort, int publicPort) { - return boost::make_shared<PlatformNATTraversalForwardPortRequest>(this, localPort, publicPort); +std::shared_ptr<NATTraversalForwardPortRequest> PlatformNATTraversalWorker::createForwardPortRequest(unsigned short localPort, unsigned short publicPort) { + return std::make_shared<PlatformNATTraversalForwardPortRequest>(this, localPort, publicPort); } -boost::shared_ptr<NATTraversalRemovePortForwardingRequest> PlatformNATTraversalWorker::createRemovePortForwardingRequest(int localPort, int publicPort) { - NATPortMapping mapping(localPort, publicPort, NATPortMapping::TCP); // FIXME - return boost::make_shared<PlatformNATTraversalRemovePortForwardingRequest>(this, mapping); +std::shared_ptr<NATTraversalRemovePortForwardingRequest> PlatformNATTraversalWorker::createRemovePortForwardingRequest(unsigned short localPort, unsigned short publicPort) { + NATPortMapping mapping(localPort, publicPort, NATPortMapping::TCP); // FIXME + return std::make_shared<PlatformNATTraversalRemovePortForwardingRequest>(this, mapping); } void PlatformNATTraversalWorker::start() { - while (!stopRequested) { - PlatformNATTraversalRequest::ref request; - { - boost::unique_lock<boost::mutex> lock(queueMutex); - while (queue.empty()) { - queueNonEmpty.wait(lock); - } - request = queue.front(); - queue.pop_front(); - } - // Check whether we don't have a non-null request (used to stop the - // worker) - if (request) { - request->runBlocking(); - } - } + while (!stopRequested) { + PlatformNATTraversalRequest::ref request; + { + std::unique_lock<std::mutex> lock(queueMutex); + while (queue.empty()) { + queueNonEmpty.wait(lock); + } + request = queue.front(); + queue.pop_front(); + } + // Check whether we don't have a non-null request (used to stop the + // worker) + if (request) { + request->runBlocking(); + } + } } void PlatformNATTraversalWorker::stop() { - // TODO + // TODO } void PlatformNATTraversalWorker::addRequestToQueue(PlatformNATTraversalRequest::ref request) { - { - boost::lock_guard<boost::mutex> lock(queueMutex); - queue.push_back(request); - } - queueNonEmpty.notify_one(); + { + std::lock_guard<std::mutex> lock(queueMutex); + queue.push_back(request); + } + queueNonEmpty.notify_one(); } } diff --git a/Swiften/Network/PlatformNATTraversalWorker.h b/Swiften/Network/PlatformNATTraversalWorker.h index 6148705..368798e 100644 --- a/Swiften/Network/PlatformNATTraversalWorker.h +++ b/Swiften/Network/PlatformNATTraversalWorker.h @@ -4,60 +4,71 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once +#include <condition_variable> #include <deque> -#include <boost/optional.hpp> -#include <boost/thread/thread.hpp> -#include <boost/thread/mutex.hpp> -#include <boost/thread/condition_variable.hpp> +#include <mutex> +#include <thread> + #include <boost/logic/tribool.hpp> +#include <boost/optional.hpp> #include <Swiften/Base/API.h> -#include <Swiften/Network/NATTraverser.h> +#include <Swiften/Base/Atomic.h> #include <Swiften/Network/HostAddressPort.h> +#include <Swiften/Network/NATTraverser.h> #include <Swiften/Network/NullNATTraversalInterface.h> namespace Swift { - class EventLoop; - class NATTraversalGetPublicIPRequest; - class NATTraversalForwardPortRequest; - class NATTraversalRemovePortForwardingRequest; - class PlatformNATTraversalRequest; - class NATPMPInterface; - class MiniUPnPInterface; - class NATTraversalInterface; - class NATPortMapping; - - class SWIFTEN_API PlatformNATTraversalWorker : public NATTraverser { - friend class PlatformNATTraversalRequest; - - public: - PlatformNATTraversalWorker(EventLoop* eventLoop); - ~PlatformNATTraversalWorker(); - - boost::shared_ptr<NATTraversalGetPublicIPRequest> createGetPublicIPRequest(); - boost::shared_ptr<NATTraversalForwardPortRequest> createForwardPortRequest(int localPort, int publicPort); - boost::shared_ptr<NATTraversalRemovePortForwardingRequest> createRemovePortForwardingRequest(int localPort, int publicPort); - - private: - NATTraversalInterface* getNATTraversalInterface() const; - void addRequestToQueue(boost::shared_ptr<PlatformNATTraversalRequest>); - void start(); - void stop(); - - private: - EventLoop* eventLoop; - bool stopRequested; - boost::thread* thread; - std::deque<boost::shared_ptr<PlatformNATTraversalRequest> > queue; - boost::mutex queueMutex; - boost::condition_variable queueNonEmpty; - - NullNATTraversalInterface* nullNATTraversalInterface; - mutable boost::logic::tribool natPMPSupported; - mutable NATPMPInterface* natPMPInterface; - mutable boost::logic::tribool miniUPnPSupported; - mutable MiniUPnPInterface* miniUPnPInterface; - }; + class EventLoop; + class NATTraversalGetPublicIPRequest; + class NATTraversalForwardPortRequest; + class NATTraversalRemovePortForwardingRequest; + class PlatformNATTraversalRequest; + class NATPMPInterface; + class MiniUPnPInterface; + class NATTraversalInterface; + + class SWIFTEN_API PlatformNATTraversalWorker : public NATTraverser { + friend class PlatformNATTraversalRequest; + + public: + PlatformNATTraversalWorker(EventLoop* eventLoop); + virtual ~PlatformNATTraversalWorker(); + + std::shared_ptr<NATTraversalGetPublicIPRequest> createGetPublicIPRequest(); + std::shared_ptr<NATTraversalForwardPortRequest> createForwardPortRequest(unsigned short localPort, unsigned short publicPort); + std::shared_ptr<NATTraversalRemovePortForwardingRequest> createRemovePortForwardingRequest(unsigned short localPort, unsigned short publicPort); + + private: + NATTraversalInterface* getNATTraversalInterface() const; + void addRequestToQueue(std::shared_ptr<PlatformNATTraversalRequest>); + void start(); + void stop(); + + EventLoop* getEventLoop() const { + return eventLoop; + } + + private: + EventLoop* eventLoop; + Atomic<bool> stopRequested; + std::thread* thread; + std::deque<std::shared_ptr<PlatformNATTraversalRequest> > queue; + std::mutex queueMutex; + std::condition_variable queueNonEmpty; + + NullNATTraversalInterface* nullNATTraversalInterface; + mutable boost::logic::tribool natPMPSupported; + mutable NATPMPInterface* natPMPInterface; + mutable boost::logic::tribool miniUPnPSupported; + mutable MiniUPnPInterface* miniUPnPInterface; + }; } diff --git a/Swiften/Network/PlatformNetworkEnvironment.h b/Swiften/Network/PlatformNetworkEnvironment.h index c6b945e..ff29491 100644 --- a/Swiften/Network/PlatformNetworkEnvironment.h +++ b/Swiften/Network/PlatformNetworkEnvironment.h @@ -11,16 +11,22 @@ #if defined(SWIFTEN_PLATFORM_MACOSX) #include <Swiften/Network/UnixNetworkEnvironment.h> namespace Swift { - typedef UnixNetworkEnvironment PlatformNetworkEnvironment; + typedef UnixNetworkEnvironment PlatformNetworkEnvironment; } #elif defined(SWIFTEN_PLATFORM_WIN32) #include <Swiften/Network/WindowsNetworkEnvironment.h> namespace Swift { - typedef WindowsNetworkEnvironment PlatformNetworkEnvironment; + typedef WindowsNetworkEnvironment PlatformNetworkEnvironment; } +#elif defined(SWIFTEN_PLATFORM_SOLARIS) +#include <Swiften/Network/SolarisNetworkEnvironment.h> +namespace Swift { + typedef SolarisNetworkEnvironment PlatformNetworkEnvironment; +} + #else #include <Swiften/Network/UnixNetworkEnvironment.h> namespace Swift { - typedef UnixNetworkEnvironment PlatformNetworkEnvironment; + typedef UnixNetworkEnvironment PlatformNetworkEnvironment; } #endif diff --git a/Swiften/Network/PlatformProxyProvider.h b/Swiften/Network/PlatformProxyProvider.h index 1a0a1c6..c63e718 100644 --- a/Swiften/Network/PlatformProxyProvider.h +++ b/Swiften/Network/PlatformProxyProvider.h @@ -11,16 +11,16 @@ #if defined(SWIFTEN_PLATFORM_MACOSX) #include <Swiften/Network/MacOSXProxyProvider.h> namespace Swift { - typedef MacOSXProxyProvider PlatformProxyProvider; + typedef MacOSXProxyProvider PlatformProxyProvider; } #elif defined(SWIFTEN_PLATFORM_WIN32) #include <Swiften/Network/WindowsProxyProvider.h> namespace Swift { - typedef WindowsProxyProvider PlatformProxyProvider; + typedef WindowsProxyProvider PlatformProxyProvider; } #else #include <Swiften/Network/UnixProxyProvider.h> namespace Swift { - typedef UnixProxyProvider PlatformProxyProvider; + typedef UnixProxyProvider PlatformProxyProvider; } #endif diff --git a/Swiften/Network/ProxiedConnection.cpp b/Swiften/Network/ProxiedConnection.cpp index 8bf12d3..0c5cda6 100644 --- a/Swiften/Network/ProxiedConnection.cpp +++ b/Swiften/Network/ProxiedConnection.cpp @@ -1,112 +1,127 @@ /* - * Copyright (c) 2012 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2012-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ - #include <Swiften/Network/ProxiedConnection.h> -#include <iostream> #include <boost/bind.hpp> #include <Swiften/Base/ByteArray.h> -#include <Swiften/Network/HostAddressPort.h> +#include <Swiften/Base/Log.h> #include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/Network/HostAddressPort.h> using namespace Swift; ProxiedConnection::ProxiedConnection( - DomainNameResolver* resolver, - ConnectionFactory* connectionFactory, - TimerFactory* timerFactory, - const std::string& proxyHost, - int proxyPort) : - resolver_(resolver), - connectionFactory_(connectionFactory), - timerFactory_(timerFactory), - proxyHost_(proxyHost), - proxyPort_(proxyPort), - server_(HostAddressPort(HostAddress("0.0.0.0"), 0)) { - connected_ = false; + DomainNameResolver* resolver, + ConnectionFactory* connectionFactory, + TimerFactory* timerFactory, + const std::string& proxyHost, + unsigned short proxyPort) : + resolver_(resolver), + connectionFactory_(connectionFactory), + timerFactory_(timerFactory), + proxyHost_(proxyHost), + proxyPort_(proxyPort), + server_(HostAddressPort(HostAddress::fromString("0.0.0.0").get(), 0)) { + connected_ = false; } ProxiedConnection::~ProxiedConnection() { - cancelConnector(); - if (connection_) { - connection_->onDataRead.disconnect(boost::bind(&ProxiedConnection::handleDataRead, shared_from_this(), _1)); - connection_->onDisconnected.disconnect(boost::bind(&ProxiedConnection::handleDisconnected, shared_from_this(), _1)); - } - if (connected_) { - std::cerr << "Warning: Connection was still established." << std::endl; - } + cancelConnector(); + if (connection_) { + connection_->onDataRead.disconnect(boost::bind(&ProxiedConnection::handleDataRead, shared_from_this(), _1)); + connection_->onDisconnected.disconnect(boost::bind(&ProxiedConnection::handleDisconnected, shared_from_this(), _1)); + } + if (connected_) { + SWIFT_LOG(warning) << "Connection was still established."; + } } void ProxiedConnection::cancelConnector() { - if (connector_) { - connector_->onConnectFinished.disconnect(boost::bind(&ProxiedConnection::handleConnectFinished, shared_from_this(), _1)); - connector_->stop(); - connector_.reset(); - } + if (connector_) { + connector_->onConnectFinished.disconnect(boost::bind(&ProxiedConnection::handleConnectFinished, shared_from_this(), _1)); + connector_->stop(); + connector_.reset(); + } } void ProxiedConnection::connect(const HostAddressPort& server) { - server_ = server; + server_ = server; - connector_ = Connector::create(proxyHost_, proxyPort_, false, resolver_, connectionFactory_, timerFactory_); - connector_->onConnectFinished.connect(boost::bind(&ProxiedConnection::handleConnectFinished, shared_from_this(), _1)); - connector_->start(); + connector_ = Connector::create(proxyHost_, proxyPort_, boost::optional<std::string>(), resolver_, connectionFactory_, timerFactory_); + connector_->onConnectFinished.connect(boost::bind(&ProxiedConnection::handleConnectFinished, shared_from_this(), _1)); + connector_->start(); } void ProxiedConnection::listen() { - assert(false); - connection_->listen(); + assert(false); + connection_->listen(); } void ProxiedConnection::disconnect() { - connected_ = false; - connection_->disconnect(); + cancelConnector(); + connected_ = false; + if (connection_) { + connection_->disconnect(); + } } void ProxiedConnection::handleDisconnected(const boost::optional<Error>& error) { - onDisconnected(error); + onDisconnected(error); } void ProxiedConnection::write(const SafeByteArray& data) { - connection_->write(data); + connection_->write(data); } void ProxiedConnection::handleConnectFinished(Connection::ref connection) { - cancelConnector(); - if (connection) { - connection_ = connection; - connection_->onDataRead.connect(boost::bind(&ProxiedConnection::handleDataRead, shared_from_this(), _1)); - connection_->onDisconnected.connect(boost::bind(&ProxiedConnection::handleDisconnected, shared_from_this(), _1)); - - initializeProxy(); - } - else { - onConnectFinished(true); - } + cancelConnector(); + if (connection) { + connection_ = connection; + connection_->onDataRead.connect(boost::bind(&ProxiedConnection::handleDataRead, shared_from_this(), _1)); + connection_->onDisconnected.connect(boost::bind(&ProxiedConnection::handleDisconnected, shared_from_this(), _1)); + + initializeProxy(); + } + else { + onConnectFinished(true); + } } -void ProxiedConnection::handleDataRead(boost::shared_ptr<SafeByteArray> data) { - if (!connected_) { - handleProxyInitializeData(data); - } - else { - onDataRead(data); - } +void ProxiedConnection::handleDataRead(std::shared_ptr<SafeByteArray> data) { + if (!connected_) { + handleProxyInitializeData(data); + } + else { + onDataRead(data); + } } HostAddressPort ProxiedConnection::getLocalAddress() const { - return connection_->getLocalAddress(); + return connection_->getLocalAddress(); +} + +HostAddressPort ProxiedConnection::getRemoteAddress() const { + return connection_->getRemoteAddress(); } void ProxiedConnection::setProxyInitializeFinished(bool success) { - connected_ = success; - if (!success) { - disconnect(); - } - onConnectFinished(!success); + connected_ = success; + if (!success) { + disconnect(); + } + onConnectFinished(!success); +} + +void ProxiedConnection::reconnect() { + if (connected_) { + connection_->onDataRead.disconnect(boost::bind(&ProxiedConnection::handleDataRead, shared_from_this(), _1)); + connection_->onDisconnected.disconnect(boost::bind(&ProxiedConnection::handleDisconnected, shared_from_this(), _1)); + connection_->disconnect(); + } + connect(server_); } diff --git a/Swiften/Network/ProxiedConnection.h b/Swiften/Network/ProxiedConnection.h index aa8df38..f79845a 100644 --- a/Swiften/Network/ProxiedConnection.h +++ b/Swiften/Network/ProxiedConnection.h @@ -1,67 +1,64 @@ /* - * Copyright (c) 2012 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2012-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/enable_shared_from_this.hpp> +#include <memory> +#include <Swiften/Base/API.h> +#include <Swiften/Base/SafeString.h> #include <Swiften/Network/Connection.h> #include <Swiften/Network/Connector.h> #include <Swiften/Network/HostAddressPort.h> -#include <Swiften/Base/SafeString.h> - -namespace boost { - class thread; - namespace system { - class error_code; - } -} namespace Swift { - class ConnectionFactory; + class ConnectionFactory; + + class SWIFTEN_API ProxiedConnection : public Connection, public std::enable_shared_from_this<ProxiedConnection> { + public: + ProxiedConnection(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, unsigned short proxyPort); + virtual ~ProxiedConnection(); - class ProxiedConnection : public Connection, public boost::enable_shared_from_this<ProxiedConnection> { - public: - ProxiedConnection(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, int proxyPort); - ~ProxiedConnection(); + virtual void listen(); + virtual void connect(const HostAddressPort& address); + virtual void disconnect(); + virtual void write(const SafeByteArray& data); - virtual void listen(); - virtual void connect(const HostAddressPort& address); - virtual void disconnect(); - virtual void write(const SafeByteArray& data); + virtual HostAddressPort getLocalAddress() const; + virtual HostAddressPort getRemoteAddress() const; - virtual HostAddressPort getLocalAddress() const; + private: + void handleConnectFinished(Connection::ref connection); + void handleDataRead(std::shared_ptr<SafeByteArray> data); + void handleDisconnected(const boost::optional<Error>& error); + void cancelConnector(); - private: - void handleConnectFinished(Connection::ref connection); - void handleDataRead(boost::shared_ptr<SafeByteArray> data); - void handleDisconnected(const boost::optional<Error>& error); - void cancelConnector(); + protected: + void setProxyInitializeFinished(bool success); - protected: - void setProxyInitializeFinished(bool success); + virtual void initializeProxy() = 0; + virtual void handleProxyInitializeData(std::shared_ptr<SafeByteArray> data) = 0; - virtual void initializeProxy() = 0; - virtual void handleProxyInitializeData(boost::shared_ptr<SafeByteArray> data) = 0; + const HostAddressPort& getServer() const { + return server_; + } - const HostAddressPort& getServer() const { - return server_; - } + void reconnect(); - private: - bool connected_; - DomainNameResolver* resolver_; - ConnectionFactory* connectionFactory_; - TimerFactory* timerFactory_; - std::string proxyHost_; - int proxyPort_; - HostAddressPort server_; - Connector::ref connector_; - boost::shared_ptr<Connection> connection_; - }; + private: + bool connected_; + DomainNameResolver* resolver_; + ConnectionFactory* connectionFactory_; + TimerFactory* timerFactory_; + std::string proxyHost_; + unsigned short proxyPort_; + HostAddressPort server_; + Connector::ref connector_; + std::shared_ptr<Connection> connection_; + }; } diff --git a/Swiften/Network/ProxyProvider.cpp b/Swiften/Network/ProxyProvider.cpp index fe235b1..dd07d3a 100644 --- a/Swiften/Network/ProxyProvider.cpp +++ b/Swiften/Network/ProxyProvider.cpp @@ -4,7 +4,13 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ -#include "ProxyProvider.h" +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + +#include <Swiften/Network/ProxyProvider.h> namespace Swift { diff --git a/Swiften/Network/ProxyProvider.h b/Swiften/Network/ProxyProvider.h index 9a1ccee..bf737c0 100644 --- a/Swiften/Network/ProxyProvider.h +++ b/Swiften/Network/ProxyProvider.h @@ -4,19 +4,27 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2015-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once + #include <map> -#include <Swiften/Network/HostAddressPort.h> +#include <Swiften/Base/API.h> #include <Swiften/Base/String.h> +#include <Swiften/Network/HostAddressPort.h> namespace Swift { - class ProxyProvider { - public: - ProxyProvider(); - virtual ~ProxyProvider(); - virtual HostAddressPort getHTTPConnectProxy() const = 0; - virtual HostAddressPort getSOCKS5Proxy() const = 0; - }; + class SWIFTEN_API ProxyProvider { + public: + ProxyProvider(); + virtual ~ProxyProvider(); + virtual HostAddressPort getHTTPConnectProxy() const = 0; + virtual HostAddressPort getSOCKS5Proxy() const = 0; + }; } diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript index ea0fb62..9c2a134 100644 --- a/Swiften/Network/SConscript +++ b/Swiften/Network/SConscript @@ -3,104 +3,109 @@ Import("swiften_env") myenv = swiften_env.Clone() if myenv.get("unbound", False) : - myenv.MergeFlags(myenv.get("UNBOUND_FLAGS", {})) - myenv.MergeFlags(myenv.get("LDNS_FLAGS", {})) + myenv.MergeFlags(myenv.get("UNBOUND_FLAGS", {})) + myenv.MergeFlags(myenv.get("LDNS_FLAGS", {})) sourceList = [ - "ProxiedConnection.cpp", - "HTTPConnectProxiedConnection.cpp", - "HTTPConnectProxiedConnectionFactory.cpp", - "SOCKS5ProxiedConnection.cpp", - "SOCKS5ProxiedConnectionFactory.cpp", - "BoostConnection.cpp", - "BoostConnectionFactory.cpp", - "BoostConnectionServer.cpp", - "BoostConnectionServerFactory.cpp", - "BoostIOServiceThread.cpp", - "BOSHConnection.cpp", - "BOSHConnectionPool.cpp", - "CachingDomainNameResolver.cpp", - "ConnectionFactory.cpp", - "ConnectionServer.cpp", - "ConnectionServerFactory.cpp", - "DummyConnection.cpp", - "FakeConnection.cpp", - "ChainedConnector.cpp", - "Connector.cpp", - "Connection.cpp", - "TimerFactory.cpp", - "DummyTimerFactory.cpp", - "BoostTimerFactory.cpp", - "DomainNameResolver.cpp", - "DomainNameAddressQuery.cpp", - "DomainNameServiceQuery.cpp", - "StaticDomainNameResolver.cpp", - "HostAddress.cpp", - "HostAddressPort.cpp", - "HostNameOrAddress.cpp", - "NetworkFactories.cpp", - "BoostNetworkFactories.cpp", - "NetworkEnvironment.cpp", - "Timer.cpp", - "TLSConnection.cpp", - "TLSConnectionFactory.cpp", - "BoostTimer.cpp", - "ProxyProvider.cpp", - "NullProxyProvider.cpp", - "NATTraverser.cpp", - "NullNATTraverser.cpp", - "NATTraversalGetPublicIPRequest.cpp", - "NATTraversalForwardPortRequest.cpp", - "NATTraversalRemovePortForwardingRequest.cpp", - "NATTraversalInterface.cpp", - ] + "ProxiedConnection.cpp", + "HTTPConnectProxiedConnection.cpp", + "HTTPConnectProxiedConnectionFactory.cpp", + "SOCKS5ProxiedConnection.cpp", + "SOCKS5ProxiedConnectionFactory.cpp", + "BoostConnection.cpp", + "BoostConnectionFactory.cpp", + "BoostConnectionServer.cpp", + "BoostConnectionServerFactory.cpp", + "BoostIOServiceThread.cpp", + "BOSHConnection.cpp", + "BOSHConnectionPool.cpp", + "CachingDomainNameResolver.cpp", + "ConnectionFactory.cpp", + "ConnectionServer.cpp", + "ConnectionServerFactory.cpp", + "DummyConnection.cpp", + "FakeConnection.cpp", + "ChainedConnector.cpp", + "Connector.cpp", + "Connection.cpp", + "TimerFactory.cpp", + "DummyTimerFactory.cpp", + "BoostTimerFactory.cpp", + "DomainNameResolver.cpp", + "DomainNameAddressQuery.cpp", + "DomainNameServiceQuery.cpp", + "StaticDomainNameResolver.cpp", + "HostAddress.cpp", + "HostAddressPort.cpp", + "HostNameOrAddress.cpp", + "NetworkFactories.cpp", + "BoostNetworkFactories.cpp", + "NetworkEnvironment.cpp", + "Timer.cpp", + "TLSConnection.cpp", + "TLSConnectionFactory.cpp", + "BoostTimer.cpp", + "ProxyProvider.cpp", + "NullProxyProvider.cpp", + "NATTraverser.cpp", + "NullNATTraverser.cpp", + "NATTraversalGetPublicIPRequest.cpp", + "NATTraversalForwardPortRequest.cpp", + "NATTraversalRemovePortForwardingRequest.cpp", + "NATTraversalInterface.cpp", + "HTTPTrafficFilter.cpp", + ] if myenv.get("unbound", False) : - myenv.Append(CPPDEFINES = "USE_UNBOUND") - sourceList.append("UnboundDomainNameResolver.cpp") + myenv.Append(CPPDEFINES = "USE_UNBOUND") + sourceList.append("UnboundDomainNameResolver.cpp") else : - sourceList.append("PlatformDomainNameResolver.cpp") - sourceList.append("PlatformDomainNameServiceQuery.cpp") - sourceList.append("PlatformDomainNameAddressQuery.cpp") + sourceList.append("PlatformDomainNameResolver.cpp") + sourceList.append("PlatformDomainNameServiceQuery.cpp") + sourceList.append("PlatformDomainNameAddressQuery.cpp") if myenv["PLATFORM"] == "darwin" and myenv["target"] != "android": - myenv.Append(FRAMEWORKS = ["CoreServices", "SystemConfiguration"]) - sourceList += [ "MacOSXProxyProvider.cpp" ] - sourceList += [ "UnixNetworkEnvironment.cpp" ] + myenv.Append(FRAMEWORKS = ["CoreServices", "SystemConfiguration"]) + sourceList += [ "MacOSXProxyProvider.cpp" ] + sourceList += [ "UnixNetworkEnvironment.cpp" ] elif myenv["PLATFORM"] == "win32" : - sourceList += [ "WindowsProxyProvider.cpp" ] - sourceList += [ "WindowsNetworkEnvironment.cpp" ] + sourceList += [ "WindowsProxyProvider.cpp" ] + sourceList += [ "WindowsNetworkEnvironment.cpp" ] +elif myenv["PLATFORM"] == "sunos" : + sourceList += [ "UnixProxyProvider.cpp" ] + sourceList += [ "SolarisNetworkEnvironment.cpp" ] + sourceList += [ "EnvironmentProxyProvider.cpp" ] else : - sourceList += [ "UnixNetworkEnvironment.cpp" ] - sourceList += [ "UnixProxyProvider.cpp" ] - sourceList += [ "EnvironmentProxyProvider.cpp" ] - if myenv.get("HAVE_GCONF", 0) : - myenv.Append(CPPDEFINES = "HAVE_GCONF") - myenv.MergeFlags(myenv["GCONF_FLAGS"]) - sourceList += [ "GConfProxyProvider.cpp" ] + sourceList += [ "UnixNetworkEnvironment.cpp" ] + sourceList += [ "UnixProxyProvider.cpp" ] + sourceList += [ "EnvironmentProxyProvider.cpp" ] + if myenv.get("HAVE_GCONF", 0) : + myenv.Append(CPPDEFINES = "HAVE_GCONF") + myenv.MergeFlags(myenv["GCONF_FLAGS"]) + sourceList += [ "GConfProxyProvider.cpp" ] objects = myenv.SwiftenObject(sourceList) -if myenv["experimental"] : - # LibNATPMP classes - if myenv.get("HAVE_LIBNATPMP", False) : - natpmp_env = myenv.Clone() - natpmp_env.Append(CPPDEFINES = natpmp_env["LIBNATPMP_FLAGS"].get("INTERNAL_CPPDEFINES", [])) - myenv.Append(CPPDEFINES = ["HAVE_LIBNATPMP"]) - objects += natpmp_env.SwiftenObject([ - "NATPMPInterface.cpp", - ]) +if myenv["experimental_ft"] : + # LibNATPMP classes + if myenv.get("HAVE_LIBNATPMP", False) : + natpmp_env = myenv.Clone() + natpmp_env.Append(CPPDEFINES = natpmp_env["LIBNATPMP_FLAGS"].get("INTERNAL_CPPDEFINES", [])) + myenv.Append(CPPDEFINES = ["HAVE_LIBNATPMP"]) + objects += natpmp_env.SwiftenObject([ + "NATPMPInterface.cpp", + ]) - # LibMINIUPnP classes - if myenv.get("HAVE_LIBMINIUPNPC", False) : - upnp_env = myenv.Clone() - upnp_env.Append(CPPDEFINES = upnp_env["LIBMINIUPNPC_FLAGS"].get("INTERNAL_CPPDEFINES", [])) - myenv.Append(CPPDEFINES = ["HAVE_LIBMINIUPNPC"]) - objects += upnp_env.SwiftenObject([ - "MiniUPnPInterface.cpp", - ]) - objects += myenv.SwiftenObject([ - "PlatformNATTraversalWorker.cpp", - ]) + # LibMINIUPnP classes + if myenv.get("HAVE_LIBMINIUPNPC", False) : + upnp_env = myenv.Clone() + upnp_env.Append(CPPDEFINES = upnp_env["LIBMINIUPNPC_FLAGS"].get("INTERNAL_CPPDEFINES", [])) + myenv.Append(CPPDEFINES = ["HAVE_LIBMINIUPNPC"]) + objects += upnp_env.SwiftenObject([ + "MiniUPnPInterface.cpp", + ]) + objects += myenv.SwiftenObject([ + "PlatformNATTraversalWorker.cpp", + ]) swiften_env.Append(SWIFTEN_OBJECTS = [objects]) diff --git a/Swiften/Network/SOCKS5ProxiedConnection.cpp b/Swiften/Network/SOCKS5ProxiedConnection.cpp index a9243d6..c76b6e6 100644 --- a/Swiften/Network/SOCKS5ProxiedConnection.cpp +++ b/Swiften/Network/SOCKS5ProxiedConnection.cpp @@ -4,112 +4,117 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2014-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #include <Swiften/Network/SOCKS5ProxiedConnection.h> -#include <iostream> #include <boost/bind.hpp> -#include <boost/thread.hpp> -#include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/Base/ByteArray.h> #include <Swiften/Base/Log.h> #include <Swiften/Base/String.h> -#include <Swiften/Base/ByteArray.h> +#include <Swiften/Network/ConnectionFactory.h> #include <Swiften/Network/HostAddressPort.h> using namespace Swift; SOCKS5ProxiedConnection::SOCKS5ProxiedConnection( - DomainNameResolver* resolver, - ConnectionFactory* connectionFactory, - TimerFactory* timerFactory, - const std::string& proxyHost, - int proxyPort) : - ProxiedConnection(resolver, connectionFactory, timerFactory, proxyHost, proxyPort) { - } + DomainNameResolver* resolver, + ConnectionFactory* connectionFactory, + TimerFactory* timerFactory, + const std::string& proxyHost, + unsigned short proxyPort) : + ProxiedConnection(resolver, connectionFactory, timerFactory, proxyHost, proxyPort), + proxyState_(Initial) { + } void SOCKS5ProxiedConnection::initializeProxy() { - proxyState_ = ProxyAuthenticating; - SafeByteArray socksConnect; - socksConnect.push_back(0x05); // VER = SOCKS5 = 0x05 - socksConnect.push_back(0x01); // Number of authentication methods after this byte. - socksConnect.push_back(0x00); // 0x00 == no authentication - // buffer.push_back(0x01); // 0x01 == GSSAPI - // buffer.push_back(0x02); // 0x02 == Username/Password - // rest see RFC 1928 (http://tools.ietf.org/html/rfc1928) - write(socksConnect); + proxyState_ = ProxyAuthenticating; + SafeByteArray socksConnect; + socksConnect.push_back(0x05); // VER = SOCKS5 = 0x05 + socksConnect.push_back(0x01); // Number of authentication methods after this byte. + socksConnect.push_back(0x00); // 0x00 == no authentication + // buffer.push_back(0x01); // 0x01 == GSSAPI + // buffer.push_back(0x02); // 0x02 == Username/Password + // rest see RFC 1928 (http://tools.ietf.org/html/rfc1928) + write(socksConnect); } -void SOCKS5ProxiedConnection::handleProxyInitializeData(boost::shared_ptr<SafeByteArray> data) { - SafeByteArray socksConnect; - boost::asio::ip::address rawAddress = getServer().getAddress().getRawAddress(); - assert(rawAddress.is_v4() || rawAddress.is_v6()); +void SOCKS5ProxiedConnection::handleProxyInitializeData(std::shared_ptr<SafeByteArray> data) { + SafeByteArray socksConnect; + boost::asio::ip::address rawAddress = getServer().getAddress().getRawAddress(); + assert(rawAddress.is_v4() || rawAddress.is_v6()); + + if (proxyState_ == ProxyAuthenticating) { + SWIFT_LOG(debug) << "ProxyAuthenticating response received, reply with the connect BYTEs"; + unsigned char choosenMethod = static_cast<unsigned char> ((*data)[1]); + if ((*data)[0] == 0x05 && choosenMethod != 0xFF) { + switch(choosenMethod) { // use the correct Method + case 0x00: + try { + proxyState_ = ProxyConnecting; + socksConnect.push_back(0x05); // VER = SOCKS5 = 0x05 + socksConnect.push_back(0x01); // Construct a TCP connection. (CMD) + socksConnect.push_back(0x00); // reserved. + socksConnect.push_back(rawAddress.is_v4() ? 0x01 : 0x04); // IPv4 == 0x01, Hostname == 0x02, IPv6 == 0x04. (ATYP) + size_t size = rawAddress.is_v4() ? rawAddress.to_v4().to_bytes().size() : rawAddress.to_v6().to_bytes().size(); + for (size_t s = 0; s < size; s++) { + unsigned char uc; + if(rawAddress.is_v4()) { + uc = rawAddress.to_v4().to_bytes()[s]; // the address. + } + else { + uc = rawAddress.to_v6().to_bytes()[s]; // the address. + } + socksConnect.push_back(uc); - if (proxyState_ == ProxyAuthenticating) { - SWIFT_LOG(debug) << "ProxyAuthenticating response received, reply with the connect BYTEs" << std::endl; - unsigned char choosenMethod = static_cast<unsigned char> ((*data)[1]); - if ((*data)[0] == 0x05 && choosenMethod != 0xFF) { - switch(choosenMethod) { // use the correct Method - case 0x00: - try { - proxyState_ = ProxyConnecting; - socksConnect.push_back(0x05); // VER = SOCKS5 = 0x05 - socksConnect.push_back(0x01); // Construct a TCP connection. (CMD) - socksConnect.push_back(0x00); // reserved. - socksConnect.push_back(rawAddress.is_v4() ? 0x01 : 0x04); // IPv4 == 0x01, Hostname == 0x02, IPv6 == 0x04. (ATYP) - size_t size = rawAddress.is_v4() ? rawAddress.to_v4().to_bytes().size() : rawAddress.to_v6().to_bytes().size(); - for (size_t s = 0; s < size; s++) { - unsigned char uc; - if(rawAddress.is_v4()) { - uc = rawAddress.to_v4().to_bytes()[s]; // the address. - } - else { - uc = rawAddress.to_v6().to_bytes()[s]; // the address. - } - socksConnect.push_back(uc); - - } - socksConnect.push_back(static_cast<unsigned char> ((getServer().getPort() >> 8) & 0xFF)); // highbyte of the port. - socksConnect.push_back(static_cast<unsigned char> (getServer().getPort() & 0xFF)); // lowbyte of the port. - write(socksConnect); - return; - } - catch(...) { - std::cerr << "exception caught" << std::endl; - } - write(socksConnect); - break; - default: - setProxyInitializeFinished(true); - break; - } - return; - } - setProxyInitializeFinished(false); - } - else if (proxyState_ == ProxyConnecting) { - SWIFT_LOG(debug) << "Connect response received, check if successfully." << std::endl; - SWIFT_LOG(debug) << "Errorbyte: 0x" << std::hex << static_cast<int> ((*data)[1]) << std::dec << std::endl; - /* + } + socksConnect.push_back(static_cast<unsigned char> ((getServer().getPort() >> 8) & 0xFF)); // highbyte of the port. + socksConnect.push_back(static_cast<unsigned char> (getServer().getPort() & 0xFF)); // lowbyte of the port. + write(socksConnect); + return; + } + catch(...) { + SWIFT_LOG(error) << "exception caught"; + } + write(socksConnect); + break; + default: + setProxyInitializeFinished(true); + break; + } + return; + } + setProxyInitializeFinished(false); + } + else if (proxyState_ == ProxyConnecting) { + SWIFT_LOG(debug) << "Connect response received, check if successfully."; + SWIFT_LOG(debug) << "Errorbyte: 0x" << std::hex << static_cast<int> ((*data)[1]) << std::dec; + /* - data.at(1) can be one of the following: - 0x00 succeeded - 0x01 general SOCKS server failure - 0x02 connection not allowed by ruleset - 0x03 Network unreachable - 0x04 Host unreachable - 0x05 Connection refused - 0x06 TTL expired - 0x07 Command not supported (CMD) - 0x08 Address type not supported (ATYP) - 0x09 bis 0xFF unassigned - */ - if ((*data)[0] == 0x05 && (*data)[1] == 0x0) { - SWIFT_LOG(debug) << "Successfully connected the server via the proxy." << std::endl; - setProxyInitializeFinished(true); - } - else { - std::cerr << "SOCKS Proxy returned an error: " << std::hex << (*data)[1] << std::endl; - setProxyInitializeFinished(false); - } - } + data.at(1) can be one of the following: + 0x00 succeeded + 0x01 general SOCKS server failure + 0x02 connection not allowed by ruleset + 0x03 Network unreachable + 0x04 Host unreachable + 0x05 Connection refused + 0x06 TTL expired + 0x07 Command not supported (CMD) + 0x08 Address type not supported (ATYP) + 0x09 bis 0xFF unassigned + */ + if ((*data)[0] == 0x05 && (*data)[1] == 0x0) { + SWIFT_LOG(debug) << "Successfully connected the server via the proxy."; + setProxyInitializeFinished(true); + } + else { + SWIFT_LOG(error) << "SOCKS Proxy returned an error: " << std::hex << (*data)[1]; + setProxyInitializeFinished(false); + } + } } diff --git a/Swiften/Network/SOCKS5ProxiedConnection.h b/Swiften/Network/SOCKS5ProxiedConnection.h index 7906879..515c5b7 100644 --- a/Swiften/Network/SOCKS5ProxiedConnection.h +++ b/Swiften/Network/SOCKS5ProxiedConnection.h @@ -4,33 +4,41 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2015-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once +#include <Swiften/Base/API.h> #include <Swiften/Network/ProxiedConnection.h> namespace Swift { - class ConnectionFactory; - class DomainNameResolver; - class TimerFactory; - - class SOCKS5ProxiedConnection : public ProxiedConnection { - public: - typedef boost::shared_ptr<SOCKS5ProxiedConnection> ref; - - static ref create(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, int proxyPort) { - return ref(new SOCKS5ProxiedConnection(resolver, connectionFactory, timerFactory, proxyHost, proxyPort)); - } - - private: - SOCKS5ProxiedConnection(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, int proxyPort); - - virtual void initializeProxy(); - virtual void handleProxyInitializeData(boost::shared_ptr<SafeByteArray> data); - - private: - enum { - ProxyAuthenticating = 0, - ProxyConnecting - } proxyState_; - }; + class ConnectionFactory; + class DomainNameResolver; + class TimerFactory; + + class SWIFTEN_API SOCKS5ProxiedConnection : public ProxiedConnection { + public: + typedef std::shared_ptr<SOCKS5ProxiedConnection> ref; + + static ref create(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, unsigned short proxyPort) { + return ref(new SOCKS5ProxiedConnection(resolver, connectionFactory, timerFactory, proxyHost, proxyPort)); + } + + private: + SOCKS5ProxiedConnection(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, unsigned short proxyPort); + + virtual void initializeProxy(); + virtual void handleProxyInitializeData(std::shared_ptr<SafeByteArray> data); + + private: + enum { + Initial = 0, + ProxyAuthenticating, + ProxyConnecting + } proxyState_; + }; } diff --git a/Swiften/Network/SOCKS5ProxiedConnectionFactory.cpp b/Swiften/Network/SOCKS5ProxiedConnectionFactory.cpp index af99034..abd7718 100644 --- a/Swiften/Network/SOCKS5ProxiedConnectionFactory.cpp +++ b/Swiften/Network/SOCKS5ProxiedConnectionFactory.cpp @@ -4,17 +4,23 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #include <Swiften/Network/SOCKS5ProxiedConnectionFactory.h> #include <Swiften/Network/SOCKS5ProxiedConnection.h> namespace Swift { -SOCKS5ProxiedConnectionFactory::SOCKS5ProxiedConnectionFactory(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, int proxyPort) : resolver_(resolver), connectionFactory_(connectionFactory), timerFactory_(timerFactory), proxyHost_(proxyHost), proxyPort_(proxyPort) { +SOCKS5ProxiedConnectionFactory::SOCKS5ProxiedConnectionFactory(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, unsigned short proxyPort) : resolver_(resolver), connectionFactory_(connectionFactory), timerFactory_(timerFactory), proxyHost_(proxyHost), proxyPort_(proxyPort) { } -boost::shared_ptr<Connection> SOCKS5ProxiedConnectionFactory::createConnection() { - return SOCKS5ProxiedConnection::create(resolver_, connectionFactory_, timerFactory_, proxyHost_, proxyPort_); +std::shared_ptr<Connection> SOCKS5ProxiedConnectionFactory::createConnection() { + return SOCKS5ProxiedConnection::create(resolver_, connectionFactory_, timerFactory_, proxyHost_, proxyPort_); } } diff --git a/Swiften/Network/SOCKS5ProxiedConnectionFactory.h b/Swiften/Network/SOCKS5ProxiedConnectionFactory.h index 4c5c585..47ae9a3 100644 --- a/Swiften/Network/SOCKS5ProxiedConnectionFactory.h +++ b/Swiften/Network/SOCKS5ProxiedConnectionFactory.h @@ -4,27 +4,34 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2015-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once +#include <Swiften/Base/API.h> #include <Swiften/Network/ConnectionFactory.h> #include <Swiften/Network/HostAddressPort.h> #include <Swiften/Network/HostNameOrAddress.h> namespace Swift { - class DomainNameResolver; - class TimerFactory; + class DomainNameResolver; + class TimerFactory; - class SOCKS5ProxiedConnectionFactory : public ConnectionFactory { - public: - SOCKS5ProxiedConnectionFactory(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, int proxyPort); + class SWIFTEN_API SOCKS5ProxiedConnectionFactory : public ConnectionFactory { + public: + SOCKS5ProxiedConnectionFactory(DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, const std::string& proxyHost, unsigned short proxyPort); - virtual boost::shared_ptr<Connection> createConnection(); + virtual std::shared_ptr<Connection> createConnection(); - private: - DomainNameResolver* resolver_; - ConnectionFactory* connectionFactory_; - TimerFactory* timerFactory_; - std::string proxyHost_; - int proxyPort_; - }; + private: + DomainNameResolver* resolver_; + ConnectionFactory* connectionFactory_; + TimerFactory* timerFactory_; + std::string proxyHost_; + unsigned short proxyPort_; + }; } diff --git a/Swiften/Network/SolarisNetworkEnvironment.cpp b/Swiften/Network/SolarisNetworkEnvironment.cpp new file mode 100644 index 0000000..db8c740 --- /dev/null +++ b/Swiften/Network/SolarisNetworkEnvironment.cpp @@ -0,0 +1,294 @@ +/* + * Copyright (c) 2011 Tobias Markmann + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +/* + * Copyright (c) 2013-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + +#include <Swiften/Network/SolarisNetworkEnvironment.h> + +#include <errno.h> +#include <stdlib.h> +#include <string.h> + +#include <map> +#include <string> +#include <vector> + +#include <boost/optional.hpp> +#include <boost/signals2.hpp> + +#include <net/if.h> +#include <sys/socket.h> +#include <sys/sockio.h> +#include <sys/types.h> +#include <unistd.h> + +#include <Swiften/Network/HostAddress.h> +#include <Swiften/Network/NetworkInterface.h> + +/* + * Copyright (c) 2006 WIDE Project. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the project nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE PROJECT AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE PROJECT OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +#undef ifa_broadaddr +#undef ifa_dstaddr +struct ifaddrs { + struct ifaddrs *ifa_next; /* Pointer to next struct */ + char *ifa_name; /* Interface name */ + uint64_t ifa_flags; /* Interface flags */ + struct sockaddr *ifa_addr; /* Interface address */ + struct sockaddr *ifa_netmask; /* Interface netmask */ + struct sockaddr *ifa_dstaddr; /* P2P interface destination */ +}; +#define ifa_broadaddr ifa_dstaddr + +static int +get_lifreq(int fd, struct lifreq **ifr_ret) +{ + struct lifnum lifn; + struct lifconf lifc; + struct lifreq *lifrp; + + lifn.lifn_family = AF_UNSPEC; + lifn.lifn_flags = 0; + if (ioctl(fd, SIOCGLIFNUM, &lifn) == -1) + lifn.lifn_count = 16; + else + lifn.lifn_count += 16; + + for (;;) { + lifc.lifc_len = lifn.lifn_count * sizeof (*lifrp); + lifrp = (struct lifreq *) malloc(lifc.lifc_len); + if (lifrp == NULL) + return (-1); + + lifc.lifc_family = AF_UNSPEC; + lifc.lifc_flags = 0; + lifc.lifc_buf = (char *)lifrp; + if (ioctl(fd, SIOCGLIFCONF, &lifc) == -1) { + free(lifrp); + if (errno == EINVAL) { + lifn.lifn_count <<= 1; + continue; + } + (void) close(fd); + return (-1); + } + if (lifc.lifc_len < (lifn.lifn_count - 1) * sizeof (*lifrp)) + break; + free(lifrp); + lifn.lifn_count <<= 1; + } + (void) close(fd); + + *ifr_ret = lifrp; + + return (lifc.lifc_len / sizeof (*lifrp)); +} + +static size_t +nbytes(const struct lifreq *lifrp, int nlif, size_t socklen) +{ + size_t len = 0; + size_t slen; + + while (nlif > 0) { + slen = strlen(lifrp->lifr_name) + 1; + len += sizeof (struct ifaddrs) + ((slen + 3) & ~3); + len += 3 * socklen; + lifrp++; + nlif--; + } + return (len); +} + +static struct sockaddr * +addrcpy(struct sockaddr_storage *addr, char **bufp) +{ + char *buf = *bufp; + size_t len; + + len = addr->ss_family == AF_INET ? sizeof (struct sockaddr_in) : + sizeof (struct sockaddr_in6); + (void) memcpy(buf, addr, len); + *bufp = buf + len; + return ((struct sockaddr *)buf); +} + +static int +populate(struct ifaddrs *ifa, int fd, struct lifreq *lifrp, int nlif, int af, + char **bufp) +{ + char *buf = *bufp; + size_t slen; + + while (nlif > 0) { + ifa->ifa_next = (nlif > 1) ? ifa + 1 : NULL; + (void) strcpy(ifa->ifa_name = buf, lifrp->lifr_name); + slen = strlen(lifrp->lifr_name) + 1; + buf += (slen + 3) & ~3; + if (ioctl(fd, SIOCGLIFFLAGS, lifrp) == -1) + ifa->ifa_flags = 0; + else + ifa->ifa_flags = lifrp->lifr_flags; + if (ioctl(fd, SIOCGLIFADDR, lifrp) == -1) + ifa->ifa_addr = NULL; + else + ifa->ifa_addr = addrcpy(&lifrp->lifr_addr, &buf); + if (ioctl(fd, SIOCGLIFNETMASK, lifrp) == -1) + ifa->ifa_netmask = NULL; + else + ifa->ifa_netmask = addrcpy(&lifrp->lifr_addr, &buf); + if (ifa->ifa_flags & IFF_POINTOPOINT) { + if (ioctl(fd, SIOCGLIFDSTADDR, lifrp) == -1) + ifa->ifa_dstaddr = NULL; + else + ifa->ifa_dstaddr = + addrcpy(&lifrp->lifr_dstaddr, &buf); + } else if (ifa->ifa_flags & IFF_BROADCAST) { + if (ioctl(fd, SIOCGLIFBRDADDR, lifrp) == -1) + ifa->ifa_broadaddr = NULL; + else + ifa->ifa_broadaddr = + addrcpy(&lifrp->lifr_broadaddr, &buf); + } else { + ifa->ifa_dstaddr = NULL; + } + + ifa++; + nlif--; + lifrp++; + } + *bufp = buf; + return (0); +} + +static int +getifaddrs(struct ifaddrs **ifap) +{ + int fd4, fd6; + int nif4, nif6 = 0; + struct lifreq *ifr4 = NULL; + struct lifreq *ifr6 = NULL; + struct ifaddrs *ifa = NULL; + char *buf; + + if ((fd4 = socket(AF_INET, SOCK_DGRAM, 0)) == -1) + return (-1); + if ((fd6 = socket(AF_INET6, SOCK_DGRAM, 0)) == -1 && + errno != EAFNOSUPPORT) { + (void) close(fd4); + return (-1); + } + + if ((nif4 = get_lifreq(fd4, &ifr4)) == -1 || + (fd6 != -1 && (nif6 = get_lifreq(fd6, &ifr6)) == -1)) + goto failure; + + if (nif4 == 0 && nif6 == 0) { + *ifap = NULL; + return (0); + } + + ifa = (struct ifaddrs *) malloc(nbytes(ifr4, nif4, sizeof (struct sockaddr_in)) + + nbytes(ifr6, nif6, sizeof (struct sockaddr_in6))); + if (ifa == NULL) + goto failure; + + buf = (char *)(ifa + nif4 + nif6); + + if (populate(ifa, fd4, ifr4, nif4, AF_INET, &buf) == -1) + goto failure; + if (nif4 > 0 && nif6 > 0) + ifa[nif4 - 1].ifa_next = ifa + nif4; + if (populate(ifa + nif4, fd6, ifr6, nif6, AF_INET6, &buf) == -1) + goto failure; + + return (0); + +failure: + free(ifa); + (void) close(fd4); + if (fd6 != -1) + (void) close(fd6); + free(ifr4); + free(ifr6); + return (-1); +} + +static void +freeifaddrs(struct ifaddrs *ifa) +{ + free(ifa); +} + +/* End WIDE Project code */ + +namespace Swift { + +std::vector<NetworkInterface> SolarisNetworkEnvironment::getNetworkInterfaces() const { + std::map<std::string, NetworkInterface> interfaces; + + ifaddrs* addrs = 0; + int ret = getifaddrs(&addrs); + if (ret != 0) { + return std::vector<NetworkInterface>(); + } + + for (ifaddrs* a = addrs; a != 0; a = a->ifa_next) { + std::string name(a->ifa_name); + boost::optional<HostAddress> address; + if (a->ifa_addr->sa_family == PF_INET) { + sockaddr_in* sa = reinterpret_cast<sockaddr_in*>(a->ifa_addr); + address = HostAddress(reinterpret_cast<const unsigned char*>(&(sa->sin_addr)), 4); + } + else if (a->ifa_addr->sa_family == PF_INET6) { + sockaddr_in6* sa = reinterpret_cast<sockaddr_in6*>(a->ifa_addr); + address = HostAddress(reinterpret_cast<const unsigned char*>(&(sa->sin6_addr)), 16); + } + if (address) { + std::map<std::string, NetworkInterface>::iterator i = interfaces.insert(std::make_pair(name, NetworkInterface(name, a->ifa_flags & IFF_LOOPBACK))).first; + i->second.addAddress(*address); + } + } + + freeifaddrs(addrs); + + std::vector<NetworkInterface> result; + for (std::map<std::string,NetworkInterface>::const_iterator i = interfaces.begin(); i != interfaces.end(); ++i) { + result.push_back(i->second); + } + return result; +} + +} diff --git a/Swiften/Network/SolarisNetworkEnvironment.h b/Swiften/Network/SolarisNetworkEnvironment.h new file mode 100644 index 0000000..199fc6f --- /dev/null +++ b/Swiften/Network/SolarisNetworkEnvironment.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2011 Tobias Markmann + * Licensed under the simplified BSD license. + * See Documentation/Licenses/BSD-simplified.txt for more information. + */ + +/* + * Copyright (c) 2013-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + +#pragma once + +#include <vector> + +#include <boost/signals2.hpp> + +#include <Swiften/Network/NetworkEnvironment.h> +#include <Swiften/Network/NetworkInterface.h> + +namespace Swift { + + class SolarisNetworkEnvironment : public NetworkEnvironment { + public: + std::vector<NetworkInterface> getNetworkInterfaces() const; + }; + +} diff --git a/Swiften/Network/StaticDomainNameResolver.cpp b/Swiften/Network/StaticDomainNameResolver.cpp index ee18ee5..eca6687 100644 --- a/Swiften/Network/StaticDomainNameResolver.cpp +++ b/Swiften/Network/StaticDomainNameResolver.cpp @@ -1,80 +1,85 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/StaticDomainNameResolver.h> +#include <string> + #include <boost/bind.hpp> #include <boost/lexical_cast.hpp> -#include <Swiften/Network/DomainNameResolveError.h> #include <Swiften/EventLoop/EventOwner.h> -#include <string> +#include <Swiften/Network/DomainNameResolveError.h> using namespace Swift; namespace { - struct ServiceQuery : public DomainNameServiceQuery, public boost::enable_shared_from_this<ServiceQuery> { - ServiceQuery(const std::string& service, Swift::StaticDomainNameResolver* resolver, EventLoop* eventLoop, boost::shared_ptr<EventOwner> owner) : eventLoop(eventLoop), service(service), resolver(resolver), owner(owner) {} - - virtual void run() { - if (!resolver->getIsResponsive()) { - return; - } - std::vector<DomainNameServiceQuery::Result> results; - for(StaticDomainNameResolver::ServicesCollection::const_iterator i = resolver->getServices().begin(); i != resolver->getServices().end(); ++i) { - if (i->first == service) { - results.push_back(i->second); - } - } - eventLoop->postEvent(boost::bind(&ServiceQuery::emitOnResult, shared_from_this(), results), owner); - } - - void emitOnResult(std::vector<DomainNameServiceQuery::Result> results) { - onResult(results); - } - - EventLoop* eventLoop; - std::string service; - StaticDomainNameResolver* resolver; - boost::shared_ptr<EventOwner> owner; - }; - - struct AddressQuery : public DomainNameAddressQuery, public boost::enable_shared_from_this<AddressQuery> { - AddressQuery(const std::string& host, StaticDomainNameResolver* resolver, EventLoop* eventLoop, boost::shared_ptr<EventOwner> owner) : eventLoop(eventLoop), host(host), resolver(resolver), owner(owner) {} - - virtual void run() { - if (!resolver->getIsResponsive()) { - return; - } - StaticDomainNameResolver::AddressesMap::const_iterator i = resolver->getAddresses().find(host); - if (i != resolver->getAddresses().end()) { - eventLoop->postEvent( - boost::bind(&AddressQuery::emitOnResult, shared_from_this(), i->second, boost::optional<DomainNameResolveError>())); - } - else { - eventLoop->postEvent(boost::bind(&AddressQuery::emitOnResult, shared_from_this(), std::vector<HostAddress>(), boost::optional<DomainNameResolveError>(DomainNameResolveError())), owner); - } - } - - void emitOnResult(std::vector<HostAddress> results, boost::optional<DomainNameResolveError> error) { - onResult(results, error); - } - - EventLoop* eventLoop; - std::string host; - StaticDomainNameResolver* resolver; - boost::shared_ptr<EventOwner> owner; - }; + struct ServiceQuery : public DomainNameServiceQuery, public std::enable_shared_from_this<ServiceQuery> { + ServiceQuery(const std::string& service, Swift::StaticDomainNameResolver* resolver, EventLoop* eventLoop, std::shared_ptr<EventOwner> owner) : eventLoop(eventLoop), service(service), resolver(resolver), owner(owner) {} + + virtual void run() { + if (!resolver->getIsResponsive()) { + return; + } + std::vector<DomainNameServiceQuery::Result> results; + for(const auto& i : resolver->getServices()) { + if (i.first == service) { + results.push_back(i.second); + } + } + eventLoop->postEvent(boost::bind(&ServiceQuery::emitOnResult, shared_from_this(), results), owner); + } + + void emitOnResult(std::vector<DomainNameServiceQuery::Result> results) { + onResult(results); + } + + EventLoop* eventLoop; + std::string service; + StaticDomainNameResolver* resolver; + std::shared_ptr<EventOwner> owner; + }; + + struct AddressQuery : public DomainNameAddressQuery, public std::enable_shared_from_this<AddressQuery> { + AddressQuery(const std::string& host, StaticDomainNameResolver* resolver, EventLoop* eventLoop, std::shared_ptr<EventOwner> owner) : eventLoop(eventLoop), host(host), resolver(resolver), owner(owner) {} + + virtual void run() { + if (!resolver->getIsResponsive()) { + return; + } + if (auto address = HostAddress::fromString(host)) { + // IP Literals should resolve to themselves + resolver->addAddress(host, *address); + } + StaticDomainNameResolver::AddressesMap::const_iterator i = resolver->getAddresses().find(host); + if (i != resolver->getAddresses().end()) { + eventLoop->postEvent( + boost::bind(&AddressQuery::emitOnResult, shared_from_this(), i->second, boost::optional<DomainNameResolveError>())); + } + else { + eventLoop->postEvent(boost::bind(&AddressQuery::emitOnResult, shared_from_this(), std::vector<HostAddress>(), boost::optional<DomainNameResolveError>(DomainNameResolveError())), owner); + } + } + + void emitOnResult(std::vector<HostAddress> results, boost::optional<DomainNameResolveError> error) { + onResult(results, error); + } + + EventLoop* eventLoop; + std::string host; + StaticDomainNameResolver* resolver; + std::shared_ptr<EventOwner> owner; + }; } class StaticDomainNameResolverEventOwner : public EventOwner { - public: - ~StaticDomainNameResolverEventOwner() { + public: + ~StaticDomainNameResolverEventOwner() { - } + } }; @@ -84,36 +89,36 @@ StaticDomainNameResolver::StaticDomainNameResolver(EventLoop* eventLoop) : event } StaticDomainNameResolver::~StaticDomainNameResolver() { - eventLoop->removeEventsFromOwner(owner); + eventLoop->removeEventsFromOwner(owner); } void StaticDomainNameResolver::addAddress(const std::string& domain, const HostAddress& address) { - addresses[domain].push_back(address); + addresses[domain].push_back(address); } void StaticDomainNameResolver::addService(const std::string& service, const DomainNameServiceQuery::Result& result) { - services.push_back(std::make_pair(service, result)); + services.push_back(std::make_pair(service, result)); } void StaticDomainNameResolver::addXMPPClientService(const std::string& domain, const HostAddressPort& address) { - static int hostid = 0; - std::string hostname(std::string("host-") + boost::lexical_cast<std::string>(hostid)); - hostid++; + static int hostid = 0; + std::string hostname(std::string("host-") + std::to_string(hostid)); + hostid++; - addService("_xmpp-client._tcp." + domain, ServiceQuery::Result(hostname, address.getPort(), 0, 0)); - addAddress(hostname, address.getAddress()); + addService("_xmpp-client._tcp." + domain, ServiceQuery::Result(hostname, address.getPort(), 0, 0)); + addAddress(hostname, address.getAddress()); } -void StaticDomainNameResolver::addXMPPClientService(const std::string& domain, const std::string& hostname, int port) { - addService("_xmpp-client._tcp." + domain, ServiceQuery::Result(hostname, port, 0, 0)); +void StaticDomainNameResolver::addXMPPClientService(const std::string& domain, const std::string& hostname, unsigned short port) { + addService("_xmpp-client._tcp." + domain, ServiceQuery::Result(hostname, port, 0, 0)); } -boost::shared_ptr<DomainNameServiceQuery> StaticDomainNameResolver::createServiceQuery(const std::string& name) { - return boost::shared_ptr<DomainNameServiceQuery>(new ServiceQuery(name, this, eventLoop, owner)); +std::shared_ptr<DomainNameServiceQuery> StaticDomainNameResolver::createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain) { + return std::make_shared<ServiceQuery>(serviceLookupPrefix + domain, this, eventLoop, owner); } -boost::shared_ptr<DomainNameAddressQuery> StaticDomainNameResolver::createAddressQuery(const std::string& name) { - return boost::shared_ptr<DomainNameAddressQuery>(new AddressQuery(name, this, eventLoop, owner)); +std::shared_ptr<DomainNameAddressQuery> StaticDomainNameResolver::createAddressQuery(const std::string& name) { + return std::make_shared<AddressQuery>(name, this, eventLoop, owner); } } diff --git a/Swiften/Network/StaticDomainNameResolver.h b/Swiften/Network/StaticDomainNameResolver.h index 386179b..2064046 100644 --- a/Swiften/Network/StaticDomainNameResolver.h +++ b/Swiften/Network/StaticDomainNameResolver.h @@ -1,60 +1,61 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <vector> #include <map> +#include <memory> +#include <vector> #include <Swiften/Base/API.h> -#include <Swiften/Network/HostAddress.h> -#include <Swiften/Network/HostAddressPort.h> +#include <Swiften/EventLoop/EventLoop.h> +#include <Swiften/Network/DomainNameAddressQuery.h> #include <Swiften/Network/DomainNameResolver.h> #include <Swiften/Network/DomainNameServiceQuery.h> -#include <Swiften/Network/DomainNameAddressQuery.h> -#include <Swiften/EventLoop/EventLoop.h> +#include <Swiften/Network/HostAddress.h> +#include <Swiften/Network/HostAddressPort.h> namespace Swift { - class SWIFTEN_API StaticDomainNameResolver : public DomainNameResolver { - public: - typedef std::map<std::string, std::vector<HostAddress> > AddressesMap; - typedef std::vector< std::pair<std::string, DomainNameServiceQuery::Result> > ServicesCollection; - - public: - StaticDomainNameResolver(EventLoop* eventLoop); - ~StaticDomainNameResolver(); - - void addAddress(const std::string& domain, const HostAddress& address); - void addService(const std::string& service, const DomainNameServiceQuery::Result& result); - void addXMPPClientService(const std::string& domain, const HostAddressPort&); - void addXMPPClientService(const std::string& domain, const std::string& host, int port); - - const AddressesMap& getAddresses() const { - return addresses; - } - - const ServicesCollection& getServices() const { - return services; - } - - bool getIsResponsive() const { - return isResponsive; - } - - void setIsResponsive(bool b) { - isResponsive = b; - } - - virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const std::string& name); - virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const std::string& name); - private: - EventLoop* eventLoop; - bool isResponsive; - AddressesMap addresses; - ServicesCollection services; - boost::shared_ptr<EventOwner> owner; - }; + class SWIFTEN_API StaticDomainNameResolver : public DomainNameResolver { + public: + typedef std::map<std::string, std::vector<HostAddress> > AddressesMap; + typedef std::vector< std::pair<std::string, DomainNameServiceQuery::Result> > ServicesCollection; + + public: + StaticDomainNameResolver(EventLoop* eventLoop); + virtual ~StaticDomainNameResolver(); + + void addAddress(const std::string& domain, const HostAddress& address); + void addService(const std::string& service, const DomainNameServiceQuery::Result& result); + void addXMPPClientService(const std::string& domain, const HostAddressPort&); + void addXMPPClientService(const std::string& domain, const std::string& host, unsigned short port); + + const AddressesMap& getAddresses() const { + return addresses; + } + + const ServicesCollection& getServices() const { + return services; + } + + bool getIsResponsive() const { + return isResponsive; + } + + void setIsResponsive(bool b) { + isResponsive = b; + } + + virtual std::shared_ptr<DomainNameServiceQuery> createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain); + virtual std::shared_ptr<DomainNameAddressQuery> createAddressQuery(const std::string& name); + private: + EventLoop* eventLoop; + bool isResponsive; + AddressesMap addresses; + ServicesCollection services; + std::shared_ptr<EventOwner> owner; + }; } diff --git a/Swiften/Network/TLSConnection.cpp b/Swiften/Network/TLSConnection.cpp index 543ee1e..82bf114 100644 --- a/Swiften/Network/TLSConnection.cpp +++ b/Swiften/Network/TLSConnection.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2011 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/TLSConnection.h> @@ -14,78 +14,85 @@ namespace Swift { -TLSConnection::TLSConnection(Connection::ref connection, TLSContextFactory* tlsFactory) : connection(connection) { - context = tlsFactory->createTLSContext(); - context->onDataForNetwork.connect(boost::bind(&TLSConnection::handleTLSDataForNetwork, this, _1)); - context->onDataForApplication.connect(boost::bind(&TLSConnection::handleTLSDataForApplication, this, _1)); - context->onConnected.connect(boost::bind(&TLSConnection::handleTLSConnectFinished, this, false)); - context->onError.connect(boost::bind(&TLSConnection::handleTLSConnectFinished, this, true)); - - connection->onConnectFinished.connect(boost::bind(&TLSConnection::handleRawConnectFinished, this, _1)); - connection->onDataRead.connect(boost::bind(&TLSConnection::handleRawDataRead, this, _1)); - connection->onDataWritten.connect(boost::bind(&TLSConnection::handleRawDataWritten, this)); - connection->onDisconnected.connect(boost::bind(&TLSConnection::handleRawDisconnected, this, _1)); +TLSConnection::TLSConnection(Connection::ref connection, TLSContextFactory* tlsFactory, const TLSOptions& tlsOptions) : connection(connection) { + context = tlsFactory->createTLSContext(tlsOptions); + context->onDataForNetwork.connect(boost::bind(&TLSConnection::handleTLSDataForNetwork, this, _1)); + context->onDataForApplication.connect(boost::bind(&TLSConnection::handleTLSDataForApplication, this, _1)); + context->onConnected.connect(boost::bind(&TLSConnection::handleTLSConnectFinished, this, false)); + context->onError.connect(boost::bind(&TLSConnection::handleTLSConnectFinished, this, true)); + + connection->onConnectFinished.connect(boost::bind(&TLSConnection::handleRawConnectFinished, this, _1)); + connection->onDataRead.connect(boost::bind(&TLSConnection::handleRawDataRead, this, _1)); + connection->onDataWritten.connect(boost::bind(&TLSConnection::handleRawDataWritten, this)); + connection->onDisconnected.connect(boost::bind(&TLSConnection::handleRawDisconnected, this, _1)); } TLSConnection::~TLSConnection() { - connection->onConnectFinished.disconnect(boost::bind(&TLSConnection::handleRawConnectFinished, this, _1)); - connection->onDataRead.disconnect(boost::bind(&TLSConnection::handleRawDataRead, this, _1)); - connection->onDataWritten.disconnect(boost::bind(&TLSConnection::handleRawDataWritten, this)); - connection->onDisconnected.disconnect(boost::bind(&TLSConnection::handleRawDisconnected, this, _1)); - delete context; + connection->onConnectFinished.disconnect(boost::bind(&TLSConnection::handleRawConnectFinished, this, _1)); + connection->onDataRead.disconnect(boost::bind(&TLSConnection::handleRawDataRead, this, _1)); + connection->onDataWritten.disconnect(boost::bind(&TLSConnection::handleRawDataWritten, this)); + connection->onDisconnected.disconnect(boost::bind(&TLSConnection::handleRawDisconnected, this, _1)); } void TLSConnection::handleTLSConnectFinished(bool error) { - onConnectFinished(error); - if (error) { - disconnect(); - } + onConnectFinished(error); + if (error) { + disconnect(); + } } void TLSConnection::handleTLSDataForNetwork(const SafeByteArray& data) { - connection->write(data); + connection->write(data); } void TLSConnection::handleTLSDataForApplication(const SafeByteArray& data) { - onDataRead(boost::make_shared<SafeByteArray>(data)); + onDataRead(std::make_shared<SafeByteArray>(data)); } void TLSConnection::connect(const HostAddressPort& address) { - connection->connect(address); + connection->connect(address); } void TLSConnection::disconnect() { - connection->disconnect(); + connection->disconnect(); } void TLSConnection::write(const SafeByteArray& data) { - context->handleDataFromApplication(data); + context->handleDataFromApplication(data); } HostAddressPort TLSConnection::getLocalAddress() const { - return connection->getLocalAddress(); + return connection->getLocalAddress(); +} + +HostAddressPort TLSConnection::getRemoteAddress() const { + return connection->getRemoteAddress(); +} + +TLSContext* TLSConnection::getTLSContext() const { + return context.get(); } void TLSConnection::handleRawConnectFinished(bool error) { - connection->onConnectFinished.disconnect(boost::bind(&TLSConnection::handleRawConnectFinished, this, _1)); - if (error) { - onConnectFinished(true); - } - else { - context->connect(); - } + connection->onConnectFinished.disconnect(boost::bind(&TLSConnection::handleRawConnectFinished, this, _1)); + if (error) { + onConnectFinished(true); + } + else { + context->connect(); + } } void TLSConnection::handleRawDisconnected(const boost::optional<Error>& error) { - onDisconnected(error); + onDisconnected(error); } -void TLSConnection::handleRawDataRead(boost::shared_ptr<SafeByteArray> data) { - context->handleDataFromNetwork(*data); +void TLSConnection::handleRawDataRead(std::shared_ptr<SafeByteArray> data) { + context->handleDataFromNetwork(*data); } void TLSConnection::handleRawDataWritten() { - onDataWritten(); + onDataWritten(); } } diff --git a/Swiften/Network/TLSConnection.h b/Swiften/Network/TLSConnection.h index 60f73ea..1ab1ec6 100644 --- a/Swiften/Network/TLSConnection.h +++ b/Swiften/Network/TLSConnection.h @@ -1,46 +1,52 @@ /* - * Copyright (c) 2011 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> -#include <boost/enable_shared_from_this.hpp> -#include <Swiften/Base/boost_bsignals.h> +#include <memory> +#include <boost/signals2.hpp> + +#include <Swiften/Base/API.h> #include <Swiften/Base/SafeByteArray.h> #include <Swiften/Network/Connection.h> +#include <Swiften/TLS/TLSOptions.h> namespace Swift { - class HostAddressPort; - class TLSContextFactory; - class TLSContext; - - class TLSConnection : public Connection { - public: - - TLSConnection(Connection::ref connection, TLSContextFactory* tlsFactory); - virtual ~TLSConnection(); - - virtual void listen() {assert(false);} - virtual void connect(const HostAddressPort& address); - virtual void disconnect(); - virtual void write(const SafeByteArray& data); - - virtual HostAddressPort getLocalAddress() const; - - private: - void handleRawConnectFinished(bool error); - void handleRawDisconnected(const boost::optional<Error>& error); - void handleRawDataRead(boost::shared_ptr<SafeByteArray> data); - void handleRawDataWritten(); - void handleTLSConnectFinished(bool error); - void handleTLSDataForNetwork(const SafeByteArray& data); - void handleTLSDataForApplication(const SafeByteArray& data); - private: - TLSContext* context; - Connection::ref connection; - }; + class HostAddressPort; + class TLSContextFactory; + class TLSContext; + + class SWIFTEN_API TLSConnection : public Connection { + public: + + TLSConnection(Connection::ref connection, TLSContextFactory* tlsFactory, const TLSOptions&); + virtual ~TLSConnection(); + + virtual void listen() {assert(false);} + virtual void connect(const HostAddressPort& address); + virtual void disconnect(); + virtual void write(const SafeByteArray& data); + + virtual HostAddressPort getLocalAddress() const; + virtual HostAddressPort getRemoteAddress() const; + + TLSContext* getTLSContext() const; + + private: + void handleRawConnectFinished(bool error); + void handleRawDisconnected(const boost::optional<Error>& error); + void handleRawDataRead(std::shared_ptr<SafeByteArray> data); + void handleRawDataWritten(); + void handleTLSConnectFinished(bool error); + void handleTLSDataForNetwork(const SafeByteArray& data); + void handleTLSDataForApplication(const SafeByteArray& data); + + private: + std::unique_ptr<TLSContext> context; + Connection::ref connection; + }; } diff --git a/Swiften/Network/TLSConnectionFactory.cpp b/Swiften/Network/TLSConnectionFactory.cpp index 0c21650..b311c7d 100644 --- a/Swiften/Network/TLSConnectionFactory.cpp +++ b/Swiften/Network/TLSConnectionFactory.cpp @@ -1,18 +1,18 @@ /* - * Copyright (c) 2011 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/TLSConnectionFactory.h> -#include <boost/shared_ptr.hpp> +#include <memory> #include <Swiften/Network/TLSConnection.h> namespace Swift { -TLSConnectionFactory::TLSConnectionFactory(TLSContextFactory* contextFactory, ConnectionFactory* connectionFactory) : contextFactory(contextFactory), connectionFactory(connectionFactory){ +TLSConnectionFactory::TLSConnectionFactory(TLSContextFactory* contextFactory, ConnectionFactory* connectionFactory, const TLSOptions& o) : contextFactory(contextFactory), connectionFactory(connectionFactory), options_(o) { } @@ -21,8 +21,8 @@ TLSConnectionFactory::~TLSConnectionFactory() { } -boost::shared_ptr<Connection> TLSConnectionFactory::createConnection() { - return boost::make_shared<TLSConnection>(connectionFactory->createConnection(), contextFactory); +std::shared_ptr<Connection> TLSConnectionFactory::createConnection() { + return std::make_shared<TLSConnection>(connectionFactory->createConnection(), contextFactory, options_); } } diff --git a/Swiften/Network/TLSConnectionFactory.h b/Swiften/Network/TLSConnectionFactory.h index 32757a1..148e345 100644 --- a/Swiften/Network/TLSConnectionFactory.h +++ b/Swiften/Network/TLSConnectionFactory.h @@ -1,27 +1,30 @@ /* - * Copyright (c) 2011 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> +#include <Swiften/Base/API.h> #include <Swiften/Network/ConnectionFactory.h> #include <Swiften/TLS/TLSContextFactory.h> +#include <Swiften/TLS/TLSOptions.h> namespace Swift { - class Connection; + class Connection; - class TLSConnectionFactory : public ConnectionFactory { - public: - TLSConnectionFactory(TLSContextFactory* contextFactory, ConnectionFactory* connectionFactory); - virtual ~TLSConnectionFactory(); + class SWIFTEN_API TLSConnectionFactory : public ConnectionFactory { + public: + TLSConnectionFactory(TLSContextFactory* contextFactory, ConnectionFactory* connectionFactory, const TLSOptions&); + virtual ~TLSConnectionFactory(); - virtual boost::shared_ptr<Connection> createConnection(); - private: - TLSContextFactory* contextFactory; - ConnectionFactory* connectionFactory; - }; + virtual std::shared_ptr<Connection> createConnection(); + private: + TLSContextFactory* contextFactory; + ConnectionFactory* connectionFactory; + TLSOptions options_; + }; } diff --git a/Swiften/Network/Timer.cpp b/Swiften/Network/Timer.cpp index 3efbd3b..f1d16bb 100644 --- a/Swiften/Network/Timer.cpp +++ b/Swiften/Network/Timer.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/Timer.h> diff --git a/Swiften/Network/Timer.h b/Swiften/Network/Timer.h index d08cf3c..977ed89 100644 --- a/Swiften/Network/Timer.h +++ b/Swiften/Network/Timer.h @@ -1,41 +1,42 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <Swiften/Base/boost_bsignals.h> +#include <boost/signals2.hpp> + #include <Swiften/Base/API.h> namespace Swift { - /** - * A class for triggering an event after a given period. - */ - class SWIFTEN_API Timer { - public: - typedef boost::shared_ptr<Timer> ref; - - virtual ~Timer(); - - /** - * Starts the timer. - * - * After the given period, onTick() will be called. - */ - virtual void start() = 0; - - /** - * Cancels the timer. - * - * If the timer was started, onTick() will no longer be called. - */ - virtual void stop() = 0; - - /** - * Emitted when the timer expires. - */ - boost::signal<void ()> onTick; - }; + /** + * A class for triggering an event after a given period. + */ + class SWIFTEN_API Timer { + public: + typedef std::shared_ptr<Timer> ref; + + virtual ~Timer(); + + /** + * Starts the timer. + * + * After the given period, onTick() will be called. + */ + virtual void start() = 0; + + /** + * Cancels the timer. + * + * If the timer was started, onTick() will no longer be called. + */ + virtual void stop() = 0; + + /** + * Emitted when the timer expires. + */ + boost::signals2::signal<void ()> onTick; + }; } diff --git a/Swiften/Network/TimerFactory.cpp b/Swiften/Network/TimerFactory.cpp index 3fb807c..86b8b2b 100644 --- a/Swiften/Network/TimerFactory.cpp +++ b/Swiften/Network/TimerFactory.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <Swiften/Network/TimerFactory.h> diff --git a/Swiften/Network/TimerFactory.h b/Swiften/Network/TimerFactory.h index 62850bc..ebb1b6e 100644 --- a/Swiften/Network/TimerFactory.h +++ b/Swiften/Network/TimerFactory.h @@ -1,21 +1,21 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #pragma once -#include <boost/shared_ptr.hpp> +#include <memory> #include <Swiften/Base/API.h> #include <Swiften/Network/Timer.h> namespace Swift { - class SWIFTEN_API TimerFactory { - public: - virtual ~TimerFactory(); + class SWIFTEN_API TimerFactory { + public: + virtual ~TimerFactory(); - virtual Timer::ref createTimer(int milliseconds) = 0; - }; + virtual Timer::ref createTimer(int milliseconds) = 0; + }; } diff --git a/Swiften/Network/UnboundDomainNameResolver.cpp b/Swiften/Network/UnboundDomainNameResolver.cpp index d986385..21bc697 100755..100644 --- a/Swiften/Network/UnboundDomainNameResolver.cpp +++ b/Swiften/Network/UnboundDomainNameResolver.cpp @@ -4,238 +4,249 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ -#include "UnboundDomainNameResolver.h" +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ +#include <Swiften/Network/UnboundDomainNameResolver.h> + +#include <memory> #include <vector> #include <boost/bind.hpp> -#include <boost/smart_ptr/make_shared.hpp> -#include <boost/enable_shared_from_this.hpp> + +#include <arpa/inet.h> +#include <ldns/ldns.h> +#include <unbound.h> +#include <unistd.h> #include <Swiften/Base/Log.h> #include <Swiften/EventLoop/EventLoop.h> +#include <Swiften/IDN/IDNConverter.h> #include <Swiften/Network/DomainNameAddressQuery.h> #include <Swiften/Network/DomainNameResolveError.h> #include <Swiften/Network/DomainNameServiceQuery.h> #include <Swiften/Network/HostAddress.h> #include <Swiften/Network/TimerFactory.h> -#include <arpa/inet.h> -#include <unbound.h> -#include <ldns/ldns.h> -#include <unistd.h> - namespace Swift { class UnboundQuery { - public: - UnboundQuery(UnboundDomainNameResolver* resolver, ub_ctx* context) : resolver(resolver), ubContext(context) {} - virtual ~UnboundQuery() {} - virtual void handleResult(int err, ub_result* result) = 0; - protected: - UnboundDomainNameResolver* resolver; - ub_ctx* ubContext; + public: + UnboundQuery(UnboundDomainNameResolver* resolver, ub_ctx* context) : resolver(resolver), ubContext(context) {} + virtual ~UnboundQuery() {} + virtual void handleResult(int err, ub_result* result) = 0; + protected: + UnboundDomainNameResolver* resolver; + ub_ctx* ubContext; }; struct UnboundWrapperHelper { - UnboundWrapperHelper(UnboundDomainNameResolver* resolver, boost::shared_ptr<UnboundQuery> query) : resolver(resolver), query(query) {} - UnboundDomainNameResolver* resolver; - boost::shared_ptr<UnboundQuery> query; + UnboundWrapperHelper(UnboundDomainNameResolver* resolver, std::shared_ptr<UnboundQuery> query) : resolver(resolver), query(query) {} + UnboundDomainNameResolver* resolver; + std::shared_ptr<UnboundQuery> query; }; -class UnboundDomainNameServiceQuery : public DomainNameServiceQuery, public UnboundQuery, public boost::enable_shared_from_this<UnboundDomainNameServiceQuery> { - public: - UnboundDomainNameServiceQuery(UnboundDomainNameResolver* resolver, ub_ctx* context, std::string name) : UnboundQuery(resolver, context), name(name) { - } - - virtual ~UnboundDomainNameServiceQuery() { } - - virtual void run() { - int retval; - UnboundWrapperHelper* helper = new UnboundWrapperHelper(resolver, shared_from_this()); - - retval = ub_resolve_async(ubContext, const_cast<char*>(name.c_str()), LDNS_RR_TYPE_SRV, - 1 /* CLASS IN (internet) */, - helper, UnboundDomainNameResolver::unbound_callback_wrapper, NULL); - if(retval != 0) { - SWIFT_LOG(debug) << "resolve error: " << ub_strerror(retval) << std::endl; - delete helper; - } - } - - void handleResult(int err, struct ub_result* result) { - std::vector<DomainNameServiceQuery::Result> serviceRecords; - - if(err != 0) { - SWIFT_LOG(debug) << "resolve error: " << ub_strerror(err) << std::endl; - } else { - if(result->havedata) { - ldns_pkt* replyPacket = 0; - ldns_buffer* buffer = ldns_buffer_new(1024); - if (buffer && ldns_wire2pkt(&replyPacket, static_cast<const uint8_t*>(result->answer_packet), result->answer_len) == LDNS_STATUS_OK) { - ldns_rr_list* rrList = ldns_pkt_answer(replyPacket); - for (size_t n = 0; n < ldns_rr_list_rr_count(rrList); n++) { - ldns_rr* rr = ldns_rr_list_rr(rrList, n); - if ((ldns_rr_get_type(rr) != LDNS_RR_TYPE_SRV) || - (ldns_rr_get_class(rr) != LDNS_RR_CLASS_IN) || - (ldns_rr_rd_count(rr) != 4)) { - continue; - } - - DomainNameServiceQuery::Result serviceRecord; - serviceRecord.priority = ldns_rdf2native_int16(ldns_rr_rdf(rr, 0)); - serviceRecord.weight = ldns_rdf2native_int16(ldns_rr_rdf(rr, 1)); - serviceRecord.port = ldns_rdf2native_int16(ldns_rr_rdf(rr, 2)); - - ldns_buffer_rewind(buffer); - if ((ldns_rdf2buffer_str_dname(buffer, ldns_rr_rdf(rr, 3)) != LDNS_STATUS_OK) || - (ldns_buffer_position(buffer) < 2) || - !ldns_buffer_reserve(buffer, 1)) { - // either name invalid, empty or buffer to small - continue; - } - char terminator = 0; - ldns_buffer_write(buffer, &terminator, sizeof(terminator)); - - serviceRecord.hostname = std::string(reinterpret_cast<char*>(ldns_buffer_at(buffer, 0))); - serviceRecords.push_back(serviceRecord); - SWIFT_LOG(debug) << "hostname " << serviceRecord.hostname << " added" << std::endl; - } - } - if (replyPacket) ldns_pkt_free(replyPacket); - if (buffer) ldns_buffer_free(buffer); - } - } - - ub_resolve_free(result); - onResult(serviceRecords); - } - - private: - std::string name; +class UnboundDomainNameServiceQuery : public DomainNameServiceQuery, public UnboundQuery, public std::enable_shared_from_this<UnboundDomainNameServiceQuery> { + public: + UnboundDomainNameServiceQuery(UnboundDomainNameResolver* resolver, ub_ctx* context, std::string name) : UnboundQuery(resolver, context), name(name) { + } + + virtual ~UnboundDomainNameServiceQuery() { } + + virtual void run() { + int retval; + UnboundWrapperHelper* helper = new UnboundWrapperHelper(resolver, shared_from_this()); + + retval = ub_resolve_async(ubContext, const_cast<char*>(name.c_str()), LDNS_RR_TYPE_SRV, + 1 /* CLASS IN (internet) */, + helper, UnboundDomainNameResolver::unbound_callback_wrapper, NULL); + if(retval != 0) { + SWIFT_LOG(debug) << "resolve error: " << ub_strerror(retval); + delete helper; + } + } + + void handleResult(int err, struct ub_result* result) { + std::vector<DomainNameServiceQuery::Result> serviceRecords; + + if(err != 0) { + SWIFT_LOG(debug) << "resolve error: " << ub_strerror(err); + } else { + if(result->havedata) { + ldns_pkt* replyPacket = 0; + ldns_buffer* buffer = ldns_buffer_new(1024); + if (buffer && ldns_wire2pkt(&replyPacket, static_cast<const uint8_t*>(result->answer_packet), result->answer_len) == LDNS_STATUS_OK) { + ldns_rr_list* rrList = ldns_pkt_answer(replyPacket); + for (size_t n = 0; n < ldns_rr_list_rr_count(rrList); n++) { + ldns_rr* rr = ldns_rr_list_rr(rrList, n); + if ((ldns_rr_get_type(rr) != LDNS_RR_TYPE_SRV) || + (ldns_rr_get_class(rr) != LDNS_RR_CLASS_IN) || + (ldns_rr_rd_count(rr) != 4)) { + continue; + } + + DomainNameServiceQuery::Result serviceRecord; + serviceRecord.priority = ldns_rdf2native_int16(ldns_rr_rdf(rr, 0)); + serviceRecord.weight = ldns_rdf2native_int16(ldns_rr_rdf(rr, 1)); + serviceRecord.port = ldns_rdf2native_int16(ldns_rr_rdf(rr, 2)); + + ldns_buffer_rewind(buffer); + if ((ldns_rdf2buffer_str_dname(buffer, ldns_rr_rdf(rr, 3)) != LDNS_STATUS_OK) || + (ldns_buffer_position(buffer) < 2) || + !ldns_buffer_reserve(buffer, 1)) { + // either name invalid, empty or buffer to small + continue; + } + char terminator = 0; + ldns_buffer_write(buffer, &terminator, sizeof(terminator)); + + serviceRecord.hostname = std::string(reinterpret_cast<char*>(ldns_buffer_at(buffer, 0))); + serviceRecords.push_back(serviceRecord); + SWIFT_LOG(debug) << "hostname " << serviceRecord.hostname << " added"; + } + } + if (replyPacket) ldns_pkt_free(replyPacket); + if (buffer) ldns_buffer_free(buffer); + } + } + + ub_resolve_free(result); + onResult(serviceRecords); + } + + private: + std::string name; }; -class UnboundDomainNameAddressQuery : public DomainNameAddressQuery, public UnboundQuery, public boost::enable_shared_from_this<UnboundDomainNameAddressQuery> { - public: - UnboundDomainNameAddressQuery(UnboundDomainNameResolver* resolver, ub_ctx* context, std::string name) : UnboundQuery(resolver, context), name(name) { - } - - virtual ~UnboundDomainNameAddressQuery() { } - - virtual void run() { - int retval; - UnboundWrapperHelper* helper = new UnboundWrapperHelper(resolver, shared_from_this()); - - //FIXME: support AAAA queries in some way - retval = ub_resolve_async(ubContext, const_cast<char*>(name.c_str()), LDNS_RR_TYPE_A, - 1 /* CLASS IN (internet) */, - helper, UnboundDomainNameResolver::unbound_callback_wrapper, NULL); - if(retval != 0) { - SWIFT_LOG(debug) << "resolve error: " << ub_strerror(retval) << std::endl; - delete helper; - } - } - - void handleResult(int err, struct ub_result* result) { - std::vector<HostAddress> addresses; - boost::optional<DomainNameResolveError> error; - SWIFT_LOG(debug) << "Result for: " << name << std::endl; - - if(err != 0) { - SWIFT_LOG(debug) << "resolve error: " << ub_strerror(err) << std::endl; - error = DomainNameResolveError(); - } else { - if(result->havedata) { - for(int i=0; result->data[i]; i++) { - char address[100]; - const char* addressStr = 0; - if ((addressStr = inet_ntop(AF_INET, result->data[i], address, 100))) { - SWIFT_LOG(debug) << "IPv4 address: " << addressStr << std::endl; - addresses.push_back(HostAddress(std::string(addressStr))); - } else if ((addressStr = inet_ntop(AF_INET6, result->data[i], address, 100))) { - SWIFT_LOG(debug) << "IPv6 address: " << addressStr << std::endl; - addresses.push_back(HostAddress(std::string(addressStr))); - } else { - SWIFT_LOG(debug) << "inet_ntop() failed" << std::endl; - error = DomainNameResolveError(); - } - } - } else { - error = DomainNameResolveError(); - } - } - - ub_resolve_free(result); - onResult(addresses, error); - } - - private: - std::string name; +class UnboundDomainNameAddressQuery : public DomainNameAddressQuery, public UnboundQuery, public std::enable_shared_from_this<UnboundDomainNameAddressQuery> { + public: + UnboundDomainNameAddressQuery(UnboundDomainNameResolver* resolver, ub_ctx* context, std::string name) : UnboundQuery(resolver, context), name(name) { + } + + virtual ~UnboundDomainNameAddressQuery() { } + + virtual void run() { + int retval; + UnboundWrapperHelper* helper = new UnboundWrapperHelper(resolver, shared_from_this()); + + //FIXME: support AAAA queries in some way + retval = ub_resolve_async(ubContext, const_cast<char*>(name.c_str()), LDNS_RR_TYPE_A, + 1 /* CLASS IN (internet) */, + helper, UnboundDomainNameResolver::unbound_callback_wrapper, NULL); + if(retval != 0) { + SWIFT_LOG(debug) << "resolve error: " << ub_strerror(retval); + delete helper; + } + } + + void handleResult(int err, struct ub_result* result) { + std::vector<HostAddress> addresses; + boost::optional<DomainNameResolveError> error; + SWIFT_LOG(debug) << "Result for: " << name; + + if(err != 0) { + SWIFT_LOG(debug) << "resolve error: " << ub_strerror(err); + error = DomainNameResolveError(); + } else { + if(result->havedata) { + for(int i=0; result->data[i]; i++) { + char address[100]; + const char* addressStr = 0; + if ((addressStr = inet_ntop(AF_INET, result->data[i], address, 100))) { + SWIFT_LOG(debug) << "IPv4 address: " << addressStr; + addresses.push_back(HostAddress(std::string(addressStr))); + } else if ((addressStr = inet_ntop(AF_INET6, result->data[i], address, 100))) { + SWIFT_LOG(debug) << "IPv6 address: " << addressStr; + addresses.push_back(HostAddress(std::string(addressStr))); + } else { + SWIFT_LOG(debug) << "inet_ntop() failed"; + error = DomainNameResolveError(); + } + } + } else { + error = DomainNameResolveError(); + } + } + + ub_resolve_free(result); + onResult(addresses, error); + } + + private: + std::string name; }; -UnboundDomainNameResolver::UnboundDomainNameResolver(boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : ioService(ioService), ubDescriptior(*ioService), eventLoop(eventLoop) { - ubContext = ub_ctx_create(); - if(!ubContext) { - SWIFT_LOG(debug) << "could not create unbound context" << std::endl; - } - eventOwner = boost::make_shared<EventOwner>(); +UnboundDomainNameResolver::UnboundDomainNameResolver(IDNConverter* idnConverter, std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop) : idnConverter(idnConverter), ioService(ioService), ubDescriptior(*ioService), eventLoop(eventLoop) { + ubContext = ub_ctx_create(); + if(!ubContext) { + SWIFT_LOG(debug) << "could not create unbound context"; + } + eventOwner = std::make_shared<EventOwner>(); - ub_ctx_async(ubContext, true); + ub_ctx_async(ubContext, true); - int ret; + int ret; - /* read /etc/resolv.conf for DNS proxy settings (from DHCP) */ - if( (ret=ub_ctx_resolvconf(ubContext, const_cast<char*>("/etc/resolv.conf"))) != 0) { - SWIFT_LOG(debug) << "error reading resolv.conf: " << ub_strerror(ret) << ". errno says: " << strerror(errno) << std::endl; - } - /* read /etc/hosts for locally supplied host addresses */ - if( (ret=ub_ctx_hosts(ubContext, const_cast<char*>("/etc/hosts"))) != 0) { - SWIFT_LOG(debug) << "error reading hosts: " << ub_strerror(ret) << ". errno says: " << strerror(errno) << std::endl; - } + /* read /etc/resolv.conf for DNS proxy settings (from DHCP) */ + if( (ret=ub_ctx_resolvconf(ubContext, const_cast<char*>("/etc/resolv.conf"))) != 0) { + SWIFT_LOG(error) << "error reading resolv.conf: " << ub_strerror(ret) << ". errno says: " << strerror(errno); + } + /* read /etc/hosts for locally supplied host addresses */ + if( (ret=ub_ctx_hosts(ubContext, const_cast<char*>("/etc/hosts"))) != 0) { + SWIFT_LOG(error) << "error reading hosts: " << ub_strerror(ret) << ". errno says: " << strerror(errno); + } - ubDescriptior.assign(ub_fd(ubContext)); + ubDescriptior.assign(ub_fd(ubContext)); - ubDescriptior.async_read_some(boost::asio::null_buffers(), boost::bind(&UnboundDomainNameResolver::handleUBSocketReadable, this, boost::asio::placeholders::error)); + ubDescriptior.async_read_some(boost::asio::null_buffers(), boost::bind(&UnboundDomainNameResolver::handleUBSocketReadable, this, boost::asio::placeholders::error)); } UnboundDomainNameResolver::~UnboundDomainNameResolver() { - eventLoop->removeEventsFromOwner(eventOwner); - if (ubContext) { - ub_ctx_delete(ubContext); - } + eventLoop->removeEventsFromOwner(eventOwner); + if (ubContext) { + ub_ctx_delete(ubContext); + } } -void UnboundDomainNameResolver::unbound_callback(boost::shared_ptr<UnboundQuery> query, int err, ub_result* result) { - query->handleResult(err, result); +void UnboundDomainNameResolver::unbound_callback(std::shared_ptr<UnboundQuery> query, int err, ub_result* result) { + query->handleResult(err, result); } void UnboundDomainNameResolver::unbound_callback_wrapper(void* data, int err, ub_result* result) { - UnboundWrapperHelper* helper = static_cast<UnboundWrapperHelper*>(data); - UnboundDomainNameResolver* resolver = helper->resolver; - resolver->unbound_callback(helper->query, err, result); - delete helper; + UnboundWrapperHelper* helper = static_cast<UnboundWrapperHelper*>(data); + UnboundDomainNameResolver* resolver = helper->resolver; + resolver->unbound_callback(helper->query, err, result); + delete helper; } void UnboundDomainNameResolver::handleUBSocketReadable(boost::system::error_code) { - eventLoop->postEvent(boost::bind(&UnboundDomainNameResolver::processData, this), eventOwner); - ubDescriptior.async_read_some(boost::asio::null_buffers(), boost::bind(&UnboundDomainNameResolver::handleUBSocketReadable, this, boost::asio::placeholders::error)); + eventLoop->postEvent(boost::bind(&UnboundDomainNameResolver::processData, this), eventOwner); + ubDescriptior.async_read_some(boost::asio::null_buffers(), boost::bind(&UnboundDomainNameResolver::handleUBSocketReadable, this, boost::asio::placeholders::error)); } void UnboundDomainNameResolver::processData() { - if (ub_poll(ubContext)) { - int ret = ub_process(ubContext); - if(ret != 0) { - SWIFT_LOG(debug) << "resolve error: " << ub_strerror(ret) << std::endl; - } - } + if (ub_poll(ubContext)) { + int ret = ub_process(ubContext); + if(ret != 0) { + SWIFT_LOG(debug) << "resolve error: " << ub_strerror(ret); + } + } } -boost::shared_ptr<DomainNameServiceQuery> UnboundDomainNameResolver::createServiceQuery(const std::string& name) { - return boost::make_shared<UnboundDomainNameServiceQuery>(this, ubContext, name); +std::shared_ptr<DomainNameServiceQuery> UnboundDomainNameResolver::createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain) { + boost::optional<std::string> encodedDomain = idnConverter->getIDNAEncoded(domain); + std::string result; + if (encodedDomain) { + result = serviceLookupPrefix + *encodedDomain; + } + return std::make_shared<UnboundDomainNameServiceQuery>(this, ubContext, result); } -boost::shared_ptr<DomainNameAddressQuery> UnboundDomainNameResolver::createAddressQuery(const std::string& name) { - return boost::make_shared<UnboundDomainNameAddressQuery>(this, ubContext, name); +std::shared_ptr<DomainNameAddressQuery> UnboundDomainNameResolver::createAddressQuery(const std::string& name) { + return std::make_shared<UnboundDomainNameAddressQuery>(this, ubContext, idnConverter->getIDNAEncoded(name).get_value_or("")); } } diff --git a/Swiften/Network/UnboundDomainNameResolver.h b/Swiften/Network/UnboundDomainNameResolver.h index 0db8a66..988a415 100755..100644 --- a/Swiften/Network/UnboundDomainNameResolver.h +++ b/Swiften/Network/UnboundDomainNameResolver.h @@ -4,48 +4,55 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016-2017 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once -#include <Swiften/Network/DomainNameResolver.h> -#include <Swiften/Network/Timer.h> -#include <Swiften/EventLoop/EventOwner.h> +#include <memory> -#include <boost/shared_ptr.hpp> -#include <boost/enable_shared_from_this.hpp> #include <boost/asio.hpp> +#include <Swiften/EventLoop/EventOwner.h> +#include <Swiften/Network/DomainNameResolver.h> +#include <Swiften/Network/Timer.h> + struct ub_ctx; struct ub_result; namespace Swift { - class EventLoop; - class TimerFactory; + class EventLoop; + class IDNConverter; - class UnboundDomainNameResolver; - class UnboundQuery; + class UnboundDomainNameResolver; + class UnboundQuery; - class UnboundDomainNameResolver : public DomainNameResolver, public EventOwner, public boost::enable_shared_from_this<UnboundDomainNameResolver> { - public: - UnboundDomainNameResolver(boost::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop); - virtual ~UnboundDomainNameResolver(); + class UnboundDomainNameResolver : public DomainNameResolver, public EventOwner, public std::enable_shared_from_this<UnboundDomainNameResolver> { + public: + UnboundDomainNameResolver(IDNConverter* idnConverter, std::shared_ptr<boost::asio::io_service> ioService, EventLoop* eventLoop); + virtual ~UnboundDomainNameResolver(); - virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const std::string& name); - virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const std::string& name); + virtual std::shared_ptr<DomainNameServiceQuery> createServiceQuery(const std::string& serviceLookupPrefix, const std::string& domain); + virtual std::shared_ptr<DomainNameAddressQuery> createAddressQuery(const std::string& name); - static void unbound_callback_wrapper(void* data, int err, ub_result* result); + static void unbound_callback_wrapper(void* data, int err, ub_result* result); - private: - void unbound_callback(boost::shared_ptr<UnboundQuery> query, int err, ub_result* result); + private: + void unbound_callback(std::shared_ptr<UnboundQuery> query, int err, ub_result* result); - void handleUBSocketReadable(boost::system::error_code); - void processData(); + void handleUBSocketReadable(boost::system::error_code); + void processData(); - private: - boost::shared_ptr<EventOwner> eventOwner; - boost::shared_ptr<boost::asio::io_service> ioService; - boost::asio::posix::stream_descriptor ubDescriptior; - EventLoop* eventLoop; - ub_ctx* ubContext; - }; + private: + IDNConverter* idnConverter; + std::shared_ptr<EventOwner> eventOwner; + std::shared_ptr<boost::asio::io_service> ioService; + boost::asio::posix::stream_descriptor ubDescriptior; + EventLoop* eventLoop; + ub_ctx* ubContext; + }; } diff --git a/Swiften/Network/UnitTest/BOSHConnectionPoolTest.cpp b/Swiften/Network/UnitTest/BOSHConnectionPoolTest.cpp index 8a63fcb..4aeaf24 100644 --- a/Swiften/Network/UnitTest/BOSHConnectionPoolTest.cpp +++ b/Swiften/Network/UnitTest/BOSHConnectionPoolTest.cpp @@ -1,469 +1,468 @@ /* - * Copyright (c) 2011 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ +#include <memory> + +#include <boost/bind.hpp> +#include <boost/lexical_cast.hpp> +#include <boost/optional.hpp> + #include <QA/Checker/IO.h> #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> -#include <boost/optional.hpp> -#include <boost/bind.hpp> -#include <boost/smart_ptr/make_shared.hpp> -#include <boost/shared_ptr.hpp> -#include <boost/lexical_cast.hpp> - #include <Swiften/Base/Algorithm.h> -#include <Swiften/Network/Connection.h> -#include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/EventLoop/DummyEventLoop.h> #include <Swiften/Network/BOSHConnection.h> #include <Swiften/Network/BOSHConnectionPool.h> +#include <Swiften/Network/Connection.h> +#include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/Network/DummyTimerFactory.h> #include <Swiften/Network/HostAddressPort.h> #include <Swiften/Network/StaticDomainNameResolver.h> -#include <Swiften/Network/DummyTimerFactory.h> -#include <Swiften/EventLoop/DummyEventLoop.h> #include <Swiften/Parser/PlatformXMLParserFactory.h> - - +#include <Swiften/TLS/TLSOptions.h> using namespace Swift; -typedef boost::shared_ptr<BOSHConnectionPool> PoolRef; +typedef std::shared_ptr<BOSHConnectionPool> PoolRef; class BOSHConnectionPoolTest : public CppUnit::TestFixture { - CPPUNIT_TEST_SUITE(BOSHConnectionPoolTest); - CPPUNIT_TEST(testConnectionCount_OneWrite); - CPPUNIT_TEST(testConnectionCount_TwoWrites); - CPPUNIT_TEST(testConnectionCount_ThreeWrites); - CPPUNIT_TEST(testConnectionCount_ThreeWrites_ManualConnect); - CPPUNIT_TEST(testConnectionCount_ThreeWritesTwoReads); - CPPUNIT_TEST(testSession); - CPPUNIT_TEST(testWrite_Empty); - CPPUNIT_TEST_SUITE_END(); - - public: - void setUp() { - to = "wonderland.lit"; - path = "/http-bind"; - port = "5280"; - sid = "MyShinySID"; - initial = "<body wait='60' " - "inactivity='30' " - "polling='5' " - "requests='2' " - "hold='1' " - "maxpause='120' " - "sid='" + sid + "' " - "ver='1.6' " - "from='wonderland.lit' " - "xmlns='http://jabber.org/protocol/httpbind'/>"; - eventLoop = new DummyEventLoop(); - connectionFactory = new MockConnectionFactory(eventLoop); - boshURL = URL("http", to, 5280, path); - sessionTerminated = 0; - sessionStarted = 0; - initialRID = 2349876; - xmppDataRead.clear(); - boshDataRead.clear(); - boshDataWritten.clear(); - resolver = new StaticDomainNameResolver(eventLoop); - resolver->addAddress(to, HostAddress("127.0.0.1")); - timerFactory = new DummyTimerFactory(); - } - - void tearDown() { - eventLoop->processEvents(); - delete connectionFactory; - delete resolver; - delete timerFactory; - delete eventLoop; - } - - void testConnectionCount_OneWrite() { - PoolRef testling = createTestling(); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(0, sessionStarted); - readResponse(initial, connectionFactory->connections[0]); - CPPUNIT_ASSERT_EQUAL(1, sessionStarted); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - testling->write(createSafeByteArray("<blah/>")); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - CPPUNIT_ASSERT_EQUAL(1, sessionStarted); - } - - void testConnectionCount_TwoWrites() { - PoolRef testling = createTestling(); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - eventLoop->processEvents(); - readResponse(initial, connectionFactory->connections[0]); - eventLoop->processEvents(); - testling->write(createSafeByteArray("<blah/>")); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - testling->write(createSafeByteArray("<bleh/>")); - eventLoop->processEvents(); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); - } - - void testConnectionCount_ThreeWrites() { - PoolRef testling = createTestling(); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - eventLoop->processEvents(); - readResponse(initial, connectionFactory->connections[0]); - testling->restartStream(); - readResponse("<body/>", connectionFactory->connections[0]); - testling->restartStream(); - readResponse("<body/>", connectionFactory->connections[0]); - testling->write(createSafeByteArray("<blah/>")); - testling->write(createSafeByteArray("<bleh/>")); - testling->write(createSafeByteArray("<bluh/>")); - eventLoop->processEvents(); - CPPUNIT_ASSERT(st(2) >= connectionFactory->connections.size()); - } - - void testConnectionCount_ThreeWrites_ManualConnect() { - connectionFactory->autoFinishConnect = false; - PoolRef testling = createTestling(); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - CPPUNIT_ASSERT_EQUAL(st(0), boshDataWritten.size()); /* Connection not connected yet, can't send data */ - - connectionFactory->connections[0]->onConnectFinished(false); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* Connection finished, stream header sent */ - - readResponse(initial, connectionFactory->connections[0]); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* Don't respond to initial data with a holding call */ - - testling->restartStream(); - eventLoop->processEvents(); - readResponse("<body/>", connectionFactory->connections[0]); - eventLoop->processEvents(); - testling->restartStream(); - eventLoop->processEvents(); - - - testling->write(createSafeByteArray("<blah/>")); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); - CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); /* New connection isn't up yet. */ - - connectionFactory->connections[1]->onConnectFinished(false); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* New connection ready. */ - - testling->write(createSafeByteArray("<bleh/>")); - eventLoop->processEvents(); - testling->write(createSafeByteArray("<bluh/>")); - CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* New data can't be sent, no free connections. */ - eventLoop->processEvents(); - CPPUNIT_ASSERT(st(2) >= connectionFactory->connections.size()); - } - - void testConnectionCount_ThreeWritesTwoReads() { - boost::shared_ptr<MockConnection> c0; - boost::shared_ptr<MockConnection> c1; - unsigned long long rid = initialRID; - - PoolRef testling = createTestling(); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - c0 = connectionFactory->connections[0]; - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* header*/ - - rid++; - readResponse(initial, c0); - CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - CPPUNIT_ASSERT(!c0->pending); - - rid++; - testling->restartStream(); - eventLoop->processEvents(); - readResponse("<body/>", connectionFactory->connections[0]); - - rid++; - testling->write(createSafeByteArray("<blah/>")); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); /* 0 was waiting for response, open and send on 1 */ - CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* data */ - c1 = connectionFactory->connections[1]; - std::string fullBody = "<body rid='" + boost::lexical_cast<std::string>(rid) + "' sid='" + sid + "' xmlns='http://jabber.org/protocol/httpbind'><blah/></body>"; /* check empty write */ - CPPUNIT_ASSERT_EQUAL(fullBody, lastBody()); - CPPUNIT_ASSERT(c0->pending); - CPPUNIT_ASSERT(c1->pending); - - - rid++; - readResponse("<body xmlns='http://jabber.org/protocol/httpbind'><message><splatploing/></message></body>", c0); /* Doesn't include necessary attributes - as the support is improved this'll start to fail */ - eventLoop->processEvents(); - CPPUNIT_ASSERT(!c0->pending); - CPPUNIT_ASSERT(c1->pending); - CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* don't send empty in [0], still have [1] waiting */ - CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); - - rid++; - readResponse("<body xmlns='http://jabber.org/protocol/httpbind'><message><splatploing><blittlebarg/></splatploing></message></body>", c1); - eventLoop->processEvents(); - CPPUNIT_ASSERT(!c1->pending); - CPPUNIT_ASSERT(c0->pending); - CPPUNIT_ASSERT_EQUAL(st(5), boshDataWritten.size()); /* empty to make room */ - CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); - - rid++; - testling->write(createSafeByteArray("<bleh/>")); - eventLoop->processEvents(); - CPPUNIT_ASSERT(c0->pending); - CPPUNIT_ASSERT(c1->pending); - CPPUNIT_ASSERT_EQUAL(st(6), boshDataWritten.size()); /* data */ - - rid++; - testling->write(createSafeByteArray("<bluh/>")); - CPPUNIT_ASSERT(c0->pending); - CPPUNIT_ASSERT(c1->pending); - CPPUNIT_ASSERT_EQUAL(st(6), boshDataWritten.size()); /* Don't send data, no room */ - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); - } - - void testSession() { - to = "prosody.doomsong.co.uk"; - resolver->addAddress("prosody.doomsong.co.uk", HostAddress("127.0.0.1")); - path = "/http-bind/"; - boshURL = URL("http", to, 5280, path); - - PoolRef testling = createTestling(); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* header*/ - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - - std::string response = "<body authid='743da605-4c2e-4de1-afac-ac040dd4a940' xmpp:version='1.0' xmlns:stream='http://etherx.jabber.org/streams' xmlns:xmpp='urn:xmpp:xbosh' inactivity='60' wait='60' polling='5' secure='true' hold='1' from='prosody.doomsong.co.uk' ver='1.6' sid='743da605-4c2e-4de1-afac-ac040dd4a940' requests='2' xmlns='http://jabber.org/protocol/httpbind'><stream:features><auth xmlns='http://jabber.org/features/iq-auth'/><mechanisms xmlns='urn:ietf:params:xml:ns:xmpp-sasl'><mechanism>SCRAM-SHA-1</mechanism><mechanism>DIGEST-MD5</mechanism></mechanisms></stream:features></body>"; - readResponse(response, connectionFactory->connections[0]); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - - std::string send = "<auth xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\" mechanism=\"SCRAM-SHA-1\">biwsbj1hZG1pbixyPWZhOWE5ZDhiLWZmMDctNGE4Yy04N2E3LTg4YWRiNDQxZGUwYg==</auth>"; - testling->write(createSafeByteArray(send)); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(2), boshDataWritten.size()); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - - response = "<body xmlns='http://jabber.org/protocol/httpbind' sid='743da605-4c2e-4de1-afac-ac040dd4a940' xmlns:stream = 'http://etherx.jabber.org/streams'><challenge xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>cj1mYTlhOWQ4Yi1mZjA3LTRhOGMtODdhNy04OGFkYjQ0MWRlMGJhZmZlMWNhMy1mMDJkLTQ5NzEtYjkyNS0yM2NlNWQ2MDQyMjYscz1OVGd5WkdWaFptTXRaVE15WXkwMFpXUmhMV0ZqTURRdFpqYzRNbUppWmpGa1pqWXgsaT00MDk2</challenge></body>"; - readResponse(response, connectionFactory->connections[0]); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(2), boshDataWritten.size()); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - - send = "<response xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\">Yz1iaXdzLHI9ZmE5YTlkOGItZmYwNy00YThjLTg3YTctODhhZGI0NDFkZTBiYWZmZTFjYTMtZjAyZC00OTcxLWI5MjUtMjNjZTVkNjA0MjI2LHA9aU11NWt3dDN2VWplU2RqL01Jb3VIRldkZjBnPQ==</response>"; - testling->write(createSafeByteArray(send)); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - - response = "<body xmlns='http://jabber.org/protocol/httpbind' sid='743da605-4c2e-4de1-afac-ac040dd4a940' xmlns:stream = 'http://etherx.jabber.org/streams'><success xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>dj1YNmNBY3BBOWxHNjNOOXF2bVQ5S0FacERrVm89</success></body>"; - readResponse(response, connectionFactory->connections[0]); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - - testling->restartStream(); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - - response = "<body xmpp:version='1.0' xmlns:stream='http://etherx.jabber.org/streams' xmlns:xmpp='urn:xmpp:xbosh' inactivity='60' wait='60' polling='5' secure='true' hold='1' from='prosody.doomsong.co.uk' ver='1.6' sid='743da605-4c2e-4de1-afac-ac040dd4a940' requests='2' xmlns='http://jabber.org/protocol/httpbind'><stream:features><bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'><required/></bind><session xmlns='urn:ietf:params:xml:ns:xmpp-session'><optional/></session><sm xmlns='urn:xmpp:sm:2'><optional/></sm></stream:features></body>"; - readResponse(response, connectionFactory->connections[0]); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(5), boshDataWritten.size()); /* Now we've authed (restarted) we should be keeping one query in flight so the server can reply to us at any time it wants. */ - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - - send = "<body rid='2821988967416214' sid='cf663f6b94279d4f' xmlns='http://jabber.org/protocol/httpbind'><iq id='session-bind' type='set'><bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'><resource>d5a9744036cd20a0</resource></bind></iq></body>"; - testling->write(createSafeByteArray(send)); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(6), boshDataWritten.size()); - CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); /* and as it keeps one in flight, it's needed to open a second to send these data */ - - } - - void testWrite_Empty() { - boost::shared_ptr<MockConnection> c0; - - PoolRef testling = createTestling(); - CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); - c0 = connectionFactory->connections[0]; - - readResponse(initial, c0); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* Shouldn't have sent anything extra */ - eventLoop->processEvents(); - testling->restartStream(); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(2), boshDataWritten.size()); - readResponse("<body></body>", c0); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); - std::string fullBody = "<body rid='" + boost::lexical_cast<std::string>(initialRID + 2) + "' sid='" + sid + "' xmlns='http://jabber.org/protocol/httpbind'></body>"; - std::string response = boshDataWritten[2]; - size_t bodyPosition = response.find("\r\n\r\n"); - CPPUNIT_ASSERT_EQUAL(fullBody, response.substr(bodyPosition+4)); - - - } - - private: - - PoolRef createTestling() { - BOSHConnectionPool* a = new BOSHConnectionPool(boshURL, resolver, connectionFactory, &parserFactory, static_cast<TLSContextFactory*>(NULL), timerFactory, eventLoop, to, initialRID, URL(), SafeString(""), SafeString("")); - PoolRef pool(a); - //FIXME: Remko - why does the above work, but the below fail? - //PoolRef pool = boost::make_shared<BOSHConnectionPool>(boshURL, resolver, connectionFactory, &parserFactory, static_cast<TLSContextFactory*>(NULL), timerFactory, eventLoop, to, initialRID, URL(), SafeString(""), SafeString("")); - pool->onXMPPDataRead.connect(boost::bind(&BOSHConnectionPoolTest::handleXMPPDataRead, this, _1)); - pool->onBOSHDataRead.connect(boost::bind(&BOSHConnectionPoolTest::handleBOSHDataRead, this, _1)); - pool->onBOSHDataWritten.connect(boost::bind(&BOSHConnectionPoolTest::handleBOSHDataWritten, this, _1)); - pool->onSessionStarted.connect(boost::bind(&BOSHConnectionPoolTest::handleSessionStarted, this)); - pool->onSessionTerminated.connect(boost::bind(&BOSHConnectionPoolTest::handleSessionTerminated, this)); - eventLoop->processEvents(); - eventLoop->processEvents(); - return pool; - } - - std::string lastBody() { - std::string response = boshDataWritten[boshDataWritten.size() - 1]; - size_t bodyPosition = response.find("\r\n\r\n"); - return response.substr(bodyPosition+4); - } - - size_t st(int val) { - return static_cast<size_t>(val); - } - - void handleXMPPDataRead(const SafeByteArray& d) { - xmppDataRead.push_back(safeByteArrayToString(d)); - } - - void handleBOSHDataRead(const SafeByteArray& d) { - boshDataRead.push_back(safeByteArrayToString(d)); - } - - void handleBOSHDataWritten(const SafeByteArray& d) { - boshDataWritten.push_back(safeByteArrayToString(d)); - } - - - void handleSessionStarted() { - sessionStarted++; - } - - void handleSessionTerminated() { - sessionTerminated++; - } - - struct MockConnection : public Connection { - public: - MockConnection(const std::vector<HostAddressPort>& failingPorts, EventLoop* eventLoop, bool autoFinishConnect) : eventLoop(eventLoop), failingPorts(failingPorts), disconnected(false), pending(false), autoFinishConnect(autoFinishConnect) { - } - - void listen() { assert(false); } - - void connect(const HostAddressPort& address) { - hostAddressPort = address; - bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); - if (autoFinishConnect) { - eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail)); - } - } - - HostAddressPort getLocalAddress() const { return HostAddressPort(); } - - void disconnect() { - disconnected = true; - onDisconnected(boost::optional<Connection::Error>()); - } - - void write(const SafeByteArray& d) { - append(dataWritten, d); - pending = true; - } - - EventLoop* eventLoop; - boost::optional<HostAddressPort> hostAddressPort; - std::vector<HostAddressPort> failingPorts; - ByteArray dataWritten; - bool disconnected; - bool pending; - bool autoFinishConnect; - }; - - struct MockConnectionFactory : public ConnectionFactory { - MockConnectionFactory(EventLoop* eventLoop, bool autoFinishConnect = true) : eventLoop(eventLoop), autoFinishConnect(autoFinishConnect) { - } - - boost::shared_ptr<Connection> createConnection() { - boost::shared_ptr<MockConnection> connection = boost::make_shared<MockConnection>(failingPorts, eventLoop, autoFinishConnect); - connections.push_back(connection); - return connection; - } - - EventLoop* eventLoop; - std::vector< boost::shared_ptr<MockConnection> > connections; - std::vector<HostAddressPort> failingPorts; - bool autoFinishConnect; - }; - - void readResponse(const std::string& response, boost::shared_ptr<MockConnection> connection) { - connection->pending = false; - boost::shared_ptr<SafeByteArray> data1 = boost::make_shared<SafeByteArray>(createSafeByteArray( - "HTTP/1.1 200 OK\r\n" - "Content-Type: text/xml; charset=utf-8\r\n" - "Access-Control-Allow-Origin: *\r\n" - "Access-Control-Allow-Headers: Content-Type\r\n" - "Content-Length: ")); - connection->onDataRead(data1); - boost::shared_ptr<SafeByteArray> data2 = boost::make_shared<SafeByteArray>(createSafeByteArray(boost::lexical_cast<std::string>(response.size()))); - connection->onDataRead(data2); - boost::shared_ptr<SafeByteArray> data3 = boost::make_shared<SafeByteArray>(createSafeByteArray("\r\n\r\n")); - connection->onDataRead(data3); - boost::shared_ptr<SafeByteArray> data4 = boost::make_shared<SafeByteArray>(createSafeByteArray(response)); - connection->onDataRead(data4); - } - - std::string fullRequestFor(const std::string& data) { - std::string body = data; - std::string result = "POST /" + path + " HTTP/1.1\r\n" - + "Host: " + to + ":" + port + "\r\n" - + "Content-Type: text/xml; charset=utf-8\r\n" - + "Content-Length: " + boost::lexical_cast<std::string>(body.size()) + "\r\n\r\n" - + body; - return result; - } - - private: - URL boshURL; - DummyEventLoop* eventLoop; - MockConnectionFactory* connectionFactory; - std::vector<std::string> xmppDataRead; - std::vector<std::string> boshDataRead; - std::vector<std::string> boshDataWritten; - PlatformXMLParserFactory parserFactory; - StaticDomainNameResolver* resolver; - TimerFactory* timerFactory; - std::string to; - std::string path; - std::string port; - std::string sid; - std::string initial; - unsigned long long initialRID; - int sessionStarted; - int sessionTerminated; + CPPUNIT_TEST_SUITE(BOSHConnectionPoolTest); + CPPUNIT_TEST(testConnectionCount_OneWrite); + CPPUNIT_TEST(testConnectionCount_TwoWrites); + CPPUNIT_TEST(testConnectionCount_ThreeWrites); + CPPUNIT_TEST(testConnectionCount_ThreeWrites_ManualConnect); + CPPUNIT_TEST(testConnectionCount_ThreeWritesTwoReads); + CPPUNIT_TEST(testSession); + CPPUNIT_TEST(testWrite_Empty); + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + to = "wonderland.lit"; + path = "/http-bind"; + port = "5280"; + sid = "MyShinySID"; + initial = "<body wait='60' " + "inactivity='30' " + "polling='5' " + "requests='2' " + "hold='1' " + "maxpause='120' " + "sid='" + sid + "' " + "ver='1.6' " + "from='wonderland.lit' " + "xmlns='http://jabber.org/protocol/httpbind'/>"; + eventLoop = new DummyEventLoop(); + connectionFactory = new MockConnectionFactory(eventLoop); + boshURL = URL("http", to, 5280, path); + sessionTerminated = 0; + sessionStarted = 0; + initialRID = 2349876; + xmppDataRead.clear(); + boshDataRead.clear(); + boshDataWritten.clear(); + resolver = new StaticDomainNameResolver(eventLoop); + resolver->addAddress(to, HostAddress::fromString("127.0.0.1").get()); + timerFactory = new DummyTimerFactory(); + } + + void tearDown() { + eventLoop->processEvents(); + delete connectionFactory; + delete resolver; + delete timerFactory; + delete eventLoop; + } + + void testConnectionCount_OneWrite() { + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(0, sessionStarted); + readResponse(initial, connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(1, sessionStarted); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + testling->write(createSafeByteArray("<blah/>")); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + CPPUNIT_ASSERT_EQUAL(1, sessionStarted); + } + + void testConnectionCount_TwoWrites() { + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + readResponse(initial, connectionFactory->connections[0]); + eventLoop->processEvents(); + testling->write(createSafeByteArray("<blah/>")); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + testling->write(createSafeByteArray("<bleh/>")); + eventLoop->processEvents(); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + } + + void testConnectionCount_ThreeWrites() { + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + readResponse(initial, connectionFactory->connections[0]); + testling->restartStream(); + readResponse("<body/>", connectionFactory->connections[0]); + testling->restartStream(); + readResponse("<body/>", connectionFactory->connections[0]); + testling->write(createSafeByteArray("<blah/>")); + testling->write(createSafeByteArray("<bleh/>")); + testling->write(createSafeByteArray("<bluh/>")); + eventLoop->processEvents(); + CPPUNIT_ASSERT(st(2) >= connectionFactory->connections.size()); + } + + void testConnectionCount_ThreeWrites_ManualConnect() { + connectionFactory->autoFinishConnect = false; + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + CPPUNIT_ASSERT_EQUAL(st(0), boshDataWritten.size()); /* Connection not connected yet, can't send data */ + + connectionFactory->connections[0]->onConnectFinished(false); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* Connection finished, stream header sent */ + + readResponse(initial, connectionFactory->connections[0]); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* Don't respond to initial data with a holding call */ + + testling->restartStream(); + eventLoop->processEvents(); + readResponse("<body/>", connectionFactory->connections[0]); + eventLoop->processEvents(); + testling->restartStream(); + eventLoop->processEvents(); + + + testling->write(createSafeByteArray("<blah/>")); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); /* New connection isn't up yet. */ + + connectionFactory->connections[1]->onConnectFinished(false); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* New connection ready. */ + + testling->write(createSafeByteArray("<bleh/>")); + eventLoop->processEvents(); + testling->write(createSafeByteArray("<bluh/>")); + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* New data can't be sent, no free connections. */ + eventLoop->processEvents(); + CPPUNIT_ASSERT(st(2) >= connectionFactory->connections.size()); + } + + void testConnectionCount_ThreeWritesTwoReads() { + std::shared_ptr<MockConnection> c0; + std::shared_ptr<MockConnection> c1; + unsigned long long rid = initialRID; + + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + c0 = connectionFactory->connections[0]; + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* header*/ + + rid++; + readResponse(initial, c0); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + CPPUNIT_ASSERT(!c0->pending); + + rid++; + testling->restartStream(); + eventLoop->processEvents(); + readResponse("<body/>", connectionFactory->connections[0]); + + rid++; + testling->write(createSafeByteArray("<blah/>")); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); /* 0 was waiting for response, open and send on 1 */ + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* data */ + c1 = connectionFactory->connections[1]; + std::string fullBody = "<body rid='" + std::to_string(rid) + "' sid='" + sid + "' xmlns='http://jabber.org/protocol/httpbind'><blah/></body>"; /* check empty write */ + CPPUNIT_ASSERT_EQUAL(fullBody, lastBody()); + CPPUNIT_ASSERT(c0->pending); + CPPUNIT_ASSERT(c1->pending); + + + rid++; + readResponse("<body xmlns='http://jabber.org/protocol/httpbind'><message><splatploing/></message></body>", c0); /* Doesn't include necessary attributes - as the support is improved this'll start to fail */ + eventLoop->processEvents(); + CPPUNIT_ASSERT(!c0->pending); + CPPUNIT_ASSERT(c1->pending); + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); /* don't send empty in [0], still have [1] waiting */ + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + + rid++; + readResponse("<body xmlns='http://jabber.org/protocol/httpbind'><message><splatploing><blittlebarg/></splatploing></message></body>", c1); + eventLoop->processEvents(); + CPPUNIT_ASSERT(!c1->pending); + CPPUNIT_ASSERT(c0->pending); + CPPUNIT_ASSERT_EQUAL(st(5), boshDataWritten.size()); /* empty to make room */ + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + + rid++; + testling->write(createSafeByteArray("<bleh/>")); + eventLoop->processEvents(); + CPPUNIT_ASSERT(c0->pending); + CPPUNIT_ASSERT(c1->pending); + CPPUNIT_ASSERT_EQUAL(st(6), boshDataWritten.size()); /* data */ + + rid++; + testling->write(createSafeByteArray("<bluh/>")); + CPPUNIT_ASSERT(c0->pending); + CPPUNIT_ASSERT(c1->pending); + CPPUNIT_ASSERT_EQUAL(st(6), boshDataWritten.size()); /* Don't send data, no room */ + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); + } + + void testSession() { + to = "prosody.doomsong.co.uk"; + resolver->addAddress("prosody.doomsong.co.uk", HostAddress::fromString("127.0.0.1").get()); + path = "/http-bind/"; + boshURL = URL("http", to, 5280, path); + + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* header*/ + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + std::string response = "<body authid='743da605-4c2e-4de1-afac-ac040dd4a940' xmpp:version='1.0' xmlns:stream='http://etherx.jabber.org/streams' xmlns:xmpp='urn:xmpp:xbosh' inactivity='60' wait='60' polling='5' secure='true' hold='1' from='prosody.doomsong.co.uk' ver='1.6' sid='743da605-4c2e-4de1-afac-ac040dd4a940' requests='2' xmlns='http://jabber.org/protocol/httpbind'><stream:features><auth xmlns='http://jabber.org/features/iq-auth'/><mechanisms xmlns='urn:ietf:params:xml:ns:xmpp-sasl'><mechanism>SCRAM-SHA-1</mechanism><mechanism>DIGEST-MD5</mechanism></mechanisms></stream:features></body>"; + readResponse(response, connectionFactory->connections[0]); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + std::string send = "<auth xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\" mechanism=\"SCRAM-SHA-1\">biwsbj1hZG1pbixyPWZhOWE5ZDhiLWZmMDctNGE4Yy04N2E3LTg4YWRiNDQxZGUwYg==</auth>"; + testling->write(createSafeByteArray(send)); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + response = "<body xmlns='http://jabber.org/protocol/httpbind' sid='743da605-4c2e-4de1-afac-ac040dd4a940' xmlns:stream = 'http://etherx.jabber.org/streams'><challenge xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>cj1mYTlhOWQ4Yi1mZjA3LTRhOGMtODdhNy04OGFkYjQ0MWRlMGJhZmZlMWNhMy1mMDJkLTQ5NzEtYjkyNS0yM2NlNWQ2MDQyMjYscz1OVGd5WkdWaFptTXRaVE15WXkwMFpXUmhMV0ZqTURRdFpqYzRNbUppWmpGa1pqWXgsaT00MDk2</challenge></body>"; + readResponse(response, connectionFactory->connections[0]); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + send = "<response xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\">Yz1iaXdzLHI9ZmE5YTlkOGItZmYwNy00YThjLTg3YTctODhhZGI0NDFkZTBiYWZmZTFjYTMtZjAyZC00OTcxLWI5MjUtMjNjZTVkNjA0MjI2LHA9aU11NWt3dDN2VWplU2RqL01Jb3VIRldkZjBnPQ==</response>"; + testling->write(createSafeByteArray(send)); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + response = "<body xmlns='http://jabber.org/protocol/httpbind' sid='743da605-4c2e-4de1-afac-ac040dd4a940' xmlns:stream = 'http://etherx.jabber.org/streams'><success xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>dj1YNmNBY3BBOWxHNjNOOXF2bVQ5S0FacERrVm89</success></body>"; + readResponse(response, connectionFactory->connections[0]); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + testling->restartStream(); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(4), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + response = "<body xmpp:version='1.0' xmlns:stream='http://etherx.jabber.org/streams' xmlns:xmpp='urn:xmpp:xbosh' inactivity='60' wait='60' polling='5' secure='true' hold='1' from='prosody.doomsong.co.uk' ver='1.6' sid='743da605-4c2e-4de1-afac-ac040dd4a940' requests='2' xmlns='http://jabber.org/protocol/httpbind'><stream:features><bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'><required/></bind><session xmlns='urn:ietf:params:xml:ns:xmpp-session'><optional/></session><sm xmlns='urn:xmpp:sm:2'><optional/></sm></stream:features></body>"; + readResponse(response, connectionFactory->connections[0]); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(5), boshDataWritten.size()); /* Now we've authed (restarted) we should be keeping one query in flight so the server can reply to us at any time it wants. */ + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + + send = "<body rid='2821988967416214' sid='cf663f6b94279d4f' xmlns='http://jabber.org/protocol/httpbind'><iq id='session-bind' type='set'><bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'><resource>d5a9744036cd20a0</resource></bind></iq></body>"; + testling->write(createSafeByteArray(send)); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(6), boshDataWritten.size()); + CPPUNIT_ASSERT_EQUAL(st(2), connectionFactory->connections.size()); /* and as it keeps one in flight, it's needed to open a second to send these data */ + + } + + void testWrite_Empty() { + std::shared_ptr<MockConnection> c0; + + PoolRef testling = createTestling(); + CPPUNIT_ASSERT_EQUAL(st(1), connectionFactory->connections.size()); + c0 = connectionFactory->connections[0]; + + readResponse(initial, c0); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(1), boshDataWritten.size()); /* Shouldn't have sent anything extra */ + eventLoop->processEvents(); + testling->restartStream(); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(2), boshDataWritten.size()); + readResponse("<body></body>", c0); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(st(3), boshDataWritten.size()); + std::string fullBody = "<body rid='" + std::to_string(initialRID + 2) + "' sid='" + sid + "' xmlns='http://jabber.org/protocol/httpbind'></body>"; + std::string response = boshDataWritten[2]; + size_t bodyPosition = response.find("\r\n\r\n"); + CPPUNIT_ASSERT_EQUAL(fullBody, response.substr(bodyPosition+4)); + + + } + + private: + + PoolRef createTestling() { + // make_shared is limited to 9 arguments; instead new is used here. + PoolRef pool = PoolRef(new BOSHConnectionPool(boshURL, resolver, connectionFactory, &parserFactory, static_cast<TLSContextFactory*>(nullptr), timerFactory, eventLoop, to, initialRID, URL(), SafeString(""), SafeString(""), TLSOptions())); + pool->open(); + pool->onXMPPDataRead.connect(boost::bind(&BOSHConnectionPoolTest::handleXMPPDataRead, this, _1)); + pool->onBOSHDataRead.connect(boost::bind(&BOSHConnectionPoolTest::handleBOSHDataRead, this, _1)); + pool->onBOSHDataWritten.connect(boost::bind(&BOSHConnectionPoolTest::handleBOSHDataWritten, this, _1)); + pool->onSessionStarted.connect(boost::bind(&BOSHConnectionPoolTest::handleSessionStarted, this)); + pool->onSessionTerminated.connect(boost::bind(&BOSHConnectionPoolTest::handleSessionTerminated, this)); + eventLoop->processEvents(); + eventLoop->processEvents(); + return pool; + } + + std::string lastBody() { + std::string response = boshDataWritten[boshDataWritten.size() - 1]; + size_t bodyPosition = response.find("\r\n\r\n"); + return response.substr(bodyPosition+4); + } + + size_t st(int val) { + return static_cast<size_t>(val); + } + + void handleXMPPDataRead(const SafeByteArray& d) { + xmppDataRead.push_back(safeByteArrayToString(d)); + } + + void handleBOSHDataRead(const SafeByteArray& d) { + boshDataRead.push_back(safeByteArrayToString(d)); + } + + void handleBOSHDataWritten(const SafeByteArray& d) { + boshDataWritten.push_back(safeByteArrayToString(d)); + } + + + void handleSessionStarted() { + sessionStarted++; + } + + void handleSessionTerminated() { + sessionTerminated++; + } + + struct MockConnection : public Connection { + public: + MockConnection(const std::vector<HostAddressPort>& failingPorts, EventLoop* eventLoop, bool autoFinishConnect) : eventLoop(eventLoop), failingPorts(failingPorts), disconnected(false), pending(false), autoFinishConnect(autoFinishConnect) { + } + + void listen() { assert(false); } + + void connect(const HostAddressPort& address) { + hostAddressPort = address; + bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); + if (autoFinishConnect) { + eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail)); + } + } + + HostAddressPort getLocalAddress() const { return HostAddressPort(); } + HostAddressPort getRemoteAddress() const { return HostAddressPort(); } + + void disconnect() { + disconnected = true; + onDisconnected(boost::optional<Connection::Error>()); + } + + void write(const SafeByteArray& d) { + append(dataWritten, d); + pending = true; + } + + EventLoop* eventLoop; + boost::optional<HostAddressPort> hostAddressPort; + std::vector<HostAddressPort> failingPorts; + ByteArray dataWritten; + bool disconnected; + bool pending; + bool autoFinishConnect; + }; + + struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory(EventLoop* eventLoop, bool autoFinishConnect = true) : eventLoop(eventLoop), autoFinishConnect(autoFinishConnect) { + } + + std::shared_ptr<Connection> createConnection() { + std::shared_ptr<MockConnection> connection = std::make_shared<MockConnection>(failingPorts, eventLoop, autoFinishConnect); + connections.push_back(connection); + return connection; + } + + EventLoop* eventLoop; + std::vector< std::shared_ptr<MockConnection> > connections; + std::vector<HostAddressPort> failingPorts; + bool autoFinishConnect; + }; + + void readResponse(const std::string& response, std::shared_ptr<MockConnection> connection) { + connection->pending = false; + std::shared_ptr<SafeByteArray> data1 = std::make_shared<SafeByteArray>(createSafeByteArray( + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/xml; charset=utf-8\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Access-Control-Allow-Headers: Content-Type\r\n" + "Content-Length: ")); + connection->onDataRead(data1); + std::shared_ptr<SafeByteArray> data2 = std::make_shared<SafeByteArray>(createSafeByteArray(std::to_string(response.size()))); + connection->onDataRead(data2); + std::shared_ptr<SafeByteArray> data3 = std::make_shared<SafeByteArray>(createSafeByteArray("\r\n\r\n")); + connection->onDataRead(data3); + std::shared_ptr<SafeByteArray> data4 = std::make_shared<SafeByteArray>(createSafeByteArray(response)); + connection->onDataRead(data4); + } + + std::string fullRequestFor(const std::string& data) { + std::string body = data; + std::string result = "POST /" + path + " HTTP/1.1\r\n" + + "Host: " + to + ":" + port + "\r\n" + + "Content-Type: text/xml; charset=utf-8\r\n" + + "Content-Length: " + std::to_string(body.size()) + "\r\n\r\n" + + body; + return result; + } + + private: + URL boshURL; + DummyEventLoop* eventLoop; + MockConnectionFactory* connectionFactory; + std::vector<std::string> xmppDataRead; + std::vector<std::string> boshDataRead; + std::vector<std::string> boshDataWritten; + PlatformXMLParserFactory parserFactory; + StaticDomainNameResolver* resolver; + TimerFactory* timerFactory; + std::string to; + std::string path; + std::string port; + std::string sid; + std::string initial; + unsigned long long initialRID; + int sessionStarted; + int sessionTerminated; }; diff --git a/Swiften/Network/UnitTest/BOSHConnectionTest.cpp b/Swiften/Network/UnitTest/BOSHConnectionTest.cpp index 7ef0249..17d8333 100644 --- a/Swiften/Network/UnitTest/BOSHConnectionTest.cpp +++ b/Swiften/Network/UnitTest/BOSHConnectionTest.cpp @@ -1,300 +1,346 @@ /* - * Copyright (c) 2011 Kevin Smith - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2011-2017 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ +#include <memory> + +#include <boost/bind.hpp> +#include <boost/lexical_cast.hpp> +#include <boost/optional.hpp> + #include <QA/Checker/IO.h> #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> -#include <boost/optional.hpp> -#include <boost/bind.hpp> -#include <boost/smart_ptr/make_shared.hpp> -#include <boost/shared_ptr.hpp> -#include <boost/lexical_cast.hpp> - #include <Swiften/Base/Algorithm.h> +#include <Swiften/EventLoop/DummyEventLoop.h> +#include <Swiften/Network/BOSHConnection.h> #include <Swiften/Network/Connection.h> #include <Swiften/Network/ConnectionFactory.h> -#include <Swiften/Network/BOSHConnection.h> +#include <Swiften/Network/DummyTimerFactory.h> #include <Swiften/Network/HostAddressPort.h> #include <Swiften/Network/StaticDomainNameResolver.h> -#include <Swiften/Network/DummyTimerFactory.h> -#include <Swiften/EventLoop/DummyEventLoop.h> #include <Swiften/Parser/PlatformXMLParserFactory.h> +#include <Swiften/TLS/TLSOptions.h> using namespace Swift; class BOSHConnectionTest : public CppUnit::TestFixture { - CPPUNIT_TEST_SUITE(BOSHConnectionTest); - CPPUNIT_TEST(testHeader); - CPPUNIT_TEST(testReadiness_ok); - CPPUNIT_TEST(testReadiness_pending); - CPPUNIT_TEST(testReadiness_disconnect); - CPPUNIT_TEST(testReadiness_noSID); - CPPUNIT_TEST(testWrite_Receive); - CPPUNIT_TEST(testWrite_ReceiveTwice); - CPPUNIT_TEST(testRead_Fragment); - CPPUNIT_TEST(testHTTPRequest); - CPPUNIT_TEST(testHTTPRequest_Empty); - CPPUNIT_TEST_SUITE_END(); - - public: - void setUp() { - eventLoop = new DummyEventLoop(); - connectionFactory = new MockConnectionFactory(eventLoop); - resolver = new StaticDomainNameResolver(eventLoop); - timerFactory = new DummyTimerFactory(); - connectFinished = false; - disconnected = false; - disconnectedError = false; - dataRead.clear(); - } - - void tearDown() { - eventLoop->processEvents(); - delete connectionFactory; - delete resolver; - delete timerFactory; - delete eventLoop; - } - - void testHeader() { - BOSHConnection::ref testling = createTestling(); - testling->connect(); - eventLoop->processEvents(); - testling->startStream("wonderland.lit", 1); - std::string initial("<body wait='60' " - "inactivity='30' " - "polling='5' " - "requests='2' " - "hold='1' " - "maxpause='120' " - "sid='MyShinySID' " - "ver='1.6' " - "from='wonderland.lit' " - "xmlns='http://jabber.org/protocol/httpbind'/>"); - readResponse(initial, connectionFactory->connections[0]); - CPPUNIT_ASSERT_EQUAL(std::string("MyShinySID"), sid); - CPPUNIT_ASSERT(testling->isReadyToSend()); - } - - void testReadiness_ok() { - BOSHConnection::ref testling = createTestling(); - testling->connect(); - eventLoop->processEvents(); - testling->setSID("blahhhhh"); - CPPUNIT_ASSERT(testling->isReadyToSend()); - } - - void testReadiness_pending() { - BOSHConnection::ref testling = createTestling(); - testling->connect(); - eventLoop->processEvents(); - testling->setSID("mySID"); - CPPUNIT_ASSERT(testling->isReadyToSend()); - testling->write(createSafeByteArray("<mypayload/>")); - CPPUNIT_ASSERT(!testling->isReadyToSend()); - readResponse("<body><blah/></body>", connectionFactory->connections[0]); - CPPUNIT_ASSERT(testling->isReadyToSend()); - } - - void testReadiness_disconnect() { - BOSHConnection::ref testling = createTestling(); - testling->connect(); - eventLoop->processEvents(); - testling->setSID("mySID"); - CPPUNIT_ASSERT(testling->isReadyToSend()); - connectionFactory->connections[0]->onDisconnected(false); - CPPUNIT_ASSERT(!testling->isReadyToSend()); - } - - - void testReadiness_noSID() { - BOSHConnection::ref testling = createTestling(); - testling->connect(); - eventLoop->processEvents(); - CPPUNIT_ASSERT(!testling->isReadyToSend()); - } - - void testWrite_Receive() { - BOSHConnection::ref testling = createTestling(); - testling->connect(); - eventLoop->processEvents(); - testling->setSID("mySID"); - testling->write(createSafeByteArray("<mypayload/>")); - readResponse("<body><blah/></body>", connectionFactory->connections[0]); - CPPUNIT_ASSERT_EQUAL(std::string("<blah/>"), byteArrayToString(dataRead)); - - } - - void testWrite_ReceiveTwice() { - BOSHConnection::ref testling = createTestling(); - testling->connect(); - eventLoop->processEvents(); - testling->setSID("mySID"); - testling->write(createSafeByteArray("<mypayload/>")); - readResponse("<body><blah/></body>", connectionFactory->connections[0]); - CPPUNIT_ASSERT_EQUAL(std::string("<blah/>"), byteArrayToString(dataRead)); - dataRead.clear(); - testling->write(createSafeByteArray("<mypayload2/>")); - readResponse("<body><bleh/></body>", connectionFactory->connections[0]); - CPPUNIT_ASSERT_EQUAL(std::string("<bleh/>"), byteArrayToString(dataRead)); - } - - void testRead_Fragment() { - BOSHConnection::ref testling = createTestling(); - testling->connect(); - eventLoop->processEvents(); - CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(1), connectionFactory->connections.size()); - boost::shared_ptr<MockConnection> connection = connectionFactory->connections[0]; - boost::shared_ptr<SafeByteArray> data1 = boost::make_shared<SafeByteArray>(createSafeByteArray( - "HTTP/1.1 200 OK\r\n" - "Content-Type: text/xml; charset=utf-8\r\n" - "Access-Control-Allow-Origin: *\r\n" - "Access-Control-Allow-Headers: Content-Type\r\n" - "Content-Length: 64\r\n")); - boost::shared_ptr<SafeByteArray> data2 = boost::make_shared<SafeByteArray>(createSafeByteArray( - "\r\n<body xmlns='http://jabber.org/protocol/httpbind'>" - "<bl")); - boost::shared_ptr<SafeByteArray> data3 = boost::make_shared<SafeByteArray>(createSafeByteArray( - "ah/>" - "</body>")); - connection->onDataRead(data1); - connection->onDataRead(data2); - CPPUNIT_ASSERT(dataRead.empty()); - connection->onDataRead(data3); - CPPUNIT_ASSERT_EQUAL(std::string("<blah/>"), byteArrayToString(dataRead)); - } - - void testHTTPRequest() { - std::string data = "<blah/>"; - std::string sid = "wigglebloom"; - std::string fullBody = "<body xmlns='http://jabber.org/protocol/httpbind' sid='" + sid + "' rid='20'>" + data + "</body>"; - std::pair<SafeByteArray, size_t> http = BOSHConnection::createHTTPRequest(createSafeByteArray(data), false, false, 20, sid, URL()); - CPPUNIT_ASSERT_EQUAL(fullBody.size(), http.second); - } - - void testHTTPRequest_Empty() { - std::string data = ""; - std::string sid = "wigglebloomsickle"; - std::string fullBody = "<body rid='42' sid='" + sid + "' xmlns='http://jabber.org/protocol/httpbind'>" + data + "</body>"; - std::pair<SafeByteArray, size_t> http = BOSHConnection::createHTTPRequest(createSafeByteArray(data), false, false, 42, sid, URL()); - CPPUNIT_ASSERT_EQUAL(fullBody.size(), http.second); - std::string response = safeByteArrayToString(http.first); - size_t bodyPosition = response.find("\r\n\r\n"); - CPPUNIT_ASSERT_EQUAL(fullBody, response.substr(bodyPosition+4)); - } - - private: - - BOSHConnection::ref createTestling() { - resolver->addAddress("wonderland.lit", HostAddress("127.0.0.1")); - Connector::ref connector = Connector::create("wonderland.lit", 5280, false, resolver, connectionFactory, timerFactory); - BOSHConnection::ref c = BOSHConnection::create(URL("http", "wonderland.lit", 5280, "/http-bind"), connector, &parserFactory); - c->onConnectFinished.connect(boost::bind(&BOSHConnectionTest::handleConnectFinished, this, _1)); - c->onDisconnected.connect(boost::bind(&BOSHConnectionTest::handleDisconnected, this, _1)); - c->onXMPPDataRead.connect(boost::bind(&BOSHConnectionTest::handleDataRead, this, _1)); - c->onSessionStarted.connect(boost::bind(&BOSHConnectionTest::handleSID, this, _1)); - c->setRID(42); - return c; - } - - void handleConnectFinished(bool error) { - connectFinished = true; - connectFinishedWithError = error; - } - - void handleDisconnected(bool e) { - disconnected = true; - disconnectedError = e; - } - - void handleDataRead(const SafeByteArray& d) { - append(dataRead, d); - } - - void handleSID(const std::string& s) { - sid = s; - } - - struct MockConnection : public Connection { - public: - MockConnection(const std::vector<HostAddressPort>& failingPorts, EventLoop* eventLoop) : eventLoop(eventLoop), failingPorts(failingPorts), disconnected(false) { - } - - void listen() { assert(false); } - - void connect(const HostAddressPort& address) { - hostAddressPort = address; - bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); - eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail)); - } - - HostAddressPort getLocalAddress() const { return HostAddressPort(); } - - void disconnect() { - disconnected = true; - onDisconnected(boost::optional<Connection::Error>()); - } - - void write(const SafeByteArray& d) { - append(dataWritten, d); - } - - EventLoop* eventLoop; - boost::optional<HostAddressPort> hostAddressPort; - std::vector<HostAddressPort> failingPorts; - ByteArray dataWritten; - bool disconnected; - }; - - struct MockConnectionFactory : public ConnectionFactory { - MockConnectionFactory(EventLoop* eventLoop) : eventLoop(eventLoop) { - } - - boost::shared_ptr<Connection> createConnection() { - boost::shared_ptr<MockConnection> connection = boost::make_shared<MockConnection>(failingPorts, eventLoop); - connections.push_back(connection); - return connection; - } - - EventLoop* eventLoop; - std::vector< boost::shared_ptr<MockConnection> > connections; - std::vector<HostAddressPort> failingPorts; - }; - - void readResponse(const std::string& response, boost::shared_ptr<MockConnection> connection) { - boost::shared_ptr<SafeByteArray> data1 = boost::make_shared<SafeByteArray>(createSafeByteArray( - "HTTP/1.1 200 OK\r\n" - "Content-Type: text/xml; charset=utf-8\r\n" - "Access-Control-Allow-Origin: *\r\n" - "Access-Control-Allow-Headers: Content-Type\r\n" - "Content-Length: ")); - connection->onDataRead(data1); - boost::shared_ptr<SafeByteArray> data2 = boost::make_shared<SafeByteArray>(createSafeByteArray(boost::lexical_cast<std::string>(response.size()))); - connection->onDataRead(data2); - boost::shared_ptr<SafeByteArray> data3 = boost::make_shared<SafeByteArray>(createSafeByteArray("\r\n\r\n")); - connection->onDataRead(data3); - boost::shared_ptr<SafeByteArray> data4 = boost::make_shared<SafeByteArray>(createSafeByteArray(response)); - connection->onDataRead(data4); - } - - - private: - DummyEventLoop* eventLoop; - MockConnectionFactory* connectionFactory; - bool connectFinished; - bool connectFinishedWithError; - bool disconnected; - bool disconnectedError; - ByteArray dataRead; - PlatformXMLParserFactory parserFactory; - StaticDomainNameResolver* resolver; - TimerFactory* timerFactory; - std::string sid; + CPPUNIT_TEST_SUITE(BOSHConnectionTest); + CPPUNIT_TEST(testHeader); + CPPUNIT_TEST(testReadiness_ok); + CPPUNIT_TEST(testReadiness_pending); + CPPUNIT_TEST(testReadiness_disconnect); + CPPUNIT_TEST(testReadiness_noSID); + CPPUNIT_TEST(testWrite_Receive); + CPPUNIT_TEST(testWrite_ReceiveTwice); + CPPUNIT_TEST(testRead_Fragment); + CPPUNIT_TEST(testHTTPRequest); + CPPUNIT_TEST(testHTTPRequest_Empty); + CPPUNIT_TEST(testTerminate); + CPPUNIT_TEST(testTerminateWithAdditionalData); + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + eventLoop = new DummyEventLoop(); + connectionFactory = new MockConnectionFactory(eventLoop); + resolver = new StaticDomainNameResolver(eventLoop); + timerFactory = new DummyTimerFactory(); + tlsContextFactory = nullptr; + connectFinished = false; + disconnected = false; + disconnectedError = false; + sessionTerminatedError.reset(); + dataRead.clear(); + } + + void tearDown() { + eventLoop->processEvents(); + delete connectionFactory; + delete resolver; + delete timerFactory; + delete eventLoop; + } + + void testHeader() { + BOSHConnection::ref testling = createTestling(); + testling->connect(); + eventLoop->processEvents(); + testling->startStream("wonderland.lit", 1); + std::string initial("<body wait='60' " + "inactivity='30' " + "polling='5' " + "requests='2' " + "hold='1' " + "maxpause='120' " + "sid='MyShinySID' " + "ver='1.6' " + "from='wonderland.lit' " + "xmlns='http://jabber.org/protocol/httpbind'/>"); + readResponse(initial, connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(std::string("MyShinySID"), sid); + CPPUNIT_ASSERT(testling->isReadyToSend()); + } + + void testReadiness_ok() { + BOSHConnection::ref testling = createTestling(); + testling->connect(); + eventLoop->processEvents(); + testling->setSID("blahhhhh"); + CPPUNIT_ASSERT(testling->isReadyToSend()); + } + + void testReadiness_pending() { + BOSHConnection::ref testling = createTestling(); + testling->connect(); + eventLoop->processEvents(); + testling->setSID("mySID"); + CPPUNIT_ASSERT(testling->isReadyToSend()); + testling->write(createSafeByteArray("<mypayload/>")); + CPPUNIT_ASSERT(!testling->isReadyToSend()); + readResponse("<body><blah/></body>", connectionFactory->connections[0]); + CPPUNIT_ASSERT(testling->isReadyToSend()); + } + + void testReadiness_disconnect() { + BOSHConnection::ref testling = createTestling(); + testling->connect(); + eventLoop->processEvents(); + testling->setSID("mySID"); + CPPUNIT_ASSERT(testling->isReadyToSend()); + connectionFactory->connections[0]->onDisconnected(boost::optional<Connection::Error>()); + CPPUNIT_ASSERT(!testling->isReadyToSend()); + } + + + void testReadiness_noSID() { + BOSHConnection::ref testling = createTestling(); + testling->connect(); + eventLoop->processEvents(); + CPPUNIT_ASSERT(!testling->isReadyToSend()); + } + + void testWrite_Receive() { + BOSHConnection::ref testling = createTestling(); + testling->connect(); + eventLoop->processEvents(); + testling->setSID("mySID"); + testling->write(createSafeByteArray("<mypayload/>")); + readResponse("<body><blah/></body>", connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(std::string("<blah/>"), byteArrayToString(dataRead)); + + } + + void testWrite_ReceiveTwice() { + BOSHConnection::ref testling = createTestling(); + testling->connect(); + eventLoop->processEvents(); + testling->setSID("mySID"); + testling->write(createSafeByteArray("<mypayload/>")); + readResponse("<body><blah/></body>", connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(std::string("<blah/>"), byteArrayToString(dataRead)); + dataRead.clear(); + testling->write(createSafeByteArray("<mypayload2/>")); + readResponse("<body><bleh/></body>", connectionFactory->connections[0]); + CPPUNIT_ASSERT_EQUAL(std::string("<bleh/>"), byteArrayToString(dataRead)); + } + + void testRead_Fragment() { + BOSHConnection::ref testling = createTestling(); + testling->connect(); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(1), connectionFactory->connections.size()); + std::shared_ptr<MockConnection> connection = connectionFactory->connections[0]; + std::shared_ptr<SafeByteArray> data1 = std::make_shared<SafeByteArray>(createSafeByteArray( + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/xml; charset=utf-8\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Access-Control-Allow-Headers: Content-Type\r\n" + "Content-Length: 64\r\n")); + std::shared_ptr<SafeByteArray> data2 = std::make_shared<SafeByteArray>(createSafeByteArray( + "\r\n<body xmlns='http://jabber.org/protocol/httpbind'>" + "<bl")); + std::shared_ptr<SafeByteArray> data3 = std::make_shared<SafeByteArray>(createSafeByteArray( + "ah/>" + "</body>")); + connection->onDataRead(data1); + connection->onDataRead(data2); + CPPUNIT_ASSERT(dataRead.empty()); + connection->onDataRead(data3); + CPPUNIT_ASSERT_EQUAL(std::string("<blah/>"), byteArrayToString(dataRead)); + } + + void testHTTPRequest() { + std::string data = "<blah/>"; + std::string sid = "wigglebloom"; + std::string fullBody = "<body xmlns='http://jabber.org/protocol/httpbind' sid='" + sid + "' rid='20'>" + data + "</body>"; + std::pair<SafeByteArray, size_t> http = BOSHConnection::createHTTPRequest(createSafeByteArray(data), false, false, 20, sid, URL()); + CPPUNIT_ASSERT_EQUAL(fullBody.size(), http.second); + } + + void testHTTPRequest_Empty() { + std::string data = ""; + std::string sid = "wigglebloomsickle"; + std::string fullBody = "<body rid='42' sid='" + sid + "' xmlns='http://jabber.org/protocol/httpbind'>" + data + "</body>"; + std::pair<SafeByteArray, size_t> http = BOSHConnection::createHTTPRequest(createSafeByteArray(data), false, false, 42, sid, URL()); + CPPUNIT_ASSERT_EQUAL(fullBody.size(), http.second); + std::string response = safeByteArrayToString(http.first); + size_t bodyPosition = response.find("\r\n\r\n"); + CPPUNIT_ASSERT_EQUAL(fullBody, response.substr(bodyPosition+4)); + } + + void testTerminate() { + BOSHConnection::ref testling = createTestling(); + testling->connect(); + eventLoop->processEvents(); + testling->startStream("localhost", 1); + std::string initial("<body xmlns=\"http://jabber.org/protocol/httpbind\" " + "condition=\"bad-request\" " + "type=\"terminate\">" + "</body>"); + readResponse(initial, connectionFactory->connections[0]); + CPPUNIT_ASSERT(sessionTerminatedError); + CPPUNIT_ASSERT_EQUAL(BOSHError::BadRequest, sessionTerminatedError->getType()); + CPPUNIT_ASSERT_EQUAL(true, dataRead.empty()); + } + + // On a BOSH error no additional data may be emitted. + void testTerminateWithAdditionalData() { + BOSHConnection::ref testling = createTestling(); + testling->connect(); + eventLoop->processEvents(); + testling->startStream("localhost", 1); + std::string initial("<body xmlns=\"http://jabber.org/protocol/httpbind\" " + "condition=\"bad-request\" " + "type=\"terminate\">" + "<text>an error message</text>" + "</body>"); + readResponse(initial, connectionFactory->connections[0]); + CPPUNIT_ASSERT(sessionTerminatedError); + CPPUNIT_ASSERT_EQUAL(BOSHError::BadRequest, sessionTerminatedError->getType()); + CPPUNIT_ASSERT_EQUAL(true, dataRead.empty()); + } + + + private: + + BOSHConnection::ref createTestling() { + resolver->addAddress("wonderland.lit", HostAddress::fromString("127.0.0.1").get()); + Connector::ref connector = Connector::create("wonderland.lit", 5280, boost::optional<std::string>(), resolver, connectionFactory, timerFactory); + BOSHConnection::ref c = BOSHConnection::create(URL("http", "wonderland.lit", 5280, "/http-bind"), connector, &parserFactory, tlsContextFactory, TLSOptions()); + c->onConnectFinished.connect(boost::bind(&BOSHConnectionTest::handleConnectFinished, this, _1)); + c->onDisconnected.connect(boost::bind(&BOSHConnectionTest::handleDisconnected, this, _1)); + c->onXMPPDataRead.connect(boost::bind(&BOSHConnectionTest::handleDataRead, this, _1)); + c->onSessionStarted.connect(boost::bind(&BOSHConnectionTest::handleSID, this, _1)); + c->onSessionTerminated.connect(boost::bind(&BOSHConnectionTest::handleSessionTerminated, this, _1)); + c->setRID(42); + return c; + } + + void handleConnectFinished(bool error) { + connectFinished = true; + connectFinishedWithError = error; + } + + void handleDisconnected(bool e) { + disconnected = true; + disconnectedError = e; + } + + void handleDataRead(const SafeByteArray& d) { + append(dataRead, d); + } + + void handleSID(const std::string& s) { + sid = s; + } + + void handleSessionTerminated(BOSHError::ref error) { + sessionTerminatedError = error; + } + + struct MockConnection : public Connection { + public: + MockConnection(const std::vector<HostAddressPort>& failingPorts, EventLoop* eventLoop) : eventLoop(eventLoop), failingPorts(failingPorts), disconnected(false) { + } + + void listen() { assert(false); } + + void connect(const HostAddressPort& address) { + hostAddressPort = address; + bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); + eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail)); + } + + HostAddressPort getLocalAddress() const { return HostAddressPort(); } + HostAddressPort getRemoteAddress() const { return HostAddressPort(); } + + void disconnect() { + disconnected = true; + onDisconnected(boost::optional<Connection::Error>()); + } + + void write(const SafeByteArray& d) { + append(dataWritten, d); + } + + EventLoop* eventLoop; + boost::optional<HostAddressPort> hostAddressPort; + std::vector<HostAddressPort> failingPorts; + ByteArray dataWritten; + bool disconnected; + }; + + struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory(EventLoop* eventLoop) : eventLoop(eventLoop) { + } + + std::shared_ptr<Connection> createConnection() { + std::shared_ptr<MockConnection> connection = std::make_shared<MockConnection>(failingPorts, eventLoop); + connections.push_back(connection); + return connection; + } + + EventLoop* eventLoop; + std::vector< std::shared_ptr<MockConnection> > connections; + std::vector<HostAddressPort> failingPorts; + }; + + void readResponse(const std::string& response, std::shared_ptr<MockConnection> connection) { + std::shared_ptr<SafeByteArray> data1 = std::make_shared<SafeByteArray>(createSafeByteArray( + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/xml; charset=utf-8\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Access-Control-Allow-Headers: Content-Type\r\n" + "Content-Length: ")); + connection->onDataRead(data1); + std::shared_ptr<SafeByteArray> data2 = std::make_shared<SafeByteArray>(createSafeByteArray(std::to_string(response.size()))); + connection->onDataRead(data2); + std::shared_ptr<SafeByteArray> data3 = std::make_shared<SafeByteArray>(createSafeByteArray("\r\n\r\n")); + connection->onDataRead(data3); + std::shared_ptr<SafeByteArray> data4 = std::make_shared<SafeByteArray>(createSafeByteArray(response)); + connection->onDataRead(data4); + } + + + private: + DummyEventLoop* eventLoop; + MockConnectionFactory* connectionFactory; + bool connectFinished; + bool connectFinishedWithError; + bool disconnected; + bool disconnectedError; + BOSHError::ref sessionTerminatedError; + ByteArray dataRead; + PlatformXMLParserFactory parserFactory; + StaticDomainNameResolver* resolver; + TimerFactory* timerFactory; + TLSContextFactory* tlsContextFactory; + std::string sid; }; diff --git a/Swiften/Network/UnitTest/ChainedConnectorTest.cpp b/Swiften/Network/UnitTest/ChainedConnectorTest.cpp index 9abed57..2d78cd7 100644 --- a/Swiften/Network/UnitTest/ChainedConnectorTest.cpp +++ b/Swiften/Network/UnitTest/ChainedConnectorTest.cpp @@ -1,186 +1,188 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ -#include <cppunit/extensions/HelperMacros.h> -#include <cppunit/extensions/TestFactoryRegistry.h> +#include <memory> #include <boost/bind.hpp> -#include <boost/smart_ptr/make_shared.hpp> +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> + +#include <Swiften/EventLoop/DummyEventLoop.h> #include <Swiften/Network/ChainedConnector.h> #include <Swiften/Network/Connection.h> #include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/Network/DomainNameResolveError.h> +#include <Swiften/Network/DummyTimerFactory.h> #include <Swiften/Network/HostAddressPort.h> #include <Swiften/Network/StaticDomainNameResolver.h> -#include <Swiften/Network/DummyTimerFactory.h> -#include <Swiften/EventLoop/DummyEventLoop.h> -#include <Swiften/Network/DomainNameResolveError.h> using namespace Swift; class ChainedConnectorTest : public CppUnit::TestFixture { - CPPUNIT_TEST_SUITE(ChainedConnectorTest); - CPPUNIT_TEST(testConnect_FirstConnectorSucceeds); - CPPUNIT_TEST(testConnect_SecondConnectorSucceeds); - CPPUNIT_TEST(testConnect_NoConnectorSucceeds); - CPPUNIT_TEST(testConnect_NoDNS); - CPPUNIT_TEST(testStop); - CPPUNIT_TEST_SUITE_END(); - - public: - void setUp() { - error.reset(); - host = HostAddressPort(HostAddress("1.1.1.1"), 1234); - eventLoop = new DummyEventLoop(); - resolver = new StaticDomainNameResolver(eventLoop); - resolver->addXMPPClientService("foo.com", host); - connectionFactory1 = new MockConnectionFactory(eventLoop, 1); - connectionFactory2 = new MockConnectionFactory(eventLoop, 2); - timerFactory = new DummyTimerFactory(); - } - - void tearDown() { - delete timerFactory; - delete connectionFactory2; - delete connectionFactory1; - delete resolver; - delete eventLoop; - } - - void testConnect_FirstConnectorSucceeds() { - boost::shared_ptr<ChainedConnector> testling(createConnector()); - connectionFactory1->connects = true; - connectionFactory2->connects = false; - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(connections[0]); - CPPUNIT_ASSERT_EQUAL(1, boost::dynamic_pointer_cast<MockConnection>(connections[0])->id); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_SecondConnectorSucceeds() { - boost::shared_ptr<ChainedConnector> testling(createConnector()); - connectionFactory1->connects = false; - connectionFactory2->connects = true; - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(connections[0]); - CPPUNIT_ASSERT_EQUAL(2, boost::dynamic_pointer_cast<MockConnection>(connections[0])->id); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_NoConnectorSucceeds() { - boost::shared_ptr<ChainedConnector> testling(createConnector()); - connectionFactory1->connects = false; - connectionFactory2->connects = false; - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(!connections[0]); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_NoDNS() { - /* Reset resolver so there's no record */ - delete resolver; - resolver = new StaticDomainNameResolver(eventLoop); - boost::shared_ptr<ChainedConnector> testling(createConnector()); - connectionFactory1->connects = false; - connectionFactory2->connects = false; - - testling->start(); - //testling->stop(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(!connections[0]); - CPPUNIT_ASSERT(boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testStop() { - boost::shared_ptr<ChainedConnector> testling(createConnector()); - connectionFactory1->connects = true; - connectionFactory2->connects = false; - - testling->start(); - testling->stop(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(!connections[0]); - } - - private: - boost::shared_ptr<ChainedConnector> createConnector() { - std::vector<ConnectionFactory*> factories; - factories.push_back(connectionFactory1); - factories.push_back(connectionFactory2); - boost::shared_ptr<ChainedConnector> connector = boost::make_shared<ChainedConnector>("foo.com", -1, true, resolver, factories, timerFactory); - connector->onConnectFinished.connect(boost::bind(&ChainedConnectorTest::handleConnectorFinished, this, _1, _2)); - return connector; - } - - void handleConnectorFinished(boost::shared_ptr<Connection> connection, boost::shared_ptr<Error> resultError) { - error = resultError; - boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection)); - if (connection) { - assert(c); - } - connections.push_back(c); - } - - struct MockConnection : public Connection { - public: - MockConnection(bool connects, int id, EventLoop* eventLoop) : connects(connects), id(id), eventLoop(eventLoop) { - } - - void listen() { assert(false); } - void connect(const HostAddressPort&) { - eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), !connects)); - } - - HostAddressPort getLocalAddress() const { return HostAddressPort(); } - void disconnect() { assert(false); } - void write(const SafeByteArray&) { assert(false); } - - bool connects; - int id; - EventLoop* eventLoop; - }; - - struct MockConnectionFactory : public ConnectionFactory { - MockConnectionFactory(EventLoop* eventLoop, int id) : eventLoop(eventLoop), connects(true), id(id) { - } - - boost::shared_ptr<Connection> createConnection() { - return boost::make_shared<MockConnection>(connects, id, eventLoop); - } - - EventLoop* eventLoop; - bool connects; - int id; - }; - - private: - HostAddressPort host; - DummyEventLoop* eventLoop; - StaticDomainNameResolver* resolver; - MockConnectionFactory* connectionFactory1; - MockConnectionFactory* connectionFactory2; - DummyTimerFactory* timerFactory; - std::vector< boost::shared_ptr<MockConnection> > connections; - boost::shared_ptr<Error> error; + CPPUNIT_TEST_SUITE(ChainedConnectorTest); + CPPUNIT_TEST(testConnect_FirstConnectorSucceeds); + CPPUNIT_TEST(testConnect_SecondConnectorSucceeds); + CPPUNIT_TEST(testConnect_NoConnectorSucceeds); + CPPUNIT_TEST(testConnect_NoDNS); + CPPUNIT_TEST(testStop); + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + error.reset(); + host = HostAddressPort(HostAddress::fromString("1.1.1.1").get(), 1234); + eventLoop = new DummyEventLoop(); + resolver = new StaticDomainNameResolver(eventLoop); + resolver->addXMPPClientService("foo.com", host); + connectionFactory1 = new MockConnectionFactory(eventLoop, 1); + connectionFactory2 = new MockConnectionFactory(eventLoop, 2); + timerFactory = new DummyTimerFactory(); + } + + void tearDown() { + delete timerFactory; + delete connectionFactory2; + delete connectionFactory1; + delete resolver; + delete eventLoop; + } + + void testConnect_FirstConnectorSucceeds() { + std::shared_ptr<ChainedConnector> testling(createConnector()); + connectionFactory1->connects = true; + connectionFactory2->connects = false; + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT_EQUAL(1, std::dynamic_pointer_cast<MockConnection>(connections[0])->id); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_SecondConnectorSucceeds() { + std::shared_ptr<ChainedConnector> testling(createConnector()); + connectionFactory1->connects = false; + connectionFactory2->connects = true; + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT_EQUAL(2, std::dynamic_pointer_cast<MockConnection>(connections[0])->id); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_NoConnectorSucceeds() { + std::shared_ptr<ChainedConnector> testling(createConnector()); + connectionFactory1->connects = false; + connectionFactory2->connects = false; + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_NoDNS() { + /* Reset resolver so there's no record */ + delete resolver; + resolver = new StaticDomainNameResolver(eventLoop); + std::shared_ptr<ChainedConnector> testling(createConnector()); + connectionFactory1->connects = false; + connectionFactory2->connects = false; + + testling->start(); + //testling->stop(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + CPPUNIT_ASSERT(std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testStop() { + std::shared_ptr<ChainedConnector> testling(createConnector()); + connectionFactory1->connects = true; + connectionFactory2->connects = false; + + testling->start(); + testling->stop(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + private: + std::shared_ptr<ChainedConnector> createConnector() { + std::vector<ConnectionFactory*> factories; + factories.push_back(connectionFactory1); + factories.push_back(connectionFactory2); + std::shared_ptr<ChainedConnector> connector = std::make_shared<ChainedConnector>("foo.com", -1, boost::optional<std::string>("_xmpp-client._tcp."), resolver, factories, timerFactory); + connector->onConnectFinished.connect(boost::bind(&ChainedConnectorTest::handleConnectorFinished, this, _1, _2)); + return connector; + } + + void handleConnectorFinished(std::shared_ptr<Connection> connection, std::shared_ptr<Error> resultError) { + error = resultError; + std::shared_ptr<MockConnection> c(std::dynamic_pointer_cast<MockConnection>(connection)); + if (connection) { + assert(c); + } + connections.push_back(c); + } + + struct MockConnection : public Connection { + public: + MockConnection(bool connects, int id, EventLoop* eventLoop) : connects(connects), id(id), eventLoop(eventLoop) { + } + + void listen() { assert(false); } + void connect(const HostAddressPort&) { + eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), !connects)); + } + + HostAddressPort getLocalAddress() const { return HostAddressPort(); } + HostAddressPort getRemoteAddress() const { return HostAddressPort(); } + void disconnect() { assert(false); } + void write(const SafeByteArray&) { assert(false); } + + bool connects; + int id; + EventLoop* eventLoop; + }; + + struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory(EventLoop* eventLoop, int id) : eventLoop(eventLoop), connects(true), id(id) { + } + + std::shared_ptr<Connection> createConnection() { + return std::make_shared<MockConnection>(connects, id, eventLoop); + } + + EventLoop* eventLoop; + bool connects; + int id; + }; + + private: + HostAddressPort host; + DummyEventLoop* eventLoop; + StaticDomainNameResolver* resolver; + MockConnectionFactory* connectionFactory1; + MockConnectionFactory* connectionFactory2; + DummyTimerFactory* timerFactory; + std::vector< std::shared_ptr<MockConnection> > connections; + std::shared_ptr<Error> error; }; CPPUNIT_TEST_SUITE_REGISTRATION(ChainedConnectorTest); diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp index fe18340..065911d 100644 --- a/Swiften/Network/UnitTest/ConnectorTest.cpp +++ b/Swiften/Network/UnitTest/ConnectorTest.cpp @@ -1,378 +1,398 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ +#include <boost/bind.hpp> +#include <boost/optional.hpp> + #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> -#include <boost/optional.hpp> -#include <boost/bind.hpp> - -#include <Swiften/Network/Connector.h> +#include <Swiften/EventLoop/DummyEventLoop.h> #include <Swiften/Network/Connection.h> #include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/Network/Connector.h> +#include <Swiften/Network/DomainNameAddressQuery.h> +#include <Swiften/Network/DummyTimerFactory.h> #include <Swiften/Network/HostAddressPort.h> #include <Swiften/Network/StaticDomainNameResolver.h> -#include <Swiften/Network/DummyTimerFactory.h> -#include <Swiften/EventLoop/DummyEventLoop.h> -#include <Swiften/Network/DomainNameAddressQuery.h> using namespace Swift; class ConnectorTest : public CppUnit::TestFixture { - CPPUNIT_TEST_SUITE(ConnectorTest); - CPPUNIT_TEST(testConnect); - CPPUNIT_TEST(testConnect_NoServiceLookups); - CPPUNIT_TEST(testConnect_NoServiceLookups_DefaultPort); - CPPUNIT_TEST(testConnect_FirstAddressHostFails); - CPPUNIT_TEST(testConnect_NoSRVHost); - CPPUNIT_TEST(testConnect_NoHosts); - CPPUNIT_TEST(testConnect_FirstSRVHostFails); - CPPUNIT_TEST(testConnect_AllSRVHostsFailWithoutFallbackHost); - CPPUNIT_TEST(testConnect_AllSRVHostsFailWithFallbackHost); - CPPUNIT_TEST(testConnect_SRVAndFallbackHostsFail); - //CPPUNIT_TEST(testConnect_TimeoutDuringResolve); - CPPUNIT_TEST(testConnect_TimeoutDuringConnectToOnlyCandidate); - CPPUNIT_TEST(testConnect_TimeoutDuringConnectToCandidateFallsBack); - CPPUNIT_TEST(testConnect_NoTimeout); - CPPUNIT_TEST(testStop_DuringSRVQuery); - CPPUNIT_TEST(testStop_Timeout); - CPPUNIT_TEST_SUITE_END(); - - public: - void setUp() { - host1 = HostAddressPort(HostAddress("1.1.1.1"), 1234); - host2 = HostAddressPort(HostAddress("2.2.2.2"), 2345); - host3 = HostAddressPort(HostAddress("3.3.3.3"), 5222); - eventLoop = new DummyEventLoop(); - resolver = new StaticDomainNameResolver(eventLoop); - connectionFactory = new MockConnectionFactory(eventLoop); - timerFactory = new DummyTimerFactory(); - } - - void tearDown() { - delete timerFactory; - delete connectionFactory; - delete resolver; - delete eventLoop; - } - - void testConnect() { - Connector::ref testling(createConnector()); - resolver->addXMPPClientService("foo.com", host1); - resolver->addXMPPClientService("foo.com", host2); - resolver->addAddress("foo.com", host3.getAddress()); - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(connections[0]); - CPPUNIT_ASSERT(host1 == *(connections[0]->hostAddressPort)); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_NoServiceLookups() { - Connector::ref testling(createConnector(4321, false)); - resolver->addXMPPClientService("foo.com", host1); - resolver->addXMPPClientService("foo.com", host2); - resolver->addAddress("foo.com", host3.getAddress()); - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(connections[0]); - CPPUNIT_ASSERT(host3.getAddress() == (*(connections[0]->hostAddressPort)).getAddress()); - CPPUNIT_ASSERT(4321 == (*(connections[0]->hostAddressPort)).getPort()); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_NoServiceLookups_DefaultPort() { - Connector::ref testling(createConnector(-1, false)); - resolver->addXMPPClientService("foo.com", host1); - resolver->addXMPPClientService("foo.com", host2); - resolver->addAddress("foo.com", host3.getAddress()); - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(connections[0]); - CPPUNIT_ASSERT(host3.getAddress() == (*(connections[0]->hostAddressPort)).getAddress()); - CPPUNIT_ASSERT_EQUAL(5222, (*(connections[0]->hostAddressPort)).getPort()); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_NoSRVHost() { - Connector::ref testling(createConnector()); - resolver->addAddress("foo.com", host3.getAddress()); - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(connections[0]); - CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort)); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_FirstAddressHostFails() { - Connector::ref testling(createConnector()); - - HostAddress address1("1.1.1.1"); - HostAddress address2("2.2.2.2"); - resolver->addXMPPClientService("foo.com", "host-foo.com", 1234); - resolver->addAddress("host-foo.com", address1); - resolver->addAddress("host-foo.com", address2); - connectionFactory->failingPorts.push_back(HostAddressPort(address1, 1234)); - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(connections[0]); - CPPUNIT_ASSERT(HostAddressPort(address2, 1234) == *(connections[0]->hostAddressPort)); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_NoHosts() { - Connector::ref testling(createConnector()); - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(!connections[0]); - CPPUNIT_ASSERT(boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_FirstSRVHostFails() { - Connector::ref testling(createConnector()); - resolver->addXMPPClientService("foo.com", host1); - resolver->addXMPPClientService("foo.com", host2); - connectionFactory->failingPorts.push_back(host1); - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(host2 == *(connections[0]->hostAddressPort)); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_AllSRVHostsFailWithoutFallbackHost() { - Connector::ref testling(createConnector()); - resolver->addXMPPClientService("foo.com", host1); - resolver->addXMPPClientService("foo.com", host2); - connectionFactory->failingPorts.push_back(host1); - connectionFactory->failingPorts.push_back(host2); - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(!connections[0]); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_AllSRVHostsFailWithFallbackHost() { - Connector::ref testling(createConnector()); - resolver->addXMPPClientService("foo.com", host1); - resolver->addXMPPClientService("foo.com", host2); - resolver->addAddress("foo.com", host3.getAddress()); - connectionFactory->failingPorts.push_back(host1); - connectionFactory->failingPorts.push_back(host2); - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(connections[0]); - CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort)); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_SRVAndFallbackHostsFail() { - Connector::ref testling(createConnector()); - resolver->addXMPPClientService("foo.com", host1); - resolver->addAddress("foo.com", host3.getAddress()); - connectionFactory->failingPorts.push_back(host1); - connectionFactory->failingPorts.push_back(host3); - - testling->start(); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(!connections[0]); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - /*void testConnect_TimeoutDuringResolve() { - Connector::ref testling(createConnector()); - testling->setTimeoutMilliseconds(10); - resolver->setIsResponsive(false); - - testling->start(); - eventLoop->processEvents(); - timerFactory->setTime(10); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - CPPUNIT_ASSERT(!connections[0]); - }*/ - - void testConnect_TimeoutDuringConnectToOnlyCandidate() { - Connector::ref testling(createConnector()); - testling->setTimeoutMilliseconds(10); - resolver->addXMPPClientService("foo.com", host1); - connectionFactory->isResponsive = false; - - testling->start(); - eventLoop->processEvents(); - timerFactory->setTime(10); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(!connections[0]); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testConnect_TimeoutDuringConnectToCandidateFallsBack() { - Connector::ref testling(createConnector()); - testling->setTimeoutMilliseconds(10); - - resolver->addXMPPClientService("foo.com", "host-foo.com", 1234); - HostAddress address1("1.1.1.1"); - resolver->addAddress("host-foo.com", address1); - HostAddress address2("2.2.2.2"); - resolver->addAddress("host-foo.com", address2); - - connectionFactory->isResponsive = false; - testling->start(); - eventLoop->processEvents(); - connectionFactory->isResponsive = true; - timerFactory->setTime(10); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(connections[0]); - CPPUNIT_ASSERT(HostAddressPort(address2, 1234) == *(connections[0]->hostAddressPort)); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - - void testConnect_NoTimeout() { - Connector::ref testling(createConnector()); - testling->setTimeoutMilliseconds(10); - resolver->addXMPPClientService("foo.com", host1); - - testling->start(); - eventLoop->processEvents(); - timerFactory->setTime(10); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(connections[0]); - CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testStop_DuringSRVQuery() { - Connector::ref testling(createConnector()); - resolver->addXMPPClientService("foo.com", host1); - - testling->start(); - testling->stop(); - - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(!connections[0]); - CPPUNIT_ASSERT(boost::dynamic_pointer_cast<DomainNameResolveError>(error)); - } - - void testStop_Timeout() { - Connector::ref testling(createConnector()); - testling->setTimeoutMilliseconds(10); - resolver->addXMPPClientService("foo.com", host1); - - testling->start(); - testling->stop(); - - eventLoop->processEvents(); - timerFactory->setTime(10); - eventLoop->processEvents(); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); - CPPUNIT_ASSERT(!connections[0]); - } - - - private: - Connector::ref createConnector(int port = -1, bool doServiceLookups = true) { - Connector::ref connector = Connector::create("foo.com", port, doServiceLookups, resolver, connectionFactory, timerFactory); - connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1, _2)); - return connector; - } - - void handleConnectorFinished(boost::shared_ptr<Connection> connection, boost::shared_ptr<Error> resultError) { - boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection)); - if (connection) { - assert(c); - } - connections.push_back(c); - error = resultError; - } - - struct MockConnection : public Connection { - public: - MockConnection(const std::vector<HostAddressPort>& failingPorts, bool isResponsive, EventLoop* eventLoop) : eventLoop(eventLoop), failingPorts(failingPorts), isResponsive(isResponsive) {} - - void listen() { assert(false); } - void connect(const HostAddressPort& address) { - hostAddressPort = address; - if (isResponsive) { - bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); - eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail)); - } - } - - HostAddressPort getLocalAddress() const { return HostAddressPort(); } - void disconnect() { assert(false); } - void write(const SafeByteArray&) { assert(false); } - - EventLoop* eventLoop; - boost::optional<HostAddressPort> hostAddressPort; - std::vector<HostAddressPort> failingPorts; - bool isResponsive; - }; - - struct MockConnectionFactory : public ConnectionFactory { - MockConnectionFactory(EventLoop* eventLoop) : eventLoop(eventLoop), isResponsive(true) { - } - - boost::shared_ptr<Connection> createConnection() { - return boost::shared_ptr<Connection>(new MockConnection(failingPorts, isResponsive, eventLoop)); - } - - EventLoop* eventLoop; - bool isResponsive; - std::vector<HostAddressPort> failingPorts; - }; - - private: - HostAddressPort host1; - HostAddressPort host2; - HostAddressPort host3; - DummyEventLoop* eventLoop; - StaticDomainNameResolver* resolver; - MockConnectionFactory* connectionFactory; - DummyTimerFactory* timerFactory; - std::vector< boost::shared_ptr<MockConnection> > connections; - boost::shared_ptr<Error> error; + CPPUNIT_TEST_SUITE(ConnectorTest); + CPPUNIT_TEST(testConnect); + CPPUNIT_TEST(testConnect_NoServiceLookups); + CPPUNIT_TEST(testConnect_NoServiceLookups_DefaultPort); + CPPUNIT_TEST(testConnect_OnlyLiteral); + CPPUNIT_TEST(testConnect_FirstAddressHostFails); + CPPUNIT_TEST(testConnect_NoSRVHost); + CPPUNIT_TEST(testConnect_NoHosts); + CPPUNIT_TEST(testConnect_FirstSRVHostFails); + CPPUNIT_TEST(testConnect_AllSRVHostsFailWithoutFallbackHost); + CPPUNIT_TEST(testConnect_AllSRVHostsFailWithFallbackHost); + CPPUNIT_TEST(testConnect_SRVAndFallbackHostsFail); + //CPPUNIT_TEST(testConnect_TimeoutDuringResolve); + CPPUNIT_TEST(testConnect_TimeoutDuringConnectToOnlyCandidate); + CPPUNIT_TEST(testConnect_TimeoutDuringConnectToCandidateFallsBack); + CPPUNIT_TEST(testConnect_NoTimeout); + CPPUNIT_TEST(testStop_DuringSRVQuery); + CPPUNIT_TEST(testStop_Timeout); + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + host1 = HostAddressPort(HostAddress::fromString("1.1.1.1").get(), 1234); + host2 = HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345); + host3 = HostAddressPort(HostAddress::fromString("3.3.3.3").get(), 5222); + eventLoop = new DummyEventLoop(); + resolver = new StaticDomainNameResolver(eventLoop); + connectionFactory = new MockConnectionFactory(eventLoop); + timerFactory = new DummyTimerFactory(); + } + + void tearDown() { + delete timerFactory; + delete connectionFactory; + delete resolver; + delete eventLoop; + } + + void testConnect() { + Connector::ref testling(createConnector()); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); + resolver->addAddress("foo.com", host3.getAddress()); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(host1 == *(connections[0]->hostAddressPort)); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_NoServiceLookups() { + Connector::ref testling(createConnector(4321, boost::optional<std::string>())); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); + resolver->addAddress("foo.com", host3.getAddress()); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(host3.getAddress() == (*(connections[0]->hostAddressPort)).getAddress()); + CPPUNIT_ASSERT(4321 == (*(connections[0]->hostAddressPort)).getPort()); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_NoServiceLookups_DefaultPort() { + Connector::ref testling(createConnector(0, boost::optional<std::string>())); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); + resolver->addAddress("foo.com", host3.getAddress()); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(host3.getAddress() == (*(connections[0]->hostAddressPort)).getAddress()); + CPPUNIT_ASSERT_EQUAL(static_cast<unsigned short>(5222), (*(connections[0]->hostAddressPort)).getPort()); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_NoSRVHost() { + Connector::ref testling(createConnector()); + resolver->addAddress("foo.com", host3.getAddress()); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort)); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_OnlyLiteral() { + auto testling = Connector::create("1.1.1.1", 1234, boost::none, resolver, connectionFactory, timerFactory); + testling->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1, _2)); + + auto address1 = HostAddress::fromString("1.1.1.1").get(); + connectionFactory->failingPorts.push_back(HostAddressPort(address1, 1234)); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connectionFactory->createdConnections.size())); + } + + void testConnect_FirstAddressHostFails() { + Connector::ref testling(createConnector()); + + auto address1 = HostAddress::fromString("1.1.1.1").get(); + auto address2 = HostAddress::fromString("2.2.2.2").get(); + resolver->addXMPPClientService("foo.com", "host-foo.com", 1234); + resolver->addAddress("host-foo.com", address1); + resolver->addAddress("host-foo.com", address2); + connectionFactory->failingPorts.push_back(HostAddressPort(address1, 1234)); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(HostAddressPort(address2, 1234) == *(connections[0]->hostAddressPort)); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_NoHosts() { + Connector::ref testling(createConnector()); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + CPPUNIT_ASSERT(std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_FirstSRVHostFails() { + Connector::ref testling(createConnector()); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); + connectionFactory->failingPorts.push_back(host1); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(host2 == *(connections[0]->hostAddressPort)); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_AllSRVHostsFailWithoutFallbackHost() { + Connector::ref testling(createConnector()); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); + connectionFactory->failingPorts.push_back(host1); + connectionFactory->failingPorts.push_back(host2); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_AllSRVHostsFailWithFallbackHost() { + Connector::ref testling(createConnector()); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); + resolver->addAddress("foo.com", host3.getAddress()); + connectionFactory->failingPorts.push_back(host1); + connectionFactory->failingPorts.push_back(host2); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort)); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_SRVAndFallbackHostsFail() { + Connector::ref testling(createConnector()); + resolver->addXMPPClientService("foo.com", host1); + resolver->addAddress("foo.com", host3.getAddress()); + connectionFactory->failingPorts.push_back(host1); + connectionFactory->failingPorts.push_back(host3); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + /*void testConnect_TimeoutDuringResolve() { + Connector::ref testling(createConnector()); + testling->setTimeoutMilliseconds(10); + resolver->setIsResponsive(false); + + testling->start(); + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(std::dynamic_pointer_cast<DomainNameResolveError>(error)); + CPPUNIT_ASSERT(!connections[0]); + }*/ + + void testConnect_TimeoutDuringConnectToOnlyCandidate() { + Connector::ref testling(createConnector()); + testling->setTimeoutMilliseconds(10); + resolver->addXMPPClientService("foo.com", host1); + connectionFactory->isResponsive = false; + + testling->start(); + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testConnect_TimeoutDuringConnectToCandidateFallsBack() { + Connector::ref testling(createConnector()); + testling->setTimeoutMilliseconds(10); + + auto address2 = HostAddress::fromString("2.2.2.2").get(); + + resolver->addXMPPClientService("foo.com", "host-foo.com", 1234); + resolver->addAddress("host-foo.com", HostAddress::fromString("1.1.1.1").get()); + resolver->addAddress("host-foo.com", address2); + + connectionFactory->isResponsive = false; + testling->start(); + eventLoop->processEvents(); + connectionFactory->isResponsive = true; + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(HostAddressPort(address2, 1234) == *(connections[0]->hostAddressPort)); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + + void testConnect_NoTimeout() { + Connector::ref testling(createConnector()); + testling->setTimeoutMilliseconds(10); + resolver->addXMPPClientService("foo.com", host1); + + testling->start(); + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(!std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testStop_DuringSRVQuery() { + Connector::ref testling(createConnector()); + resolver->addXMPPClientService("foo.com", host1); + + testling->start(); + testling->stop(); + + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + CPPUNIT_ASSERT(std::dynamic_pointer_cast<DomainNameResolveError>(error)); + } + + void testStop_Timeout() { + Connector::ref testling(createConnector()); + testling->setTimeoutMilliseconds(10); + resolver->addXMPPClientService("foo.com", host1); + + testling->start(); + testling->stop(); + + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + + private: + Connector::ref createConnector(unsigned short port = 0, boost::optional<std::string> serviceLookupPrefix = boost::optional<std::string>("_xmpp-client._tcp.")) { + Connector::ref connector = Connector::create("foo.com", port, serviceLookupPrefix, resolver, connectionFactory, timerFactory); + connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1, _2)); + return connector; + } + + void handleConnectorFinished(std::shared_ptr<Connection> connection, std::shared_ptr<Error> resultError) { + std::shared_ptr<MockConnection> c(std::dynamic_pointer_cast<MockConnection>(connection)); + if (connection) { + assert(c); + } + connections.push_back(c); + error = resultError; + } + + struct MockConnection : public Connection { + public: + MockConnection(const std::vector<HostAddressPort>& failingPorts, bool isResponsive, EventLoop* eventLoop) : eventLoop(eventLoop), failingPorts(failingPorts), isResponsive(isResponsive) {} + + void listen() { assert(false); } + void connect(const HostAddressPort& address) { + hostAddressPort = address; + if (isResponsive) { + bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); + eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail)); + } + } + + HostAddressPort getLocalAddress() const { return HostAddressPort(); } + HostAddressPort getRemoteAddress() const { return HostAddressPort(); } + void disconnect() { assert(false); } + void write(const SafeByteArray&) { assert(false); } + + EventLoop* eventLoop; + boost::optional<HostAddressPort> hostAddressPort; + std::vector<HostAddressPort> failingPorts; + bool isResponsive; + }; + + struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory(EventLoop* eventLoop) : eventLoop(eventLoop), isResponsive(true) { + } + + std::shared_ptr<Connection> createConnection() { + auto connection = std::make_shared<MockConnection>(failingPorts, isResponsive, eventLoop); + createdConnections.push_back(connection); + return connection; + } + + EventLoop* eventLoop; + bool isResponsive; + std::vector<HostAddressPort> failingPorts; + std::vector<std::shared_ptr<MockConnection>> createdConnections; + }; + + private: + HostAddressPort host1; + HostAddressPort host2; + HostAddressPort host3; + DummyEventLoop* eventLoop; + StaticDomainNameResolver* resolver; + MockConnectionFactory* connectionFactory; + DummyTimerFactory* timerFactory; + std::vector< std::shared_ptr<MockConnection> > connections; + std::shared_ptr<Error> error; }; diff --git a/Swiften/Network/UnitTest/DomainNameServiceQueryTest.cpp b/Swiften/Network/UnitTest/DomainNameServiceQueryTest.cpp index 53b9413..7042b27 100644 --- a/Swiften/Network/UnitTest/DomainNameServiceQueryTest.cpp +++ b/Swiften/Network/UnitTest/DomainNameServiceQueryTest.cpp @@ -1,7 +1,7 @@ /* - * Copyright (c) 2012 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2012-2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ #include <QA/Checker/IO.h> @@ -9,73 +9,73 @@ #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> -#include <Swiften/Network/DomainNameServiceQuery.h> #include <Swiften/Base/RandomGenerator.h> +#include <Swiften/Network/DomainNameServiceQuery.h> using namespace Swift; namespace { - struct RandomGenerator1 : public RandomGenerator { - virtual int generateRandomInteger(int) { - return 0; - } - }; + struct RandomGenerator1 : public RandomGenerator { + virtual int generateRandomInteger(int) { + return 0; + } + }; - struct RandomGenerator2 : public RandomGenerator { - virtual int generateRandomInteger(int i) { - return i; - } - }; + struct RandomGenerator2 : public RandomGenerator { + virtual int generateRandomInteger(int i) { + return i; + } + }; } class DomainNameServiceQueryTest : public CppUnit::TestFixture { - CPPUNIT_TEST_SUITE(DomainNameServiceQueryTest); - CPPUNIT_TEST(testSortResults_Random1); - CPPUNIT_TEST(testSortResults_Random2); - CPPUNIT_TEST_SUITE_END(); + CPPUNIT_TEST_SUITE(DomainNameServiceQueryTest); + CPPUNIT_TEST(testSortResults_Random1); + CPPUNIT_TEST(testSortResults_Random2); + CPPUNIT_TEST_SUITE_END(); - public: - void testSortResults_Random1() { - std::vector<DomainNameServiceQuery::Result> results; - results.push_back(DomainNameServiceQuery::Result("server1.com", 5222, 5, 1)); - results.push_back(DomainNameServiceQuery::Result("server2.com", 5222, 3, 10)); - results.push_back(DomainNameServiceQuery::Result("server3.com", 5222, 6, 1)); - results.push_back(DomainNameServiceQuery::Result("server4.com", 5222, 3, 20)); - results.push_back(DomainNameServiceQuery::Result("server5.com", 5222, 2, 1)); - results.push_back(DomainNameServiceQuery::Result("server6.com", 5222, 3, 10)); + public: + void testSortResults_Random1() { + std::vector<DomainNameServiceQuery::Result> results; + results.push_back(DomainNameServiceQuery::Result("server1.com", 5222, 5, 1)); + results.push_back(DomainNameServiceQuery::Result("server2.com", 5222, 3, 10)); + results.push_back(DomainNameServiceQuery::Result("server3.com", 5222, 6, 1)); + results.push_back(DomainNameServiceQuery::Result("server4.com", 5222, 3, 20)); + results.push_back(DomainNameServiceQuery::Result("server5.com", 5222, 2, 1)); + results.push_back(DomainNameServiceQuery::Result("server6.com", 5222, 3, 10)); - RandomGenerator1 generator; - DomainNameServiceQuery::sortResults(results, generator); + RandomGenerator1 generator; + DomainNameServiceQuery::sortResults(results, generator); - CPPUNIT_ASSERT_EQUAL(std::string("server5.com"), results[0].hostname); - CPPUNIT_ASSERT_EQUAL(std::string("server2.com"), results[1].hostname); - CPPUNIT_ASSERT_EQUAL(std::string("server4.com"), results[2].hostname); - CPPUNIT_ASSERT_EQUAL(std::string("server6.com"), results[3].hostname); - CPPUNIT_ASSERT_EQUAL(std::string("server1.com"), results[4].hostname); - CPPUNIT_ASSERT_EQUAL(std::string("server3.com"), results[5].hostname); - } + CPPUNIT_ASSERT_EQUAL(std::string("server5.com"), results[0].hostname); + CPPUNIT_ASSERT_EQUAL(std::string("server2.com"), results[1].hostname); + CPPUNIT_ASSERT_EQUAL(std::string("server4.com"), results[2].hostname); + CPPUNIT_ASSERT_EQUAL(std::string("server6.com"), results[3].hostname); + CPPUNIT_ASSERT_EQUAL(std::string("server1.com"), results[4].hostname); + CPPUNIT_ASSERT_EQUAL(std::string("server3.com"), results[5].hostname); + } - void testSortResults_Random2() { - std::vector<DomainNameServiceQuery::Result> results; - results.push_back(DomainNameServiceQuery::Result("server1.com", 5222, 5, 1)); - results.push_back(DomainNameServiceQuery::Result("server2.com", 5222, 3, 10)); - results.push_back(DomainNameServiceQuery::Result("server3.com", 5222, 6, 1)); - results.push_back(DomainNameServiceQuery::Result("server4.com", 5222, 3, 20)); - results.push_back(DomainNameServiceQuery::Result("server5.com", 5222, 2, 1)); - results.push_back(DomainNameServiceQuery::Result("server6.com", 5222, 3, 10)); - results.push_back(DomainNameServiceQuery::Result("server7.com", 5222, 3, 40)); + void testSortResults_Random2() { + std::vector<DomainNameServiceQuery::Result> results; + results.push_back(DomainNameServiceQuery::Result("server1.com", 5222, 5, 1)); + results.push_back(DomainNameServiceQuery::Result("server2.com", 5222, 3, 10)); + results.push_back(DomainNameServiceQuery::Result("server3.com", 5222, 6, 1)); + results.push_back(DomainNameServiceQuery::Result("server4.com", 5222, 3, 20)); + results.push_back(DomainNameServiceQuery::Result("server5.com", 5222, 2, 1)); + results.push_back(DomainNameServiceQuery::Result("server6.com", 5222, 3, 10)); + results.push_back(DomainNameServiceQuery::Result("server7.com", 5222, 3, 40)); - RandomGenerator2 generator; - DomainNameServiceQuery::sortResults(results, generator); + RandomGenerator2 generator; + DomainNameServiceQuery::sortResults(results, generator); - CPPUNIT_ASSERT_EQUAL(std::string("server5.com"), results[0].hostname); - CPPUNIT_ASSERT_EQUAL(std::string("server7.com"), results[1].hostname); - CPPUNIT_ASSERT_EQUAL(std::string("server2.com"), results[2].hostname); - CPPUNIT_ASSERT_EQUAL(std::string("server4.com"), results[3].hostname); - CPPUNIT_ASSERT_EQUAL(std::string("server6.com"), results[4].hostname); - CPPUNIT_ASSERT_EQUAL(std::string("server1.com"), results[5].hostname); - CPPUNIT_ASSERT_EQUAL(std::string("server3.com"), results[6].hostname); - } + CPPUNIT_ASSERT_EQUAL(std::string("server5.com"), results[0].hostname); + CPPUNIT_ASSERT_EQUAL(std::string("server7.com"), results[1].hostname); + CPPUNIT_ASSERT_EQUAL(std::string("server2.com"), results[2].hostname); + CPPUNIT_ASSERT_EQUAL(std::string("server4.com"), results[3].hostname); + CPPUNIT_ASSERT_EQUAL(std::string("server6.com"), results[4].hostname); + CPPUNIT_ASSERT_EQUAL(std::string("server1.com"), results[5].hostname); + CPPUNIT_ASSERT_EQUAL(std::string("server3.com"), results[6].hostname); + } }; diff --git a/Swiften/Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp b/Swiften/Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp index 134748f..e9268b0 100644 --- a/Swiften/Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp +++ b/Swiften/Network/UnitTest/HTTPConnectProxiedConnectionTest.cpp @@ -1,257 +1,437 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2019 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ + +#include <memory> + +#include <boost/algorithm/string.hpp> +#include <boost/bind.hpp> +#include <boost/lexical_cast.hpp> +#include <boost/optional.hpp> + #include <QA/Checker/IO.h> #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> -#include <boost/optional.hpp> -#include <boost/bind.hpp> -#include <boost/smart_ptr/make_shared.hpp> -#include <boost/shared_ptr.hpp> - #include <Swiften/Base/Algorithm.h> +#include <Swiften/Base/Log.h> +#include <Swiften/EventLoop/DummyEventLoop.h> #include <Swiften/Network/Connection.h> #include <Swiften/Network/ConnectionFactory.h> +#include <Swiften/Network/DummyTimerFactory.h> #include <Swiften/Network/HTTPConnectProxiedConnection.h> +#include <Swiften/Network/HTTPTrafficFilter.h> #include <Swiften/Network/HostAddressPort.h> #include <Swiften/Network/StaticDomainNameResolver.h> -#include <Swiften/Network/DummyTimerFactory.h> -#include <Swiften/EventLoop/DummyEventLoop.h> using namespace Swift; +namespace { + class ExampleHTTPTrafficFilter : public HTTPTrafficFilter { + public: + ExampleHTTPTrafficFilter() {} + virtual ~ExampleHTTPTrafficFilter() {} + + virtual std::vector<std::pair<std::string, std::string> > filterHTTPResponseHeader(const std::string& /* statusLine */, const std::vector<std::pair<std::string, std::string> >& response) { + filterResponses.push_back(response); + SWIFT_LOG(debug); + return filterResponseReturn; + } + + std::vector<std::vector<std::pair<std::string, std::string> > > filterResponses; + + std::vector<std::pair<std::string, std::string> > filterResponseReturn; + }; + + class ProxyAuthenticationHTTPTrafficFilter : public HTTPTrafficFilter { + static std::string to_lower(const std::string& str) { + std::string lower = str; + boost::algorithm::to_lower(lower); + return lower; + } + + public: + ProxyAuthenticationHTTPTrafficFilter() {} + virtual ~ProxyAuthenticationHTTPTrafficFilter() {} + + virtual std::vector<std::pair<std::string, std::string> > filterHTTPResponseHeader(const std::string& statusLine, const std::vector<std::pair<std::string, std::string> >& response) { + std::vector<std::pair<std::string, std::string> > filterResponseReturn; + std::vector<std::string> statusLineFields; + boost::split(statusLineFields, statusLine, boost::is_any_of(" "), boost::token_compress_on); + + int statusCode = boost::lexical_cast<int>(statusLineFields[1]); + if (statusCode == 407) { + for (const auto& field : response) { + if (to_lower(field.first) == to_lower("Proxy-Authenticate")) { + if (field.second.size() >= 6 && field.second.substr(0, 6) == " NTLM ") { + filterResponseReturn.push_back(std::pair<std::string, std::string>("Proxy-Authorization", "NTLM TlRMTVNTUAADAAAAGAAYAHIAAAAYABgAigAAABIAEgBIAAAABgAGAFoAAAASABIVNTUAADAAYAAAABAAEACiAAAANYKI4gUBKAoAAAAPTABBAEIAUwBNAE8ASwBFADMAXwBxAGEATABBAEIAUwBNAE8ASwBFADMA0NKq8HYYhj8AAAAAAAAAAAAAAAAAAAAAOIiih3mR+AkyM4r99sy1mdFonCu2ILODro1WTTrJ4b4JcXEzUBA2Ig==")); + return filterResponseReturn; + } + else if (field.second.size() >= 5 && field.second.substr(0, 5) == " NTLM") { + filterResponseReturn.push_back(std::pair<std::string, std::string>("Proxy-Authorization", "NTLM TlRMTVNTUAABAAAAt7II4gkACQAxAAAACQAJACgAAAVNTUAADAAFASgKAAAAD0xBQlNNT0tFM1dPUktHUk9VUA==")); + return filterResponseReturn; + } + } + } + + return filterResponseReturn; + } + else { + return std::vector<std::pair<std::string, std::string> >(); + } + } + }; +} + class HTTPConnectProxiedConnectionTest : public CppUnit::TestFixture { - CPPUNIT_TEST_SUITE(HTTPConnectProxiedConnectionTest); - CPPUNIT_TEST(testConnect_CreatesConnectionToProxy); - CPPUNIT_TEST(testConnect_SendsConnectRequest); - CPPUNIT_TEST(testConnect_ReceiveConnectResponse); - CPPUNIT_TEST(testConnect_ReceiveMalformedConnectResponse); - CPPUNIT_TEST(testConnect_ReceiveErrorConnectResponse); - CPPUNIT_TEST(testConnect_ReceiveDataAfterConnect); - CPPUNIT_TEST(testWrite_AfterConnect); - CPPUNIT_TEST(testDisconnect_AfterConnectRequest); - CPPUNIT_TEST(testDisconnect_AfterConnect); - CPPUNIT_TEST_SUITE_END(); - - public: - void setUp() { - proxyHost = "doo.bah"; - proxyPort = 1234; - proxyHostAddress = HostAddressPort(HostAddress("1.1.1.1"), proxyPort); - host = HostAddressPort(HostAddress("2.2.2.2"), 2345); - eventLoop = new DummyEventLoop(); - resolver = new StaticDomainNameResolver(eventLoop); - resolver->addAddress(proxyHost, proxyHostAddress.getAddress()); - timerFactory = new DummyTimerFactory(); - connectionFactory = new MockConnectionFactory(eventLoop); - connectFinished = false; - disconnected = false; - } - - void tearDown() { - delete timerFactory; - delete connectionFactory; - delete resolver; - delete eventLoop; - } - - void connect(HTTPConnectProxiedConnection::ref connection, const HostAddressPort& to) { - connection->connect(to); - eventLoop->processEvents(); - eventLoop->processEvents(); - eventLoop->processEvents(); - } - - void testConnect_CreatesConnectionToProxy() { - HTTPConnectProxiedConnection::ref testling(createTestling()); - - connect(testling, host); - - CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connectionFactory->connections.size())); - CPPUNIT_ASSERT(connectionFactory->connections[0]->hostAddressPort); - CPPUNIT_ASSERT(proxyHostAddress == *connectionFactory->connections[0]->hostAddressPort); - CPPUNIT_ASSERT(!connectFinished); - } - - void testConnect_SendsConnectRequest() { - HTTPConnectProxiedConnection::ref testling(createTestling()); - - connect(testling, HostAddressPort(HostAddress("2.2.2.2"), 2345)); - - CPPUNIT_ASSERT_EQUAL(createByteArray("CONNECT 2.2.2.2:2345 HTTP/1.1\r\n\r\n"), connectionFactory->connections[0]->dataWritten); - } - - void testConnect_ReceiveConnectResponse() { - HTTPConnectProxiedConnection::ref testling(createTestling()); - connect(testling, HostAddressPort(HostAddress("2.2.2.2"), 2345)); - - connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("HTTP/1.0 200 Connection established\r\n\r\n")); - eventLoop->processEvents(); - - CPPUNIT_ASSERT(connectFinished); - CPPUNIT_ASSERT(!connectFinishedWithError); - CPPUNIT_ASSERT(dataRead.empty()); - } - - void testConnect_ReceiveMalformedConnectResponse() { - HTTPConnectProxiedConnection::ref testling(createTestling()); - connect(testling, HostAddressPort(HostAddress("2.2.2.2"), 2345)); - - connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("FLOOP")); - eventLoop->processEvents(); - - CPPUNIT_ASSERT(connectFinished); - CPPUNIT_ASSERT(connectFinishedWithError); - CPPUNIT_ASSERT(connectionFactory->connections[0]->disconnected); - } - - void testConnect_ReceiveErrorConnectResponse() { - HTTPConnectProxiedConnection::ref testling(createTestling()); - connect(testling, HostAddressPort(HostAddress("2.2.2.2"), 2345)); - - connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("HTTP/1.0 401 Unauthorized\r\n\r\n")); - eventLoop->processEvents(); - - CPPUNIT_ASSERT(connectFinished); - CPPUNIT_ASSERT(connectFinishedWithError); - CPPUNIT_ASSERT(connectionFactory->connections[0]->disconnected); - } - - void testConnect_ReceiveDataAfterConnect() { - HTTPConnectProxiedConnection::ref testling(createTestling()); - connect(testling, HostAddressPort(HostAddress("2.2.2.2"), 2345)); - connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("HTTP/1.0 200 Connection established\r\n\r\n")); - eventLoop->processEvents(); - - connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("abcdef")); - - CPPUNIT_ASSERT_EQUAL(createByteArray("abcdef"), dataRead); - } - - void testWrite_AfterConnect() { - HTTPConnectProxiedConnection::ref testling(createTestling()); - connect(testling, HostAddressPort(HostAddress("2.2.2.2"), 2345)); - connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("HTTP/1.0 200 Connection established\r\n\r\n")); - eventLoop->processEvents(); - connectionFactory->connections[0]->dataWritten.clear(); - - testling->write(createSafeByteArray("abcdef")); - - CPPUNIT_ASSERT_EQUAL(createByteArray("abcdef"), connectionFactory->connections[0]->dataWritten); - } - - void testDisconnect_AfterConnectRequest() { - HTTPConnectProxiedConnection::ref testling(createTestling()); - connect(testling, HostAddressPort(HostAddress("2.2.2.2"), 2345)); - - testling->disconnect(); - - CPPUNIT_ASSERT(connectionFactory->connections[0]->disconnected); - CPPUNIT_ASSERT(disconnected); - CPPUNIT_ASSERT(!disconnectedError); - } - - void testDisconnect_AfterConnect() { - HTTPConnectProxiedConnection::ref testling(createTestling()); - connect(testling, HostAddressPort(HostAddress("2.2.2.2"), 2345)); - connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("HTTP/1.0 200 Connection established\r\n\r\n")); - eventLoop->processEvents(); - - testling->disconnect(); - - CPPUNIT_ASSERT(connectionFactory->connections[0]->disconnected); - CPPUNIT_ASSERT(disconnected); - CPPUNIT_ASSERT(!disconnectedError); - } - - private: - HTTPConnectProxiedConnection::ref createTestling() { - boost::shared_ptr<HTTPConnectProxiedConnection> c = HTTPConnectProxiedConnection::create(resolver, connectionFactory, timerFactory, proxyHost, proxyPort, "", ""); - c->onConnectFinished.connect(boost::bind(&HTTPConnectProxiedConnectionTest::handleConnectFinished, this, _1)); - c->onDisconnected.connect(boost::bind(&HTTPConnectProxiedConnectionTest::handleDisconnected, this, _1)); - c->onDataRead.connect(boost::bind(&HTTPConnectProxiedConnectionTest::handleDataRead, this, _1)); - return c; - } - - void handleConnectFinished(bool error) { - connectFinished = true; - connectFinishedWithError = error; - } - - void handleDisconnected(const boost::optional<Connection::Error>& e) { - disconnected = true; - disconnectedError = e; - } - - void handleDataRead(boost::shared_ptr<SafeByteArray> d) { - append(dataRead, *d); - } - - struct MockConnection : public Connection { - public: - MockConnection(const std::vector<HostAddressPort>& failingPorts, EventLoop* eventLoop) : eventLoop(eventLoop), failingPorts(failingPorts), disconnected(false) { - } - - void listen() { assert(false); } - - void connect(const HostAddressPort& address) { - hostAddressPort = address; - bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); - eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail)); - } - - HostAddressPort getLocalAddress() const { return HostAddressPort(); } - - void disconnect() { - disconnected = true; - onDisconnected(boost::optional<Connection::Error>()); - } - - void write(const SafeByteArray& d) { - append(dataWritten, d); - } - - EventLoop* eventLoop; - boost::optional<HostAddressPort> hostAddressPort; - std::vector<HostAddressPort> failingPorts; - ByteArray dataWritten; - bool disconnected; - }; - - struct MockConnectionFactory : public ConnectionFactory { - MockConnectionFactory(EventLoop* eventLoop) : eventLoop(eventLoop) { - } - - boost::shared_ptr<Connection> createConnection() { - boost::shared_ptr<MockConnection> connection = boost::make_shared<MockConnection>(failingPorts, eventLoop); - connections.push_back(connection); - return connection; - } - - EventLoop* eventLoop; - std::vector< boost::shared_ptr<MockConnection> > connections; - std::vector<HostAddressPort> failingPorts; - }; - - private: - std::string proxyHost; - HostAddressPort proxyHostAddress; - int proxyPort; - HostAddressPort host; - DummyEventLoop* eventLoop; - StaticDomainNameResolver* resolver; - MockConnectionFactory* connectionFactory; - TimerFactory* timerFactory; - std::vector< boost::shared_ptr<MockConnection> > connections; - bool connectFinished; - bool connectFinishedWithError; - bool disconnected; - boost::optional<Connection::Error> disconnectedError; - ByteArray dataRead; + CPPUNIT_TEST_SUITE(HTTPConnectProxiedConnectionTest); + CPPUNIT_TEST(testConnect_CreatesConnectionToProxy); + CPPUNIT_TEST(testConnect_SendsConnectRequest); + CPPUNIT_TEST(testConnect_ReceiveConnectResponse); + CPPUNIT_TEST(testConnect_ReceiveConnectChunkedResponse); + CPPUNIT_TEST(testConnect_ReceiveMalformedConnectResponse); + CPPUNIT_TEST(testConnect_ReceiveErrorConnectResponse); + CPPUNIT_TEST(testConnect_ReceiveDataAfterConnect); + CPPUNIT_TEST(testWrite_AfterConnect); + CPPUNIT_TEST(testDisconnect_AfterConnectRequest); + CPPUNIT_TEST(testDisconnect_AfterConnect); + CPPUNIT_TEST(testTrafficFilter); + CPPUNIT_TEST(testTrafficFilterNoConnectionReuse); + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + proxyHost = "doo.bah"; + proxyPort = 1234; + proxyHostAddress = HostAddressPort(HostAddress::fromString("1.1.1.1").get(), proxyPort); + host = HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345); + eventLoop = new DummyEventLoop(); + resolver = new StaticDomainNameResolver(eventLoop); + resolver->addAddress(proxyHost, proxyHostAddress.getAddress()); + timerFactory = new DummyTimerFactory(); + connectionFactory = new MockConnectionFactory(eventLoop); + connectFinished = false; + connectFinishedWithError = false; + disconnected = false; + } + + void tearDown() { + delete timerFactory; + delete connectionFactory; + delete resolver; + delete eventLoop; + } + + void connect(HTTPConnectProxiedConnection::ref connection, const HostAddressPort& to) { + connection->connect(to); + eventLoop->processEvents(); + eventLoop->processEvents(); + eventLoop->processEvents(); + } + + void testConnect_CreatesConnectionToProxy() { + HTTPConnectProxiedConnection::ref testling(createTestling()); + + connect(testling, host); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connectionFactory->connections.size())); + CPPUNIT_ASSERT(connectionFactory->connections[0]->hostAddressPort); + CPPUNIT_ASSERT(proxyHostAddress == *connectionFactory->connections[0]->hostAddressPort); + CPPUNIT_ASSERT(!connectFinished); + } + + void testConnect_SendsConnectRequest() { + HTTPConnectProxiedConnection::ref testling(createTestling()); + + connect(testling, HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345)); + + CPPUNIT_ASSERT_EQUAL(createByteArray("CONNECT 2.2.2.2:2345 HTTP/1.1\r\n\r\n"), connectionFactory->connections[0]->dataWritten); + } + + void testConnect_ReceiveConnectResponse() { + HTTPConnectProxiedConnection::ref testling(createTestling()); + connect(testling, HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345)); + + connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("HTTP/1.0 200 Connection established\r\n\r\n")); + eventLoop->processEvents(); + + CPPUNIT_ASSERT(connectFinished); + CPPUNIT_ASSERT(!connectFinishedWithError); + CPPUNIT_ASSERT(dataRead.empty()); + } + + void testConnect_ReceiveConnectChunkedResponse() { + HTTPConnectProxiedConnection::ref testling(createTestling()); + connect(testling, HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345)); + + connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("HTTP/1.0 ")); + eventLoop->processEvents(); + connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("200 Connection established\r\n\r\n")); + eventLoop->processEvents(); + + CPPUNIT_ASSERT(connectFinished); + CPPUNIT_ASSERT(!connectFinishedWithError); + CPPUNIT_ASSERT(dataRead.empty()); + } + + + void testConnect_ReceiveMalformedConnectResponse() { + HTTPConnectProxiedConnection::ref testling(createTestling()); + connect(testling, HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345)); + + connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("FLOOP")); + eventLoop->processEvents(); + + CPPUNIT_ASSERT(connectFinished); + CPPUNIT_ASSERT(connectFinishedWithError); + CPPUNIT_ASSERT(connectionFactory->connections[0]->disconnected); + } + + void testConnect_ReceiveErrorConnectResponse() { + HTTPConnectProxiedConnection::ref testling(createTestling()); + connect(testling, HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345)); + + connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("HTTP/1.0 401 Unauthorized\r\n\r\n")); + eventLoop->processEvents(); + + CPPUNIT_ASSERT(connectFinished); + CPPUNIT_ASSERT(connectFinishedWithError); + CPPUNIT_ASSERT(connectionFactory->connections[0]->disconnected); + } + + void testConnect_ReceiveDataAfterConnect() { + HTTPConnectProxiedConnection::ref testling(createTestling()); + connect(testling, HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345)); + connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("HTTP/1.0 200 Connection established\r\n\r\n")); + eventLoop->processEvents(); + + connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("abcdef")); + + CPPUNIT_ASSERT_EQUAL(createByteArray("abcdef"), dataRead); + } + + void testWrite_AfterConnect() { + HTTPConnectProxiedConnection::ref testling(createTestling()); + connect(testling, HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345)); + connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("HTTP/1.0 200 Connection established\r\n\r\n")); + eventLoop->processEvents(); + connectionFactory->connections[0]->dataWritten.clear(); + + testling->write(createSafeByteArray("abcdef")); + + CPPUNIT_ASSERT_EQUAL(createByteArray("abcdef"), connectionFactory->connections[0]->dataWritten); + } + + void testDisconnect_AfterConnectRequest() { + HTTPConnectProxiedConnection::ref testling(createTestling()); + connect(testling, HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345)); + + testling->disconnect(); + + CPPUNIT_ASSERT(connectionFactory->connections[0]->disconnected); + CPPUNIT_ASSERT(disconnected); + CPPUNIT_ASSERT(!disconnectedError); + } + + void testDisconnect_AfterConnect() { + HTTPConnectProxiedConnection::ref testling(createTestling()); + connect(testling, HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345)); + connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef("HTTP/1.0 200 Connection established\r\n\r\n")); + eventLoop->processEvents(); + + testling->disconnect(); + + CPPUNIT_ASSERT(connectionFactory->connections[0]->disconnected); + CPPUNIT_ASSERT(disconnected); + CPPUNIT_ASSERT(!disconnectedError); + } + + void testTrafficFilter() { + HTTPConnectProxiedConnection::ref testling(createTestling()); + + std::shared_ptr<ExampleHTTPTrafficFilter> httpTrafficFilter = std::make_shared<ExampleHTTPTrafficFilter>(); + + testling->setHTTPTrafficFilter(httpTrafficFilter); + connect(testling, HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345)); + + // set a default response so the server response is answered by the traffic filter + httpTrafficFilter->filterResponseReturn.clear(); + httpTrafficFilter->filterResponseReturn.push_back(std::pair<std::string, std::string>("Authorization", "Negotiate a87421000492aa874209af8bc028")); + + connectionFactory->connections[0]->dataWritten.clear(); + + // test chunked response + connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef( + "HTTP/1.0 401 Unauthorized\r\n")); + eventLoop->processEvents(); + connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef( + "WWW-Authenticate: Negotiate\r\n" + "\r\n")); + eventLoop->processEvents(); + + + // verify that the traffic filter got called and answered with its response + CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(1), httpTrafficFilter->filterResponses.size()); + CPPUNIT_ASSERT_EQUAL(std::string("WWW-Authenticate"), httpTrafficFilter->filterResponses[0][0].first); + + // remove the default response from the traffic filter + httpTrafficFilter->filterResponseReturn.clear(); + eventLoop->processEvents(); + + // verify that the traffic filter answer is send over the wire + CPPUNIT_ASSERT_EQUAL(createByteArray("CONNECT 2.2.2.2:2345 HTTP/1.1\r\nAuthorization: Negotiate a87421000492aa874209af8bc028\r\n\r\n"), connectionFactory->connections[1]->dataWritten); + + // verify that after without the default response, the traffic filter is skipped, authentication proceeds and traffic goes right through + connectionFactory->connections[1]->dataWritten.clear(); + testling->write(createSafeByteArray("abcdef")); + CPPUNIT_ASSERT_EQUAL(createByteArray("abcdef"), connectionFactory->connections[1]->dataWritten); + } + + void testTrafficFilterNoConnectionReuse() { + HTTPConnectProxiedConnection::ref testling = createTestling(); + + std::shared_ptr<ProxyAuthenticationHTTPTrafficFilter> httpTrafficFilter = std::make_shared<ProxyAuthenticationHTTPTrafficFilter>(); + testling->setHTTPTrafficFilter(httpTrafficFilter); + + connect(testling, HostAddressPort(HostAddress::fromString("2.2.2.2").get(), 2345)); + + // First HTTP CONNECT request assumes the proxy will work. + CPPUNIT_ASSERT_EQUAL(createByteArray("CONNECT 2.2.2.2:2345 HTTP/1.1\r\n" + "\r\n"), connectionFactory->connections[0]->dataWritten); + + // First reply presents initiator with authentication options. + connectionFactory->connections[0]->onDataRead(createSafeByteArrayRef( + "HTTP/1.0 407 ProxyAuthentication Required\r\n" + "proxy-Authenticate: Negotiate\r\n" + "Proxy-Authenticate: Kerberos\r\n" + "proxy-Authenticate: NTLM\r\n" + "\r\n")); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(false, connectFinished); + CPPUNIT_ASSERT_EQUAL(false, connectFinishedWithError); + + // The HTTP proxy responds with code 407, so the traffic filter should inject the authentication response on a new connection. + CPPUNIT_ASSERT_EQUAL(createByteArray("CONNECT 2.2.2.2:2345 HTTP/1.1\r\n" + "Proxy-Authorization: NTLM TlRMTVNTUAABAAAAt7II4gkACQAxAAAACQAJACgAAAVNTUAADAAFASgKAAAAD0xBQlNNT0tFM1dPUktHUk9VUA==\r\n" + "\r\n"), connectionFactory->connections[1]->dataWritten); + + // The proxy responds with another authentication step. + connectionFactory->connections[1]->onDataRead(createSafeByteArrayRef( + "HTTP/1.0 407 ProxyAuthentication Required\r\n" + "Proxy-Authenticate: NTLM TlRMTVNTUAACAAAAEAAQADgAAAA1goriluCDYHcYI/sAAAAAAAAAAFQAVABIAAAABQLODgAAAA9TAFAASQBSAEkAVAAxAEIAAgAQAFMAUABJAFIASQBUADEAQgABABAAUwBQAEkAUgBJAFQAMQBCAAQAEABzAHAAaQByAGkAdAAxAGIAAwAQAHMAcABpAHIAaQB0ADEAYgAAAAAA\r\n" + "\r\n")); + eventLoop->processEvents(); + CPPUNIT_ASSERT_EQUAL(false, connectFinished); + CPPUNIT_ASSERT_EQUAL(false, connectFinishedWithError); + + // Last HTTP request that should succeed. Further traffic will go over the connection of this request. + CPPUNIT_ASSERT_EQUAL(createByteArray("CONNECT 2.2.2.2:2345 HTTP/1.1\r\n" + "Proxy-Authorization: NTLM TlRMTVNTUAADAAAAGAAYAHIAAAAYABgAigAAABIAEgBIAAAABgAGAFoAAAASABIVNTUAADAAYAAAABAAEACiAAAANYKI4gUBKAoAAAAPTABBAEIAUwBNAE8ASwBFADMAXwBxAGEATABBAEIAUwBNAE8ASwBFADMA0NKq8HYYhj8AAAAAAAAAAAAAAAAAAAAAOIiih3mR+AkyM4r99sy1mdFonCu2ILODro1WTTrJ4b4JcXEzUBA2Ig==\r\n" + "\r\n"), connectionFactory->connections[2]->dataWritten); + + connectionFactory->connections[2]->onDataRead(createSafeByteArrayRef( + "HTTP/1.0 200 OK\r\n" + "\r\n")); + eventLoop->processEvents(); + + // The HTTP CONNECT proxy initialization finished without error. + CPPUNIT_ASSERT_EQUAL(true, connectFinished); + CPPUNIT_ASSERT_EQUAL(false, connectFinishedWithError); + + // Further traffic is written directly, without interception of the filter. + connectionFactory->connections[2]->dataWritten.clear(); + testling->write(createSafeByteArray("This is some basic data traffic.")); + CPPUNIT_ASSERT_EQUAL(createByteArray("This is some basic data traffic."), connectionFactory->connections[2]->dataWritten); + } + + private: + HTTPConnectProxiedConnection::ref createTestling() { + std::shared_ptr<HTTPConnectProxiedConnection> c = HTTPConnectProxiedConnection::create(resolver, connectionFactory, timerFactory, proxyHost, proxyPort, "", ""); + c->onConnectFinished.connect(boost::bind(&HTTPConnectProxiedConnectionTest::handleConnectFinished, this, _1)); + c->onDisconnected.connect(boost::bind(&HTTPConnectProxiedConnectionTest::handleDisconnected, this, _1)); + c->onDataRead.connect(boost::bind(&HTTPConnectProxiedConnectionTest::handleDataRead, this, _1)); + return c; + } + + void handleConnectFinished(bool error) { + connectFinished = true; + connectFinishedWithError = error; + } + + void handleDisconnected(const boost::optional<Connection::Error>& e) { + disconnected = true; + disconnectedError = e; + } + + void handleDataRead(std::shared_ptr<SafeByteArray> d) { + append(dataRead, *d); + } + + struct MockConnection : public Connection { + public: + MockConnection(const std::vector<HostAddressPort>& failingPorts, EventLoop* eventLoop) : eventLoop(eventLoop), failingPorts(failingPorts), disconnected(false) { + } + + void listen() { assert(false); } + + void connect(const HostAddressPort& address) { + hostAddressPort = address; + bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); + eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail)); + } + + HostAddressPort getLocalAddress() const { return HostAddressPort(); } + HostAddressPort getRemoteAddress() const { return HostAddressPort(); } + + void disconnect() { + disconnected = true; + onDisconnected(boost::optional<Connection::Error>()); + } + + void write(const SafeByteArray& d) { + append(dataWritten, d); + } + + EventLoop* eventLoop; + boost::optional<HostAddressPort> hostAddressPort; + std::vector<HostAddressPort> failingPorts; + ByteArray dataWritten; + bool disconnected; + }; + + struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory(EventLoop* eventLoop) : eventLoop(eventLoop) { + } + + std::shared_ptr<Connection> createConnection() { + std::shared_ptr<MockConnection> connection = std::make_shared<MockConnection>(failingPorts, eventLoop); + connections.push_back(connection); + SWIFT_LOG(debug) << "new connection created"; + return connection; + } + + EventLoop* eventLoop; + std::vector< std::shared_ptr<MockConnection> > connections; + std::vector<HostAddressPort> failingPorts; + }; + + private: + std::string proxyHost; + HostAddressPort proxyHostAddress; + unsigned short proxyPort; + HostAddressPort host; + DummyEventLoop* eventLoop; + StaticDomainNameResolver* resolver; + MockConnectionFactory* connectionFactory; + TimerFactory* timerFactory; + bool connectFinished; + bool connectFinishedWithError; + bool disconnected; + boost::optional<Connection::Error> disconnectedError; + ByteArray dataRead; }; CPPUNIT_TEST_SUITE_REGISTRATION(HTTPConnectProxiedConnectionTest); diff --git a/Swiften/Network/UnitTest/HostAddressTest.cpp b/Swiften/Network/UnitTest/HostAddressTest.cpp index b2511a8..bd345a7 100644 --- a/Swiften/Network/UnitTest/HostAddressTest.cpp +++ b/Swiften/Network/UnitTest/HostAddressTest.cpp @@ -1,66 +1,95 @@ /* - * Copyright (c) 2010 Remko Tronçon - * Licensed under the GNU General Public License v3. - * See Documentation/Licenses/GPLv3.txt for more information. + * Copyright (c) 2010-2018 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. */ +#include <string> + #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> #include <Swiften/Network/HostAddress.h> -#include <string> +#include <Swiften/Network/HostAddressPort.h> using namespace Swift; class HostAddressTest : public CppUnit::TestFixture { - CPPUNIT_TEST_SUITE(HostAddressTest); - CPPUNIT_TEST(testConstructor); - CPPUNIT_TEST(testConstructor_Invalid); - CPPUNIT_TEST(testConstructor_InvalidString); - CPPUNIT_TEST(testToString); - CPPUNIT_TEST(testToString_IPv6); - CPPUNIT_TEST(testToString_Invalid); - CPPUNIT_TEST_SUITE_END(); - - public: - void testConstructor() { - HostAddress testling("192.168.1.254"); - - CPPUNIT_ASSERT_EQUAL(std::string("192.168.1.254"), testling.toString()); - CPPUNIT_ASSERT(testling.isValid()); - } - - void testConstructor_Invalid() { - HostAddress testling; - - CPPUNIT_ASSERT(!testling.isValid()); - } - - void testConstructor_InvalidString() { - HostAddress testling("invalid"); - - CPPUNIT_ASSERT(!testling.isValid()); - } - - void testToString() { - unsigned char address[4] = {10, 0, 1, 253}; - HostAddress testling(address, 4); - - CPPUNIT_ASSERT_EQUAL(std::string("10.0.1.253"), testling.toString()); - } - - void testToString_IPv6() { - unsigned char address[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17}; - HostAddress testling(address, 16); - - CPPUNIT_ASSERT_EQUAL(std::string("102:304:506:708:90a:b0c:d0e:f11"), testling.toString()); - } - - void testToString_Invalid() { - HostAddress testling; - - CPPUNIT_ASSERT_EQUAL(std::string("0.0.0.0"), testling.toString()); - } + CPPUNIT_TEST_SUITE(HostAddressTest); + CPPUNIT_TEST(testConstructor); + CPPUNIT_TEST(testConstructor_Invalid); + CPPUNIT_TEST(testConstructor_InvalidString); + CPPUNIT_TEST(testToString); + CPPUNIT_TEST(testToString_IPv6); + CPPUNIT_TEST(testToString_Invalid); + CPPUNIT_TEST(testComparison); + CPPUNIT_TEST_SUITE_END(); + + public: + void testConstructor() { + auto testling = HostAddress::fromString("192.168.1.254"); + + CPPUNIT_ASSERT_EQUAL(std::string("192.168.1.254"), testling->toString()); + CPPUNIT_ASSERT(testling->isValid()); + } + + void testConstructor_Invalid() { + HostAddress testling; + + CPPUNIT_ASSERT(!testling.isValid()); + } + + void testConstructor_InvalidString() { + auto testling = HostAddress::fromString("invalid"); + + CPPUNIT_ASSERT(!testling); + } + + void testToString() { + unsigned char address[4] = {10, 0, 1, 253}; + HostAddress testling(address, 4); + + CPPUNIT_ASSERT_EQUAL(std::string("10.0.1.253"), testling.toString()); + } + + void testToString_IPv6() { + unsigned char address[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17}; + HostAddress testling(address, 16); + + CPPUNIT_ASSERT_EQUAL(std::string("102:304:506:708:90a:b0c:d0e:f11"), testling.toString()); + } + + void testToString_Invalid() { + HostAddress testling; + + CPPUNIT_ASSERT_EQUAL(std::string("0.0.0.0"), testling.toString()); + } + + void testComparison() { + auto ha127_0_0_1 = *HostAddress::fromString("127.0.0.1"); + auto ha127_0_0_2 = *HostAddress::fromString("127.0.0.2"); + auto ha127_0_1_0 = *HostAddress::fromString("127.0.1.0"); + + CPPUNIT_ASSERT(ha127_0_0_1 < ha127_0_0_2); + CPPUNIT_ASSERT(ha127_0_0_2 < ha127_0_1_0); + CPPUNIT_ASSERT(!(ha127_0_0_1 < ha127_0_0_1)); + CPPUNIT_ASSERT(!(ha127_0_0_2 < ha127_0_0_1)); + CPPUNIT_ASSERT(!(ha127_0_0_2 == ha127_0_0_1)); + CPPUNIT_ASSERT(ha127_0_0_1 == ha127_0_0_1); + + auto hap_127_0_0_1__1 = HostAddressPort(ha127_0_0_1, 1); + auto hap_127_0_0_1__2 = HostAddressPort(ha127_0_0_1, 2); + auto hap_127_0_0_2__1 = HostAddressPort(ha127_0_0_2, 1); + auto hap_127_0_0_2__2 = HostAddressPort(ha127_0_0_2, 2); + + CPPUNIT_ASSERT(hap_127_0_0_1__1 < hap_127_0_0_1__2); + CPPUNIT_ASSERT(!(hap_127_0_0_1__1 < hap_127_0_0_1__1)); + CPPUNIT_ASSERT(!(hap_127_0_0_1__1 == hap_127_0_0_1__2)); + CPPUNIT_ASSERT(hap_127_0_0_1__1 == hap_127_0_0_1__1); + CPPUNIT_ASSERT(!(hap_127_0_0_1__2 == hap_127_0_0_1__1)); + CPPUNIT_ASSERT(hap_127_0_0_1__2 < hap_127_0_0_2__1); + CPPUNIT_ASSERT(hap_127_0_0_2__1 < hap_127_0_0_2__2); + } }; CPPUNIT_TEST_SUITE_REGISTRATION(HostAddressTest); diff --git a/Swiften/Network/UnixNetworkEnvironment.cpp b/Swiften/Network/UnixNetworkEnvironment.cpp index e1fdc88..dc90589 100644 --- a/Swiften/Network/UnixNetworkEnvironment.cpp +++ b/Swiften/Network/UnixNetworkEnvironment.cpp @@ -4,60 +4,68 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #include <Swiften/Network/UnixNetworkEnvironment.h> +#include <map> #include <string> #include <vector> -#include <map> + #include <boost/optional.hpp> -#include <sys/types.h> -#include <sys/socket.h> +#include <boost/signals2.hpp> + #include <arpa/inet.h> #include <net/if.h> +#include <sys/socket.h> +#include <sys/types.h> #ifndef __ANDROID__ #include <ifaddrs.h> #endif -#include <Swiften/Base/boost_bsignals.h> #include <Swiften/Network/HostAddress.h> #include <Swiften/Network/NetworkInterface.h> namespace Swift { std::vector<NetworkInterface> UnixNetworkEnvironment::getNetworkInterfaces() const { - std::map<std::string, NetworkInterface> interfaces; + std::map<std::string, NetworkInterface> interfaces; #ifndef __ANDROID__ - ifaddrs* addrs = 0; - int ret = getifaddrs(&addrs); - if (ret != 0) { - return std::vector<NetworkInterface>(); - } - - for (ifaddrs* a = addrs; a != 0; a = a->ifa_next) { - std::string name(a->ifa_name); - boost::optional<HostAddress> address; - if (a->ifa_addr->sa_family == PF_INET) { - sockaddr_in* sa = reinterpret_cast<sockaddr_in*>(a->ifa_addr); - address = HostAddress(reinterpret_cast<const unsigned char*>(&(sa->sin_addr)), 4); - } - else if (a->ifa_addr->sa_family == PF_INET6) { - sockaddr_in6* sa = reinterpret_cast<sockaddr_in6*>(a->ifa_addr); - address = HostAddress(reinterpret_cast<const unsigned char*>(&(sa->sin6_addr)), 16); - } - if (address && !address->isLocalhost()) { - std::map<std::string, NetworkInterface>::iterator i = interfaces.insert(std::make_pair(name, NetworkInterface(name, a->ifa_flags & IFF_LOOPBACK))).first; - i->second.addAddress(*address); - } - } - - freeifaddrs(addrs); + ifaddrs* addrs = nullptr; + int ret = getifaddrs(&addrs); + if (ret != 0) { + return std::vector<NetworkInterface>(); + } + + for (ifaddrs* a = addrs; a != nullptr; a = a->ifa_next) { + std::string name(a->ifa_name); + boost::optional<HostAddress> address; + if (a->ifa_addr->sa_family == PF_INET) { + sockaddr_in* sa = reinterpret_cast<sockaddr_in*>(a->ifa_addr); + address = HostAddress(reinterpret_cast<const unsigned char*>(&(sa->sin_addr)), 4); + } + else if (a->ifa_addr->sa_family == PF_INET6) { + sockaddr_in6* sa = reinterpret_cast<sockaddr_in6*>(a->ifa_addr); + address = HostAddress(reinterpret_cast<const unsigned char*>(&(sa->sin6_addr)), 16); + } + if (address && !address->isLocalhost()) { + std::map<std::string, NetworkInterface>::iterator i = interfaces.insert(std::make_pair(name, NetworkInterface(name, a->ifa_flags & IFF_LOOPBACK))).first; + i->second.addAddress(*address); + } + } + + freeifaddrs(addrs); #endif - std::vector<NetworkInterface> result; - for (std::map<std::string,NetworkInterface>::const_iterator i = interfaces.begin(); i != interfaces.end(); ++i) { - result.push_back(i->second); - } - return result; + std::vector<NetworkInterface> result; + for (std::map<std::string,NetworkInterface>::const_iterator i = interfaces.begin(); i != interfaces.end(); ++i) { + result.push_back(i->second); + } + return result; } } diff --git a/Swiften/Network/UnixNetworkEnvironment.h b/Swiften/Network/UnixNetworkEnvironment.h index 8b51cae..89a01ab 100644 --- a/Swiften/Network/UnixNetworkEnvironment.h +++ b/Swiften/Network/UnixNetworkEnvironment.h @@ -4,11 +4,17 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once #include <vector> -#include <Swiften/Base/boost_bsignals.h> +#include <boost/signals2.hpp> #include <Swiften/Network/NetworkEnvironment.h> #include <Swiften/Network/NetworkInterface.h> @@ -16,8 +22,8 @@ namespace Swift { class UnixNetworkEnvironment : public NetworkEnvironment { - public: - std::vector<NetworkInterface> getNetworkInterfaces() const; + public: + std::vector<NetworkInterface> getNetworkInterfaces() const; }; } diff --git a/Swiften/Network/UnixProxyProvider.cpp b/Swiften/Network/UnixProxyProvider.cpp index 4ca9311..854d501 100644 --- a/Swiften/Network/UnixProxyProvider.cpp +++ b/Swiften/Network/UnixProxyProvider.cpp @@ -4,12 +4,14 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ -#include <stdio.h> -#include <stdlib.h> -#include <iostream> +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ -#include <Swiften/Base/foreach.h> #include <Swiften/Network/UnixProxyProvider.h> + #if defined(HAVE_GCONF) # include "Swiften/Network/GConfProxyProvider.h" #endif @@ -17,48 +19,50 @@ namespace Swift { UnixProxyProvider::UnixProxyProvider() : - gconfProxyProvider(0), - environmentProxyProvider() + gconfProxyProvider(nullptr), + environmentProxyProvider() { #if defined(HAVE_GCONF) - gconfProxyProvider = new GConfProxyProvider(); + gconfProxyProvider = new GConfProxyProvider(); #endif } UnixProxyProvider::~UnixProxyProvider() { #if defined(HAVE_GCONF) - delete gconfProxyProvider; + delete gconfProxyProvider; +#else + (void)gconfProxyProvider; #endif } HostAddressPort UnixProxyProvider::getSOCKS5Proxy() const { - HostAddressPort proxy; + HostAddressPort proxy; #if defined(HAVE_GCONF) - proxy = gconfProxyProvider->getSOCKS5Proxy(); - if(proxy.isValid()) { - return proxy; - } + proxy = gconfProxyProvider->getSOCKS5Proxy(); + if(proxy.isValid()) { + return proxy; + } #endif - proxy = environmentProxyProvider.getSOCKS5Proxy(); - if(proxy.isValid()) { - return proxy; - } - return HostAddressPort(HostAddress(), 0); + proxy = environmentProxyProvider.getSOCKS5Proxy(); + if(proxy.isValid()) { + return proxy; + } + return HostAddressPort(HostAddress(), 0); } HostAddressPort UnixProxyProvider::getHTTPConnectProxy() const { - HostAddressPort proxy; + HostAddressPort proxy; #if defined(HAVE_GCONF) - proxy = gconfProxyProvider->getHTTPConnectProxy(); - if(proxy.isValid()) { - return proxy; - } + proxy = gconfProxyProvider->getHTTPConnectProxy(); + if(proxy.isValid()) { + return proxy; + } #endif - proxy = environmentProxyProvider.getHTTPConnectProxy(); - if(proxy.isValid()) { - return proxy; - } - return HostAddressPort(HostAddress(), 0); + proxy = environmentProxyProvider.getHTTPConnectProxy(); + if(proxy.isValid()) { + return proxy; + } + return HostAddressPort(HostAddress(), 0); } diff --git a/Swiften/Network/UnixProxyProvider.h b/Swiften/Network/UnixProxyProvider.h index 37a4d05..1721480 100644 --- a/Swiften/Network/UnixProxyProvider.h +++ b/Swiften/Network/UnixProxyProvider.h @@ -4,23 +4,29 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once #include <Swiften/Network/EnvironmentProxyProvider.h> namespace Swift { - class GConfProxyProvider; + class GConfProxyProvider; - class UnixProxyProvider : public ProxyProvider { - public: - UnixProxyProvider(); - ~UnixProxyProvider(); + class UnixProxyProvider : public ProxyProvider { + public: + UnixProxyProvider(); + virtual ~UnixProxyProvider(); - virtual HostAddressPort getHTTPConnectProxy() const; - virtual HostAddressPort getSOCKS5Proxy() const; + virtual HostAddressPort getHTTPConnectProxy() const; + virtual HostAddressPort getSOCKS5Proxy() const; - private: - GConfProxyProvider* gconfProxyProvider; - EnvironmentProxyProvider environmentProxyProvider; - }; + private: + GConfProxyProvider* gconfProxyProvider; + EnvironmentProxyProvider environmentProxyProvider; + }; } diff --git a/Swiften/Network/WindowsNetworkEnvironment.cpp b/Swiften/Network/WindowsNetworkEnvironment.cpp index e2d1966..e90a5c6 100644 --- a/Swiften/Network/WindowsNetworkEnvironment.cpp +++ b/Swiften/Network/WindowsNetworkEnvironment.cpp @@ -4,63 +4,70 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #include <Swiften/Network/WindowsNetworkEnvironment.h> +#include <map> #include <string> #include <vector> -#include <map> + #include <boost/optional.hpp> -#include <Swiften/Network/HostAddress.h> -#include <Swiften/Network/NetworkInterface.h> -#include <Swiften/Base/foreach.h> -#include <Swiften/Base/ByteArray.h> -#include <winsock2.h> #include <iphlpapi.h> +#include <winsock2.h> + +#include <Swiften/Base/ByteArray.h> +#include <Swiften/Network/HostAddress.h> +#include <Swiften/Network/NetworkInterface.h> namespace Swift { std::vector<NetworkInterface> WindowsNetworkEnvironment::getNetworkInterfaces() const { - std::vector<NetworkInterface> result; + std::vector<NetworkInterface> result; - ByteArray adapters; - ULONG bufferSize = 0; - ULONG ret; - ULONG flags = GAA_FLAG_INCLUDE_ALL_INTERFACES | GAA_FLAG_INCLUDE_PREFIX | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER; - while ((ret = GetAdaptersAddresses(AF_UNSPEC, flags, NULL, reinterpret_cast<IP_ADAPTER_ADDRESSES*>(vecptr(adapters)), &bufferSize)) == ERROR_BUFFER_OVERFLOW) { - adapters.resize(bufferSize); - }; - if (ret != ERROR_SUCCESS) { - return result; - } + ByteArray adapters; + ULONG bufferSize = 0; + ULONG ret; + ULONG flags = GAA_FLAG_INCLUDE_ALL_INTERFACES | GAA_FLAG_INCLUDE_PREFIX | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER; + while ((ret = GetAdaptersAddresses(AF_UNSPEC, flags, NULL, reinterpret_cast<IP_ADAPTER_ADDRESSES*>(vecptr(adapters)), &bufferSize)) == ERROR_BUFFER_OVERFLOW) { + adapters.resize(bufferSize); + }; + if (ret != ERROR_SUCCESS) { + return result; + } - std::map<std::string,NetworkInterface> interfaces; - for (IP_ADAPTER_ADDRESSES* adapter = reinterpret_cast<IP_ADAPTER_ADDRESSES*>(vecptr(adapters)); adapter; adapter = adapter->Next) { - std::string name(adapter->AdapterName); - if (adapter->OperStatus != IfOperStatusUp) { - continue; - } - for (IP_ADAPTER_UNICAST_ADDRESS* address = adapter->FirstUnicastAddress; address; address = address->Next) { - boost::optional<HostAddress> hostAddress; - if (address->Address.lpSockaddr->sa_family == PF_INET) { - sockaddr_in* sa = reinterpret_cast<sockaddr_in*>(address->Address.lpSockaddr); - hostAddress = HostAddress(reinterpret_cast<const unsigned char*>(&(sa->sin_addr)), 4); - } - else if (address->Address.lpSockaddr->sa_family == PF_INET6) { - sockaddr_in6* sa = reinterpret_cast<sockaddr_in6*>(address->Address.lpSockaddr); - hostAddress = HostAddress(reinterpret_cast<const unsigned char*>(&(sa->sin6_addr)), 16); - } - if (hostAddress && !hostAddress->isLocalhost()) { - std::map<std::string, NetworkInterface>::iterator i = interfaces.insert(std::make_pair(name, NetworkInterface(name, false))).first; - i->second.addAddress(*hostAddress); - } - } - } + std::map<std::string,NetworkInterface> interfaces; + for (IP_ADAPTER_ADDRESSES* adapter = reinterpret_cast<IP_ADAPTER_ADDRESSES*>(vecptr(adapters)); adapter; adapter = adapter->Next) { + std::string name(adapter->AdapterName); + if (adapter->OperStatus != IfOperStatusUp) { + continue; + } + for (IP_ADAPTER_UNICAST_ADDRESS* address = adapter->FirstUnicastAddress; address; address = address->Next) { + boost::optional<HostAddress> hostAddress; + if (address->Address.lpSockaddr->sa_family == PF_INET) { + sockaddr_in* sa = reinterpret_cast<sockaddr_in*>(address->Address.lpSockaddr); + hostAddress = HostAddress(reinterpret_cast<const unsigned char*>(&(sa->sin_addr)), 4); + } + else if (address->Address.lpSockaddr->sa_family == PF_INET6) { + sockaddr_in6* sa = reinterpret_cast<sockaddr_in6*>(address->Address.lpSockaddr); + hostAddress = HostAddress(reinterpret_cast<const unsigned char*>(&(sa->sin6_addr)), 16); + } + if (hostAddress && !hostAddress->isLocalhost()) { + std::map<std::string, NetworkInterface>::iterator i = interfaces.insert(std::make_pair(name, NetworkInterface(name, false))).first; + i->second.addAddress(*hostAddress); + } + } + } - for (std::map<std::string,NetworkInterface>::const_iterator i = interfaces.begin(); i != interfaces.end(); ++i) { - result.push_back(i->second); - } - return result; + for (const auto& interface : interfaces) { + result.push_back(interface.second); + } + return result; } } diff --git a/Swiften/Network/WindowsNetworkEnvironment.h b/Swiften/Network/WindowsNetworkEnvironment.h index 18996ed..81de826 100644 --- a/Swiften/Network/WindowsNetworkEnvironment.h +++ b/Swiften/Network/WindowsNetworkEnvironment.h @@ -4,17 +4,24 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once #include <vector> +#include <boost/signals2.hpp> + #include <Swiften/Base/API.h> -#include <Swiften/Base/boost_bsignals.h> #include <Swiften/Network/NetworkEnvironment.h> namespace Swift { - class SWIFTEN_API WindowsNetworkEnvironment : public NetworkEnvironment { - public: - std::vector<NetworkInterface> getNetworkInterfaces() const; - }; + class SWIFTEN_API WindowsNetworkEnvironment : public NetworkEnvironment { + public: + std::vector<NetworkInterface> getNetworkInterfaces() const; + }; } diff --git a/Swiften/Network/WindowsProxyProvider.cpp b/Swiften/Network/WindowsProxyProvider.cpp index 3ae43e0..13fdb25 100644 --- a/Swiften/Network/WindowsProxyProvider.cpp +++ b/Swiften/Network/WindowsProxyProvider.cpp @@ -4,110 +4,122 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + +#include <Swiften/Network/WindowsProxyProvider.h> + +#include <math.h> #include <stdio.h> #include <stdlib.h> -#include <math.h> -#include <iostream> -#include <boost/lexical_cast.hpp> -#include <Swiften/Base/log.h> -#include <Swiften/Base/foreach.h> -#include <Swiften/Network/WindowsProxyProvider.h> -#include <Swiften/Base/ByteArray.h> +#include <boost/lexical_cast.hpp> +#include <boost/numeric/conversion/cast.hpp> #include <windows.h> +#include <Swiften/Base/ByteArray.h> +#include <Swiften/Base/Log.h> +#include <Swiften/Network/HostAddress.h> +#include <Swiften/Network/HostAddressPort.h> + namespace Swift { WindowsProxyProvider::WindowsProxyProvider() : ProxyProvider() { - HKEY hKey = (HKEY)INVALID_HANDLE_VALUE; - long result; - - result = RegOpenKeyEx(HKEY_CURRENT_USER, "Software\\Microsoft\\Windows\\CurrentVersion\\Internet Settings", 0, KEY_READ, &hKey); - if (result == ERROR_SUCCESS && hKey != INVALID_HANDLE_VALUE && proxyEnabled(hKey)) { - DWORD dataType = REG_SZ; - DWORD dataSize = 0; - ByteArray dataBuffer; - - result = RegQueryValueEx(hKey, "ProxyServer", NULL, &dataType, NULL, &dataSize); - if(result != ERROR_SUCCESS) { - return; - } - dataBuffer.resize(dataSize); - result = RegQueryValueEx(hKey, "ProxyServer", NULL, &dataType, reinterpret_cast<BYTE*>(vecptr(dataBuffer)), &dataSize); - if(result == ERROR_SUCCESS) { - std::vector<std::string> proxies = String::split(byteArrayToString(dataBuffer), ';'); - std::pair<std::string, std::string> protocolAndProxy; - foreach(std::string proxy, proxies) { - if(proxy.find('=') != std::string::npos) { - protocolAndProxy = String::getSplittedAtFirst(proxy, '='); - SWIFT_LOG(debug) << "Found proxy: " << protocolAndProxy.first << " => " << protocolAndProxy.second << std::endl; - if(protocolAndProxy.first.compare("socks") == 0) { - socksProxy = getAsHostAddressPort(protocolAndProxy.second); - } - else if (protocolAndProxy.first.compare("http") == 0) { - httpProxy = getAsHostAddressPort(protocolAndProxy.second); - } - } - } - } - } + HKEY hKey = (HKEY)INVALID_HANDLE_VALUE; + long result; + + result = RegOpenKeyEx(HKEY_CURRENT_USER, "Software\\Microsoft\\Windows\\CurrentVersion\\Internet Settings", 0, KEY_READ, &hKey); + if (result == ERROR_SUCCESS && hKey != INVALID_HANDLE_VALUE && proxyEnabled(hKey)) { + DWORD dataType = REG_SZ; + DWORD dataSize = 0; + ByteArray dataBuffer; + + result = RegQueryValueEx(hKey, "ProxyServer", NULL, &dataType, NULL, &dataSize); + if(result != ERROR_SUCCESS) { + return; + } + dataBuffer.resize(dataSize); + result = RegQueryValueEx(hKey, "ProxyServer", NULL, &dataType, reinterpret_cast<BYTE*>(vecptr(dataBuffer)), &dataSize); + if(result == ERROR_SUCCESS) { + std::vector<std::string> proxies = String::split(byteArrayToString(dataBuffer), ';'); + std::pair<std::string, std::string> protocolAndProxy; + for(auto&& proxy : proxies) { + if(proxy.find('=') != std::string::npos) { + protocolAndProxy = String::getSplittedAtFirst(proxy, '='); + SWIFT_LOG(debug) << "Found proxy: " << protocolAndProxy.first << " => " << protocolAndProxy.second; + if(protocolAndProxy.first.compare("socks") == 0) { + socksProxy = getAsHostAddressPort(protocolAndProxy.second); + } + else if (protocolAndProxy.first.compare("http") == 0) { + httpProxy = getAsHostAddressPort(protocolAndProxy.second); + } + } + } + } + } } HostAddressPort WindowsProxyProvider::getHTTPConnectProxy() const { - return httpProxy; + return httpProxy; } HostAddressPort WindowsProxyProvider::getSOCKS5Proxy() const { - return socksProxy; + return socksProxy; } HostAddressPort WindowsProxyProvider::getAsHostAddressPort(std::string proxy) { - HostAddressPort ret(HostAddress(), 0); - - try { - std::pair<std::string, std::string> tmp; - int port = 0; - tmp = String::getSplittedAtFirst(proxy, ':'); - // .c_str() is needed as tmp.second can include a \0 char which will end in an exception of the lexical cast. - // with .c_str() the \0 will not be part of the string which is to be casted - port = boost::lexical_cast<int> (tmp.second.c_str()); - ret = HostAddressPort(HostAddress(tmp.first), port); - } - catch(...) { - std::cerr << "Exception occured while parsing windows proxy \"getHostAddressPort\"." << std::endl; - } - - return ret; + HostAddressPort ret(HostAddress(), 0); + + try { + std::pair<std::string, std::string> tmp; + unsigned short port = 0; + tmp = String::getSplittedAtFirst(proxy, ':'); + // .c_str() is needed as tmp.second can include a \0 char which will end in an exception of the lexical cast. + // with .c_str() the \0 will not be part of the string which is to be casted + port = boost::numeric_cast<unsigned short>(boost::lexical_cast<int> (tmp.second.c_str())); + ret = HostAddressPort(HostAddress::fromString(tmp.first).get(), port); + } + catch(...) { + SWIFT_LOG(error) << "Exception occured while parsing windows proxy \"getHostAddressPort\"."; + } + + return ret; } bool WindowsProxyProvider::proxyEnabled(HKEY hKey) const { - bool ret = false; - long result; - DWORD dataType = REG_DWORD; - DWORD dataSize = 0; - DWORD data = 0; - ByteArray dataBuffer; - - if(hKey == INVALID_HANDLE_VALUE) - return ret; - - result = RegQueryValueEx(hKey, "ProxyEnable", NULL, &dataType, NULL, &dataSize); - if(result != ERROR_SUCCESS) - return ret; - - dataBuffer.resize(dataSize); - result = RegQueryValueEx(hKey, "ProxyEnable", NULL, &dataType, reinterpret_cast<BYTE*>(vecptr(dataBuffer)), &dataSize); - if(result != ERROR_SUCCESS) - return ret; - - for(size_t t = 0; t < dataBuffer.size(); t++) { - data += static_cast<int> (dataBuffer[t]) * pow(256, static_cast<double>(t)); - } - return (data == 1); + bool ret = false; + long result; + DWORD dataType = REG_DWORD; + DWORD dataSize = 0; + DWORD data = 0; + ByteArray dataBuffer; + + if(hKey == INVALID_HANDLE_VALUE) { + return ret; + } + + result = RegQueryValueEx(hKey, "ProxyEnable", NULL, &dataType, NULL, &dataSize); + if(result != ERROR_SUCCESS) { + return ret; + } + + dataBuffer.resize(dataSize); + result = RegQueryValueEx(hKey, "ProxyEnable", NULL, &dataType, reinterpret_cast<BYTE*>(vecptr(dataBuffer)), &dataSize); + if(result != ERROR_SUCCESS) { + return ret; + } + + for(size_t t = 0; t < dataBuffer.size(); t++) { + data += static_cast<int> (dataBuffer[t]) * pow(256, static_cast<double>(t)); + } + return (data == 1); } } diff --git a/Swiften/Network/WindowsProxyProvider.h b/Swiften/Network/WindowsProxyProvider.h index 12aa18d..0ca897d 100644 --- a/Swiften/Network/WindowsProxyProvider.h +++ b/Swiften/Network/WindowsProxyProvider.h @@ -4,21 +4,28 @@ * See Documentation/Licenses/BSD-simplified.txt for more information. */ +/* + * Copyright (c) 2016 Isode Limited. + * All rights reserved. + * See the COPYING file for more information. + */ + #pragma once #include <Swiften/Base/API.h> +#include <Swiften/Network/HostAddressPort.h> #include <Swiften/Network/ProxyProvider.h> namespace Swift { - class SWIFTEN_API WindowsProxyProvider : public ProxyProvider { - public: - WindowsProxyProvider(); - virtual HostAddressPort getHTTPConnectProxy() const; - virtual HostAddressPort getSOCKS5Proxy() const; - private: - HostAddressPort getAsHostAddressPort(std::string proxy); - bool proxyEnabled(HKEY hKey) const; - HostAddressPort socksProxy; - HostAddressPort httpProxy; - }; + class SWIFTEN_API WindowsProxyProvider : public ProxyProvider { + public: + WindowsProxyProvider(); + virtual HostAddressPort getHTTPConnectProxy() const; + virtual HostAddressPort getSOCKS5Proxy() const; + private: + HostAddressPort getAsHostAddressPort(std::string proxy); + bool proxyEnabled(HKEY hKey) const; + HostAddressPort socksProxy; + HostAddressPort httpProxy; + }; } |