diff options
author | Remko Tronçon <git@el-tramo.be> | 2010-10-07 21:37:48 (GMT) |
---|---|---|
committer | Remko Tronçon <git@el-tramo.be> | 2010-10-07 21:39:40 (GMT) |
commit | 091f554f42dcdef534718fb759eb45b622adfd4f (patch) | |
tree | b8753f62884ef5ef46d04782bb38d8ef2ed38d01 /Swiften | |
parent | 88eab3d1d9b722590da3837e3c79839189ea58d2 (diff) | |
download | swift-091f554f42dcdef534718fb759eb45b622adfd4f.zip swift-091f554f42dcdef534718fb759eb45b622adfd4f.tar.bz2 |
Fix crashes on disconnect during connect.
Resolves: #588
Diffstat (limited to 'Swiften')
-rw-r--r-- | Swiften/Client/Client.cpp | 62 | ||||
-rw-r--r-- | Swiften/Client/Client.h | 5 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.cpp | 37 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.h | 5 | ||||
-rw-r--r-- | Swiften/Network/Connector.cpp | 13 | ||||
-rw-r--r-- | Swiften/Network/Connector.h | 2 | ||||
-rw-r--r-- | Swiften/Network/StaticDomainNameResolver.cpp | 17 | ||||
-rw-r--r-- | Swiften/Network/UnitTest/ConnectorTest.cpp | 31 |
8 files changed, 113 insertions, 59 deletions
diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp index 974e256..5b57672 100644 --- a/Swiften/Client/Client.cpp +++ b/Swiften/Client/Client.cpp @@ -21,7 +21,7 @@ namespace Swift { -Client::Client(const JID& jid, const String& password) : jid_(jid), password_(password) { +Client::Client(const JID& jid, const String& password) : jid_(jid), password_(password), disconnectRequested_(false) { iqRouter_ = new IQRouter(this); connectionFactory_ = new BoostConnectionFactory(&MainBoostIOServiceThread::getInstance().getIOService()); timerFactory_ = new BoostTimerFactory(&MainBoostIOServiceThread::getInstance().getIOService()); @@ -52,26 +52,20 @@ void Client::connect(const JID& jid) { } void Client::connect(const String& host) { - assert(!connector_); // Crash on reconnect is here. + assert(!connector_); connector_ = Connector::create(host, &resolver_, connectionFactory_, timerFactory_); - connector_->onConnectFinished.connect(boost::bind(&Client::handleConnectorFinished, this, _1, connector_)); + connector_->onConnectFinished.connect(boost::bind(&Client::handleConnectorFinished, this, _1)); connector_->setTimeoutMilliseconds(60*1000); connector_->start(); } -void Client::handleConnectorFinished(boost::shared_ptr<Connection> connection, Connector::ref connector) { - bool currentConnection = connector_ && (connector.get() == connector_.get()); - // TODO: Add domain name resolver error - if (!currentConnection) { - /* disconnect() was called, this connection should be thrown away*/ - if (connection) { - connection->disconnect(); - } - return; - } +void Client::handleConnectorFinished(boost::shared_ptr<Connection> connection) { + connector_->onConnectFinished.disconnect(boost::bind(&Client::handleConnectorFinished, this, _1)); connector_.reset(); if (!connection) { - onError(ClientError::ConnectionError); + if (!disconnectRequested_) { + onError(ClientError::ConnectionError); + } } else { assert(!connection_); @@ -97,25 +91,20 @@ void Client::handleConnectorFinished(boost::shared_ptr<Connection> connection, C } void Client::disconnect() { - if (connector_) { - connector_.reset(); - } + // FIXME: We should be able to do without this boolean. We just have to make sure we can tell the difference between + // connector finishing without a connection due to an error or because of a disconnect. + disconnectRequested_ = true; if (session_) { session_->finish(); } - else { - closeConnection(); - } -} - -void Client::closeConnection() { - if (sessionStream_) { - sessionStream_.reset(); - } - if (connection_) { - connection_->disconnect(); - connection_.reset(); + else if (connector_) { + connector_->stop(); + assert(!session_); } + assert(!session_); + assert(!sessionStream_); + assert(!connector_); + disconnectRequested_ = false; } void Client::send(boost::shared_ptr<Stanza> stanza) { @@ -167,9 +156,22 @@ void Client::setCertificate(const String& certificate) { } void Client::handleSessionFinished(boost::shared_ptr<Error> error) { + session_->onInitialized.disconnect(boost::bind(&Client::handleSessionInitialized, this)); + session_->onStanzaAcked.disconnect(boost::bind(&Client::handleStanzaAcked, this, _1)); + session_->onFinished.disconnect(boost::bind(&Client::handleSessionFinished, this, _1)); + session_->onNeedCredentials.disconnect(boost::bind(&Client::handleNeedCredentials, this)); + session_->onStanzaReceived.disconnect(boost::bind(&Client::handleStanza, this, _1)); session_.reset(); - closeConnection(); + + sessionStream_->onDataRead.disconnect(boost::bind(&Client::handleDataRead, this, _1)); + sessionStream_->onDataWritten.disconnect(boost::bind(&Client::handleDataWritten, this, _1)); + sessionStream_.reset(); + + connection_->disconnect(); + connection_.reset(); + onAvailableChanged(false); + if (error) { ClientError clientError; if (boost::shared_ptr<ClientSession::Error> actualError = boost::dynamic_pointer_cast<ClientSession::Error>(error)) { diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h index e046b3c..7e55289 100644 --- a/Swiften/Client/Client.h +++ b/Swiften/Client/Client.h @@ -70,7 +70,7 @@ namespace Swift { boost::signal<void (const String&)> onDataWritten; private: - void handleConnectorFinished(boost::shared_ptr<Connection>, Connector::ref); + void handleConnectorFinished(boost::shared_ptr<Connection>); void handleSessionInitialized(); void send(boost::shared_ptr<Stanza>); virtual String getNewIQID(); @@ -81,8 +81,6 @@ namespace Swift { void handleDataWritten(const String&); void handleStanzaAcked(boost::shared_ptr<Stanza>); - void closeConnection(); - private: PlatformDomainNameResolver resolver_; JID jid_; @@ -99,5 +97,6 @@ namespace Swift { boost::shared_ptr<BasicSessionStream> sessionStream_; boost::shared_ptr<ClientSession> session_; String certificate_; + bool disconnectRequested_; }; } diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index 17b3931..4c2be95 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -56,10 +56,10 @@ ClientSession::~ClientSession() { } void ClientSession::start() { - streamOnStreamStartReceivedConnection = stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); - streamOnElementReceivedConnection = stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); - streamOnErrorConnection = stream->onError.connect(boost::bind(&ClientSession::handleStreamError, shared_from_this(), _1)); - streamOnTLSEncryptedConnection = stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); + stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); + stream->onError.connect(boost::bind(&ClientSession::handleStreamError, shared_from_this(), _1)); + stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); assert(state == Initial); state = WaitingForStreamStart; @@ -230,10 +230,10 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { } else if (boost::dynamic_pointer_cast<StreamManagementEnabled>(element)) { stanzaAckRequester_ = boost::shared_ptr<StanzaAckRequester>(new StanzaAckRequester()); - stanzaAckRequester_->onRequestAck.connect(boost::bind(&ClientSession::requestAck, this)); - stanzaAckRequester_->onStanzaAcked.connect(boost::bind(&ClientSession::handleStanzaAcked, this, _1)); + stanzaAckRequester_->onRequestAck.connect(boost::bind(&ClientSession::requestAck, shared_from_this())); + stanzaAckRequester_->onStanzaAcked.connect(boost::bind(&ClientSession::handleStanzaAcked, shared_from_this(), _1)); stanzaAckResponder_ = boost::shared_ptr<StanzaAckResponder>(new StanzaAckResponder()); - stanzaAckResponder_->onAck.connect(boost::bind(&ClientSession::ack, this, _1)); + stanzaAckResponder_->onAck.connect(boost::bind(&ClientSession::ack, shared_from_this(), _1)); needAcking = false; continueSessionInitialization(); } @@ -334,9 +334,6 @@ void ClientSession::handleStreamError(boost::shared_ptr<Swift::Error> error) { } void ClientSession::finish() { - if (stream->isAvailable()) { - stream->writeFooter(); - } finishSession(boost::shared_ptr<Error>()); } @@ -346,11 +343,23 @@ void ClientSession::finishSession(Error::Type error) { void ClientSession::finishSession(boost::shared_ptr<Swift::Error> error) { state = Finished; + if (stanzaAckRequester_) { + stanzaAckRequester_->onRequestAck.disconnect(boost::bind(&ClientSession::requestAck, shared_from_this())); + stanzaAckRequester_->onStanzaAcked.disconnect(boost::bind(&ClientSession::handleStanzaAcked, shared_from_this(), _1)); + stanzaAckRequester_.reset(); + } + if (stanzaAckResponder_) { + stanzaAckResponder_->onAck.disconnect(boost::bind(&ClientSession::ack, shared_from_this(), _1)); + stanzaAckResponder_.reset(); + } stream->setWhitespacePingEnabled(false); - streamOnStreamStartReceivedConnection.disconnect(); - streamOnElementReceivedConnection.disconnect(); - streamOnErrorConnection.disconnect(); - streamOnTLSEncryptedConnection.disconnect(); + stream->onStreamStartReceived.disconnect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onElementReceived.disconnect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); + stream->onError.disconnect(boost::bind(&ClientSession::handleStreamError, shared_from_this(), _1)); + stream->onTLSEncrypted.disconnect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); + if (stream->isAvailable()) { + stream->writeFooter(); + } onFinished(error); } diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index 2af9cab..2c1bda8 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -56,6 +56,7 @@ namespace Swift { }; ~ClientSession(); + static boost::shared_ptr<ClientSession> create(const JID& jid, boost::shared_ptr<SessionStream> stream) { return boost::shared_ptr<ClientSession>(new ClientSession(jid, stream)); } @@ -127,9 +128,5 @@ namespace Swift { ClientAuthenticator* authenticator; boost::shared_ptr<StanzaAckRequester> stanzaAckRequester_; boost::shared_ptr<StanzaAckResponder> stanzaAckResponder_; - boost::bsignals::connection streamOnStreamStartReceivedConnection; - boost::bsignals::connection streamOnElementReceivedConnection; - boost::bsignals::connection streamOnErrorConnection; - boost::bsignals::connection streamOnTLSEncryptedConnection; }; } diff --git a/Swiften/Network/Connector.cpp b/Swiften/Network/Connector.cpp index d9a4c5d..01875f7 100644 --- a/Swiften/Network/Connector.cpp +++ b/Swiften/Network/Connector.cpp @@ -39,6 +39,10 @@ void Connector::start() { serviceQuery->run(); } +void Connector::stop() { + finish(boost::shared_ptr<Connection>()); +} + void Connector::queryAddress(const String& hostname) { assert(!addressQuery); addressQuery = resolver->createAddressQuery(hostname); @@ -112,7 +116,7 @@ void Connector::tryConnect(const HostAddressPort& target) { assert(!currentConnection); //std::cout << "Connector::tryConnect() " << target.getAddress().toString() << " " << target.getPort() << std::endl; currentConnection = connectionFactory->createConnection(); - connectFinishedConnection = currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); + currentConnection->onConnectFinished.connect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); currentConnection->connect(target); } @@ -138,15 +142,20 @@ void Connector::handleConnectionConnectFinished(bool error) { void Connector::finish(boost::shared_ptr<Connection> connection) { if (timer) { timer->stop(); + timer->onTick.disconnect(boost::bind(&Connector::handleTimeout, shared_from_this())); timer.reset(); } if (serviceQuery) { + serviceQuery->onResult.disconnect(boost::bind(&Connector::handleServiceQueryResult, shared_from_this(), _1)); serviceQuery.reset(); } if (addressQuery) { + addressQuery->onResult.disconnect(boost::bind(&Connector::handleAddressQueryResult, shared_from_this(), _1, _2)); addressQuery.reset(); } - connectFinishedConnection.disconnect(); + if (currentConnection) { + currentConnection->onConnectFinished.disconnect(boost::bind(&Connector::handleConnectionConnectFinished, shared_from_this(), _1)); + } onConnectFinished(connection); } diff --git a/Swiften/Network/Connector.h b/Swiften/Network/Connector.h index 36026de..52779c2 100644 --- a/Swiften/Network/Connector.h +++ b/Swiften/Network/Connector.h @@ -33,6 +33,7 @@ namespace Swift { void setTimeoutMilliseconds(int milliseconds); void start(); + void stop(); boost::signal<void (boost::shared_ptr<Connection>)> onConnectFinished; @@ -65,6 +66,5 @@ namespace Swift { std::deque<HostAddress> addressQueryResults; bool queriedAllServices; boost::shared_ptr<Connection> currentConnection; - boost::bsignals::connection connectFinishedConnection; }; }; diff --git a/Swiften/Network/StaticDomainNameResolver.cpp b/Swiften/Network/StaticDomainNameResolver.cpp index b94dd11..636f310 100644 --- a/Swiften/Network/StaticDomainNameResolver.cpp +++ b/Swiften/Network/StaticDomainNameResolver.cpp @@ -15,7 +15,7 @@ using namespace Swift; namespace { - struct ServiceQuery : public DomainNameServiceQuery, public EventOwner { + struct ServiceQuery : public DomainNameServiceQuery, public boost::enable_shared_from_this<ServiceQuery> { ServiceQuery(const String& service, Swift::StaticDomainNameResolver* resolver) : service(service), resolver(resolver) {} virtual void run() { @@ -28,14 +28,18 @@ namespace { results.push_back(i->second); } } - MainEventLoop::postEvent(boost::bind(boost::ref(onResult), results)); + MainEventLoop::postEvent(boost::bind(&ServiceQuery::emitOnResult, shared_from_this(), results)); + } + + void emitOnResult(std::vector<DomainNameServiceQuery::Result> results) { + onResult(results); } String service; StaticDomainNameResolver* resolver; }; - struct AddressQuery : public DomainNameAddressQuery, public EventOwner { + struct AddressQuery : public DomainNameAddressQuery, public boost::enable_shared_from_this<AddressQuery> { AddressQuery(const String& host, StaticDomainNameResolver* resolver) : host(host), resolver(resolver) {} virtual void run() { @@ -45,12 +49,15 @@ namespace { StaticDomainNameResolver::AddressesMap::const_iterator i = resolver->getAddresses().find(host); if (i != resolver->getAddresses().end()) { MainEventLoop::postEvent( - boost::bind(boost::ref(onResult), i->second, boost::optional<DomainNameResolveError>())); + boost::bind(&AddressQuery::emitOnResult, shared_from_this(), i->second, boost::optional<DomainNameResolveError>())); } else { - MainEventLoop::postEvent(boost::bind(boost::ref(onResult), std::vector<HostAddress>(), boost::optional<DomainNameResolveError>(DomainNameResolveError()))); + MainEventLoop::postEvent(boost::bind(&AddressQuery::emitOnResult, shared_from_this(), std::vector<HostAddress>(), boost::optional<DomainNameResolveError>(DomainNameResolveError()))); } + } + void emitOnResult(std::vector<HostAddress> results, boost::optional<DomainNameResolveError> error) { + onResult(results, error); } String host; diff --git a/Swiften/Network/UnitTest/ConnectorTest.cpp b/Swiften/Network/UnitTest/ConnectorTest.cpp index 32a7157..2e396b3 100644 --- a/Swiften/Network/UnitTest/ConnectorTest.cpp +++ b/Swiften/Network/UnitTest/ConnectorTest.cpp @@ -34,6 +34,8 @@ class ConnectorTest : public CppUnit::TestFixture { CPPUNIT_TEST(testConnect_TimeoutDuringResolve); CPPUNIT_TEST(testConnect_TimeoutDuringConnect); CPPUNIT_TEST(testConnect_NoTimeout); + CPPUNIT_TEST(testStop_DuringSRVQuery); + CPPUNIT_TEST(testStop_Timeout); CPPUNIT_TEST_SUITE_END(); public: @@ -208,6 +210,35 @@ class ConnectorTest : public CppUnit::TestFixture { CPPUNIT_ASSERT(connections[0]); } + void testStop_DuringSRVQuery() { + Connector::ref testling(createConnector()); + resolver->addXMPPClientService("foo.com", host1); + + testling->start(); + testling->stop(); + + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + + void testStop_Timeout() { + Connector::ref testling(createConnector()); + testling->setTimeoutMilliseconds(10); + resolver->addXMPPClientService("foo.com", host1); + + testling->start(); + testling->stop(); + + eventLoop->processEvents(); + timerFactory->setTime(10); + eventLoop->processEvents(); + + CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(connections.size())); + CPPUNIT_ASSERT(!connections[0]); + } + private: Connector::ref createConnector() { |