diff options
-rw-r--r-- | Swiften/Network/Connector.cpp | 4 | ||||
-rw-r--r-- | Swiften/Network/Connector.h | 3 | ||||
-rw-r--r-- | Swiften/Network/DummyTimerFactory.cpp | 57 | ||||
-rw-r--r-- | Swiften/Network/DummyTimerFactory.h | 22 | ||||
-rw-r--r-- | Swiften/Network/SConscript | 1 | ||||
-rw-r--r-- | Swiften/Network/StaticDomainNameResolver.cpp | 9 | ||||
-rw-r--r-- | Swiften/Network/StaticDomainNameResolver.h | 11 | ||||
-rw-r--r-- | Swiften/Network/UnitTest/ConnectorTest.cpp | 64 |
8 files changed, 167 insertions, 4 deletions
diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp index e424f64..9ea5a7f 100644 --- a/Swiften/Network/Connector.cpp +++ b/Swiften/Network/Connector.cpp @@ -1,62 +1,66 @@ #include "Swiften/Network/Connector.h" #include <boost/bind.hpp> #include <iostream> #include "Swiften/Network/ConnectionFactory.h" #include "Swiften/Network/DomainNameResolver.h" #include "Swiften/Network/DomainNameAddressQuery.h" namespace Swift { Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), queriedAllHosts(true) { } +void Connector::setTimeoutMilliseconds(int milliseconds) { + timeoutMilliseconds = milliseconds; +} + void Connector::start() { //std::cout << "Connector::start()" << std::endl; assert(!currentConnection); assert(!serviceQuery); queriedAllHosts = false; serviceQuery = resolver->createServiceQuery("_xmpp-client._tcp." + hostname); serviceQuery->onResult.connect(boost::bind(&Connector::handleServiceQueryResult, this, _1)); serviceQuery->run(); } void Connector::queryAddress(const String& hostname) { assert(!addressQuery); addressQuery = resolver->createAddressQuery(hostname); addressQuery->onResult.connect(boost::bind(&Connector::handleAddressQueryResult, this, _1, _2)); addressQuery->run(); } void Connector::handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result) { //std::cout << "Received SRV results" << std::endl; serviceQueryResults = std::deque<DomainNameServiceQuery::Result>(result.begin(), result.end()); serviceQuery.reset(); tryNextHostname(); } void Connector::tryNextHostname() { if (queriedAllHosts) { //std::cout << "Connector::tryNextHostName(): Queried all hosts. Error." << std::endl; onConnectFinished(boost::shared_ptr<Connection>()); } else if (serviceQueryResults.empty()) { //std::cout << "Connector::tryNextHostName(): Falling back on A resolution" << std::endl; // Fall back on simple address resolving queriedAllHosts = true; queryAddress(hostname); } else { //std::cout << "Connector::tryNextHostName(): Querying next address" << std::endl; queryAddress(serviceQueryResults.front().hostname); } } void Connector::handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error) { //std::cout << "Connector::handleAddressQueryResult(): Start" << std::endl; addressQuery.reset(); if (!serviceQueryResults.empty()) { DomainNameServiceQuery::Result serviceQueryResult = serviceQueryResults.front(); serviceQueryResults.pop_front(); if (error) { diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h index cb885ab..6df3970 100644 --- a/Swiften/Network/Connector.h +++ b/Swiften/Network/Connector.h @@ -1,47 +1,48 @@ #pragma once #include <deque> #include <boost/signal.hpp> #include <boost/shared_ptr.hpp> #include "Swiften/Network/DomainNameServiceQuery.h" #include "Swiften/Network/Connection.h" #include "Swiften/Network/HostAddressPort.h" #include "Swiften/Base/String.h" #include "Swiften/Network/DomainNameResolveError.h" namespace Swift { class DomainNameAddressQuery; class DomainNameResolver; class ConnectionFactory; class Connector : public boost::bsignals::trackable { public: Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*); + void setTimeoutMilliseconds(int milliseconds); void start(); boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished; private: void handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result); void handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error); void queryAddress(const String& hostname); void tryNextHostname(); void tryConnect(const HostAddressPort& target); void handleConnectionConnectFinished(bool error); private: String hostname; DomainNameResolver* resolver; ConnectionFactory* connectionFactory; + int timeoutMilliseconds; boost::shared_ptr<DomainNameServiceQuery> serviceQuery; std::deque<DomainNameServiceQuery::Result> serviceQueryResults; boost::shared_ptr<DomainNameAddressQuery> addressQuery; bool queriedAllHosts; boost::shared_ptr<Connection> currentConnection; }; - }; diff --git a/Swiften/Network/DummyTimerFactory.cpp b/Swiften/Network/DummyTimerFactory.cpp new file mode 100644 index 0000000..72523bb --- /dev/null +++ b/Swiften/Network/DummyTimerFactory.cpp @@ -0,0 +1,57 @@ +#include "Swiften/Network/DummyTimerFactory.h" + +#include <algorithm> + +#include "Swiften/Network/Timer.h" +#include "Swiften/Base/foreach.h" + +namespace Swift { + +class DummyTimerFactory::DummyTimer : public Timer { + public: + DummyTimer(int timeout) : timeout(timeout), isRunning(false) { + } + + virtual void start() { + isRunning = true; + } + + virtual void stop() { + isRunning = false; + } + + int timeout; + bool isRunning; +}; + + +DummyTimerFactory::DummyTimerFactory() : currentTime(0) { +} + +boost::shared_ptr<Timer> DummyTimerFactory::createTimer(int milliseconds) { + boost::shared_ptr<DummyTimer> timer(new DummyTimer(milliseconds)); + timers.push_back(timer); + return timer; +} + +static bool hasZeroTimeout(boost::shared_ptr<DummyTimerFactory::DummyTimer> timer) { + return timer->timeout == 0; +} + +void DummyTimerFactory::setTime(int time) { + assert(time > currentTime); + int increment = currentTime - time; + std::vector< boost::shared_ptr<DummyTimer> > notifyTimers(timers.begin(), timers.end()); + foreach(boost::shared_ptr<DummyTimer> timer, notifyTimers) { + if (increment >= timer->timeout) { + if (timer->isRunning) { + timer->onTick(); + } + timer->timeout = 0; + } + } + timers.erase(std::remove_if(timers.begin(), timers.end(), hasZeroTimeout), timers.end()); + currentTime = time; +} + +} diff --git a/Swiften/Network/DummyTimerFactory.h b/Swiften/Network/DummyTimerFactory.h new file mode 100644 index 0000000..feac029 --- /dev/null +++ b/Swiften/Network/DummyTimerFactory.h @@ -0,0 +1,22 @@ +#pragma once + +#include <list> + +#include "Swiften/Network/TimerFactory.h" + +namespace Swift { + class DummyTimerFactory : public TimerFactory { + public: + class DummyTimer; + + DummyTimerFactory(); + + virtual boost::shared_ptr<Timer> createTimer(int milliseconds); + void setTime(int time); + + private: + friend class DummyTimer; + int currentTime; + std::list<boost::shared_ptr<DummyTimer> > timers; + }; +} diff --git a/Swiften/Network/SConscript b/Swiften/Network/SConscript index d63b673..767eee2 100644 --- a/Swiften/Network/SConscript +++ b/Swiften/Network/SConscript @@ -1,29 +1,30 @@ Import("swiften_env") myenv = swiften_env.Clone() myenv.MergeFlags(myenv["LIBIDN_FLAGS"]) myenv.MergeFlags(myenv["CARES_FLAGS"]) objects = myenv.StaticObject([ "BoostConnection.cpp", "BoostConnectionFactory.cpp", "BoostConnectionServer.cpp", "MainBoostIOServiceThread.cpp", "BoostIOServiceThread.cpp", "ConnectionFactory.cpp", "ConnectionServer.cpp", "Connector.cpp", "TimerFactory.cpp", + "DummyTimerFactory.cpp", "BoostTimerFactory.cpp", "DomainNameResolver.cpp", "DomainNameAddressQuery.cpp", "DomainNameServiceQuery.cpp", "PlatformDomainNameResolver.cpp", "PlatformDomainNameServiceQuery.cpp", "CAresDomainNameResolver.cpp", "StaticDomainNameResolver.cpp", "HostAddress.cpp", "Timer.cpp", "BoostTimer.cpp", ]) swiften_env.Append(SWIFTEN_OBJECTS = [objects]) diff --git a/Swiften/Network/StaticDomainNameResolver.cpp b/Swiften/Network/StaticDomainNameResolver.cpp index 609bbdd..a7275d2 100644 --- a/Swiften/Network/StaticDomainNameResolver.cpp +++ b/Swiften/Network/StaticDomainNameResolver.cpp @@ -1,76 +1,85 @@ #include "Swiften/Network/StaticDomainNameResolver.h" #include <boost/bind.hpp> #include <boost/lexical_cast.hpp> #include "Swiften/Network/DomainNameResolveError.h" #include "Swiften/Base/String.h" using namespace Swift; namespace { struct ServiceQuery : public DomainNameServiceQuery, public EventOwner { ServiceQuery(const String& service, Swift::StaticDomainNameResolver* resolver) : service(service), resolver(resolver) {} virtual void run() { + if (!resolver->getIsResponsive()) { + return; + } std::vector<DomainNameServiceQuery::Result> results; for(StaticDomainNameResolver::ServicesCollection::const_iterator i = resolver->getServices().begin(); i != resolver->getServices().end(); ++i) { if (i->first == service) { results.push_back(i->second); } } MainEventLoop::postEvent(boost::bind(boost::ref(onResult), results)); } String service; StaticDomainNameResolver* resolver; }; struct AddressQuery : public DomainNameAddressQuery, public EventOwner { AddressQuery(const String& host, StaticDomainNameResolver* resolver) : host(host), resolver(resolver) {} virtual void run() { + if (!resolver->getIsResponsive()) { + return; + } StaticDomainNameResolver::AddressesMap::const_iterator i = resolver->getAddresses().find(host); if (i != resolver->getAddresses().end()) { MainEventLoop::postEvent( boost::bind(boost::ref(onResult), i->second, boost::optional<DomainNameResolveError>())); } else { MainEventLoop::postEvent(boost::bind(boost::ref(onResult), HostAddress(), boost::optional<DomainNameResolveError>(DomainNameResolveError()))); } } String host; StaticDomainNameResolver* resolver; }; } namespace Swift { +StaticDomainNameResolver::StaticDomainNameResolver() : isResponsive(true) { +} + void StaticDomainNameResolver::addAddress(const String& domain, const HostAddress& address) { addresses[domain] = address; } void StaticDomainNameResolver::addService(const String& service, const DomainNameServiceQuery::Result& result) { services.push_back(std::make_pair(service, result)); } void StaticDomainNameResolver::addXMPPClientService(const String& domain, const HostAddressPort& address) { static int hostid = 0; String hostname(std::string("host-") + boost::lexical_cast<std::string>(hostid)); hostid++; addService("_xmpp-client._tcp." + domain, ServiceQuery::Result(hostname, address.getPort(), 0, 0)); addAddress(hostname, address.getAddress()); } boost::shared_ptr<DomainNameServiceQuery> StaticDomainNameResolver::createServiceQuery(const String& name) { return boost::shared_ptr<DomainNameServiceQuery>(new ServiceQuery(name, this)); } boost::shared_ptr<DomainNameAddressQuery> StaticDomainNameResolver::createAddressQuery(const String& name) { return boost::shared_ptr<DomainNameAddressQuery>(new AddressQuery(name, this)); } } diff --git a/Swiften/Network/StaticDomainNameResolver.h b/Swiften/Network/StaticDomainNameResolver.h index 0e877d3..d7e7ba4 100644 --- a/Swiften/Network/StaticDomainNameResolver.h +++ b/Swiften/Network/StaticDomainNameResolver.h @@ -1,41 +1,52 @@ #pragma once #include <vector> #include <map> #include "Swiften/Network/HostAddress.h" #include "Swiften/Network/HostAddressPort.h" #include "Swiften/Network/DomainNameResolver.h" #include "Swiften/Network/DomainNameServiceQuery.h" #include "Swiften/Network/DomainNameAddressQuery.h" #include "Swiften/EventLoop/MainEventLoop.h" namespace Swift { class String; class StaticDomainNameResolver : public DomainNameResolver { public: typedef std::map<String, HostAddress> AddressesMap; typedef std::vector< std::pair<String, DomainNameServiceQuery::Result> > ServicesCollection; public: + StaticDomainNameResolver(); + void addAddress(const String& domain, const HostAddress& address); void addService(const String& service, const DomainNameServiceQuery::Result& result); void addXMPPClientService(const String& domain, const HostAddressPort&); const AddressesMap& getAddresses() const { return addresses; } const ServicesCollection& getServices() const { return services; } + bool getIsResponsive() const { + return isResponsive; + } + + void setIsResponsive(bool b) { + isResponsive = b; + } + virtual boost::shared_ptr<DomainNameServiceQuery> createServiceQuery(const String& name); virtual boost::shared_ptr<DomainNameAddressQuery> createAddressQuery(const String& name); private: + bool isResponsive; AddressesMap addresses; ServicesCollection services; }; } diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp index af1ad4e..08b9bc1 100644 --- a/Swiften/Network/UnitTest/ConnectorTest.cpp +++ b/Swiften/Network/UnitTest/ConnectorTest.cpp @@ -1,86 +1,92 @@ #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> #include <boost/optional.hpp> #include <boost/bind.hpp> #include "Swiften/Network/Connector.h" #include "Swiften/Network/Connection.h" #include "Swiften/Network/ConnectionFactory.h" #include "Swiften/Network/HostAddressPort.h" #include "Swiften/Network/StaticDomainNameResolver.h" +#include "Swiften/Network/DummyTimerFactory.h" #include "Swiften/EventLoop/MainEventLoop.h" #include "Swiften/EventLoop/DummyEventLoop.h" using namespace Swift; class ConnectorTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(ConnectorTest); CPPUNIT_TEST(testConnect); CPPUNIT_TEST(testConnect_NoSRVHost); CPPUNIT_TEST(testConnect_NoHosts); CPPUNIT_TEST(testConnect_FirstSRVHostFails); CPPUNIT_TEST(testConnect_AllSRVHostsFailWithoutFallbackHost); CPPUNIT_TEST(testConnect_AllSRVHostsFailWithFallbackHost); CPPUNIT_TEST(testConnect_SRVAndFallbackHostsFail); + //CPPUNIT_TEST(testConnect_TimeoutDuringResolve); + //CPPUNIT_TEST(testConnect_TimeoutDuringConnect); + //CPPUNIT_TEST(testConnect_NoTimeout); CPPUNIT_TEST_SUITE_END(); public: ConnectorTest() : host1(HostAddress("1.1.1.1"), 1234), host2(HostAddress("2.2.2.2"), 2345), host3(HostAddress("3.3.3.3"), 5222) { } void setUp() { eventLoop = new DummyEventLoop(); resolver = new StaticDomainNameResolver(); connectionFactory = new MockConnectionFactory(); + timerFactory = new DummyTimerFactory(); } void tearDown() { + delete timerFactory; delete connectionFactory; delete resolver; delete eventLoop; } void testConnect() { std::auto_ptr<Connector> testling(createConnector()); resolver->addXMPPClientService("foo.com", host1); resolver->addXMPPClientService("foo.com", host2); resolver->addAddress("foo.com", host3.getAddress()); testling->start(); eventLoop->processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); CPPUNIT_ASSERT(connections[0]); CPPUNIT_ASSERT(host1 == *(connections[0]->hostAddressPort)); } void testConnect_NoSRVHost() { std::auto_ptr<Connector> testling(createConnector()); resolver->addAddress("foo.com", host3.getAddress()); testling->start(); eventLoop->processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); CPPUNIT_ASSERT(connections[0]); CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort)); } void testConnect_NoHosts() { std::auto_ptr<Connector> testling(createConnector()); testling->start(); eventLoop->processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); CPPUNIT_ASSERT(!connections[0]); } void testConnect_FirstSRVHostFails() { std::auto_ptr<Connector> testling(createConnector()); resolver->addXMPPClientService("foo.com", host1); resolver->addXMPPClientService("foo.com", host2); connectionFactory->failingPorts.push_back(host1); testling->start(); @@ -89,99 +95,151 @@ class ConnectorTest : public CppUnit::TestFixture { CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); CPPUNIT_ASSERT(host2 == *(connections[0]->hostAddressPort)); } void testConnect_AllSRVHostsFailWithoutFallbackHost() { std::auto_ptr<Connector> testling(createConnector()); resolver->addXMPPClientService("foo.com", host1); resolver->addXMPPClientService("foo.com", host2); connectionFactory->failingPorts.push_back(host1); connectionFactory->failingPorts.push_back(host2); testling->start(); eventLoop->processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); CPPUNIT_ASSERT(!connections[0]); } void testConnect_AllSRVHostsFailWithFallbackHost() { std::auto_ptr<Connector> testling(createConnector()); resolver->addXMPPClientService("foo.com", host1); resolver->addXMPPClientService("foo.com", host2); resolver->addAddress("foo.com", host3.getAddress()); connectionFactory->failingPorts.push_back(host1); connectionFactory->failingPorts.push_back(host2); testling->start(); eventLoop->processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); CPPUNIT_ASSERT(connections[0]); CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort)); } void testConnect_SRVAndFallbackHostsFail() { std::auto_ptr<Connector> testling(createConnector()); resolver->addXMPPClientService("foo.com", host1); resolver->addAddress("foo.com", host3.getAddress()); connectionFactory->failingPorts.push_back(host1); connectionFactory->failingPorts.push_back(host3); testling->start(); eventLoop->processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); CPPUNIT_ASSERT(!connections[0]); } + void testConnect_TimeoutDuringResolve() { + std::auto_ptr<Connector> testling(createConnector()); + testling->setTimeoutMilliseconds(10); + resolver->setIsResponsive(false); + + testling->start(); + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + void testConnect_TimeoutDuringConnect() { + std::auto_ptr<Connector> testling(createConnector()); + testling->setTimeoutMilliseconds(10); + resolver->addXMPPClientService("foo.com", host1); + connectionFactory->isResponsive = false; + + testling->start(); + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + void testConnect_NoTimeout() { + std::auto_ptr<Connector> testling(createConnector()); + testling->setTimeoutMilliseconds(10); + resolver->addXMPPClientService("foo.com", host1); + + testling->start(); + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(connections[0]); + } + + private: Connector* createConnector() { Connector* connector = new Connector("foo.com", resolver, connectionFactory); connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1)); return connector; } void handleConnectorFinished(boost::shared_ptr<Connection> connection) { boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection)); if (connection) { assert(c); } connections.push_back(c); } struct MockConnection : public Connection { public: - MockConnection(const std::vector<HostAddressPort>& failingPorts) : failingPorts(failingPorts) {} + MockConnection(const std::vector<HostAddressPort>& failingPorts, bool isResponsive) : failingPorts(failingPorts), isResponsive(isResponsive) {} void listen() { assert(false); } void connect(const HostAddressPort& address) { hostAddressPort = address; - MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end())); + if (isResponsive) { + MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end())); + } } void disconnect() { assert(false); } void write(const ByteArray&) { assert(false); } boost::optional<HostAddressPort> hostAddressPort; std::vector<HostAddressPort> failingPorts; + bool isResponsive; }; struct MockConnectionFactory : public ConnectionFactory { + MockConnectionFactory() : isResponsive(true) { + } + boost::shared_ptr<Connection> createConnection() { - return boost::shared_ptr<Connection>(new MockConnection(failingPorts)); + return boost::shared_ptr<Connection>(new MockConnection(failingPorts, isResponsive)); } + bool isResponsive; std::vector<HostAddressPort> failingPorts; }; private: HostAddressPort host1; HostAddressPort host2; HostAddressPort host3; DummyEventLoop* eventLoop; StaticDomainNameResolver* resolver; MockConnectionFactory* connectionFactory; + DummyTimerFactory* timerFactory; std::vector< boost::shared_ptr<MockConnection> > connections; }; CPPUNIT_TEST_SUITE_REGISTRATION(ConnectorTest); |