From 2086abd85c97ee4e03f6d7b266076c6607012243 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Wed, 7 Apr 2010 21:20:54 +0200
Subject: Support fallback multiple host addresses when connecting.

Resolves: #305

diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp
index ff45481..d2144d7 100644
--- a/Swiften/Network/Connector.cpp
+++ b/Swiften/Network/Connector.cpp
@@ -10,7 +10,7 @@
 
 namespace Swift {
 
-Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), timeoutMilliseconds(0), queriedAllHosts(true) {
+Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), timeoutMilliseconds(0), queriedAllServices(true) {
 }
 
 void Connector::setTimeoutMilliseconds(int milliseconds) {
@@ -22,7 +22,7 @@ void Connector::start() {
 	assert(!currentConnection);
 	assert(!serviceQuery);
 	assert(!timer);
-	queriedAllHosts = false;
+	queriedAllServices = false;
 	serviceQuery = resolver->createServiceQuery("_xmpp-client._tcp." + hostname);
 	serviceQuery->onResult.connect(boost::bind(&Connector::handleServiceQueryResult, this, _1));
 	if (timeoutMilliseconds > 0) {
@@ -44,18 +44,18 @@ void Connector::handleServiceQueryResult(const std::vector<DomainNameServiceQuer
 	//std::cout << "Received SRV results" << std::endl;
 	serviceQueryResults = std::deque<DomainNameServiceQuery::Result>(result.begin(), result.end());
 	serviceQuery.reset();
-	tryNextHostname();
+	tryNextServiceOrFallback();
 }
 
-void Connector::tryNextHostname() {
-	if (queriedAllHosts) {
-		//std::cout << "Connector::tryNextHostName(): Queried all hosts. Error." << std::endl;
+void Connector::tryNextServiceOrFallback() {
+	if (queriedAllServices) {
+		//std::cout << "Connector::tryNextServiceOrCallback(): Queried all hosts. Error." << std::endl;
 		finish(boost::shared_ptr<Connection>());
 	}
 	else if (serviceQueryResults.empty()) {
 		//std::cout << "Connector::tryNextHostName(): Falling back on A resolution" << std::endl;
 		// Fall back on simple address resolving
-		queriedAllHosts = true;
+		queriedAllServices = true;
 		queryAddress(hostname);
 	}
 	else {
@@ -67,28 +67,38 @@ void Connector::tryNextHostname() {
 void Connector::handleAddressQueryResult(const std::vector<HostAddress>& addresses, boost::optional<DomainNameResolveError> error) {
 	//std::cout << "Connector::handleAddressQueryResult(): Start" << std::endl;
 	addressQuery.reset();
-	if (!serviceQueryResults.empty()) {
-		DomainNameServiceQuery::Result serviceQueryResult = serviceQueryResults.front();
-		serviceQueryResults.pop_front();
-		if (error || addresses.empty()) {
-			//std::cout << "Connector::handleAddressQueryResult(): A lookup for SRV host " << serviceQueryResult.hostname << " failed." << std::endl;
-			tryNextHostname();
-		}
-		else {
-			//std::cout << "Connector::handleAddressQueryResult(): A lookup for SRV host " << serviceQueryResult.hostname << " succeeded: " << address.toString() << std::endl;
-			tryConnect(HostAddressPort(addresses[0], serviceQueryResult.port));
+	if (error || addresses.empty()) {
+		if (!serviceQueryResults.empty()) {
+			serviceQueryResults.pop_front();
 		}
+		tryNextServiceOrFallback();
 	}
-	else if (error || addresses.empty()) {
-		//std::cout << "Connector::handleAddressQueryResult(): Fallback address query failed. Giving up" << std::endl;
-		// The fallback address query failed
-		assert(queriedAllHosts);
-		finish(boost::shared_ptr<Connection>());
+	else {
+		addressQueryResults = std::deque<HostAddress>(addresses.begin(), addresses.end());
+		tryNextAddress();
+	}
+}
+
+void Connector::tryNextAddress() {
+	if (addressQueryResults.empty()) {
+		//std::cout << "Connector::tryNextAddress(): Done trying addresses. Moving on" << std::endl;
+		// Done trying all addresses. Move on to the next host.
+		if (!serviceQueryResults.empty()) {
+			serviceQueryResults.pop_front();
+		}
+		tryNextServiceOrFallback();
 	}
 	else {
-		//std::cout << "Connector::handleAddressQueryResult(): Fallback address query succeeded: " << address.toString() << std::endl;
-		// The fallback query succeeded
-		tryConnect(HostAddressPort(addresses[0], 5222));
+		//std::cout << "Connector::tryNextAddress(): trying next address." << std::endl;
+		HostAddress address = addressQueryResults.front();
+		addressQueryResults.pop_front();
+
+		int port = 5222;
+		if (!serviceQueryResults.empty()) {
+			port = serviceQueryResults.front().port;
+		}
+
+		tryConnect(HostAddressPort(address, port));
 	}
 }
 
@@ -104,7 +114,15 @@ void Connector::handleConnectionConnectFinished(bool error) {
 	//std::cout << "Connector::handleConnectionConnectFinished() " << error << std::endl;
 	if (error) {
 		currentConnection.reset();
-		tryNextHostname();
+		if (!addressQueryResults.empty()) {
+			tryNextAddress();
+		}
+		else {
+			if (!serviceQueryResults.empty()) {
+				serviceQueryResults.pop_front();
+			}
+			tryNextServiceOrFallback();
+		}
 	}
 	else {
 		finish(currentConnection);
diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h
index 32dd9ab..59fe708 100644
--- a/Swiften/Network/Connector.h
+++ b/Swiften/Network/Connector.h
@@ -31,7 +31,8 @@ namespace Swift {
 			void handleAddressQueryResult(const std::vector<HostAddress>& address, boost::optional<DomainNameResolveError> error);
 			void queryAddress(const String& hostname);
 
-			void tryNextHostname();
+			void tryNextServiceOrFallback();
+			void tryNextAddress();
 			void tryConnect(const HostAddressPort& target);
 
 			void handleConnectionConnectFinished(bool error);
@@ -48,7 +49,8 @@ namespace Swift {
 			boost::shared_ptr<DomainNameServiceQuery> serviceQuery;
 			std::deque<DomainNameServiceQuery::Result> serviceQueryResults;
 			boost::shared_ptr<DomainNameAddressQuery> addressQuery;
-			bool queriedAllHosts;
+			std::deque<HostAddress> addressQueryResults;
+			bool queriedAllServices;
 			boost::shared_ptr<Connection> currentConnection;
 	};
 };
diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp
index 663011c..2a2ab41 100644
--- a/Swiften/Network/UnitTest/ConnectorTest.cpp
+++ b/Swiften/Network/UnitTest/ConnectorTest.cpp
@@ -18,6 +18,7 @@ using namespace Swift;
 class ConnectorTest : public CppUnit::TestFixture {
 		CPPUNIT_TEST_SUITE(ConnectorTest);
 		CPPUNIT_TEST(testConnect);
+		CPPUNIT_TEST(testConnect_FirstAddressHostFails);
 		CPPUNIT_TEST(testConnect_NoSRVHost);
 		CPPUNIT_TEST(testConnect_NoHosts);
 		CPPUNIT_TEST(testConnect_FirstSRVHostFails);
@@ -73,6 +74,24 @@ class ConnectorTest : public CppUnit::TestFixture {
 			CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort));
 		}
 
+		void testConnect_FirstAddressHostFails() {
+			std::auto_ptr<Connector> testling(createConnector());
+
+			HostAddress address1("1.1.1.1");
+			HostAddress address2("2.2.2.2");
+			resolver->addXMPPClientService("foo.com", "host-foo.com", 1234);
+			resolver->addAddress("host-foo.com", address1);
+			resolver->addAddress("host-foo.com", address2);
+			connectionFactory->failingPorts.push_back(HostAddressPort(address1, 1234));
+
+			testling->start();
+			eventLoop->processEvents();
+
+			CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+			CPPUNIT_ASSERT(connections[0]);
+			CPPUNIT_ASSERT(HostAddressPort(address2, 1234) == *(connections[0]->hostAddressPort));
+		}
+
 		void testConnect_NoHosts() {
 			std::auto_ptr<Connector> testling(createConnector());
 
@@ -207,7 +226,8 @@ class ConnectorTest : public CppUnit::TestFixture {
 				void connect(const HostAddressPort& address) {
 					hostAddressPort = address;
 					if (isResponsive) {
-						MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end()));
+						bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end();
+						MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), fail));
 					}
 				}
 
-- 
cgit v0.10.2-6-g49f6