summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xBuildTools/FixIncludes.py5
-rw-r--r--Swift/ChangeLog.md1
-rw-r--r--Swiften/Base/Debug.cpp50
-rw-r--r--Swiften/Base/Debug.h8
-rw-r--r--Swiften/ChangeLog.md4
-rw-r--r--Swiften/Client/ClientOptions.h54
-rw-r--r--Swiften/Client/ClientSession.cpp184
-rw-r--r--Swiften/Client/ClientSession.h40
-rw-r--r--Swiften/Client/ClientSessionStanzaChannel.h2
-rw-r--r--Swiften/Client/CoreClient.cpp6
-rw-r--r--Swiften/Client/UnitTest/ClientSessionTest.cpp111
-rw-r--r--Swiften/Session/BasicSessionStream.cpp6
-rw-r--r--Swiften/Session/BasicSessionStream.h1
-rw-r--r--Swiften/Session/SessionStream.h1
-rw-r--r--Swiften/StreamStack/XMPPLayer.cpp1
-rw-r--r--Swiften/StreamStack/XMPPLayer.h1
16 files changed, 350 insertions, 125 deletions
diff --git a/BuildTools/FixIncludes.py b/BuildTools/FixIncludes.py
index d1b8268..8984944 100755
--- a/BuildTools/FixIncludes.py
+++ b/BuildTools/FixIncludes.py
@@ -1,93 +1,96 @@
#!/usr/bin/env python
import sys;
import os;
import re;
from sets import Set
filename = sys.argv[1]
inPlace = False
if "-i" in sys.argv:
inPlace = True
filename_base = os.path.basename(filename)
(filename_name, filename_ext) = os.path.splitext(filename_base)
c_stdlib_headers = Set(["assert.h", "limits.h", "signal.h", "stdlib.h", "ctype.h", "locale.h", "stdarg.h", "string.h", "errno.h", "math.h", "stddef.h", "time.h", "float.h", "setjmp.h", "stdio.h", "iso646.h", "wchar.h", "wctype.h", "complex.h", "inttypes.h", "stdint.h", "tgmath.h", "fenv.h", "stdbool.h"])
cpp_stdlib_headers = Set(["algorithm", "fstream", "list", "regex", "typeindex", "array", "functional", "locale", "set", "typeinfo", "atomic", "future", "map", "sstream", "type_traits", "bitset", "initializer_list", "memory", "stack", "unordered_map", "chrono", "iomanip", "mutex", "stdexcept", "unordered_set", "codecvt", "ios", "new", "streambuf", "utility", "complex", "iosfwd", "numeric", "string", "valarray", "condition_variable", "iostream", "ostream", "strstream", "vector", "deque", "istream", "queue", "system_error", "exception", "iterator", "random", "thread", "forward_list", "limits", "ratio", "tuple", "cassert", "ciso646", "csetjmp", "cstdio", "ctime", "cctype", "climits", "csignal", "cstdlib", "cwchar", "cerrno", "clocale", "cstdarg", "cstring", "cwctype", "cfloat", "cmath", "cstddef"])
class HeaderType:
- PRAGMA_ONCE, CORRESPONDING_HEADER, C_STDLIB, CPP_STDLIB, BOOST, QT, OTHER, SWIFTEN, LIMBER, SLIMBER, SWIFT_CONTROLLERS, SLUIFT, SWIFTOOLS, SWIFT = range(14)
+ PRAGMA_ONCE, CORRESPONDING_HEADER, C_STDLIB, CPP_STDLIB, BOOST, QT, SWIFTEN_BASE_DEBUG, OTHER, SWIFTEN, LIMBER, SLIMBER, SWIFT_CONTROLLERS, SLUIFT, SWIFTOOLS, SWIFT = range(15)
def findHeaderBlock(lines):
start = False
end = False
lastLine = None
for idx, line in enumerate(lines):
if not start and line.startswith("#"):
start = idx
elif start and (not end) and (not line.startswith("#")) and line.strip():
end = idx-1
break
if not end:
end = len(lines)
return (start, end)
def lineToFileName(line):
match = re.match( r'#include "(.*)"', line)
if match:
return match.group(1)
match = re.match( r'#include <(.*)>', line)
if match:
return match.group(1)
return False
def fileNameToHeaderType(name):
if name.endswith("/" + filename_name + ".h"):
return HeaderType.CORRESPONDING_HEADER
if name in c_stdlib_headers:
return HeaderType.C_STDLIB
if name in cpp_stdlib_headers:
return HeaderType.CPP_STDLIB
if name.startswith("boost"):
return HeaderType.BOOST
if name.startswith("Q"):
return HeaderType.QT
+ if name.startswith("Swiften/Base/Debug.h"):
+ return HeaderType.SWIFTEN_BASE_DEBUG
+
if name.startswith("Swiften"):
return HeaderType.SWIFTEN
if name.startswith("Limber"):
return HeaderType.LIMBER
if name.startswith("Slimber"):
return HeaderType.SLIMBER
if name.startswith("Swift/Controllers"):
return HeaderType.SWIFT_CONTROLLERS
if name.startswith("Sluift"):
return HeaderType.SLUIFT
if name.startswith("SwifTools"):
return HeaderType.SWIFTOOLS
if name.startswith("Swift"):
return HeaderType.SWIFT
return HeaderType.OTHER
def serializeHeaderGroups(groups):
headerList = []
for group in range(0, HeaderType.SWIFT + 1):
if group in groups:
# sorted and without duplicates
headers = sorted(list(set(groups[group])))
headerList.extend(headers)
diff --git a/Swift/ChangeLog.md b/Swift/ChangeLog.md
index 475062c..4c88d88 100644
--- a/Swift/ChangeLog.md
+++ b/Swift/ChangeLog.md
@@ -1,34 +1,35 @@
4.0-in-progress
---------------
- Fix UI layout issue for translations that require right-to-left (RTL) layout
- macOS releases are now code-signed with a key from Apple, so they can be run without Gatekeeper trust warnings
+- Handle sessions being closed by the server
4.0-beta2 ( 2016-07-20 )
------------------------
- Fix Swift bug introduced in 4.0-beta1 that results in the UI sometimes getting stuck during login
4.0-beta1 ( 2016-07-15 )
------------------------
- Support for message carbons (XEP-0280)
- Improved spell checker support on Linux
- Enabled trellis mode as a default feature, allowing several tiled chats windows to be shown at once
- New chat theme including a new font
- And assorted smaller features and usability enhancements
3.0 ( 2016-02-24 )
------------------
- File transfer and Mac Notification Center issues fixed
- Fix connection to servers with invalid or untrusted certificates on OS/X
- Support for the Notification Center on OS X
- Users can now authenticate using certificates (and smart cards on Windows) when using the 'BOSH' connection type.
- Encryption on OS X now uses the platform's native 'Secure Transport' mechanisms.
- Emoticons menu in chat dialogs
- Bookmark for rooms can now be edited directly from the ‘Recent Chats’ list
- Adds option to workaround servers that don’t interoperate well with Windows (schannel) encryption
- Rooms entered while offline will now get entered on reconnect
- Chats can now be seamlessly upgraded to multi-person chats by either inviting someone via the ‘cog’ menu, or dragging them from the roster. This relies on server-side support with an appropriate chatroom (MUC) service.
- Highlighting of keywords and messages from particular users can now be configured (Keyword Highlighting Blog post).
- Full profile vcards (contact information etc.) are now supported and can be configured for the user and queried for contacts.
- Simple Communication Blocking is now supported (subject to server support) to allow the blocking of nuisance users.
- Swift can now transfer files via the ‘Jingle File Transfer’ protocol.
- The status setter will now remember previously set statuses and will allow quick access to these when the user types part of a recently used status.
diff --git a/Swiften/Base/Debug.cpp b/Swiften/Base/Debug.cpp
index b59de35..b4245c3 100644
--- a/Swiften/Base/Debug.cpp
+++ b/Swiften/Base/Debug.cpp
@@ -1,42 +1,43 @@
/*
* Copyright (c) 2015-2016 Isode Limited.
* All rights reserved.
* See the COPYING file for more information.
*/
#include <Swiften/Base/Debug.h>
#include <iostream>
#include <memory>
#include <Swiften/Client/ClientError.h>
+#include <Swiften/Client/ClientSession.h>
#include <Swiften/Serializer/PayloadSerializer.h>
#include <Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h>
#include <Swiften/Serializer/XMPPSerializer.h>
std::ostream& operator<<(std::ostream& os, const Swift::ClientError& error) {
os << "ClientError(";
switch(error.getType()) {
case Swift::ClientError::UnknownError:
os << "UnknownError";
break;
case Swift::ClientError::DomainNameResolveError:
os << "DomainNameResolveError";
break;
case Swift::ClientError::ConnectionError:
os << "ConnectionError";
break;
case Swift::ClientError::ConnectionReadError:
os << "ConnectionReadError";
break;
case Swift::ClientError::ConnectionWriteError:
os << "ConnectionWriteError";
break;
case Swift::ClientError::XMLError:
os << "XMLError";
break;
case Swift::ClientError::AuthenticationFailedError:
os << "AuthenticationFailedError";
break;
case Swift::ClientError::CompressionFailedError:
os << "CompressionFailedError";
@@ -111,30 +112,79 @@ std::ostream& operator<<(std::ostream& os, const Swift::ClientError& error) {
os << "RevocationCheckFailedError";
break;
}
os << ")";
return os;
}
std::ostream& operator<<(std::ostream& os, Swift::Element* ele) {
using namespace Swift;
std::shared_ptr<Element> element = std::shared_ptr<Element>(ele);
std::shared_ptr<Payload> payload = std::dynamic_pointer_cast<Payload>(element);
if (payload) {
FullPayloadSerializerCollection payloadSerializerCollection;
PayloadSerializer *serializer = payloadSerializerCollection.getPayloadSerializer(payload);
os << "Payload(" << serializer->serialize(payload) << ")";
return os;
}
std::shared_ptr<ToplevelElement> topLevelElement = std::dynamic_pointer_cast<ToplevelElement>(element);
if (topLevelElement) {
FullPayloadSerializerCollection payloadSerializerCollection;
XMPPSerializer xmppSerializer(&payloadSerializerCollection, ClientStreamType, false);
SafeByteArray serialized = xmppSerializer.serializeElement(topLevelElement);
os << "TopLevelElement(" << safeByteArrayToString(serialized) << ")";
return os;
}
os << "Element(Unknown)";
return os;
}
+
+std::ostream& operator<<(std::ostream& os, Swift::ClientSession::State state) {
+ using CS = Swift::ClientSession;
+ switch (state) {
+ case CS::State::Initial:
+ os << "ClientSession::State::Initial";
+ break;
+ case CS::State::WaitingForStreamStart:
+ os << "ClientSession::State::WaitingForStreamStart";
+ break;
+ case CS::State::Negotiating:
+ os << "ClientSession::State::Negotiating";
+ break;
+ case CS::State::Compressing:
+ os << "ClientSession::State::Compressing";
+ break;
+ case CS::State::WaitingForEncrypt:
+ os << "ClientSession::State::WaitingForEncrypt";
+ break;
+ case CS::State::Encrypting:
+ os << "ClientSession::State::Encrypting";
+ break;
+ case CS::State::WaitingForCredentials:
+ os << "ClientSession::State::WaitingForCredentials";
+ break;
+ case CS::State::Authenticating:
+ os << "ClientSession::State::Authenticating";
+ break;
+ case CS::State::EnablingSessionManagement:
+ os << "ClientSession::State::EnablingSessionManagement";
+ break;
+ case CS::State::BindingResource:
+ os << "ClientSession::State::BindingResource";
+ break;
+ case CS::State::StartingSession:
+ os << "ClientSession::State::StartingSession";
+ break;
+ case CS::State::Initialized:
+ os << "ClientSession::State::Initialized";
+ break;
+ case CS::State::Finishing:
+ os << "ClientSession::State::Finishing";
+ break;
+ case CS::State::Finished:
+ os << "ClientSession::State::Finished";
+ break;
+ }
+ return os;
+}
diff --git a/Swiften/Base/Debug.h b/Swiften/Base/Debug.h
index 18e7fb4..9dde74c 100644
--- a/Swiften/Base/Debug.h
+++ b/Swiften/Base/Debug.h
@@ -1,20 +1,26 @@
/*
- * Copyright (c) 2015 Isode Limited.
+ * Copyright (c) 2015-2016 Isode Limited.
* All rights reserved.
* See the COPYING file for more information.
*/
+#pragma once
+
#include <iosfwd>
+#include <Swiften/Client/ClientSession.h>
+
namespace Swift {
class ClientError;
class Element;
}
namespace boost {
template<class T> class shared_ptr;
}
std::ostream& operator<<(std::ostream& os, const Swift::ClientError& error);
std::ostream& operator<<(std::ostream& os, Swift::Element* ele);
+
+std::ostream& operator<<(std::ostream& os, Swift::ClientSession::State state);
diff --git a/Swiften/ChangeLog.md b/Swiften/ChangeLog.md
index 4d52c68..62cea5e 100644
--- a/Swiften/ChangeLog.md
+++ b/Swiften/ChangeLog.md
@@ -1,13 +1,17 @@
+4.0-in-progress
+---------------
+- Handle sessions being closed by the server
+
4.0-beta1 ( 2016-07-15 )
------------------------
- Moved code-base to C++11
- Use C++11 threading instead of Boost.Thread library
- Use C++11 smart pointers instead of Boost's
- Migrated from Boost.Signals to Boost.Signals2
- Build without warnings on our CI platforms
- General cleanup like remove of superflous files and #include statements. This means header files that previously were included implictly need to be explicitly included now
- Support IPv6 addresses in URLs
- Changed source code style to use soft tabs (4 spaces wide) instead of hard tabs. Custom patches for Swiften will need to be reformatted accordingly
- Require a TLS backend for building
- Update 3rdParty/lcov to version 1.12
- Fix several possible race conditions and other small bugs \ No newline at end of file
diff --git a/Swiften/Client/ClientOptions.h b/Swiften/Client/ClientOptions.h
index 3a93197..1a337b6 100644
--- a/Swiften/Client/ClientOptions.h
+++ b/Swiften/Client/ClientOptions.h
@@ -3,160 +3,152 @@
* All rights reserved.
* See the COPYING file for more information.
*/
#pragma once
#include <memory>
#include <Swiften/Base/API.h>
#include <Swiften/Base/SafeString.h>
#include <Swiften/Base/URL.h>
#include <Swiften/TLS/TLSOptions.h>
namespace Swift {
class HTTPTrafficFilter;
struct SWIFTEN_API ClientOptions {
enum UseTLS {
NeverUseTLS,
UseTLSWhenAvailable,
RequireTLS
};
enum ProxyType {
NoProxy,
SystemConfiguredProxy,
SOCKS5Proxy,
HTTPConnectProxy
};
- ClientOptions() :
- useStreamCompression(true),
- useTLS(UseTLSWhenAvailable),
- allowPLAINWithoutTLS(false),
- useStreamResumption(false),
- forgetPassword(false),
- useAcks(true),
- singleSignOn(false),
- manualHostname(""),
- manualPort(-1),
- proxyType(SystemConfiguredProxy),
- manualProxyHostname(""),
- manualProxyPort(-1),
- boshHTTPConnectProxyAuthID(""),
- boshHTTPConnectProxyAuthPassword("") {
+ ClientOptions() {
}
/**
* Whether ZLib stream compression should be used when available.
*
* Default: true
*/
- bool useStreamCompression;
+ bool useStreamCompression = true;
/**
* Sets whether TLS encryption should be used.
*
* Default: UseTLSWhenAvailable
*/
- UseTLS useTLS;
+ UseTLS useTLS = UseTLSWhenAvailable;
/**
* Sets whether plaintext authentication is
* allowed over non-TLS-encrypted connections.
*
* Default: false
*/
- bool allowPLAINWithoutTLS;
+ bool allowPLAINWithoutTLS = false;
/**
* Use XEP-196 stream resumption when available.
*
* Default: false
*/
- bool useStreamResumption;
+ bool useStreamResumption = false;
/**
* Forget the password once it's used.
* This makes the Client useless after the first login attempt.
*
* FIXME: This is a temporary workaround.
*
* Default: false
*/
- bool forgetPassword;
+ bool forgetPassword = false;
/**
* Use XEP-0198 acks in the stream when available.
* Default: true
*/
- bool useAcks;
+ bool useAcks = true;
/**
* Use Single Sign On.
* Default: false
*/
- bool singleSignOn;
+ bool singleSignOn = false;
/**
* The hostname to connect to.
* Leave this empty for standard XMPP connection, based on the JID domain.
*/
- std::string manualHostname;
+ std::string manualHostname = "";
/**
* The port to connect to.
* Leave this to -1 to use the port discovered by SRV lookups, and 5222 as a
* fallback.
*/
- int manualPort;
+ int manualPort = -1;
/**
* The type of proxy to use for connecting to the XMPP
* server.
*/
- ProxyType proxyType;
+ ProxyType proxyType = SystemConfiguredProxy;
/**
* Override the system-configured proxy hostname.
*/
- std::string manualProxyHostname;
+ std::string manualProxyHostname = "";
/**
* Override the system-configured proxy port.
*/
- int manualProxyPort;
+ int manualProxyPort = -1;
/**
* If non-empty, use BOSH instead of direct TCP, with the given URL.
* Default: empty (no BOSH)
*/
- URL boshURL;
+ URL boshURL = URL();
/**
* If non-empty, BOSH connections will try to connect over this HTTP CONNECT
* proxy instead of directly.
* Default: empty (no proxy)
*/
- URL boshHTTPConnectProxyURL;
+ URL boshHTTPConnectProxyURL = URL();
/**
* If this and matching Password are non-empty, BOSH connections over
* HTTP CONNECT proxies will use these credentials for proxy access.
* Default: empty (no authentication needed by the proxy)
*/
- SafeString boshHTTPConnectProxyAuthID;
- SafeString boshHTTPConnectProxyAuthPassword;
+ SafeString boshHTTPConnectProxyAuthID = SafeString("");
+ SafeString boshHTTPConnectProxyAuthPassword = SafeString("");
/**
* This can be initialized with a custom HTTPTrafficFilter, which allows HTTP CONNECT
* proxy initialization to be customized.
*/
std::shared_ptr<HTTPTrafficFilter> httpTrafficFilter;
/**
* Options passed to the TLS stack
*/
TLSOptions tlsOptions;
+
+ /**
+ * Session shutdown timeout in milliseconds. This is the maximum time Swiften
+ * waits from a session close to the socket close.
+ */
+ int sessionShutdownTimeoutInMilliseconds = 10000;
};
}
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp
index c301881..bcfb004 100644
--- a/Swiften/Client/ClientSession.cpp
+++ b/Swiften/Client/ClientSession.cpp
@@ -1,514 +1,576 @@
/*
* Copyright (c) 2010-2016 Isode Limited.
* All rights reserved.
* See the COPYING file for more information.
*/
#include <Swiften/Client/ClientSession.h>
+#include <memory>
+
#include <boost/bind.hpp>
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <boost/uuid/uuid_generators.hpp>
-#include <memory>
-#include <Swiften/Base/Platform.h>
#include <Swiften/Base/Log.h>
-#include <Swiften/Elements/ProtocolHeader.h>
-#include <Swiften/Elements/StreamFeatures.h>
-#include <Swiften/Elements/StreamError.h>
-#include <Swiften/Elements/StartTLSRequest.h>
-#include <Swiften/Elements/StartTLSFailure.h>
-#include <Swiften/Elements/TLSProceed.h>
-#include <Swiften/Elements/AuthRequest.h>
-#include <Swiften/Elements/AuthSuccess.h>
-#include <Swiften/Elements/AuthFailure.h>
+#include <Swiften/Base/Platform.h>
+#include <Swiften/Crypto/CryptoProvider.h>
#include <Swiften/Elements/AuthChallenge.h>
+#include <Swiften/Elements/AuthFailure.h>
+#include <Swiften/Elements/AuthRequest.h>
#include <Swiften/Elements/AuthResponse.h>
-#include <Swiften/Elements/Compressed.h>
+#include <Swiften/Elements/AuthSuccess.h>
#include <Swiften/Elements/CompressFailure.h>
#include <Swiften/Elements/CompressRequest.h>
+#include <Swiften/Elements/Compressed.h>
#include <Swiften/Elements/EnableStreamManagement.h>
-#include <Swiften/Elements/StreamManagementEnabled.h>
-#include <Swiften/Elements/StreamManagementFailed.h>
-#include <Swiften/Elements/StartSession.h>
-#include <Swiften/Elements/StanzaAck.h>
-#include <Swiften/Elements/StanzaAckRequest.h>
#include <Swiften/Elements/IQ.h>
+#include <Swiften/Elements/ProtocolHeader.h>
#include <Swiften/Elements/ResourceBind.h>
-#include <Swiften/SASL/PLAINClientAuthenticator.h>
+#include <Swiften/Elements/StanzaAck.h>
+#include <Swiften/Elements/StanzaAckRequest.h>
+#include <Swiften/Elements/StartSession.h>
+#include <Swiften/Elements/StartTLSFailure.h>
+#include <Swiften/Elements/StartTLSRequest.h>
+#include <Swiften/Elements/StreamError.h>
+#include <Swiften/Elements/StreamFeatures.h>
+#include <Swiften/Elements/StreamManagementEnabled.h>
+#include <Swiften/Elements/StreamManagementFailed.h>
+#include <Swiften/Elements/TLSProceed.h>
+#include <Swiften/Network/Timer.h>
+#include <Swiften/Network/TimerFactory.h>
+#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h>
#include <Swiften/SASL/EXTERNALClientAuthenticator.h>
+#include <Swiften/SASL/PLAINClientAuthenticator.h>
#include <Swiften/SASL/SCRAMSHA1ClientAuthenticator.h>
-#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h>
-#include <Swiften/Crypto/CryptoProvider.h>
#include <Swiften/Session/SessionStream.h>
+#include <Swiften/StreamManagement/StanzaAckRequester.h>
+#include <Swiften/StreamManagement/StanzaAckResponder.h>
#include <Swiften/TLS/CertificateTrustChecker.h>
#include <Swiften/TLS/ServerIdentityVerifier.h>
#ifdef SWIFTEN_PLATFORM_WIN32
#include <Swiften/Base/WindowsRegistry.h>
#include <Swiften/SASL/WindowsGSSAPIClientAuthenticator.h>
#endif
#define CHECK_STATE_OR_RETURN(a) \
if (!checkState(a)) { return; }
namespace Swift {
ClientSession::ClientSession(
const JID& jid,
std::shared_ptr<SessionStream> stream,
IDNConverter* idnConverter,
- CryptoProvider* crypto) :
+ CryptoProvider* crypto,
+ TimerFactory* timerFactory) :
localJID(jid),
- state(Initial),
+ state(State::Initial),
stream(stream),
idnConverter(idnConverter),
crypto(crypto),
+ timerFactory(timerFactory),
allowPLAINOverNonTLS(false),
useStreamCompression(true),
useTLS(UseTLSWhenAvailable),
useAcks(true),
needSessionStart(false),
needResourceBind(false),
needAcking(false),
rosterVersioningSupported(false),
authenticator(nullptr),
certificateTrustChecker(nullptr),
singleSignOn(false),
authenticationPort(-1) {
#ifdef SWIFTEN_PLATFORM_WIN32
if (WindowsRegistry::isFIPSEnabled()) {
SWIFT_LOG(info) << "Windows is running in FIPS-140 mode. Some authentication methods will be unavailable." << std::endl;
}
#endif
}
ClientSession::~ClientSession() {
}
void ClientSession::start() {
stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1));
+ stream->onStreamEndReceived.connect(boost::bind(&ClientSession::handleStreamEnd, shared_from_this()));
stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1));
stream->onClosed.connect(boost::bind(&ClientSession::handleStreamClosed, shared_from_this(), _1));
stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this()));
- assert(state == Initial);
- state = WaitingForStreamStart;
+ assert(state == State::Initial);
+ state = State::WaitingForStreamStart;
sendStreamHeader();
}
void ClientSession::sendStreamHeader() {
ProtocolHeader header;
header.setTo(getRemoteJID());
stream->writeHeader(header);
}
void ClientSession::sendStanza(std::shared_ptr<Stanza> stanza) {
stream->writeElement(stanza);
if (stanzaAckRequester_) {
stanzaAckRequester_->handleStanzaSent(stanza);
}
}
void ClientSession::handleStreamStart(const ProtocolHeader&) {
- CHECK_STATE_OR_RETURN(WaitingForStreamStart);
- state = Negotiating;
+ CHECK_STATE_OR_RETURN(State::WaitingForStreamStart);
+ state = State::Negotiating;
+}
+
+void ClientSession::handleStreamEnd() {
+ if (state == State::Finishing) {
+ // We are already in finishing state if we iniated the close of the session.
+ stream->close();
+ }
+ else {
+ state = State::Finishing;
+ initiateShutdown(true);
+ }
}
void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) {
if (std::shared_ptr<Stanza> stanza = std::dynamic_pointer_cast<Stanza>(element)) {
if (stanzaAckResponder_) {
stanzaAckResponder_->handleStanzaReceived();
}
- if (getState() == Initialized) {
+ if (getState() == State::Initialized) {
onStanzaReceived(stanza);
}
else if (std::shared_ptr<IQ> iq = std::dynamic_pointer_cast<IQ>(element)) {
- if (state == BindingResource) {
+ if (state == State::BindingResource) {
std::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>());
if (iq->getType() == IQ::Error && iq->getID() == "session-bind") {
finishSession(Error::ResourceBindError);
}
else if (!resourceBind) {
finishSession(Error::UnexpectedElementError);
}
else if (iq->getType() == IQ::Result) {
localJID = resourceBind->getJID();
if (!localJID.isValid()) {
finishSession(Error::ResourceBindError);
}
needResourceBind = false;
continueSessionInitialization();
}
else {
finishSession(Error::UnexpectedElementError);
}
}
- else if (state == StartingSession) {
+ else if (state == State::StartingSession) {
if (iq->getType() == IQ::Result) {
needSessionStart = false;
continueSessionInitialization();
}
else if (iq->getType() == IQ::Error) {
finishSession(Error::SessionStartError);
}
else {
finishSession(Error::UnexpectedElementError);
}
}
else {
finishSession(Error::UnexpectedElementError);
}
}
}
else if (std::dynamic_pointer_cast<StanzaAckRequest>(element)) {
if (stanzaAckResponder_) {
stanzaAckResponder_->handleAckRequestReceived();
}
}
else if (std::shared_ptr<StanzaAck> ack = std::dynamic_pointer_cast<StanzaAck>(element)) {
if (stanzaAckRequester_) {
if (ack->isValid()) {
stanzaAckRequester_->handleAckReceived(ack->getHandledStanzasCount());
}
else {
SWIFT_LOG(warning) << "Got invalid ack from server";
}
}
else {
SWIFT_LOG(warning) << "Ignoring ack";
}
}
else if (StreamError::ref streamError = std::dynamic_pointer_cast<StreamError>(element)) {
finishSession(Error::StreamError);
}
- else if (getState() == Initialized) {
+ else if (getState() == State::Initialized) {
std::shared_ptr<Stanza> stanza = std::dynamic_pointer_cast<Stanza>(element);
if (stanza) {
if (stanzaAckResponder_) {
stanzaAckResponder_->handleStanzaReceived();
}
onStanzaReceived(stanza);
}
}
else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) {
- CHECK_STATE_OR_RETURN(Negotiating);
+ CHECK_STATE_OR_RETURN(State::Negotiating);
if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption() && useTLS != NeverUseTLS) {
- state = WaitingForEncrypt;
+ state = State::WaitingForEncrypt;
stream->writeElement(std::make_shared<StartTLSRequest>());
}
else if (useTLS == RequireTLS && !stream->isTLSEncrypted()) {
finishSession(Error::NoSupportedAuthMechanismsError);
}
else if (useStreamCompression && stream->supportsZLibCompression() && streamFeatures->hasCompressionMethod("zlib")) {
- state = Compressing;
+ state = State::Compressing;
stream->writeElement(std::make_shared<CompressRequest>("zlib"));
}
else if (streamFeatures->hasAuthenticationMechanisms()) {
#ifdef SWIFTEN_PLATFORM_WIN32
if (singleSignOn) {
const boost::optional<std::string> authenticationHostname = streamFeatures->getAuthenticationHostname();
bool gssapiSupported = streamFeatures->hasAuthenticationMechanism("GSSAPI") && authenticationHostname && !authenticationHostname->empty();
if (!gssapiSupported) {
finishSession(Error::NoSupportedAuthMechanismsError);
}
else {
WindowsGSSAPIClientAuthenticator* gssapiAuthenticator = new WindowsGSSAPIClientAuthenticator(*authenticationHostname, localJID.getDomain(), authenticationPort);
std::shared_ptr<Error> error = std::make_shared<Error>(Error::AuthenticationFailedError);
authenticator = gssapiAuthenticator;
if (!gssapiAuthenticator->isError()) {
- state = Authenticating;
+ state = State::Authenticating;
stream->writeElement(std::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse()));
}
else {
error->errorCode = gssapiAuthenticator->getErrorCode();
finishSession(error);
}
}
}
else
#endif
if (stream->hasTLSCertificate()) {
if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
authenticator = new EXTERNALClientAuthenticator();
- state = Authenticating;
+ state = State::Authenticating;
stream->writeElement(std::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray("")));
}
else {
finishSession(Error::TLSClientCertificateError);
}
}
else if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
authenticator = new EXTERNALClientAuthenticator();
- state = Authenticating;
+ state = State::Authenticating;
stream->writeElement(std::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray("")));
}
else if (streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1") || streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1-PLUS")) {
std::ostringstream s;
ByteArray finishMessage;
bool plus = streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1-PLUS");
if (stream->isTLSEncrypted()) {
finishMessage = stream->getTLSFinishMessage();
plus &= !finishMessage.empty();
}
s << boost::uuids::random_generator()();
SCRAMSHA1ClientAuthenticator* scramAuthenticator = new SCRAMSHA1ClientAuthenticator(s.str(), plus, idnConverter, crypto);
if (!finishMessage.empty()) {
scramAuthenticator->setTLSChannelBindingData(finishMessage);
}
authenticator = scramAuthenticator;
- state = WaitingForCredentials;
+ state = State::WaitingForCredentials;
onNeedCredentials();
}
else if ((stream->isTLSEncrypted() || allowPLAINOverNonTLS) && streamFeatures->hasAuthenticationMechanism("PLAIN")) {
authenticator = new PLAINClientAuthenticator();
- state = WaitingForCredentials;
+ state = State::WaitingForCredentials;
onNeedCredentials();
}
else if (streamFeatures->hasAuthenticationMechanism("DIGEST-MD5") && crypto->isMD5AllowedForCrypto()) {
std::ostringstream s;
s << boost::uuids::random_generator()();
// FIXME: Host should probably be the actual host
authenticator = new DIGESTMD5ClientAuthenticator(localJID.getDomain(), s.str(), crypto);
- state = WaitingForCredentials;
+ state = State::WaitingForCredentials;
onNeedCredentials();
}
else {
finishSession(Error::NoSupportedAuthMechanismsError);
}
}
else {
// Start the session
rosterVersioningSupported = streamFeatures->hasRosterVersioning();
stream->setWhitespacePingEnabled(true);
needSessionStart = streamFeatures->hasSession();
needResourceBind = streamFeatures->hasResourceBind();
needAcking = streamFeatures->hasStreamManagement() && useAcks;
if (!needResourceBind) {
// Resource binding is a MUST
finishSession(Error::ResourceBindError);
}
else {
continueSessionInitialization();
}
}
}
else if (std::dynamic_pointer_cast<Compressed>(element)) {
- CHECK_STATE_OR_RETURN(Compressing);
- state = WaitingForStreamStart;
+ CHECK_STATE_OR_RETURN(State::Compressing);
+ state = State::WaitingForStreamStart;
stream->addZLibCompression();
stream->resetXMPPParser();
sendStreamHeader();
}
else if (std::dynamic_pointer_cast<CompressFailure>(element)) {
finishSession(Error::CompressionFailedError);
}
else if (std::dynamic_pointer_cast<StreamManagementEnabled>(element)) {
stanzaAckRequester_ = std::make_shared<StanzaAckRequester>();
stanzaAckRequester_->onRequestAck.connect(boost::bind(&ClientSession::requestAck, shared_from_this()));
stanzaAckRequester_->onStanzaAcked.connect(boost::bind(&ClientSession::handleStanzaAcked, shared_from_this(), _1));
stanzaAckResponder_ = std::make_shared<StanzaAckResponder>();
stanzaAckResponder_->onAck.connect(boost::bind(&ClientSession::ack, shared_from_this(), _1));
needAcking = false;
continueSessionInitialization();
}
else if (std::dynamic_pointer_cast<StreamManagementFailed>(element)) {
needAcking = false;
continueSessionInitialization();
}
else if (AuthChallenge* challenge = dynamic_cast<AuthChallenge*>(element.get())) {
- CHECK_STATE_OR_RETURN(Authenticating);
+ CHECK_STATE_OR_RETURN(State::Authenticating);
assert(authenticator);
if (authenticator->setChallenge(challenge->getValue())) {
stream->writeElement(std::make_shared<AuthResponse>(authenticator->getResponse()));
}
#ifdef SWIFTEN_PLATFORM_WIN32
else if (WindowsGSSAPIClientAuthenticator* gssapiAuthenticator = dynamic_cast<WindowsGSSAPIClientAuthenticator*>(authenticator)) {
std::shared_ptr<Error> error = std::make_shared<Error>(Error::AuthenticationFailedError);
error->errorCode = gssapiAuthenticator->getErrorCode();
finishSession(error);
}
#endif
else {
finishSession(Error::AuthenticationFailedError);
}
}
else if (AuthSuccess* authSuccess = dynamic_cast<AuthSuccess*>(element.get())) {
- CHECK_STATE_OR_RETURN(Authenticating);
+ CHECK_STATE_OR_RETURN(State::Authenticating);
assert(authenticator);
if (!authenticator->setChallenge(authSuccess->getValue())) {
finishSession(Error::ServerVerificationFailedError);
}
else {
- state = WaitingForStreamStart;
+ state = State::WaitingForStreamStart;
delete authenticator;
authenticator = nullptr;
stream->resetXMPPParser();
sendStreamHeader();
}
}
else if (dynamic_cast<AuthFailure*>(element.get())) {
finishSession(Error::AuthenticationFailedError);
}
else if (dynamic_cast<TLSProceed*>(element.get())) {
- CHECK_STATE_OR_RETURN(WaitingForEncrypt);
- state = Encrypting;
+ CHECK_STATE_OR_RETURN(State::WaitingForEncrypt);
+ state = State::Encrypting;
stream->addTLSEncryption();
}
else if (dynamic_cast<StartTLSFailure*>(element.get())) {
finishSession(Error::TLSError);
}
else {
// FIXME Not correct?
- state = Initialized;
+ state = State::Initialized;
onInitialized();
}
}
void ClientSession::continueSessionInitialization() {
if (needResourceBind) {
- state = BindingResource;
+ state = State::BindingResource;
std::shared_ptr<ResourceBind> resourceBind(std::make_shared<ResourceBind>());
if (!localJID.getResource().empty()) {
resourceBind->setResource(localJID.getResource());
}
sendStanza(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind));
}
else if (needAcking) {
- state = EnablingSessionManagement;
+ state = State::EnablingSessionManagement;
stream->writeElement(std::make_shared<EnableStreamManagement>());
}
else if (needSessionStart) {
- state = StartingSession;
+ state = State::StartingSession;
sendStanza(IQ::createRequest(IQ::Set, JID(), "session-start", std::make_shared<StartSession>()));
}
else {
- state = Initialized;
+ state = State::Initialized;
onInitialized();
}
}
bool ClientSession::checkState(State state) {
if (this->state != state) {
finishSession(Error::UnexpectedElementError);
return false;
}
return true;
}
void ClientSession::sendCredentials(const SafeByteArray& password) {
- assert(WaitingForCredentials);
+ assert(state == State::WaitingForCredentials);
assert(authenticator);
- state = Authenticating;
+ state = State::Authenticating;
authenticator->setCredentials(localJID.getNode(), password);
stream->writeElement(std::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse()));
}
void ClientSession::handleTLSEncrypted() {
- CHECK_STATE_OR_RETURN(Encrypting);
+ CHECK_STATE_OR_RETURN(State::Encrypting);
std::vector<Certificate::ref> certificateChain = stream->getPeerCertificateChain();
std::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError();
if (verificationError) {
checkTrustOrFinish(certificateChain, verificationError);
}
else {
ServerIdentityVerifier identityVerifier(localJID, idnConverter);
if (!certificateChain.empty() && identityVerifier.certificateVerifies(certificateChain[0])) {
continueAfterTLSEncrypted();
}
else {
checkTrustOrFinish(certificateChain, std::make_shared<CertificateVerificationError>(CertificateVerificationError::InvalidServerIdentity));
}
}
}
void ClientSession::checkTrustOrFinish(const std::vector<Certificate::ref>& certificateChain, std::shared_ptr<CertificateVerificationError> error) {
if (certificateTrustChecker && certificateTrustChecker->isCertificateTrusted(certificateChain)) {
continueAfterTLSEncrypted();
}
else {
finishSession(error);
}
}
+void ClientSession::initiateShutdown(bool sendFooter) {
+ if (!streamShutdownTimeout) {
+ streamShutdownTimeout = timerFactory->createTimer(sessionShutdownTimeoutInMilliseconds);
+ streamShutdownTimeout->onTick.connect(boost::bind(&ClientSession::handleStreamShutdownTimeout, shared_from_this()));
+ streamShutdownTimeout->start();
+ }
+ if (sendFooter) {
+ stream->writeFooter();
+ }
+ if (state == State::Finishing) {
+ // The other side already send </stream>; we can close the socket.
+ stream->close();
+ }
+ else {
+ state = State::Finishing;
+ }
+}
+
void ClientSession::continueAfterTLSEncrypted() {
- state = WaitingForStreamStart;
+ state = State::WaitingForStreamStart;
stream->resetXMPPParser();
sendStreamHeader();
}
void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError) {
State previousState = state;
- state = Finished;
+ state = State::Finished;
+
+ if (streamShutdownTimeout) {
+ streamShutdownTimeout->stop();
+ streamShutdownTimeout.reset();
+ }
if (stanzaAckRequester_) {
stanzaAckRequester_->onRequestAck.disconnect(boost::bind(&ClientSession::requestAck, shared_from_this()));
stanzaAckRequester_->onStanzaAcked.disconnect(boost::bind(&ClientSession::handleStanzaAcked, shared_from_this(), _1));
stanzaAckRequester_.reset();
}
if (stanzaAckResponder_) {
stanzaAckResponder_->onAck.disconnect(boost::bind(&ClientSession::ack, shared_from_this(), _1));
stanzaAckResponder_.reset();
}
stream->setWhitespacePingEnabled(false);
stream->onStreamStartReceived.disconnect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1));
+ stream->onStreamEndReceived.disconnect(boost::bind(&ClientSession::handleStreamEnd, shared_from_this()));
stream->onElementReceived.disconnect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1));
stream->onClosed.disconnect(boost::bind(&ClientSession::handleStreamClosed, shared_from_this(), _1));
stream->onTLSEncrypted.disconnect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this()));
- if (previousState == Finishing) {
+ if (previousState == State::Finishing) {
onFinished(error_);
}
else {
onFinished(streamError);
}
}
+void ClientSession::handleStreamShutdownTimeout() {
+ handleStreamClosed(std::shared_ptr<Swift::Error>());
+}
+
void ClientSession::finish() {
- finishSession(std::shared_ptr<Error>());
+ if (state != State::Finishing && state != State::Finished) {
+ finishSession(std::shared_ptr<Error>());
+ }
+ else {
+ SWIFT_LOG(warning) << "Session already finished or finishing." << std::endl;
+ }
}
void ClientSession::finishSession(Error::Type error) {
finishSession(std::make_shared<Swift::ClientSession::Error>(error));
}
void ClientSession::finishSession(std::shared_ptr<Swift::Error> error) {
- state = Finishing;
if (!error_) {
error_ = error;
}
else {
SWIFT_LOG(warning) << "Session finished twice";
}
assert(stream->isOpen());
if (stanzaAckResponder_) {
stanzaAckResponder_->handleAckRequestReceived();
}
if (authenticator) {
delete authenticator;
authenticator = nullptr;
}
- stream->writeFooter();
- stream->close();
+ // Immidiately close TCP connection without stream closure.
+ if (std::dynamic_pointer_cast<CertificateVerificationError>(error)) {
+ state = State::Finishing;
+ initiateShutdown(false);
+ }
+ else {
+ if (state == State::Finishing) {
+ initiateShutdown(true);
+ }
+ else if (state != State::Finished) {
+ initiateShutdown(true);
+ }
+ }
}
void ClientSession::requestAck() {
stream->writeElement(std::make_shared<StanzaAckRequest>());
}
void ClientSession::handleStanzaAcked(std::shared_ptr<Stanza> stanza) {
onStanzaAcked(stanza);
}
void ClientSession::ack(unsigned int handledStanzasCount) {
stream->writeElement(std::make_shared<StanzaAck>(handledStanzasCount));
}
}
diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h
index ad1b78c..c7b3658 100644
--- a/Swiften/Client/ClientSession.h
+++ b/Swiften/Client/ClientSession.h
@@ -1,201 +1,215 @@
/*
* Copyright (c) 2010-2016 Isode Limited.
* All rights reserved.
* See the COPYING file for more information.
*/
#pragma once
#include <memory>
#include <string>
#include <boost/signals2.hpp>
#include <Swiften/Base/API.h>
#include <Swiften/Base/Error.h>
#include <Swiften/Elements/ToplevelElement.h>
#include <Swiften/JID/JID.h>
#include <Swiften/Session/SessionStream.h>
-#include <Swiften/StreamManagement/StanzaAckRequester.h>
-#include <Swiften/StreamManagement/StanzaAckResponder.h>
namespace Swift {
- class ClientAuthenticator;
class CertificateTrustChecker;
- class IDNConverter;
+ class ClientAuthenticator;
class CryptoProvider;
+ class IDNConverter;
+ class Stanza;
+ class StanzaAckRequester;
+ class StanzaAckResponder;
+ class TimerFactory;
+ class Timer;
class SWIFTEN_API ClientSession : public std::enable_shared_from_this<ClientSession> {
public:
- enum State {
+ enum class State {
Initial,
WaitingForStreamStart,
Negotiating,
Compressing,
WaitingForEncrypt,
Encrypting,
WaitingForCredentials,
Authenticating,
EnablingSessionManagement,
BindingResource,
StartingSession,
Initialized,
Finishing,
Finished
};
struct Error : public Swift::Error {
enum Type {
AuthenticationFailedError,
CompressionFailedError,
ServerVerificationFailedError,
NoSupportedAuthMechanismsError,
UnexpectedElementError,
ResourceBindError,
SessionStartError,
TLSClientCertificateError,
TLSError,
- StreamError
+ StreamError,
+ StreamEndError, // The server send a closing stream tag.
} type;
std::shared_ptr<boost::system::error_code> errorCode;
Error(Type type) : type(type) {}
};
enum UseTLS {
NeverUseTLS,
UseTLSWhenAvailable,
RequireTLS
};
~ClientSession();
- static std::shared_ptr<ClientSession> create(const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, CryptoProvider* crypto) {
- return std::shared_ptr<ClientSession>(new ClientSession(jid, stream, idnConverter, crypto));
+ static std::shared_ptr<ClientSession> create(const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, CryptoProvider* crypto, TimerFactory* timerFactory) {
+ return std::shared_ptr<ClientSession>(new ClientSession(jid, stream, idnConverter, crypto, timerFactory));
}
State getState() const {
return state;
}
void setAllowPLAINOverNonTLS(bool b) {
allowPLAINOverNonTLS = b;
}
void setUseStreamCompression(bool b) {
useStreamCompression = b;
}
void setUseTLS(UseTLS b) {
useTLS = b;
}
void setUseAcks(bool b) {
useAcks = b;
}
-
bool getStreamManagementEnabled() const {
// Explicitly convert to bool. In C++11, it would be cleaner to
// compare to nullptr.
return static_cast<bool>(stanzaAckRequester_);
}
bool getRosterVersioningSupported() const {
return rosterVersioningSupported;
}
std::vector<Certificate::ref> getPeerCertificateChain() const {
return stream->getPeerCertificateChain();
}
const JID& getLocalJID() const {
return localJID;
}
void start();
void finish();
bool isFinished() const {
- return getState() == Finished;
+ return getState() == State::Finished;
}
void sendCredentials(const SafeByteArray& password);
void sendStanza(std::shared_ptr<Stanza>);
void setCertificateTrustChecker(CertificateTrustChecker* checker) {
certificateTrustChecker = checker;
}
void setSingleSignOn(bool b) {
singleSignOn = b;
}
/**
* Sets the port number used in Kerberos authentication
* Does not affect network connectivity.
*/
void setAuthenticationPort(int i) {
authenticationPort = i;
}
+ void setSessionShutdownTimeout(int timeoutInMilliseconds) {
+ sessionShutdownTimeoutInMilliseconds = timeoutInMilliseconds;
+ }
+
public:
boost::signals2::signal<void ()> onNeedCredentials;
boost::signals2::signal<void ()> onInitialized;
boost::signals2::signal<void (std::shared_ptr<Swift::Error>)> onFinished;
boost::signals2::signal<void (std::shared_ptr<Stanza>)> onStanzaReceived;
boost::signals2::signal<void (std::shared_ptr<Stanza>)> onStanzaAcked;
private:
ClientSession(
const JID& jid,
std::shared_ptr<SessionStream>,
IDNConverter* idnConverter,
- CryptoProvider* crypto);
+ CryptoProvider* crypto,
+ TimerFactory* timerFactory);
void finishSession(Error::Type error);
void finishSession(std::shared_ptr<Swift::Error> error);
JID getRemoteJID() const {
return JID("", localJID.getDomain());
}
void sendStreamHeader();
void handleElement(std::shared_ptr<ToplevelElement>);
void handleStreamStart(const ProtocolHeader&);
+ void handleStreamEnd();
void handleStreamClosed(std::shared_ptr<Swift::Error>);
+ void handleStreamShutdownTimeout();
void handleTLSEncrypted();
bool checkState(State);
void continueSessionInitialization();
void requestAck();
void handleStanzaAcked(std::shared_ptr<Stanza> stanza);
void ack(unsigned int handledStanzasCount);
void continueAfterTLSEncrypted();
void checkTrustOrFinish(const std::vector<Certificate::ref>& certificateChain, std::shared_ptr<CertificateVerificationError> error);
+ void initiateShutdown(bool sendFooter);
private:
JID localJID;
State state;
std::shared_ptr<SessionStream> stream;
- IDNConverter* idnConverter;
- CryptoProvider* crypto;
+ IDNConverter* idnConverter = nullptr;
+ CryptoProvider* crypto = nullptr;
+ TimerFactory* timerFactory = nullptr;
+ std::shared_ptr<Timer> streamShutdownTimeout;
+ int sessionShutdownTimeoutInMilliseconds = 10000;
bool allowPLAINOverNonTLS;
bool useStreamCompression;
UseTLS useTLS;
bool useAcks;
bool needSessionStart;
bool needResourceBind;
bool needAcking;
bool rosterVersioningSupported;
ClientAuthenticator* authenticator;
std::shared_ptr<StanzaAckRequester> stanzaAckRequester_;
std::shared_ptr<StanzaAckResponder> stanzaAckResponder_;
std::shared_ptr<Swift::Error> error_;
CertificateTrustChecker* certificateTrustChecker;
bool singleSignOn;
int authenticationPort;
};
}
diff --git a/Swiften/Client/ClientSessionStanzaChannel.h b/Swiften/Client/ClientSessionStanzaChannel.h
index 0527a5c..c4ee393 100644
--- a/Swiften/Client/ClientSessionStanzaChannel.h
+++ b/Swiften/Client/ClientSessionStanzaChannel.h
@@ -6,47 +6,47 @@
#pragma once
#include <memory>
#include <Swiften/Base/API.h>
#include <Swiften/Base/IDGenerator.h>
#include <Swiften/Client/ClientSession.h>
#include <Swiften/Client/StanzaChannel.h>
#include <Swiften/Elements/IQ.h>
#include <Swiften/Elements/Message.h>
#include <Swiften/Elements/Presence.h>
namespace Swift {
/**
* StanzaChannel implementation around a ClientSession.
*/
class SWIFTEN_API ClientSessionStanzaChannel : public StanzaChannel {
public:
virtual ~ClientSessionStanzaChannel();
void setSession(std::shared_ptr<ClientSession> session);
void sendIQ(std::shared_ptr<IQ> iq);
void sendMessage(std::shared_ptr<Message> message);
void sendPresence(std::shared_ptr<Presence> presence);
bool getStreamManagementEnabled() const;
virtual std::vector<Certificate::ref> getPeerCertificateChain() const;
bool isAvailable() const {
- return session && session->getState() == ClientSession::Initialized;
+ return session && session->getState() == ClientSession::State::Initialized;
}
private:
std::string getNewIQID();
void send(std::shared_ptr<Stanza> stanza);
void handleSessionFinished(std::shared_ptr<Error> error);
void handleStanza(std::shared_ptr<Stanza> stanza);
void handleStanzaAcked(std::shared_ptr<Stanza> stanza);
void handleSessionInitialized();
private:
IDGenerator idGenerator;
std::shared_ptr<ClientSession> session;
};
}
diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp
index 3d75d8b..3c7902e 100644
--- a/Swiften/Client/CoreClient.cpp
+++ b/Swiften/Client/CoreClient.cpp
@@ -127,66 +127,67 @@ void CoreClient::connect(const ClientOptions& o) {
std::shared_ptr<BOSHSessionStream> boshSessionStream_ = std::shared_ptr<BOSHSessionStream>(new BOSHSessionStream(
options.boshURL,
getPayloadParserFactories(),
getPayloadSerializers(),
networkFactories->getConnectionFactory(),
networkFactories->getTLSContextFactory(),
networkFactories->getTimerFactory(),
networkFactories->getXMLParserFactory(),
networkFactories->getEventLoop(),
networkFactories->getDomainNameResolver(),
host,
options.boshHTTPConnectProxyURL,
options.boshHTTPConnectProxyAuthID,
options.boshHTTPConnectProxyAuthPassword,
options.tlsOptions,
options.httpTrafficFilter));
sessionStream_ = boshSessionStream_;
sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1));
sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1));
if (certificate_ && !certificate_->isNull()) {
SWIFT_LOG(debug) << "set certificate" << std::endl;
sessionStream_->setTLSCertificate(certificate_);
}
boshSessionStream_->open();
bindSessionToStream();
}
}
void CoreClient::bindSessionToStream() {
- session_ = ClientSession::create(jid_, sessionStream_, networkFactories->getIDNConverter(), networkFactories->getCryptoProvider());
+ session_ = ClientSession::create(jid_, sessionStream_, networkFactories->getIDNConverter(), networkFactories->getCryptoProvider(), networkFactories->getTimerFactory());
session_->setCertificateTrustChecker(certificateTrustChecker);
session_->setUseStreamCompression(options.useStreamCompression);
session_->setAllowPLAINOverNonTLS(options.allowPLAINWithoutTLS);
session_->setSingleSignOn(options.singleSignOn);
session_->setAuthenticationPort(options.manualPort);
+ session_->setSessionShutdownTimeout(options.sessionShutdownTimeoutInMilliseconds);
switch(options.useTLS) {
case ClientOptions::UseTLSWhenAvailable:
session_->setUseTLS(ClientSession::UseTLSWhenAvailable);
break;
case ClientOptions::NeverUseTLS:
session_->setUseTLS(ClientSession::NeverUseTLS);
break;
case ClientOptions::RequireTLS:
session_->setUseTLS(ClientSession::RequireTLS);
break;
}
session_->setUseAcks(options.useAcks);
stanzaChannel_->setSession(session_);
session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1));
session_->onNeedCredentials.connect(boost::bind(&CoreClient::handleNeedCredentials, this));
session_->start();
}
/**
* Only called for TCP sessions. BOSH is handled inside the BOSHSessionStream.
*/
void CoreClient::handleConnectorFinished(std::shared_ptr<Connection> connection, std::shared_ptr<Error> error) {
resetConnector();
if (!connection) {
if (options.forgetPassword) {
purgePassword();
}
boost::optional<ClientError> clientError;
if (!disconnectRequested_) {
clientError = std::dynamic_pointer_cast<DomainNameResolveError>(error) ? boost::optional<ClientError>(ClientError::DomainNameResolveError) : boost::optional<ClientError>(ClientError::ConnectionError);
@@ -246,60 +247,63 @@ void CoreClient::handleSessionFinished(std::shared_ptr<Error> error) {
case ClientSession::Error::AuthenticationFailedError:
clientError = ClientError(ClientError::AuthenticationFailedError);
break;
case ClientSession::Error::CompressionFailedError:
clientError = ClientError(ClientError::CompressionFailedError);
break;
case ClientSession::Error::ServerVerificationFailedError:
clientError = ClientError(ClientError::ServerVerificationFailedError);
break;
case ClientSession::Error::NoSupportedAuthMechanismsError:
clientError = ClientError(ClientError::NoSupportedAuthMechanismsError);
break;
case ClientSession::Error::UnexpectedElementError:
clientError = ClientError(ClientError::UnexpectedElementError);
break;
case ClientSession::Error::ResourceBindError:
clientError = ClientError(ClientError::ResourceBindError);
break;
case ClientSession::Error::SessionStartError:
clientError = ClientError(ClientError::SessionStartError);
break;
case ClientSession::Error::TLSError:
clientError = ClientError(ClientError::TLSError);
break;
case ClientSession::Error::TLSClientCertificateError:
clientError = ClientError(ClientError::ClientCertificateError);
break;
case ClientSession::Error::StreamError:
clientError = ClientError(ClientError::StreamError);
break;
+ case ClientSession::Error::StreamEndError:
+ clientError = ClientError(ClientError::StreamError);
+ break;
}
clientError.setErrorCode(actualError->errorCode);
}
else if (std::shared_ptr<TLSError> actualError = std::dynamic_pointer_cast<TLSError>(error)) {
switch(actualError->getType()) {
case TLSError::CertificateCardRemoved:
clientError = ClientError(ClientError::CertificateCardRemoved);
break;
case TLSError::UnknownError:
clientError = ClientError(ClientError::TLSError);
break;
}
}
else if (std::shared_ptr<SessionStream::SessionStreamError> actualError = std::dynamic_pointer_cast<SessionStream::SessionStreamError>(error)) {
switch(actualError->type) {
case SessionStream::SessionStreamError::ParseError:
clientError = ClientError(ClientError::XMLError);
break;
case SessionStream::SessionStreamError::TLSError:
clientError = ClientError(ClientError::TLSError);
break;
case SessionStream::SessionStreamError::InvalidTLSCertificateError:
clientError = ClientError(ClientError::ClientCertificateLoadError);
break;
case SessionStream::SessionStreamError::ConnectionReadError:
clientError = ClientError(ClientError::ConnectionReadError);
break;
case SessionStream::SessionStreamError::ConnectionWriteError:
clientError = ClientError(ClientError::ConnectionWriteError);
break;
diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp
index bd93f4b..44df961 100644
--- a/Swiften/Client/UnitTest/ClientSessionTest.cpp
+++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp
@@ -1,380 +1,458 @@
/*
* Copyright (c) 2010-2016 Isode Limited.
* All rights reserved.
* See the COPYING file for more information.
*/
#include <deque>
#include <memory>
#include <boost/bind.hpp>
#include <boost/optional.hpp>
+#include <Swiften/Base/Debug.h>
+
#include <cppunit/extensions/HelperMacros.h>
#include <cppunit/extensions/TestFactoryRegistry.h>
#include <Swiften/Client/ClientSession.h>
#include <Swiften/Crypto/CryptoProvider.h>
#include <Swiften/Crypto/PlatformCryptoProvider.h>
#include <Swiften/Elements/AuthChallenge.h>
#include <Swiften/Elements/AuthFailure.h>
#include <Swiften/Elements/AuthRequest.h>
#include <Swiften/Elements/AuthSuccess.h>
#include <Swiften/Elements/EnableStreamManagement.h>
#include <Swiften/Elements/IQ.h>
#include <Swiften/Elements/Message.h>
#include <Swiften/Elements/ResourceBind.h>
#include <Swiften/Elements/StanzaAck.h>
#include <Swiften/Elements/StartTLSFailure.h>
#include <Swiften/Elements/StartTLSRequest.h>
#include <Swiften/Elements/StreamError.h>
#include <Swiften/Elements/StreamFeatures.h>
#include <Swiften/Elements/StreamManagementEnabled.h>
#include <Swiften/Elements/StreamManagementFailed.h>
#include <Swiften/Elements/TLSProceed.h>
#include <Swiften/IDN/IDNConverter.h>
#include <Swiften/IDN/PlatformIDNConverter.h>
+#include <Swiften/Network/DummyTimerFactory.h>
#include <Swiften/Session/SessionStream.h>
#include <Swiften/TLS/BlindCertificateTrustChecker.h>
#include <Swiften/TLS/SimpleCertificate.h>
using namespace Swift;
class ClientSessionTest : public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE(ClientSessionTest);
CPPUNIT_TEST(testStart_Error);
CPPUNIT_TEST(testStart_StreamError);
CPPUNIT_TEST(testStartTLS);
CPPUNIT_TEST(testStartTLS_ServerError);
CPPUNIT_TEST(testStartTLS_ConnectError);
CPPUNIT_TEST(testStartTLS_InvalidIdentity);
CPPUNIT_TEST(testStart_StreamFeaturesWithoutResourceBindingFails);
CPPUNIT_TEST(testAuthenticate);
CPPUNIT_TEST(testAuthenticate_Unauthorized);
CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms);
CPPUNIT_TEST(testAuthenticate_PLAINOverNonTLS);
CPPUNIT_TEST(testAuthenticate_RequireTLS);
CPPUNIT_TEST(testAuthenticate_EXTERNAL);
CPPUNIT_TEST(testStreamManagement);
CPPUNIT_TEST(testStreamManagement_Failed);
CPPUNIT_TEST(testUnexpectedChallenge);
CPPUNIT_TEST(testFinishAcksStanzas);
+
+ CPPUNIT_TEST(testServerInitiatedSessionClose);
+ CPPUNIT_TEST(testClientInitiatedSessionClose);
+ CPPUNIT_TEST(testTimeoutOnShutdown);
+
/*
CPPUNIT_TEST(testResourceBind);
CPPUNIT_TEST(testResourceBind_ChangeResource);
CPPUNIT_TEST(testResourceBind_EmptyResource);
CPPUNIT_TEST(testResourceBind_Error);
CPPUNIT_TEST(testSessionStart);
CPPUNIT_TEST(testSessionStart_Error);
CPPUNIT_TEST(testSessionStart_AfterResourceBind);
CPPUNIT_TEST(testWhitespacePing);
CPPUNIT_TEST(testReceiveElementAfterSessionStarted);
CPPUNIT_TEST(testSendElement);
*/
CPPUNIT_TEST_SUITE_END();
public:
void setUp() {
crypto = std::shared_ptr<CryptoProvider>(PlatformCryptoProvider::create());
idnConverter = std::shared_ptr<IDNConverter>(PlatformIDNConverter::create());
+ timerFactory = std::make_shared<DummyTimerFactory>();
server = std::make_shared<MockSessionStream>();
sessionFinishedReceived = false;
needCredentials = false;
blindCertificateTrustChecker = new BlindCertificateTrustChecker();
}
void tearDown() {
delete blindCertificateTrustChecker;
}
void testStart_Error() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->breakConnection();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
void testStart_StreamError() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->sendStreamStart();
server->sendStreamError();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
void testStartTLS() {
std::shared_ptr<ClientSession> session(createSession());
session->setCertificateTrustChecker(blindCertificateTrustChecker);
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithStartTLS();
server->receiveStartTLS();
CPPUNIT_ASSERT(!server->tlsEncrypted);
server->sendTLSProceed();
CPPUNIT_ASSERT(server->tlsEncrypted);
server->onTLSEncrypted();
server->receiveStreamStart();
server->sendStreamStart();
session->finish();
}
void testStartTLS_ServerError() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithStartTLS();
server->receiveStartTLS();
server->sendTLSFailure();
CPPUNIT_ASSERT(!server->tlsEncrypted);
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
void testStartTLS_ConnectError() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithStartTLS();
server->receiveStartTLS();
server->sendTLSProceed();
server->breakTLS();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
void testStartTLS_InvalidIdentity() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithStartTLS();
server->receiveStartTLS();
CPPUNIT_ASSERT(!server->tlsEncrypted);
server->sendTLSProceed();
CPPUNIT_ASSERT(server->tlsEncrypted);
server->onTLSEncrypted();
+ server->close();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
+ CPPUNIT_ASSERT(std::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError));
CPPUNIT_ASSERT_EQUAL(CertificateVerificationError::InvalidServerIdentity, std::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError)->getType());
}
void testStart_StreamFeaturesWithoutResourceBindingFails() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendEmptyStreamFeatures();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
void testAuthenticate() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithPLAINAuthentication();
CPPUNIT_ASSERT(needCredentials);
- CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::WaitingForCredentials, session->getState());
session->sendCredentials(createSafeByteArray("mypass"));
server->receiveAuthRequest("PLAIN");
server->sendAuthSuccess();
server->receiveStreamStart();
session->finish();
}
void testAuthenticate_Unauthorized() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithPLAINAuthentication();
CPPUNIT_ASSERT(needCredentials);
- CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::WaitingForCredentials, session->getState());
session->sendCredentials(createSafeByteArray("mypass"));
server->receiveAuthRequest("PLAIN");
server->sendAuthFailure();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
void testAuthenticate_PLAINOverNonTLS() {
std::shared_ptr<ClientSession> session(createSession());
session->setAllowPLAINOverNonTLS(false);
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithPLAINAuthentication();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
void testAuthenticate_RequireTLS() {
std::shared_ptr<ClientSession> session(createSession());
session->setUseTLS(ClientSession::RequireTLS);
session->setAllowPLAINOverNonTLS(true);
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithMultipleAuthentication();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
void testAuthenticate_NoValidAuthMechanisms() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithUnknownAuthentication();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ server->close();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
void testAuthenticate_EXTERNAL() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithEXTERNALAuthentication();
server->receiveAuthRequest("EXTERNAL");
server->sendAuthSuccess();
server->receiveStreamStart();
session->finish();
}
void testUnexpectedChallenge() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithEXTERNALAuthentication();
server->receiveAuthRequest("EXTERNAL");
server->sendChallenge();
server->sendChallenge();
- CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
CPPUNIT_ASSERT(sessionFinishedReceived);
CPPUNIT_ASSERT(sessionFinishedError);
}
void testStreamManagement() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithPLAINAuthentication();
session->sendCredentials(createSafeByteArray("mypass"));
server->receiveAuthRequest("PLAIN");
server->sendAuthSuccess();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithBindAndStreamManagement();
server->receiveBind();
server->sendBindResult();
server->receiveStreamManagementEnable();
server->sendStreamManagementEnabled();
CPPUNIT_ASSERT(session->getStreamManagementEnabled());
// TODO: Test if the requesters & responders do their work
- CPPUNIT_ASSERT_EQUAL(ClientSession::Initialized, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Initialized, session->getState());
session->finish();
}
void testStreamManagement_Failed() {
std::shared_ptr<ClientSession> session(createSession());
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithPLAINAuthentication();
session->sendCredentials(createSafeByteArray("mypass"));
server->receiveAuthRequest("PLAIN");
server->sendAuthSuccess();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithBindAndStreamManagement();
server->receiveBind();
server->sendBindResult();
server->receiveStreamManagementEnable();
server->sendStreamManagementFailed();
CPPUNIT_ASSERT(!session->getStreamManagementEnabled());
- CPPUNIT_ASSERT_EQUAL(ClientSession::Initialized, session->getState());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Initialized, session->getState());
session->finish();
}
void testFinishAcksStanzas() {
std::shared_ptr<ClientSession> session(createSession());
initializeSession(session);
server->sendMessage();
server->sendMessage();
server->sendMessage();
session->finish();
server->receiveAck(3);
}
+ void testServerInitiatedSessionClose() {
+ std::shared_ptr<ClientSession> session(createSession());
+ initializeSession(session);
+
+ server->onStreamEndReceived();
+ server->close();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+ }
+
+ void testClientInitiatedSessionClose() {
+ std::shared_ptr<ClientSession> session(createSession());
+ initializeSession(session);
+
+ session->finish();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+
+ server->onStreamEndReceived();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
+ }
+
+ void testTimeoutOnShutdown() {
+ std::shared_ptr<ClientSession> session(createSession());
+ initializeSession(session);
+
+ session->finish();
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState());
+ CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer);
+ timerFactory->setTime(60000);
+
+ CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState());
+ CPPUNIT_ASSERT(sessionFinishedReceived);
+ }
+
private:
std::shared_ptr<ClientSession> createSession() {
- std::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server, idnConverter.get(), crypto.get());
+ std::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server, idnConverter.get(), crypto.get(), timerFactory.get());
session->onFinished.connect(boost::bind(&ClientSessionTest::handleSessionFinished, this, _1));
session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::handleSessionNeedCredentials, this));
session->setAllowPLAINOverNonTLS(true);
return session;
}
void initializeSession(std::shared_ptr<ClientSession> session) {
session->start();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithPLAINAuthentication();
session->sendCredentials(createSafeByteArray("mypass"));
server->receiveAuthRequest("PLAIN");
server->sendAuthSuccess();
server->receiveStreamStart();
server->sendStreamStart();
server->sendStreamFeaturesWithBindAndStreamManagement();
server->receiveBind();
server->sendBindResult();
server->receiveStreamManagementEnable();
server->sendStreamManagementEnabled();
}
void handleSessionFinished(std::shared_ptr<Error> error) {
sessionFinishedReceived = true;
sessionFinishedError = error;
}
void handleSessionNeedCredentials() {
needCredentials = true;
@@ -604,60 +682,61 @@ class ClientSessionTest : public CppUnit::TestFixture {
CPPUNIT_ASSERT(event.element);
std::shared_ptr<StanzaAck> ack = std::dynamic_pointer_cast<StanzaAck>(event.element);
CPPUNIT_ASSERT(ack);
CPPUNIT_ASSERT_EQUAL(n, ack->getHandledStanzasCount());
}
Event popEvent() {
CPPUNIT_ASSERT(!receivedEvents.empty());
Event event = receivedEvents.front();
receivedEvents.pop_front();
return event;
}
bool available;
bool canTLSEncrypt;
bool tlsEncrypted;
bool compressed;
bool whitespacePingEnabled;
std::string bindID;
int resetCount;
std::deque<Event> receivedEvents;
};
std::shared_ptr<IDNConverter> idnConverter;
std::shared_ptr<MockSessionStream> server;
bool sessionFinishedReceived;
bool needCredentials;
std::shared_ptr<Error> sessionFinishedError;
BlindCertificateTrustChecker* blindCertificateTrustChecker;
std::shared_ptr<CryptoProvider> crypto;
+ std::shared_ptr<DummyTimerFactory> timerFactory;
};
CPPUNIT_TEST_SUITE_REGISTRATION(ClientSessionTest);
#if 0
void testAuthenticate() {
std::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::setNeedCredentials, this));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithAuthentication();
session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState());
CPPUNIT_ASSERT(needCredentials_);
getMockServer()->expectAuth("me", "mypass");
getMockServer()->sendAuthSuccess();
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
session->sendCredentials("mypass");
CPPUNIT_ASSERT_EQUAL(ClientSession::Authenticating, session->getState());
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::Negotiating, session->getState());
}
void testAuthenticate_Unauthorized() {
std::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp
index 402f642..10c6ad0 100644
--- a/Swiften/Session/BasicSessionStream.cpp
+++ b/Swiften/Session/BasicSessionStream.cpp
@@ -13,88 +13,90 @@
#include <Swiften/StreamStack/CompressionLayer.h>
#include <Swiften/StreamStack/ConnectionLayer.h>
#include <Swiften/StreamStack/StreamStack.h>
#include <Swiften/StreamStack/TLSLayer.h>
#include <Swiften/StreamStack/WhitespacePingLayer.h>
#include <Swiften/StreamStack/XMPPLayer.h>
#include <Swiften/TLS/TLSContext.h>
#include <Swiften/TLS/TLSContextFactory.h>
namespace Swift {
BasicSessionStream::BasicSessionStream(
StreamType streamType,
std::shared_ptr<Connection> connection,
PayloadParserFactoryCollection* payloadParserFactories,
PayloadSerializerCollection* payloadSerializers,
TLSContextFactory* tlsContextFactory,
TimerFactory* timerFactory,
XMLParserFactory* xmlParserFactory,
const TLSOptions& tlsOptions) :
available(false),
connection(connection),
tlsContextFactory(tlsContextFactory),
timerFactory(timerFactory),
compressionLayer(nullptr),
tlsLayer(nullptr),
whitespacePingLayer(nullptr),
tlsOptions_(tlsOptions) {
xmppLayer = new XMPPLayer(payloadParserFactories, payloadSerializers, xmlParserFactory, streamType);
xmppLayer->onStreamStart.connect(boost::bind(&BasicSessionStream::handleStreamStartReceived, this, _1));
+ xmppLayer->onStreamEnd.connect(boost::bind(&BasicSessionStream::handleStreamEndReceived, this));
xmppLayer->onElement.connect(boost::bind(&BasicSessionStream::handleElementReceived, this, _1));
xmppLayer->onError.connect(boost::bind(&BasicSessionStream::handleXMPPError, this));
xmppLayer->onDataRead.connect(boost::bind(&BasicSessionStream::handleDataRead, this, _1));
xmppLayer->onWriteData.connect(boost::bind(&BasicSessionStream::handleDataWritten, this, _1));
connection->onDisconnected.connect(boost::bind(&BasicSessionStream::handleConnectionFinished, this, _1));
connectionLayer = new ConnectionLayer(connection);
streamStack = new StreamStack(xmppLayer, connectionLayer);
available = true;
}
BasicSessionStream::~BasicSessionStream() {
delete compressionLayer;
if (tlsLayer) {
tlsLayer->onError.disconnect(boost::bind(&BasicSessionStream::handleTLSError, this, _1));
tlsLayer->onConnected.disconnect(boost::bind(&BasicSessionStream::handleTLSConnected, this));
delete tlsLayer;
}
delete whitespacePingLayer;
delete streamStack;
connection->onDisconnected.disconnect(boost::bind(&BasicSessionStream::handleConnectionFinished, this, _1));
delete connectionLayer;
xmppLayer->onStreamStart.disconnect(boost::bind(&BasicSessionStream::handleStreamStartReceived, this, _1));
+ xmppLayer->onStreamEnd.disconnect(boost::bind(&BasicSessionStream::handleStreamEndReceived, this));
xmppLayer->onElement.disconnect(boost::bind(&BasicSessionStream::handleElementReceived, this, _1));
xmppLayer->onError.disconnect(boost::bind(&BasicSessionStream::handleXMPPError, this));
xmppLayer->onDataRead.disconnect(boost::bind(&BasicSessionStream::handleDataRead, this, _1));
xmppLayer->onWriteData.disconnect(boost::bind(&BasicSessionStream::handleDataWritten, this, _1));
delete xmppLayer;
}
void BasicSessionStream::writeHeader(const ProtocolHeader& header) {
assert(available);
xmppLayer->writeHeader(header);
}
void BasicSessionStream::writeElement(std::shared_ptr<ToplevelElement> element) {
assert(available);
xmppLayer->writeElement(element);
}
void BasicSessionStream::writeFooter() {
assert(available);
xmppLayer->writeFooter();
}
void BasicSessionStream::writeData(const std::string& data) {
assert(available);
xmppLayer->writeData(data);
}
void BasicSessionStream::close() {
connection->disconnect();
}
@@ -144,60 +146,64 @@ ByteArray BasicSessionStream::getTLSFinishMessage() const {
bool BasicSessionStream::supportsZLibCompression() {
return true;
}
void BasicSessionStream::addZLibCompression() {
compressionLayer = new CompressionLayer();
streamStack->addLayer(compressionLayer);
}
void BasicSessionStream::setWhitespacePingEnabled(bool enabled) {
if (enabled) {
if (!whitespacePingLayer) {
whitespacePingLayer = new WhitespacePingLayer(timerFactory);
streamStack->addLayer(whitespacePingLayer);
}
whitespacePingLayer->setActive();
}
else if (whitespacePingLayer) {
whitespacePingLayer->setInactive();
}
}
void BasicSessionStream::resetXMPPParser() {
xmppLayer->resetParser();
}
void BasicSessionStream::handleStreamStartReceived(const ProtocolHeader& header) {
onStreamStartReceived(header);
}
+void BasicSessionStream::handleStreamEndReceived() {
+ onStreamEndReceived();
+}
+
void BasicSessionStream::handleElementReceived(std::shared_ptr<ToplevelElement> element) {
onElementReceived(element);
}
void BasicSessionStream::handleXMPPError() {
available = false;
onClosed(std::make_shared<SessionStreamError>(SessionStreamError::ParseError));
}
void BasicSessionStream::handleTLSConnected() {
onTLSEncrypted();
}
void BasicSessionStream::handleTLSError(std::shared_ptr<TLSError> error) {
available = false;
onClosed(error);
}
void BasicSessionStream::handleConnectionFinished(const boost::optional<Connection::Error>& error) {
available = false;
if (error == Connection::ReadError) {
onClosed(std::make_shared<SessionStreamError>(SessionStreamError::ConnectionReadError));
}
else if (error) {
onClosed(std::make_shared<SessionStreamError>(SessionStreamError::ConnectionWriteError));
}
else {
onClosed(std::shared_ptr<SessionStreamError>());
}
}
diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h
index 1806cef..48b3d63 100644
--- a/Swiften/Session/BasicSessionStream.h
+++ b/Swiften/Session/BasicSessionStream.h
@@ -46,49 +46,50 @@ namespace Swift {
virtual void close();
virtual bool isOpen();
virtual void writeHeader(const ProtocolHeader& header);
virtual void writeElement(std::shared_ptr<ToplevelElement>);
virtual void writeFooter();
virtual void writeData(const std::string& data);
virtual bool supportsZLibCompression();
virtual void addZLibCompression();
virtual bool supportsTLSEncryption();
virtual void addTLSEncryption();
virtual bool isTLSEncrypted();
virtual Certificate::ref getPeerCertificate() const;
virtual std::vector<Certificate::ref> getPeerCertificateChain() const;
virtual std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const;
virtual ByteArray getTLSFinishMessage() const;
virtual void setWhitespacePingEnabled(bool);
virtual void resetXMPPParser();
private:
void handleConnectionFinished(const boost::optional<Connection::Error>& error);
void handleXMPPError();
void handleTLSConnected();
void handleTLSError(std::shared_ptr<TLSError>);
void handleStreamStartReceived(const ProtocolHeader&);
+ void handleStreamEndReceived();
void handleElementReceived(std::shared_ptr<ToplevelElement>);
void handleDataRead(const SafeByteArray& data);
void handleDataWritten(const SafeByteArray& data);
private:
bool available;
std::shared_ptr<Connection> connection;
TLSContextFactory* tlsContextFactory;
TimerFactory* timerFactory;
XMPPLayer* xmppLayer;
ConnectionLayer* connectionLayer;
CompressionLayer* compressionLayer;
TLSLayer* tlsLayer;
WhitespacePingLayer* whitespacePingLayer;
StreamStack* streamStack;
TLSOptions tlsOptions_;
};
}
diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h
index f56f495..c5ec42a 100644
--- a/Swiften/Session/SessionStream.h
+++ b/Swiften/Session/SessionStream.h
@@ -48,45 +48,46 @@ namespace Swift {
virtual void writeHeader(const ProtocolHeader& header) = 0;
virtual void writeFooter() = 0;
virtual void writeElement(std::shared_ptr<ToplevelElement>) = 0;
virtual void writeData(const std::string& data) = 0;
virtual bool supportsZLibCompression() = 0;
virtual void addZLibCompression() = 0;
virtual bool supportsTLSEncryption() = 0;
virtual void addTLSEncryption() = 0;
virtual bool isTLSEncrypted() = 0;
virtual void setWhitespacePingEnabled(bool enabled) = 0;
virtual void resetXMPPParser() = 0;
void setTLSCertificate(CertificateWithKey::ref cert) {
certificate = cert;
}
virtual bool hasTLSCertificate() {
return certificate && !certificate->isNull();
}
virtual Certificate::ref getPeerCertificate() const = 0;
virtual std::vector<Certificate::ref> getPeerCertificateChain() const = 0;
virtual std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const = 0;
virtual ByteArray getTLSFinishMessage() const = 0;
boost::signals2::signal<void (const ProtocolHeader&)> onStreamStartReceived;
+ boost::signals2::signal<void ()> onStreamEndReceived;
boost::signals2::signal<void (std::shared_ptr<ToplevelElement>)> onElementReceived;
boost::signals2::signal<void (std::shared_ptr<Error>)> onClosed;
boost::signals2::signal<void ()> onTLSEncrypted;
boost::signals2::signal<void (const SafeByteArray&)> onDataRead;
boost::signals2::signal<void (const SafeByteArray&)> onDataWritten;
protected:
CertificateWithKey::ref getTLSCertificate() const {
return certificate;
}
private:
CertificateWithKey::ref certificate;
};
}
diff --git a/Swiften/StreamStack/XMPPLayer.cpp b/Swiften/StreamStack/XMPPLayer.cpp
index 2eed906..982d13f 100644
--- a/Swiften/StreamStack/XMPPLayer.cpp
+++ b/Swiften/StreamStack/XMPPLayer.cpp
@@ -59,42 +59,43 @@ void XMPPLayer::handleDataRead(const SafeByteArray& data) {
inParser_ = true;
// FIXME: Converting to unsafe string. Should be ok, since we don't take passwords
// from the stream in clients. If servers start using this, and require safe storage,
// we need to fix this.
if (!xmppParser_->parse(byteArrayToString(ByteArray(data.begin(), data.end())))) {
inParser_ = false;
onError();
return;
}
inParser_ = false;
if (resetParserAfterParse_) {
doResetParser();
}
}
void XMPPLayer::doResetParser() {
delete xmppParser_;
xmppParser_ = new XMPPParser(this, payloadParserFactories_, xmlParserFactory_);
resetParserAfterParse_ = false;
}
void XMPPLayer::handleStreamStart(const ProtocolHeader& header) {
onStreamStart(header);
}
void XMPPLayer::handleElement(std::shared_ptr<ToplevelElement> stanza) {
onElement(stanza);
}
void XMPPLayer::handleStreamEnd() {
+ onStreamEnd();
}
void XMPPLayer::resetParser() {
if (inParser_) {
resetParserAfterParse_ = true;
}
else {
doResetParser();
}
}
}
diff --git a/Swiften/StreamStack/XMPPLayer.h b/Swiften/StreamStack/XMPPLayer.h
index 1d4abf8..f0b5afb 100644
--- a/Swiften/StreamStack/XMPPLayer.h
+++ b/Swiften/StreamStack/XMPPLayer.h
@@ -24,53 +24,54 @@ namespace Swift {
class PayloadParserFactoryCollection;
class XMPPSerializer;
class PayloadSerializerCollection;
class XMLParserFactory;
class BOSHSessionStream;
class SWIFTEN_API XMPPLayer : public XMPPParserClient, public HighLayer, boost::noncopyable {
friend class BOSHSessionStream;
public:
XMPPLayer(
PayloadParserFactoryCollection* payloadParserFactories,
PayloadSerializerCollection* payloadSerializers,
XMLParserFactory* xmlParserFactory,
StreamType streamType,
bool setExplictNSonTopLevelElements = false);
virtual ~XMPPLayer();
void writeHeader(const ProtocolHeader& header);
void writeFooter();
void writeElement(std::shared_ptr<ToplevelElement>);
void writeData(const std::string& data);
void resetParser();
protected:
void handleDataRead(const SafeByteArray& data);
void writeDataInternal(const SafeByteArray& data);
public:
boost::signals2::signal<void (const ProtocolHeader&)> onStreamStart;
+ boost::signals2::signal<void ()> onStreamEnd;
boost::signals2::signal<void (std::shared_ptr<ToplevelElement>)> onElement;
boost::signals2::signal<void (const SafeByteArray&)> onWriteData;
boost::signals2::signal<void (const SafeByteArray&)> onDataRead;
boost::signals2::signal<void ()> onError;
private:
void handleStreamStart(const ProtocolHeader&);
void handleElement(std::shared_ptr<ToplevelElement>);
void handleStreamEnd();
void doResetParser();
private:
PayloadParserFactoryCollection* payloadParserFactories_;
XMPPParser* xmppParser_;
PayloadSerializerCollection* payloadSerializers_;
XMLParserFactory* xmlParserFactory_;
XMPPSerializer* xmppSerializer_;
bool setExplictNSonTopLevelElements_;
bool resetParserAfterParse_;
bool inParser_;
};
}