summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'Swiften/Network')
-rw-r--r--Swiften/Network/Connector.cpp27
-rw-r--r--Swiften/Network/Connector.h8
-rw-r--r--Swiften/Network/DummyTimerFactory.cpp7
-rw-r--r--Swiften/Network/UnitTest/ConnectorTest.cpp8
4 files changed, 39 insertions, 11 deletions
diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp
index 9ea5a7f..d372bf2 100644
--- a/Swiften/Network/Connector.cpp
+++ b/Swiften/Network/Connector.cpp
@@ -1,107 +1,126 @@
#include "Swiften/Network/Connector.h"
#include <boost/bind.hpp>
#include <iostream>
#include "Swiften/Network/ConnectionFactory.h"
#include "Swiften/Network/DomainNameResolver.h"
#include "Swiften/Network/DomainNameAddressQuery.h"
+#include "Swiften/Network/TimerFactory.h"
namespace Swift {
-Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), queriedAllHosts(true) {
+Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), timeoutMilliseconds(0), queriedAllHosts(true) {
}
void Connector::setTimeoutMilliseconds(int milliseconds) {
timeoutMilliseconds = milliseconds;
}
void Connector::start() {
//std::cout << "Connector::start()" << std::endl;
assert(!currentConnection);
assert(!serviceQuery);
+ assert(!timer);
queriedAllHosts = false;
serviceQuery = resolver->createServiceQuery("_xmpp-client._tcp." + hostname);
serviceQuery->onResult.connect(boost::bind(&Connector::handleServiceQueryResult, this, _1));
+ if (timeoutMilliseconds > 0) {
+ timer = timerFactory->createTimer(timeoutMilliseconds);
+ timer->onTick.connect(boost::bind(&Connector::handleTimeout, this));
+ timer->start();
+ }
serviceQuery->run();
}
void Connector::queryAddress(const String& hostname) {
assert(!addressQuery);
addressQuery = resolver->createAddressQuery(hostname);
addressQuery->onResult.connect(boost::bind(&Connector::handleAddressQueryResult, this, _1, _2));
addressQuery->run();
}
void Connector::handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result) {
//std::cout << "Received SRV results" << std::endl;
serviceQueryResults = std::deque<DomainNameServiceQuery::Result>(result.begin(), result.end());
serviceQuery.reset();
tryNextHostname();
}
void Connector::tryNextHostname() {
if (queriedAllHosts) {
//std::cout << "Connector::tryNextHostName(): Queried all hosts. Error." << std::endl;
- onConnectFinished(boost::shared_ptr<Connection>());
+ finish(boost::shared_ptr<Connection>());
}
else if (serviceQueryResults.empty()) {
//std::cout << "Connector::tryNextHostName(): Falling back on A resolution" << std::endl;
// Fall back on simple address resolving
queriedAllHosts = true;
queryAddress(hostname);
}
else {
//std::cout << "Connector::tryNextHostName(): Querying next address" << std::endl;
queryAddress(serviceQueryResults.front().hostname);
}
}
void Connector::handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error) {
//std::cout << "Connector::handleAddressQueryResult(): Start" << std::endl;
addressQuery.reset();
if (!serviceQueryResults.empty()) {
DomainNameServiceQuery::Result serviceQueryResult = serviceQueryResults.front();
serviceQueryResults.pop_front();
if (error) {
//std::cout << "Connector::handleAddressQueryResult(): A lookup for SRV host " << serviceQueryResult.hostname << " failed." << std::endl;
tryNextHostname();
}
else {
//std::cout << "Connector::handleAddressQueryResult(): A lookup for SRV host " << serviceQueryResult.hostname << " succeeded: " << address.toString() << std::endl;
tryConnect(HostAddressPort(address, serviceQueryResult.port));
}
}
else if (error) {
//std::cout << "Connector::handleAddressQueryResult(): Fallback address query failed. Giving up" << std::endl;
// The fallback address query failed
assert(queriedAllHosts);
- onConnectFinished(boost::shared_ptr<Connection>());
+ finish(boost::shared_ptr<Connection>());
}
else {
//std::cout << "Connector::handleAddressQueryResult(): Fallback address query succeeded: " << address.toString() << std::endl;
// The fallback query succeeded
tryConnect(HostAddressPort(address, 5222));
}
}
void Connector::tryConnect(const HostAddressPort& target) {
assert(!currentConnection);
//std::cout << "Connector::tryConnect() " << target.getAddress().toString() << " " << target.getPort() << std::endl;
currentConnection = connectionFactory->createConnection();
currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, this, _1));
currentConnection->connect(target);
}
void Connector::handleConnectionConnectFinished(bool error) {
//std::cout << "Connector::handleConnectionConnectFinished() " << error << std::endl;
if (error) {
currentConnection.reset();
tryNextHostname();
}
else {
- onConnectFinished(currentConnection);
+ finish(currentConnection);
+ }
+}
+
+void Connector::finish(boost::shared_ptr<Connection> connection) {
+ if (timer) {
+ timer->stop();
+ timer.reset();
}
+ onConnectFinished(connection);
+}
+
+void Connector::handleTimeout() {
+ finish(boost::shared_ptr<Connection>());
}
};
diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h
index 6df3970..507f085 100644
--- a/Swiften/Network/Connector.h
+++ b/Swiften/Network/Connector.h
@@ -1,48 +1,54 @@
#pragma once
#include <deque>
#include <boost/signal.hpp>
#include <boost/shared_ptr.hpp>
#include "Swiften/Network/DomainNameServiceQuery.h"
#include "Swiften/Network/Connection.h"
+#include "Swiften/Network/Timer.h"
#include "Swiften/Network/HostAddressPort.h"
#include "Swiften/Base/String.h"
#include "Swiften/Network/DomainNameResolveError.h"
namespace Swift {
class DomainNameAddressQuery;
class DomainNameResolver;
class ConnectionFactory;
+ class TimerFactory;
class Connector : public boost::bsignals::trackable {
public:
- Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*);
+ Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*, TimerFactory*);
void setTimeoutMilliseconds(int milliseconds);
void start();
boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished;
private:
void handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result);
void handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error);
void queryAddress(const String& hostname);
void tryNextHostname();
void tryConnect(const HostAddressPort& target);
void handleConnectionConnectFinished(bool error);
+ void finish(boost::shared_ptr<Connection>);
+ void handleTimeout();
private:
String hostname;
DomainNameResolver* resolver;
ConnectionFactory* connectionFactory;
+ TimerFactory* timerFactory;
int timeoutMilliseconds;
+ boost::shared_ptr<Timer> timer;
boost::shared_ptr<DomainNameServiceQuery> serviceQuery;
std::deque<DomainNameServiceQuery::Result> serviceQueryResults;
boost::shared_ptr<DomainNameAddressQuery> addressQuery;
bool queriedAllHosts;
boost::shared_ptr<Connection> currentConnection;
};
};
diff --git a/Swiften/Network/DummyTimerFactory.cpp b/Swiften/Network/DummyTimerFactory.cpp
index 72523bb..7626584 100644
--- a/Swiften/Network/DummyTimerFactory.cpp
+++ b/Swiften/Network/DummyTimerFactory.cpp
@@ -1,57 +1,60 @@
#include "Swiften/Network/DummyTimerFactory.h"
#include <algorithm>
-#include "Swiften/Network/Timer.h"
#include "Swiften/Base/foreach.h"
+#include "Swiften/Network/Timer.h"
namespace Swift {
class DummyTimerFactory::DummyTimer : public Timer {
public:
DummyTimer(int timeout) : timeout(timeout), isRunning(false) {
}
virtual void start() {
isRunning = true;
}
virtual void stop() {
isRunning = false;
}
int timeout;
bool isRunning;
};
DummyTimerFactory::DummyTimerFactory() : currentTime(0) {
}
boost::shared_ptr<Timer> DummyTimerFactory::createTimer(int milliseconds) {
boost::shared_ptr<DummyTimer> timer(new DummyTimer(milliseconds));
timers.push_back(timer);
return timer;
}
static bool hasZeroTimeout(boost::shared_ptr<DummyTimerFactory::DummyTimer> timer) {
return timer->timeout == 0;
}
void DummyTimerFactory::setTime(int time) {
assert(time > currentTime);
- int increment = currentTime - time;
+ int increment = time - currentTime;
std::vector< boost::shared_ptr<DummyTimer> > notifyTimers(timers.begin(), timers.end());
foreach(boost::shared_ptr<DummyTimer> timer, notifyTimers) {
if (increment >= timer->timeout) {
if (timer->isRunning) {
timer->onTick();
}
timer->timeout = 0;
}
+ else {
+ timer->timeout -= increment;
+ }
}
timers.erase(std::remove_if(timers.begin(), timers.end(), hasZeroTimeout), timers.end());
currentTime = time;
}
}
diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp
index 08b9bc1..663011c 100644
--- a/Swiften/Network/UnitTest/ConnectorTest.cpp
+++ b/Swiften/Network/UnitTest/ConnectorTest.cpp
@@ -1,77 +1,77 @@
#include <cppunit/extensions/HelperMacros.h>
#include <cppunit/extensions/TestFactoryRegistry.h>
#include <boost/optional.hpp>
#include <boost/bind.hpp>
#include "Swiften/Network/Connector.h"
#include "Swiften/Network/Connection.h"
#include "Swiften/Network/ConnectionFactory.h"
#include "Swiften/Network/HostAddressPort.h"
#include "Swiften/Network/StaticDomainNameResolver.h"
#include "Swiften/Network/DummyTimerFactory.h"
#include "Swiften/EventLoop/MainEventLoop.h"
#include "Swiften/EventLoop/DummyEventLoop.h"
using namespace Swift;
class ConnectorTest : public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE(ConnectorTest);
CPPUNIT_TEST(testConnect);
CPPUNIT_TEST(testConnect_NoSRVHost);
CPPUNIT_TEST(testConnect_NoHosts);
CPPUNIT_TEST(testConnect_FirstSRVHostFails);
CPPUNIT_TEST(testConnect_AllSRVHostsFailWithoutFallbackHost);
CPPUNIT_TEST(testConnect_AllSRVHostsFailWithFallbackHost);
CPPUNIT_TEST(testConnect_SRVAndFallbackHostsFail);
- //CPPUNIT_TEST(testConnect_TimeoutDuringResolve);
- //CPPUNIT_TEST(testConnect_TimeoutDuringConnect);
- //CPPUNIT_TEST(testConnect_NoTimeout);
+ CPPUNIT_TEST(testConnect_TimeoutDuringResolve);
+ CPPUNIT_TEST(testConnect_TimeoutDuringConnect);
+ CPPUNIT_TEST(testConnect_NoTimeout);
CPPUNIT_TEST_SUITE_END();
public:
ConnectorTest() : host1(HostAddress("1.1.1.1"), 1234), host2(HostAddress("2.2.2.2"), 2345), host3(HostAddress("3.3.3.3"), 5222) {
}
void setUp() {
eventLoop = new DummyEventLoop();
resolver = new StaticDomainNameResolver();
connectionFactory = new MockConnectionFactory();
timerFactory = new DummyTimerFactory();
}
void tearDown() {
delete timerFactory;
delete connectionFactory;
delete resolver;
delete eventLoop;
}
void testConnect() {
std::auto_ptr<Connector> testling(createConnector());
resolver->addXMPPClientService("foo.com", host1);
resolver->addXMPPClientService("foo.com", host2);
resolver->addAddress("foo.com", host3.getAddress());
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
CPPUNIT_ASSERT(host1 == *(connections[0]->hostAddressPort));
}
void testConnect_NoSRVHost() {
std::auto_ptr<Connector> testling(createConnector());
resolver->addAddress("foo.com", host3.getAddress());
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort));
}
void testConnect_NoHosts() {
std::auto_ptr<Connector> testling(createConnector());
@@ -141,97 +141,97 @@ class ConnectorTest : public CppUnit::TestFixture {
}
void testConnect_TimeoutDuringResolve() {
std::auto_ptr<Connector> testling(createConnector());
testling->setTimeoutMilliseconds(10);
resolver->setIsResponsive(false);
testling->start();
eventLoop->processEvents();
timerFactory->setTime(10);
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(!connections[0]);
}
void testConnect_TimeoutDuringConnect() {
std::auto_ptr<Connector> testling(createConnector());
testling->setTimeoutMilliseconds(10);
resolver->addXMPPClientService("foo.com", host1);
connectionFactory->isResponsive = false;
testling->start();
eventLoop->processEvents();
timerFactory->setTime(10);
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(!connections[0]);
}
void testConnect_NoTimeout() {
std::auto_ptr<Connector> testling(createConnector());
testling->setTimeoutMilliseconds(10);
resolver->addXMPPClientService("foo.com", host1);
testling->start();
eventLoop->processEvents();
timerFactory->setTime(10);
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
}
private:
Connector* createConnector() {
- Connector* connector = new Connector("foo.com", resolver, connectionFactory);
+ Connector* connector = new Connector("foo.com", resolver, connectionFactory, timerFactory);
connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1));
return connector;
}
void handleConnectorFinished(boost::shared_ptr<Connection> connection) {
boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection));
if (connection) {
assert(c);
}
connections.push_back(c);
}
struct MockConnection : public Connection {
public:
MockConnection(const std::vector<HostAddressPort>& failingPorts, bool isResponsive) : failingPorts(failingPorts), isResponsive(isResponsive) {}
void listen() { assert(false); }
void connect(const HostAddressPort& address) {
hostAddressPort = address;
if (isResponsive) {
MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end()));
}
}
void disconnect() { assert(false); }
void write(const ByteArray&) { assert(false); }
boost::optional<HostAddressPort> hostAddressPort;
std::vector<HostAddressPort> failingPorts;
bool isResponsive;
};
struct MockConnectionFactory : public ConnectionFactory {
MockConnectionFactory() : isResponsive(true) {
}
boost::shared_ptr<Connection> createConnection() {
return boost::shared_ptr<Connection>(new MockConnection(failingPorts, isResponsive));
}
bool isResponsive;
std::vector<HostAddressPort> failingPorts;
};
private:
HostAddressPort host1;
HostAddressPort host2;
HostAddressPort host3;