diff options
author | Remko Tronçon <git@el-tramo.be> | 2009-11-12 18:12:47 (GMT) |
---|---|---|
committer | Remko Tronçon <git@el-tramo.be> | 2009-11-12 18:12:47 (GMT) |
commit | fdd8755e2363e8d706a3d0bdc2e71f234abdf829 (patch) | |
tree | 470401f6f80873c4e1ce5af5cd30ab6837854d04 /Swiften/Network | |
parent | 6a20be61e229255f93d55f13be3346525698237a (diff) | |
download | swift-contrib-fdd8755e2363e8d706a3d0bdc2e71f234abdf829.zip swift-contrib-fdd8755e2363e8d706a3d0bdc2e71f234abdf829.tar.bz2 |
Refactored DNS handling.
Connections now fallback on other DNS entries upon failure,
taking into account SRV priorities.
Diffstat (limited to 'Swiften/Network')
-rw-r--r-- | Swiften/Network/Connector.cpp | 48 | ||||
-rw-r--r-- | Swiften/Network/Connector.h | 34 | ||||
-rw-r--r-- | Swiften/Network/DomainNameResolver.cpp | 168 | ||||
-rw-r--r-- | Swiften/Network/DomainNameResolver.h | 17 | ||||
-rw-r--r-- | Swiften/Network/DummyConnection.h | 40 | ||||
-rw-r--r-- | Swiften/Network/HostAddress.h | 4 | ||||
-rw-r--r-- | Swiften/Network/HostAddressPort.h | 4 | ||||
-rw-r--r-- | Swiften/Network/PlatformDomainNameResolver.cpp | 200 | ||||
-rw-r--r-- | Swiften/Network/PlatformDomainNameResolver.h | 25 | ||||
-rw-r--r-- | Swiften/Network/SConscript | 3 | ||||
-rw-r--r-- | Swiften/Network/SRVRecord.h | 10 | ||||
-rw-r--r-- | Swiften/Network/SRVRecordPriorityComparator.h | 11 | ||||
-rw-r--r-- | Swiften/Network/StaticDomainNameResolver.cpp | 28 | ||||
-rw-r--r-- | Swiften/Network/StaticDomainNameResolver.h | 23 | ||||
-rw-r--r-- | Swiften/Network/UnitTest/ConnectorTest.cpp | 140 |
15 files changed, 573 insertions, 182 deletions
diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp new file mode 100644 index 0000000..5b4fe22 --- /dev/null +++ b/Swiften/Network/Connector.cpp @@ -0,0 +1,48 @@ +#include "Swiften/Network/Connector.h" + +#include <boost/bind.hpp> + +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/Network/DomainNameResolver.h" +#include "Swiften/Network/DomainNameResolveException.h" + +namespace Swift { + +Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory) { +} + +void Connector::start() { + assert(!currentConnection); + try { + std::vector<HostAddressPort> resolveResult = resolver->resolve(hostname.getUTF8String()); + resolvedHosts = std::deque<HostAddressPort>(resolveResult.begin(), resolveResult.end()); + tryNextHostname(); + } + catch (const DomainNameResolveException&) { + onConnectFinished(boost::shared_ptr<Connection>()); + } +} + +void Connector::tryNextHostname() { + if (resolvedHosts.empty()) { + onConnectFinished(boost::shared_ptr<Connection>()); + } + else { + HostAddressPort remote = resolvedHosts.front(); + resolvedHosts.pop_front(); + currentConnection = connectionFactory->createConnection(); + currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, this, _1)); + currentConnection->connect(remote); + } +} + +void Connector::handleConnectionConnectFinished(bool error) { + if (error) { + tryNextHostname(); + } + else { + onConnectFinished(currentConnection); + } +} + +}; diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h new file mode 100644 index 0000000..084c416 --- /dev/null +++ b/Swiften/Network/Connector.h @@ -0,0 +1,34 @@ +#pragma once + +#include <deque> +#include <boost/signal.hpp> +#include <boost/shared_ptr.hpp> + +#include "Swiften/Network/Connection.h" +#include "Swiften/Base/String.h" + +namespace Swift { + class DomainNameResolver; + class ConnectionFactory; + + class Connector : public boost::bsignals::trackable { + public: + Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*); + + void start(); + + boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished; + + private: + void tryNextHostname(); + void handleConnectionConnectFinished(bool error); + + private: + String hostname; + DomainNameResolver* resolver; + ConnectionFactory* connectionFactory; + std::deque<HostAddressPort> resolvedHosts; + boost::shared_ptr<Connection> currentConnection; + }; + +}; diff --git a/Swiften/Network/DomainNameResolver.cpp b/Swiften/Network/DomainNameResolver.cpp index 44b3ecf..907dfc9 100644 --- a/Swiften/Network/DomainNameResolver.cpp +++ b/Swiften/Network/DomainNameResolver.cpp @@ -1,176 +1,8 @@ #include "Swiften/Network/DomainNameResolver.h" -#include "Swiften/Base/Platform.h" - -#include <stdlib.h> -#include <boost/asio.hpp> -#include <idna.h> -#ifdef SWIFTEN_PLATFORM_WINDOWS -#undef UNICODE -#include <windows.h> -#include <windns.h> -#ifndef DNS_TYPE_SRV -#define DNS_TYPE_SRV 33 -#endif -#else -#include <arpa/nameser.h> -#include <arpa/nameser_compat.h> -#include <resolv.h> -#endif - -#include "Swiften/Network/DomainNameResolveException.h" -#include "Swiften/Base/String.h" -#include "Swiften/Base/ByteArray.h" namespace Swift { -DomainNameResolver::DomainNameResolver() { -} - DomainNameResolver::~DomainNameResolver() { } -HostAddressPort DomainNameResolver::resolve(const String& domain) { - char* output; - if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) { - std::string outputString(output); - free(output); - return resolveDomain(outputString); - } - else { - return resolveDomain(domain.getUTF8String()); - } -} - -HostAddressPort DomainNameResolver::resolveDomain(const std::string& domain) { - try { - return resolveXMPPService(domain); - } - catch (const DomainNameResolveException&) { - } - return HostAddressPort(resolveHostName(domain), 5222); -} - -HostAddressPort DomainNameResolver::resolveXMPPService(const std::string& domain) { - std::string srvQuery = "_xmpp-client._tcp." + domain; - -#if defined(SWIFTEN_PLATFORM_WINDOWS) - DNS_RECORD* responses; - // FIXME: This conversion doesn't work if unicode is deffed above - if (DnsQuery(srvQuery.c_str(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) { - throw DomainNameResolveException(); - } - - DNS_RECORD* currentEntry = responses; - while (currentEntry) { - if (currentEntry->wType == DNS_TYPE_SRV) { - int port = currentEntry->Data.SRV.wPort; - try { - // The pNameTarget is actually a PCWSTR, so I would have expected this - // conversion to not work at all, but it does. - // Actually, it doesn't. Fix this and remove explicit cast - // Remove unicode undef above as well - std::string hostname((const char*) currentEntry->Data.SRV.pNameTarget); - HostAddress address = resolveHostName(hostname); - DnsRecordListFree(responses, DnsFreeRecordList); - return HostAddressPort(address, port); - } - catch (const DomainNameResolveException&) { - } - } - currentEntry = currentEntry->pNext; - } - DnsRecordListFree(responses, DnsFreeRecordList); - -#else - - ByteArray response; - response.resize(NS_PACKETSZ); - int responseLength = res_query(const_cast<char*>(srvQuery.c_str()), ns_c_in, ns_t_srv, reinterpret_cast<u_char*>(response.getData()), response.getSize()); - if (responseLength == -1) { - throw DomainNameResolveException(); - } - - // Parse header - HEADER* header = reinterpret_cast<HEADER*>(response.getData()); - unsigned char* messageStart = reinterpret_cast<unsigned char*>(response.getData()); - unsigned char* messageEnd = messageStart + responseLength; - unsigned char* currentEntry = messageStart + NS_HFIXEDSZ; - - // Skip over the queries - int queriesCount = ntohs(header->qdcount); - while (queriesCount > 0) { - int entryLength = dn_skipname(currentEntry, messageEnd); - if (entryLength < 0) { - throw DomainNameResolveException(); - } - currentEntry += entryLength + NS_QFIXEDSZ; - queriesCount--; - } - - // Process the SRV answers - int answersCount = ntohs(header->ancount); - while (answersCount > 0) { - int entryLength = dn_skipname(currentEntry, messageEnd); - currentEntry += entryLength; - currentEntry += NS_RRFIXEDSZ; - - // Uninteresting information - currentEntry += 2; // PRIORITY - currentEntry += 2; // WEIGHT - - // Port - if (currentEntry >= messageEnd) { - throw DomainNameResolveException(); - } - int port = ns_get16(currentEntry); - currentEntry += 2; - - // Hostname - if (currentEntry >= messageEnd) { - throw DomainNameResolveException(); - } - ByteArray entry; - entry.resize(NS_MAXDNAME); - entryLength = dn_expand(messageStart, messageEnd, currentEntry, entry.getData(), entry.getSize()); - if (entryLength < 0) { - throw DomainNameResolveException(); - } - try { - // Resolve the hostname - std::string hostname(entry.getData(), entryLength); - HostAddress address = resolveHostName(hostname); - return HostAddressPort(address, port); - } - catch (const DomainNameResolveException&) { - } - currentEntry += entryLength; - answersCount--; - } -#endif - - throw DomainNameResolveException(); -} - -HostAddress DomainNameResolver::resolveHostName(const std::string& hostname) { - boost::asio::io_service ioService; - boost::asio::ip::tcp::resolver resolver(ioService); - boost::asio::ip::tcp::resolver::query query(hostname, "5222"); - try { - boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query); - if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) { - throw DomainNameResolveException(); - } - boost::asio::ip::address address = (*endpointIterator).endpoint().address(); - if (address.is_v4()) { - return HostAddress(&address.to_v4().to_bytes()[0], 4); - } - else { - return HostAddress(&address.to_v6().to_bytes()[0], 16); - } - } - catch (...) { - throw DomainNameResolveException(); - } -} - } diff --git a/Swiften/Network/DomainNameResolver.h b/Swiften/Network/DomainNameResolver.h index c7736b1..5c83622 100644 --- a/Swiften/Network/DomainNameResolver.h +++ b/Swiften/Network/DomainNameResolver.h @@ -1,10 +1,7 @@ -#ifndef SWIFTEN_DOMAINNAMERESOLVER_H -#define SWIFTEN_DOMAINNAMERESOLVER_H +#pragma once -#include <string> +#include <vector> -#include "Swiften/Base/String.h" -#include "Swiften/Network/HostAddress.h" #include "Swiften/Network/HostAddressPort.h" namespace Swift { @@ -12,16 +9,8 @@ namespace Swift { class DomainNameResolver { public: - DomainNameResolver(); virtual ~DomainNameResolver(); - HostAddressPort resolve(const String& domain); - - private: - virtual HostAddressPort resolveDomain(const std::string& domain); - HostAddressPort resolveXMPPService(const std::string& domain); - HostAddress resolveHostName(const std::string& hostName); + virtual std::vector<HostAddressPort> resolve(const String& domain) = 0; }; } - -#endif diff --git a/Swiften/Network/DummyConnection.h b/Swiften/Network/DummyConnection.h new file mode 100644 index 0000000..11281b3 --- /dev/null +++ b/Swiften/Network/DummyConnection.h @@ -0,0 +1,40 @@ +#pragma once + +#include <cassert> +#include <boost/bind.hpp> +#include <boost/enable_shared_from_this.hpp> + +#include "Swiften/Network/Connection.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/EventLoop/EventOwner.h" + +namespace Swift { + class DummyConnection : + public Connection, + public EventOwner, + public boost::enable_shared_from_this<DummyConnection> { + + void listen() { + assert(false); + } + + void connect(const HostAddressPort&) { + assert(false); + } + + void disconnect() { + assert(false); + } + + void write(const ByteArray& data) { + onDataWritten(data); + } + + void receive(const ByteArray& data) { + MainEventLoop::postEvent(boost::bind( + boost::ref(onDataRead), ByteArray(data)), shared_from_this()); + } + + boost::signal<void (const ByteArray&)> onDataWritten; + }; +} diff --git a/Swiften/Network/HostAddress.h b/Swiften/Network/HostAddress.h index bf6d2f8..11f8a2b 100644 --- a/Swiften/Network/HostAddress.h +++ b/Swiften/Network/HostAddress.h @@ -18,6 +18,10 @@ namespace Swift { std::string toString() const; + bool operator==(const HostAddress& o) const { + return address_ == o.address_; + } + private: std::vector<unsigned char> address_; }; diff --git a/Swiften/Network/HostAddressPort.h b/Swiften/Network/HostAddressPort.h index 8668ae4..d632058 100644 --- a/Swiften/Network/HostAddressPort.h +++ b/Swiften/Network/HostAddressPort.h @@ -17,6 +17,10 @@ namespace Swift { return port_; } + bool operator==(const HostAddressPort& o) const { + return address_ == o.address_ && port_ == o.port_; + } + private: HostAddress address_; int port_; diff --git a/Swiften/Network/PlatformDomainNameResolver.cpp b/Swiften/Network/PlatformDomainNameResolver.cpp new file mode 100644 index 0000000..3803e7c --- /dev/null +++ b/Swiften/Network/PlatformDomainNameResolver.cpp @@ -0,0 +1,200 @@ +#include "Swiften/Network/PlatformDomainNameResolver.h" + +#include "Swiften/Base/Platform.h" +#include "Swiften/Base/foreach.h" + +#include <stdlib.h> +#include <boost/asio.hpp> +#include <idna.h> +#ifdef SWIFTEN_PLATFORM_WINDOWS +#undef UNICODE +#include <windows.h> +#include <windns.h> +#ifndef DNS_TYPE_SRV +#define DNS_TYPE_SRV 33 +#endif +#else +#include <arpa/nameser.h> +#include <arpa/nameser_compat.h> +#include <resolv.h> +#endif +#include <algorithm> + +#include "Swiften/Network/DomainNameResolveException.h" +#include "Swiften/Base/String.h" +#include "Swiften/Base/ByteArray.h" +#include "Swiften/Network/SRVRecord.h" +#include "Swiften/Network/SRVRecordPriorityComparator.h" + +namespace Swift { + +PlatformDomainNameResolver::PlatformDomainNameResolver() { +} + +std::vector<HostAddressPort> PlatformDomainNameResolver::resolve(const String& domain) { + char* output; + if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) { + std::string outputString(output); + free(output); + return resolveDomain(outputString); + } + else { + return resolveDomain(domain.getUTF8String()); + } +} + +std::vector<HostAddressPort> PlatformDomainNameResolver::resolveDomain(const std::string& domain) { + try { + return resolveXMPPService(domain); + } + catch (const DomainNameResolveException&) { + } + std::vector<HostAddressPort> result; + result.push_back(HostAddressPort(resolveHostName(domain), 5222)); + return result; +} + +std::vector<HostAddressPort> PlatformDomainNameResolver::resolveXMPPService(const std::string& domain) { + std::vector<SRVRecord> records; + + std::string srvQuery = "_xmpp-client._tcp." + domain; + +#if defined(SWIFTEN_PLATFORM_WINDOWS) + DNS_RECORD* responses; + // FIXME: This conversion doesn't work if unicode is deffed above + if (DnsQuery(srvQuery.c_str(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) { + throw DomainNameResolveException(); + } + + DNS_RECORD* currentEntry = responses; + while (currentEntry) { + if (currentEntry->wType == DNS_TYPE_SRV) { + SRVRecord record; + record.priority = currentEntry->Data.SRV.wPriority; + record.weight = currentEntry->Data.SRV.wWeight; + record.port = currentEntry->Data.SRV.wPort; + + // The pNameTarget is actually a PCWSTR, so I would have expected this + // conversion to not work at all, but it does. + // Actually, it doesn't. Fix this and remove explicit cast + // Remove unicode undef above as well + record.hostname = std::string((const char*) currentEntry->Data.SRV.pNameTarget); + records.push_back(record); + } + currentEntry = currentEntry->pNext; + } + DnsRecordListFree(responses, DnsFreeRecordList); + +#else + + ByteArray response; + response.resize(NS_PACKETSZ); + int responseLength = res_query(const_cast<char*>(srvQuery.c_str()), ns_c_in, ns_t_srv, reinterpret_cast<u_char*>(response.getData()), response.getSize()); + if (responseLength == -1) { + throw DomainNameResolveException(); + } + + // Parse header + HEADER* header = reinterpret_cast<HEADER*>(response.getData()); + unsigned char* messageStart = reinterpret_cast<unsigned char*>(response.getData()); + unsigned char* messageEnd = messageStart + responseLength; + unsigned char* currentEntry = messageStart + NS_HFIXEDSZ; + + // Skip over the queries + int queriesCount = ntohs(header->qdcount); + while (queriesCount > 0) { + int entryLength = dn_skipname(currentEntry, messageEnd); + if (entryLength < 0) { + throw DomainNameResolveException(); + } + currentEntry += entryLength + NS_QFIXEDSZ; + queriesCount--; + } + + // Process the SRV answers + int answersCount = ntohs(header->ancount); + while (answersCount > 0) { + SRVRecord record; + + int entryLength = dn_skipname(currentEntry, messageEnd); + currentEntry += entryLength; + currentEntry += NS_RRFIXEDSZ; + + // Priority + if (currentEntry + 2 >= messageEnd) { + throw DomainNameResolveException(); + } + record.priority = ns_get16(currentEntry); + currentEntry += 2; + + // Weight + if (currentEntry + 2 >= messageEnd) { + throw DomainNameResolveException(); + } + record.weight = ns_get16(currentEntry); + currentEntry += 2; + + // Port + if (currentEntry + 2 >= messageEnd) { + throw DomainNameResolveException(); + } + record.port = ns_get16(currentEntry); + currentEntry += 2; + + // Hostname + if (currentEntry >= messageEnd) { + throw DomainNameResolveException(); + } + ByteArray entry; + entry.resize(NS_MAXDNAME); + entryLength = dn_expand(messageStart, messageEnd, currentEntry, entry.getData(), entry.getSize()); + if (entryLength < 0) { + throw DomainNameResolveException(); + } + record.hostname = std::string(entry.getData(), entryLength); + records.push_back(record); + currentEntry += entryLength; + answersCount--; + } +#endif + + // Resolve the hostnames in the records, and build the result list + std::sort(records.begin(), records.end(), SRVRecordPriorityComparator()); + std::vector<HostAddressPort> result; + foreach(const SRVRecord& record, records) { + try { + HostAddress address = resolveHostName(record.hostname); + result.push_back(HostAddressPort(address, record.port)); + } + catch (const DomainNameResolveException&) { + } + } + if (result.empty()) { + throw DomainNameResolveException(); + } + return result; +} + +HostAddress PlatformDomainNameResolver::resolveHostName(const std::string& hostname) { + boost::asio::io_service ioService; + boost::asio::ip::tcp::resolver resolver(ioService); + boost::asio::ip::tcp::resolver::query query(hostname, "5222"); + try { + boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query); + if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) { + throw DomainNameResolveException(); + } + boost::asio::ip::address address = (*endpointIterator).endpoint().address(); + if (address.is_v4()) { + return HostAddress(&address.to_v4().to_bytes()[0], 4); + } + else { + return HostAddress(&address.to_v6().to_bytes()[0], 16); + } + } + catch (...) { + throw DomainNameResolveException(); + } +} + +} diff --git a/Swiften/Network/PlatformDomainNameResolver.h b/Swiften/Network/PlatformDomainNameResolver.h new file mode 100644 index 0000000..651bc97 --- /dev/null +++ b/Swiften/Network/PlatformDomainNameResolver.h @@ -0,0 +1,25 @@ +#pragma once + +#include <string> +#include <vector> + +#include "Swiften/Base/String.h" +#include "Swiften/Network/DomainNameResolver.h" +#include "Swiften/Network/HostAddress.h" +#include "Swiften/Network/HostAddressPort.h" + +namespace Swift { + class String; + + class PlatformDomainNameResolver : public DomainNameResolver { + public: + PlatformDomainNameResolver(); + + std::vector<HostAddressPort> resolve(const String& domain); + + private: + std::vector<HostAddressPort> resolveDomain(const std::string& domain); + std::vector<HostAddressPort> resolveXMPPService(const std::string& domain); + HostAddress resolveHostName(const std::string& hostName); + }; +} diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript index 06d9350..652bda1 100644 --- a/Swiften/Network/SConscript +++ b/Swiften/Network/SConscript @@ -11,7 +11,10 @@ objects = myenv.StaticObject([ "BoostIOServiceThread.cpp", "ConnectionFactory.cpp", "ConnectionServer.cpp", + "Connector.cpp", "DomainNameResolver.cpp", + "PlatformDomainNameResolver.cpp", + "StaticDomainNameResolver.cpp", "HostAddress.cpp", "Timer.cpp", ]) diff --git a/Swiften/Network/SRVRecord.h b/Swiften/Network/SRVRecord.h new file mode 100644 index 0000000..5a11635 --- /dev/null +++ b/Swiften/Network/SRVRecord.h @@ -0,0 +1,10 @@ +#pragma once + +namespace Swift { + struct SRVRecord { + std::string hostname; + int port; + int priority; + int weight; + }; +} diff --git a/Swiften/Network/SRVRecordPriorityComparator.h b/Swiften/Network/SRVRecordPriorityComparator.h new file mode 100644 index 0000000..fc16597 --- /dev/null +++ b/Swiften/Network/SRVRecordPriorityComparator.h @@ -0,0 +1,11 @@ +#pragma once + +#include "Swiften/Network/SRVRecord.h" + +namespace Swift { + struct SRVRecordPriorityComparator { + bool operator()(const SRVRecord& a, const SRVRecord& b) const { + return a.priority < b.priority; + } + }; +} diff --git a/Swiften/Network/StaticDomainNameResolver.cpp b/Swiften/Network/StaticDomainNameResolver.cpp new file mode 100644 index 0000000..8ca4062 --- /dev/null +++ b/Swiften/Network/StaticDomainNameResolver.cpp @@ -0,0 +1,28 @@ +#include "Swiften/Network/StaticDomainNameResolver.h" +#include "Swiften/Network/DomainNameResolveException.h" +#include "Swiften/Base/String.h" + +namespace Swift { + +StaticDomainNameResolver::StaticDomainNameResolver() { +} + +std::vector<HostAddressPort> StaticDomainNameResolver::resolve(const String& queriedDomain) { + std::vector<HostAddressPort> result; + + for(DomainCollection::const_iterator i = domains.begin(); i != domains.end(); ++i) { + if (i->first == queriedDomain) { + result.push_back(i->second); + } + } + if (result.empty()) { + throw DomainNameResolveException(); + } + return result; +} + +void StaticDomainNameResolver::addDomain(const String& domain, const HostAddressPort& addressPort) { + domains.push_back(std::make_pair(domain, addressPort)); +} + +} diff --git a/Swiften/Network/StaticDomainNameResolver.h b/Swiften/Network/StaticDomainNameResolver.h new file mode 100644 index 0000000..8688429 --- /dev/null +++ b/Swiften/Network/StaticDomainNameResolver.h @@ -0,0 +1,23 @@ +#pragma once + +#include <vector> +#include <map> + +#include "Swiften/Network/DomainNameResolver.h" + +namespace Swift { + class String; + + class StaticDomainNameResolver : public DomainNameResolver { + public: + StaticDomainNameResolver(); + + virtual std::vector<HostAddressPort> resolve(const String& domain); + + void addDomain(const String& domain, const HostAddressPort& addressPort); + + private: + typedef std::vector< std::pair<String, HostAddressPort> > DomainCollection; + DomainCollection domains; + }; +} diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp new file mode 100644 index 0000000..32893d8 --- /dev/null +++ b/Swiften/Network/UnitTest/ConnectorTest.cpp @@ -0,0 +1,140 @@ +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> + +#include <boost/optional.hpp> +#include <boost/bind.hpp> + +#include "Swiften/Network/Connector.h" +#include "Swiften/Network/Connection.h" +#include "Swiften/Network/ConnectionFactory.h" +#include "Swiften/Network/HostAddressPort.h" +#include "Swiften/Network/StaticDomainNameResolver.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/EventLoop/DummyEventLoop.h" + +using namespace Swift; + +class ConnectorTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(ConnectorTest); + CPPUNIT_TEST(testConnect); + CPPUNIT_TEST(testConnect_NoHosts); + CPPUNIT_TEST(testConnect_FirstHostFails); + CPPUNIT_TEST(testConnect_AllHostsFail); + CPPUNIT_TEST_SUITE_END(); + + public: + ConnectorTest() : host1(HostAddress("1.1.1.1"), 1234), host2(HostAddress("2.2.2.2"), 2345) { + } + + void setUp() { + eventLoop = new DummyEventLoop(); + resolver = new StaticDomainNameResolver(); + connectionFactory = new MockConnectionFactory(); + } + + void tearDown() { + delete connectionFactory; + delete resolver; + delete eventLoop; + } + + void testConnect() { + std::auto_ptr<Connector> testling(createConnector()); + resolver->addDomain("foo.com", host1); + resolver->addDomain("foo.com", host2); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(host1 == *(connections[0]->hostAddressPort)); + } + + void testConnect_NoHosts() { + std::auto_ptr<Connector> testling(createConnector()); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + void testConnect_FirstHostFails() { + std::auto_ptr<Connector> testling(createConnector()); + resolver->addDomain("foo.com", host1); + resolver->addDomain("foo.com", host2); + connectionFactory->failingPorts.push_back(host1); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(host2 == *(connections[0]->hostAddressPort)); + } + + void testConnect_AllHostsFail() { + std::auto_ptr<Connector> testling(createConnector()); + resolver->addDomain("foo.com", host1); + resolver->addDomain("foo.com", host2); + connectionFactory->failingPorts.push_back(host1); + connectionFactory->failingPorts.push_back(host2); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + private: + Connector* createConnector() { + Connector* connector = new Connector("foo.com", resolver, connectionFactory); + connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1)); + return connector; + } + + void handleConnectorFinished(boost::shared_ptr<Connection> connection) { + boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection)); + if (connection) { + assert(c); + } + connections.push_back(c); + } + + struct MockConnection : public Connection { + public: + MockConnection(const std::vector<HostAddressPort>& failingPorts) : failingPorts(failingPorts) {} + + void listen() { assert(false); } + void connect(const HostAddressPort& address) { + hostAddressPort = address; + MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end())); + } + + void disconnect() { assert(false); } + void write(const ByteArray&) { assert(false); } + + boost::optional<HostAddressPort> hostAddressPort; + std::vector<HostAddressPort> failingPorts; + }; + + struct MockConnectionFactory : public ConnectionFactory { + boost::shared_ptr<Connection> createConnection() { + return boost::shared_ptr<Connection>(new MockConnection(failingPorts)); + } + + std::vector<HostAddressPort> failingPorts; + }; + + private: + HostAddressPort host1; + HostAddressPort host2; + DummyEventLoop* eventLoop; + StaticDomainNameResolver* resolver; + MockConnectionFactory* connectionFactory; + std::vector< boost::shared_ptr<MockConnection> > connections; +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(ConnectorTest); |