From 5608da36a3a319070494d5a70ff984e7c172186e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be> Date: Wed, 2 Dec 2009 21:42:30 +0100 Subject: DNS querying is now asynchronous. This means we can now move them to a separate thread. diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp index 19f7ee5..e9de19a 100644 --- a/Swiften/Client/Client.cpp +++ b/Swiften/Client/Client.cpp @@ -9,7 +9,6 @@ #include "Swiften/Network/Connector.h" #include "Swiften/Network/BoostConnectionFactory.h" #include "Swiften/Network/BoostTimerFactory.h" -#include "Swiften/Network/DomainNameResolveException.h" #include "Swiften/TLS/PKCS12Certificate.h" #include "Swiften/Session/BasicSessionStream.h" diff --git a/Swiften/Network/BoostConnection.cpp b/Swiften/Network/BoostConnection.cpp index 9f2a7da..0d62300 100644 --- a/Swiften/Network/BoostConnection.cpp +++ b/Swiften/Network/BoostConnection.cpp @@ -7,8 +7,7 @@ #include "Swiften/EventLoop/MainEventLoop.h" #include "Swiften/Base/String.h" #include "Swiften/Base/ByteArray.h" -#include "Swiften/Network/DomainNameResolver.h" -#include "Swiften/Network/DomainNameResolveException.h" +#include "Swiften/Network/HostAddressPort.h" namespace Swift { diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp index 5b4fe22..c9087d7 100644 --- a/Swiften/Network/Connector.cpp +++ b/Swiften/Network/Connector.cpp @@ -1,43 +1,86 @@ #include "Swiften/Network/Connector.h" #include <boost/bind.hpp> +#include <iostream> #include "Swiften/Network/ConnectionFactory.h" #include "Swiften/Network/DomainNameResolver.h" -#include "Swiften/Network/DomainNameResolveException.h" +#include "Swiften/Network/DomainNameAddressQuery.h" namespace Swift { -Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory) { +Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), queriedAllHosts(true) { } void Connector::start() { assert(!currentConnection); - try { - std::vector<HostAddressPort> resolveResult = resolver->resolve(hostname.getUTF8String()); - resolvedHosts = std::deque<HostAddressPort>(resolveResult.begin(), resolveResult.end()); - tryNextHostname(); - } - catch (const DomainNameResolveException&) { + assert(!serviceQuery); + queriedAllHosts = false; + serviceQuery = resolver->createServiceQuery("_xmpp-client._tcp." + hostname); + serviceQuery->onResult.connect(boost::bind(&Connector::handleServiceQueryResult, this, _1)); + serviceQuery->run(); +} + +void Connector::queryAddress(const String& hostname) { + assert(!addressQuery); + addressQuery = resolver->createAddressQuery(hostname); + addressQuery->onResult.connect(boost::bind(&Connector::handleAddressQueryResult, this, _1, _2)); + addressQuery->run(); +} + +void Connector::handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result) { + serviceQueryResults = std::deque<DomainNameServiceQuery::Result>(result.begin(), result.end()); + serviceQuery.reset(); + tryNextHostname(); +} + +void Connector::tryNextHostname() { + if (queriedAllHosts) { onConnectFinished(boost::shared_ptr<Connection>()); } + else if (serviceQueryResults.empty()) { + // Fall back on simple address resolving + queriedAllHosts = true; + queryAddress(hostname); + } + else { + queryAddress(serviceQueryResults.front().hostname); + } } -void Connector::tryNextHostname() { - if (resolvedHosts.empty()) { +void Connector::handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error) { + addressQuery.reset(); + if (!serviceQueryResults.empty()) { + DomainNameServiceQuery::Result serviceQueryResult = serviceQueryResults.front(); + serviceQueryResults.pop_front(); + if (error) { + tryNextHostname(); + } + else { + tryConnect(HostAddressPort(address, serviceQueryResult.port)); + } + } + else if (error) { + // The fallback address query failed + assert(queriedAllHosts); onConnectFinished(boost::shared_ptr<Connection>()); } else { - HostAddressPort remote = resolvedHosts.front(); - resolvedHosts.pop_front(); - currentConnection = connectionFactory->createConnection(); - currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, this, _1)); - currentConnection->connect(remote); + // The fallback query succeeded + tryConnect(HostAddressPort(address, 5222)); } } +void Connector::tryConnect(const HostAddressPort& target) { + assert(!currentConnection); + currentConnection = connectionFactory->createConnection(); + currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, this, _1)); + currentConnection->connect(target); +} + void Connector::handleConnectionConnectFinished(bool error) { if (error) { + currentConnection.reset(); tryNextHostname(); } else { diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h index 44b4584..cb885ab 100644 --- a/Swiften/Network/Connector.h +++ b/Swiften/Network/Connector.h @@ -4,11 +4,14 @@ #include <boost/signal.hpp> #include <boost/shared_ptr.hpp> +#include "Swiften/Network/DomainNameServiceQuery.h" #include "Swiften/Network/Connection.h" #include "Swiften/Network/HostAddressPort.h" #include "Swiften/Base/String.h" +#include "Swiften/Network/DomainNameResolveError.h" namespace Swift { + class DomainNameAddressQuery; class DomainNameResolver; class ConnectionFactory; @@ -21,14 +24,23 @@ namespace Swift { boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished; private: + void handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result); + void handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error); + void queryAddress(const String& hostname); + void tryNextHostname(); + void tryConnect(const HostAddressPort& target); + void handleConnectionConnectFinished(bool error); private: String hostname; DomainNameResolver* resolver; ConnectionFactory* connectionFactory; - std::deque<HostAddressPort> resolvedHosts; + boost::shared_ptr<DomainNameServiceQuery> serviceQuery; + std::deque<DomainNameServiceQuery::Result> serviceQueryResults; + boost::shared_ptr<DomainNameAddressQuery> addressQuery; + bool queriedAllHosts; boost::shared_ptr<Connection> currentConnection; }; diff --git a/Swiften/Network/DomainNameAddressQuery.cpp b/Swiften/Network/DomainNameAddressQuery.cpp new file mode 100644 index 0000000..5e77cd7 --- /dev/null +++ b/Swiften/Network/DomainNameAddressQuery.cpp @@ -0,0 +1,8 @@ +#include "Swiften/Network/DomainNameAddressQuery.h" + +namespace Swift { + +DomainNameAddressQuery::~DomainNameAddressQuery() { +} + +} diff --git a/Swiften/Network/DomainNameAddressQuery.h b/Swiften/Network/DomainNameAddressQuery.h new file mode 100644 index 0000000..66a79db --- /dev/null +++ b/Swiften/Network/DomainNameAddressQuery.h @@ -0,0 +1,19 @@ +#pragma once + +#include <boost/signals.hpp> +#include <boost/optional.hpp> +#include <boost/shared_ptr.hpp> + +#include "Swiften/Network/DomainNameResolveError.h" +#include "Swiften/Network/HostAddress.h" + +namespace Swift { + class DomainNameAddressQuery { + public: + virtual ~DomainNameAddressQuery(); + + virtual void run() = 0; + + boost::signal<void (const HostAddress&, boost::optional<DomainNameResolveError>)> onResult; + }; +} diff --git a/Swiften/Network/DomainNameResolveError.h b/Swiften/Network/DomainNameResolveError.h new file mode 100644 index 0000000..860ea23 --- /dev/null +++ b/Swiften/Network/DomainNameResolveError.h @@ -0,0 +1,10 @@ +#pragma once + +#include "Swiften/Base/Error.h" + +namespace Swift { + class DomainNameResolveError : public Error { + public: + DomainNameResolveError() {} + }; +} diff --git a/Swiften/Network/DomainNameResolveException.h b/Swiften/Network/DomainNameResolveException.h deleted file mode 100644 index a6cfbc6..0000000 --- a/Swiften/Network/DomainNameResolveException.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef SWIFTEN_DOMAINNAMELOOKUPEXCEPTION_H -#define SWIFTEN_DOMAINNAMELOOKUPEXCEPTION_H - -namespace Swift { - class DomainNameResolveException { - public: - DomainNameResolveException() {} - }; -} - -#endif diff --git a/Swiften/Network/DomainNameResolver.h b/Swiften/Network/DomainNameResolver.h index 5c83622..b99ace3 100644 --- a/Swiften/Network/DomainNameResolver.h +++ b/Swiften/Network/DomainNameResolver.h @@ -1,16 +1,17 @@ #pragma once -#include <vector> - -#include "Swiften/Network/HostAddressPort.h" +#include <boost/shared_ptr.hpp> namespace Swift { + class DomainNameServiceQuery; + class DomainNameAddressQuery; class String; class DomainNameResolver { public: virtual ~DomainNameResolver(); - virtual std::vector<HostAddressPort> resolve(const String& domain) = 0; + virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const String& name) = 0; + virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const String& name) = 0; }; } diff --git a/Swiften/Network/DomainNameServiceQuery.cpp b/Swiften/Network/DomainNameServiceQuery.cpp new file mode 100644 index 0000000..7dfd353 --- /dev/null +++ b/Swiften/Network/DomainNameServiceQuery.cpp @@ -0,0 +1,8 @@ +#include "Swiften/Network/DomainNameServiceQuery.h" + +namespace Swift { + +DomainNameServiceQuery::~DomainNameServiceQuery() { +} + +} diff --git a/Swiften/Network/DomainNameServiceQuery.h b/Swiften/Network/DomainNameServiceQuery.h new file mode 100644 index 0000000..3c08749 --- /dev/null +++ b/Swiften/Network/DomainNameServiceQuery.h @@ -0,0 +1,27 @@ +#pragma once + +#include <boost/signals.hpp> +#include <boost/optional.hpp> +#include <vector> + +#include "Swiften/Base/String.h" +#include "Swiften/Network/DomainNameResolveError.h" + +namespace Swift { + class DomainNameServiceQuery { + public: + struct Result { + Result(const String& hostname = "", int port = -1, int priority = -1, int weight = -1) : hostname(hostname), port(port), priority(priority), weight(weight) {} + String hostname; + int port; + int priority; + int weight; + }; + + virtual ~DomainNameServiceQuery(); + + virtual void run() = 0; + + boost::signal<void (const std::vector<Result>&)> onResult; + }; +} diff --git a/Swiften/Network/PlatformDomainNameResolver.cpp b/Swiften/Network/PlatformDomainNameResolver.cpp index 3803e7c..e30615b 100644 --- a/Swiften/Network/PlatformDomainNameResolver.cpp +++ b/Swiften/Network/PlatformDomainNameResolver.cpp @@ -1,200 +1,81 @@ #include "Swiften/Network/PlatformDomainNameResolver.h" -#include "Swiften/Base/Platform.h" -#include "Swiften/Base/foreach.h" +// Putting this early on, because some system types conflict with thread +#include "Swiften/Network/PlatformDomainNameServiceQuery.h" -#include <stdlib.h> +#include <string> +#include <vector> #include <boost/asio.hpp> +#include <boost/bind.hpp> +#include <boost/enable_shared_from_this.hpp> #include <idna.h> -#ifdef SWIFTEN_PLATFORM_WINDOWS -#undef UNICODE -#include <windows.h> -#include <windns.h> -#ifndef DNS_TYPE_SRV -#define DNS_TYPE_SRV 33 -#endif -#else -#include <arpa/nameser.h> -#include <arpa/nameser_compat.h> -#include <resolv.h> -#endif #include <algorithm> -#include "Swiften/Network/DomainNameResolveException.h" #include "Swiften/Base/String.h" -#include "Swiften/Base/ByteArray.h" -#include "Swiften/Network/SRVRecord.h" -#include "Swiften/Network/SRVRecordPriorityComparator.h" - -namespace Swift { - -PlatformDomainNameResolver::PlatformDomainNameResolver() { -} - -std::vector<HostAddressPort> PlatformDomainNameResolver::resolve(const String& domain) { - char* output; - if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) { - std::string outputString(output); - free(output); - return resolveDomain(outputString); - } - else { - return resolveDomain(domain.getUTF8String()); - } -} - -std::vector<HostAddressPort> PlatformDomainNameResolver::resolveDomain(const std::string& domain) { - try { - return resolveXMPPService(domain); - } - catch (const DomainNameResolveException&) { - } - std::vector<HostAddressPort> result; - result.push_back(HostAddressPort(resolveHostName(domain), 5222)); - return result; -} - -std::vector<HostAddressPort> PlatformDomainNameResolver::resolveXMPPService(const std::string& domain) { - std::vector<SRVRecord> records; - - std::string srvQuery = "_xmpp-client._tcp." + domain; - -#if defined(SWIFTEN_PLATFORM_WINDOWS) - DNS_RECORD* responses; - // FIXME: This conversion doesn't work if unicode is deffed above - if (DnsQuery(srvQuery.c_str(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) { - throw DomainNameResolveException(); - } - - DNS_RECORD* currentEntry = responses; - while (currentEntry) { - if (currentEntry->wType == DNS_TYPE_SRV) { - SRVRecord record; - record.priority = currentEntry->Data.SRV.wPriority; - record.weight = currentEntry->Data.SRV.wWeight; - record.port = currentEntry->Data.SRV.wPort; - - // The pNameTarget is actually a PCWSTR, so I would have expected this - // conversion to not work at all, but it does. - // Actually, it doesn't. Fix this and remove explicit cast - // Remove unicode undef above as well - record.hostname = std::string((const char*) currentEntry->Data.SRV.pNameTarget); - records.push_back(record); +#include "Swiften/Network/HostAddress.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/Network/HostAddressPort.h" +#include "Swiften/Network/DomainNameAddressQuery.h" + +using namespace Swift; + +namespace { + struct AddressQuery : public DomainNameAddressQuery, public boost::enable_shared_from_this<AddressQuery>, public EventOwner { + AddressQuery(const String& host) : hostname(host) {} + + virtual void run() { + boost::asio::ip::tcp::resolver resolver(ioService); + boost::asio::ip::tcp::resolver::query query(hostname.getUTF8String(), "5222"); + try { + boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query); + if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) { + emitError(); + } + else { + boost::asio::ip::address address = (*endpointIterator).endpoint().address(); + HostAddress result = (address.is_v4() ? HostAddress(&address.to_v4().to_bytes()[0], 4) : HostAddress(&address.to_v6().to_bytes()[0], 16)); + MainEventLoop::postEvent( + boost::bind(boost::ref(onResult), result, boost::optional<DomainNameResolveError>()), + shared_from_this()); + } + } + catch (...) { + emitError(); + } } - currentEntry = currentEntry->pNext; - } - DnsRecordListFree(responses, DnsFreeRecordList); - -#else - - ByteArray response; - response.resize(NS_PACKETSZ); - int responseLength = res_query(const_cast<char*>(srvQuery.c_str()), ns_c_in, ns_t_srv, reinterpret_cast<u_char*>(response.getData()), response.getSize()); - if (responseLength == -1) { - throw DomainNameResolveException(); - } - - // Parse header - HEADER* header = reinterpret_cast<HEADER*>(response.getData()); - unsigned char* messageStart = reinterpret_cast<unsigned char*>(response.getData()); - unsigned char* messageEnd = messageStart + responseLength; - unsigned char* currentEntry = messageStart + NS_HFIXEDSZ; - // Skip over the queries - int queriesCount = ntohs(header->qdcount); - while (queriesCount > 0) { - int entryLength = dn_skipname(currentEntry, messageEnd); - if (entryLength < 0) { - throw DomainNameResolveException(); + void emitError() { + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), HostAddress(), boost::optional<DomainNameResolveError>(DomainNameResolveError())), shared_from_this()); } - currentEntry += entryLength + NS_QFIXEDSZ; - queriesCount--; - } - - // Process the SRV answers - int answersCount = ntohs(header->ancount); - while (answersCount > 0) { - SRVRecord record; - int entryLength = dn_skipname(currentEntry, messageEnd); - currentEntry += entryLength; - currentEntry += NS_RRFIXEDSZ; + boost::asio::io_service ioService; + String hostname; + }; - // Priority - if (currentEntry + 2 >= messageEnd) { - throw DomainNameResolveException(); + String getNormalized(const String& domain) { + char* output; + if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) { + String result(output); + free(output); + return result; } - record.priority = ns_get16(currentEntry); - currentEntry += 2; - - // Weight - if (currentEntry + 2 >= messageEnd) { - throw DomainNameResolveException(); + else { + return domain; } - record.weight = ns_get16(currentEntry); - currentEntry += 2; + } +} - // Port - if (currentEntry + 2 >= messageEnd) { - throw DomainNameResolveException(); - } - record.port = ns_get16(currentEntry); - currentEntry += 2; +namespace Swift { - // Hostname - if (currentEntry >= messageEnd) { - throw DomainNameResolveException(); - } - ByteArray entry; - entry.resize(NS_MAXDNAME); - entryLength = dn_expand(messageStart, messageEnd, currentEntry, entry.getData(), entry.getSize()); - if (entryLength < 0) { - throw DomainNameResolveException(); - } - record.hostname = std::string(entry.getData(), entryLength); - records.push_back(record); - currentEntry += entryLength; - answersCount--; - } -#endif +PlatformDomainNameResolver::PlatformDomainNameResolver() { +} - // Resolve the hostnames in the records, and build the result list - std::sort(records.begin(), records.end(), SRVRecordPriorityComparator()); - std::vector<HostAddressPort> result; - foreach(const SRVRecord& record, records) { - try { - HostAddress address = resolveHostName(record.hostname); - result.push_back(HostAddressPort(address, record.port)); - } - catch (const DomainNameResolveException&) { - } - } - if (result.empty()) { - throw DomainNameResolveException(); - } - return result; +boost::shared_ptr<DomainNameServiceQuery> PlatformDomainNameResolver::createServiceQuery(const String& name) { + return boost::shared_ptr<DomainNameServiceQuery>(new PlatformDomainNameServiceQuery(getNormalized(name))); } -HostAddress PlatformDomainNameResolver::resolveHostName(const std::string& hostname) { - boost::asio::io_service ioService; - boost::asio::ip::tcp::resolver resolver(ioService); - boost::asio::ip::tcp::resolver::query query(hostname, "5222"); - try { - boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query); - if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) { - throw DomainNameResolveException(); - } - boost::asio::ip::address address = (*endpointIterator).endpoint().address(); - if (address.is_v4()) { - return HostAddress(&address.to_v4().to_bytes()[0], 4); - } - else { - return HostAddress(&address.to_v6().to_bytes()[0], 16); - } - } - catch (...) { - throw DomainNameResolveException(); - } +boost::shared_ptr<DomainNameAddressQuery> PlatformDomainNameResolver::createAddressQuery(const String& name) { + return boost::shared_ptr<DomainNameAddressQuery>(new AddressQuery(getNormalized(name))); } } diff --git a/Swiften/Network/PlatformDomainNameResolver.h b/Swiften/Network/PlatformDomainNameResolver.h index 651bc97..4617b15 100644 --- a/Swiften/Network/PlatformDomainNameResolver.h +++ b/Swiften/Network/PlatformDomainNameResolver.h @@ -1,12 +1,6 @@ #pragma once -#include <string> -#include <vector> - -#include "Swiften/Base/String.h" #include "Swiften/Network/DomainNameResolver.h" -#include "Swiften/Network/HostAddress.h" -#include "Swiften/Network/HostAddressPort.h" namespace Swift { class String; @@ -15,11 +9,7 @@ namespace Swift { public: PlatformDomainNameResolver(); - std::vector<HostAddressPort> resolve(const String& domain); - - private: - std::vector<HostAddressPort> resolveDomain(const std::string& domain); - std::vector<HostAddressPort> resolveXMPPService(const std::string& domain); - HostAddress resolveHostName(const std::string& hostName); + virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const String& name); + virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const String& name); }; } diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.cpp b/Swiften/Network/PlatformDomainNameServiceQuery.cpp new file mode 100644 index 0000000..eeb9fd6 --- /dev/null +++ b/Swiften/Network/PlatformDomainNameServiceQuery.cpp @@ -0,0 +1,156 @@ +#include "Swiften/Network/PlatformDomainNameServiceQuery.h" + +#include "Swiften/Base/Platform.h" +#include <stdlib.h> +#ifdef SWIFTEN_PLATFORM_WINDOWS +#undef UNICODE +#include <windows.h> +#include <windns.h> +#ifndef DNS_TYPE_SRV +#define DNS_TYPE_SRV 33 +#endif +#else +#include <arpa/nameser.h> +#include <arpa/nameser_compat.h> +#include <resolv.h> +#endif +#include <boost/bind.hpp> + +#include "Swiften/Base/ByteArray.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/Base/foreach.h" + +using namespace Swift; + +namespace { + struct SRVRecordPriorityComparator { + bool operator()(const DomainNameServiceQuery::Result& a, const DomainNameServiceQuery::Result& b) const { + return a.priority < b.priority; + } + }; +} + +namespace Swift { + +PlatformDomainNameServiceQuery::PlatformDomainNameServiceQuery(const String& service) : service(service) { +} + +void PlatformDomainNameServiceQuery::run() { + std::vector<DomainNameServiceQuery::Result> records; + +#if defined(SWIFTEN_PLATFORM_WINDOWS) + DNS_RECORD* responses; + // FIXME: This conversion doesn't work if unicode is deffed above + if (DnsQuery(service.getUTF8Data(), DNS_TYPE_SRV, DNS_QUERY_STANDARD, NULL, &responses, NULL) != ERROR_SUCCESS) { + emitError(); + return; + } + + DNS_RECORD* currentEntry = responses; + while (currentEntry) { + if (currentEntry->wType == DNS_TYPE_SRV) { + SRVRecord record; + record.priority = currentEntry->Data.SRV.wPriority; + record.weight = currentEntry->Data.SRV.wWeight; + record.port = currentEntry->Data.SRV.wPort; + + // The pNameTarget is actually a PCWSTR, so I would have expected this + // conversion to not work at all, but it does. + // Actually, it doesn't. Fix this and remove explicit cast + // Remove unicode undef above as well + record.hostname = std::string((const char*) currentEntry->Data.SRV.pNameTarget); + records.push_back(record); + } + currentEntry = currentEntry->pNext; + } + DnsRecordListFree(responses, DnsFreeRecordList); + +#else + + ByteArray response; + response.resize(NS_PACKETSZ); + int responseLength = res_query(const_cast<char*>(service.getUTF8Data()), ns_c_in, ns_t_srv, reinterpret_cast<u_char*>(response.getData()), response.getSize()); + if (responseLength == -1) { + emitError(); + return; + } + + // Parse header + HEADER* header = reinterpret_cast<HEADER*>(response.getData()); + unsigned char* messageStart = reinterpret_cast<unsigned char*>(response.getData()); + unsigned char* messageEnd = messageStart + responseLength; + unsigned char* currentEntry = messageStart + NS_HFIXEDSZ; + + // Skip over the queries + int queriesCount = ntohs(header->qdcount); + while (queriesCount > 0) { + int entryLength = dn_skipname(currentEntry, messageEnd); + if (entryLength < 0) { + emitError(); + return; + } + currentEntry += entryLength + NS_QFIXEDSZ; + queriesCount--; + } + + // Process the SRV answers + int answersCount = ntohs(header->ancount); + while (answersCount > 0) { + DomainNameServiceQuery::Result record; + + int entryLength = dn_skipname(currentEntry, messageEnd); + currentEntry += entryLength; + currentEntry += NS_RRFIXEDSZ; + + // Priority + if (currentEntry + 2 >= messageEnd) { + emitError(); + return; + } + record.priority = ns_get16(currentEntry); + currentEntry += 2; + + // Weight + if (currentEntry + 2 >= messageEnd) { + emitError(); + return; + } + record.weight = ns_get16(currentEntry); + currentEntry += 2; + + // Port + if (currentEntry + 2 >= messageEnd) { + emitError(); + return; + } + record.port = ns_get16(currentEntry); + currentEntry += 2; + + // Hostname + if (currentEntry >= messageEnd) { + emitError(); + return; + } + ByteArray entry; + entry.resize(NS_MAXDNAME); + entryLength = dn_expand(messageStart, messageEnd, currentEntry, entry.getData(), entry.getSize()); + if (entryLength < 0) { + emitError(); + return; + } + record.hostname = String(entry.getData()); + records.push_back(record); + currentEntry += entryLength; + answersCount--; + } +#endif + + std::sort(records.begin(), records.end(), SRVRecordPriorityComparator()); + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), records)); +} + +void PlatformDomainNameServiceQuery::emitError() { + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), std::vector<DomainNameServiceQuery::Result>()), shared_from_this()); +} + +} diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.h b/Swiften/Network/PlatformDomainNameServiceQuery.h new file mode 100644 index 0000000..58257af --- /dev/null +++ b/Swiften/Network/PlatformDomainNameServiceQuery.h @@ -0,0 +1,22 @@ +#pragma once + +#include <boost/enable_shared_from_this.hpp> + +#include "Swiften/Network/DomainNameServiceQuery.h" +#include "Swiften/EventLoop/EventOwner.h" +#include "Swiften/Base/String.h" + +namespace Swift { + class PlatformDomainNameServiceQuery : public DomainNameServiceQuery, public boost::enable_shared_from_this<PlatformDomainNameServiceQuery>, public EventOwner { + public: + PlatformDomainNameServiceQuery(const String& service); + + virtual void run(); + + private: + void emitError(); + + private: + String service; + }; +} diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript index 475d6e4..9aa8139 100644 --- a/Swiften/Network/SConscript +++ b/Swiften/Network/SConscript @@ -15,7 +15,10 @@ objects = myenv.StaticObject([ "TimerFactory.cpp", "BoostTimerFactory.cpp", "DomainNameResolver.cpp", + "DomainNameAddressQuery.cpp", + "DomainNameServiceQuery.cpp", "PlatformDomainNameResolver.cpp", + "PlatformDomainNameServiceQuery.cpp", "StaticDomainNameResolver.cpp", "HostAddress.cpp", "Timer.cpp", diff --git a/Swiften/Network/SRVRecord.h b/Swiften/Network/SRVRecord.h deleted file mode 100644 index 5a11635..0000000 --- a/Swiften/Network/SRVRecord.h +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once - -namespace Swift { - struct SRVRecord { - std::string hostname; - int port; - int priority; - int weight; - }; -} diff --git a/Swiften/Network/StaticDomainNameResolver.cpp b/Swiften/Network/StaticDomainNameResolver.cpp index 8ca4062..275ec78 100644 --- a/Swiften/Network/StaticDomainNameResolver.cpp +++ b/Swiften/Network/StaticDomainNameResolver.cpp @@ -1,28 +1,52 @@ #include "Swiften/Network/StaticDomainNameResolver.h" -#include "Swiften/Network/DomainNameResolveException.h" +#include "Swiften/Network/DomainNameResolveError.h" #include "Swiften/Base/String.h" +using namespace Swift; + +namespace { + struct ServiceQuery : public DomainNameServiceQuery { + ServiceQuery(const String& service, Swift::StaticDomainNameResolver* resolver) : service(service), resolver(resolver) {} + + virtual void run() { + } + + String service; + StaticDomainNameResolver* resolver; + }; + + struct AddressQuery : public DomainNameAddressQuery { + AddressQuery(const String& host, StaticDomainNameResolver* resolver) : host(host), resolver(resolver) {} + + virtual void run() { + } + + String host; + StaticDomainNameResolver* resolver; + }; +} + namespace Swift { -StaticDomainNameResolver::StaticDomainNameResolver() { +void StaticDomainNameResolver::addAddress(const String& domain, const HostAddress& address) { + addresses[domain] = address; } -std::vector<HostAddressPort> StaticDomainNameResolver::resolve(const String& queriedDomain) { - std::vector<HostAddressPort> result; +void StaticDomainNameResolver::addService(const String& service, const DomainNameServiceQuery::Result& result) { + services.push_back(std::make_pair(service, result)); +} - for(DomainCollection::const_iterator i = domains.begin(); i != domains.end(); ++i) { - if (i->first == queriedDomain) { - result.push_back(i->second); - } - } - if (result.empty()) { - throw DomainNameResolveException(); - } - return result; +void StaticDomainNameResolver::addXMPPClientService(const String& domain, const HostAddressPort& address) { + addService("_xmpp-client._tcp." + domain, ServiceQuery::Result(domain, 5222, 0, 0)); + addAddress(domain, address.getAddress()); +} + +boost::shared_ptr<DomainNameServiceQuery> StaticDomainNameResolver::createServiceQuery(const String& name) { + return boost::shared_ptr<DomainNameServiceQuery>(new ServiceQuery(name, this)); } -void StaticDomainNameResolver::addDomain(const String& domain, const HostAddressPort& addressPort) { - domains.push_back(std::make_pair(domain, addressPort)); +boost::shared_ptr<DomainNameAddressQuery> StaticDomainNameResolver::createAddressQuery(const String& name) { + return boost::shared_ptr<DomainNameAddressQuery>(new AddressQuery(name, this)); } } diff --git a/Swiften/Network/StaticDomainNameResolver.h b/Swiften/Network/StaticDomainNameResolver.h index 8688429..ed8c613 100644 --- a/Swiften/Network/StaticDomainNameResolver.h +++ b/Swiften/Network/StaticDomainNameResolver.h @@ -3,21 +3,38 @@ #include <vector> #include <map> +#include "Swiften/Network/HostAddress.h" +#include "Swiften/Network/HostAddressPort.h" #include "Swiften/Network/DomainNameResolver.h" +#include "Swiften/Network/DomainNameServiceQuery.h" +#include "Swiften/Network/DomainNameAddressQuery.h" namespace Swift { class String; class StaticDomainNameResolver : public DomainNameResolver { public: - StaticDomainNameResolver(); + typedef std::map<String, HostAddress> AddressesMap; + typedef std::vector< std::pair<String, DomainNameServiceQuery::Result> > ServicesCollection; - virtual std::vector<HostAddressPort> resolve(const String& domain); + public: + void addAddress(const String& domain, const HostAddress& address); + void addService(const String& service, const DomainNameServiceQuery::Result& result); + void addXMPPClientService(const String& domain, const HostAddressPort&); + + const AddressesMap& getAddresses() const { + return addresses; + } + + const ServicesCollection& getServices() const { + return services; + } - void addDomain(const String& domain, const HostAddressPort& addressPort); + virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const String& name); + virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const String& name); private: - typedef std::vector< std::pair<String, HostAddressPort> > DomainCollection; - DomainCollection domains; + AddressesMap addresses; + ServicesCollection services; }; } diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp index 32893d8..05c6e28 100644 --- a/Swiften/Network/UnitTest/ConnectorTest.cpp +++ b/Swiften/Network/UnitTest/ConnectorTest.cpp @@ -40,8 +40,8 @@ class ConnectorTest : public CppUnit::TestFixture { void testConnect() { std::auto_ptr<Connector> testling(createConnector()); - resolver->addDomain("foo.com", host1); - resolver->addDomain("foo.com", host2); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); testling->start(); eventLoop->processEvents(); @@ -63,8 +63,8 @@ class ConnectorTest : public CppUnit::TestFixture { void testConnect_FirstHostFails() { std::auto_ptr<Connector> testling(createConnector()); - resolver->addDomain("foo.com", host1); - resolver->addDomain("foo.com", host2); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); connectionFactory->failingPorts.push_back(host1); testling->start(); @@ -76,8 +76,8 @@ class ConnectorTest : public CppUnit::TestFixture { void testConnect_AllHostsFail() { std::auto_ptr<Connector> testling(createConnector()); - resolver->addDomain("foo.com", host1); - resolver->addDomain("foo.com", host2); + resolver->addXMPPClientService("foo.com", host1); + resolver->addXMPPClientService("foo.com", host2); connectionFactory->failingPorts.push_back(host1); connectionFactory->failingPorts.push_back(host2); diff --git a/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp b/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp index cb812a1..09837d6 100644 --- a/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp +++ b/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp @@ -1,76 +1,168 @@ #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> +#include <boost/bind.hpp> +#include "Swiften/Base/sleep.h" #include "Swiften/Base/String.h" +#include "Swiften/Base/ByteArray.h" #include "Swiften/Network/PlatformDomainNameResolver.h" -#include "Swiften/Network/DomainNameResolveException.h" +#include "Swiften/Network/DomainNameAddressQuery.h" +#include "Swiften/Network/DomainNameServiceQuery.h" +#include "Swiften/EventLoop/DummyEventLoop.h" using namespace Swift; class DomainNameResolverTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(DomainNameResolverTest); - CPPUNIT_TEST(testResolve_NoSRV); - CPPUNIT_TEST(testResolve_SRV); - CPPUNIT_TEST(testResolve_Invalid); - //CPPUNIT_TEST(testResolve_IPv6); - CPPUNIT_TEST(testResolve_International); - CPPUNIT_TEST(testResolve_Localhost); + CPPUNIT_TEST(testResolveAddress); + CPPUNIT_TEST(testResolveAddress_Error); + //CPPUNIT_TEST(testResolveAddress_IPv6); + CPPUNIT_TEST(testResolveAddress_International); + CPPUNIT_TEST(testResolveAddress_Localhost); + CPPUNIT_TEST(testResolveService); + CPPUNIT_TEST(testResolveService_Error); CPPUNIT_TEST_SUITE_END(); public: DomainNameResolverTest() {} void setUp() { - resolver_ = new PlatformDomainNameResolver(); + eventLoop = new DummyEventLoop(); + resolver = new PlatformDomainNameResolver(); + resultsAvailable = false; } void tearDown() { - delete resolver_; + delete resolver; + delete eventLoop; } - void testResolve_NoSRV() { - HostAddressPort result = resolver_->resolve("xmpp.test.swift.im")[0]; + void testResolveAddress() { + boost::shared_ptr<DomainNameAddressQuery> query(createAddressQuery("xmpp.test.swift.im")); - CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.0"), result.getAddress().toString()); - CPPUNIT_ASSERT_EQUAL(5222, result.getPort()); + query->run(); + waitForResults(); + + CPPUNIT_ASSERT(!addressQueryError); + CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.0"), addressQueryResult.toString()); } - void testResolve_SRV() { - std::vector<HostAddressPort> result = resolver_->resolve("xmpp-srv.test.swift.im"); + void testResolveAddress_Error() { + boost::shared_ptr<DomainNameAddressQuery> query(createAddressQuery("invalid.test.swift.im")); + + query->run(); + waitForResults(); - CPPUNIT_ASSERT_EQUAL(3, static_cast<int>(result.size())); - CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.1"), result[0].getAddress().toString()); - CPPUNIT_ASSERT_EQUAL(5000, result[0].getPort()); - CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.3"), result[1].getAddress().toString()); - CPPUNIT_ASSERT_EQUAL(5000, result[1].getPort()); - CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.2"), result[2].getAddress().toString()); - CPPUNIT_ASSERT_EQUAL(5000, result[2].getPort()); + CPPUNIT_ASSERT(addressQueryError); } - void testResolve_Invalid() { - CPPUNIT_ASSERT_THROW(resolver_->resolve("invalid.test.swift.im"), DomainNameResolveException); + void testResolveAddress_IPv6() { + boost::shared_ptr<DomainNameAddressQuery> query(createAddressQuery("xmpp-ipv6.test.swift.im")); + + query->run(); + waitForResults(); + + CPPUNIT_ASSERT(!addressQueryError); + CPPUNIT_ASSERT_EQUAL(std::string("0000:0000:0000:0000:0000:ffff:0a00:0104"), addressQueryResult.toString()); } - void testResolve_IPv6() { - HostAddressPort result = resolver_->resolve("xmpp-ipv6.test.swift.im")[0]; - CPPUNIT_ASSERT_EQUAL(std::string("0000:0000:0000:0000:0000:ffff:0a00:0104"), result.getAddress().toString()); - CPPUNIT_ASSERT_EQUAL(5222, result.getPort()); + void testResolveAddress_International() { + boost::shared_ptr<DomainNameAddressQuery> query(createAddressQuery("tron\xc3\xa7on.test.swift.im")); + + query->run(); + waitForResults(); + + CPPUNIT_ASSERT(!addressQueryError); + CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.3"), addressQueryResult.toString()); } - void testResolve_International() { - HostAddressPort result = resolver_->resolve("tron\xc3\xa7on.test.swift.im")[0]; - CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.3"), result.getAddress().toString()); - CPPUNIT_ASSERT_EQUAL(5222, result.getPort()); + void testResolveAddress_Localhost() { + boost::shared_ptr<DomainNameAddressQuery> query(createAddressQuery("localhost")); + + query->run(); + waitForResults(); + + CPPUNIT_ASSERT(!addressQueryError); + CPPUNIT_ASSERT_EQUAL(std::string("127.0.0.1"), addressQueryResult.toString()); } - void testResolve_Localhost() { - HostAddressPort result = resolver_->resolve("localhost")[0]; - CPPUNIT_ASSERT_EQUAL(std::string("127.0.0.1"), result.getAddress().toString()); - CPPUNIT_ASSERT_EQUAL(5222, result.getPort()); + + void testResolveService() { + boost::shared_ptr<DomainNameServiceQuery> query(createServiceQuery("_xmpp-client._tcp.xmpp-srv.test.swift.im")); + + query->run(); + waitForResults(); + + CPPUNIT_ASSERT_EQUAL(4, static_cast<int>(serviceQueryResult.size())); + CPPUNIT_ASSERT_EQUAL(String("xmpp1.test.swift.im"), serviceQueryResult[0].hostname); + CPPUNIT_ASSERT_EQUAL(5000, serviceQueryResult[0].port); + CPPUNIT_ASSERT_EQUAL(0, serviceQueryResult[0].priority); + CPPUNIT_ASSERT_EQUAL(1, serviceQueryResult[0].weight); + CPPUNIT_ASSERT_EQUAL(String("xmpp-invalid.test.swift.im"), serviceQueryResult[1].hostname); + CPPUNIT_ASSERT_EQUAL(5000, serviceQueryResult[1].port); + CPPUNIT_ASSERT_EQUAL(1, serviceQueryResult[1].priority); + CPPUNIT_ASSERT_EQUAL(100, serviceQueryResult[1].weight); + CPPUNIT_ASSERT_EQUAL(String("xmpp3.test.swift.im"), serviceQueryResult[2].hostname); + CPPUNIT_ASSERT_EQUAL(5000, serviceQueryResult[2].port); + CPPUNIT_ASSERT_EQUAL(3, serviceQueryResult[2].priority); + CPPUNIT_ASSERT_EQUAL(100, serviceQueryResult[2].weight); + CPPUNIT_ASSERT_EQUAL(String("xmpp2.test.swift.im"), serviceQueryResult[3].hostname); + CPPUNIT_ASSERT_EQUAL(5000, serviceQueryResult[3].port); + CPPUNIT_ASSERT_EQUAL(5, serviceQueryResult[3].priority); + CPPUNIT_ASSERT_EQUAL(100, serviceQueryResult[3].weight); + } + + void testResolveService_Error() { } +/* + } + */ + + private: + boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const String& domain) { + boost::shared_ptr<DomainNameAddressQuery> result = resolver->createAddressQuery(domain); + result->onResult.connect(boost::bind(&DomainNameResolverTest::handleAddressQueryResult, this, _1, _2)); + return result; + } + + void handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error) { + addressQueryResult = address; + addressQueryError = error; + resultsAvailable = true; + } + + boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const String& domain) { + boost::shared_ptr<DomainNameServiceQuery> result = resolver->createServiceQuery(domain); + result->onResult.connect(boost::bind(&DomainNameResolverTest::handleServiceQueryResult, this, _1)); + return result; + } + + void handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result) { + serviceQueryResult = result; + resultsAvailable = true; + } + + void waitForResults() { + eventLoop->processEvents(); + int ticks = 0; + while (!resultsAvailable) { + ticks++; + if (ticks > 1000) { + CPPUNIT_ASSERT(false); + } + Swift::sleep(10); + eventLoop->processEvents(); + } + } + private: - PlatformDomainNameResolver* resolver_; + DummyEventLoop* eventLoop; + bool resultsAvailable; + HostAddress addressQueryResult; + boost::optional<DomainNameResolveError> addressQueryError; + std::vector<DomainNameServiceQuery::Result> serviceQueryResult; + PlatformDomainNameResolver* resolver; }; CPPUNIT_TEST_SUITE_REGISTRATION(DomainNameResolverTest); -- cgit v0.10.2-6-g49f6