diff options
-rw-r--r-- | Swiften/Client/Client.cpp | 3 | ||||
-rw-r--r-- | Swiften/Elements/MUCPayload.h | 13 | ||||
-rw-r--r-- | Swiften/Network/Connector.cpp | 27 | ||||
-rw-r--r-- | Swiften/Network/Connector.h | 8 | ||||
-rw-r--r-- | Swiften/Network/DummyTimerFactory.cpp | 7 | ||||
-rw-r--r-- | Swiften/Network/UnitTest/ConnectorTest.cpp | 8 |
6 files changed, 45 insertions, 21 deletions
diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp index 27f3d9c..c704248 100644 --- a/Swiften/Client/Client.cpp +++ b/Swiften/Client/Client.cpp @@ -1,88 +1,89 @@ #include "Swiften/Client/Client.h" #include <boost/bind.hpp> #include "Swiften/Network/MainBoostIOServiceThread.h" #include "Swiften/Network/BoostIOServiceThread.h" #include "Swiften/Client/ClientSession.h" #include "Swiften/StreamStack/PlatformTLSLayerFactory.h" #include "Swiften/Network/Connector.h" #include "Swiften/Network/BoostConnectionFactory.h" #include "Swiften/Network/BoostTimerFactory.h" #include "Swiften/TLS/PKCS12Certificate.h" #include "Swiften/Session/BasicSessionStream.h" namespace Swift { Client::Client(const JID& jid, const String& password) : IQRouter(this), jid_(jid), password_(password) { connectionFactory_ = new BoostConnectionFactory(&MainBoostIOServiceThread::getInstance().getIOService()); timerFactory_ = new BoostTimerFactory(&MainBoostIOServiceThread::getInstance().getIOService()); tlsLayerFactory_ = new PlatformTLSLayerFactory(); } Client::~Client() { if (session_ || connection_) { std::cerr << "Warning: Client not disconnected properly" << std::endl; } delete tlsLayerFactory_; delete timerFactory_; delete connectionFactory_; } bool Client::isAvailable() { return session_; } void Client::connect() { assert(!connector_); - connector_ = boost::shared_ptr<Connector>(new Connector(jid_.getDomain(), &resolver_, connectionFactory_)); + connector_ = boost::shared_ptr<Connector>(new Connector(jid_.getDomain(), &resolver_, connectionFactory_, timerFactory_)); connector_->onConnectFinished.connect(boost::bind(&Client::handleConnectorFinished, this, _1)); + connector_->setTimeoutMilliseconds(60*1000); connector_->start(); } void Client::handleConnectorFinished(boost::shared_ptr<Connection> connection) { // TODO: Add domain name resolver error connector_.reset(); if (!connection) { onError(ClientError::ConnectionError); } else { assert(!connection_); connection_ = connection; assert(!sessionStream_); sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(connection_, &payloadParserFactories_, &payloadSerializers_, tlsLayerFactory_, timerFactory_)); if (!certificate_.isEmpty()) { sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_)); } sessionStream_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1)); sessionStream_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1)); sessionStream_->initialize(); session_ = ClientSession::create(jid_, sessionStream_); session_->onInitialized.connect(boost::bind(boost::ref(onConnected))); session_->onFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1)); session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this)); session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1)); session_->start(); } } void Client::disconnect() { if (session_) { session_->finish(); } else { closeConnection(); } } void Client::closeConnection() { if (sessionStream_) { sessionStream_.reset(); } if (connection_) { connection_->disconnect(); connection_.reset(); } diff --git a/Swiften/Elements/MUCPayload.h b/Swiften/Elements/MUCPayload.h index 205ae46..97932a1 100644 --- a/Swiften/Elements/MUCPayload.h +++ b/Swiften/Elements/MUCPayload.h @@ -1,16 +1,11 @@ -#ifndef SWIFTEN_MUCPayload_H -#define SWIFTEN_MUCPayload_H +#pragma once -#include "Swiften/Base/String.h" #include "Swiften/Elements/Payload.h" namespace Swift { - class MUCPayload : public Payload - { + class MUCPayload : public Payload { public: - MUCPayload() { } - + MUCPayload() { + } }; } - -#endif diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp index 9ea5a7f..d372bf2 100644 --- a/Swiften/Network/Connector.cpp +++ b/Swiften/Network/Connector.cpp @@ -1,107 +1,126 @@ #include "Swiften/Network/Connector.h" #include <boost/bind.hpp> #include <iostream> #include "Swiften/Network/ConnectionFactory.h" #include "Swiften/Network/DomainNameResolver.h" #include "Swiften/Network/DomainNameAddressQuery.h" +#include "Swiften/Network/TimerFactory.h" namespace Swift { -Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), queriedAllHosts(true) { +Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), timeoutMilliseconds(0), queriedAllHosts(true) { } void Connector::setTimeoutMilliseconds(int milliseconds) { timeoutMilliseconds = milliseconds; } void Connector::start() { //std::cout << "Connector::start()" << std::endl; assert(!currentConnection); assert(!serviceQuery); + assert(!timer); queriedAllHosts = false; serviceQuery = resolver->createServiceQuery("_xmpp-client._tcp." + hostname); serviceQuery->onResult.connect(boost::bind(&Connector::handleServiceQueryResult, this, _1)); + if (timeoutMilliseconds > 0) { + timer = timerFactory->createTimer(timeoutMilliseconds); + timer->onTick.connect(boost::bind(&Connector::handleTimeout, this)); + timer->start(); + } serviceQuery->run(); } void Connector::queryAddress(const String& hostname) { assert(!addressQuery); addressQuery = resolver->createAddressQuery(hostname); addressQuery->onResult.connect(boost::bind(&Connector::handleAddressQueryResult, this, _1, _2)); addressQuery->run(); } void Connector::handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result) { //std::cout << "Received SRV results" << std::endl; serviceQueryResults = std::deque<DomainNameServiceQuery::Result>(result.begin(), result.end()); serviceQuery.reset(); tryNextHostname(); } void Connector::tryNextHostname() { if (queriedAllHosts) { //std::cout << "Connector::tryNextHostName(): Queried all hosts. Error." << std::endl; - onConnectFinished(boost::shared_ptr<Connection>()); + finish(boost::shared_ptr<Connection>()); } else if (serviceQueryResults.empty()) { //std::cout << "Connector::tryNextHostName(): Falling back on A resolution" << std::endl; // Fall back on simple address resolving queriedAllHosts = true; queryAddress(hostname); } else { //std::cout << "Connector::tryNextHostName(): Querying next address" << std::endl; queryAddress(serviceQueryResults.front().hostname); } } void Connector::handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error) { //std::cout << "Connector::handleAddressQueryResult(): Start" << std::endl; addressQuery.reset(); if (!serviceQueryResults.empty()) { DomainNameServiceQuery::Result serviceQueryResult = serviceQueryResults.front(); serviceQueryResults.pop_front(); if (error) { //std::cout << "Connector::handleAddressQueryResult(): A lookup for SRV host " << serviceQueryResult.hostname << " failed." << std::endl; tryNextHostname(); } else { //std::cout << "Connector::handleAddressQueryResult(): A lookup for SRV host " << serviceQueryResult.hostname << " succeeded: " << address.toString() << std::endl; tryConnect(HostAddressPort(address, serviceQueryResult.port)); } } else if (error) { //std::cout << "Connector::handleAddressQueryResult(): Fallback address query failed. Giving up" << std::endl; // The fallback address query failed assert(queriedAllHosts); - onConnectFinished(boost::shared_ptr<Connection>()); + finish(boost::shared_ptr<Connection>()); } else { //std::cout << "Connector::handleAddressQueryResult(): Fallback address query succeeded: " << address.toString() << std::endl; // The fallback query succeeded tryConnect(HostAddressPort(address, 5222)); } } void Connector::tryConnect(const HostAddressPort& target) { assert(!currentConnection); //std::cout << "Connector::tryConnect() " << target.getAddress().toString() << " " << target.getPort() << std::endl; currentConnection = connectionFactory->createConnection(); currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, this, _1)); currentConnection->connect(target); } void Connector::handleConnectionConnectFinished(bool error) { //std::cout << "Connector::handleConnectionConnectFinished() " << error << std::endl; if (error) { currentConnection.reset(); tryNextHostname(); } else { - onConnectFinished(currentConnection); + finish(currentConnection); + } +} + +void Connector::finish(boost::shared_ptr<Connection> connection) { + if (timer) { + timer->stop(); + timer.reset(); } + onConnectFinished(connection); +} + +void Connector::handleTimeout() { + finish(boost::shared_ptr<Connection>()); } }; diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h index 6df3970..507f085 100644 --- a/Swiften/Network/Connector.h +++ b/Swiften/Network/Connector.h @@ -1,48 +1,54 @@ #pragma once #include <deque> #include <boost/signal.hpp> #include <boost/shared_ptr.hpp> #include "Swiften/Network/DomainNameServiceQuery.h" #include "Swiften/Network/Connection.h" +#include "Swiften/Network/Timer.h" #include "Swiften/Network/HostAddressPort.h" #include "Swiften/Base/String.h" #include "Swiften/Network/DomainNameResolveError.h" namespace Swift { class DomainNameAddressQuery; class DomainNameResolver; class ConnectionFactory; + class TimerFactory; class Connector : public boost::bsignals::trackable { public: - Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*); + Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*, TimerFactory*); void setTimeoutMilliseconds(int milliseconds); void start(); boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished; private: void handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result); void handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error); void queryAddress(const String& hostname); void tryNextHostname(); void tryConnect(const HostAddressPort& target); void handleConnectionConnectFinished(bool error); + void finish(boost::shared_ptr<Connection>); + void handleTimeout(); private: String hostname; DomainNameResolver* resolver; ConnectionFactory* connectionFactory; + TimerFactory* timerFactory; int timeoutMilliseconds; + boost::shared_ptr<Timer> timer; boost::shared_ptr<DomainNameServiceQuery> serviceQuery; std::deque<DomainNameServiceQuery::Result> serviceQueryResults; boost::shared_ptr<DomainNameAddressQuery> addressQuery; bool queriedAllHosts; boost::shared_ptr<Connection> currentConnection; }; }; diff --git a/Swiften/Network/DummyTimerFactory.cpp b/Swiften/Network/DummyTimerFactory.cpp index 72523bb..7626584 100644 --- a/Swiften/Network/DummyTimerFactory.cpp +++ b/Swiften/Network/DummyTimerFactory.cpp @@ -1,57 +1,60 @@ #include "Swiften/Network/DummyTimerFactory.h" #include <algorithm> -#include "Swiften/Network/Timer.h" #include "Swiften/Base/foreach.h" +#include "Swiften/Network/Timer.h" namespace Swift { class DummyTimerFactory::DummyTimer : public Timer { public: DummyTimer(int timeout) : timeout(timeout), isRunning(false) { } virtual void start() { isRunning = true; } virtual void stop() { isRunning = false; } int timeout; bool isRunning; }; DummyTimerFactory::DummyTimerFactory() : currentTime(0) { } boost::shared_ptr<Timer> DummyTimerFactory::createTimer(int milliseconds) { boost::shared_ptr<DummyTimer> timer(new DummyTimer(milliseconds)); timers.push_back(timer); return timer; } static bool hasZeroTimeout(boost::shared_ptr<DummyTimerFactory::DummyTimer> timer) { return timer->timeout == 0; } void DummyTimerFactory::setTime(int time) { assert(time > currentTime); - int increment = currentTime - time; + int increment = time - currentTime; std::vector< boost::shared_ptr<DummyTimer> > notifyTimers(timers.begin(), timers.end()); foreach(boost::shared_ptr<DummyTimer> timer, notifyTimers) { if (increment >= timer->timeout) { if (timer->isRunning) { timer->onTick(); } timer->timeout = 0; } + else { + timer->timeout -= increment; + } } timers.erase(std::remove_if(timers.begin(), timers.end(), hasZeroTimeout), timers.end()); currentTime = time; } } diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp index 08b9bc1..663011c 100644 --- a/Swiften/Network/UnitTest/ConnectorTest.cpp +++ b/Swiften/Network/UnitTest/ConnectorTest.cpp @@ -1,77 +1,77 @@ #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> #include <boost/optional.hpp> #include <boost/bind.hpp> #include "Swiften/Network/Connector.h" #include "Swiften/Network/Connection.h" #include "Swiften/Network/ConnectionFactory.h" #include "Swiften/Network/HostAddressPort.h" #include "Swiften/Network/StaticDomainNameResolver.h" #include "Swiften/Network/DummyTimerFactory.h" #include "Swiften/EventLoop/MainEventLoop.h" #include "Swiften/EventLoop/DummyEventLoop.h" using namespace Swift; class ConnectorTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(ConnectorTest); CPPUNIT_TEST(testConnect); CPPUNIT_TEST(testConnect_NoSRVHost); CPPUNIT_TEST(testConnect_NoHosts); CPPUNIT_TEST(testConnect_FirstSRVHostFails); CPPUNIT_TEST(testConnect_AllSRVHostsFailWithoutFallbackHost); CPPUNIT_TEST(testConnect_AllSRVHostsFailWithFallbackHost); CPPUNIT_TEST(testConnect_SRVAndFallbackHostsFail); - //CPPUNIT_TEST(testConnect_TimeoutDuringResolve); - //CPPUNIT_TEST(testConnect_TimeoutDuringConnect); - //CPPUNIT_TEST(testConnect_NoTimeout); + CPPUNIT_TEST(testConnect_TimeoutDuringResolve); + CPPUNIT_TEST(testConnect_TimeoutDuringConnect); + CPPUNIT_TEST(testConnect_NoTimeout); CPPUNIT_TEST_SUITE_END(); public: ConnectorTest() : host1(HostAddress("1.1.1.1"), 1234), host2(HostAddress("2.2.2.2"), 2345), host3(HostAddress("3.3.3.3"), 5222) { } void setUp() { eventLoop = new DummyEventLoop(); resolver = new StaticDomainNameResolver(); connectionFactory = new MockConnectionFactory(); timerFactory = new DummyTimerFactory(); } void tearDown() { delete timerFactory; delete connectionFactory; delete resolver; delete eventLoop; } void testConnect() { std::auto_ptr<Connector> testling(createConnector()); resolver->addXMPPClientService("foo.com", host1); resolver->addXMPPClientService("foo.com", host2); resolver->addAddress("foo.com", host3.getAddress()); testling->start(); eventLoop->processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); CPPUNIT_ASSERT(connections[0]); CPPUNIT_ASSERT(host1 == *(connections[0]->hostAddressPort)); } void testConnect_NoSRVHost() { std::auto_ptr<Connector> testling(createConnector()); resolver->addAddress("foo.com", host3.getAddress()); testling->start(); eventLoop->processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); CPPUNIT_ASSERT(connections[0]); CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort)); } void testConnect_NoHosts() { std::auto_ptr<Connector> testling(createConnector()); @@ -141,97 +141,97 @@ class ConnectorTest : public CppUnit::TestFixture { } void testConnect_TimeoutDuringResolve() { std::auto_ptr<Connector> testling(createConnector()); testling->setTimeoutMilliseconds(10); resolver->setIsResponsive(false); testling->start(); eventLoop->processEvents(); timerFactory->setTime(10); eventLoop->processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); CPPUNIT_ASSERT(!connections[0]); } void testConnect_TimeoutDuringConnect() { std::auto_ptr<Connector> testling(createConnector()); testling->setTimeoutMilliseconds(10); resolver->addXMPPClientService("foo.com", host1); connectionFactory->isResponsive = false; testling->start(); eventLoop->processEvents(); timerFactory->setTime(10); eventLoop->processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); CPPUNIT_ASSERT(!connections[0]); } void testConnect_NoTimeout() { std::auto_ptr<Connector> testling(createConnector()); testling->setTimeoutMilliseconds(10); resolver->addXMPPClientService("foo.com", host1); testling->start(); eventLoop->processEvents(); timerFactory->setTime(10); eventLoop->processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); CPPUNIT_ASSERT(connections[0]); } private: Connector* createConnector() { - Connector* connector = new Connector("foo.com", resolver, connectionFactory); + Connector* connector = new Connector("foo.com", resolver, connectionFactory, timerFactory); connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1)); return connector; } void handleConnectorFinished(boost::shared_ptr<Connection> connection) { boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection)); if (connection) { assert(c); } connections.push_back(c); } struct MockConnection : public Connection { public: MockConnection(const std::vector<HostAddressPort>& failingPorts, bool isResponsive) : failingPorts(failingPorts), isResponsive(isResponsive) {} void listen() { assert(false); } void connect(const HostAddressPort& address) { hostAddressPort = address; if (isResponsive) { MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end())); } } void disconnect() { assert(false); } void write(const ByteArray&) { assert(false); } boost::optional<HostAddressPort> hostAddressPort; std::vector<HostAddressPort> failingPorts; bool isResponsive; }; struct MockConnectionFactory : public ConnectionFactory { MockConnectionFactory() : isResponsive(true) { } boost::shared_ptr<Connection> createConnection() { return boost::shared_ptr<Connection>(new MockConnection(failingPorts, isResponsive)); } bool isResponsive; std::vector<HostAddressPort> failingPorts; }; private: HostAddressPort host1; HostAddressPort host2; HostAddressPort host3; |