summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Swiften/Client/Client.cpp3
-rw-r--r--Swiften/Elements/MUCPayload.h13
-rw-r--r--Swiften/Network/Connector.cpp27
-rw-r--r--Swiften/Network/Connector.h8
-rw-r--r--Swiften/Network/DummyTimerFactory.cpp7
-rw-r--r--Swiften/Network/UnitTest/ConnectorTest.cpp8
6 files changed, 45 insertions, 21 deletions
diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp
index 27f3d9c..c704248 100644
--- a/Swiften/Client/Client.cpp
+++ b/Swiften/Client/Client.cpp
@@ -1,88 +1,89 @@
#include "Swiften/Client/Client.h"
#include <boost/bind.hpp>
#include "Swiften/Network/MainBoostIOServiceThread.h"
#include "Swiften/Network/BoostIOServiceThread.h"
#include "Swiften/Client/ClientSession.h"
#include "Swiften/StreamStack/PlatformTLSLayerFactory.h"
#include "Swiften/Network/Connector.h"
#include "Swiften/Network/BoostConnectionFactory.h"
#include "Swiften/Network/BoostTimerFactory.h"
#include "Swiften/TLS/PKCS12Certificate.h"
#include "Swiften/Session/BasicSessionStream.h"
namespace Swift {
Client::Client(const JID& jid, const String& password) :
IQRouter(this), jid_(jid), password_(password) {
connectionFactory_ = new BoostConnectionFactory(&MainBoostIOServiceThread::getInstance().getIOService());
timerFactory_ = new BoostTimerFactory(&MainBoostIOServiceThread::getInstance().getIOService());
tlsLayerFactory_ = new PlatformTLSLayerFactory();
}
Client::~Client() {
if (session_ || connection_) {
std::cerr << "Warning: Client not disconnected properly" << std::endl;
}
delete tlsLayerFactory_;
delete timerFactory_;
delete connectionFactory_;
}
bool Client::isAvailable() {
return session_;
}
void Client::connect() {
assert(!connector_);
- connector_ = boost::shared_ptr<Connector>(new Connector(jid_.getDomain(), &resolver_, connectionFactory_));
+ connector_ = boost::shared_ptr<Connector>(new Connector(jid_.getDomain(), &resolver_, connectionFactory_, timerFactory_));
connector_->onConnectFinished.connect(boost::bind(&Client::handleConnectorFinished, this, _1));
+ connector_->setTimeoutMilliseconds(60*1000);
connector_->start();
}
void Client::handleConnectorFinished(boost::shared_ptr<Connection> connection) {
// TODO: Add domain name resolver error
connector_.reset();
if (!connection) {
onError(ClientError::ConnectionError);
}
else {
assert(!connection_);
connection_ = connection;
assert(!sessionStream_);
sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(connection_, &payloadParserFactories_, &payloadSerializers_, tlsLayerFactory_, timerFactory_));
if (!certificate_.isEmpty()) {
sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_));
}
sessionStream_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1));
sessionStream_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1));
sessionStream_->initialize();
session_ = ClientSession::create(jid_, sessionStream_);
session_->onInitialized.connect(boost::bind(boost::ref(onConnected)));
session_->onFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1));
session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this));
session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1));
session_->start();
}
}
void Client::disconnect() {
if (session_) {
session_->finish();
}
else {
closeConnection();
}
}
void Client::closeConnection() {
if (sessionStream_) {
sessionStream_.reset();
}
if (connection_) {
connection_->disconnect();
connection_.reset();
}
diff --git a/Swiften/Elements/MUCPayload.h b/Swiften/Elements/MUCPayload.h
index 205ae46..97932a1 100644
--- a/Swiften/Elements/MUCPayload.h
+++ b/Swiften/Elements/MUCPayload.h
@@ -1,16 +1,11 @@
-#ifndef SWIFTEN_MUCPayload_H
-#define SWIFTEN_MUCPayload_H
+#pragma once
-#include "Swiften/Base/String.h"
#include "Swiften/Elements/Payload.h"
namespace Swift {
- class MUCPayload : public Payload
- {
+ class MUCPayload : public Payload {
public:
- MUCPayload() { }
-
+ MUCPayload() {
+ }
};
}
-
-#endif
diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp
index 9ea5a7f..d372bf2 100644
--- a/Swiften/Network/Connector.cpp
+++ b/Swiften/Network/Connector.cpp
@@ -1,107 +1,126 @@
#include "Swiften/Network/Connector.h"
#include <boost/bind.hpp>
#include <iostream>
#include "Swiften/Network/ConnectionFactory.h"
#include "Swiften/Network/DomainNameResolver.h"
#include "Swiften/Network/DomainNameAddressQuery.h"
+#include "Swiften/Network/TimerFactory.h"
namespace Swift {
-Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), queriedAllHosts(true) {
+Connector::Connector(const String& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), timeoutMilliseconds(0), queriedAllHosts(true) {
}
void Connector::setTimeoutMilliseconds(int milliseconds) {
timeoutMilliseconds = milliseconds;
}
void Connector::start() {
//std::cout << "Connector::start()" << std::endl;
assert(!currentConnection);
assert(!serviceQuery);
+ assert(!timer);
queriedAllHosts = false;
serviceQuery = resolver->createServiceQuery("_xmpp-client._tcp." + hostname);
serviceQuery->onResult.connect(boost::bind(&Connector::handleServiceQueryResult, this, _1));
+ if (timeoutMilliseconds > 0) {
+ timer = timerFactory->createTimer(timeoutMilliseconds);
+ timer->onTick.connect(boost::bind(&Connector::handleTimeout, this));
+ timer->start();
+ }
serviceQuery->run();
}
void Connector::queryAddress(const String& hostname) {
assert(!addressQuery);
addressQuery = resolver->createAddressQuery(hostname);
addressQuery->onResult.connect(boost::bind(&Connector::handleAddressQueryResult, this, _1, _2));
addressQuery->run();
}
void Connector::handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result) {
//std::cout << "Received SRV results" << std::endl;
serviceQueryResults = std::deque<DomainNameServiceQuery::Result>(result.begin(), result.end());
serviceQuery.reset();
tryNextHostname();
}
void Connector::tryNextHostname() {
if (queriedAllHosts) {
//std::cout << "Connector::tryNextHostName(): Queried all hosts. Error." << std::endl;
- onConnectFinished(boost::shared_ptr<Connection>());
+ finish(boost::shared_ptr<Connection>());
}
else if (serviceQueryResults.empty()) {
//std::cout << "Connector::tryNextHostName(): Falling back on A resolution" << std::endl;
// Fall back on simple address resolving
queriedAllHosts = true;
queryAddress(hostname);
}
else {
//std::cout << "Connector::tryNextHostName(): Querying next address" << std::endl;
queryAddress(serviceQueryResults.front().hostname);
}
}
void Connector::handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error) {
//std::cout << "Connector::handleAddressQueryResult(): Start" << std::endl;
addressQuery.reset();
if (!serviceQueryResults.empty()) {
DomainNameServiceQuery::Result serviceQueryResult = serviceQueryResults.front();
serviceQueryResults.pop_front();
if (error) {
//std::cout << "Connector::handleAddressQueryResult(): A lookup for SRV host " << serviceQueryResult.hostname << " failed." << std::endl;
tryNextHostname();
}
else {
//std::cout << "Connector::handleAddressQueryResult(): A lookup for SRV host " << serviceQueryResult.hostname << " succeeded: " << address.toString() << std::endl;
tryConnect(HostAddressPort(address, serviceQueryResult.port));
}
}
else if (error) {
//std::cout << "Connector::handleAddressQueryResult(): Fallback address query failed. Giving up" << std::endl;
// The fallback address query failed
assert(queriedAllHosts);
- onConnectFinished(boost::shared_ptr<Connection>());
+ finish(boost::shared_ptr<Connection>());
}
else {
//std::cout << "Connector::handleAddressQueryResult(): Fallback address query succeeded: " << address.toString() << std::endl;
// The fallback query succeeded
tryConnect(HostAddressPort(address, 5222));
}
}
void Connector::tryConnect(const HostAddressPort& target) {
assert(!currentConnection);
//std::cout << "Connector::tryConnect() " << target.getAddress().toString() << " " << target.getPort() << std::endl;
currentConnection = connectionFactory->createConnection();
currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, this, _1));
currentConnection->connect(target);
}
void Connector::handleConnectionConnectFinished(bool error) {
//std::cout << "Connector::handleConnectionConnectFinished() " << error << std::endl;
if (error) {
currentConnection.reset();
tryNextHostname();
}
else {
- onConnectFinished(currentConnection);
+ finish(currentConnection);
+ }
+}
+
+void Connector::finish(boost::shared_ptr<Connection> connection) {
+ if (timer) {
+ timer->stop();
+ timer.reset();
}
+ onConnectFinished(connection);
+}
+
+void Connector::handleTimeout() {
+ finish(boost::shared_ptr<Connection>());
}
};
diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h
index 6df3970..507f085 100644
--- a/Swiften/Network/Connector.h
+++ b/Swiften/Network/Connector.h
@@ -1,48 +1,54 @@
#pragma once
#include <deque>
#include <boost/signal.hpp>
#include <boost/shared_ptr.hpp>
#include "Swiften/Network/DomainNameServiceQuery.h"
#include "Swiften/Network/Connection.h"
+#include "Swiften/Network/Timer.h"
#include "Swiften/Network/HostAddressPort.h"
#include "Swiften/Base/String.h"
#include "Swiften/Network/DomainNameResolveError.h"
namespace Swift {
class DomainNameAddressQuery;
class DomainNameResolver;
class ConnectionFactory;
+ class TimerFactory;
class Connector : public boost::bsignals::trackable {
public:
- Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*);
+ Connector(const String& hostname, DomainNameResolver*, ConnectionFactory*, TimerFactory*);
void setTimeoutMilliseconds(int milliseconds);
void start();
boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished;
private:
void handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result);
void handleAddressQueryResult(const HostAddress& address, boost::optional<DomainNameResolveError> error);
void queryAddress(const String& hostname);
void tryNextHostname();
void tryConnect(const HostAddressPort& target);
void handleConnectionConnectFinished(bool error);
+ void finish(boost::shared_ptr<Connection>);
+ void handleTimeout();
private:
String hostname;
DomainNameResolver* resolver;
ConnectionFactory* connectionFactory;
+ TimerFactory* timerFactory;
int timeoutMilliseconds;
+ boost::shared_ptr<Timer> timer;
boost::shared_ptr<DomainNameServiceQuery> serviceQuery;
std::deque<DomainNameServiceQuery::Result> serviceQueryResults;
boost::shared_ptr<DomainNameAddressQuery> addressQuery;
bool queriedAllHosts;
boost::shared_ptr<Connection> currentConnection;
};
};
diff --git a/Swiften/Network/DummyTimerFactory.cpp b/Swiften/Network/DummyTimerFactory.cpp
index 72523bb..7626584 100644
--- a/Swiften/Network/DummyTimerFactory.cpp
+++ b/Swiften/Network/DummyTimerFactory.cpp
@@ -1,57 +1,60 @@
#include "Swiften/Network/DummyTimerFactory.h"
#include <algorithm>
-#include "Swiften/Network/Timer.h"
#include "Swiften/Base/foreach.h"
+#include "Swiften/Network/Timer.h"
namespace Swift {
class DummyTimerFactory::DummyTimer : public Timer {
public:
DummyTimer(int timeout) : timeout(timeout), isRunning(false) {
}
virtual void start() {
isRunning = true;
}
virtual void stop() {
isRunning = false;
}
int timeout;
bool isRunning;
};
DummyTimerFactory::DummyTimerFactory() : currentTime(0) {
}
boost::shared_ptr<Timer> DummyTimerFactory::createTimer(int milliseconds) {
boost::shared_ptr<DummyTimer> timer(new DummyTimer(milliseconds));
timers.push_back(timer);
return timer;
}
static bool hasZeroTimeout(boost::shared_ptr<DummyTimerFactory::DummyTimer> timer) {
return timer->timeout == 0;
}
void DummyTimerFactory::setTime(int time) {
assert(time > currentTime);
- int increment = currentTime - time;
+ int increment = time - currentTime;
std::vector< boost::shared_ptr<DummyTimer> > notifyTimers(timers.begin(), timers.end());
foreach(boost::shared_ptr<DummyTimer> timer, notifyTimers) {
if (increment >= timer->timeout) {
if (timer->isRunning) {
timer->onTick();
}
timer->timeout = 0;
}
+ else {
+ timer->timeout -= increment;
+ }
}
timers.erase(std::remove_if(timers.begin(), timers.end(), hasZeroTimeout), timers.end());
currentTime = time;
}
}
diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp
index 08b9bc1..663011c 100644
--- a/Swiften/Network/UnitTest/ConnectorTest.cpp
+++ b/Swiften/Network/UnitTest/ConnectorTest.cpp
@@ -1,77 +1,77 @@
#include <cppunit/extensions/HelperMacros.h>
#include <cppunit/extensions/TestFactoryRegistry.h>
#include <boost/optional.hpp>
#include <boost/bind.hpp>
#include "Swiften/Network/Connector.h"
#include "Swiften/Network/Connection.h"
#include "Swiften/Network/ConnectionFactory.h"
#include "Swiften/Network/HostAddressPort.h"
#include "Swiften/Network/StaticDomainNameResolver.h"
#include "Swiften/Network/DummyTimerFactory.h"
#include "Swiften/EventLoop/MainEventLoop.h"
#include "Swiften/EventLoop/DummyEventLoop.h"
using namespace Swift;
class ConnectorTest : public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE(ConnectorTest);
CPPUNIT_TEST(testConnect);
CPPUNIT_TEST(testConnect_NoSRVHost);
CPPUNIT_TEST(testConnect_NoHosts);
CPPUNIT_TEST(testConnect_FirstSRVHostFails);
CPPUNIT_TEST(testConnect_AllSRVHostsFailWithoutFallbackHost);
CPPUNIT_TEST(testConnect_AllSRVHostsFailWithFallbackHost);
CPPUNIT_TEST(testConnect_SRVAndFallbackHostsFail);
- //CPPUNIT_TEST(testConnect_TimeoutDuringResolve);
- //CPPUNIT_TEST(testConnect_TimeoutDuringConnect);
- //CPPUNIT_TEST(testConnect_NoTimeout);
+ CPPUNIT_TEST(testConnect_TimeoutDuringResolve);
+ CPPUNIT_TEST(testConnect_TimeoutDuringConnect);
+ CPPUNIT_TEST(testConnect_NoTimeout);
CPPUNIT_TEST_SUITE_END();
public:
ConnectorTest() : host1(HostAddress("1.1.1.1"), 1234), host2(HostAddress("2.2.2.2"), 2345), host3(HostAddress("3.3.3.3"), 5222) {
}
void setUp() {
eventLoop = new DummyEventLoop();
resolver = new StaticDomainNameResolver();
connectionFactory = new MockConnectionFactory();
timerFactory = new DummyTimerFactory();
}
void tearDown() {
delete timerFactory;
delete connectionFactory;
delete resolver;
delete eventLoop;
}
void testConnect() {
std::auto_ptr<Connector> testling(createConnector());
resolver->addXMPPClientService("foo.com", host1);
resolver->addXMPPClientService("foo.com", host2);
resolver->addAddress("foo.com", host3.getAddress());
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
CPPUNIT_ASSERT(host1 == *(connections[0]->hostAddressPort));
}
void testConnect_NoSRVHost() {
std::auto_ptr<Connector> testling(createConnector());
resolver->addAddress("foo.com", host3.getAddress());
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort));
}
void testConnect_NoHosts() {
std::auto_ptr<Connector> testling(createConnector());
@@ -141,97 +141,97 @@ class ConnectorTest : public CppUnit::TestFixture {
}
void testConnect_TimeoutDuringResolve() {
std::auto_ptr<Connector> testling(createConnector());
testling->setTimeoutMilliseconds(10);
resolver->setIsResponsive(false);
testling->start();
eventLoop->processEvents();
timerFactory->setTime(10);
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(!connections[0]);
}
void testConnect_TimeoutDuringConnect() {
std::auto_ptr<Connector> testling(createConnector());
testling->setTimeoutMilliseconds(10);
resolver->addXMPPClientService("foo.com", host1);
connectionFactory->isResponsive = false;
testling->start();
eventLoop->processEvents();
timerFactory->setTime(10);
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(!connections[0]);
}
void testConnect_NoTimeout() {
std::auto_ptr<Connector> testling(createConnector());
testling->setTimeoutMilliseconds(10);
resolver->addXMPPClientService("foo.com", host1);
testling->start();
eventLoop->processEvents();
timerFactory->setTime(10);
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
}
private:
Connector* createConnector() {
- Connector* connector = new Connector("foo.com", resolver, connectionFactory);
+ Connector* connector = new Connector("foo.com", resolver, connectionFactory, timerFactory);
connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1));
return connector;
}
void handleConnectorFinished(boost::shared_ptr<Connection> connection) {
boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection));
if (connection) {
assert(c);
}
connections.push_back(c);
}
struct MockConnection : public Connection {
public:
MockConnection(const std::vector<HostAddressPort>& failingPorts, bool isResponsive) : failingPorts(failingPorts), isResponsive(isResponsive) {}
void listen() { assert(false); }
void connect(const HostAddressPort& address) {
hostAddressPort = address;
if (isResponsive) {
MainEventLoop::postEvent(boost::bind(boost::ref(onConnectFinished), std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end()));
}
}
void disconnect() { assert(false); }
void write(const ByteArray&) { assert(false); }
boost::optional<HostAddressPort> hostAddressPort;
std::vector<HostAddressPort> failingPorts;
bool isResponsive;
};
struct MockConnectionFactory : public ConnectionFactory {
MockConnectionFactory() : isResponsive(true) {
}
boost::shared_ptr<Connection> createConnection() {
return boost::shared_ptr<Connection>(new MockConnection(failingPorts, isResponsive));
}
bool isResponsive;
std::vector<HostAddressPort> failingPorts;
};
private:
HostAddressPort host1;
HostAddressPort host2;
HostAddressPort host3;