From 152c455e18aae9b613f17bca8bba4a2beafe0228 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Thu, 31 Dec 2009 14:27:52 +0100
Subject: Added tests for timing out initial connect.


diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp
index e424f64..9ea5a7f 100644
--- a/Swiften/Network/Connector.cpp
+++ b/Swiften/Network/Connector.cpp
@@ -12,6 +12,10 @@ namespace Swift {
 Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), queriedAllHosts(true) {
 }
 
+void Connector::setTimeoutMilliseconds(int milliseconds) {
+	timeoutMilliseconds = milliseconds;
+}
+
 void Connector::start() {
 	//std::cout << "Connector::start()" << std::endl;
 	assert(!currentConnection);
diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h
index cb885ab..6df3970 100644
--- a/Swiften/Network/Connector.h
+++ b/Swiften/Network/Connector.h
@@ -19,6 +19,7 @@ namespace Swift {
 		public:
 			Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*);
 
+			void setTimeoutMilliseconds(int milliseconds);
 			void start();
 
 			boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished;
@@ -37,11 +38,11 @@ namespace Swift {
 			String hostname;
 			DomainNameResolver* resolver;
 			ConnectionFactory* connectionFactory;
+			int timeoutMilliseconds;
 			boost::shared_ptr<DomainNameServiceQuery> serviceQuery;
 			std::deque<DomainNameServiceQuery::Result> serviceQueryResults;
 			boost::shared_ptr<DomainNameAddressQuery> addressQuery;
 			bool queriedAllHosts;
 			boost::shared_ptr<Connection> currentConnection;
 	};
-		
 };
