summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'Swiften/Network')
-rw-r--r--Swiften/Network/Connector.cpp48
-rw-r--r--Swiften/Network/Connector.h34
-rw-r--r--Swiften/Network/DomainNameResolver.cpp168
-rw-r--r--Swiften/Network/DomainNameResolver.h17
-rw-r--r--Swiften/Network/DummyConnection.h40
-rw-r--r--Swiften/Network/HostAddress.h4
-rw-r--r--Swiften/Network/HostAddressPort.h4
-rw-r--r--Swiften/Network/PlatformDomainNameResolver.cpp200
-rw-r--r--Swiften/Network/PlatformDomainNameResolver.h25
-rw-r--r--Swiften/Network/SConscript3
-rw-r--r--Swiften/Network/SRVRecord.h10
-rw-r--r--Swiften/Network/SRVRecordPriorityComparator.h11
-rw-r--r--Swiften/Network/StaticDomainNameResolver.cpp28
-rw-r--r--Swiften/Network/StaticDomainNameResolver.h23
-rw-r--r--Swiften/Network/UnitTest/ConnectorTest.cpp140
15 files changed, 573 insertions, 182 deletions
diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp
new file mode 100644
index 0000000..5b4fe22
--- /dev/null
+++ b/Swiften/Network/Connector.cpp
@@ -0,0 +1,48 @@
+#include "Swiften/Network/Connector.h"
+
+#include <boost/bind.hpp>
+
+#include "Swiften/Network/ConnectionFactory.h"
+#include "Swiften/Network/DomainNameResolver.h"
+#include "Swiften/Network/DomainNameResolveException.h"
+
+namespace Swift {
+
+Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory) {
+}
+
+void Connector::start() {
+ assert(!currentConnection);
+ try {
+ std::vector<HostAddressPort> resolveResult = resolver->resolve(hostname.getUTF8String());
+ resolvedHosts = std::deque<HostAddressPort>(resolveResult.begin(), resolveResult.end());
+ tryNextHostname();
+ }
+ catch (const DomainNameResolveException&) {
+ onConnectFinished(boost::shared_ptr<Connection>());
+ }
+}
+
+void Connector::tryNextHostname() {
+ if (resolvedHosts.empty()) {
+ onConnectFinished(boost::shared_ptr<Connection>());
+ }
+ else {
+ HostAddressPort remote = resolvedHosts.front();
+ resolvedHosts.pop_front();
+ currentConnection = connectionFactory->createConnection();
+ currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, this, _1));
+ currentConnection->connect(remote);
+ }
+}
+
+void Connector::handleConnectionConnectFinished(bool error) {
+ if (error) {
+ tryNextHostname();
+ }
+ else {
+ onConnectFinished(currentConnection);
+ }
+}
+
+};
diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h
new file mode 100644
index 0000000..084c416
--- /dev/null
+++ b/Swiften/Network/Connector.h
@@ -0,0 +1,34 @@
+#pragma once
+
+#include <deque>
+#include <boost/signal.hpp>
+#include <boost/shared_ptr.hpp>
+
+#include "Swiften/Network/Connection.h"
+#include "Swiften/Base/String.h"
+
+namespace Swift {
+ class DomainNameResolver;
+ class ConnectionFactory;
+
+ class Connector : public boost::bsignals::trackable {
+ public:
+ Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*);
+
+ void start();
+
+ boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished;
+
+ private:
+ void tryNextHostname();
+ void handleConnectionConnectFinished(bool error);
+
+ private:
+ String hostname;
+ DomainNameResolver* resolver;
+ ConnectionFactory* connectionFactory;
+ std::deque<HostAddressPort> resolvedHosts;
+ boost::shared_ptr<Connection> currentConnection;
+ };
+
+};
diff --git a/Swiften/Network/DomainNameResolver.cpp b/Swiften/Network/DomainNameResolver.cpp
index 44b3ecf..907dfc9 100644
--- a/Swiften/Network/DomainNameResolver.cpp
+++ b/Swiften/Network/DomainNameResolver.cpp
@@ -1,176 +1,8 @@
#include "Swiften/Network/DomainNameResolver.h"
-#include "Swiften/Base/Platform.h"
-
-#include <stdlib.h>
-#include <boost/asio.hpp>
-#include <idna.h>
-#ifdef SWIFTEN_PLATFORM_WINDOWS
-#undef UNICODE
-#include <windows.h>
-#include <windns.h>
-#ifndef DNS_TYPE_SRV
-#define DNS_TYPE_SRV 33
-#endif
-#else
-#include <arpa/nameser.h>
-#include <arpa/nameser_compat.h>
-#include <resolv.h>
-#endif
-
-#include "Swiften/Network/DomainNameResolveException.h"
-#include "Swiften/Base/String.h"
-#include "Swiften/Base/ByteArray.h"
namespace Swift {
-DomainNameResolver::DomainNameResolver() {
-}
-
DomainNameResolver::~DomainNameResolver() {
}
-HostAddressPort DomainNameResolver::resolve(const String& domain) {
- char* output;
- if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) {
- std::string outputString(output);
- free(output);
- return resolveDomain(outputString);
- }
- else {
- return resolveDomain(domain.getUTF8String());
- }
-}
-
-HostAddressPort DomainNameResolver::resolveDomain(const std::string& domain) {
- try {
- return resolveXMPPService(domain);
- }
- catch (const DomainNameResolveException&) {
- }
- return HostAddressPort(resolveHostName(domain), 5222);
-}
-
-HostAddressPort DomainNameResolver::resolveXMPPService(const std::string& domain) {
- std::string srvQuery = "_xmpp-client._tcp." + domain;
-
-#if defined(SWIFTEN_PLATFORM_WINDOWS)
- DNS_RECORD* responses;
- // FIXME: This conversion doesn't work if unicode is deffed above
- if (DnsQuery(srvQuery.c_str(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) {
- throw DomainNameResolveException();
- }
-
- DNS_RECORD* currentEntry = responses;
- while (currentEntry) {
- if (currentEntry->wType == DNS_TYPE_SRV) {
- int port = currentEntry->Data.SRV.wPort;
- try {
- // The pNameTarget is actually a PCWSTR, so I would have expected this
- // conversion to not work at all, but it does.
- // Actually, it doesn't. Fix this and remove explicit cast
- // Remove unicode undef above as well
- std::string hostname((const char*) currentEntry->Data.SRV.pNameTarget);
- HostAddress address = resolveHostName(hostname);
- DnsRecordListFree(responses, DnsFreeRecordList);
- return HostAddressPort(address, port);
- }
- catch (const DomainNameResolveException&) {
- }
- }
- currentEntry = currentEntry->pNext;
- }
- DnsRecordListFree(responses, DnsFreeRecordList);
-
-#else
-
- ByteArray response;
- response.resize(NS_PACKETSZ);
- int responseLength = res_query(const_cast<char*>(srvQuery.c_str()), ns_c_in, ns_t_srv, reinterpret_cast<u_char*>(response.getData()), response.getSize());
- if (responseLength == -1) {
- throw DomainNameResolveException();
- }
-
- // Parse header
- HEADER* header = reinterpret_cast<HEADER*>(response.getData());
- unsigned char* messageStart = reinterpret_cast<unsigned char*>(response.getData());
- unsigned char* messageEnd = messageStart + responseLength;
- unsigned char* currentEntry = messageStart + NS_HFIXEDSZ;
-
- // Skip over the queries
- int queriesCount = ntohs(header->qdcount);
- while (queriesCount > 0) {
- int entryLength = dn_skipname(currentEntry, messageEnd);
- if (entryLength < 0) {
- throw DomainNameResolveException();
- }
- currentEntry += entryLength + NS_QFIXEDSZ;
- queriesCount--;
- }
-
- // Process the SRV answers
- int answersCount = ntohs(header->ancount);
- while (answersCount > 0) {
- int entryLength = dn_skipname(currentEntry, messageEnd);
- currentEntry += entryLength;
- currentEntry += NS_RRFIXEDSZ;
-
- // Uninteresting information
- currentEntry += 2; // PRIORITY
- currentEntry += 2; // WEIGHT
-
- // Port
- if (currentEntry >= messageEnd) {
- throw DomainNameResolveException();
- }
- int port = ns_get16(currentEntry);
- currentEntry += 2;
-
- // Hostname
- if (currentEntry >= messageEnd) {
- throw DomainNameResolveException();
- }
- ByteArray entry;
- entry.resize(NS_MAXDNAME);
- entryLength = dn_expand(messageStart, messageEnd, currentEntry, entry.getData(), entry.getSize());
- if (entryLength < 0) {
- throw DomainNameResolveException();
- }
- try {
- // Resolve the hostname
- std::string hostname(entry.getData(), entryLength);
- HostAddress address = resolveHostName(hostname);
- return HostAddressPort(address, port);
- }
- catch (const DomainNameResolveException&) {
- }
- currentEntry += entryLength;
- answersCount--;
- }
-#endif
-
- throw DomainNameResolveException();
-}
-
-HostAddress DomainNameResolver::resolveHostName(const std::string& hostname) {
- boost::asio::io_service ioService;
- boost::asio::ip::tcp::resolver resolver(ioService);
- boost::asio::ip::tcp::resolver::query query(hostname, "5222");
- try {
- boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query);
- if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) {
- throw DomainNameResolveException();
- }
- boost::asio::ip::address address = (*endpointIterator).endpoint().address();
- if (address.is_v4()) {
- return HostAddress(&address.to_v4().to_bytes()[0], 4);
- }
- else {
- return HostAddress(&address.to_v6().to_bytes()[0], 16);
- }
- }
- catch (...) {
- throw DomainNameResolveException();
- }
-}
-
}
diff --git a/Swiften/Network/DomainNameResolver.h b/Swiften/Network/DomainNameResolver.h
index c7736b1..5c83622 100644
--- a/Swiften/Network/DomainNameResolver.h
+++ b/Swiften/Network/DomainNameResolver.h
@@ -1,10 +1,7 @@
-#ifndef SWIFTEN_DOMAINNAMERESOLVER_H
-#define SWIFTEN_DOMAINNAMERESOLVER_H
+#pragma once
-#include <string>
+#include <vector>
-#include "Swiften/Base/String.h"
-#include "Swiften/Network/HostAddress.h"
#include "Swiften/Network/HostAddressPort.h"
namespace Swift {
@@ -12,16 +9,8 @@ namespace Swift {
class DomainNameResolver {
public:
- DomainNameResolver();
virtual ~DomainNameResolver();
- HostAddressPort resolve(const String& domain);
-
- private:
- virtual HostAddressPort resolveDomain(const std::string& domain);
- HostAddressPort resolveXMPPService(const std::string& domain);
- HostAddress resolveHostName(const std::string& hostName);
+ virtual std::vector<HostAddressPort> resolve(const String& domain) = 0;
};
}
-
-#endif
diff --git a/Swiften/Network/DummyConnection.h b/Swiften/Network/DummyConnection.h
new file mode 100644
index 0000000..11281b3
--- /dev/null
+++ b/Swiften/Network/DummyConnection.h
@@ -0,0 +1,40 @@
+#pragma once
+
+#include <cassert>
+#include <boost/bind.hpp>
+#include <boost/enable_shared_from_this.hpp>
+
+#include "Swiften/Network/Connection.h"
+#include "Swiften/EventLoop/MainEventLoop.h"
+#include "Swiften/EventLoop/EventOwner.h"
+
+namespace Swift {
+ class DummyConnection :
+ public Connection,
+ public EventOwner,
+ public boost::enable_shared_from_this<DummyConnection> {
+
+ void listen() {
+ assert(false);
+ }
+
+ void connect(const HostAddressPort&) {
+ assert(false);
+ }
+
+ void disconnect() {
+ assert(false);
+ }
+
+ void write(const ByteArray& data) {
+ onDataWritten(data);
+ }
+
+ void receive(const ByteArray& data) {
+ MainEventLoop::postEvent(boost::bind(
+ boost::ref(onDataRead), ByteArray(data)), shared_from_this());
+ }
+
+ boost::signal<void (const ByteArray&)> onDataWritten;
+ };
+}
diff --git a/Swiften/Network/HostAddress.h b/Swiften/Network/HostAddress.h
index bf6d2f8..11f8a2b 100644
--- a/Swiften/Network/HostAddress.h
+++ b/Swiften/Network/HostAddress.h
@@ -18,6 +18,10 @@ namespace Swift {
std::string toString() const;
+ bool operator==(const HostAddress& o) const {
+ return address_ == o.address_;
+ }
+
private:
std::vector<unsigned char> address_;
};
diff --git a/Swiften/Network/HostAddressPort.h b/Swiften/Network/HostAddressPort.h
index 8668ae4..d632058 100644
--- a/Swiften/Network/HostAddressPort.h
+++ b/Swiften/Network/HostAddressPort.h
@@ -17,6 +17,10 @@ namespace Swift {
return port_;
}
+ bool operator==(const HostAddressPort& o) const {
+ return address_ == o.address_ && port_ == o.port_;
+ }
+
private:
HostAddress address_;
int port_;
diff --git a/Swiften/Network/PlatformDomainNameResolver.cpp b/Swiften/Network/PlatformDomainNameResolver.cpp
new file mode 100644
index 0000000..3803e7c
--- /dev/null
+++ b/Swiften/Network/PlatformDomainNameResolver.cpp
@@ -0,0 +1,200 @@
+#include "Swiften/Network/PlatformDomainNameResolver.h"
+
+#include "Swiften/Base/Platform.h"
+#include "Swiften/Base/foreach.h"
+
+#include <stdlib.h>
+#include <boost/asio.hpp>
+#include <idna.h>
+#ifdef SWIFTEN_PLATFORM_WINDOWS
+#undef UNICODE
+#include <windows.h>
+#include <windns.h>
+#ifndef DNS_TYPE_SRV
+#define DNS_TYPE_SRV 33
+#endif
+#else
+#include <arpa/nameser.h>
+#include <arpa/nameser_compat.h>
+#include <resolv.h>
+#endif
+#include <algorithm>
+
+#include "Swiften/Network/DomainNameResolveException.h"
+#include "Swiften/Base/String.h"
+#include "Swiften/Base/ByteArray.h"
+#include "Swiften/Network/SRVRecord.h"
+#include "Swiften/Network/SRVRecordPriorityComparator.h"
+
+namespace Swift {
+
+PlatformDomainNameResolver::PlatformDomainNameResolver() {
+}
+
+std::vector<HostAddressPort> PlatformDomainNameResolver::resolve(const String& domain) {
+ char* output;
+ if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) {
+ std::string outputString(output);
+ free(output);
+ return resolveDomain(outputString);
+ }
+ else {
+ return resolveDomain(domain.getUTF8String());
+ }
+}
+
+std::vector<HostAddressPort> PlatformDomainNameResolver::resolveDomain(const std::string& domain) {
+ try {
+ return resolveXMPPService(domain);
+ }
+ catch (const DomainNameResolveException&) {
+ }
+ std::vector<HostAddressPort> result;
+ result.push_back(HostAddressPort(resolveHostName(domain), 5222));
+ return result;
+}
+
+std::vector<HostAddressPort> PlatformDomainNameResolver::resolveXMPPService(const std::string& domain) {
+ std::vector<SRVRecord> records;
+
+ std::string srvQuery = "_xmpp-client._tcp." + domain;
+
+#if defined(SWIFTEN_PLATFORM_WINDOWS)
+ DNS_RECORD* responses;
+ // FIXME: This conversion doesn't work if unicode is deffed above
+ if (DnsQuery(srvQuery.c_str(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) {
+ throw DomainNameResolveException();
+ }
+
+ DNS_RECORD* currentEntry = responses;
+ while (currentEntry) {
+ if (currentEntry->wType == DNS_TYPE_SRV) {
+ SRVRecord record;
+ record.priority = currentEntry->Data.SRV.wPriority;
+ record.weight = currentEntry->Data.SRV.wWeight;
+ record.port = currentEntry->Data.SRV.wPort;
+
+ // The pNameTarget is actually a PCWSTR, so I would have expected this
+ // conversion to not work at all, but it does.
+ // Actually, it doesn't. Fix this and remove explicit cast
+ // Remove unicode undef above as well
+ record.hostname = std::string((const char*) currentEntry->Data.SRV.pNameTarget);
+ records.push_back(record);
+ }
+ currentEntry = currentEntry->pNext;
+ }
+ DnsRecordListFree(responses, DnsFreeRecordList);
+
+#else
+
+ ByteArray response;
+ response.resize(NS_PACKETSZ);
+ int responseLength = res_query(const_cast<char*>(srvQuery.c_str()), ns_c_in, ns_t_srv, reinterpret_cast<u_char*>(response.getData()), response.getSize());
+ if (responseLength == -1) {
+ throw DomainNameResolveException();
+ }
+
+ // Parse header
+ HEADER* header = reinterpret_cast<HEADER*>(response.getData());
+ unsigned char* messageStart = reinterpret_cast<unsigned char*>(response.getData());
+ unsigned char* messageEnd = messageStart + responseLength;
+ unsigned char* currentEntry = messageStart + NS_HFIXEDSZ;
+
+ // Skip over the queries
+ int queriesCount = ntohs(header->qdcount);
+ while (queriesCount > 0) {
+ int entryLength = dn_skipname(currentEntry, messageEnd);
+ if (entryLength < 0) {
+ throw DomainNameResolveException();
+ }
+ currentEntry += entryLength + NS_QFIXEDSZ;
+ queriesCount--;
+ }
+
+ // Process the SRV answers
+ int answersCount = ntohs(header->ancount);
+ while (answersCount > 0) {
+ SRVRecord record;
+
+ int entryLength = dn_skipname(currentEntry, messageEnd);
+ currentEntry += entryLength;
+ currentEntry += NS_RRFIXEDSZ;
+
+ // Priority
+ if (currentEntry + 2 >= messageEnd) {
+ throw DomainNameResolveException();
+ }
+ record.priority = ns_get16(currentEntry);
+ currentEntry += 2;
+
+ // Weight
+ if (currentEntry + 2 >= messageEnd) {
+ throw DomainNameResolveException();
+ }
+ record.weight = ns_get16(currentEntry);
+ currentEntry += 2;
+
+ // Port
+ if (currentEntry + 2 >= messageEnd) {
+ throw DomainNameResolveException();
+ }
+ record.port = ns_get16(currentEntry);
+ currentEntry += 2;
+
+ // Hostname
+ if (currentEntry >= messageEnd) {
+ throw DomainNameResolveException();
+ }
+ ByteArray entry;
+ entry.resize(NS_MAXDNAME);
+ entryLength = dn_expand(messageStart, messageEnd, currentEntry, entry.getData(), entry.getSize());
+ if (entryLength < 0) {
+ throw DomainNameResolveException();
+ }
+ record.hostname = std::string(entry.getData(), entryLength);
+ records.push_back(record);
+ currentEntry += entryLength;
+ answersCount--;
+ }
+#endif
+
+ // Resolve the hostnames in the records, and build the result list
+ std::sort(records.begin(), records.end(), SRVRecordPriorityComparator());
+ std::vector<HostAddressPort> result;
+ foreach(const SRVRecord& record, records) {
+ try {
+ HostAddress address = resolveHostName(record.hostname);
+ result.push_back(HostAddressPort(address, record.port));
+ }
+ catch (const DomainNameResolveException&) {
+ }
+ }
+ if (result.empty()) {
+ throw DomainNameResolveException();
+ }
+ return result;
+}
+
+HostAddress PlatformDomainNameResolver::resolveHostName(const std::string& hostname) {
+ boost::asio::io_service ioService;
+ boost::asio::ip::tcp::resolver resolver(ioService);
+ boost::asio::ip::tcp::resolver::query query(hostname, "5222");
+ try {
+ boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query);
+ if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) {
+ throw DomainNameResolveException();
+ }
+ boost::asio::ip::address address = (*endpointIterator).endpoint().address();
+ if (address.is_v4()) {
+ return HostAddress(&address.to_v4().to_bytes()[0], 4);
+ }
+ else {
+ return HostAddress(&address.to_v6().to_bytes()[0], 16);
+ }
+ }
+ catch (...) {
+ throw DomainNameResolveException();
+ }
+}
+
+}
diff --git a/Swiften/Network/PlatformDomainNameResolver.h b/Swiften/Network/PlatformDomainNameResolver.h
new file mode 100644
index 0000000..651bc97
--- /dev/null
+++ b/Swiften/Network/PlatformDomainNameResolver.h
@@ -0,0 +1,25 @@
+#pragma once
+
+#include <string>
+#include <vector>
+
+#include "Swiften/Base/String.h"
+#include "Swiften/Network/DomainNameResolver.h"
+#include "Swiften/Network/HostAddress.h"
+#include "Swiften/Network/HostAddressPort.h"
+
+namespace Swift {
+ class String;
+
+ class PlatformDomainNameResolver : public DomainNameResolver {
+ public:
+ PlatformDomainNameResolver();
+
+ std::vector<HostAddressPort> resolve(const String& domain);
+
+ private:
+ std::vector<HostAddressPort> resolveDomain(const std::string& domain);
+ std::vector<HostAddressPort> resolveXMPPService(const std::string& domain);
+ HostAddress resolveHostName(const std::string& hostName);
+ };
+}
diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript
index 06d9350..652bda1 100644
--- a/Swiften/Network/SConscript
+++ b/Swiften/Network/SConscript
@@ -11,7 +11,10 @@ objects = myenv.StaticObject([
"BoostIOServiceThread.cpp",
"ConnectionFactory.cpp",
"ConnectionServer.cpp",
+ "Connector.cpp",
"DomainNameResolver.cpp",
+ "PlatformDomainNameResolver.cpp",
+ "StaticDomainNameResolver.cpp",
"HostAddress.cpp",
"Timer.cpp",
])
diff --git a/Swiften/Network/SRVRecord.h b/Swiften/Network/SRVRecord.h
new file mode 100644
index 0000000..5a11635
--- /dev/null
+++ b/Swiften/Network/SRVRecord.h
@@ -0,0 +1,10 @@
+#pragma once
+
+namespace Swift {
+ struct SRVRecord {
+ std::string hostname;
+ int port;
+ int priority;
+ int weight;
+ };
+}
diff --git a/Swiften/Network/SRVRecordPriorityComparator.h b/Swiften/Network/SRVRecordPriorityComparator.h
new file mode 100644
index 0000000..fc16597
--- /dev/null
+++ b/Swiften/Network/SRVRecordPriorityComparator.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include "Swiften/Network/SRVRecord.h"
+
+namespace Swift {
+ struct SRVRecordPriorityComparator {
+ bool operator()(const SRVRecord& a, const SRVRecord& b) const {
+ return a.priority < b.priority;
+ }
+ };
+}
diff --git a/Swiften/Network/StaticDomainNameResolver.cpp b/Swiften/Network/StaticDomainNameResolver.cpp
new file mode 100644
index 0000000..8ca4062
--- /dev/null
+++ b/Swiften/Network/StaticDomainNameResolver.cpp
@@ -0,0 +1,28 @@
+#include "Swiften/Network/StaticDomainNameResolver.h"
+#include "Swiften/Network/DomainNameResolveException.h"
+#include "Swiften/Base/String.h"
+
+namespace Swift {
+
+StaticDomainNameResolver::StaticDomainNameResolver() {
+}
+
+std::vector<HostAddressPort> StaticDomainNameResolver::resolve(const String& queriedDomain) {
+ std::vector<HostAddressPort> result;
+
+ for(DomainCollection::const_iterator i = domains.begin(); i != domains.end(); ++i) {
+ if (i->first == queriedDomain) {
+ result.push_back(i->second);
+ }
+ }
+ if (result.empty()) {
+ throw DomainNameResolveException();
+ }
+ return result;
+}
+
+void StaticDomainNameResolver::addDomain(const String& domain, const HostAddressPort& addressPort) {
+ domains.push_back(std::make_pair(domain, addressPort));
+}
+
+}
diff --git a/Swiften/Network/StaticDomainNameResolver.h b/Swiften/Network/StaticDomainNameResolver.h
new file mode 100644
index 0000000..8688429
--- /dev/null
+++ b/Swiften/Network/StaticDomainNameResolver.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include <vector>
+#include <map>
+
+#include "Swiften/Network/DomainNameResolver.h"
+
+namespace Swift {
+ class String;
+
+ class StaticDomainNameResolver : public DomainNameResolver {
+ public:
+ StaticDomainNameResolver();
+
+ virtual std::vector<HostAddressPort> resolve(const String& domain);
+
+ void addDomain(const String& domain, const HostAddressPort& addressPort);
+
+ private:
+ typedef std::vector< std::pair<String, HostAddressPort> > DomainCollection;
+ DomainCollection domains;
+ };
+}
diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp
new file mode 100644
index 0000000..32893d8
--- /dev/null
+++ b/Swiften/Network/UnitTest/ConnectorTest.cpp
@@ -0,0 +1,140 @@
+#include <cppunit/extensions/HelperMacros.h>
+#include <cppunit/extensions/TestFactoryRegistry.h>
+
+#include <boost/optional.hpp>
+#include <boost/bind.hpp>
+
+#include "Swiften/Network/Connector.h"
+#include "Swiften/Network/Connection.h"
+#include "Swiften/Network/ConnectionFactory.h"
+#include "Swiften/Network/HostAddressPort.h"
+#include "Swiften/Network/StaticDomainNameResolver.h"
+#include "Swiften/EventLoop/MainEventLoop.h"
+#include "Swiften/EventLoop/DummyEventLoop.h"
+
+using namespace Swift;
+
+class ConnectorTest : public CppUnit::TestFixture {
+ CPPUNIT_TEST_SUITE(ConnectorTest);
+ CPPUNIT_TEST(testConnect);
+ CPPUNIT_TEST(testConnect_NoHosts);
+ CPPUNIT_TEST(testConnect_FirstHostFails);
+ CPPUNIT_TEST(testConnect_AllHostsFail);
+ CPPUNIT_TEST_SUITE_END();
+
+ public:
+ ConnectorTest() : host1(HostAddress("1.1.1.1"), 1234), host2(HostAddress("2.2.2.2"), 2345) {
+ }
+
+ void setUp() {
+ eventLoop = new DummyEventLoop();
+ resolver = new StaticDomainNameResolver();
+ connectionFactory = new MockConnectionFactory();
+ }
+
+ void tearDown() {
+ delete connectionFactory;
+ delete resolver;
+ delete eventLoop;
+ }
+
+ void testConnect() {
+ std::auto_ptr<Connector> testling(createConnector());
+ resolver->addDomain("foo.com", host1);
+ resolver->addDomain("foo.com", host2);
+
+ testling->start();
+ eventLoop->processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+ CPPUNIT_ASSERT(connections[0]);
+ CPPUNIT_ASSERT(host1 == *(connections[0]->hostAddressPort));
+ }
+
+ void testConnect_NoHosts() {
+ std::auto_ptr<Connector> testling(createConnector());
+
+ testling->start();
+ eventLoop->processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+ CPPUNIT_ASSERT(!connections[0]);
+ }
+
+ void testConnect_FirstHostFails() {
+ std::auto_ptr<Connector> testling(createConnector());
+ resolver->addDomain("foo.com", host1);
+ resolver->addDomain("foo.com", host2);
+ connectionFactory->failingPorts.push_back(host1);
+
+ testling->start();
+ eventLoop->processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+ CPPUNIT_ASSERT(host2 == *(connections[0]->hostAddressPort));
+ }
+
+ void testConnect_AllHostsFail() {
+ std::auto_ptr<Connector> testling(createConnector());
+ resolver->addDomain("foo.com", host1);
+ resolver->addDomain("foo.com", host2);
+ connectionFactory->failingPorts.push_back(host1);
+ connectionFactory->failingPorts.push_back(host2);
+
+ testling->start();
+ eventLoop->processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+ CPPUNIT_ASSERT(!connections[0]);
+ }
+
+ private:
+ Connector* createConnector() {
+ Connector* connector = new Connector("foo.com", resolver, connectionFactory);
+ connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1));
+ return connector;
+ }
+
+ void handleConnectorFinished(boost::shared_ptr<Connection> connection) {
+ boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection));
+ if (connection) {
+ assert(c);
+ }
+ connections.push_back(c);
+ }
+
+ struct MockConnection : public Connection {
+ public:
+ MockConnection(const std::vector<HostAddressPort>& failingPorts) : failingPorts(failingPorts) {}
+
+ void listen() { assert(false); }
+ void connect(const HostAddressPort& address) {
+ hostAddressPort = address;
+ MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end()));
+ }
+
+ void disconnect() { assert(false); }
+ void write(const ByteArray&) { assert(false); }
+
+ boost::optional<HostAddressPort> hostAddressPort;
+ std::vector<HostAddressPort> failingPorts;
+ };
+
+ struct MockConnectionFactory : public ConnectionFactory {
+ boost::shared_ptr<Connection> createConnection() {
+ return boost::shared_ptr<Connection>(new MockConnection(failingPorts));
+ }
+
+ std::vector<HostAddressPort> failingPorts;
+ };
+
+ private:
+ HostAddressPort host1;
+ HostAddressPort host2;
+ DummyEventLoop* eventLoop;
+ StaticDomainNameResolver* resolver;
+ MockConnectionFactory* connectionFactory;
+ std::vector< boost::shared_ptr<MockConnection> > connections;
+};
+
+CPPUNIT_TEST_SUITE_REGISTRATION(ConnectorTest);