summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Swiften/Client/CoreClient.cpp15
-rw-r--r--Swiften/Client/CoreClient.h2
-rw-r--r--Swiften/Elements/DeliveryReceipt.h2
-rw-r--r--Swiften/Network/ChainedConnector.cpp19
-rw-r--r--Swiften/Network/ChainedConnector.h8
-rw-r--r--Swiften/Network/Connector.cpp8
-rw-r--r--Swiften/Network/Connector.h3
-rw-r--r--Swiften/Network/UnitTest/ChainedConnectorTest.cpp29
-rw-r--r--Swiften/Network/UnitTest/ConnectorTest.cpp34
9 files changed, 88 insertions, 32 deletions
diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp
index e2a8e5a..f7e3b21 100644
--- a/Swiften/Client/CoreClient.cpp
+++ b/Swiften/Client/CoreClient.cpp
@@ -1,173 +1,178 @@
/*
* Copyright (c) 2010-2011 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#include <Swiften/Client/CoreClient.h>
#include <boost/bind.hpp>
#include <boost/smart_ptr/make_shared.hpp>
#include <Swiften/Base/IDGenerator.h>
#include <Swiften/Base/Log.h>
#include <Swiften/Base/foreach.h>
#include <Swiften/Base/Algorithm.h>
#include <Swiften/Client/ClientSession.h>
#include <Swiften/TLS/CertificateVerificationError.h>
#include <Swiften/Network/ChainedConnector.h>
#include <Swiften/Network/NetworkFactories.h>
#include <Swiften/Network/ProxyProvider.h>
+#include <Swiften/Network/DomainNameResolveError.h>
#include <Swiften/TLS/PKCS12Certificate.h>
#include <Swiften/Session/BasicSessionStream.h>
#include <Swiften/Session/BOSHSessionStream.h>
#include <Swiften/Queries/IQRouter.h>
#include <Swiften/Client/ClientSessionStanzaChannel.h>
#include <Swiften/Network/SOCKS5ProxiedConnectionFactory.h>
#include <Swiften/Network/HTTPConnectProxiedConnectionFactory.h>
namespace Swift {
CoreClient::CoreClient(const JID& jid, const SafeByteArray& password, NetworkFactories* networkFactories) : jid_(jid), password_(password), networkFactories(networkFactories), disconnectRequested_(false), certificateTrustChecker(NULL) {
stanzaChannel_ = new ClientSessionStanzaChannel();
stanzaChannel_->onMessageReceived.connect(boost::bind(&CoreClient::handleMessageReceived, this, _1));
stanzaChannel_->onPresenceReceived.connect(boost::bind(&CoreClient::handlePresenceReceived, this, _1));
stanzaChannel_->onStanzaAcked.connect(boost::bind(&CoreClient::handleStanzaAcked, this, _1));
stanzaChannel_->onAvailableChanged.connect(boost::bind(&CoreClient::handleStanzaChannelAvailableChanged, this, _1));
iqRouter_ = new IQRouter(stanzaChannel_);
iqRouter_->setJID(jid);
}
CoreClient::~CoreClient() {
forceReset();
delete iqRouter_;
stanzaChannel_->onAvailableChanged.disconnect(boost::bind(&CoreClient::handleStanzaChannelAvailableChanged, this, _1));
stanzaChannel_->onMessageReceived.disconnect(boost::bind(&CoreClient::handleMessageReceived, this, _1));
stanzaChannel_->onPresenceReceived.disconnect(boost::bind(&CoreClient::handlePresenceReceived, this, _1));
stanzaChannel_->onStanzaAcked.disconnect(boost::bind(&CoreClient::handleStanzaAcked, this, _1));
delete stanzaChannel_;
}
void CoreClient::connect(const ClientOptions& o) {
SWIFT_LOG(debug) << "Connecting" << std::endl;
options = o;
connect(jid_.getDomain());
}
void CoreClient::connect(const std::string& host) {
forceReset();
SWIFT_LOG(debug) << "Connecting to host " << host << std::endl;
disconnectRequested_ = false;
assert(!connector_);
assert(proxyConnectionFactories.empty());
- if(networkFactories->getProxyProvider()->getSOCKS5Proxy().isValid()) {
+ if (networkFactories->getProxyProvider()->getSOCKS5Proxy().isValid()) {
proxyConnectionFactories.push_back(new SOCKS5ProxiedConnectionFactory(networkFactories->getConnectionFactory(), networkFactories->getProxyProvider()->getSOCKS5Proxy()));
}
if(networkFactories->getProxyProvider()->getHTTPConnectProxy().isValid()) {
proxyConnectionFactories.push_back(new HTTPConnectProxiedConnectionFactory(networkFactories->getDomainNameResolver(), networkFactories->getConnectionFactory(), networkFactories->getTimerFactory(), networkFactories->getEventLoop(), networkFactories->getProxyProvider()->getHTTPConnectProxy().getAddress().toString(), networkFactories->getProxyProvider()->getHTTPConnectProxy().getPort()));
}
std::vector<ConnectionFactory*> connectionFactories(proxyConnectionFactories);
if (options.boshURL.empty()) {
connectionFactories.push_back(networkFactories->getConnectionFactory());
connector_ = boost::make_shared<ChainedConnector>(host, networkFactories->getDomainNameResolver(), connectionFactories, networkFactories->getTimerFactory());
- connector_->onConnectFinished.connect(boost::bind(&CoreClient::handleConnectorFinished, this, _1));
+ connector_->onConnectFinished.connect(boost::bind(&CoreClient::handleConnectorFinished, this, _1, _2));
connector_->setTimeoutMilliseconds(60*1000);
connector_->start();
}
else {
/* Autodiscovery of which proxy works is largely ok with a TCP session, because this is a one-off. With BOSH
* it would be quite painful given that potentially every stanza could be sent on a new connection.
*/
//sessionStream_ = boost::make_shared<BOSHSessionStream>(boost::make_shared<BOSHConnectionFactory>(options.boshURL, networkFactories->getConnectionFactory(), networkFactories->getXMLParserFactory(), networkFactories->getTLSContextFactory()), getPayloadParserFactories(), getPayloadSerializers(), networkFactories->getTLSContextFactory(), networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory(), networkFactories->getEventLoop(), host, options.boshHTTPConnectProxyURL, options.boshHTTPConnectProxyAuthID, options.boshHTTPConnectProxyAuthPassword);
sessionStream_ = boost::shared_ptr<BOSHSessionStream>(new BOSHSessionStream(
options.boshURL,
getPayloadParserFactories(),
getPayloadSerializers(),
networkFactories->getConnectionFactory(),
networkFactories->getTLSContextFactory(),
networkFactories->getTimerFactory(),
networkFactories->getXMLParserFactory(),
networkFactories->getEventLoop(),
networkFactories->getDomainNameResolver(),
host,
options.boshHTTPConnectProxyURL,
options.boshHTTPConnectProxyAuthID,
options.boshHTTPConnectProxyAuthPassword));
sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1));
sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1));
bindSessionToStream();
}
}
void CoreClient::bindSessionToStream() {
session_ = ClientSession::create(jid_, sessionStream_);
session_->setCertificateTrustChecker(certificateTrustChecker);
session_->setUseStreamCompression(options.useStreamCompression);
session_->setAllowPLAINOverNonTLS(options.allowPLAINWithoutTLS);
switch(options.useTLS) {
case ClientOptions::UseTLSWhenAvailable:
session_->setUseTLS(ClientSession::UseTLSWhenAvailable);
break;
case ClientOptions::NeverUseTLS:
session_->setUseTLS(ClientSession::NeverUseTLS);
break;
case ClientOptions::RequireTLS:
session_->setUseTLS(ClientSession::RequireTLS);
break;
}
session_->setUseAcks(options.useAcks);
stanzaChannel_->setSession(session_);
session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1));
session_->onNeedCredentials.connect(boost::bind(&CoreClient::handleNeedCredentials, this));
session_->start();
}
/**
* Only called for TCP sessions. BOSH is handled inside the BOSHSessionStream.
*/
-void CoreClient::handleConnectorFinished(boost::shared_ptr<Connection> connection) {
+void CoreClient::handleConnectorFinished(boost::shared_ptr<Connection> connection, boost::shared_ptr<Error> error) {
resetConnector();
if (!connection) {
if (options.forgetPassword) {
purgePassword();
}
- onDisconnected(disconnectRequested_ ? boost::optional<ClientError>() : boost::optional<ClientError>(ClientError::ConnectionError));
+ boost::optional<ClientError> clientError;
+ if (!disconnectRequested_) {
+ clientError = boost::dynamic_pointer_cast<DomainNameResolveError>(error) ? boost::optional<ClientError>(ClientError::DomainNameResolveError) : boost::optional<ClientError>(ClientError::ConnectionError);
+ }
+ onDisconnected(clientError);
}
else {
assert(!connection_);
connection_ = connection;
assert(!sessionStream_);
sessionStream_ = boost::make_shared<BasicSessionStream>(ClientStreamType, connection_, getPayloadParserFactories(), getPayloadSerializers(), networkFactories->getTLSContextFactory(), networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory());
if (certificate_ && !certificate_->isNull()) {
sessionStream_->setTLSCertificate(certificate_);
}
sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1));
sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1));
bindSessionToStream();
}
}
void CoreClient::disconnect() {
// FIXME: We should be able to do without this boolean. We just have to make sure we can tell the difference between
// connector finishing without a connection due to an error or because of a disconnect.
disconnectRequested_ = true;
if (session_ && !session_->isFinished()) {
session_->finish();
}
else if (connector_) {
connector_->stop();
}
}
void CoreClient::setCertificate(CertificateWithKey::ref certificate) {
certificate_ = certificate;
}
void CoreClient::handleSessionFinished(boost::shared_ptr<Error> error) {
if (options.forgetPassword) {
@@ -324,71 +329,71 @@ void CoreClient::handlePresenceReceived(Presence::ref presence) {
void CoreClient::handleMessageReceived(Message::ref message) {
onMessageReceived(message);
}
void CoreClient::handleStanzaAcked(Stanza::ref stanza) {
onStanzaAcked(stanza);
}
bool CoreClient::isAvailable() const {
return stanzaChannel_->isAvailable();
}
bool CoreClient::getStreamManagementEnabled() const {
return stanzaChannel_->getStreamManagementEnabled();
}
StanzaChannel* CoreClient::getStanzaChannel() const {
return stanzaChannel_;
}
const JID& CoreClient::getJID() const {
if (session_) {
return session_->getLocalJID();
}
else {
return jid_;
}
}
void CoreClient::purgePassword() {
safeClear(password_);
}
void CoreClient::resetConnector() {
- connector_->onConnectFinished.disconnect(boost::bind(&CoreClient::handleConnectorFinished, this, _1));
+ connector_->onConnectFinished.disconnect(boost::bind(&CoreClient::handleConnectorFinished, this, _1, _2));
connector_.reset();
foreach(ConnectionFactory* f, proxyConnectionFactories) {
delete f;
}
proxyConnectionFactories.clear();
}
void CoreClient::resetSession() {
session_->onFinished.disconnect(boost::bind(&CoreClient::handleSessionFinished, this, _1));
session_->onNeedCredentials.disconnect(boost::bind(&CoreClient::handleNeedCredentials, this));
sessionStream_->onDataRead.disconnect(boost::bind(&CoreClient::handleDataRead, this, _1));
sessionStream_->onDataWritten.disconnect(boost::bind(&CoreClient::handleDataWritten, this, _1));
if (connection_) {
connection_->disconnect();
}
else if (boost::dynamic_pointer_cast<BOSHSessionStream>(sessionStream_)) {
sessionStream_->close();
}
sessionStream_.reset();
connection_.reset();
}
void CoreClient::forceReset() {
if (connector_) {
std::cerr << "Warning: Client not disconnected properly: Connector still active" << std::endl;
resetConnector();
}
if (sessionStream_ || connection_) {
std::cerr << "Warning: Client not disconnected properly: Session still active" << std::endl;
resetSession();
}
}
diff --git a/Swiften/Client/CoreClient.h b/Swiften/Client/CoreClient.h
index c9a6f30..cafc634 100644
--- a/Swiften/Client/CoreClient.h
+++ b/Swiften/Client/CoreClient.h
@@ -171,68 +171,68 @@ namespace Swift {
boost::signal<void (const SafeByteArray&)> onDataWritten;
/**
* Emitted when a message is received.
*/
boost::signal<void (boost::shared_ptr<Message>)> onMessageReceived;
/**
* Emitted when a presence stanza is received.
*/
boost::signal<void (boost::shared_ptr<Presence>) > onPresenceReceived;
/**
* Emitted when the server acknowledges receipt of a
* stanza (if acknowledgements are available).
*
* \see getStreamManagementEnabled()
*/
boost::signal<void (boost::shared_ptr<Stanza>)> onStanzaAcked;
protected:
boost::shared_ptr<ClientSession> getSession() const {
return session_;
}
NetworkFactories* getNetworkFactories() const {
return networkFactories;
}
/**
* Called before onConnected signal is emmitted.
*/
virtual void handleConnected() {};
private:
- void handleConnectorFinished(boost::shared_ptr<Connection>);
+ void handleConnectorFinished(boost::shared_ptr<Connection>, boost::shared_ptr<Error> error);
void handleStanzaChannelAvailableChanged(bool available);
void handleSessionFinished(boost::shared_ptr<Error>);
void handleNeedCredentials();
void handleDataRead(const SafeByteArray&);
void handleDataWritten(const SafeByteArray&);
void handlePresenceReceived(boost::shared_ptr<Presence>);
void handleMessageReceived(boost::shared_ptr<Message>);
void handleStanzaAcked(boost::shared_ptr<Stanza>);
void purgePassword();
void bindSessionToStream();
void resetConnector();
void resetSession();
void forceReset();
private:
JID jid_;
SafeByteArray password_;
NetworkFactories* networkFactories;
ClientSessionStanzaChannel* stanzaChannel_;
IQRouter* iqRouter_;
ClientOptions options;
boost::shared_ptr<ChainedConnector> connector_;
std::vector<ConnectionFactory*> proxyConnectionFactories;
boost::shared_ptr<Connection> connection_;
boost::shared_ptr<SessionStream> sessionStream_;
boost::shared_ptr<ClientSession> session_;
CertificateWithKey::ref certificate_;
bool disconnectRequested_;
CertificateTrustChecker* certificateTrustChecker;
};
}
diff --git a/Swiften/Elements/DeliveryReceipt.h b/Swiften/Elements/DeliveryReceipt.h
index f42176f..bd634db 100644
--- a/Swiften/Elements/DeliveryReceipt.h
+++ b/Swiften/Elements/DeliveryReceipt.h
@@ -1,36 +1,38 @@
/*
* Copyright (c) 2011 Tobias Markmann
* Licensed under the BSD license.
* See http://www.opensource.org/licenses/bsd-license.php for more information.
*/
#pragma once
#include <string>
+#include <string>
+
#include <Swiften/Elements/Payload.h>
namespace Swift {
class DeliveryReceipt : public Payload {
public:
typedef boost::shared_ptr<DeliveryReceipt> ref;
public:
DeliveryReceipt() {}
DeliveryReceipt(const std::string& msgId) : receivedID_(msgId) {}
void setReceivedID(const std::string& msgId) {
receivedID_ = msgId;
}
std::string getReceivedID() const {
return receivedID_;
}
private:
std::string receivedID_;
};
}
diff --git a/Swiften/Network/ChainedConnector.cpp b/Swiften/Network/ChainedConnector.cpp
index 1a38e53..0a1283f 100644
--- a/Swiften/Network/ChainedConnector.cpp
+++ b/Swiften/Network/ChainedConnector.cpp
@@ -9,74 +9,75 @@
#include <boost/bind.hpp>
#include <typeinfo>
#include <Swiften/Base/Log.h>
#include <Swiften/Base/foreach.h>
#include <Swiften/Network/Connector.h>
#include <Swiften/Network/ConnectionFactory.h>
using namespace Swift;
ChainedConnector::ChainedConnector(
const std::string& hostname,
DomainNameResolver* resolver,
const std::vector<ConnectionFactory*>& connectionFactories,
TimerFactory* timerFactory) :
hostname(hostname),
resolver(resolver),
connectionFactories(connectionFactories),
timerFactory(timerFactory),
timeoutMilliseconds(0) {
}
void ChainedConnector::setTimeoutMilliseconds(int milliseconds) {
timeoutMilliseconds = milliseconds;
}
void ChainedConnector::start() {
SWIFT_LOG(debug) << "Starting queued connector for " << hostname << std::endl;
connectionFactoryQueue = std::deque<ConnectionFactory*>(connectionFactories.begin(), connectionFactories.end());
tryNextConnectionFactory();
}
void ChainedConnector::stop() {
if (currentConnector) {
- currentConnector->onConnectFinished.disconnect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1));
+ currentConnector->onConnectFinished.disconnect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1, _2));
currentConnector->stop();
currentConnector.reset();
}
- finish(boost::shared_ptr<Connection>());
+ finish(boost::shared_ptr<Connection>(), boost::shared_ptr<Error>());
}
void ChainedConnector::tryNextConnectionFactory() {
assert(!currentConnector);
if (connectionFactoryQueue.empty()) {
SWIFT_LOG(debug) << "No more connection factories" << std::endl;
- finish(boost::shared_ptr<Connection>());
+ finish(boost::shared_ptr<Connection>(), lastError);
}
else {
ConnectionFactory* connectionFactory = connectionFactoryQueue.front();
SWIFT_LOG(debug) << "Trying next connection factory: " << typeid(*connectionFactory).name() << std::endl;
connectionFactoryQueue.pop_front();
currentConnector = Connector::create(hostname, resolver, connectionFactory, timerFactory);
currentConnector->setTimeoutMilliseconds(timeoutMilliseconds);
- currentConnector->onConnectFinished.connect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1));
+ currentConnector->onConnectFinished.connect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1, _2));
currentConnector->start();
}
}
-void ChainedConnector::handleConnectorFinished(boost::shared_ptr<Connection> connection) {
+void ChainedConnector::handleConnectorFinished(boost::shared_ptr<Connection> connection, boost::shared_ptr<Error> error) {
SWIFT_LOG(debug) << "Connector finished" << std::endl;
- currentConnector->onConnectFinished.disconnect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1));
+ currentConnector->onConnectFinished.disconnect(boost::bind(&ChainedConnector::handleConnectorFinished, this, _1, _2));
+ lastError = error;
currentConnector.reset();
if (connection) {
- finish(connection);
+ finish(connection, error);
}
else {
tryNextConnectionFactory();
}
}
-void ChainedConnector::finish(boost::shared_ptr<Connection> connection) {
- onConnectFinished(connection);
+void ChainedConnector::finish(boost::shared_ptr<Connection> connection, boost::shared_ptr<Error> error) {
+ onConnectFinished(connection, error);
}
diff --git a/Swiften/Network/ChainedConnector.h b/Swiften/Network/ChainedConnector.h
index 15b17f3..12ef023 100644
--- a/Swiften/Network/ChainedConnector.h
+++ b/Swiften/Network/ChainedConnector.h
@@ -1,47 +1,49 @@
/*
* Copyright (c) 2011 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#pragma once
#include <string>
#include <vector>
#include <deque>
#include <boost/shared_ptr.hpp>
#include <Swiften/Base/boost_bsignals.h>
+#include <Swiften/Base/Error.h>
namespace Swift {
class Connection;
class Connector;
class ConnectionFactory;
class TimerFactory;
class DomainNameResolver;
class ChainedConnector {
public:
ChainedConnector(const std::string& hostname, DomainNameResolver*, const std::vector<ConnectionFactory*>&, TimerFactory*);
void setTimeoutMilliseconds(int milliseconds);
void start();
void stop();
- boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished;
+ boost::signal<void (boost::shared_ptr<Connection>, boost::shared_ptr<Error>)> onConnectFinished;
private:
- void finish(boost::shared_ptr<Connection> connection);
+ void finish(boost::shared_ptr<Connection> connection, boost::shared_ptr<Error>);
void tryNextConnectionFactory();
- void handleConnectorFinished(boost::shared_ptr<Connection>);
+ void handleConnectorFinished(boost::shared_ptr<Connection>, boost::shared_ptr<Error>);
private:
std::string hostname;
DomainNameResolver* resolver;
std::vector<ConnectionFactory*> connectionFactories;
TimerFactory* timerFactory;
int timeoutMilliseconds;
std::deque<ConnectionFactory*> connectionFactoryQueue;
boost::shared_ptr<Connector> currentConnector;
+ boost::shared_ptr<Error> lastError;
};
};
diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp
index 378875b..5e7f8d9 100644
--- a/Swiften/Network/Connector.cpp
+++ b/Swiften/Network/Connector.cpp
@@ -1,123 +1,127 @@
/*
* Copyright (c) 2010 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#include <Swiften/Network/Connector.h>
#include <boost/bind.hpp>
#include <iostream>
#include <Swiften/Network/ConnectionFactory.h>
#include <Swiften/Network/DomainNameResolver.h>
#include <Swiften/Network/DomainNameAddressQuery.h>
#include <Swiften/Network/TimerFactory.h>
#include <Swiften/Base/Log.h>
namespace Swift {
-Connector::Connector(const std::string& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, int defaultPort) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), defaultPort(defaultPort), timeoutMilliseconds(0), queriedAllServices(true) {
+Connector::Connector(const std::string& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, int defaultPort) : hostname(hostname), resolver(resolver), connectionFactory(connectionFactory), timerFactory(timerFactory), defaultPort(defaultPort), timeoutMilliseconds(0), queriedAllServices(true), foundSomeDNS(false) {
}
void Connector::setTimeoutMilliseconds(int milliseconds) {
timeoutMilliseconds = milliseconds;
}
void Connector::start() {
SWIFT_LOG(debug) << "Starting connector for " << hostname << std::endl;
//std::cout << "Connector::start()" << std::endl;
assert(!currentConnection);
assert(!serviceQuery);
assert(!timer);
queriedAllServices = false;
serviceQuery = resolver->createServiceQuery("_xmpp-client._tcp." + hostname);
serviceQuery->onResult.connect(boost::bind(&Connector::handleServiceQueryResult, shared_from_this(), _1));
if (timeoutMilliseconds > 0) {
timer = timerFactory->createTimer(timeoutMilliseconds);
timer->onTick.connect(boost::bind(&Connector::handleTimeout, shared_from_this()));
timer->start();
}
serviceQuery->run();
}
void Connector::stop() {
finish(boost::shared_ptr<Connection>());
}
void Connector::queryAddress(const std::string& hostname) {
assert(!addressQuery);
addressQuery = resolver->createAddressQuery(hostname);
addressQuery->onResult.connect(boost::bind(&Connector::handleAddressQueryResult, shared_from_this(), _1, _2));
addressQuery->run();
}
void Connector::handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result) {
SWIFT_LOG(debug) << result.size() << " SRV result(s)" << std::endl;
serviceQueryResults = std::deque<DomainNameServiceQuery::Result>(result.begin(), result.end());
serviceQuery.reset();
+ if (!serviceQueryResults.empty()) {
+ foundSomeDNS = true;
+ }
tryNextServiceOrFallback();
}
void Connector::tryNextServiceOrFallback() {
if (queriedAllServices) {
SWIFT_LOG(debug) << "Queried all services" << std::endl;
finish(boost::shared_ptr<Connection>());
}
else if (serviceQueryResults.empty()) {
SWIFT_LOG(debug) << "Falling back on A resolution" << std::endl;
// Fall back on simple address resolving
queriedAllServices = true;
queryAddress(hostname);
}
else {
SWIFT_LOG(debug) << "Querying next address" << std::endl;
queryAddress(serviceQueryResults.front().hostname);
}
}
void Connector::handleAddressQueryResult(const std::vector<HostAddress>& addresses, boost::optional<DomainNameResolveError> error) {
SWIFT_LOG(debug) << addresses.size() << " addresses" << std::endl;
addressQuery.reset();
if (error || addresses.empty()) {
if (!serviceQueryResults.empty()) {
serviceQueryResults.pop_front();
}
tryNextServiceOrFallback();
}
else {
+ foundSomeDNS = true;
addressQueryResults = std::deque<HostAddress>(addresses.begin(), addresses.end());
tryNextAddress();
}
}
void Connector::tryNextAddress() {
if (addressQueryResults.empty()) {
SWIFT_LOG(debug) << "Done trying addresses. Moving on." << std::endl;
// Done trying all addresses. Move on to the next host.
if (!serviceQueryResults.empty()) {
serviceQueryResults.pop_front();
}
tryNextServiceOrFallback();
}
else {
SWIFT_LOG(debug) << "Trying next address" << std::endl;
HostAddress address = addressQueryResults.front();
addressQueryResults.pop_front();
int port = defaultPort;
if (!serviceQueryResults.empty()) {
port = serviceQueryResults.front().port;
}
tryConnect(HostAddressPort(address, port));
}
}
void Connector::tryConnect(const HostAddressPort& target) {
assert(!currentConnection);
SWIFT_LOG(debug) << "Trying to connect to " << target.getAddress().toString() << ":" << target.getPort() << std::endl;
currentConnection = connectionFactory->createConnection();
currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1));
currentConnection->connect(target);
}
@@ -128,44 +132,44 @@ void Connector::handleConnectionConnectFinished(bool error) {
if (error) {
currentConnection.reset();
if (!addressQueryResults.empty()) {
tryNextAddress();
}
else {
if (!serviceQueryResults.empty()) {
serviceQueryResults.pop_front();
}
tryNextServiceOrFallback();
}
}
else {
finish(currentConnection);
}
}
void Connector::finish(boost::shared_ptr<Connection> connection) {
if (timer) {
timer->stop();
timer->onTick.disconnect(boost::bind(&Connector::handleTimeout, shared_from_this()));
timer.reset();
}
if (serviceQuery) {
serviceQuery->onResult.disconnect(boost::bind(&Connector::handleServiceQueryResult, shared_from_this(), _1));
serviceQuery.reset();
}
if (addressQuery) {
addressQuery->onResult.disconnect(boost::bind(&Connector::handleAddressQueryResult, shared_from_this(), _1, _2));
addressQuery.reset();
}
if (currentConnection) {
currentConnection->onConnectFinished.disconnect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1));
currentConnection.reset();
}
- onConnectFinished(connection);
+ onConnectFinished(connection, (connection || foundSomeDNS) ? boost::shared_ptr<Error>() : boost::make_shared<DomainNameResolveError>());
}
void Connector::handleTimeout() {
SWIFT_LOG(debug) << "Timeout" << std::endl;
finish(boost::shared_ptr<Connection>());
}
};
diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h
index 8f2c359..bf0efaf 100644
--- a/Swiften/Network/Connector.h
+++ b/Swiften/Network/Connector.h
@@ -3,69 +3,70 @@
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#pragma once
#include <deque>
#include <Swiften/Base/boost_bsignals.h>
#include <boost/shared_ptr.hpp>
#include <Swiften/Network/DomainNameServiceQuery.h>
#include <Swiften/Network/Connection.h>
#include <Swiften/Network/Timer.h>
#include <Swiften/Network/HostAddressPort.h>
#include <string>
#include <Swiften/Network/DomainNameResolveError.h>
namespace Swift {
class DomainNameAddressQuery;
class DomainNameResolver;
class ConnectionFactory;
class TimerFactory;
class Connector : public boost::bsignals::trackable, public boost::enable_shared_from_this<Connector> {
public:
typedef boost::shared_ptr<Connector> ref;
static Connector::ref create(const std::string& hostname, DomainNameResolver* resolver, ConnectionFactory* connectionFactory, TimerFactory* timerFactory, int defaultPort = 5222) {
return ref(new Connector(hostname, resolver, connectionFactory, timerFactory, defaultPort));
}
void setTimeoutMilliseconds(int milliseconds);
void start();
void stop();
- boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished;
+ boost::signal<void (boost::shared_ptr<Connection>, boost::shared_ptr<Error>)> onConnectFinished;
private:
Connector(const std::string& hostname, DomainNameResolver*, ConnectionFactory*, TimerFactory*, int defaultPort);
void handleServiceQueryResult(const std::vector<DomainNameServiceQuery::Result>& result);
void handleAddressQueryResult(const std::vector<HostAddress>& address, boost::optional<DomainNameResolveError> error);
void queryAddress(const std::string& hostname);
void tryNextServiceOrFallback();
void tryNextAddress();
void tryConnect(const HostAddressPort& target);
void handleConnectionConnectFinished(bool error);
void finish(boost::shared_ptr<Connection>);
void handleTimeout();
private:
std::string hostname;
DomainNameResolver* resolver;
ConnectionFactory* connectionFactory;
TimerFactory* timerFactory;
int defaultPort;
int timeoutMilliseconds;
boost::shared_ptr<Timer> timer;
boost::shared_ptr<DomainNameServiceQuery> serviceQuery;
std::deque<DomainNameServiceQuery::Result> serviceQueryResults;
boost::shared_ptr<DomainNameAddressQuery> addressQuery;
std::deque<HostAddress> addressQueryResults;
bool queriedAllServices;
boost::shared_ptr<Connection> currentConnection;
+ bool foundSomeDNS;
};
};
diff --git a/Swiften/Network/UnitTest/ChainedConnectorTest.cpp b/Swiften/Network/UnitTest/ChainedConnectorTest.cpp
index c7d23da..a2fceb9 100644
--- a/Swiften/Network/UnitTest/ChainedConnectorTest.cpp
+++ b/Swiften/Network/UnitTest/ChainedConnectorTest.cpp
@@ -1,161 +1,186 @@
/*
* Copyright (c) 2010 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#include <cppunit/extensions/HelperMacros.h>
#include <cppunit/extensions/TestFactoryRegistry.h>
#include <boost/bind.hpp>
#include <boost/smart_ptr/make_shared.hpp>
#include <Swiften/Network/ChainedConnector.h>
#include <Swiften/Network/Connection.h>
#include <Swiften/Network/ConnectionFactory.h>
#include <Swiften/Network/HostAddressPort.h>
#include <Swiften/Network/StaticDomainNameResolver.h>
#include <Swiften/Network/DummyTimerFactory.h>
#include <Swiften/EventLoop/DummyEventLoop.h>
+#include <Swiften/Network/DomainNameResolveError.h>
using namespace Swift;
class ChainedConnectorTest : public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE(ChainedConnectorTest);
CPPUNIT_TEST(testConnect_FirstConnectorSucceeds);
CPPUNIT_TEST(testConnect_SecondConnectorSucceeds);
CPPUNIT_TEST(testConnect_NoConnectorSucceeds);
+ CPPUNIT_TEST(testConnect_NoDNS);
CPPUNIT_TEST(testStop);
CPPUNIT_TEST_SUITE_END();
public:
void setUp() {
+ error.reset();
host = HostAddressPort(HostAddress("1.1.1.1"), 1234);
eventLoop = new DummyEventLoop();
resolver = new StaticDomainNameResolver(eventLoop);
resolver->addXMPPClientService("foo.com", host);
connectionFactory1 = new MockConnectionFactory(eventLoop, 1);
connectionFactory2 = new MockConnectionFactory(eventLoop, 2);
timerFactory = new DummyTimerFactory();
}
void tearDown() {
delete timerFactory;
delete connectionFactory2;
delete connectionFactory1;
delete resolver;
delete eventLoop;
}
void testConnect_FirstConnectorSucceeds() {
boost::shared_ptr<ChainedConnector> testling(createConnector());
connectionFactory1->connects = true;
connectionFactory2->connects = false;
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
CPPUNIT_ASSERT_EQUAL(1, boost::dynamic_pointer_cast<MockConnection>(connections[0])->id);
+ CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testConnect_SecondConnectorSucceeds() {
boost::shared_ptr<ChainedConnector> testling(createConnector());
connectionFactory1->connects = false;
connectionFactory2->connects = true;
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
CPPUNIT_ASSERT_EQUAL(2, boost::dynamic_pointer_cast<MockConnection>(connections[0])->id);
+ CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testConnect_NoConnectorSucceeds() {
boost::shared_ptr<ChainedConnector> testling(createConnector());
connectionFactory1->connects = false;
connectionFactory2->connects = false;
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(!connections[0]);
+ CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error));
+ }
+
+ void testConnect_NoDNS() {
+ /* Reset resolver so there's no record */
+ delete resolver;
+ resolver = new StaticDomainNameResolver(eventLoop);
+ boost::shared_ptr<ChainedConnector> testling(createConnector());
+ connectionFactory1->connects = false;
+ connectionFactory2->connects = false;
+
+ testling->start();
+ //testling->stop();
+ eventLoop->processEvents();
+
+ CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+ CPPUNIT_ASSERT(!connections[0]);
+ CPPUNIT_ASSERT(boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testStop() {
boost::shared_ptr<ChainedConnector> testling(createConnector());
connectionFactory1->connects = true;
connectionFactory2->connects = false;
testling->start();
testling->stop();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(!connections[0]);
}
private:
boost::shared_ptr<ChainedConnector> createConnector() {
std::vector<ConnectionFactory*> factories;
factories.push_back(connectionFactory1);
factories.push_back(connectionFactory2);
boost::shared_ptr<ChainedConnector> connector = boost::make_shared<ChainedConnector>("foo.com", resolver, factories, timerFactory);
- connector->onConnectFinished.connect(boost::bind(&ChainedConnectorTest::handleConnectorFinished, this, _1));
+ connector->onConnectFinished.connect(boost::bind(&ChainedConnectorTest::handleConnectorFinished, this, _1, _2));
return connector;
}
- void handleConnectorFinished(boost::shared_ptr<Connection> connection) {
+ void handleConnectorFinished(boost::shared_ptr<Connection> connection, boost::shared_ptr<Error> resultError) {
+ error = resultError;
boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection));
if (connection) {
assert(c);
}
connections.push_back(c);
}
struct MockConnection : public Connection {
public:
MockConnection(bool connects, int id, EventLoop* eventLoop) : connects(connects), id(id), eventLoop(eventLoop) {
}
void listen() { assert(false); }
void connect(const HostAddressPort&) {
eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), !connects));
}
HostAddressPort getLocalAddress() const { return HostAddressPort(); }
void disconnect() { assert(false); }
void write(const SafeByteArray&) { assert(false); }
bool connects;
int id;
EventLoop* eventLoop;
};
struct MockConnectionFactory : public ConnectionFactory {
MockConnectionFactory(EventLoop* eventLoop, int id) : eventLoop(eventLoop), connects(true), id(id) {
}
boost::shared_ptr<Connection> createConnection() {
return boost::make_shared<MockConnection>(connects, id, eventLoop);
}
EventLoop* eventLoop;
bool connects;
int id;
};
private:
HostAddressPort host;
DummyEventLoop* eventLoop;
StaticDomainNameResolver* resolver;
MockConnectionFactory* connectionFactory1;
MockConnectionFactory* connectionFactory2;
DummyTimerFactory* timerFactory;
std::vector< boost::shared_ptr<MockConnection> > connections;
+ boost::shared_ptr<Error> error;
};
CPPUNIT_TEST_SUITE_REGISTRATION(ChainedConnectorTest);
diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp
index 6488e67..67270be 100644
--- a/Swiften/Network/UnitTest/ConnectorTest.cpp
+++ b/Swiften/Network/UnitTest/ConnectorTest.cpp
@@ -1,304 +1,320 @@
/*
* Copyright (c) 2010 Remko Tronçon
* Licensed under the GNU General Public License v3.
* See Documentation/Licenses/GPLv3.txt for more information.
*/
#include <cppunit/extensions/HelperMacros.h>
#include <cppunit/extensions/TestFactoryRegistry.h>
#include <boost/optional.hpp>
#include <boost/bind.hpp>
#include <Swiften/Network/Connector.h>
#include <Swiften/Network/Connection.h>
#include <Swiften/Network/ConnectionFactory.h>
#include <Swiften/Network/HostAddressPort.h>
#include <Swiften/Network/StaticDomainNameResolver.h>
#include <Swiften/Network/DummyTimerFactory.h>
#include <Swiften/EventLoop/DummyEventLoop.h>
+#include <Swiften/Network/DomainNameAddressQuery.h>
using namespace Swift;
class ConnectorTest : public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE(ConnectorTest);
CPPUNIT_TEST(testConnect);
CPPUNIT_TEST(testConnect_FirstAddressHostFails);
CPPUNIT_TEST(testConnect_NoSRVHost);
CPPUNIT_TEST(testConnect_NoHosts);
CPPUNIT_TEST(testConnect_FirstSRVHostFails);
CPPUNIT_TEST(testConnect_AllSRVHostsFailWithoutFallbackHost);
CPPUNIT_TEST(testConnect_AllSRVHostsFailWithFallbackHost);
CPPUNIT_TEST(testConnect_SRVAndFallbackHostsFail);
CPPUNIT_TEST(testConnect_TimeoutDuringResolve);
CPPUNIT_TEST(testConnect_TimeoutDuringConnect);
CPPUNIT_TEST(testConnect_NoTimeout);
CPPUNIT_TEST(testStop_DuringSRVQuery);
CPPUNIT_TEST(testStop_Timeout);
CPPUNIT_TEST_SUITE_END();
public:
void setUp() {
host1 = HostAddressPort(HostAddress("1.1.1.1"), 1234);
host2 = HostAddressPort(HostAddress("2.2.2.2"), 2345);
host3 = HostAddressPort(HostAddress("3.3.3.3"), 5222);
eventLoop = new DummyEventLoop();
resolver = new StaticDomainNameResolver(eventLoop);
connectionFactory = new MockConnectionFactory(eventLoop);
timerFactory = new DummyTimerFactory();
}
void tearDown() {
delete timerFactory;
delete connectionFactory;
delete resolver;
delete eventLoop;
}
void testConnect() {
Connector::ref testling(createConnector());
resolver->addXMPPClientService("foo.com", host1);
resolver->addXMPPClientService("foo.com", host2);
resolver->addAddress("foo.com", host3.getAddress());
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
CPPUNIT_ASSERT(host1 == *(connections[0]->hostAddressPort));
+ CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testConnect_NoSRVHost() {
Connector::ref testling(createConnector());
resolver->addAddress("foo.com", host3.getAddress());
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort));
+ CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testConnect_FirstAddressHostFails() {
Connector::ref testling(createConnector());
HostAddress address1("1.1.1.1");
HostAddress address2("2.2.2.2");
resolver->addXMPPClientService("foo.com", "host-foo.com", 1234);
resolver->addAddress("host-foo.com", address1);
resolver->addAddress("host-foo.com", address2);
connectionFactory->failingPorts.push_back(HostAddressPort(address1, 1234));
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
CPPUNIT_ASSERT(HostAddressPort(address2, 1234) == *(connections[0]->hostAddressPort));
+ CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testConnect_NoHosts() {
Connector::ref testling(createConnector());
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(!connections[0]);
+ CPPUNIT_ASSERT(boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testConnect_FirstSRVHostFails() {
Connector::ref testling(createConnector());
resolver->addXMPPClientService("foo.com", host1);
resolver->addXMPPClientService("foo.com", host2);
connectionFactory->failingPorts.push_back(host1);
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(host2 == *(connections[0]->hostAddressPort));
+ CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testConnect_AllSRVHostsFailWithoutFallbackHost() {
Connector::ref testling(createConnector());
resolver->addXMPPClientService("foo.com", host1);
resolver->addXMPPClientService("foo.com", host2);
connectionFactory->failingPorts.push_back(host1);
connectionFactory->failingPorts.push_back(host2);
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(!connections[0]);
+ CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testConnect_AllSRVHostsFailWithFallbackHost() {
Connector::ref testling(createConnector());
resolver->addXMPPClientService("foo.com", host1);
resolver->addXMPPClientService("foo.com", host2);
resolver->addAddress("foo.com", host3.getAddress());
connectionFactory->failingPorts.push_back(host1);
connectionFactory->failingPorts.push_back(host2);
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
CPPUNIT_ASSERT(host3 == *(connections[0]->hostAddressPort));
+ CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testConnect_SRVAndFallbackHostsFail() {
Connector::ref testling(createConnector());
resolver->addXMPPClientService("foo.com", host1);
resolver->addAddress("foo.com", host3.getAddress());
connectionFactory->failingPorts.push_back(host1);
connectionFactory->failingPorts.push_back(host3);
testling->start();
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(!connections[0]);
+ CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testConnect_TimeoutDuringResolve() {
Connector::ref testling(createConnector());
testling->setTimeoutMilliseconds(10);
resolver->setIsResponsive(false);
testling->start();
eventLoop->processEvents();
timerFactory->setTime(10);
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+ CPPUNIT_ASSERT(boost::dynamic_pointer_cast<DomainNameResolveError>(error));
CPPUNIT_ASSERT(!connections[0]);
}
void testConnect_TimeoutDuringConnect() {
Connector::ref testling(createConnector());
testling->setTimeoutMilliseconds(10);
resolver->addXMPPClientService("foo.com", host1);
connectionFactory->isResponsive = false;
testling->start();
eventLoop->processEvents();
timerFactory->setTime(10);
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(!connections[0]);
+ CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testConnect_NoTimeout() {
Connector::ref testling(createConnector());
testling->setTimeoutMilliseconds(10);
resolver->addXMPPClientService("foo.com", host1);
testling->start();
eventLoop->processEvents();
timerFactory->setTime(10);
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(connections[0]);
+ CPPUNIT_ASSERT(!boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testStop_DuringSRVQuery() {
- Connector::ref testling(createConnector());
- resolver->addXMPPClientService("foo.com", host1);
+ Connector::ref testling(createConnector());
+ resolver->addXMPPClientService("foo.com", host1);
- testling->start();
- testling->stop();
+ testling->start();
+ testling->stop();
- eventLoop->processEvents();
+ eventLoop->processEvents();
- CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
- CPPUNIT_ASSERT(!connections[0]);
+ CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
+ CPPUNIT_ASSERT(!connections[0]);
+ CPPUNIT_ASSERT(boost::dynamic_pointer_cast<DomainNameResolveError>(error));
}
void testStop_Timeout() {
Connector::ref testling(createConnector());
testling->setTimeoutMilliseconds(10);
resolver->addXMPPClientService("foo.com", host1);
testling->start();
testling->stop();
eventLoop->processEvents();
timerFactory->setTime(10);
eventLoop->processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size()));
CPPUNIT_ASSERT(!connections[0]);
}
private:
Connector::ref createConnector() {
Connector::ref connector = Connector::create("foo.com", resolver, connectionFactory, timerFactory);
- connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1));
+ connector->onConnectFinished.connect(boost::bind(&ConnectorTest::handleConnectorFinished, this, _1, _2));
return connector;
}
- void handleConnectorFinished(boost::shared_ptr<Connection> connection) {
+ void handleConnectorFinished(boost::shared_ptr<Connection> connection, boost::shared_ptr<Error> resultError) {
boost::shared_ptr<MockConnection> c(boost::dynamic_pointer_cast<MockConnection>(connection));
if (connection) {
assert(c);
}
connections.push_back(c);
+ error = resultError;
}
struct MockConnection : public Connection {
public:
MockConnection(const std::vector<HostAddressPort>& failingPorts, bool isResponsive, EventLoop* eventLoop) : eventLoop(eventLoop), failingPorts(failingPorts), isResponsive(isResponsive) {}
void listen() { assert(false); }
void connect(const HostAddressPort& address) {
hostAddressPort = address;
if (isResponsive) {
bool fail = std::find(failingPorts.begin(), failingPorts.end(), address) != failingPorts.end();
eventLoop->postEvent(boost::bind(boost::ref(onConnectFinished), fail));
}
}
HostAddressPort getLocalAddress() const { return HostAddressPort(); }
void disconnect() { assert(false); }
void write(const SafeByteArray&) { assert(false); }
EventLoop* eventLoop;
boost::optional<HostAddressPort> hostAddressPort;
std::vector<HostAddressPort> failingPorts;
bool isResponsive;
};
struct MockConnectionFactory : public ConnectionFactory {
MockConnectionFactory(EventLoop* eventLoop) : eventLoop(eventLoop), isResponsive(true) {
}
boost::shared_ptr<Connection> createConnection() {
return boost::shared_ptr<Connection>(new MockConnection(failingPorts, isResponsive, eventLoop));
}
EventLoop* eventLoop;
bool isResponsive;
std::vector<HostAddressPort> failingPorts;
};
private:
HostAddressPort host1;
HostAddressPort host2;
HostAddressPort host3;
DummyEventLoop* eventLoop;
StaticDomainNameResolver* resolver;
MockConnectionFactory* connectionFactory;
DummyTimerFactory* timerFactory;
std::vector< boost::shared_ptr<MockConnection> > connections;
+ boost::shared_ptr<Error> error;
+
};
CPPUNIT_TEST_SUITE_REGISTRATION(ConnectorTest);