diff options
Diffstat (limited to 'Swiften/Network')
-rw-r--r-- | Swiften/Network/Connector.cpp | 70 | ||||
-rw-r--r-- | Swiften/Network/Connector.h | 6 | ||||
-rw-r--r-- | Swiften/Network/StaticDomainNameResolver.cpp | 10 | ||||
-rw-r--r-- | Swiften/Network/StaticDomainNameResolver.h | 3 | ||||
-rw-r--r-- | Swiften/Network/UnitTest/ConnectorTest.cpp | 22 |
5 files changed, 77 insertions, 34 deletions
diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp index ff45481..d2144d7 100644 --- a/Swiften/Network/Connector.cpp +++ b/Swiften/Network/Connector.cpp @@ -10,7 +10,7 @@ namespace Swift { -Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), timeoutMilliseconds(0), queriedAllHosts(true) { +Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), timeoutMilliseconds(0), queriedAllServices(true) { } void Connector::setTimeoutMilliseconds(int milliseconds) { @@ -22,7 +22,7 @@ void Connector::start() { assert(!currentConnection); assert(!serviceQuery); assert(!timer); - queriedAllHosts = false; + queriedAllServices = false; serviceQuery = resolver->createServiceQuery("_xmpp-client._tcp." + hostname); serviceQuery->onResult.connect(boost::bind(&Connector::handleServiceQueryResult, this, _1)); if (timeoutMilliseconds > 0) { @@ -44,18 +44,18 @@ void Connector::handleServiceQueryResult(const std::vector<DomainNameServiceQuer //std::cout << "Received SRV results" << std::endl; serviceQueryResults = std::deque<DomainNameServiceQuery::Result>(result.begin(), result.end()); serviceQuery.reset(); - tryNextHostname(); + tryNextServiceOrFallback(); } -void Connector::tryNextHostname() { - if (queriedAllHosts) { - //std::cout << "Connector::tryNextHostName(): Queried all hosts. Error." << std::endl; +void Connector::tryNextServiceOrFallback() { + if (queriedAllServices) { + //std::cout << "Connector::tryNextServiceOrCallback(): Queried all hosts. Error." << std::endl; finish(boost::shared_ptr<Connection>()); } else if (serviceQueryResults.empty()) { //std::cout << "Connector::tryNextHostName(): Falling back on A resolution" << std::endl; // Fall back on simple address resolving - queriedAllHosts = true; + queriedAllServices = true; queryAddress(hostname); } else { @@ -67,28 +67,38 @@ void Connector::tryNextHostname() { void Connector::handleAddressQueryResult(const std::vector<HostAddress>& addresses, boost::optional<DomainNameResolveError> error) { //std::cout << "Connector::handleAddressQueryResult(): Start" << std::endl; addressQuery.reset(); - if (!serviceQueryResults.empty()) { - DomainNameServiceQuery::Result serviceQueryResult = serviceQueryResults.front(); - serviceQueryResults.pop_front(); - if (error || addresses.empty()) { - //std::cout << "Connector::handleAddressQueryResult(): A lookup for SRV host " << serviceQueryResult.hostname << " failed." << std::endl; - tryNextHostname(); - } - else { - //std::cout << "Connector::handleAddressQueryResult(): A lookup for SRV host " << serviceQueryResult.hostname << " succeeded: " << address.toString() << std::endl; - tryConnect(HostAddressPort(addresses[0], serviceQueryResult.port)); + if (error || addresses.empty()) { + if (!serviceQueryResults.empty()) { + serviceQueryResults.pop_front(); } + tryNextServiceOrFallback(); } - else if (error || addresses.empty()) { - //std::cout << "Connector::handleAddressQueryResult(): Fallback address query failed. Giving up" << std::endl; - // The fallback address query failed - assert(queriedAllHosts); - finish(boost::shared_ptr<Connection>()); + else { + addressQueryResults = std::deque<HostAddress>(addresses.begin(), addresses.end()); + tryNextAddress(); + } +} + +void Connector::tryNextAddress() { + if (addressQueryResults.empty()) { + //std::cout << "Connector::tryNextAddress(): Done trying addresses. Moving on" << std::endl; + // Done trying all addresses. Move on to the next host. + if (!serviceQueryResults.empty()) { + serviceQueryResults.pop_front(); + } + tryNextServiceOrFallback(); } else { - //std::cout << "Connector::handleAddressQueryResult(): Fallback address query succeeded: " << address.toString() << std::endl; - // The fallback query succeeded - tryConnect(HostAddressPort(addresses[0], 5222)); + //std::cout << "Connector::tryNextAddress(): trying next address." << std::endl; + HostAddress address = addressQueryResults.front(); + addressQueryResults.pop_front(); + + int port = 5222; + if (!serviceQueryResults.empty()) { + port = serviceQueryResults.front().port; + } + + tryConnect(HostAddressPort(address, port)); } } @@ -104,7 +114,15 @@ void Connector::handleConnectionConnectFinished(bool error) { //std::cout << "Connector::handleConnectionConnectFinished() " << error << std::endl; if (error) { currentConnection.reset(); - tryNextHostname(); + if (!addressQueryResults.empty()) { + tryNextAddress(); + } + else { + if (!serviceQueryResults.empty()) { + serviceQueryResults.pop_front(); + } + tryNextServiceOrFallback(); + } } else { finish(currentConnection); diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h index 32dd9ab..59fe708 100644 --- a/Swiften/Network/Connector.h +++ b/Swiften/Network/Connector.h @@ -31,7 +31,8 @@ namespace Swift { void handleAddressQueryResult(const std::vector<HostAddress>& address, boost::optional<DomainNameResolveError> error); void queryAddress(const String& hostname); - void tryNextHostname(); + void tryNextServiceOrFallback(); + void tryNextAddress(); void tryConnect(const HostAddressPort& target); void handleConnectionConnectFinished(bool error); @@ -48,7 +49,8 @@ namespace Swift { boost::shared_ptr<DomainNameServiceQuery> serviceQuery; std::deque<DomainNameServiceQuery::Result> serviceQueryResults; boost::shared_ptr<DomainNameAddressQuery> addressQuery; - bool queriedAllHosts; + std::deque<HostAddress> addressQueryResults; + bool queriedAllServices; boost::shared_ptr<Connection> currentConnection; }; }; diff --git a/Swiften/Network/StaticDomainNameResolver.cpp b/Swiften/Network/StaticDomainNameResolver.cpp index a751fdd..196176f 100644 --- a/Swiften/Network/StaticDomainNameResolver.cpp +++ b/Swiften/Network/StaticDomainNameResolver.cpp @@ -38,10 +38,8 @@ namespace { } StaticDomainNameResolver::AddressesMap::const_iterator i = resolver->getAddresses().find(host); if (i != resolver->getAddresses().end()) { - std::vector<HostAddress> result; - result.push_back(i->second); MainEventLoop::postEvent( - boost::bind(boost::ref(onResult), result, boost::optional<DomainNameResolveError>())); + boost::bind(boost::ref(onResult), i->second, boost::optional<DomainNameResolveError>())); } else { MainEventLoop::postEvent(boost::bind(boost::ref(onResult), std::vector<HostAddress>(), boost::optional<DomainNameResolveError>(DomainNameResolveError()))); @@ -60,7 +58,7 @@ StaticDomainNameResolver::StaticDomainNameResolver() : isResponsive(true) { } void StaticDomainNameResolver::addAddress(const String& domain, const HostAddress& address) { - addresses[domain] = address; + addresses[domain].push_back(address); } void StaticDomainNameResolver::addService(const String& service, const DomainNameServiceQuery::Result& result) { @@ -76,6 +74,10 @@ void StaticDomainNameResolver::addXMPPClientService(const String& domain, const addAddress(hostname, address.getAddress()); } +void StaticDomainNameResolver::addXMPPClientService(const String& domain, const String& hostname, int port) { + addService("_xmpp-client._tcp." + domain, ServiceQuery::Result(hostname, port, 0, 0)); +} + boost::shared_ptr<DomainNameServiceQuery> StaticDomainNameResolver::createServiceQuery(const String& name) { return boost::shared_ptr<DomainNameServiceQuery>(new ServiceQuery(name, this)); } diff --git a/Swiften/Network/StaticDomainNameResolver.h b/Swiften/Network/StaticDomainNameResolver.h index d7e7ba4..2428d29 100644 --- a/Swiften/Network/StaticDomainNameResolver.h +++ b/Swiften/Network/StaticDomainNameResolver.h @@ -15,7 +15,7 @@ namespace Swift { class StaticDomainNameResolver : public DomainNameResolver { public: - typedef std::map<String, HostAddress> AddressesMap; + typedef std::map<String, std::vector<HostAddress> > AddressesMap; typedef std::vector< std::pair<String, DomainNameServiceQuery::Result> > ServicesCollection; public: @@ -24,6 +24,7 @@ namespace Swift { void addAddress(const String& domain, const HostAddress& address); void addService(const String& service, const DomainNameServiceQuery::Result& result); void addXMPPClientService(const String& domain, const HostAddressPort&); + void addXMPPClientService(const String& domain, const String& host, int port); const AddressesMap& getAddresses() const { return addresses; diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp index 663011c..2a2ab41 100644 --- a/Swiften/Network/UnitTest/ConnectorTest.cpp +++ b/Swiften/Network/UnitTest/ConnectorTest.cpp @@ -18,6 +18,7 @@ using namespace Swift; class ConnectorTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(ConnectorTest); CPPUNIT_TEST(testConnect); + CPPUNIT_TEST(testConnect_FirstAddressHostFails); CPPUNIT_TEST(testConnect_NoSRVHost); CPPUNIT_TEST(testConnect_NoHosts); CPPUNIT_TEST(testConnect_FirstSRVHostFails); @@ -73,6 +74,24 @@ class ConnectorTest : public CppUnit::TestFixture { CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort)); } + void testConnect_FirstAddressHostFails() { + std::auto_ptr<Connector> testling(createConnector()); + + HostAddress address1("1.1.1.1"); + HostAddress address2("2.2.2.2"); + resolver->addXMPPClientService("foo.com", "host-foo.com", 1234); + resolver->addAddress("host-foo.com", address1); + resolver->addAddress("host-foo.com", address2); + connectionFactory->failingPorts.push_back(HostAddressPort(address1, 1234)); + + testling->start(); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + CPPUNIT_ASSERT(HostAddressPort(address2, 1234) == *(connections[0]->hostAddressPort)); + } + void testConnect_NoHosts() { std::auto_ptr<Connector> testling(createConnector()); @@ -207,7 +226,8 @@ class ConnectorTest : public CppUnit::TestFixture { void connect(const HostAddressPort& address) { hostAddressPort = address; if (isResponsive) { - MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end())); + bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end(); + MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), fail)); } } |