summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRemko Tronçon <git@el-tramo.be>2009-11-12 18:12:47 (GMT)
committerRemko Tronçon <git@el-tramo.be>2009-11-12 18:12:47 (GMT)
commitfdd8755e2363e8d706a3d0bdc2e71f234abdf829 (patch)
tree470401f6f80873c4e1ce5af5cd30ab6837854d04
parent6a20be61e229255f93d55f13be3346525698237a (diff)
downloadswift-contrib-fdd8755e2363e8d706a3d0bdc2e71f234abdf829.zip
swift-contrib-fdd8755e2363e8d706a3d0bdc2e71f234abdf829.tar.bz2
Refactored DNS handling.
Connections now fallback on other DNS entries upon failure, taking into account SRV priorities.
-rwxr-xr-x[-rw-r--r--]BuildTools/Git/Hooks/pre-commit0
-rw-r--r--Swiften/Client/Client.cpp33
-rw-r--r--Swiften/Client/Client.h6
-rw-r--r--Swiften/Network/Connector.cpp48
-rw-r--r--Swiften/Network/Connector.h34
-rw-r--r--Swiften/Network/DomainNameResolver.cpp168
-rw-r--r--Swiften/Network/DomainNameResolver.h17
-rw-r--r--Swiften/Network/DummyConnection.h40
-rw-r--r--Swiften/Network/HostAddress.h4
-rw-r--r--Swiften/Network/HostAddressPort.h4
-rw-r--r--Swiften/Network/PlatformDomainNameResolver.cpp200
-rw-r--r--Swiften/Network/PlatformDomainNameResolver.h25
-rw-r--r--Swiften/Network/SConscript3
-rw-r--r--Swiften/Network/SRVRecord.h10
-rw-r--r--Swiften/Network/SRVRecordPriorityComparator.h11
-rw-r--r--Swiften/Network/StaticDomainNameResolver.cpp28
-rw-r--r--Swiften/Network/StaticDomainNameResolver.h23
-rw-r--r--Swiften/Network/UnitTest/ConnectorTest.cpp140
-rw-r--r--Swiften/QA/ClientTest/ClientTest.cpp14
-rw-r--r--Swiften/QA/NetworkTest/DomainNameResolverTest.cpp30
-rw-r--r--Swiften/SConscript1
-rw-r--r--Swiften/Session/BasicSessionStream.cpp14
-rw-r--r--Swiften/Session/BasicSessionStream.h50
23 files changed, 660 insertions, 243 deletions
diff --git a/BuildTools/Git/Hooks/pre-commit b/BuildTools/Git/Hooks/pre-commit
index 9ffc1c7..9ffc1c7 100644..100755
--- a/BuildTools/Git/Hooks/pre-commit
+++ b/BuildTools/Git/Hooks/pre-commit
diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp
index 85db4ac..6614bf7 100644
--- a/Swiften/Client/Client.cpp
+++ b/Swiften/Client/Client.cpp
@@ -2,11 +2,11 @@
#include <boost/bind.hpp>
-#include "Swiften/Network/DomainNameResolver.h"
#include "Swiften/Network/MainBoostIOServiceThread.h"
#include "Swiften/Network/BoostIOServiceThread.h"
#include "Swiften/Client/ClientSession.h"
#include "Swiften/StreamStack/PlatformTLSLayerFactory.h"
+#include "Swiften/Network/Connector.h"
#include "Swiften/Network/BoostConnectionFactory.h"
#include "Swiften/Network/DomainNameResolveException.h"
#include "Swiften/TLS/PKCS12Certificate.h"
@@ -33,24 +33,22 @@ bool Client::isAvailable() {
}
void Client::connect() {
- assert(!connection_);
- DomainNameResolver resolver;
- try {
- HostAddressPort remote = resolver.resolve(jid_.getDomain().getUTF8String());
- connection_ = connectionFactory_->createConnection();
- connection_->onConnectFinished.connect(boost::bind(&Client::handleConnectionConnectFinished, this, _1));
- connection_->connect(remote);
- }
- catch (const DomainNameResolveException& e) {
- onError(ClientError::DomainNameResolveError);
- }
+ assert(!connector_);
+ connector_ = boost::shared_ptr<Connector>(new Connector(jid_.getDomain(), &resolver_, connectionFactory_));
+ connector_->onConnectFinished.connect(boost::bind(&Client::handleConnectorFinished, this, _1));
+ connector_->start();
}
-void Client::handleConnectionConnectFinished(bool error) {
- if (error) {
+void Client::handleConnectorFinished(boost::shared_ptr<Connection> connection) {
+ // TODO: Add domain name resolver error
+ connector_.reset();
+ if (!connection) {
onError(ClientError::ConnectionError);
}
else {
+ assert(!connection_);
+ connection_ = connection;
+
assert(!sessionStream_);
sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(connection_, &payloadParserFactories_, &payloadSerializers_, tlsLayerFactory_));
if (!certificate_.isEmpty()) {
@@ -78,6 +76,9 @@ void Client::disconnect() {
}
void Client::closeConnection() {
+ if (sessionStream_) {
+ sessionStream_.reset();
+ }
if (connection_) {
connection_->disconnect();
connection_.reset();
@@ -186,11 +187,11 @@ void Client::handleNeedCredentials() {
}
void Client::handleDataRead(const String& data) {
- onDataRead(data);
+ onDataRead(data);
}
void Client::handleDataWritten(const String& data) {
- onDataWritten(data);
+ onDataWritten(data);
}
}
diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h
index 3f7d350..f09c916 100644
--- a/Swiften/Client/Client.h
+++ b/Swiften/Client/Client.h
@@ -4,6 +4,7 @@
#include <boost/signals.hpp>
#include <boost/shared_ptr.hpp>
+#include "Swiften/Network/PlatformDomainNameResolver.h"
#include "Swiften/Base/Error.h"
#include "Swiften/Client/ClientSession.h"
#include "Swiften/Client/ClientError.h"
@@ -22,6 +23,7 @@ namespace Swift {
class ConnectionFactory;
class ClientSession;
class BasicSessionStream;
+ class Connector;
class Client : public StanzaChannel, public IQRouter, public boost::bsignals::trackable {
public:
@@ -46,7 +48,7 @@ namespace Swift {
boost::signal<void (const String&)> onDataWritten;
private:
- void handleConnectionConnectFinished(bool error);
+ void handleConnectorFinished(boost::shared_ptr<Connection>);
void send(boost::shared_ptr<Stanza>);
virtual String getNewIQID();
void handleElement(boost::shared_ptr<Element>);
@@ -58,9 +60,11 @@ namespace Swift {
void closeConnection();
private:
+ PlatformDomainNameResolver resolver_;
JID jid_;
String password_;
IDGenerator idGenerator_;
+ boost::shared_ptr<Connector> connector_;
ConnectionFactory* connectionFactory_;
TLSLayerFactory* tlsLayerFactory_;
FullPayloadParserFactoryCollection payloadParserFactories_;
diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp
new file mode 100644
index 0000000..5b4fe22
--- /dev/null
+++ b/Swiften/Network/Connector.cpp
@@ -0,0 +1,48 @@
+#include "Swiften/Network/Connector.h"
+
+#include <boost/bind.hpp>
+
+#include "Swiften/Network/ConnectionFactory.h"
+#include "Swiften/Network/DomainNameResolver.h"
+#include "Swiften/Network/DomainNameResolveException.h"
+
+namespace Swift {
+
+Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory) {
+}
+
+void Connector::start() {
+ assert(!currentConnection);
+ try {
+ std::vector<HostAddressPort> resolveResult = resolver->resolve(hostname.getUTF8String());
+ resolvedHosts = std::deque<HostAddressPort>(resolveResult.begin(), resolveResult.end());
+ tryNextHostname();
+ }
+ catch (const DomainNameResolveException&) {
+ onConnectFinished(boost::shared_ptr<Connection>());
+ }
+}
+
+void Connector::tryNextHostname() {
+ if (resolvedHosts.empty()) {
+ onConnectFinished(boost::shared_ptr<Connection>());
+ }
+ else {
+ HostAddressPort remote = resolvedHosts.front();
+ resolvedHosts.pop_front();
+ currentConnection = connectionFactory->createConnection();
+ currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, this, _1));
+ currentConnection->connect(remote);
+ }
+}
+
+void Connector::handleConnectionConnectFinished(bool error) {
+ if (error) {
+ tryNextHostname();
+ }
+ else {
+ onConnectFinished(currentConnection);
+ }
+}
+
+};
diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h
new file mode 100644
index 0000000..084c416
--- /dev/null
+++ b/Swiften/Network/Connector.h
@@ -0,0 +1,34 @@
+#pragma once
+
+#include <deque>
+#include <boost/signal.hpp>
+#include <boost/shared_ptr.hpp>
+
+#include "Swiften/Network/Connection.h"
+#include "Swiften/Base/String.h"
+
+namespace Swift {
+ class DomainNameResolver;
+ class ConnectionFactory;
+
+ class Connector : public boost::bsignals::trackable {
+ public:
+ Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*);
+
+ void start();
+
+ boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished;
+
+ private:
+ void tryNextHostname();
+ void handleConnectionConnectFinished(bool error);
+
+ private:
+ String hostname;
+ DomainNameResolver* resolver;
+ ConnectionFactory* connectionFactory;
+ std::deque<HostAddressPort> resolvedHosts;
+ boost::shared_ptr<Connection> currentConnection;
+ };
+
+};
diff --git a/Swiften/Network/DomainNameResolver.cpp b/Swiften/Network/DomainNameResolver.cpp
index 44b3ecf..907dfc9 100644
--- a/Swiften/Network/DomainNameResolver.cpp
+++ b/Swiften/Network/DomainNameResolver.cpp
@@ -1,176 +1,8 @@
#include "Swiften/Network/DomainNameResolver.h"
-#include "Swiften/Base/Platform.h"
-
-#include <stdlib.h>
-#include <boost/asio.hpp>
-#include <idna.h>
-#ifdef SWIFTEN_PLATFORM_WINDOWS
-#undef UNICODE
-#include <windows.h>
-#include <windns.h>
-#ifndef DNS_TYPE_SRV
-#define DNS_TYPE_SRV 33
-#endif
-#else
-#include <arpa/nameser.h>
-#include <arpa/nameser_compat.h>
-#include <resolv.h>
-#endif
-
-#include "Swiften/Network/DomainNameResolveException.h"
-#include "Swiften/Base/String.h"
-#include "Swiften/Base/ByteArray.h"
namespace Swift {
-DomainNameResolver::DomainNameResolver() {
-}
-
DomainNameResolver::~DomainNameResolver() {
}
-HostAddressPort DomainNameResolver::resolve(const String& domain) {
- char* output;
- if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) {
- std::string outputString(output);
- free(output);
- return resolveDomain(outputString);
- }
- else {
- return resolveDomain(domain.getUTF8String());
- }
-}
-
-HostAddressPort DomainNameResolver::resolveDomain(const std::string& domain) {
- try {
- return resolveXMPPService(domain);
- }
- catch (const DomainNameResolveException&) {
- }
- return HostAddressPort(resolveHostName(domain), 5222);
-}
-
-HostAddressPort DomainNameResolver::resolveXMPPService(const std::string& domain) {
- std::string srvQuery = "_xmpp-client._tcp." + domain;
-
-#if defined(SWIFTEN_PLATFORM_WINDOWS)
- DNS_RECORD* responses;
- // FIXME: This conversion doesn't work if unicode is deffed above
- if (DnsQuery(srvQuery.c_str(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) {
- throw DomainNameResolveException();
- }
-
- DNS_RECORD* currentEntry = responses;
- while (currentEntry) {
- if (currentEntry->wType == DNS_TYPE_SRV) {
- int port = currentEntry->Data.SRV.wPort;
- try {
- // The pNameTarget is actually a PCWSTR, so I would have expected this
- // conversion to not work at all, but it does.
- // Actually, it doesn't. Fix this and remove explicit cast
- // Remove unicode undef above as well
- std::string hostname((const char*) currentEntry->Data.SRV.pNameTarget);
- HostAddress address = resolveHostName(hostname);
- DnsRecordListFree(responses, DnsFreeRecordList);
- return HostAddressPort(address, port);
- }
- catch (const DomainNameResolveException&) {
- }
- }
- currentEntry = currentEntry->pNext;
- }
- DnsRecordListFree(responses, DnsFreeRecordList);
-
-#else
-
- ByteArray response;
- response.resize(NS_PACKETSZ);
- int responseLength = res_query(const_cast<char*>(srvQuery.c_str()), ns_c_in, ns_t_srv, reinterpret_cast<u_char*>(response.getData()), response.getSize());
- if (responseLength == -1) {
- throw DomainNameResolveException();
- }
-
- // Parse header
- HEADER* header = reinterpret_cast<HEADER*>(response.getData());
- unsigned char* messageStart = reinterpret_cast<unsigned char*>(response.getData());
- unsigned char* messageEnd = messageStart + responseLength;
- unsigned char* currentEntry = messageStart + NS_HFIXEDSZ;
-
- // Skip over the queries
- int queriesCount = ntohs(header->qdcount);
- while (queriesCount > 0) {
- int entryLength = dn_skipname(currentEntry, messageEnd);
- if (entryLength < 0) {
- throw DomainNameResolveException();
- }
- currentEntry += entryLength + NS_QFIXEDSZ;
- queriesCount--;
- }
-
- // Process the SRV answers
- int answersCount = ntohs(header->ancount);
- while (answersCount > 0) {
- int entryLength = dn_skipname(currentEntry, messageEnd);
- currentEntry += entryLength;
- currentEntry += NS_RRFIXEDSZ;
-
- // Uninteresting information
- currentEntry += 2; // PRIORITY
- currentEntry += 2; // WEIGHT
-
- // Port
- if (currentEntry >= messageEnd) {
- throw DomainNameResolveException();
- }
- int port = ns_get16(currentEntry);
- currentEntry += 2;
-
- // Hostname
- if (currentEntry >= messageEnd) {
- throw DomainNameResolveException();
- }
- ByteArray entry;
- entry.resize(NS_MAXDNAME);
- entryLength = dn_expand(messageStart, messageEnd, currentEntry, entry.getData(), entry.getSize());
- if (entryLength < 0) {
- throw DomainNameResolveException();
- }
- try {
- // Resolve the hostname
- std::string hostname(entry.getData(), entryLength);
- HostAddress address = resolveHostName(hostname);
- return HostAddressPort(address, port);
- }
- catch (const DomainNameResolveException&) {
- }
- currentEntry += entryLength;
- answersCount--;
- }
-#endif
-
- throw DomainNameResolveException();
-}
-
-HostAddress DomainNameResolver::resolveHostName(const std::string& hostname) {
- boost::asio::io_service ioService;
- boost::asio::ip::tcp::resolver resolver(ioService);
- boost::asio::ip::tcp::resolver::query query(hostname, "5222");
- try {
- boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query);
- if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) {
- throw DomainNameResolveException();
- }
- boost::asio::ip::address address = (*endpointIterator).endpoint().address();
- if (address.is_v4()) {
- return HostAddress(&address.to_v4().to_bytes()[0], 4);
- }
- else {
- return HostAddress(&address.to_v6().to_bytes()[0], 16);
- }
- }
- catch (...) {
- throw DomainNameResolveException();
- }
-}
-
}
diff --git a/Swiften/Network/DomainNameResolver.h b/Swiften/Network/DomainNameResolver.h
index c7736b1..5c83622 100644
--- a/Swiften/Network/DomainNameResolver.h
+++ b/Swiften/Network/DomainNameResolver.h
@@ -1,10 +1,7 @@
-#ifndef SWIFTEN_DOMAINNAMERESOLVER_H
-#define SWIFTEN_DOMAINNAMERESOLVER_H
+#pragma once
-#include <string>
+#include <vector>
-#include "Swiften/Base/String.h"
-#include "Swiften/Network/HostAddress.h"
#include "Swiften/Network/HostAddressPort.h"
namespace Swift {
@@ -12,16 +9,8 @@ namespace Swift {
class DomainNameResolver {
public:
- DomainNameResolver();
virtual ~DomainNameResolver();
- HostAddressPort resolve(const String& domain);
-
- private:
- virtual HostAddressPort resolveDomain(const std::string& domain);
- HostAddressPort resolveXMPPService(const std::string& domain);
- HostAddress resolveHostName(const std::string& hostName);
+ virtual std::vector<HostAddressPort> resolve(const String& domain) = 0;
};
}
-
-#endif
diff --git a/Swiften/Network/DummyConnection.h b/Swiften/Network/DummyConnection.h
new file mode 100644
index 0000000..11281b3
--- /dev/null
+++ b/Swiften/Network/DummyConnection.h
@@ -0,0 +1,40 @@
+#pragma once
+
+#include <cassert>
+#include <boost/bind.hpp>
+#include <boost/enable_shared_from_this.hpp>
+
+#include "Swiften/Network/Connection.h"
+#include "Swiften/EventLoop/MainEventLoop.h"
+#include "Swiften/EventLoop/EventOwner.h"
+
+namespace Swift {
+ class DummyConnection :
+ public Connection,
+ public EventOwner,
+ public boost::enable_shared_from_this<DummyConnection> {
+
+ void listen() {
+ assert(false);
+ }
+
+ void connect(const HostAddressPort&) {
+ assert(false);
+ }
+
+ void disconnect() {
+ assert(false);
+ }
+
+ void write(const ByteArray& data) {
+ onDataWritten(data);
+ }
+
+ void receive(const ByteArray& data) {
+ MainEventLoop::postEvent(boost::bind(
+ boost::ref(onDataRead), ByteArray(data)), shared_from_this());
+ }
+
+ boost::signal<void (const ByteArray&)> onDataWritten;
+ };
+}
diff --git a/Swiften/Network/HostAddress.h b/Swiften/Network/HostAddress.h
index bf6d2f8..11f8a2b 100644
--- a/Swiften/Network/HostAddress.h
+++ b/Swiften/Network/HostAddress.h
@@ -18,6 +18,10 @@ namespace Swift {
std::string toString() const;
+ bool operator==(const HostAddress& o) const {
+ return address_ == o.address_;
+ }
+
private:
std::vector<unsigned char> address_;
};
diff --git a/Swiften/Network/HostAddressPort.h b/Swiften/Network/HostAddressPort.h
index 8668ae4..d632058 100644
--- a/Swiften/Network/HostAddressPort.h
+++ b/Swiften/Network/HostAddressPort.h
@@ -17,6 +17,10 @@ namespace Swift {
return port_;
}
+ bool operator==(const HostAddressPort& o) const {
+ return address_ == o.address_ && port_ == o.port_;
+ }
+
private:
HostAddress address_;
int port_;
diff --git a/Swiften/Network/PlatformDomainNameResolver.cpp b/Swiften/Network/PlatformDomainNameResolver.cpp
new file mode 100644
index 0000000..3803e7c
--- /dev/null
+++ b/Swiften/Network/PlatformDomainNameResolver.cpp
@@ -0,0 +1,200 @@
+#include "Swiften/Network/PlatformDomainNameResolver.h"
+
+#include "Swiften/Base/Platform.h"
+#include "Swiften/Base/foreach.h"
+
+#include <stdlib.h>
+#include <boost/asio.hpp>
+#include <idna.h>
+#ifdef SWIFTEN_PLATFORM_WINDOWS
+#undef UNICODE
+#include <windows.h>
+#include <windns.h>
+#ifndef DNS_TYPE_SRV
+#define DNS_TYPE_SRV 33
+#endif
+#else
+#include <arpa/nameser.h>
+#include <arpa/nameser_compat.h>
+#include <resolv.h>
+#endif
+#include <algorithm>
+
+#include "Swiften/Network/DomainNameResolveException.h"
+#include "Swiften/Base/String.h"
+#include "Swiften/Base/ByteArray.h"
+#include "Swiften/Network/SRVRecord.h"
+#include "Swiften/Network/SRVRecordPriorityComparator.h"
+
+namespace Swift {
+
+PlatformDomainNameResolver::PlatformDomainNameResolver() {
+}
+
+std::vector<HostAddressPort> PlatformDomainNameResolver::resolve(const String& domain) {
+ char* output;
+ if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) {
+ std::string outputString(output);
+ free(output);
+ return resolveDomain(outputString);
+ }
+ else {
+ return resolveDomain(domain.getUTF8String());
+ }
+}
+
+std::vector<HostAddressPort> PlatformDomainNameResolver::resolveDomain(const std::string& domain) {
+ try {
+ return resolveXMPPService(domain);
+ }
+ catch (const DomainNameResolveException&) {
+ }
+ std::vector<HostAddressPort> result;
+ result.push_back(HostAddressPort(resolveHostName(domain), 5222));
+ return result;
+}
+
+std::vector<HostAddressPort> PlatformDomainNameResolver::resolveXMPPService(const std::string& domain) {
+ std::vector<SRVRecord> records;
+
+ std::string srvQuery = "_xmpp-client._tcp." + domain;
+
+#if defined(SWIFTEN_PLATFORM_WINDOWS)
+ DNS_RECORD* responses;
+ // FIXME: This conversion doesn't work if unicode is deffed above
+ if (DnsQuery(srvQuery.c_str(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) {
+ throw DomainNameResolveException();
+ }
+
+ DNS_RECORD* currentEntry = responses;
+ while (currentEntry) {
+ if (currentEntry->wType == DNS_TYPE_SRV) {
+ SRVRecord record;
+ record.priority = currentEntry->Data.SRV.wPriority;
+ record.weight = currentEntry->Data.SRV.wWeight;
+ record.port = currentEntry->Data.SRV.wPort;
+
+ // The pNameTarget is actually a PCWSTR, so I would have expected this
+ // conversion to not work at all, but it does.
+ // Actually, it doesn't. Fix this and remove explicit cast
+ // Remove unicode undef above as well
+ record.hostname = std::string((const char*) currentEntry->Data.SRV.pNameTarget);
+ records.push_back(record);
+ }
+ currentEntry = currentEntry->pNext;
+ }
+ DnsRecordListFree(responses, DnsFreeRecordList);
+
+#else
+
+ ByteArray response;
+ response.resize(NS_PACKETSZ);
+ int responseLength = res_query(const_cast<char*>(srvQuery.c_str()), ns_c_in, ns_t_srv, reinterpret_cast<u_char*>(response.getData()), response.getSize());
+ if (responseLength == -1) {
+ throw DomainNameResolveException();
+ }
+
+ // Parse header
+ HEADER* header = reinterpret_cast<HEADER*>(response.getData());
+ unsigned char* messageStart = reinterpret_cast<unsigned char*>(response.getData());
+ unsigned char* messageEnd = messageStart + responseLength;
+ unsigned char* currentEntry = messageStart + NS_HFIXEDSZ;
+
+ // Skip over the queries
+ int queriesCount = ntohs(header->qdcount);
+ while (queriesCount > 0) {
+ int entryLength = dn_skipname(currentEntry, messageEnd);
+ if (entryLength < 0) {
+ throw DomainNameResolveException();
+ }
+ currentEntry += entryLength + NS_QFIXEDSZ;
+ queriesCount--;
+ }
+
+ // Process the SRV answers
+ int answersCount = ntohs(header->ancount);
+ while (answersCount > 0) {
+ SRVRecord record;
+
+ int entryLength = dn_skipname(currentEntry, messageEnd);
+ currentEntry += entryLength;
+ currentEntry += NS_RRFIXEDSZ;
+
+ // Priority
+ if (currentEntry + 2 >= messageEnd) {
+ throw DomainNameResolveException();
+ }
+ record.priority = ns_get16(currentEntry);
+ currentEntry += 2;
+
+ // Weight
+ if (currentEntry + 2 >= messageEnd) {
+ throw DomainNameResolveException();
+ }
+ record.weight = ns_get16(currentEntry);
+ currentEntry += 2;
+
+ // Port
+ if (currentEntry + 2 >= messageEnd) {
+ throw DomainNameResolveException();
+ }
+ record.port = ns_get16(currentEntry);
+ currentEntry += 2;
+
+ // Hostname
+ if (currentEntry >= messageEnd) {
+ throw DomainNameResolveException();
+ }
+ ByteArray entry;
+ entry.resize(NS_MAXDNAME);
+ entryLength = dn_expand(messageStart, messageEnd, currentEntry, entry.getData(), entry.getSize());
+ if (entryLength < 0) {
+ throw DomainNameResolveException();
+ }
+ record.hostname = std::string(entry.getData(), entryLength);
+ records.push_back(record);
+ currentEntry += entryLength;
+ answersCount--;
+ }
+#endif
+
+ // Resolve the hostnames in the records, and build the result list
+ std::sort(records.begin(), records.end(), SRVRecordPriorityComparator());
+ std::vector<HostAddressPort> result;
+ foreach(const SRVRecord& record, records) {
+ try {
+ HostAddress address = resolveHostName(record.hostname);
+ result.push_back(HostAddressPort(address, record.port));
+ }
+ catch (const DomainNameResolveException&) {
+ }
+ }
+ if (result.empty()) {
+ throw DomainNameResolveException();
+ }
+ return result;
+}
+
+HostAddress PlatformDomainNameResolver::resolveHostName(const std::string& hostname) {
+ boost::asio::io_service ioService;
+ boost::asio::ip::tcp::resolver resolver(ioService);
+ boost::asio::ip::tcp::resolver::query query(hostname, "5222");
+ try {
+ boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query);
+ if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) {
+ throw DomainNameResolveException();
+ }
+ boost::asio::ip::address address = (*endpointIterator).endpoint().address();
+ if (address.is_v4()) {
+ return HostAddress(&address.to_v4().to_bytes()[0], 4);
+ }
+ else {
+ return HostAddress(&address.to_v6().to_bytes()[0], 16);
+ }
+ }
+ catch (...) {
+ throw DomainNameResolveException();
+ }
+}
+
+}
diff --git a/Swiften/Network/PlatformDomainNameResolver.h b/Swiften/Network/PlatformDomainNameResolver.h
new file mode 100644
index 0000000..651bc97
--- /dev/null
+++ b/Swiften/Network/PlatformDomainNameResolver.h
@@ -0,0 +1,25 @@
+#pragma once
+
+#include <string>
+#include <vector>
+
+#include "Swiften/Base/String.h"
+#include "Swiften/Network/DomainNameResolver.h"
+#include "Swiften/Network/HostAddress.h"
+#include "Swiften/Network/HostAddressPort.h"
+
+namespace Swift {
+ class String;
+
+ class PlatformDomainNameResolver : public DomainNameResolver {
+ public:
+ PlatformDomainNameResolver();
+
+ std::vector<HostAddressPort> resolve(const String& domain);
+
+ private:
+ std::vector<HostAddressPort> resolveDomain(const std::string& domain);
+ std::vector<HostAddressPort> resolveXMPPService(const std::string& domain);
+ HostAddress resolveHostName(const std::string& hostName);
+ };
+}
diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript
index 06d9350..652bda1 100644
--- a/Swiften/Network/SConscript
+++ b/Swiften/Network/SConscript
@@ -11,7 +11,10 @@ objects = myenv.StaticObject([
"BoostIOServiceThread.cpp",
"ConnectionFactory.cpp",
"ConnectionServer.cpp",
+ "Connector.cpp",
"DomainNameResolver.cpp",
+ "PlatformDomainNameResolver.cpp",
+ "StaticDomainNameResolver.cpp",
"HostAddress.cpp",
"Timer.cpp",
])
diff --git a/Swiften/Network/SRVRecord.h b/Swiften/Network/SRVRecord.h
new file mode 100644
index 0000000..5a11635
--- /dev/null
+++ b/Swiften/Network/SRVRecord.h
@@ -0,0 +1,10 @@
+#pragma once
+
+namespace Swift {
+ struct SRVRecord {
+ std::string hostname;
+ int port;
+ int priority;
+ int weight;
+ };
+}
diff --git a/Swiften/Network/SRVRecordPriorityComparator.h b/Swiften/Network/SRVRecordPriorityComparator.h
new file mode 100644
index 0000000..fc16597
--- /dev/null
+++ b/Swiften/Network/SRVRecordPriorityComparator.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include "Swiften/Network/SRVRecord.h"
+
+namespace Swift {
+ struct SRVRecordPriorityComparator {
+ bool operator()(const SRVRecord& a, const SRVRecord& b) const {
+ return a.priority < b.priority;
+ }
+ };
+}
diff --git a/Swiften/Network/StaticDomainNameResolver.cpp b/Swiften/Network/StaticDomainNameResolver.cpp
new file mode 100644
index 0000000..8ca4062
--- /dev/null
+++ b/Swiften/Network/StaticDomainNameResolver.cpp
@@ -0,0 +1,28 @@
+#include "Swiften/Network/StaticDomainNameResolver.h"
+#include "Swiften/Network/DomainNameResolveException.h"
+#include "Swiften/Base/String.h"
+
+namespace Swift {
+
+StaticDomainNameResolver::StaticDomainNameResolver() {
+}
+
+std::vector<HostAddressPort> StaticDomainNameResolver::resolve(const String& queriedDomain) {
+ std::vector<HostAddressPort> result;
+
+ for(DomainCollection::const_iterator i = domains.begin(); i != domains.end(); ++i) {
+ if (i->first == queriedDomain) {
+ result.push_back(i->second);
+ }
+ }
+ if (result.empty()) {
+ throw DomainNameResolveException();
+ }
+ return result;
+}
+
+void StaticDomainNameResolver::addDomain(const String& domain, const HostAddressPort& addressPort) {
+ domains.push_back(std::make_pair(domain, addressPort));
+}
+
+}
diff --git a/Swiften/Network/StaticDomainNameResolver.h b/Swiften/Network/StaticDomainNameResolver.h
new file mode 100644
index 0000000..8688429
--- /dev/null
+++ b/Swiften/Network/StaticDomainNameResolver.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include <vector>
+#include <map>
+
+#include "Swiften/Network/DomainNameResolver.h"
+
+namespace Swift {
+ class String;
+
+ class StaticDomainNameResolver : public DomainNameResolver {
+ public:
+ StaticDomainNameResolver();
+
+ virtual std::vector<HostAddressPort> resolve(const String& domain);
+
+ void addDomain(const String& domain, const HostAddressPort& addressPort);
+
+ private:
+ typedef std::vector< std::pair<String, HostAddressPort> > DomainCollection;
+ DomainCollection domains;
+ };
+}
diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp
new file mode 100644
index 0000000..32893d8
--- /dev/null
+++ b/Swiften/Network/UnitTest/ConnectorTest.cpp
@@ -0,0 +1,140 @@
+#include <cppunit/extensions/HelperMacros.h>
+#include <cppunit/extensions/TestFactoryRegistry.h>
+
+#include <boost/optional.hpp>
+#include <boost/bind.hpp>
+
+#include "Swiften/Network/Connector.h"
+#include "Swiften/Network/Connection.h"
+#include "Swiften/Network/ConnectionFactory.h"
+#include "Swiften/Network/HostAddressPort.h"
+#include "Swiften/Network/StaticDomainNameResolver.h"
+#include "Swiften/EventLoop/MainEventLoop.h"
+#include "Swiften/EventLoop/DummyEventLoop.h"
+
+using namespace Swift;
+
+class ConnectorTest : public CppUnit::TestFixture {
+ CPPUNIT_TEST_SUITE(ConnectorTest);
+ CPPUNIT_TEST(testConnect);
+ CPPUNIT_TEST(testConnect_NoHosts);
+ CPPUNIT_TEST(testConnect_FirstHostFails);
+ CPPUNIT_TEST(testConnect_AllHostsFail);
+ CPPUNIT_TEST_SUITE_END();
+
+ public:
+ ConnectorTest() : host1(HostAddress("1.1.1.1"), 1234), host2(HostAddress("2.2.2.2"), 2345) {
+ }
+
+ void setUp() {
+ eventLoop = new DummyEventLoop();
+ resolver = new StaticDomainNameResolver();
+ connectionFactory = new MockConnectionFactory();
+ }
+
+ void tearDown() {
+ delete connectionFactory;
+ delete resolver;
+ delete eventLoop;
+ }
+
+ void testConnect() {
+ std::auto_ptr<Connector> testling(createConnector());
+ resolver->addDomain("foo.com", host1);
+ resolver->addDomain("foo.com", host2);
+
+ testling->start();
+ eventLoop->processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+ CPPUNIT_ASSERT(connections[0]);
+ CPPUNIT_ASSERT(host1 == *(connections[0]->hostAddressPort));
+ }
+
+ void testConnect_NoHosts() {
+ std::auto_ptr<Connector> testling(createConnector());
+
+ testling->start();
+ eventLoop->processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+ CPPUNIT_ASSERT(!connections[0]);
+ }
+
+ void testConnect_FirstHostFails() {
+ std::auto_ptr<Connector> testling(createConnector());
+ resolver->addDomain("foo.com", host1);
+ resolver->addDomain("foo.com", host2);
+ connectionFactory->failingPorts.push_back(host1);
+
+ testling->start();
+ eventLoop->processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+ CPPUNIT_ASSERT(host2 == *(connections[0]->hostAddressPort));
+ }
+
+ void testConnect_AllHostsFail() {
+ std::auto_ptr<Connector> testling(createConnector());
+ resolver->addDomain("foo.com", host1);
+ resolver->addDomain("foo.com", host2);
+ connectionFactory->failingPorts.push_back(host1);
+ connectionFactory->failingPorts.push_back(host2);
+
+ testling->start();
+ eventLoop->processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+ CPPUNIT_ASSERT(!connections[0]);
+ }
+
+ private:
+ Connector* createConnector() {
+ Connector* connector = new Connector("foo.com", resolver, connectionFactory);
+ connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1));
+ return connector;
+ }
+
+ void handleConnectorFinished(boost::shared_ptr<Connection> connection) {
+ boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection));
+ if (connection) {
+ assert(c);
+ }
+ connections.push_back(c);
+ }
+
+ struct MockConnection : public Connection {
+ public:
+ MockConnection(const std::vector<HostAddressPort>& failingPorts) : failingPorts(failingPorts) {}
+
+ void listen() { assert(false); }
+ void connect(const HostAddressPort& address) {
+ hostAddressPort = address;
+ MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end()));
+ }
+
+ void disconnect() { assert(false); }
+ void write(const ByteArray&) { assert(false); }
+
+ boost::optional<HostAddressPort> hostAddressPort;
+ std::vector<HostAddressPort> failingPorts;
+ };
+
+ struct MockConnectionFactory : public ConnectionFactory {
+ boost::shared_ptr<Connection> createConnection() {
+ return boost::shared_ptr<Connection>(new MockConnection(failingPorts));
+ }
+
+ std::vector<HostAddressPort> failingPorts;
+ };
+
+ private:
+ HostAddressPort host1;
+ HostAddressPort host2;
+ DummyEventLoop* eventLoop;
+ StaticDomainNameResolver* resolver;
+ MockConnectionFactory* connectionFactory;
+ std::vector< boost::shared_ptr<MockConnection> > connections;
+};
+
+CPPUNIT_TEST_SUITE_REGISTRATION(ConnectorTest);
diff --git a/Swiften/QA/ClientTest/ClientTest.cpp b/Swiften/QA/ClientTest/ClientTest.cpp
index b50a0bf..cf1c161 100644
--- a/Swiften/QA/ClientTest/ClientTest.cpp
+++ b/Swiften/QA/ClientTest/ClientTest.cpp
@@ -15,12 +15,20 @@ using namespace Swift;
SimpleEventLoop eventLoop;
Client* client = 0;
+bool reconnected = false;
bool rosterReceived = false;
void handleRosterReceived(boost::shared_ptr<Payload>) {
- rosterReceived = true;
- client->disconnect();
- eventLoop.stop();
+ if (reconnected) {
+ rosterReceived = true;
+ client->disconnect();
+ eventLoop.stop();
+ }
+ else {
+ reconnected = true;
+ client->disconnect();
+ client->connect();
+ }
}
void handleConnected() {
diff --git a/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp b/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp
index 8968efd..cb812a1 100644
--- a/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp
+++ b/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp
@@ -2,7 +2,7 @@
#include <cppunit/extensions/TestFactoryRegistry.h>
#include "Swiften/Base/String.h"
-#include "Swiften/Network/DomainNameResolver.h"
+#include "Swiften/Network/PlatformDomainNameResolver.h"
#include "Swiften/Network/DomainNameResolveException.h"
using namespace Swift;
@@ -14,13 +14,14 @@ class DomainNameResolverTest : public CppUnit::TestFixture {
CPPUNIT_TEST(testResolve_Invalid);
//CPPUNIT_TEST(testResolve_IPv6);
CPPUNIT_TEST(testResolve_International);
+ CPPUNIT_TEST(testResolve_Localhost);
CPPUNIT_TEST_SUITE_END();
public:
DomainNameResolverTest() {}
void setUp() {
- resolver_ = new DomainNameResolver();
+ resolver_ = new PlatformDomainNameResolver();
}
void tearDown() {
@@ -28,17 +29,22 @@ class DomainNameResolverTest : public CppUnit::TestFixture {
}
void testResolve_NoSRV() {
- HostAddressPort result = resolver_->resolve("xmpp.test.swift.im");
+ HostAddressPort result = resolver_->resolve("xmpp.test.swift.im")[0];
CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.0"), result.getAddress().toString());
CPPUNIT_ASSERT_EQUAL(5222, result.getPort());
}
void testResolve_SRV() {
- HostAddressPort result = resolver_->resolve("xmpp-srv.test.swift.im");
+ std::vector<HostAddressPort> result = resolver_->resolve("xmpp-srv.test.swift.im");
- CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.1"), result.getAddress().toString());
- CPPUNIT_ASSERT_EQUAL(5000, result.getPort());
+ CPPUNIT_ASSERT_EQUAL(3, static_cast<int>(result.size()));
+ CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.1"), result[0].getAddress().toString());
+ CPPUNIT_ASSERT_EQUAL(5000, result[0].getPort());
+ CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.3"), result[1].getAddress().toString());
+ CPPUNIT_ASSERT_EQUAL(5000, result[1].getPort());
+ CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.2"), result[2].getAddress().toString());
+ CPPUNIT_ASSERT_EQUAL(5000, result[2].getPort());
}
void testResolve_Invalid() {
@@ -46,19 +52,25 @@ class DomainNameResolverTest : public CppUnit::TestFixture {
}
void testResolve_IPv6() {
- HostAddressPort result = resolver_->resolve("xmpp-ipv6.test.swift.im");
+ HostAddressPort result = resolver_->resolve("xmpp-ipv6.test.swift.im")[0];
CPPUNIT_ASSERT_EQUAL(std::string("0000:0000:0000:0000:0000:ffff:0a00:0104"), result.getAddress().toString());
CPPUNIT_ASSERT_EQUAL(5222, result.getPort());
}
void testResolve_International() {
- HostAddressPort result = resolver_->resolve("tron\xc3\xa7on.test.swift.im");
+ HostAddressPort result = resolver_->resolve("tron\xc3\xa7on.test.swift.im")[0];
CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.3"), result.getAddress().toString());
CPPUNIT_ASSERT_EQUAL(5222, result.getPort());
}
+ void testResolve_Localhost() {
+ HostAddressPort result = resolver_->resolve("localhost")[0];
+ CPPUNIT_ASSERT_EQUAL(std::string("127.0.0.1"), result.getAddress().toString());
+ CPPUNIT_ASSERT_EQUAL(5222, result.getPort());
+ }
+
private:
- DomainNameResolver* resolver_;
+ PlatformDomainNameResolver* resolver_;
};
CPPUNIT_TEST_SUITE_REGISTRATION(DomainNameResolverTest);
diff --git a/Swiften/SConscript b/Swiften/SConscript
index d5ddce4..d896cd8 100644
--- a/Swiften/SConscript
+++ b/Swiften/SConscript
@@ -121,6 +121,7 @@ env.Append(UNITTEST_SOURCES = [
File("LinkLocal/UnitTest/LinkLocalServiceInfoTest.cpp"),
File("LinkLocal/UnitTest/LinkLocalServiceTest.cpp"),
File("Network/UnitTest/HostAddressTest.cpp"),
+ File("Network/UnitTest/ConnectorTest.cpp"),
File("Parser/PayloadParsers/UnitTest/BodyParserTest.cpp"),
File("Parser/PayloadParsers/UnitTest/DiscoInfoParserTest.cpp"),
File("Parser/PayloadParsers/UnitTest/ErrorParserTest.cpp"),
diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp
index e0fbce7..a9a3cb0 100644
--- a/Swiften/Session/BasicSessionStream.cpp
+++ b/Swiften/Session/BasicSessionStream.cpp
@@ -20,13 +20,13 @@ void BasicSessionStream::initialize() {
xmppLayer->onStreamStart.connect(boost::bind(&BasicSessionStream::handleStreamStartReceived, shared_from_this(), _1));
xmppLayer->onElement.connect(boost::bind(&BasicSessionStream::handleElementReceived, shared_from_this(), _1));
xmppLayer->onError.connect(boost::bind(
- &BasicSessionStream::handleXMPPError, shared_from_this()));
- xmppLayer->onDataRead.connect(boost::bind(&BasicSessionStream::handleDataRead, shared_from_this(), _1));
- xmppLayer->onWriteData.connect(boost::bind(&BasicSessionStream::handleDataWritten, shared_from_this(), _1));
+ &BasicSessionStream::handleXMPPError, shared_from_this()));
+ xmppLayer->onDataRead.connect(boost::bind(&BasicSessionStream::handleDataRead, shared_from_this(), _1));
+ xmppLayer->onWriteData.connect(boost::bind(&BasicSessionStream::handleDataWritten, shared_from_this(), _1));
connection->onDisconnected.connect(boost::bind(&BasicSessionStream::handleConnectionError, shared_from_this(), _1));
connectionLayer = boost::shared_ptr<ConnectionLayer>(
- new ConnectionLayer(connection));
+ new ConnectionLayer(connection));
streamStack = new StreamStack(xmppLayer, connectionLayer);
@@ -34,7 +34,7 @@ void BasicSessionStream::initialize() {
}
BasicSessionStream::~BasicSessionStream() {
- delete streamStack;
+ delete streamStack;
}
void BasicSessionStream::writeHeader(const ProtocolHeader& header) {
@@ -57,7 +57,7 @@ bool BasicSessionStream::isAvailable() {
}
bool BasicSessionStream::supportsTLSEncryption() {
- return tlsLayerFactory && tlsLayerFactory->canCreate();
+ return tlsLayerFactory && tlsLayerFactory->canCreate();
}
void BasicSessionStream::addTLSEncryption() {
@@ -88,7 +88,7 @@ void BasicSessionStream::setWhitespacePingEnabled(bool enabled) {
}
void BasicSessionStream::resetXMPPParser() {
- xmppLayer->resetParser();
+ xmppLayer->resetParser();
}
void BasicSessionStream::handleStreamStartReceived(const ProtocolHeader& header) {
diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h
index 0cb50eb..07bae81 100644
--- a/Swiften/Session/BasicSessionStream.h
+++ b/Swiften/Session/BasicSessionStream.h
@@ -7,26 +7,26 @@
#include "Swiften/Session/SessionStream.h"
namespace Swift {
- class TLSLayerFactory;
- class TLSLayer;
- class WhitespacePingLayer;
+ class TLSLayerFactory;
+ class TLSLayer;
+ class WhitespacePingLayer;
class PayloadParserFactoryCollection;
class PayloadSerializerCollection;
- class StreamStack;
+ class StreamStack;
class XMPPLayer;
- class ConnectionLayer;
+ class ConnectionLayer;
- class BasicSessionStream :
- public SessionStream,
- public boost::enable_shared_from_this<BasicSessionStream> {
- public:
- BasicSessionStream(
- boost::shared_ptr<Connection> connection,
- PayloadParserFactoryCollection* payloadParserFactories,
- PayloadSerializerCollection* payloadSerializers,
- TLSLayerFactory* tlsLayerFactory
- );
- ~BasicSessionStream();
+ class BasicSessionStream :
+ public SessionStream,
+ public boost::enable_shared_from_this<BasicSessionStream> {
+ public:
+ BasicSessionStream(
+ boost::shared_ptr<Connection> connection,
+ PayloadParserFactoryCollection* payloadParserFactories,
+ PayloadSerializerCollection* payloadSerializers,
+ TLSLayerFactory* tlsLayerFactory
+ );
+ ~BasicSessionStream();
void initialize();
@@ -43,17 +43,17 @@ namespace Swift {
virtual void resetXMPPParser();
- private:
+ private:
void handleConnectionError(const boost::optional<Connection::Error>& error);
- void handleXMPPError();
+ void handleXMPPError();
void handleTLSConnected();
- void handleTLSError();
+ void handleTLSError();
void handleStreamStartReceived(const ProtocolHeader&);
void handleElementReceived(boost::shared_ptr<Element>);
- void handleDataRead(const ByteArray& data);
- void handleDataWritten(const ByteArray& data);
+ void handleDataRead(const ByteArray& data);
+ void handleDataWritten(const ByteArray& data);
- private:
+ private:
bool available;
boost::shared_ptr<Connection> connection;
PayloadParserFactoryCollection* payloadParserFactories;
@@ -62,7 +62,7 @@ namespace Swift {
boost::shared_ptr<XMPPLayer> xmppLayer;
boost::shared_ptr<ConnectionLayer> connectionLayer;
StreamStack* streamStack;
- boost::shared_ptr<TLSLayer> tlsLayer;
- boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer;
- };
+ boost::shared_ptr<TLSLayer> tlsLayer;
+ boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer;
+ };
}