From fdd8755e2363e8d706a3d0bdc2e71f234abdf829 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Remko=20Tron=C3=A7on?= Date: Thu, 12 Nov 2009 19:12:47 +0100 Subject: Refactored DNS handling. Connections now fallback on other DNS entries upon failure, taking into account SRV priorities. diff --git a/BuildTools/Git/Hooks/pre-commit b/BuildTools/Git/Hooks/pre-commit old mode 100644 new mode 100755 diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp index 85db4ac..6614bf7 100644 --- a/Swiften/Client/Client.cpp +++ b/Swiften/Client/Client.cpp @@ -2,11 +2,11 @@ #include -#include "Swiften/Network/DomainNameResolver.h" #include "Swiften/Network/MainBoostIOServiceThread.h" #include "Swiften/Network/BoostIOServiceThread.h" #include "Swiften/Client/ClientSession.h" #include "Swiften/StreamStack/PlatformTLSLayerFactory.h" +#include "Swiften/Network/Connector.h" #include "Swiften/Network/BoostConnectionFactory.h" #include "Swiften/Network/DomainNameResolveException.h" #include "Swiften/TLS/PKCS12Certificate.h" @@ -33,24 +33,22 @@ bool Client::isAvailable() { } void Client::connect() { - assert(!connection_); - DomainNameResolver resolver; - try { - HostAddressPort remote = resolver.resolve(jid_.getDomain().getUTF8String()); - connection_ = connectionFactory_->createConnection(); - connection_->onConnectFinished.connect(boost::bind(&Client::handleConnectionConnectFinished, this, _1)); - connection_->connect(remote); - } - catch (const DomainNameResolveException& e) { - onError(ClientError::DomainNameResolveError); - } + assert(!connector_); + connector_ = boost::shared_ptr(new Connector(jid_.getDomain(), &resolver_, connectionFactory_)); + connector_->onConnectFinished.connect(boost::bind(&Client::handleConnectorFinished, this, _1)); + connector_->start(); } -void Client::handleConnectionConnectFinished(bool error) { - if (error) { +void Client::handleConnectorFinished(boost::shared_ptr connection) { + // TODO: Add domain name resolver error + connector_.reset(); + if (!connection) { onError(ClientError::ConnectionError); } else { + assert(!connection_); + connection_ = connection; + assert(!sessionStream_); sessionStream_ = boost::shared_ptr(new BasicSessionStream(connection_, &payloadParserFactories_, &payloadSerializers_, tlsLayerFactory_)); if (!certificate_.isEmpty()) { @@ -78,6 +76,9 @@ void Client::disconnect() { } void Client::closeConnection() { + if (sessionStream_) { + sessionStream_.reset(); + } if (connection_) { connection_->disconnect(); connection_.reset(); @@ -186,11 +187,11 @@ void Client::handleNeedCredentials() { } void Client::handleDataRead(const String& data) { - onDataRead(data); + onDataRead(data); } void Client::handleDataWritten(const String& data) { - onDataWritten(data); + onDataWritten(data); } } diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h index 3f7d350..f09c916 100644 --- a/Swiften/Client/Client.h +++ b/Swiften/Client/Client.h @@ -4,6 +4,7 @@ #include #include +#include "Swiften/Network/PlatformDomainNameResolver.h" #include "Swiften/Base/Error.h" #include "Swiften/Client/ClientSession.h" #include "Swiften/Client/ClientError.h" @@ -22,6 +23,7 @@ namespace Swift { class ConnectionFactory; class ClientSession; class BasicSessionStream; + class Connector; class Client : public StanzaChannel, public IQRouter, public boost::bsignals::trackable { public: @@ -46,7 +48,7 @@ namespace Swift { boost::signal onDataWritten; private: - void handleConnectionConnectFinished(bool error); + void handleConnectorFinished(boost::shared_ptr); void send(boost::shared_ptr); virtual String getNewIQID(); void handleElement(boost::shared_ptr); @@ -58,9 +60,11 @@ namespace Swift { void closeConnection(); private: + PlatformDomainNameResolver resolver_; JID jid_; String password_; IDGenerator idGenerator_; + boost::shared_ptr connector_; ConnectionFactory* connectionFactory_; TLSLayerFactory* tlsLayerFactory_; FullPayloadParserFactoryCollection payloadParserFactories_; diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp new file mode 100644 index 0000000..5b4fe22 --- /dev/null +++ b/Swiften/Network/Connector.cpp @@ -0,0 +1,48 @@ +#include "Swiften/Network/Connector.h" + +#include + +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/Network/DomainNameResolver.h" +#include "Swiften/Network/DomainNameResolveException.h" + +namespace Swift { + +Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory) { +} + +void Connector::start() { + assert(!currentConnection); + try { + std::vector resolveResult = resolver->resolve(hostname.getUTF8String()); + resolvedHosts = std::deque(resolveResult.begin(), resolveResult.end()); + tryNextHostname(); + } + catch (const DomainNameResolveException&) { + onConnectFinished(boost::shared_ptr()); + } +} + +void Connector::tryNextHostname() { + if (resolvedHosts.empty()) { + onConnectFinished(boost::shared_ptr()); + } + else { + HostAddressPort remote = resolvedHosts.front(); + resolvedHosts.pop_front(); + currentConnection = connectionFactory->createConnection(); + currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, this, _1)); + currentConnection->connect(remote); + } +} + +void Connector::handleConnectionConnectFinished(bool error) { + if (error) { + tryNextHostname(); + } + else { + onConnectFinished(currentConnection); + } +} + +}; diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h new file mode 100644 index 0000000..084c416 --- /dev/null +++ b/Swiften/Network/Connector.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +#include "Swiften/Network/Connection.h" +#include "Swiften/Base/String.h" + +namespace Swift { + class DomainNameResolver; + class ConnectionFactory; + + class Connector : public boost::bsignals::trackable { + public: + Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*); + + void start(); + + boost::signal)> onConnectFinished; + + private: + void tryNextHostname(); + void handleConnectionConnectFinished(bool error); + + private: + String hostname; + DomainNameResolver* resolver; + ConnectionFactory* connectionFactory; + std::deque resolvedHosts; + boost::shared_ptr currentConnection; + }; + +}; diff --git a/Swiften/Network/DomainNameResolver.cpp b/Swiften/Network/DomainNameResolver.cpp index 44b3ecf..907dfc9 100644 --- a/Swiften/Network/DomainNameResolver.cpp +++ b/Swiften/Network/DomainNameResolver.cpp @@ -1,176 +1,8 @@ #include "Swiften/Network/DomainNameResolver.h" -#include "Swiften/Base/Platform.h" - -#include -#include -#include -#ifdef SWIFTEN_PLATFORM_WINDOWS -#undef UNICODE -#include -#include -#ifndef DNS_TYPE_SRV -#define DNS_TYPE_SRV 33 -#endif -#else -#include -#include -#include -#endif - -#include "Swiften/Network/DomainNameResolveException.h" -#include "Swiften/Base/String.h" -#include "Swiften/Base/ByteArray.h" namespace Swift { -DomainNameResolver::DomainNameResolver() { -} - DomainNameResolver::~DomainNameResolver() { } -HostAddressPort DomainNameResolver::resolve(const String& domain) { - char* output; - if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) { - std::string outputString(output); - free(output); - return resolveDomain(outputString); - } - else { - return resolveDomain(domain.getUTF8String()); - } -} - -HostAddressPort DomainNameResolver::resolveDomain(const std::string& domain) { - try { - return resolveXMPPService(domain); - } - catch (const DomainNameResolveException&) { - } - return HostAddressPort(resolveHostName(domain), 5222); -} - -HostAddressPort DomainNameResolver::resolveXMPPService(const std::string& domain) { - std::string srvQuery = "_xmpp-client._tcp." + domain; - -#if defined(SWIFTEN_PLATFORM_WINDOWS) - DNS_RECORD* responses; - // FIXME: This conversion doesn't work if unicode is deffed above - if (DnsQuery(srvQuery.c_str(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) { - throw DomainNameResolveException(); - } - - DNS_RECORD* currentEntry = responses; - while (currentEntry) { - if (currentEntry->wType == DNS_TYPE_SRV) { - int port = currentEntry->Data.SRV.wPort; - try { - // The pNameTarget is actually a PCWSTR, so I would have expected this - // conversion to not work at all, but it does. - // Actually, it doesn't. Fix this and remove explicit cast - // Remove unicode undef above as well - std::string hostname((const char*) currentEntry->Data.SRV.pNameTarget); - HostAddress address = resolveHostName(hostname); - DnsRecordListFree(responses, DnsFreeRecordList); - return HostAddressPort(address, port); - } - catch (const DomainNameResolveException&) { - } - } - currentEntry = currentEntry->pNext; - } - DnsRecordListFree(responses, DnsFreeRecordList); - -#else - - ByteArray response; - response.resize(NS_PACKETSZ); - int responseLength = res_query(const_cast(srvQuery.c_str()), ns_c_in, ns_t_srv, reinterpret_cast(response.getData()), response.getSize()); - if (responseLength == -1) { - throw DomainNameResolveException(); - } - - // Parse header - HEADER* header = reinterpret_cast(response.getData()); - unsigned char* messageStart = reinterpret_cast(response.getData()); - unsigned char* messageEnd = messageStart + responseLength; - unsigned char* currentEntry = messageStart + NS_HFIXEDSZ; - - // Skip over the queries - int queriesCount = ntohs(header->qdcount); - while (queriesCount > 0) { - int entryLength = dn_skipname(currentEntry, messageEnd); - if (entryLength < 0) { - throw DomainNameResolveException(); - } - currentEntry += entryLength + NS_QFIXEDSZ; - queriesCount--; - } - - // Process the SRV answers - int answersCount = ntohs(header->ancount); - while (answersCount > 0) { - int entryLength = dn_skipname(currentEntry, messageEnd); - currentEntry += entryLength; - currentEntry += NS_RRFIXEDSZ; - - // Uninteresting information - currentEntry += 2; // PRIORITY - currentEntry += 2; // WEIGHT - - // Port - if (currentEntry >= messageEnd) { - throw DomainNameResolveException(); - } - int port = ns_get16(currentEntry); - currentEntry += 2; - - // Hostname - if (currentEntry >= messageEnd) { - throw DomainNameResolveException(); - } - ByteArray entry; - entry.resize(NS_MAXDNAME); - entryLength = dn_expand(messageStart, messageEnd, currentEntry, entry.getData(), entry.getSize()); - if (entryLength < 0) { - throw DomainNameResolveException(); - } - try { - // Resolve the hostname - std::string hostname(entry.getData(), entryLength); - HostAddress address = resolveHostName(hostname); - return HostAddressPort(address, port); - } - catch (const DomainNameResolveException&) { - } - currentEntry += entryLength; - answersCount--; - } -#endif - - throw DomainNameResolveException(); -} - -HostAddress DomainNameResolver::resolveHostName(const std::string& hostname) { - boost::asio::io_service ioService; - boost::asio::ip::tcp::resolver resolver(ioService); - boost::asio::ip::tcp::resolver::query query(hostname, "5222"); - try { - boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query); - if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) { - throw DomainNameResolveException(); - } - boost::asio::ip::address address = (*endpointIterator).endpoint().address(); - if (address.is_v4()) { - return HostAddress(&address.to_v4().to_bytes()[0], 4); - } - else { - return HostAddress(&address.to_v6().to_bytes()[0], 16); - } - } - catch (...) { - throw DomainNameResolveException(); - } -} - } diff --git a/Swiften/Network/DomainNameResolver.h b/Swiften/Network/DomainNameResolver.h index c7736b1..5c83622 100644 --- a/Swiften/Network/DomainNameResolver.h +++ b/Swiften/Network/DomainNameResolver.h @@ -1,10 +1,7 @@ -#ifndef SWIFTEN_DOMAINNAMERESOLVER_H -#define SWIFTEN_DOMAINNAMERESOLVER_H +#pragma once -#include +#include -#include "Swiften/Base/String.h" -#include "Swiften/Network/HostAddress.h" #include "Swiften/Network/HostAddressPort.h" namespace Swift { @@ -12,16 +9,8 @@ namespace Swift { class DomainNameResolver { public: - DomainNameResolver(); virtual ~DomainNameResolver(); - HostAddressPort resolve(const String& domain); - - private: - virtual HostAddressPort resolveDomain(const std::string& domain); - HostAddressPort resolveXMPPService(const std::string& domain); - HostAddress resolveHostName(const std::string& hostName); + virtual std::vector resolve(const String& domain) = 0; }; } - -#endif diff --git a/Swiften/Network/DummyConnection.h b/Swiften/Network/DummyConnection.h new file mode 100644 index 0000000..11281b3 --- /dev/null +++ b/Swiften/Network/DummyConnection.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include +#include + +#include "Swiften/Network/Connection.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/EventLoop/EventOwner.h" + +namespace Swift { + class DummyConnection : + public Connection, + public EventOwner, + public boost::enable_shared_from_this { + + void listen() { + assert(false); + } + + void connect(const HostAddressPort&) { + assert(false); + } + + void disconnect() { + assert(false); + } + + void write(const ByteArray& data) { + onDataWritten(data); + } + + void receive(const ByteArray& data) { + MainEventLoop::postEvent(boost::bind( + boost::ref(onDataRead), ByteArray(data)), shared_from_this()); + } + + boost::signal onDataWritten; + }; +} diff --git a/Swiften/Network/HostAddress.h b/Swiften/Network/HostAddress.h index bf6d2f8..11f8a2b 100644 --- a/Swiften/Network/HostAddress.h +++ b/Swiften/Network/HostAddress.h @@ -18,6 +18,10 @@ namespace Swift { std::string toString() const; + bool operator==(const HostAddress& o) const { + return address_ == o.address_; + } + private: std::vector address_; }; diff --git a/Swiften/Network/HostAddressPort.h b/Swiften/Network/HostAddressPort.h index 8668ae4..d632058 100644 --- a/Swiften/Network/HostAddressPort.h +++ b/Swiften/Network/HostAddressPort.h @@ -17,6 +17,10 @@ namespace Swift { return port_; } + bool operator==(const HostAddressPort& o) const { + return address_ == o.address_ && port_ == o.port_; + } + private: HostAddress address_; int port_; diff --git a/Swiften/Network/PlatformDomainNameResolver.cpp b/Swiften/Network/PlatformDomainNameResolver.cpp new file mode 100644 index 0000000..3803e7c --- /dev/null +++ b/Swiften/Network/PlatformDomainNameResolver.cpp @@ -0,0 +1,200 @@ +#include "Swiften/Network/PlatformDomainNameResolver.h" + +#include "Swiften/Base/Platform.h" +#include "Swiften/Base/foreach.h" + +#include +#include +#include +#ifdef SWIFTEN_PLATFORM_WINDOWS +#undef UNICODE +#include +#include +#ifndef DNS_TYPE_SRV +#define DNS_TYPE_SRV 33 +#endif +#else +#include +#include +#include +#endif +#include + +#include "Swiften/Network/DomainNameResolveException.h" +#include "Swiften/Base/String.h" +#include "Swiften/Base/ByteArray.h" +#include "Swiften/Network/SRVRecord.h" +#include "Swiften/Network/SRVRecordPriorityComparator.h" + +namespace Swift { + +PlatformDomainNameResolver::PlatformDomainNameResolver() { +} + +std::vector PlatformDomainNameResolver::resolve(const String& domain) { + char* output; + if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) { + std::string outputString(output); + free(output); + return resolveDomain(outputString); + } + else { + return resolveDomain(domain.getUTF8String()); + } +} + +std::vector PlatformDomainNameResolver::resolveDomain(const std::string& domain) { + try { + return resolveXMPPService(domain); + } + catch (const DomainNameResolveException&) { + } + std::vector result; + result.push_back(HostAddressPort(resolveHostName(domain), 5222)); + return result; +} + +std::vector PlatformDomainNameResolver::resolveXMPPService(const std::string& domain) { + std::vector records; + + std::string srvQuery = "_xmpp-client._tcp." + domain; + +#if defined(SWIFTEN_PLATFORM_WINDOWS) + DNS_RECORD* responses; + // FIXME: This conversion doesn't work if unicode is deffed above + if (DnsQuery(srvQuery.c_str(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) { + throw DomainNameResolveException(); + } + + DNS_RECORD* currentEntry = responses; + while (currentEntry) { + if (currentEntry->wType == DNS_TYPE_SRV) { + SRVRecord record; + record.priority = currentEntry->Data.SRV.wPriority; + record.weight = currentEntry->Data.SRV.wWeight; + record.port = currentEntry->Data.SRV.wPort; + + // The pNameTarget is actually a PCWSTR, so I would have expected this + // conversion to not work at all, but it does. + // Actually, it doesn't. Fix this and remove explicit cast + // Remove unicode undef above as well + record.hostname = std::string((const char*) currentEntry->Data.SRV.pNameTarget); + records.push_back(record); + } + currentEntry = currentEntry->pNext; + } + DnsRecordListFree(responses, DnsFreeRecordList); + +#else + + ByteArray response; + response.resize(NS_PACKETSZ); + int responseLength = res_query(const_cast(srvQuery.c_str()), ns_c_in, ns_t_srv, reinterpret_cast(response.getData()), response.getSize()); + if (responseLength == -1) { + throw DomainNameResolveException(); + } + + // Parse header + HEADER* header = reinterpret_cast(response.getData()); + unsigned char* messageStart = reinterpret_cast(response.getData()); + unsigned char* messageEnd = messageStart + responseLength; + unsigned char* currentEntry = messageStart + NS_HFIXEDSZ; + + // Skip over the queries + int queriesCount = ntohs(header->qdcount); + while (queriesCount > 0) { + int entryLength = dn_skipname(currentEntry, messageEnd); + if (entryLength < 0) { + throw DomainNameResolveException(); + } + currentEntry += entryLength + NS_QFIXEDSZ; + queriesCount--; + } + + // Process the SRV answers + int answersCount = ntohs(header->ancount); + while (answersCount > 0) { + SRVRecord record; + + int entryLength = dn_skipname(currentEntry, messageEnd); + currentEntry += entryLength; + currentEntry += NS_RRFIXEDSZ; + + // Priority + if (currentEntry + 2 >= messageEnd) { + throw DomainNameResolveException(); + } + record.priority = ns_get16(currentEntry); + currentEntry += 2; + + // Weight + if (currentEntry + 2 >= messageEnd) { + throw DomainNameResolveException(); + } + record.weight = ns_get16(currentEntry); + currentEntry += 2; + + // Port + if (currentEntry + 2 >= messageEnd) { + throw DomainNameResolveException(); + } + record.port = ns_get16(currentEntry); + currentEntry += 2; + + // Hostname + if (currentEntry >= messageEnd) { + throw DomainNameResolveException(); + } + ByteArray entry; + entry.resize(NS_MAXDNAME); + entryLength = dn_expand(messageStart, messageEnd, currentEntry, entry.getData(), entry.getSize()); + if (entryLength < 0) { + throw DomainNameResolveException(); + } + record.hostname = std::string(entry.getData(), entryLength); + records.push_back(record); + currentEntry += entryLength; + answersCount--; + } +#endif + + // Resolve the hostnames in the records, and build the result list + std::sort(records.begin(), records.end(), SRVRecordPriorityComparator()); + std::vector result; + foreach(const SRVRecord& record, records) { + try { + HostAddress address = resolveHostName(record.hostname); + result.push_back(HostAddressPort(address, record.port)); + } + catch (const DomainNameResolveException&) { + } + } + if (result.empty()) { + throw DomainNameResolveException(); + } + return result; +} + +HostAddress PlatformDomainNameResolver::resolveHostName(const std::string& hostname) { + boost::asio::io_service ioService; + boost::asio::ip::tcp::resolver resolver(ioService); + boost::asio::ip::tcp::resolver::query query(hostname, "5222"); + try { + boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query); + if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) { + throw DomainNameResolveException(); + } + boost::asio::ip::address address = (*endpointIterator).endpoint().address(); + if (address.is_v4()) { + return HostAddress(&address.to_v4().to_bytes()[0], 4); + } + else { + return HostAddress(&address.to_v6().to_bytes()[0], 16); + } + } + catch (...) { + throw DomainNameResolveException(); + } +} + +} diff --git a/Swiften/Network/PlatformDomainNameResolver.h b/Swiften/Network/PlatformDomainNameResolver.h new file mode 100644 index 0000000..651bc97 --- /dev/null +++ b/Swiften/Network/PlatformDomainNameResolver.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "Swiften/Base/String.h" +#include "Swiften/Network/DomainNameResolver.h" +#include "Swiften/Network/HostAddress.h" +#include "Swiften/Network/HostAddressPort.h" + +namespace Swift { + class String; + + class PlatformDomainNameResolver : public DomainNameResolver { + public: + PlatformDomainNameResolver(); + + std::vector resolve(const String& domain); + + private: + std::vector resolveDomain(const std::string& domain); + std::vector resolveXMPPService(const std::string& domain); + HostAddress resolveHostName(const std::string& hostName); + }; +} diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript index 06d9350..652bda1 100644 --- a/Swiften/Network/SConscript +++ b/Swiften/Network/SConscript @@ -11,7 +11,10 @@ objects = myenv.StaticObject([ "BoostIOServiceThread.cpp", "ConnectionFactory.cpp", "ConnectionServer.cpp", + "Connector.cpp", "DomainNameResolver.cpp", + "PlatformDomainNameResolver.cpp", + "StaticDomainNameResolver.cpp", "HostAddress.cpp", "Timer.cpp", ]) diff --git a/Swiften/Network/SRVRecord.h b/Swiften/Network/SRVRecord.h new file mode 100644 index 0000000..5a11635 --- /dev/null +++ b/Swiften/Network/SRVRecord.h @@ -0,0 +1,10 @@ +#pragma once + +namespace Swift { + struct SRVRecord { + std::string hostname; + int port; + int priority; + int weight; + }; +} diff --git a/Swiften/Network/SRVRecordPriorityComparator.h b/Swiften/Network/SRVRecordPriorityComparator.h new file mode 100644 index 0000000..fc16597 --- /dev/null +++ b/Swiften/Network/SRVRecordPriorityComparator.h @@ -0,0 +1,11 @@ +#pragma once + +#include "Swiften/Network/SRVRecord.h" + +namespace Swift { + struct SRVRecordPriorityComparator { + bool operator()(const SRVRecord& a, const SRVRecord& b) const { + return a.priority < b.priority; + } + }; +} diff --git a/Swiften/Network/StaticDomainNameResolver.cpp b/Swiften/Network/StaticDomainNameResolver.cpp new file mode 100644 index 0000000..8ca4062 --- /dev/null +++ b/Swiften/Network/StaticDomainNameResolver.cpp @@ -0,0 +1,28 @@ +#include "Swiften/Network/StaticDomainNameResolver.h" +#include "Swiften/Network/DomainNameResolveException.h" +#include "Swiften/Base/String.h" + +namespace Swift { + +StaticDomainNameResolver::StaticDomainNameResolver() { +} + +std::vector StaticDomainNameResolver::resolve(const String& queriedDomain) { + std::vector result; + + for(DomainCollection::const_iterator i = domains.begin(); i != domains.end(); ++i) { + if (i->first == queriedDomain) { + result.push_back(i->second); + } + } + if (result.empty()) { + throw DomainNameResolveException(); + } + return result; +} + +void StaticDomainNameResolver::addDomain(const String& domain, const HostAddressPort& addressPort) { + domains.push_back(std::make_pair(domain, addressPort)); +} + +} diff --git a/Swiften/Network/StaticDomainNameResolver.h b/Swiften/Network/StaticDomainNameResolver.h new file mode 100644 index 0000000..8688429 --- /dev/null +++ b/Swiften/Network/StaticDomainNameResolver.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +#include "Swiften/Network/DomainNameResolver.h" + +namespace Swift { + class String; + + class StaticDomainNameResolver : public DomainNameResolver { + public: + StaticDomainNameResolver(); + + virtual std::vector resolve(const String& domain); + + void addDomain(const String& domain, const HostAddressPort& addressPort); + + private: + typedef std::vector< std::pair > DomainCollection; + DomainCollection domains; + }; +} diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp new file mode 100644 index 0000000..32893d8 --- /dev/null +++ b/Swiften/Network/UnitTest/ConnectorTest.cpp @@ -0,0 +1,140 @@ +#include +#include + +#include +#include + +#include "Swiften/Network/Connector.h" +#include "Swiften/Network/Connection.h" +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/Network/HostAddressPort.h" +#include "Swiften/Network/StaticDomainNameResolver.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/EventLoop/DummyEventLoop.h" + +using namespace Swift; + +class ConnectorTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(ConnectorTest); + CPPUNIT_TEST(testConnect); + CPPUNIT_TEST(testConnect_NoHosts); + CPPUNIT_TEST(testConnect_FirstHostFails); + CPPUNIT_TEST(testConnect_AllHostsFail); + CPPUNIT_TEST_SUITE_END(); + + public: + ConnectorTest() : host1(HostAddress("1.1.1.1"), 1234), host2(HostAddress("2.2.2.2"), 2345) { + } + + void setUp() { + eventLoop = new DummyEventLoop(); + resolver = new StaticDomainNameResolver(); + connectionFactory = new MockConnectionFactory(); + } + + void tearDown() { + delete connectionFactory; + delete resolver; + delete eventLoop; + } + + void testConnect() { + std::auto_ptr testling(createConnector()); + resolver->addDomain("foo.com", host1); + resolver->addDomain("foo.com", host2); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(host1 == *(connections[0]->hostAddressPort)); + } + + void testConnect_NoHosts() { + std::auto_ptr testling(createConnector()); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + void testConnect_FirstHostFails() { + std::auto_ptr testling(createConnector()); + resolver->addDomain("foo.com", host1); + resolver->addDomain("foo.com", host2); + connectionFactory->failingPorts.push_back(host1); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast(connections.size())); + CPPUNIT_ASSERT(host2 == *(connections[0]->hostAddressPort)); + } + + void testConnect_AllHostsFail() { + std::auto_ptr testling(createConnector()); + resolver->addDomain("foo.com", host1); + resolver->addDomain("foo.com", host2); + connectionFactory->failingPorts.push_back(host1); + connectionFactory->failingPorts.push_back(host2); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + private: + Connector* createConnector() { + Connector* connector = new Connector("foo.com", resolver, connectionFactory); + connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1)); + return connector; + } + + void handleConnectorFinished(boost::shared_ptr connection) { + boost::shared_ptr c(boost::dynamic_pointer_cast(connection)); + if (connection) { + assert(c); + } + connections.push_back(c); + } + + struct MockConnection : public Connection { + public: + MockConnection(const std::vector& failingPorts) : failingPorts(failingPorts) {} + + void listen() { assert(false); } + void connect(const HostAddressPort& address) { + hostAddressPort = address; + MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end())); + } + + void disconnect() { assert(false); } + void write(const ByteArray&) { assert(false); } + + boost::optional hostAddressPort; + std::vector failingPorts; + }; + + struct MockConnectionFactory : public ConnectionFactory { + boost::shared_ptr createConnection() { + return boost::shared_ptr(new MockConnection(failingPorts)); + } + + std::vector failingPorts; + }; + + private: + HostAddressPort host1; + HostAddressPort host2; + DummyEventLoop* eventLoop; + StaticDomainNameResolver* resolver; + MockConnectionFactory* connectionFactory; + std::vector< boost::shared_ptr > connections; +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(ConnectorTest); diff --git a/Swiften/QA/ClientTest/ClientTest.cpp b/Swiften/QA/ClientTest/ClientTest.cpp index b50a0bf..cf1c161 100644 --- a/Swiften/QA/ClientTest/ClientTest.cpp +++ b/Swiften/QA/ClientTest/ClientTest.cpp @@ -15,12 +15,20 @@ using namespace Swift; SimpleEventLoop eventLoop; Client* client = 0; +bool reconnected = false; bool rosterReceived = false; void handleRosterReceived(boost::shared_ptr) { - rosterReceived = true; - client->disconnect(); - eventLoop.stop(); + if (reconnected) { + rosterReceived = true; + client->disconnect(); + eventLoop.stop(); + } + else { + reconnected = true; + client->disconnect(); + client->connect(); + } } void handleConnected() { diff --git a/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp b/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp index 8968efd..cb812a1 100644 --- a/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp +++ b/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp @@ -2,7 +2,7 @@ #include #include "Swiften/Base/String.h" -#include "Swiften/Network/DomainNameResolver.h" +#include "Swiften/Network/PlatformDomainNameResolver.h" #include "Swiften/Network/DomainNameResolveException.h" using namespace Swift; @@ -14,13 +14,14 @@ class DomainNameResolverTest : public CppUnit::TestFixture { CPPUNIT_TEST(testResolve_Invalid); //CPPUNIT_TEST(testResolve_IPv6); CPPUNIT_TEST(testResolve_International); + CPPUNIT_TEST(testResolve_Localhost); CPPUNIT_TEST_SUITE_END(); public: DomainNameResolverTest() {} void setUp() { - resolver_ = new DomainNameResolver(); + resolver_ = new PlatformDomainNameResolver(); } void tearDown() { @@ -28,17 +29,22 @@ class DomainNameResolverTest : public CppUnit::TestFixture { } void testResolve_NoSRV() { - HostAddressPort result = resolver_->resolve("xmpp.test.swift.im"); + HostAddressPort result = resolver_->resolve("xmpp.test.swift.im")[0]; CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.0"), result.getAddress().toString()); CPPUNIT_ASSERT_EQUAL(5222, result.getPort()); } void testResolve_SRV() { - HostAddressPort result = resolver_->resolve("xmpp-srv.test.swift.im"); + std::vector result = resolver_->resolve("xmpp-srv.test.swift.im"); - CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.1"), result.getAddress().toString()); - CPPUNIT_ASSERT_EQUAL(5000, result.getPort()); + CPPUNIT_ASSERT_EQUAL(3, static_cast(result.size())); + CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.1"), result[0].getAddress().toString()); + CPPUNIT_ASSERT_EQUAL(5000, result[0].getPort()); + CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.3"), result[1].getAddress().toString()); + CPPUNIT_ASSERT_EQUAL(5000, result[1].getPort()); + CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.2"), result[2].getAddress().toString()); + CPPUNIT_ASSERT_EQUAL(5000, result[2].getPort()); } void testResolve_Invalid() { @@ -46,19 +52,25 @@ class DomainNameResolverTest : public CppUnit::TestFixture { } void testResolve_IPv6() { - HostAddressPort result = resolver_->resolve("xmpp-ipv6.test.swift.im"); + HostAddressPort result = resolver_->resolve("xmpp-ipv6.test.swift.im")[0]; CPPUNIT_ASSERT_EQUAL(std::string("0000:0000:0000:0000:0000:ffff:0a00:0104"), result.getAddress().toString()); CPPUNIT_ASSERT_EQUAL(5222, result.getPort()); } void testResolve_International() { - HostAddressPort result = resolver_->resolve("tron\xc3\xa7on.test.swift.im"); + HostAddressPort result = resolver_->resolve("tron\xc3\xa7on.test.swift.im")[0]; CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.3"), result.getAddress().toString()); CPPUNIT_ASSERT_EQUAL(5222, result.getPort()); } + void testResolve_Localhost() { + HostAddressPort result = resolver_->resolve("localhost")[0]; + CPPUNIT_ASSERT_EQUAL(std::string("127.0.0.1"), result.getAddress().toString()); + CPPUNIT_ASSERT_EQUAL(5222, result.getPort()); + } + private: - DomainNameResolver* resolver_; + PlatformDomainNameResolver* resolver_; }; CPPUNIT_TEST_SUITE_REGISTRATION(DomainNameResolverTest); diff --git a/Swiften/SConscript b/Swiften/SConscript index d5ddce4..d896cd8 100644 --- a/Swiften/SConscript +++ b/Swiften/SConscript @@ -121,6 +121,7 @@ env.Append(UNITTEST_SOURCES = [ File("LinkLocal/UnitTest/LinkLocalServiceInfoTest.cpp"), File("LinkLocal/UnitTest/LinkLocalServiceTest.cpp"), File("Network/UnitTest/HostAddressTest.cpp"), + File("Network/UnitTest/ConnectorTest.cpp"), File("Parser/PayloadParsers/UnitTest/BodyParserTest.cpp"), File("Parser/PayloadParsers/UnitTest/DiscoInfoParserTest.cpp"), File("Parser/PayloadParsers/UnitTest/ErrorParserTest.cpp"), diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp index e0fbce7..a9a3cb0 100644 --- a/Swiften/Session/BasicSessionStream.cpp +++ b/Swiften/Session/BasicSessionStream.cpp @@ -20,13 +20,13 @@ void BasicSessionStream::initialize() { xmppLayer->onStreamStart.connect(boost::bind(&BasicSessionStream::handleStreamStartReceived, shared_from_this(), _1)); xmppLayer->onElement.connect(boost::bind(&BasicSessionStream::handleElementReceived, shared_from_this(), _1)); xmppLayer->onError.connect(boost::bind( - &BasicSessionStream::handleXMPPError, shared_from_this())); - xmppLayer->onDataRead.connect(boost::bind(&BasicSessionStream::handleDataRead, shared_from_this(), _1)); - xmppLayer->onWriteData.connect(boost::bind(&BasicSessionStream::handleDataWritten, shared_from_this(), _1)); + &BasicSessionStream::handleXMPPError, shared_from_this())); + xmppLayer->onDataRead.connect(boost::bind(&BasicSessionStream::handleDataRead, shared_from_this(), _1)); + xmppLayer->onWriteData.connect(boost::bind(&BasicSessionStream::handleDataWritten, shared_from_this(), _1)); connection->onDisconnected.connect(boost::bind(&BasicSessionStream::handleConnectionError, shared_from_this(), _1)); connectionLayer = boost::shared_ptr( - new ConnectionLayer(connection)); + new ConnectionLayer(connection)); streamStack = new StreamStack(xmppLayer, connectionLayer); @@ -34,7 +34,7 @@ void BasicSessionStream::initialize() { } BasicSessionStream::~BasicSessionStream() { - delete streamStack; + delete streamStack; } void BasicSessionStream::writeHeader(const ProtocolHeader& header) { @@ -57,7 +57,7 @@ bool BasicSessionStream::isAvailable() { } bool BasicSessionStream::supportsTLSEncryption() { - return tlsLayerFactory && tlsLayerFactory->canCreate(); + return tlsLayerFactory && tlsLayerFactory->canCreate(); } void BasicSessionStream::addTLSEncryption() { @@ -88,7 +88,7 @@ void BasicSessionStream::setWhitespacePingEnabled(bool enabled) { } void BasicSessionStream::resetXMPPParser() { - xmppLayer->resetParser(); + xmppLayer->resetParser(); } void BasicSessionStream::handleStreamStartReceived(const ProtocolHeader& header) { diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h index 0cb50eb..07bae81 100644 --- a/Swiften/Session/BasicSessionStream.h +++ b/Swiften/Session/BasicSessionStream.h @@ -7,26 +7,26 @@ #include "Swiften/Session/SessionStream.h" namespace Swift { - class TLSLayerFactory; - class TLSLayer; - class WhitespacePingLayer; + class TLSLayerFactory; + class TLSLayer; + class WhitespacePingLayer; class PayloadParserFactoryCollection; class PayloadSerializerCollection; - class StreamStack; + class StreamStack; class XMPPLayer; - class ConnectionLayer; + class ConnectionLayer; - class BasicSessionStream : - public SessionStream, - public boost::enable_shared_from_this { - public: - BasicSessionStream( - boost::shared_ptr connection, - PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers, - TLSLayerFactory* tlsLayerFactory - ); - ~BasicSessionStream(); + class BasicSessionStream : + public SessionStream, + public boost::enable_shared_from_this { + public: + BasicSessionStream( + boost::shared_ptr connection, + PayloadParserFactoryCollection* payloadParserFactories, + PayloadSerializerCollection* payloadSerializers, + TLSLayerFactory* tlsLayerFactory + ); + ~BasicSessionStream(); void initialize(); @@ -43,17 +43,17 @@ namespace Swift { virtual void resetXMPPParser(); - private: + private: void handleConnectionError(const boost::optional& error); - void handleXMPPError(); + void handleXMPPError(); void handleTLSConnected(); - void handleTLSError(); + void handleTLSError(); void handleStreamStartReceived(const ProtocolHeader&); void handleElementReceived(boost::shared_ptr); - void handleDataRead(const ByteArray& data); - void handleDataWritten(const ByteArray& data); + void handleDataRead(const ByteArray& data); + void handleDataWritten(const ByteArray& data); - private: + private: bool available; boost::shared_ptr connection; PayloadParserFactoryCollection* payloadParserFactories; @@ -62,7 +62,7 @@ namespace Swift { boost::shared_ptr xmppLayer; boost::shared_ptr connectionLayer; StreamStack* streamStack; - boost::shared_ptr tlsLayer; - boost::shared_ptr whitespacePingLayer; - }; + boost::shared_ptr tlsLayer; + boost::shared_ptr whitespacePingLayer; + }; } -- cgit v0.10.2-6-g49f6