From 897ad55ffee76c9e84ffb174d700f6182a3e7d0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Remko=20Tron=C3=A7on?= Date: Fri, 4 Dec 2009 22:49:44 +0100 Subject: Implemented CAresDomainNameResolver. diff --git a/3rdParty/CAres/.gitignore b/3rdParty/CAres/.gitignore new file mode 100644 index 0000000..a1a869c --- /dev/null +++ b/3rdParty/CAres/.gitignore @@ -0,0 +1 @@ +include diff --git a/3rdParty/CAres/SConscript b/3rdParty/CAres/SConscript index 9a3c379..f1bb103 100644 --- a/3rdParty/CAres/SConscript +++ b/3rdParty/CAres/SConscript @@ -2,7 +2,7 @@ Import("env") env["CARES_FLAGS"] = { "CPPDEFINES": ["CARES_STATICLIB"], - "CPPPATH": [Dir("src"), Dir(".")], + "CPPPATH": [Dir("include")], "LIBPATH": [Dir(".")], "LIBS": ["CAres"], } @@ -14,6 +14,12 @@ myenv.Append(CPPPATH = ["src", "."]) if myenv["PLATFORM"] != "win32" : myenv.Append(CPPDEFINES = ["HAVE_CONFIG_H"]) +myenv.Install("include", [ + "src/ares.h", + "src/ares_version.h", + "src/ares_build.h", + "src/ares_rules.h" + ]) myenv.StaticLibrary("CAres", [ "src/ares__close_sockets.c", "src/ares__get_hostent.c", diff --git a/Swiften/Network/CAresDomainNameResolver.cpp b/Swiften/Network/CAresDomainNameResolver.cpp new file mode 100644 index 0000000..6daba3d --- /dev/null +++ b/Swiften/Network/CAresDomainNameResolver.cpp @@ -0,0 +1,159 @@ +// TODO: Check the second param of postEvent. We sometimes omit it. Same +// goes for the PlatformDomainNameResolver. + +#include "Swiften/Network/CAresDomainNameResolver.h" + +#include +#include +#include + +#include "Swiften/Network/DomainNameServiceQuery.h" +#include "Swiften/Network/DomainNameAddressQuery.h" +#include "Swiften/Base/ByteArray.h" +#include "Swiften/EventLoop/MainEventLoop.h" +#include "Swiften/Base/foreach.h" + +namespace Swift { + +class CAresQuery : public boost::enable_shared_from_this, public EventOwner { + public: + CAresQuery(const String& query, int dnsclass, int type, CAresDomainNameResolver* resolver) : query(query), dnsclass(dnsclass), type(type), resolver(resolver) { + } + + virtual ~CAresQuery() { + } + + void addToQueue() { + resolver->addToQueue(shared_from_this()); + } + + void doRun(ares_channel* channel) { + ares_query(*channel, query.getUTF8Data(), dnsclass, type, &CAresQuery::handleResult, this); + } + + static void handleResult(void* arg, int status, int timeouts, unsigned char* buffer, int len) { + reinterpret_cast(arg)->handleResult(status, timeouts, buffer, len); + } + + virtual void handleResult(int status, int, unsigned char* buffer, int len) = 0; + + private: + String query; + int dnsclass; + int type; + CAresDomainNameResolver* resolver; +}; + +class CAresDomainNameServiceQuery : public DomainNameServiceQuery, public CAresQuery { + public: + CAresDomainNameServiceQuery(const String& service, CAresDomainNameResolver* resolver) : CAresQuery(service, 1, 33, resolver) { + } + + virtual void run() { + addToQueue(); + } + + void handleResult(int status, int, unsigned char* buffer, int len) { + if (status == ARES_SUCCESS) { + std::vector records; + ares_srv_reply* rawRecords; + if (ares_parse_srv_reply(buffer, len, &rawRecords) == ARES_SUCCESS) { + for( ; rawRecords != NULL; rawRecords = rawRecords->next) { + DomainNameServiceQuery::Result record; + record.priority = rawRecords->priority; + record.weight = rawRecords->weight; + record.port = rawRecords->port; + record.hostname = String(rawRecords->host); + records.push_back(record); + } + } + std::sort(records.begin(), records.end(), ResultPriorityComparator()); + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), records)); + } + else if (status != ARES_EDESTRUCTION) { + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), std::vector()), shared_from_this()); + } + } +}; + +class CAresDomainNameAddressQuery : public DomainNameAddressQuery, public CAresQuery { + public: + CAresDomainNameAddressQuery(const String& host, CAresDomainNameResolver* resolver) : CAresQuery(host, 1, 1, resolver) { + } + + virtual void run() { + addToQueue(); + } + + void handleResult(int status, int, unsigned char* buffer, int len) { + if (status == ARES_SUCCESS) { + struct hostent* hosts; + if (ares_parse_a_reply(buffer, len, &hosts, NULL, NULL) == ARES_SUCCESS) { + // Check whether the different fields are what we expect them to be + struct in_addr addr; + addr.s_addr = *(unsigned int*)hosts->h_addr_list[0]; + HostAddress result(inet_ntoa(addr)); + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), result, boost::optional()), boost::dynamic_pointer_cast(shared_from_this())); + ares_free_hostent(hosts); + } + else { + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), HostAddress(), boost::optional(DomainNameResolveError())), shared_from_this()); + } + } + else if (status != ARES_EDESTRUCTION) { + MainEventLoop::postEvent(boost::bind(boost::ref(onResult), HostAddress(), boost::optional(DomainNameResolveError())), shared_from_this()); + } + } +}; + +CAresDomainNameResolver::CAresDomainNameResolver() : stopRequested(false) { + ares_init(&channel); + thread = new boost::thread(boost::bind(&CAresDomainNameResolver::run, this)); +} + +CAresDomainNameResolver::~CAresDomainNameResolver() { + stopRequested = true; + thread->join(); + ares_destroy(channel); +} + +boost::shared_ptr CAresDomainNameResolver::createServiceQuery(const String& name) { + return boost::shared_ptr(new CAresDomainNameServiceQuery(getNormalized(name), this)); +} + +boost::shared_ptr CAresDomainNameResolver::createAddressQuery(const String& name) { + return boost::shared_ptr(new CAresDomainNameAddressQuery(getNormalized(name), this)); +} + +void CAresDomainNameResolver::addToQueue(boost::shared_ptr query) { + boost::lock_guard lock(pendingQueriesMutex); + pendingQueries.push_back(query); +} + +void CAresDomainNameResolver::run() { + fd_set readers, writers; + struct timeval timeout; + timeout.tv_sec = 0; + timeout.tv_usec = 100000; + while(!stopRequested) { + { + boost::unique_lock lock(pendingQueriesMutex); + foreach(const boost::shared_ptr& query, pendingQueries) { + query->doRun(&channel); + } + pendingQueries.clear(); + } + FD_ZERO(&readers); + FD_ZERO(&writers); + int nfds = ares_fds(channel, &readers, &writers); + //if (nfds) { + // break; + //} + struct timeval tv; + struct timeval* tvp = ares_timeout(channel, &timeout, &tv); + select(nfds, &readers, &writers, NULL, tvp); + ares_process(channel, &readers, &writers); + } +} + +} diff --git a/Swiften/Network/CAresDomainNameResolver.h b/Swiften/Network/CAresDomainNameResolver.h new file mode 100644 index 0000000..0cdd163 --- /dev/null +++ b/Swiften/Network/CAresDomainNameResolver.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include +#include + +#include "Swiften/Network/DomainNameResolver.h" + +namespace Swift { + class CAresQuery; + + class CAresDomainNameResolver : public DomainNameResolver { + public: + CAresDomainNameResolver(); + ~CAresDomainNameResolver(); + + virtual boost::shared_ptr createServiceQuery(const String& name); + virtual boost::shared_ptr createAddressQuery(const String& name); + + private: + friend class CAresQuery; + + void run(); + void addToQueue(boost::shared_ptr); + + private: + bool stopRequested; + ares_channel channel; + boost::thread* thread; + boost::mutex pendingQueriesMutex; + std::list< boost::shared_ptr > pendingQueries; + }; +} diff --git a/Swiften/Network/DomainNameResolver.cpp b/Swiften/Network/DomainNameResolver.cpp index 907dfc9..63ed881 100644 --- a/Swiften/Network/DomainNameResolver.cpp +++ b/Swiften/Network/DomainNameResolver.cpp @@ -1,8 +1,22 @@ #include "Swiften/Network/DomainNameResolver.h" +#include + namespace Swift { DomainNameResolver::~DomainNameResolver() { } +String DomainNameResolver::getNormalized(const String& domain) { + char* output; + if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) { + String result(output); + free(output); + return result; + } + else { + return domain; + } +} + } diff --git a/Swiften/Network/DomainNameResolver.h b/Swiften/Network/DomainNameResolver.h index b99ace3..d3dab26 100644 --- a/Swiften/Network/DomainNameResolver.h +++ b/Swiften/Network/DomainNameResolver.h @@ -2,6 +2,8 @@ #include +#include "Swiften/Base/String.h" + namespace Swift { class DomainNameServiceQuery; class DomainNameAddressQuery; @@ -13,5 +15,8 @@ namespace Swift { virtual boost::shared_ptr createServiceQuery(const String& name) = 0; virtual boost::shared_ptr createAddressQuery(const String& name) = 0; + + protected: + static String getNormalized(const String& domain); }; } diff --git a/Swiften/Network/DomainNameServiceQuery.h b/Swiften/Network/DomainNameServiceQuery.h index 3c08749..57e48d3 100644 --- a/Swiften/Network/DomainNameServiceQuery.h +++ b/Swiften/Network/DomainNameServiceQuery.h @@ -18,6 +18,12 @@ namespace Swift { int weight; }; + struct ResultPriorityComparator { + bool operator()(const DomainNameServiceQuery::Result& a, const DomainNameServiceQuery::Result& b) const { + return a.priority < b.priority; + } + }; + virtual ~DomainNameServiceQuery(); virtual void run() = 0; diff --git a/Swiften/Network/PlatformDomainNameResolver.cpp b/Swiften/Network/PlatformDomainNameResolver.cpp index 755b177..5ffa2fb 100644 --- a/Swiften/Network/PlatformDomainNameResolver.cpp +++ b/Swiften/Network/PlatformDomainNameResolver.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include "Swiften/Base/String.h" @@ -77,17 +76,6 @@ namespace { bool safeToJoin; }; - String getNormalized(const String& domain) { - char* output; - if (idna_to_ascii_8z(domain.getUTF8Data(), &output, 0) == IDNA_SUCCESS) { - String result(output); - free(output); - return result; - } - else { - return domain; - } - } } namespace Swift { diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.cpp b/Swiften/Network/PlatformDomainNameServiceQuery.cpp index d6c87dc..659f397 100644 --- a/Swiften/Network/PlatformDomainNameServiceQuery.cpp +++ b/Swiften/Network/PlatformDomainNameServiceQuery.cpp @@ -22,14 +22,6 @@ using namespace Swift; -namespace { - struct SRVRecordPriorityComparator { - bool operator()(const DomainNameServiceQuery::Result& a, const DomainNameServiceQuery::Result& b) const { - return a.priority < b.priority; - } - }; -} - namespace Swift { PlatformDomainNameServiceQuery::PlatformDomainNameServiceQuery(const String& service) : thread(NULL), service(service), safeToJoin(true) { @@ -165,7 +157,7 @@ void PlatformDomainNameServiceQuery::doRun() { #endif safeToJoin = true; - std::sort(records.begin(), records.end(), SRVRecordPriorityComparator()); + std::sort(records.begin(), records.end(), ResultPriorityComparator()); std::cout << "Sending out " << records.size() << " SRV results " << std::endl; MainEventLoop::postEvent(boost::bind(boost::ref(onResult), records)); } diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript index 9aa8139..d63b673 100644 --- a/Swiften/Network/SConscript +++ b/Swiften/Network/SConscript @@ -1,7 +1,8 @@ Import("swiften_env") myenv = swiften_env.Clone() -myenv.MergeFlags(swiften_env["LIBIDN_FLAGS"]) +myenv.MergeFlags(myenv["LIBIDN_FLAGS"]) +myenv.MergeFlags(myenv["CARES_FLAGS"]) objects = myenv.StaticObject([ "BoostConnection.cpp", @@ -19,6 +20,7 @@ objects = myenv.StaticObject([ "DomainNameServiceQuery.cpp", "PlatformDomainNameResolver.cpp", "PlatformDomainNameServiceQuery.cpp", + "CAresDomainNameResolver.cpp", "StaticDomainNameResolver.cpp", "HostAddress.cpp", "Timer.cpp", -- cgit v0.10.2-6-g49f6