diff --git a/Swiften/Network/DummyTimerFactory.cpp b/Swiften/Network/DummyTimerFactory.cpp
new file mode 100644
index 0000000..72523bb
--- /dev/null
+++ b/Swiften/Network/DummyTimerFactory.cpp
@@ -0,0 +1,57 @@
+#include "Swiften/Network/DummyTimerFactory.h"
+
+#include <algorithm>
+
+#include "Swiften/Network/Timer.h"
+#include "Swiften/Base/foreach.h"
+
+namespace Swift {
+
+class DummyTimerFactory::DummyTimer : public Timer {
+	public:
+		DummyTimer(int timeout) : timeout(timeout), isRunning(false) {
+		}
+
+		virtual void start() {
+			isRunning = true;
+		}
+
+		virtual void stop() {
+			isRunning = false;
+		}
+	
+		int timeout;
+		bool isRunning;
+};
+
+
+DummyTimerFactory::DummyTimerFactory() : currentTime(0) {
+}
+
+boost::shared_ptr<Timer> DummyTimerFactory::createTimer(int milliseconds) {
+	boost::shared_ptr<DummyTimer> timer(new DummyTimer(milliseconds));
+	timers.push_back(timer);
+	return timer;
+}
+
+static bool hasZeroTimeout(boost::shared_ptr<DummyTimerFactory::DummyTimer> timer) {
+	return timer->timeout == 0;
+}
+
+void DummyTimerFactory::setTime(int time) {
+	assert(time > currentTime);
+	int increment = currentTime - time;
+	std::vector< boost::shared_ptr<DummyTimer> > notifyTimers(timers.begin(), timers.end());
+	foreach(boost::shared_ptr<DummyTimer> timer, notifyTimers) {
+		if (increment >= timer->timeout) {
+			if (timer->isRunning) {
+				timer->onTick();
+			}
+			timer->timeout = 0;
+		}
+	}
+	timers.erase(std::remove_if(timers.begin(), timers.end(), hasZeroTimeout), timers.end());
+	currentTime = time;
+}
+
+}
diff --git a/Swiften/Network/DummyTimerFactory.h b/Swiften/Network/DummyTimerFactory.h
new file mode 100644
index 0000000..feac029
--- /dev/null
+++ b/Swiften/Network/DummyTimerFactory.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#include <list>
+
+#include "Swiften/Network/TimerFactory.h"
+
+namespace Swift {
+	class DummyTimerFactory : public TimerFactory {
+		public:
+			class DummyTimer;
+
+			DummyTimerFactory();
+
+			virtual boost::shared_ptr<Timer> createTimer(int milliseconds);
+			void setTime(int time);
+
+		private:
+			friend class DummyTimer;
+			int currentTime;
+			std::list<boost::shared_ptr<DummyTimer> > timers;
+	};
+}
diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript
index d63b673..767eee2 100644
--- a/Swiften/Network/SConscript
+++ b/Swiften/Network/SConscript
@@ -14,6 +14,7 @@ objects = myenv.StaticObject([
 			"ConnectionServer.cpp",
       "Connector.cpp",
 			"TimerFactory.cpp",
+			"DummyTimerFactory.cpp",
 			"BoostTimerFactory.cpp",
 			"DomainNameResolver.cpp",
 			"DomainNameAddressQuery.cpp",
diff --git a/Swiften/Network/StaticDomainNameResolver.cpp b/Swiften/Network/StaticDomainNameResolver.cpp
index 609bbdd..a7275d2 100644
--- a/Swiften/Network/StaticDomainNameResolver.cpp
+++ b/Swiften/Network/StaticDomainNameResolver.cpp
@@ -13,6 +13,9 @@ namespace {
 		ServiceQuery(const String& service, Swift::StaticDomainNameResolver* resolver) : service(service), resolver(resolver) {}
 
 		virtual void run() {
+			if (!resolver->getIsResponsive()) {
+				return;
+			}
 			std::vector<DomainNameServiceQuery::Result> results;
 			for(StaticDomainNameResolver::ServicesCollection::const_iterator i = resolver->getServices().begin(); i != resolver->getServices().end(); ++i) {
 				if (i->first == service) {
@@ -30,6 +33,9 @@ namespace {
 		AddressQuery(const String& host, StaticDomainNameResolver* resolver) : host(host), resolver(resolver) {}
 
 		virtual void run() {
+			if (!resolver->getIsResponsive()) {
+				return;
+			}
 			StaticDomainNameResolver::AddressesMap::const_iterator i = resolver->getAddresses().find(host);
 			if (i != resolver->getAddresses().end()) {
 				MainEventLoop::postEvent(
@@ -48,6 +54,9 @@ namespace {
 
 namespace Swift {
 
+StaticDomainNameResolver::StaticDomainNameResolver() : isResponsive(true) {
+}
+
 void StaticDomainNameResolver::addAddress(const String& domain, const HostAddress& address) {
 	addresses[domain] = address;
 }
diff --git a/Swiften/Network/StaticDomainNameResolver.h b/Swiften/Network/StaticDomainNameResolver.h
index 0e877d3..d7e7ba4 100644
--- a/Swiften/Network/StaticDomainNameResolver.h
+++ b/Swiften/Network/StaticDomainNameResolver.h
@@ -19,6 +19,8 @@ namespace Swift {
 			typedef std::vector< std::pair<String, DomainNameServiceQuery::Result> > ServicesCollection;
 
 		public:
+			StaticDomainNameResolver();
+
 			void addAddress(const String& domain, const HostAddress& address);
 			void addService(const String& service, const DomainNameServiceQuery::Result& result);
 			void addXMPPClientService(const String& domain, const HostAddressPort&);
@@ -31,10 +33,19 @@ namespace Swift {
 				return services;
 			}
 
+			bool getIsResponsive() const {
+				return isResponsive;
+			}
+
+			void setIsResponsive(bool b) {
+				isResponsive = b;
+			}
+
 			virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const String& name);
 			virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const String& name);
 			
 		private:
+			bool isResponsive;
 			AddressesMap addresses;
 			ServicesCollection services;
 	};
diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp
index af1ad4e..08b9bc1 100644
--- a/Swiften/Network/UnitTest/ConnectorTest.cpp
+++ b/Swiften/Network/UnitTest/ConnectorTest.cpp
@@ -9,6 +9,7 @@
 #include "Swiften/Network/ConnectionFactory.h"
 #include "Swiften/Network/HostAddressPort.h"
 #include "Swiften/Network/StaticDomainNameResolver.h"
+#include "Swiften/Network/DummyTimerFactory.h"
 #include "Swiften/EventLoop/MainEventLoop.h"
 #include "Swiften/EventLoop/DummyEventLoop.h"
 
@@ -23,6 +24,9 @@ class ConnectorTest : public CppUnit::TestFixture {
 		CPPUNIT_TEST(testConnect_AllSRVHostsFailWithoutFallbackHost);
 		CPPUNIT_TEST(testConnect_AllSRVHostsFailWithFallbackHost);
 		CPPUNIT_TEST(testConnect_SRVAndFallbackHostsFail);
+		//CPPUNIT_TEST(testConnect_TimeoutDuringResolve);
+		//CPPUNIT_TEST(testConnect_TimeoutDuringConnect);
+		//CPPUNIT_TEST(testConnect_NoTimeout);
 		CPPUNIT_TEST_SUITE_END();
 
 	public:
@@ -33,9 +37,11 @@ class ConnectorTest : public CppUnit::TestFixture {
 			eventLoop = new DummyEventLoop();
 			resolver = new StaticDomainNameResolver();
 			connectionFactory = new MockConnectionFactory();
+			timerFactory = new DummyTimerFactory();
 		}
 
 		void tearDown() {
+			delete timerFactory;
 			delete connectionFactory;
 			delete resolver;
 			delete eventLoop;
@@ -134,6 +140,50 @@ class ConnectorTest : public CppUnit::TestFixture {
 			CPPUNIT_ASSERT(!connections[0]);
 		}
 
+		void testConnect_TimeoutDuringResolve() {
+			std::auto_ptr<Connector> testling(createConnector());
+			testling->setTimeoutMilliseconds(10);
+			resolver->setIsResponsive(false);
+
+			testling->start();
+			eventLoop->processEvents();
+			timerFactory->setTime(10);
+			eventLoop->processEvents();
+
+			CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+			CPPUNIT_ASSERT(!connections[0]);
+		}
+
+		void testConnect_TimeoutDuringConnect() {
+			std::auto_ptr<Connector> testling(createConnector());
+			testling->setTimeoutMilliseconds(10);
+			resolver->addXMPPClientService("foo.com", host1);
+			connectionFactory->isResponsive = false;
+
+			testling->start();
+			eventLoop->processEvents();
+			timerFactory->setTime(10);
+			eventLoop->processEvents();
+
+			CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+			CPPUNIT_ASSERT(!connections[0]);
+		}
+
+		void testConnect_NoTimeout() {
+			std::auto_ptr<Connector> testling(createConnector());
+			testling->setTimeoutMilliseconds(10);
+			resolver->addXMPPClientService("foo.com", host1);
+
+			testling->start();
+			eventLoop->processEvents();
+			timerFactory->setTime(10);
+			eventLoop->processEvents();
+
+			CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+			CPPUNIT_ASSERT(connections[0]);
+		}
+
+
 	private:
 		Connector* createConnector() {
 			Connector* connector = new Connector("foo.com", resolver, connectionFactory);
@@ -151,12 +201,14 @@ class ConnectorTest : public CppUnit::TestFixture {
 
 		struct MockConnection : public Connection {
 			public:
-				MockConnection(const std::vector<HostAddressPort>& failingPorts) : failingPorts(failingPorts) {}
+				MockConnection(const std::vector<HostAddressPort>& failingPorts, bool isResponsive) : failingPorts(failingPorts), isResponsive(isResponsive) {}
 
 				void listen() { assert(false); }
 				void connect(const HostAddressPort& address) {
 					hostAddressPort = address;
-					MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end()));
+					if (isResponsive) {
+						MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end()));
+					}
 				}
 
 				void disconnect() { assert(false); }
@@ -164,13 +216,18 @@ class ConnectorTest : public CppUnit::TestFixture {
 
 				boost::optional<HostAddressPort> hostAddressPort;
 				std::vector<HostAddressPort> failingPorts;
+				bool isResponsive;
 		};
 
 		struct MockConnectionFactory : public ConnectionFactory {
+			MockConnectionFactory() : isResponsive(true) {
+			}
+
 			boost::shared_ptr<Connection> createConnection() {
-				return boost::shared_ptr<Connection>(new MockConnection(failingPorts));
+				return boost::shared_ptr<Connection>(new MockConnection(failingPorts, isResponsive));
 			}
 
+			bool isResponsive;
 			std::vector<HostAddressPort> failingPorts;
 		};
 
@@ -181,6 +238,7 @@ class ConnectorTest : public CppUnit::TestFixture {
 		DummyEventLoop* eventLoop;
 		StaticDomainNameResolver* resolver;
 		MockConnectionFactory* connectionFactory;
+		DummyTimerFactory* timerFactory;
 		std::vector< boost::shared_ptr<MockConnection> > connections;
 };
 
-- 
cgit v0.10.2-6-g49f6