From 312a114c7e204cfe4cfe961509ab9b24ccde7860 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Remko=20Tron=C3=A7on?= Date: Sat, 18 Dec 2010 19:50:50 +0100 Subject: Move all domain name resolve queries into one thread. This avoids reentrancy problems on some platform DNS calls. Resolves: #443 diff --git a/Swiften/Examples/BenchTool/BenchTool.cpp b/Swiften/Examples/BenchTool/BenchTool.cpp index a5c0925..1dcc8c2 100644 --- a/Swiften/Examples/BenchTool/BenchTool.cpp +++ b/Swiften/Examples/BenchTool/BenchTool.cpp @@ -22,6 +22,8 @@ using namespace Swift; SimpleEventLoop eventLoop; BoostNetworkFactories networkFactories(&eventLoop); int numberOfConnectedClients = 0; +int numberOfInstances = 100; + void handleConnected() { numberOfConnectedClients++; @@ -29,8 +31,6 @@ void handleConnected() { } int main(int, char**) { - int numberOfInstances = 1000; - char* jid = getenv("SWIFT_BENCHTOOL_JID"); if (!jid) { std::cerr << "Please set the SWIFT_BENCHTOOL_JID environment variable" << std::endl; diff --git a/Swiften/Network/DomainNameAddressQuery.h b/Swiften/Network/DomainNameAddressQuery.h index 390916f..5bac350 100644 --- a/Swiften/Network/DomainNameAddressQuery.h +++ b/Swiften/Network/DomainNameAddressQuery.h @@ -16,6 +16,8 @@ namespace Swift { class DomainNameAddressQuery { public: + typedef boost::shared_ptr ref; + virtual ~DomainNameAddressQuery(); virtual void run() = 0; diff --git a/Swiften/Network/DomainNameServiceQuery.h b/Swiften/Network/DomainNameServiceQuery.h index 3ba3a00..fb44e82 100644 --- a/Swiften/Network/DomainNameServiceQuery.h +++ b/Swiften/Network/DomainNameServiceQuery.h @@ -9,6 +9,7 @@ #include "Swiften/Base/boost_bsignals.h" #include #include +#include #include "Swiften/Base/String.h" #include "Swiften/Network/DomainNameResolveError.h" @@ -16,6 +17,8 @@ namespace Swift { class DomainNameServiceQuery { public: + typedef boost::shared_ptr ref; + struct Result { Result(const String& hostname = "", int port = -1, int priority = -1, int weight = -1) : hostname(hostname), port(port), priority(priority), weight(weight) {} String hostname; diff --git a/Swiften/Network/PlatformDomainNameAddressQuery.cpp b/Swiften/Network/PlatformDomainNameAddressQuery.cpp new file mode 100644 index 0000000..2a8574d --- /dev/null +++ b/Swiften/Network/PlatformDomainNameAddressQuery.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include + +#include +#include + +namespace Swift { + +PlatformDomainNameAddressQuery::PlatformDomainNameAddressQuery(const String& host, EventLoop* eventLoop, PlatformDomainNameResolver* resolver) : PlatformDomainNameQuery(resolver), hostname(host), eventLoop(eventLoop) { +} + +void PlatformDomainNameAddressQuery::run() { + getResolver()->addQueryToQueue(shared_from_this()); +} + +void PlatformDomainNameAddressQuery::runBlocking() { + //std::cout << "PlatformDomainNameResolver::doRun()" << std::endl; + boost::asio::ip::tcp::resolver resolver(ioService); + boost::asio::ip::tcp::resolver::query query(hostname.getUTF8String(), "5222"); + try { + //std::cout << "PlatformDomainNameResolver::doRun(): Resolving" << std::endl; + boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query); + //std::cout << "PlatformDomainNameResolver::doRun(): Resolved" << std::endl; + if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) { + //std::cout << "PlatformDomainNameResolver::doRun(): Error 1" << std::endl; + emitError(); + } + else { + std::vector results; + for ( ; endpointIterator != boost::asio::ip::tcp::resolver::iterator(); ++endpointIterator) { + boost::asio::ip::address address = (*endpointIterator).endpoint().address(); + results.push_back(address.is_v4() ? HostAddress(&address.to_v4().to_bytes()[0], 4) : HostAddress(&address.to_v6().to_bytes()[0], 16)); + } + + //std::cout << "PlatformDomainNameResolver::doRun(): Success" << std::endl; + eventLoop->postEvent( + boost::bind(boost::ref(onResult), results, boost::optional()), + shared_from_this()); + } + } + catch (...) { + //std::cout << "PlatformDomainNameResolver::doRun(): Error 2" << std::endl; + emitError(); + } +} + +void PlatformDomainNameAddressQuery::emitError() { + eventLoop->postEvent(boost::bind(boost::ref(onResult), std::vector(), boost::optional(DomainNameResolveError())), shared_from_this()); +} + +} diff --git a/Swiften/Network/PlatformDomainNameAddressQuery.h b/Swiften/Network/PlatformDomainNameAddressQuery.h new file mode 100644 index 0000000..0153688 --- /dev/null +++ b/Swiften/Network/PlatformDomainNameAddressQuery.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace Swift { + class PlatformDomainNameResolver; + class EventLoop; + + class PlatformDomainNameAddressQuery : public DomainNameAddressQuery, public PlatformDomainNameQuery, public boost::enable_shared_from_this, public EventOwner { + public: + PlatformDomainNameAddressQuery(const String& host, EventLoop* eventLoop, PlatformDomainNameResolver*); + + void run(); + + private: + void runBlocking(); + void emitError(); + + private: + boost::asio::io_service ioService; + String hostname; + EventLoop* eventLoop; + }; +} + + diff --git a/Swiften/Network/PlatformDomainNameQuery.h b/Swiften/Network/PlatformDomainNameQuery.h new file mode 100644 index 0000000..bbfb1d1 --- /dev/null +++ b/Swiften/Network/PlatformDomainNameQuery.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include + +namespace Swift { + class PlatformDomainNameResolver; + + class PlatformDomainNameQuery { + public: + typedef boost::shared_ptr ref; + + PlatformDomainNameQuery(PlatformDomainNameResolver* resolver) : resolver(resolver) {} + virtual ~PlatformDomainNameQuery() {} + + virtual void runBlocking() = 0; + + protected: + PlatformDomainNameResolver* getResolver() { + return resolver; + } + + private: + PlatformDomainNameResolver* resolver; + }; +} diff --git a/Swiften/Network/PlatformDomainNameResolver.cpp b/Swiften/Network/PlatformDomainNameResolver.cpp index 3f72466..6a61337 100644 --- a/Swiften/Network/PlatformDomainNameResolver.cpp +++ b/Swiften/Network/PlatformDomainNameResolver.cpp @@ -11,10 +11,8 @@ #include #include -#include #include #include -#include #include #include "Swiften/Base/String.h" @@ -23,84 +21,55 @@ #include "Swiften/EventLoop/EventLoop.h" #include "Swiften/Network/HostAddressPort.h" #include "Swiften/Network/DomainNameAddressQuery.h" +#include using namespace Swift; -namespace { - struct AddressQuery : public DomainNameAddressQuery, public boost::enable_shared_from_this, public EventOwner { - AddressQuery(const String& host, EventLoop* eventLoop) : hostname(host), eventLoop(eventLoop), thread(NULL), safeToJoin(false) {} - - ~AddressQuery() { - if (safeToJoin) { - thread->join(); - } - else { - // FIXME: UGLYYYYY - } - delete thread; - } - - void run() { - safeToJoin = false; - thread = new boost::thread(boost::bind(&AddressQuery::doRun, shared_from_this())); - } - - void doRun() { - //std::cout << "PlatformDomainNameResolver::doRun()" << std::endl; - boost::asio::ip::tcp::resolver resolver(ioService); - boost::asio::ip::tcp::resolver::query query(hostname.getUTF8String(), "5222"); - try { - //std::cout << "PlatformDomainNameResolver::doRun(): Resolving" << std::endl; - boost::asio::ip::tcp::resolver::iterator endpointIterator = resolver.resolve(query); - //std::cout << "PlatformDomainNameResolver::doRun(): Resolved" << std::endl; - if (endpointIterator == boost::asio::ip::tcp::resolver::iterator()) { - //std::cout << "PlatformDomainNameResolver::doRun(): Error 1" << std::endl; - emitError(); - } - else { - std::vector results; - for ( ; endpointIterator != boost::asio::ip::tcp::resolver::iterator(); ++endpointIterator) { - boost::asio::ip::address address = (*endpointIterator).endpoint().address(); - results.push_back(address.is_v4() ? HostAddress(&address.to_v4().to_bytes()[0], 4) : HostAddress(&address.to_v6().to_bytes()[0], 16)); - } - - //std::cout << "PlatformDomainNameResolver::doRun(): Success" << std::endl; - eventLoop->postEvent( - boost::bind(boost::ref(onResult), results, boost::optional()), - shared_from_this()); - } - } - catch (...) { - //std::cout << "PlatformDomainNameResolver::doRun(): Error 2" << std::endl; - emitError(); - } - safeToJoin = true; - } - - void emitError() { - eventLoop->postEvent(boost::bind(boost::ref(onResult), std::vector(), boost::optional(DomainNameResolveError())), shared_from_this()); - } - - boost::asio::io_service ioService; - String hostname; - EventLoop* eventLoop; - boost::thread* thread; - bool safeToJoin; - }; +namespace Swift { +PlatformDomainNameResolver::PlatformDomainNameResolver(EventLoop* eventLoop) : eventLoop(eventLoop), stopRequested(false) { + thread = new boost::thread(boost::bind(&PlatformDomainNameResolver::run, this)); } -namespace Swift { - -PlatformDomainNameResolver::PlatformDomainNameResolver(EventLoop* eventLoop) : eventLoop(eventLoop) { +PlatformDomainNameResolver::~PlatformDomainNameResolver() { + stopRequested = true; + addQueryToQueue(boost::shared_ptr()); + thread->join(); } boost::shared_ptr PlatformDomainNameResolver::createServiceQuery(const String& name) { - return boost::shared_ptr(new PlatformDomainNameServiceQuery(IDNA::getEncoded(name), eventLoop)); + return boost::shared_ptr(new PlatformDomainNameServiceQuery(IDNA::getEncoded(name), eventLoop, this)); } boost::shared_ptr PlatformDomainNameResolver::createAddressQuery(const String& name) { - return boost::shared_ptr(new AddressQuery(IDNA::getEncoded(name), eventLoop)); + return boost::shared_ptr(new PlatformDomainNameAddressQuery(IDNA::getEncoded(name), eventLoop, this)); +} + +void PlatformDomainNameResolver::run() { + while (!stopRequested) { + PlatformDomainNameQuery::ref query; + { + boost::unique_lock lock(queueMutex); + while (queue.empty()) { + queueNonEmpty.wait(lock); + } + query = queue.front(); + queue.pop_front(); + } + // Check whether we don't have a non-null query (used to stop the + // resolver) + if (query) { + query->runBlocking(); + } + } +} + +void PlatformDomainNameResolver::addQueryToQueue(PlatformDomainNameQuery::ref query) { + { + boost::lock_guard lock(queueMutex); + queue.push_back(query); + } + queueNonEmpty.notify_one(); } } diff --git a/Swiften/Network/PlatformDomainNameResolver.h b/Swiften/Network/PlatformDomainNameResolver.h index 46c209b..249f2e3 100644 --- a/Swiften/Network/PlatformDomainNameResolver.h +++ b/Swiften/Network/PlatformDomainNameResolver.h @@ -6,7 +6,15 @@ #pragma once -#include "Swiften/Network/DomainNameResolver.h" +#include +#include +#include +#include + +#include +#include +#include +#include namespace Swift { class String; @@ -15,11 +23,23 @@ namespace Swift { class PlatformDomainNameResolver : public DomainNameResolver { public: PlatformDomainNameResolver(EventLoop* eventLoop); + ~PlatformDomainNameResolver(); + + virtual DomainNameServiceQuery::ref createServiceQuery(const String& name); + virtual DomainNameAddressQuery::ref createAddressQuery(const String& name); - virtual boost::shared_ptr createServiceQuery(const String& name); - virtual boost::shared_ptr createAddressQuery(const String& name); + private: + void run(); + void addQueryToQueue(PlatformDomainNameQuery::ref); private: + friend class PlatformDomainNameServiceQuery; + friend class PlatformDomainNameAddressQuery; EventLoop* eventLoop; + bool stopRequested; + boost::thread* thread; + std::deque queue; + boost::mutex queueMutex; + boost::condition_variable queueNonEmpty; }; } diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.cpp b/Swiften/Network/PlatformDomainNameServiceQuery.cpp index 7ab6e7a..bdbb664 100644 --- a/Swiften/Network/PlatformDomainNameServiceQuery.cpp +++ b/Swiften/Network/PlatformDomainNameServiceQuery.cpp @@ -28,30 +28,20 @@ #include "Swiften/EventLoop/EventLoop.h" #include "Swiften/Base/foreach.h" #include +#include using namespace Swift; namespace Swift { -PlatformDomainNameServiceQuery::PlatformDomainNameServiceQuery(const String& service, EventLoop* eventLoop) : eventLoop(eventLoop), thread(NULL), service(service), safeToJoin(true) { -} - -PlatformDomainNameServiceQuery::~PlatformDomainNameServiceQuery() { - if (safeToJoin) { - thread->join(); - } - else { - // FIXME: UGLYYYYY - } - delete thread; +PlatformDomainNameServiceQuery::PlatformDomainNameServiceQuery(const String& service, EventLoop* eventLoop, PlatformDomainNameResolver* resolver) : PlatformDomainNameQuery(resolver), eventLoop(eventLoop), service(service) { } void PlatformDomainNameServiceQuery::run() { - safeToJoin = false; - thread = new boost::thread(boost::bind(&PlatformDomainNameServiceQuery::doRun, shared_from_this())); + getResolver()->addQueryToQueue(shared_from_this()); } -void PlatformDomainNameServiceQuery::doRun() { +void PlatformDomainNameServiceQuery::runBlocking() { SWIFT_LOG(debug) << "Querying " << service << std::endl; std::vector records; @@ -166,14 +156,12 @@ void PlatformDomainNameServiceQuery::doRun() { } #endif - safeToJoin = true; std::sort(records.begin(), records.end(), ResultPriorityComparator()); //std::cout << "Sending out " << records.size() << " SRV results " << std::endl; eventLoop->postEvent(boost::bind(boost::ref(onResult), records)); } void PlatformDomainNameServiceQuery::emitError() { - safeToJoin = true; eventLoop->postEvent(boost::bind(boost::ref(onResult), std::vector()), shared_from_this()); } diff --git a/Swiften/Network/PlatformDomainNameServiceQuery.h b/Swiften/Network/PlatformDomainNameServiceQuery.h index 9808196..c9dbd65 100644 --- a/Swiften/Network/PlatformDomainNameServiceQuery.h +++ b/Swiften/Network/PlatformDomainNameServiceQuery.h @@ -6,31 +6,28 @@ #pragma once -#include #include #include "Swiften/Network/DomainNameServiceQuery.h" #include "Swiften/EventLoop/EventOwner.h" #include "Swiften/Base/String.h" +#include namespace Swift { class EventLoop; - class PlatformDomainNameServiceQuery : public DomainNameServiceQuery, public boost::enable_shared_from_this, public EventOwner { + class PlatformDomainNameServiceQuery : public DomainNameServiceQuery, public PlatformDomainNameQuery, public boost::enable_shared_from_this, public EventOwner { public: - PlatformDomainNameServiceQuery(const String& service, EventLoop* eventLoop); - ~PlatformDomainNameServiceQuery(); + PlatformDomainNameServiceQuery(const String& service, EventLoop* eventLoop, PlatformDomainNameResolver* resolver); virtual void run(); private: - void doRun(); + void runBlocking(); void emitError(); private: EventLoop* eventLoop; - boost::thread* thread; String service; - bool safeToJoin; }; } diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript index f193407..2e376af 100644 --- a/Swiften/Network/SConscript +++ b/Swiften/Network/SConscript @@ -21,6 +21,7 @@ sourceList = [ "DomainNameServiceQuery.cpp", "PlatformDomainNameResolver.cpp", "PlatformDomainNameServiceQuery.cpp", + "PlatformDomainNameAddressQuery.cpp", "StaticDomainNameResolver.cpp", "HostAddress.cpp", "NetworkFactories.cpp", diff --git a/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp b/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp index 1bda585..d0e0a43 100644 --- a/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp +++ b/Swiften/QA/NetworkTest/DomainNameResolverTest.cpp @@ -34,6 +34,7 @@ class DomainNameResolverTest : public CppUnit::TestFixture { CPPUNIT_TEST(testResolveAddress_IPv4and6); CPPUNIT_TEST(testResolveAddress_International); CPPUNIT_TEST(testResolveAddress_Localhost); + CPPUNIT_TEST(testResolveAddress_Parallel); CPPUNIT_TEST(testResolveService); CPPUNIT_TEST(testResolveService_Error); CPPUNIT_TEST_SUITE_END(); @@ -115,6 +116,31 @@ class DomainNameResolverTest : public CppUnit::TestFixture { CPPUNIT_ASSERT(std::find(addressQueryResult.begin(), addressQueryResult.end(), HostAddress("127.0.0.1")) != addressQueryResult.end()); } + void testResolveAddress_Parallel() { + std::vector queries; + static const size_t numQueries = 100; + for (size_t i = 0; i < numQueries; ++i) { + DomainNameAddressQuery::ref query(createAddressQuery("xmpp.test.swift.im")); + queries.push_back(query); + query->run(); + } + + eventLoop->processEvents(); + int ticks = 0; + while (allAddressQueryResults.size() < numQueries) { + ticks++; + if (ticks > 1000) { + CPPUNIT_ASSERT(false); + } + Swift::sleep(10); + eventLoop->processEvents(); + } + + CPPUNIT_ASSERT_EQUAL(numQueries, allAddressQueryResults.size()); + for (size_t i = 0; i < numQueries; ++i) { + CPPUNIT_ASSERT_EQUAL(std::string("10.0.0.0"), allAddressQueryResults[i].toString()); + } + } void testResolveService() { boost::shared_ptr query(createServiceQuery("_xmpp-client._tcp.xmpp-srv.test.swift.im")); @@ -144,10 +170,6 @@ class DomainNameResolverTest : public CppUnit::TestFixture { void testResolveService_Error() { } -/* - } - */ - private: boost::shared_ptr createAddressQuery(const String& domain) { boost::shared_ptr result = resolver->createAddressQuery(domain); @@ -158,6 +180,7 @@ class DomainNameResolverTest : public CppUnit::TestFixture { void handleAddressQueryResult(const std::vector& addresses, boost::optional error) { addressQueryResult = addresses; std::sort(addressQueryResult.begin(), addressQueryResult.end(), CompareHostAddresses()); + allAddressQueryResults.insert(allAddressQueryResults.begin(), addresses.begin(), addresses.end()); addressQueryError = error; resultsAvailable = true; } @@ -190,6 +213,7 @@ class DomainNameResolverTest : public CppUnit::TestFixture { DummyEventLoop* eventLoop; bool resultsAvailable; std::vector addressQueryResult; + std::vector allAddressQueryResults; boost::optional addressQueryError; std::vector serviceQueryResult; PlatformDomainNameResolver* resolver; -- cgit v0.10.2-6-g49f6