summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTobias Markmann <tm@ayena.de>2016-07-29 09:47:23 (GMT)
committerTobias Markmann <tm@ayena.de>2016-11-28 10:35:05 (GMT)
commit2039930eadd4756068a8a60c8340d9908a7136d3 (patch)
treed8aca4bf98a2bb6e3b819305b1f87af3117f4910
parent2f90eb7409df91a80c60b189242ac0c1de313910 (diff)
downloadswift-2039930eadd4756068a8a60c8340d9908a7136d3.zip
swift-2039930eadd4756068a8a60c8340d9908a7136d3.tar.bz2
Correctly handle server initiated closing of stream
If a server closes the XMPP stream, it sends a </stream:stream> tag. The client is supposed to respond with the same tag and then both parties can close the TLS/TCP socket. Previously Swift(-en) would simply ignore </stream:stream> tag if it was not directly followed by a shutdown of the TCP connection. In addition there is now a timeout timer started as soon as Swiften or the server initiates a shutdown. It will close the socket and cleanup the ClientSession if the server does not respond in time or the network is faulty. Refactored some code in ClientSession in the process. Moved ClientSession::State to a C++11 strongly typed enum class. This also fixes issues where duplicated </stream:stream> tags would be send by Swift. Test-Information: Tested against Prosody ba782a093b14 and M-Link 16.3v6-0, which provide ad-hoc commands to end a user session. Previously this was ignored by Swift. Now it correctly responds to the server, detects it as a disconnect and tries to reconnect afterwards. Added unit test for the case where the server closes the session stream. Change-Id: I59dfde3aa6b50dc117f340e5db6b9e58b54b3c60
-rwxr-xr-xBuildTools/FixIncludes.py5
-rw-r--r--Swift/ChangeLog.md1
-rw-r--r--Swiften/Base/Debug.cpp50
-rw-r--r--Swiften/Base/Debug.h8
-rw-r--r--Swiften/ChangeLog.md4
-rw-r--r--Swiften/Client/ClientOptions.h54
-rw-r--r--Swiften/Client/ClientSession.cpp184
-rw-r--r--Swiften/Client/ClientSession.h40
-rw-r--r--Swiften/Client/ClientSessionStanzaChannel.h2
-rw-r--r--Swiften/Client/CoreClient.cpp6
-rw-r--r--Swiften/Client/UnitTest/ClientSessionTest.cpp111
-rw-r--r--Swiften/Session/BasicSessionStream.cpp6
-rw-r--r--Swiften/Session/BasicSessionStream.h1
-rw-r--r--Swiften/Session/SessionStream.h1
-rw-r--r--Swiften/StreamStack/XMPPLayer.cpp1
-rw-r--r--Swiften/StreamStack/XMPPLayer.h1
16 files changed, 350 insertions, 125 deletions
diff --git a/BuildTools/FixIncludes.py b/BuildTools/FixIncludes.py
index d1b8268..8984944 100755
--- a/BuildTools/FixIncludes.py
+++ b/BuildTools/FixIncludes.py
@@ -19,7 +19,7 @@ c_stdlib_headers = Set(["assert.h", "limits.h", "signal.h", "stdlib.h", "ctyp
cpp_stdlib_headers = Set(["algorithm", "fstream", "list", "regex", "typeindex", "array", "functional", "locale", "set", "typeinfo", "atomic", "future", "map", "sstream", "type_traits", "bitset", "initializer_list", "memory", "stack", "unordered_map", "chrono", "iomanip", "mutex", "stdexcept", "unordered_set", "codecvt", "ios", "new", "streambuf", "utility", "complex", "iosfwd", "numeric", "string", "valarray", "condition_variable", "iostream", "ostream", "strstream", "vector", "deque", "istream", "queue", "system_error", "exception", "iterator", "random", "thread", "forward_list", "limits", "ratio", "tuple", "cassert", "ciso646", "csetjmp", "cstdio", "ctime", "cctype", "climits", "csignal", "cstdlib", "cwchar", "cerrno", "clocale", "cstdarg", "cstring", "cwctype", "cfloat", "cmath", "cstddef"])
class HeaderType:
- PRAGMA_ONCE, CORRESPONDING_HEADER, C_STDLIB, CPP_STDLIB, BOOST, QT, OTHER, SWIFTEN, LIMBER, SLIMBER, SWIFT_CONTROLLERS, SLUIFT, SWIFTOOLS, SWIFT = range(14)
+ PRAGMA_ONCE, CORRESPONDING_HEADER, C_STDLIB, CPP_STDLIB, BOOST, QT, SWIFTEN_BASE_DEBUG, OTHER, SWIFTEN, LIMBER, SLIMBER, SWIFT_CONTROLLERS, SLUIFT, SWIFTOOLS, SWIFT = range(15)
def findHeaderBlock(lines):
start = False
@@ -61,6 +61,9 @@ def fileNameToHeaderType(name):
if name.startswith("Q"):
return HeaderType.QT
+ if name.startswith("Swiften/Base/Debug.h"):
+ return HeaderType.SWIFTEN_BASE_DEBUG
+
if name.startswith("Swiften"):
return HeaderType.SWIFTEN
diff --git a/Swift/ChangeLog.md b/Swift/ChangeLog.md
index 475062c..4c88d88 100644
--- a/Swift/ChangeLog.md
+++ b/Swift/ChangeLog.md
@@ -2,6 +2,7 @@
---------------
- Fix UI layout issue for translations that require right-to-left (RTL) layout
- macOS releases are now code-signed with a key from Apple, so they can be run without Gatekeeper trust warnings
+- Handle sessions being closed by the server
4.0-beta2 ( 2016-07-20 )
------------------------
diff --git a/Swiften/Base/Debug.cpp b/Swiften/Base/Debug.cpp
index b59de35..b4245c3 100644
--- a/Swiften/Base/Debug.cpp
+++ b/Swiften/Base/Debug.cpp
@@ -10,6 +10,7 @@
#include <memory>
#include <Swiften/Client/ClientError.h>
+#include <Swiften/Client/ClientSession.h>
#include <Swiften/Serializer/PayloadSerializer.h>
#include <Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h>
#include <Swiften/Serializer/XMPPSerializer.h>
@@ -138,3 +139,52 @@ std::ostream& operator<<(std::ostream& os, Swift::Element* ele) {
os << "Element(Unknown)";
return os;
}
+
+std::ostream& operator<<(std::ostream& os, Swift::ClientSession::State state) {
+ using CS = Swift::ClientSession;
+ switch (state) {
+ case CS::State::Initial:
+ os << "ClientSession::State::Initial";
+ break;
+ case CS::State::WaitingForStreamStart:
+ os << "ClientSession::State::WaitingForStreamStart";
+ break;
+ case CS::State::Negotiating:
+ os << "ClientSession::State::Negotiating";
+ break;
+ case CS::State::Compressing:
+ os << "ClientSession::State::Compressing";
+ break;
+ case CS::State::WaitingForEncrypt:
+ os << "ClientSession::State::WaitingForEncrypt";
+ break;
+ case CS::State::Encrypting:
+ os << "ClientSession::State::Encrypting";
+ break;
+ case CS::State::WaitingForCredentials:
+ os << "ClientSession::State::WaitingForCredentials";
+ break;
+ case CS::State::Authenticating:
+ os << "ClientSession::State::Authenticating";
+ break;
+ case CS::State::EnablingSessionManagement:
+ os << "ClientSession::State::EnablingSessionManagement";
+ break;
+ case CS::State::BindingResource:
+ os << "ClientSession::State::BindingResource";
+ break;
+ case CS::State::StartingSession:
+ os << "ClientSession::State::StartingSession";
+ break;
+ case CS::State::Initialized:
+ os << "ClientSession::State::Initialized";
+ break;
+ case CS::State::Finishing:
+ os << "ClientSession::State::Finishing";
+ break;
+ case CS::State::Finished:
+ os << "ClientSession::State::Finished";
+ break;
+ }
+ return os;
+}
diff --git a/Swiften/Base/Debug.h b/Swiften/Base/Debug.h
index 18e7fb4..9dde74c 100644
--- a/Swiften/Base/Debug.h
+++ b/Swiften/Base/Debug.h
@@ -1,11 +1,15 @@
/*
- * Copyright (c) 2015 Isode Limited.
+ * Copyright (c) 2015-2016 Isode Limited.
* All rights reserved.
* See the COPYING file for more information.
*/
+#pragma once
+
#include <iosfwd>
+#include <Swiften/Client/ClientSession.h>
+
namespace Swift {
class ClientError;
class Element;
@@ -18,3 +22,5 @@ namespace boost {
std::ostream& operator<<(std::ostream& os, const Swift::ClientError& error);
std::ostream& operator<<(std::ostream& os, Swift::Element* ele);
+
+std::ostream& operator<<(std::ostream& os, Swift::ClientSession::State state);
diff --git a/Swiften/ChangeLog.md b/Swiften/ChangeLog.md
index 4d52c68..62cea5e 100644
--- a/Swiften/ChangeLog.md
+++ b/Swiften/ChangeLog.md
@@ -1,3 +1,7 @@
+4.0-in-progress
+---------------
+- Handle sessions being closed by the server
+
4.0-beta1 ( 2016-07-15 )
------------------------
- Moved code-base to C++11
diff --git a/Swiften/Client/ClientOptions.h b/Swiften/Client/ClientOptions.h
index 3a93197..1a337b6 100644
--- a/Swiften/Client/ClientOptions.h
+++ b/Swiften/Client/ClientOptions.h
@@ -30,21 +30,7 @@ namespace Swift {
HTTPConnectProxy
};
- ClientOptions() :
- useStreamCompression(true),
- useTLS(UseTLSWhenAvailable),
- allowPLAINWithoutTLS(false),
- useStreamResumption(false),
- forgetPassword(false),
- useAcks(true),
- singleSignOn(false),
- manualHostname(""),
- manualPort(-1),
- proxyType(SystemConfiguredProxy),
- manualProxyHostname(""),
- manualProxyPort(-1),
- boshHTTPConnectProxyAuthID(""),
- boshHTTPConnectProxyAuthPassword("") {
+ ClientOptions() {
}
/**
@@ -52,14 +38,14 @@ namespace Swift {
*
* Default: true
*/
- bool useStreamCompression;
+ bool useStreamCompression = true;
/**
* Sets whether TLS encryption should be used.
*
* Default: UseTLSWhenAvailable
*/
- UseTLS useTLS;
+ UseTLS useTLS = UseTLSWhenAvailable;
/**
* Sets whether plaintext authentication is
@@ -67,14 +53,14 @@ namespace Swift {
*
* Default: false
*/
- bool allowPLAINWithoutTLS;
+ bool allowPLAINWithoutTLS = false;
/**
* Use XEP-196 stream resumption when available.
*
* Default: false
*/
- bool useStreamResumption;
+ bool useStreamResumption = false;
/**
* Forget the password once it's used.
@@ -84,69 +70,69 @@ namespace Swift {
*
* Default: false
*/
- bool forgetPassword;
+ bool forgetPassword = false;
/**
* Use XEP-0198 acks in the stream when available.
* Default: true
*/
- bool useAcks;
+ bool useAcks = true;
/**
* Use Single Sign On.
* Default: false
*/
- bool singleSignOn;
+ bool singleSignOn = false;
/**
* The hostname to connect to.
* Leave this empty for standard XMPP connection, based on the JID domain.
*/
- std::string manualHostname;
+ std::string manualHostname = "";
/**
* The port to connect to.
* Leave this to -1 to use the port discovered by SRV lookups, and 5222 as a
* fallback.
*/
- int manualPort;
+ int manualPort = -1;
/**
* The type of proxy to use for connecting to the XMPP
* server.
*/
- ProxyType proxyType;
+ ProxyType proxyType = SystemConfiguredProxy;
/**
* Override the system-configured proxy hostname.
*/
- std::string manualProxyHostname;
+ std::string manualProxyHostname = "";
/**
* Override the system-configured proxy port.
*/
- int manualProxyPort;
+ int manualProxyPort = -1;
/**
* If non-empty, use BOSH instead of direct TCP, with the given URL.
* Default: empty (no BOSH)
*/
- URL boshURL;
+ URL boshURL = URL();
/**
* If non-empty, BOSH connections will try to connect over this HTTP CONNECT
* proxy instead of directly.
* Default: empty (no proxy)
*/
- URL boshHTTPConnectProxyURL;
+ URL boshHTTPConnectProxyURL = URL();
/**
* If this and matching Password are non-empty, BOSH connections over
* HTTP CONNECT proxies will use these credentials for proxy access.
* Default: empty (no authentication needed by the proxy)
*/
- SafeString boshHTTPConnectProxyAuthID;
- SafeString boshHTTPConnectProxyAuthPassword;
+ SafeString boshHTTPConnectProxyAuthID = SafeString("");
+ SafeString boshHTTPConnectProxyAuthPassword = SafeString("");
/**
* This can be initialized with a custom HTTPTrafficFilter, which allows HTTP CONNECT
@@ -158,5 +144,11 @@ namespace Swift {
* Options passed to the TLS stack
*/
TLSOptions tlsOptions;
+
+ /**
+ * Session shutdown timeout in milliseconds. This is the maximum time Swiften
+ * waits from a session close to the socket close.
+ */
+ int sessionShutdownTimeoutInMilliseconds = 10000;
};
}
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp
index c301881..bcfb004 100644
--- a/Swiften/Client/ClientSession.cpp
+++ b/Swiften/Client/ClientSession.cpp
@@ -6,42 +6,47 @@
#include <Swiften/Client/ClientSession.h>
+#include <memory>
+
#include <boost/bind.hpp>
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <boost/uuid/uuid_generators.hpp>
-#include <memory>
-#include <Swiften/Base/Platform.h>
#include <Swiften/Base/Log.h>
-#include <Swiften/Elements/ProtocolHeader.h>
-#include <Swiften/Elements/StreamFeatures.h>
-#include <Swiften/Elements/StreamError.h>
-#include <Swiften/Elements/StartTLSRequest.h>
-#include <Swiften/Elements/StartTLSFailure.h>
-#include <Swiften/Elements/TLSProceed.h>
-#include <Swiften/Elements/AuthRequest.h>
-#include <Swiften/Elements/AuthSuccess.h>
-#include <Swiften/Elements/AuthFailure.h>
+#include <Swiften/Base/Platform.h>
+#include <Swiften/Crypto/CryptoProvider.h>
#include <Swiften/Elements/AuthChallenge.h>
+#include <Swiften/Elements/AuthFailure.h>
+#include <Swiften/Elements/AuthRequest.h>
#include <Swiften/Elements/AuthResponse.h>
-#include <Swiften/Elements/Compressed.h>
+#include <Swiften/Elements/AuthSuccess.h>
#include <Swiften/Elements/CompressFailure.h>
#include <Swiften/Elements/CompressRequest.h>
+#include <Swiften/Elements/Compressed.h>
#include <Swiften/Elements/EnableStreamManagement.h>
-#include <Swiften/Elements/StreamManagementEnabled.h>
-#include <Swiften/Elements/StreamManagementFailed.h>
-#include <Swiften/Elements/StartSession.h>
-#include <Swiften/Elements/StanzaAck.h>
-#include <Swiften/Elements/StanzaAckRequest.h>
#include <Swiften/Elements/IQ.h>
+#include <Swiften/Elements/ProtocolHeader.h>
#include <Swiften/Elements/ResourceBind.h>
-#include <Swiften/SASL/PLAINClientAuthenticator.h>
+#include <Swiften/Elements/StanzaAck.h>
+#include <Swiften/Elements/StanzaAckRequest.h>
+#include <Swiften/Elements/StartSession.h>
+#include <Swiften/Elements/StartTLSFailure.h>
+#include <Swiften/Elements/StartTLSRequest.h>
+#include <Swiften/Elements/StreamError.h>
+#include <Swiften/Elements/StreamFeatures.h>
+#include <Swiften/Elements/StreamManagementEnabled.h>
+#include <Swiften/Elements/StreamManagementFailed.h>
+#include <Swiften/Elements/TLSProceed.h>
+#include <Swiften/Network/Timer.h>
+#include <Swiften/Network/TimerFactory.h>
+#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h>
#include <Swiften/SASL/EXTERNALClientAuthenticator.h>
+#include <Swiften/SASL/PLAINClientAuthenticator.h>
#include <Swiften/SASL/SCRAMSHA1ClientAuthenticator.h>
-#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h>
-#include <Swiften/Crypto/CryptoProvider.h>
#include <Swiften/Session/SessionStream.h>
+#include <Swiften/StreamManagement/StanzaAckRequester.h>
+#include <Swiften/StreamManagement/StanzaAckResponder.h>
#include <Swiften/TLS/CertificateTrustChecker.h>
#include <Swiften/TLS/ServerIdentityVerifier.h>
@@ -59,12 +64,14 @@ ClientSession::ClientSession(
const JID& jid,
std::shared_ptr<SessionStream> stream,
IDNConverter* idnConverter,
- CryptoProvider* crypto) :
+ CryptoProvider* crypto,
+ TimerFactory* timerFactory) :
localJID(jid),
- state(Initial),
+ state(State::Initial),
stream(stream),
idnConverter(idnConverter),
crypto(crypto),
+ timerFactory(timerFactory),
allowPLAINOverNonTLS(false),
useStreamCompression(true),
useTLS(UseTLSWhenAvailable),
@@ -89,12 +96,13 @@ ClientSession::~ClientSession() {
void ClientSession::start() {
stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1));
+ stream->onStreamEndReceived.connect(boost::bind(&ClientSession::handleStreamEnd, shared_from_this()));
stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1));
stream->onClosed.connect(boost::bind(&ClientSession::handleStreamClosed, shared_from_this(), _1));
stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this()));
- assert(state == Initial);
- state = WaitingForStreamStart;
+ assert(state == State::Initial);
+ state = State::WaitingForStreamStart;
sendStreamHeader();
}
@@ -112,8 +120,19 @@ void ClientSession::sendStanza(std::shared_ptr<Stanza> stanza) {
}
void ClientSession::handleStreamStart(const ProtocolHeader&) {
- CHECK_STATE_OR_RETURN(WaitingForStreamStart);
- state = Negotiating;
+ CHECK_STATE_OR_RETURN(State::WaitingForStreamStart);
+ state = State::Negotiating;
+}
+
+void ClientSession::handleStreamEnd() {
+ if (state == State::Finishing) {
+ // We are already in finishing state if we iniated the close of the session.
+ stream->close();
+ }
+ else {
+ state = State::Finishing;
+ initiateShutdown(true);
+ }
}
void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
@@ -121,11 +140,11 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
if (stanzaAckResponder_) {
stanzaAckResponder_->handleStanzaReceived();
}
- if (getState() == Initialized) {
+ if (getState() == State::Initialized) {
onStanzaReceived(stanza);
}
else if (std::shared_ptr<IQ> iq = std::dynamic_pointer_cast<IQ>(element)) {
- if (state == BindingResource) {
+ if (state == State::BindingResource) {
std::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>());
if (iq->getType() == IQ::Error && iq->getID() == "session-bind") {
finishSession(Error::ResourceBindError);
@@ -145,7 +164,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
finishSession(Error::UnexpectedElementError);
}
}
- else if (state == StartingSession) {
+ else if (state == State::StartingSession) {
if (iq->getType() == IQ::Result) {
needSessionStart = false;
continueSessionInitialization();
@@ -183,7 +202,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
else if (StreamError::ref streamError = std::dynamic_pointer_cast<StreamError>(element)) {
finishSession(Error::StreamError);
}
- else if (getState() == Initialized) {
+ else if (getState() == State::Initialized) {
std::shared_ptr<Stanza> stanza = std::dynamic_pointer_cast<Stanza>(element);
if (stanza) {
if (stanzaAckResponder_) {
@@ -193,17 +212,17 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
}
}
else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) {
- CHECK_STATE_OR_RETURN(Negotiating);
+ CHECK_STATE_OR_RETURN(State::Negotiating);
if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption() && useTLS != NeverUseTLS) {
- state = WaitingForEncrypt;
+ state = State::WaitingForEncrypt;
stream->writeElement(std::make_shared<StartTLSRequest>());
}
else if (useTLS == RequireTLS && !stream->isTLSEncrypted()) {
finishSession(Error::NoSupportedAuthMechanismsError);
}
else if (useStreamCompression && stream->supportsZLibCompression() && streamFeatures->hasCompressionMethod("zlib")) {
- state = Compressing;
+ state = State::Compressing;
stream->writeElement(std::make_shared<CompressRequest>("zlib"));
}
else if (streamFeatures->hasAuthenticationMechanisms()) {
@@ -222,7 +241,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
authenticator = gssapiAuthenticator;
if (!gssapiAuthenticator->isError()) {
- state = Authenticating;
+ state = State::Authenticating;
stream->writeElement(std::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse()));
}
else {
@@ -236,7 +255,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
if (stream->hasTLSCertificate()) {
if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
authenticator = new EXTERNALClientAuthenticator();
- state = Authenticating;
+ state = State::Authenticating;
stream->writeElement(std::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray("")));
}
else {
@@ -245,7 +264,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
}
else if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
authenticator = new EXTERNALClientAuthenticator();
- state = Authenticating;
+ state = State::Authenticating;
stream->writeElement(std::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray("")));
}
else if (streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1") || streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1-PLUS")) {
@@ -262,12 +281,12 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
scramAuthenticator->setTLSChannelBindingData(finishMessage);
}
authenticator = scramAuthenticator;
- state = WaitingForCredentials;
+ state = State::WaitingForCredentials;
onNeedCredentials();
}
else if ((stream->isTLSEncrypted() || allowPLAINOverNonTLS) && streamFeatures->hasAuthenticationMechanism("PLAIN")) {
authenticator = new PLAINClientAuthenticator();
- state = WaitingForCredentials;
+ state = State::WaitingForCredentials;
onNeedCredentials();
}
else if (streamFeatures->hasAuthenticationMechanism("DIGEST-MD5") && crypto->isMD5AllowedForCrypto()) {
@@ -275,7 +294,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
s << boost::uuids::random_generator()();
// FIXME: Host should probably be the actual host
authenticator = new DIGESTMD5ClientAuthenticator(localJID.getDomain(), s.str(), crypto);
- state = WaitingForCredentials;
+ state = State::WaitingForCredentials;
onNeedCredentials();
}
else {
@@ -299,8 +318,8 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
}
}
else if (std::dynamic_pointer_cast<Compressed>(element)) {
- CHECK_STATE_OR_RETURN(Compressing);
- state = WaitingForStreamStart;
+ CHECK_STATE_OR_RETURN(State::Compressing);
+ state = State::WaitingForStreamStart;
stream->addZLibCompression();
stream->resetXMPPParser();
sendStreamHeader();
@@ -322,7 +341,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
continueSessionInitialization();
}
else if (AuthChallenge* challenge = dynamic_cast<AuthChallenge*>(element.get())) {
- CHECK_STATE_OR_RETURN(Authenticating);
+ CHECK_STATE_OR_RETURN(State::Authenticating);
assert(authenticator);
if (authenticator->setChallenge(challenge->getValue())) {
stream->writeElement(std::make_shared<AuthResponse>(authenticator->getResponse()));
@@ -340,13 +359,13 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
}
}
else if (AuthSuccess* authSuccess = dynamic_cast<AuthSuccess*>(element.get())) {
- CHECK_STATE_OR_RETURN(Authenticating);
+ CHECK_STATE_OR_RETURN(State::Authenticating);
assert(authenticator);
if (!authenticator->setChallenge(authSuccess->getValue())) {
finishSession(Error::ServerVerificationFailedError);
}
else {
- state = WaitingForStreamStart;
+ state = State::WaitingForStreamStart;
delete authenticator;
authenticator = nullptr;
stream->resetXMPPParser();
@@ -357,8 +376,8 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
finishSession(Error::AuthenticationFailedError);
}
else if (dynamic_cast<TLSProceed*>(element.get())) {
- CHECK_STATE_OR_RETURN(WaitingForEncrypt);
- state = Encrypting;
+ CHECK_STATE_OR_RETURN(State::WaitingForEncrypt);
+ state = State::Encrypting;
stream->addTLSEncryption();
}
else if (dynamic_cast<StartTLSFailure*>(element.get())) {
@@ -366,14 +385,14 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
}
else {
// FIXME Not correct?
- state = Initialized;
+ state = State::Initialized;
onInitialized();
}
}
void ClientSession::continueSessionInitialization() {
if (needResourceBind) {
- state = BindingResource;
+ state = State::BindingResource;
std::shared_ptr<ResourceBind> resourceBind(std::make_shared<ResourceBind>());
if (!localJID.getResource().empty()) {
resourceBind->setResource(localJID.getResource());
@@ -381,15 +400,15 @@ void ClientSession::continueSessionInitialization() {
sendStanza(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind));
}
else if (needAcking) {
- state = EnablingSessionManagement;
+ state = State::EnablingSessionManagement;
stream->writeElement(std::make_shared<EnableStreamManagement>());
}
else if (needSessionStart) {
- state = StartingSession;
+ state = State::StartingSession;
sendStanza(IQ::createRequest(IQ::Set, JID(), "session-start", std::make_shared<StartSession>()));
}
else {
- state = Initialized;
+ state = State::Initialized;
onInitialized();
}
}
@@ -403,15 +422,15 @@ bool ClientSession::checkState(State state) {
}
void ClientSession::sendCredentials(const SafeByteArray& password) {
- assert(WaitingForCredentials);
+ assert(state == State::WaitingForCredentials);
assert(authenticator);
- state = Authenticating;
+ state = State::Authenticating;
authenticator->setCredentials(localJID.getNode(), password);
stream->writeElement(std::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse()));
}
void ClientSession::handleTLSEncrypted() {
- CHECK_STATE_OR_RETURN(Encrypting);
+ CHECK_STATE_OR_RETURN(State::Encrypting);
std::vector<Certificate::ref> certificateChain = stream->getPeerCertificateChain();
std::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError();
@@ -438,15 +457,38 @@ void ClientSession::checkTrustOrFinish(const std::vector<Certificate::ref>& cert
}
}
+void ClientSession::initiateShutdown(bool sendFooter) {
+ if (!streamShutdownTimeout) {
+ streamShutdownTimeout = timerFactory->createTimer(sessionShutdownTimeoutInMilliseconds);
+ streamShutdownTimeout->onTick.connect(boost::bind(&ClientSession::handleStreamShutdownTimeout, shared_from_this()));
+ streamShutdownTimeout->start();
+ }
+ if (sendFooter) {
+ stream->writeFooter();
+ }
+ if (state == State::Finishing) {
+ // The other side already send </stream>; we can close the socket.
+ stream->close();
+ }
+ else {
+ state = State::Finishing;
+ }
+}
+
void ClientSession::continueAfterTLSEncrypted() {
- state = WaitingForStreamStart;
+ state = State::WaitingForStreamStart;
stream->resetXMPPParser();
sendStreamHeader();
}
void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError) {
State previousState = state;
- state = Finished;
+ state = State::Finished;
+
+ if (streamShutdownTimeout) {
+ streamShutdownTimeout->stop();
+ streamShutdownTimeout.reset();
+ }
if (stanzaAckRequester_) {
stanzaAckRequester_->onRequestAck.disconnect(boost::bind(&ClientSession::requestAck, shared_from_this()));
@@ -459,11 +501,12 @@ void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError
}
stream->setWhitespacePingEnabled(false);
stream->onStreamStartReceived.disconnect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1));
+ stream->onStreamEndReceived.disconnect(boost::bind(&ClientSession::handleStreamEnd, shared_from_this()));
stream->onElementReceived.disconnect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1));
stream->onClosed.disconnect(boost::bind(&ClientSession::handleStreamClosed, shared_from_this(), _1));
stream->onTLSEncrypted.disconnect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this()));
- if (previousState == Finishing) {
+ if (previousState == State::Finishing) {
onFinished(error_);
}
else {
@@ -471,8 +514,17 @@ void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError
}
}
+void ClientSession::handleStreamShutdownTimeout() {
+ handleStreamClosed(std::shared_ptr<Swift::Error>());
+}
+
void ClientSession::finish() {
- finishSession(std::shared_ptr<Error>());
+ if (state != State::Finishing && state != State::Finished) {
+ finishSession(std::shared_ptr<Error>());
+ }
+ else {
+ SWIFT_LOG(warning) << "Session already finished or finishing." << std::endl;
+ }
}
void ClientSession::finishSession(Error::Type error) {
@@ -480,7 +532,6 @@ void ClientSession::finishSession(Error::Type error) {
}
void ClientSession::finishSession(std::shared_ptr<Swift::Error> error) {
- state = Finishing;
if (!error_) {
error_ = error;
}
@@ -495,8 +546,19 @@ void ClientSession::finishSession(std::shared_ptr<Swift::Error> error) {
delete authenticator;
authenticator = nullptr;
}
- stream->writeFooter();
- stream->close();
+ // Immidiately close TCP connection without stream closure.
+ if (std::dynamic_pointer_cast<CertificateVerificationError>(error)) {
+ state = State::Finishing;
+ initiateShutdown(false);
+ }
+ else {
+ if (state == State::Finishing) {
+ initiateShutdown(true);
+ }
+ else if (state != State::Finished) {
+ initiateShutdown(true);
+ }
+ }
}
void ClientSession::requestAck() {
diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h
index ad1b78c..c7b3658 100644
--- a/Swiften/Client/ClientSession.h
+++ b/Swiften/Client/ClientSession.h
@@ -16,18 +16,21 @@
#include <Swiften/Elements/ToplevelElement.h>
#include <Swiften/JID/JID.h>
#include <Swiften/Session/SessionStream.h>
-#include <Swiften/StreamManagement/StanzaAckRequester.h>
-#include <Swiften/StreamManagement/StanzaAckResponder.h>
namespace Swift {
- class ClientAuthenticator;
class CertificateTrustChecker;
- class IDNConverter;
+ class ClientAuthenticator;
class CryptoProvider;
+ class IDNConverter;
+ class Stanza;
+ class StanzaAckRequester;
+ class StanzaAckResponder;
+ class TimerFactory;
+ class Timer;
class SWIFTEN_API ClientSession : public std::enable_shared_from_this<ClientSession> {
public:
- enum State {
+ enum class State {
Initial,
WaitingForStreamStart,
Negotiating,
@@ -55,7 +58,8 @@ namespace Swift {
SessionStartError,
TLSClientCertificateError,
TLSError,
- StreamError
+ StreamError,
+ StreamEndError, // The server send a closing stream tag.
} type;
std::shared_ptr<boost::system::error_code> errorCode;
Error(Type type) : type(type) {}
@@ -69,8 +73,8 @@ namespace Swift {
~ClientSession();
- static std::shared_ptr<ClientSession> create(const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, CryptoProvider* crypto) {
- return std::shared_ptr<ClientSession>(new ClientSession(jid, stream, idnConverter, crypto));
+ static std::shared_ptr<ClientSession> create(const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, CryptoProvider* crypto, TimerFactory* timerFactory) {
+ return std::shared_ptr<ClientSession>(new ClientSession(jid, stream, idnConverter, crypto, timerFactory));
}
State getState() const {
@@ -93,7 +97,6 @@ namespace Swift {
useAcks = b;
}
-
bool getStreamManagementEnabled() const {
// Explicitly convert to bool. In C++11, it would be cleaner to
// compare to nullptr.
@@ -116,7 +119,7 @@ namespace Swift {
void finish();
bool isFinished() const {
- return getState() == Finished;
+ return getState() == State::Finished;
}
void sendCredentials(const SafeByteArray& password);
@@ -138,6 +141,10 @@ namespace Swift {
authenticationPort = i;
}
+ void setSessionShutdownTimeout(int timeoutInMilliseconds) {
+ sessionShutdownTimeoutInMilliseconds = timeoutInMilliseconds;
+ }
+
public:
boost::signals2::signal<void ()> onNeedCredentials;
boost::signals2::signal<void ()> onInitialized;
@@ -150,7 +157,8 @@ namespace Swift {
const JID& jid,
std::shared_ptr<SessionStream>,
IDNConverter* idnConverter,
- CryptoProvider* crypto);
+ CryptoProvider* crypto,
+ TimerFactory* timerFactory);
void finishSession(Error::Type error);
void finishSession(std::shared_ptr<Swift::Error> error);
@@ -163,7 +171,9 @@ namespace Swift {
void handleElement(std::shared_ptr<ToplevelElement>);
void handleStreamStart(const ProtocolHeader&);
+ void handleStreamEnd();
void handleStreamClosed(std::shared_ptr<Swift::Error>);
+ void handleStreamShutdownTimeout();
void handleTLSEncrypted();
@@ -175,13 +185,17 @@ namespace Swift {
void ack(unsigned int handledStanzasCount);
void continueAfterTLSEncrypted();
void checkTrustOrFinish(const std::vector<Certificate::ref>& certificateChain, std::shared_ptr<CertificateVerificationError> error);
+ void initiateShutdown(bool sendFooter);
private:
JID localJID;
State state;
std::shared_ptr<SessionStream> stream;
- IDNConverter* idnConverter;
- CryptoProvider* crypto;
+ IDNConverter* idnConverter = nullptr;
+ CryptoProvider* crypto = nullptr;
+ TimerFactory* timerFactory = nullptr;
+ std::shared_ptr<Timer> streamShutdownTimeout;
+ int sessionShutdownTimeoutInMilliseconds = 10000;
bool allowPLAINOverNonTLS;
bool useStreamCompression;
UseTLS useTLS;
diff --git a/Swiften/Client/ClientSessionStanzaChannel.h b/Swiften/Client/ClientSessionStanzaChannel.h
index 0527a5c..c4ee393 100644
--- a/Swiften/Client/ClientSessionStanzaChannel.h
+++ b/Swiften/Client/ClientSessionStanzaChannel.h
@@ -33,7 +33,7 @@ namespace Swift {
virtual std::vector<Certificate::ref> getPeerCertificateChain() const;
bool isAvailable() const {
- return session && session->getState() == ClientSession::Initialized;
+ return session && session->getState() == ClientSession::State::Initialized;
}
private:
diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp
index 3d75d8b..3c7902e 100644
--- a/Swiften/Client/CoreClient.cpp
+++ b/Swiften/Client/CoreClient.cpp
@@ -154,12 +154,13 @@ void CoreClient::connect(const ClientOptions& o) {
}
void CoreClient::bindSessionToStream() {
- session_ = ClientSession::create(jid_, sessionStream_, networkFactories->getIDNConverter(), networkFactories->getCryptoProvider());
+ session_ = ClientSession::create(jid_, sessionStream_, networkFactories->getIDNConverter(), networkFactories->getCryptoProvider(), networkFactories->getTimerFactory());
session_->setCertificateTrustChecker(certificateTrustChecker);
session_->setUseStreamCompression(options.useStreamCompression);
session_->setAllowPLAINOverNonTLS(options.allowPLAINWithoutTLS);
session_->setSingleSignOn(options.singleSignOn);
session_->setAuthenticationPort(options.manualPort);
+ session_->setSessionShutdownTimeout(options.sessionShutdownTimeoutInMilliseconds);
switch(options.useTLS) {
case ClientOptions::UseTLSWhenAvailable:
session_->setUseTLS(ClientSession::UseTLSWhenAvailable);
@@ -273,6 +274,9 @@ void CoreClient::handleSessionFinished(std::shared_ptr<Error> error) {
case ClientSession::Error::StreamError:
clientError = ClientError(ClientError::StreamError);
break;
+ case ClientSession::Error::StreamEndError:
+ clientError = ClientError(ClientError::StreamError);
+ break;
}
clientError.setErrorCode(actualError->errorCode);
}
diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp
index bd93f4b..44df961 100644
--- a/Swiften/Client/UnitTest/ClientSessionTest.cpp
+++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp
@@ -10,6 +10,8 @@
#include <boost/bind.hpp>
#include <boost/optional.hpp>
+#include <Swiften/Base/Debug.h>
+
#include <cppunit/extensions/HelperMacros.h>
#include <cppunit/extensions/TestFactoryRegistry.h>
@@ -34,6 +36,7 @@
#include <Swiften/Elements/TLSProceed.h>
#include <Swiften/IDN/IDNConverter.h>
#include <Swiften/IDN/PlatformIDNConverter.h>
+#include <Swiften/Network/DummyTimerFactory.h>
#include <Swiften/Session/SessionStream.h>
#include <Swiften/TLS/BlindCertificateTrustChecker.h>
#include <Swiften/TLS/SimpleCertificate.h>
@@ -59,6 +62,11 @@ class ClientSessionTest : public CppUnit::TestFixture {
CPPUNIT_TEST(testStreamManagement_Failed);
CPPUNIT_TEST(testUnexpectedChallenge);
CPPUNIT_TEST(testFinishAcksStanzas);
+
+ CPPUNIT_TEST(testServerInitiatedSessionClose);
+ CPPUNIT_TEST(testClientInitiatedSessionClose);
+ CPPUNIT_TEST(testTimeoutOnShutdown);
+
/*
CPPUNIT_TEST(testResourceBind);
CPPUNIT_TEST(testResourceBind_ChangeResource);
@@ -77,6 +85,7 @@ class ClientSessionTest : public CppUnit::TestFixture {
void setUp() {
crypto = std::shared_ptr<CryptoProvider>(PlatformCryptoProvider::create());
idnConverter = std::shared_ptr<IDNConverter>(PlatformIDNConverter::create());
+ timerFactory = std::make_shared<DummyTimerFactory>();
server = std::make_shared<MockSessionStream>();
sessionFinishedReceived = false;
needCredentials = false;
@@ -92,7 +101,7 @@ class ClientSessionTest : public CppUnit::TestFixture {
session->start();
server->breakConnection();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
@@ -103,7 +112,11 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->sendStreamStart();
server->sendStreamError();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
@@ -136,7 +149,11 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->sendTLSFailure();
CPPUNIT_ASSERT(!server->tlsEncrypted);
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
@@ -151,7 +168,7 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->sendTLSProceed();
server->breakTLS();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
@@ -167,10 +184,12 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->sendTLSProceed();
CPPUNIT_ASSERT(server->tlsEncrypted);
server->onTLSEncrypted();
+ server->close();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
+ CPPUNIT_ASSERT(std::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError));
CPPUNIT_ASSERT_EQUAL(CertificateVerificationError::InvalidServerIdentity, std::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError)->getType());
}
@@ -181,7 +200,11 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->sendStreamStart();
server->sendEmptyStreamFeatures();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
@@ -193,7 +216,7 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->sendStreamStart();
server->sendStreamFeaturesWithPLAINAuthentication();
CPPUNIT_ASSERT(needCredentials);
- CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::WaitingForCredentials, session->getState());
session->sendCredentials(createSafeByteArray("mypass"));
server->receiveAuthRequest("PLAIN");
server->sendAuthSuccess();
@@ -209,12 +232,16 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->sendStreamStart();
server->sendStreamFeaturesWithPLAINAuthentication();
CPPUNIT_ASSERT(needCredentials);
- CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::WaitingForCredentials, session->getState());
session->sendCredentials(createSafeByteArray("mypass"));
server->receiveAuthRequest("PLAIN");
server->sendAuthFailure();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
@@ -227,7 +254,11 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->sendStreamStart();
server->sendStreamFeaturesWithPLAINAuthentication();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
@@ -240,8 +271,11 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithMultipleAuthentication();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
@@ -253,7 +287,12 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->sendStreamStart();
server->sendStreamFeaturesWithUnknownAuthentication();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ server->close();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
@@ -281,7 +320,11 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->sendChallenge();
server->sendChallenge();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
@@ -305,7 +348,7 @@ class ClientSessionTest : public CppUnit::TestFixture {
CPPUNIT_ASSERT(session->getStreamManagementEnabled());
// TODO: Test if the requesters & responders do their work
- CPPUNIT_ASSERT_EQUAL(ClientSession::Initialized, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Initialized, session->getState());
session->finish();
}
@@ -328,7 +371,7 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->sendStreamManagementFailed();
CPPUNIT_ASSERT(!session->getStreamManagementEnabled());
- CPPUNIT_ASSERT_EQUAL(ClientSession::Initialized, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Initialized, session->getState());
session->finish();
}
@@ -345,9 +388,44 @@ class ClientSessionTest : public CppUnit::TestFixture {
server->receiveAck(3);
}
+ void testServerInitiatedSessionClose() {
+ std::shared_ptr<ClientSession> session(createSession());
+ initializeSession(session);
+
+ server->onStreamEndReceived();
+ server->close();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+ }
+
+ void testClientInitiatedSessionClose() {
+ std::shared_ptr<ClientSession> session(createSession());
+ initializeSession(session);
+
+ session->finish();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
+ }
+
+ void testTimeoutOnShutdown() {
+ std::shared_ptr<ClientSession> session(createSession());
+ initializeSession(session);
+
+ session->finish();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+ timerFactory->setTime(60000);
+
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
+ CPPUNIT_ASSERT(sessionFinishedReceived);
+ }
+
private:
std::shared_ptr<ClientSession> createSession() {
- std::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server, idnConverter.get(), crypto.get());
+ std::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server, idnConverter.get(), crypto.get(), timerFactory.get());
session->onFinished.connect(boost::bind(&ClientSessionTest::handleSessionFinished, this, _1));
session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::handleSessionNeedCredentials, this));
session->setAllowPLAINOverNonTLS(true);
@@ -631,6 +709,7 @@ class ClientSessionTest : public CppUnit::TestFixture {
std::shared_ptr<Error> sessionFinishedError;
BlindCertificateTrustChecker* blindCertificateTrustChecker;
std::shared_ptr<CryptoProvider> crypto;
+ std::shared_ptr<DummyTimerFactory> timerFactory;
};
CPPUNIT_TEST_SUITE_REGISTRATION(ClientSessionTest);
diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp
index 402f642..10c6ad0 100644
--- a/Swiften/Session/BasicSessionStream.cpp
+++ b/Swiften/Session/BasicSessionStream.cpp
@@ -40,6 +40,7 @@ BasicSessionStream::BasicSessionStream(
tlsOptions_(tlsOptions) {
xmppLayer = new XMPPLayer(payloadParserFactories, payloadSerializers, xmlParserFactory, streamType);
xmppLayer->onStreamStart.connect(boost::bind(&BasicSessionStream::handleStreamStartReceived, this, _1));
+ xmppLayer->onStreamEnd.connect(boost::bind(&BasicSessionStream::handleStreamEndReceived, this));
xmppLayer->onElement.connect(boost::bind(&BasicSessionStream::handleElementReceived, this, _1));
xmppLayer->onError.connect(boost::bind(&BasicSessionStream::handleXMPPError, this));
xmppLayer->onDataRead.connect(boost::bind(&BasicSessionStream::handleDataRead, this, _1));
@@ -68,6 +69,7 @@ BasicSessionStream::~BasicSessionStream() {
delete connectionLayer;
xmppLayer->onStreamStart.disconnect(boost::bind(&BasicSessionStream::handleStreamStartReceived, this, _1));
+ xmppLayer->onStreamEnd.disconnect(boost::bind(&BasicSessionStream::handleStreamEndReceived, this));
xmppLayer->onElement.disconnect(boost::bind(&BasicSessionStream::handleElementReceived, this, _1));
xmppLayer->onError.disconnect(boost::bind(&BasicSessionStream::handleXMPPError, this));
xmppLayer->onDataRead.disconnect(boost::bind(&BasicSessionStream::handleDataRead, this, _1));
@@ -171,6 +173,10 @@ void BasicSessionStream::handleStreamStartReceived(const ProtocolHeader& header)
onStreamStartReceived(header);
}
+void BasicSessionStream::handleStreamEndReceived() {
+ onStreamEndReceived();
+}
+
void BasicSessionStream::handleElementReceived(std::shared_ptr<ToplevelElement> element) {
onElementReceived(element);
}
diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h
index 1806cef..48b3d63 100644
--- a/Swiften/Session/BasicSessionStream.h
+++ b/Swiften/Session/BasicSessionStream.h
@@ -73,6 +73,7 @@ namespace Swift {
void handleTLSConnected();
void handleTLSError(std::shared_ptr<TLSError>);
void handleStreamStartReceived(const ProtocolHeader&);
+ void handleStreamEndReceived();
void handleElementReceived(std::shared_ptr<ToplevelElement>);
void handleDataRead(const SafeByteArray& data);
void handleDataWritten(const SafeByteArray& data);
diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h
index f56f495..c5ec42a 100644
--- a/Swiften/Session/SessionStream.h
+++ b/Swiften/Session/SessionStream.h
@@ -75,6 +75,7 @@ namespace Swift {
virtual ByteArray getTLSFinishMessage() const = 0;
boost::signals2::signal<void (const ProtocolHeader&)> onStreamStartReceived;
+ boost::signals2::signal<void ()> onStreamEndReceived;
boost::signals2::signal<void (std::shared_ptr<ToplevelElement>)> onElementReceived;
boost::signals2::signal<void (std::shared_ptr<Error>)> onClosed;
boost::signals2::signal<void ()> onTLSEncrypted;
diff --git a/Swiften/StreamStack/XMPPLayer.cpp b/Swiften/StreamStack/XMPPLayer.cpp
index 2eed906..982d13f 100644
--- a/Swiften/StreamStack/XMPPLayer.cpp
+++ b/Swiften/StreamStack/XMPPLayer.cpp
@@ -86,6 +86,7 @@ void XMPPLayer::handleElement(std::shared_ptr<ToplevelElement> stanza) {
}
void XMPPLayer::handleStreamEnd() {
+ onStreamEnd();
}
void XMPPLayer::resetParser() {
diff --git a/Swiften/StreamStack/XMPPLayer.h b/Swiften/StreamStack/XMPPLayer.h
index 1d4abf8..f0b5afb 100644
--- a/Swiften/StreamStack/XMPPLayer.h
+++ b/Swiften/StreamStack/XMPPLayer.h
@@ -51,6 +51,7 @@ namespace Swift {
public:
boost::signals2::signal<void (const ProtocolHeader&)> onStreamStart;
+ boost::signals2::signal<void ()> onStreamEnd;
boost::signals2::signal<void (std::shared_ptr<ToplevelElement>)> onElement;
boost::signals2::signal<void (const SafeByteArray&)> onWriteData;
boost::signals2::signal<void (const SafeByteArray&)> onDataRead;