diff options
author | Tobias Markmann <tm@ayena.de> | 2016-07-29 09:47:23 (GMT) |
---|---|---|
committer | Tobias Markmann <tm@ayena.de> | 2016-11-28 10:35:05 (GMT) |
commit | 2039930eadd4756068a8a60c8340d9908a7136d3 (patch) | |
tree | d8aca4bf98a2bb6e3b819305b1f87af3117f4910 | |
parent | 2f90eb7409df91a80c60b189242ac0c1de313910 (diff) | |
download | swift-2039930eadd4756068a8a60c8340d9908a7136d3.zip swift-2039930eadd4756068a8a60c8340d9908a7136d3.tar.bz2 |
Correctly handle server initiated closing of stream
If a server closes the XMPP stream, it sends a </stream:stream>
tag. The client is supposed to respond with the same tag and
then both parties can close the TLS/TCP socket.
Previously Swift(-en) would simply ignore </stream:stream>
tag if it was not directly followed by a shutdown of the TCP
connection.
In addition there is now a timeout timer started as soon as
Swiften or the server initiates a shutdown. It will close
the socket and cleanup the ClientSession if the server does
not respond in time or the network is faulty.
Refactored some code in ClientSession in the process. Moved
ClientSession::State to a C++11 strongly typed enum class.
This also fixes issues where duplicated </stream:stream>
tags would be send by Swift.
Test-Information:
Tested against Prosody ba782a093b14 and M-Link 16.3v6-0,
which provide ad-hoc commands to end a user session.
Previously this was ignored by Swift. Now it correctly responds
to the server, detects it as a disconnect and tries to
reconnect afterwards.
Added unit test for the case where the server closes the
session stream.
Change-Id: I59dfde3aa6b50dc117f340e5db6b9e58b54b3c60
-rwxr-xr-x | BuildTools/FixIncludes.py | 5 | ||||
-rw-r--r-- | Swift/ChangeLog.md | 1 | ||||
-rw-r--r-- | Swiften/Base/Debug.cpp | 50 | ||||
-rw-r--r-- | Swiften/Base/Debug.h | 8 | ||||
-rw-r--r-- | Swiften/ChangeLog.md | 4 | ||||
-rw-r--r-- | Swiften/Client/ClientOptions.h | 54 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.cpp | 184 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.h | 40 | ||||
-rw-r--r-- | Swiften/Client/ClientSessionStanzaChannel.h | 2 | ||||
-rw-r--r-- | Swiften/Client/CoreClient.cpp | 6 | ||||
-rw-r--r-- | Swiften/Client/UnitTest/ClientSessionTest.cpp | 111 | ||||
-rw-r--r-- | Swiften/Session/BasicSessionStream.cpp | 6 | ||||
-rw-r--r-- | Swiften/Session/BasicSessionStream.h | 1 | ||||
-rw-r--r-- | Swiften/Session/SessionStream.h | 1 | ||||
-rw-r--r-- | Swiften/StreamStack/XMPPLayer.cpp | 1 | ||||
-rw-r--r-- | Swiften/StreamStack/XMPPLayer.h | 1 |
16 files changed, 350 insertions, 125 deletions
diff --git a/BuildTools/FixIncludes.py b/BuildTools/FixIncludes.py index d1b8268..8984944 100755 --- a/BuildTools/FixIncludes.py +++ b/BuildTools/FixIncludes.py @@ -1,93 +1,96 @@ #!/usr/bin/env python import sys; import os; import re; from sets import Set filename = sys.argv[1] inPlace = False if "-i" in sys.argv: inPlace = True filename_base = os.path.basename(filename) (filename_name, filename_ext) = os.path.splitext(filename_base) c_stdlib_headers = Set(["assert.h", "limits.h", "signal.h", "stdlib.h", "ctype.h", "locale.h", "stdarg.h", "string.h", "errno.h", "math.h", "stddef.h", "time.h", "float.h", "setjmp.h", "stdio.h", "iso646.h", "wchar.h", "wctype.h", "complex.h", "inttypes.h", "stdint.h", "tgmath.h", "fenv.h", "stdbool.h"]) cpp_stdlib_headers = Set(["algorithm", "fstream", "list", "regex", "typeindex", "array", "functional", "locale", "set", "typeinfo", "atomic", "future", "map", "sstream", "type_traits", "bitset", "initializer_list", "memory", "stack", "unordered_map", "chrono", "iomanip", "mutex", "stdexcept", "unordered_set", "codecvt", "ios", "new", "streambuf", "utility", "complex", "iosfwd", "numeric", "string", "valarray", "condition_variable", "iostream", "ostream", "strstream", "vector", "deque", "istream", "queue", "system_error", "exception", "iterator", "random", "thread", "forward_list", "limits", "ratio", "tuple", "cassert", "ciso646", "csetjmp", "cstdio", "ctime", "cctype", "climits", "csignal", "cstdlib", "cwchar", "cerrno", "clocale", "cstdarg", "cstring", "cwctype", "cfloat", "cmath", "cstddef"]) class HeaderType: - PRAGMA_ONCE, CORRESPONDING_HEADER, C_STDLIB, CPP_STDLIB, BOOST, QT, OTHER, SWIFTEN, LIMBER, SLIMBER, SWIFT_CONTROLLERS, SLUIFT, SWIFTOOLS, SWIFT = range(14) + PRAGMA_ONCE, CORRESPONDING_HEADER, C_STDLIB, CPP_STDLIB, BOOST, QT, SWIFTEN_BASE_DEBUG, OTHER, SWIFTEN, LIMBER, SLIMBER, SWIFT_CONTROLLERS, SLUIFT, SWIFTOOLS, SWIFT = range(15) def findHeaderBlock(lines): start = False end = False lastLine = None for idx, line in enumerate(lines): if not start and line.startswith("#"): start = idx elif start and (not end) and (not line.startswith("#")) and line.strip(): end = idx-1 break if not end: end = len(lines) return (start, end) def lineToFileName(line): match = re.match( r'#include "(.*)"', line) if match: return match.group(1) match = re.match( r'#include <(.*)>', line) if match: return match.group(1) return False def fileNameToHeaderType(name): if name.endswith("/" + filename_name + ".h"): return HeaderType.CORRESPONDING_HEADER if name in c_stdlib_headers: return HeaderType.C_STDLIB if name in cpp_stdlib_headers: return HeaderType.CPP_STDLIB if name.startswith("boost"): return HeaderType.BOOST if name.startswith("Q"): return HeaderType.QT + if name.startswith("Swiften/Base/Debug.h"): + return HeaderType.SWIFTEN_BASE_DEBUG + if name.startswith("Swiften"): return HeaderType.SWIFTEN if name.startswith("Limber"): return HeaderType.LIMBER if name.startswith("Slimber"): return HeaderType.SLIMBER if name.startswith("Swift/Controllers"): return HeaderType.SWIFT_CONTROLLERS if name.startswith("Sluift"): return HeaderType.SLUIFT if name.startswith("SwifTools"): return HeaderType.SWIFTOOLS if name.startswith("Swift"): return HeaderType.SWIFT return HeaderType.OTHER def serializeHeaderGroups(groups): headerList = [] for group in range(0, HeaderType.SWIFT + 1): if group in groups: # sorted and without duplicates headers = sorted(list(set(groups[group]))) headerList.extend(headers) diff --git a/Swift/ChangeLog.md b/Swift/ChangeLog.md index 475062c..4c88d88 100644 --- a/Swift/ChangeLog.md +++ b/Swift/ChangeLog.md @@ -1,34 +1,35 @@ 4.0-in-progress --------------- - Fix UI layout issue for translations that require right-to-left (RTL) layout - macOS releases are now code-signed with a key from Apple, so they can be run without Gatekeeper trust warnings +- Handle sessions being closed by the server 4.0-beta2 ( 2016-07-20 ) ------------------------ - Fix Swift bug introduced in 4.0-beta1 that results in the UI sometimes getting stuck during login 4.0-beta1 ( 2016-07-15 ) ------------------------ - Support for message carbons (XEP-0280) - Improved spell checker support on Linux - Enabled trellis mode as a default feature, allowing several tiled chats windows to be shown at once - New chat theme including a new font - And assorted smaller features and usability enhancements 3.0 ( 2016-02-24 ) ------------------ - File transfer and Mac Notification Center issues fixed - Fix connection to servers with invalid or untrusted certificates on OS/X - Support for the Notification Center on OS X - Users can now authenticate using certificates (and smart cards on Windows) when using the 'BOSH' connection type. - Encryption on OS X now uses the platform's native 'Secure Transport' mechanisms. - Emoticons menu in chat dialogs - Bookmark for rooms can now be edited directly from the ‘Recent Chats’ list - Adds option to workaround servers that don’t interoperate well with Windows (schannel) encryption - Rooms entered while offline will now get entered on reconnect - Chats can now be seamlessly upgraded to multi-person chats by either inviting someone via the ‘cog’ menu, or dragging them from the roster. This relies on server-side support with an appropriate chatroom (MUC) service. - Highlighting of keywords and messages from particular users can now be configured (Keyword Highlighting Blog post). - Full profile vcards (contact information etc.) are now supported and can be configured for the user and queried for contacts. - Simple Communication Blocking is now supported (subject to server support) to allow the blocking of nuisance users. - Swift can now transfer files via the ‘Jingle File Transfer’ protocol. - The status setter will now remember previously set statuses and will allow quick access to these when the user types part of a recently used status. diff --git a/Swiften/Base/Debug.cpp b/Swiften/Base/Debug.cpp index b59de35..b4245c3 100644 --- a/Swiften/Base/Debug.cpp +++ b/Swiften/Base/Debug.cpp @@ -1,42 +1,43 @@ /* * Copyright (c) 2015-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #include <Swiften/Base/Debug.h> #include <iostream> #include <memory> #include <Swiften/Client/ClientError.h> +#include <Swiften/Client/ClientSession.h> #include <Swiften/Serializer/PayloadSerializer.h> #include <Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h> #include <Swiften/Serializer/XMPPSerializer.h> std::ostream& operator<<(std::ostream& os, const Swift::ClientError& error) { os << "ClientError("; switch(error.getType()) { case Swift::ClientError::UnknownError: os << "UnknownError"; break; case Swift::ClientError::DomainNameResolveError: os << "DomainNameResolveError"; break; case Swift::ClientError::ConnectionError: os << "ConnectionError"; break; case Swift::ClientError::ConnectionReadError: os << "ConnectionReadError"; break; case Swift::ClientError::ConnectionWriteError: os << "ConnectionWriteError"; break; case Swift::ClientError::XMLError: os << "XMLError"; break; case Swift::ClientError::AuthenticationFailedError: os << "AuthenticationFailedError"; break; case Swift::ClientError::CompressionFailedError: os << "CompressionFailedError"; @@ -111,30 +112,79 @@ std::ostream& operator<<(std::ostream& os, const Swift::ClientError& error) { os << "RevocationCheckFailedError"; break; } os << ")"; return os; } std::ostream& operator<<(std::ostream& os, Swift::Element* ele) { using namespace Swift; std::shared_ptr<Element> element = std::shared_ptr<Element>(ele); std::shared_ptr<Payload> payload = std::dynamic_pointer_cast<Payload>(element); if (payload) { FullPayloadSerializerCollection payloadSerializerCollection; PayloadSerializer *serializer = payloadSerializerCollection.getPayloadSerializer(payload); os << "Payload(" << serializer->serialize(payload) << ")"; return os; } std::shared_ptr<ToplevelElement> topLevelElement = std::dynamic_pointer_cast<ToplevelElement>(element); if (topLevelElement) { FullPayloadSerializerCollection payloadSerializerCollection; XMPPSerializer xmppSerializer(&payloadSerializerCollection, ClientStreamType, false); SafeByteArray serialized = xmppSerializer.serializeElement(topLevelElement); os << "TopLevelElement(" << safeByteArrayToString(serialized) << ")"; return os; } os << "Element(Unknown)"; return os; } + +std::ostream& operator<<(std::ostream& os, Swift::ClientSession::State state) { + using CS = Swift::ClientSession; + switch (state) { + case CS::State::Initial: + os << "ClientSession::State::Initial"; + break; + case CS::State::WaitingForStreamStart: + os << "ClientSession::State::WaitingForStreamStart"; + break; + case CS::State::Negotiating: + os << "ClientSession::State::Negotiating"; + break; + case CS::State::Compressing: + os << "ClientSession::State::Compressing"; + break; + case CS::State::WaitingForEncrypt: + os << "ClientSession::State::WaitingForEncrypt"; + break; + case CS::State::Encrypting: + os << "ClientSession::State::Encrypting"; + break; + case CS::State::WaitingForCredentials: + os << "ClientSession::State::WaitingForCredentials"; + break; + case CS::State::Authenticating: + os << "ClientSession::State::Authenticating"; + break; + case CS::State::EnablingSessionManagement: + os << "ClientSession::State::EnablingSessionManagement"; + break; + case CS::State::BindingResource: + os << "ClientSession::State::BindingResource"; + break; + case CS::State::StartingSession: + os << "ClientSession::State::StartingSession"; + break; + case CS::State::Initialized: + os << "ClientSession::State::Initialized"; + break; + case CS::State::Finishing: + os << "ClientSession::State::Finishing"; + break; + case CS::State::Finished: + os << "ClientSession::State::Finished"; + break; + } + return os; +} diff --git a/Swiften/Base/Debug.h b/Swiften/Base/Debug.h index 18e7fb4..9dde74c 100644 --- a/Swiften/Base/Debug.h +++ b/Swiften/Base/Debug.h @@ -1,20 +1,26 @@ /* - * Copyright (c) 2015 Isode Limited. + * Copyright (c) 2015-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ +#pragma once + #include <iosfwd> +#include <Swiften/Client/ClientSession.h> + namespace Swift { class ClientError; class Element; } namespace boost { template<class T> class shared_ptr; } std::ostream& operator<<(std::ostream& os, const Swift::ClientError& error); std::ostream& operator<<(std::ostream& os, Swift::Element* ele); + +std::ostream& operator<<(std::ostream& os, Swift::ClientSession::State state); diff --git a/Swiften/ChangeLog.md b/Swiften/ChangeLog.md index 4d52c68..62cea5e 100644 --- a/Swiften/ChangeLog.md +++ b/Swiften/ChangeLog.md @@ -1,13 +1,17 @@ +4.0-in-progress +--------------- +- Handle sessions being closed by the server + 4.0-beta1 ( 2016-07-15 ) ------------------------ - Moved code-base to C++11 - Use C++11 threading instead of Boost.Thread library - Use C++11 smart pointers instead of Boost's - Migrated from Boost.Signals to Boost.Signals2 - Build without warnings on our CI platforms - General cleanup like remove of superflous files and #include statements. This means header files that previously were included implictly need to be explicitly included now - Support IPv6 addresses in URLs - Changed source code style to use soft tabs (4 spaces wide) instead of hard tabs. Custom patches for Swiften will need to be reformatted accordingly - Require a TLS backend for building - Update 3rdParty/lcov to version 1.12 - Fix several possible race conditions and other small bugs
\ No newline at end of file diff --git a/Swiften/Client/ClientOptions.h b/Swiften/Client/ClientOptions.h index 3a93197..1a337b6 100644 --- a/Swiften/Client/ClientOptions.h +++ b/Swiften/Client/ClientOptions.h @@ -3,160 +3,152 @@ * All rights reserved. * See the COPYING file for more information. */ #pragma once #include <memory> #include <Swiften/Base/API.h> #include <Swiften/Base/SafeString.h> #include <Swiften/Base/URL.h> #include <Swiften/TLS/TLSOptions.h> namespace Swift { class HTTPTrafficFilter; struct SWIFTEN_API ClientOptions { enum UseTLS { NeverUseTLS, UseTLSWhenAvailable, RequireTLS }; enum ProxyType { NoProxy, SystemConfiguredProxy, SOCKS5Proxy, HTTPConnectProxy }; - ClientOptions() : - useStreamCompression(true), - useTLS(UseTLSWhenAvailable), - allowPLAINWithoutTLS(false), - useStreamResumption(false), - forgetPassword(false), - useAcks(true), - singleSignOn(false), - manualHostname(""), - manualPort(-1), - proxyType(SystemConfiguredProxy), - manualProxyHostname(""), - manualProxyPort(-1), - boshHTTPConnectProxyAuthID(""), - boshHTTPConnectProxyAuthPassword("") { + ClientOptions() { } /** * Whether ZLib stream compression should be used when available. * * Default: true */ - bool useStreamCompression; + bool useStreamCompression = true; /** * Sets whether TLS encryption should be used. * * Default: UseTLSWhenAvailable */ - UseTLS useTLS; + UseTLS useTLS = UseTLSWhenAvailable; /** * Sets whether plaintext authentication is * allowed over non-TLS-encrypted connections. * * Default: false */ - bool allowPLAINWithoutTLS; + bool allowPLAINWithoutTLS = false; /** * Use XEP-196 stream resumption when available. * * Default: false */ - bool useStreamResumption; + bool useStreamResumption = false; /** * Forget the password once it's used. * This makes the Client useless after the first login attempt. * * FIXME: This is a temporary workaround. * * Default: false */ - bool forgetPassword; + bool forgetPassword = false; /** * Use XEP-0198 acks in the stream when available. * Default: true */ - bool useAcks; + bool useAcks = true; /** * Use Single Sign On. * Default: false */ - bool singleSignOn; + bool singleSignOn = false; /** * The hostname to connect to. * Leave this empty for standard XMPP connection, based on the JID domain. */ - std::string manualHostname; + std::string manualHostname = ""; /** * The port to connect to. * Leave this to -1 to use the port discovered by SRV lookups, and 5222 as a * fallback. */ - int manualPort; + int manualPort = -1; /** * The type of proxy to use for connecting to the XMPP * server. */ - ProxyType proxyType; + ProxyType proxyType = SystemConfiguredProxy; /** * Override the system-configured proxy hostname. */ - std::string manualProxyHostname; + std::string manualProxyHostname = ""; /** * Override the system-configured proxy port. */ - int manualProxyPort; + int manualProxyPort = -1; /** * If non-empty, use BOSH instead of direct TCP, with the given URL. * Default: empty (no BOSH) */ - URL boshURL; + URL boshURL = URL(); /** * If non-empty, BOSH connections will try to connect over this HTTP CONNECT * proxy instead of directly. * Default: empty (no proxy) */ - URL boshHTTPConnectProxyURL; + URL boshHTTPConnectProxyURL = URL(); /** * If this and matching Password are non-empty, BOSH connections over * HTTP CONNECT proxies will use these credentials for proxy access. * Default: empty (no authentication needed by the proxy) */ - SafeString boshHTTPConnectProxyAuthID; - SafeString boshHTTPConnectProxyAuthPassword; + SafeString boshHTTPConnectProxyAuthID = SafeString(""); + SafeString boshHTTPConnectProxyAuthPassword = SafeString(""); /** * This can be initialized with a custom HTTPTrafficFilter, which allows HTTP CONNECT * proxy initialization to be customized. */ std::shared_ptr<HTTPTrafficFilter> httpTrafficFilter; /** * Options passed to the TLS stack */ TLSOptions tlsOptions; + + /** + * Session shutdown timeout in milliseconds. This is the maximum time Swiften + * waits from a session close to the socket close. + */ + int sessionShutdownTimeoutInMilliseconds = 10000; }; } diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index c301881..bcfb004 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -1,514 +1,576 @@ /* * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #include <Swiften/Client/ClientSession.h> +#include <memory> + #include <boost/bind.hpp> #include <boost/uuid/uuid.hpp> #include <boost/uuid/uuid_io.hpp> #include <boost/uuid/uuid_generators.hpp> -#include <memory> -#include <Swiften/Base/Platform.h> #include <Swiften/Base/Log.h> -#include <Swiften/Elements/ProtocolHeader.h> -#include <Swiften/Elements/StreamFeatures.h> -#include <Swiften/Elements/StreamError.h> -#include <Swiften/Elements/StartTLSRequest.h> -#include <Swiften/Elements/StartTLSFailure.h> -#include <Swiften/Elements/TLSProceed.h> -#include <Swiften/Elements/AuthRequest.h> -#include <Swiften/Elements/AuthSuccess.h> -#include <Swiften/Elements/AuthFailure.h> +#include <Swiften/Base/Platform.h> +#include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Elements/AuthChallenge.h> +#include <Swiften/Elements/AuthFailure.h> +#include <Swiften/Elements/AuthRequest.h> #include <Swiften/Elements/AuthResponse.h> -#include <Swiften/Elements/Compressed.h> +#include <Swiften/Elements/AuthSuccess.h> #include <Swiften/Elements/CompressFailure.h> #include <Swiften/Elements/CompressRequest.h> +#include <Swiften/Elements/Compressed.h> #include <Swiften/Elements/EnableStreamManagement.h> -#include <Swiften/Elements/StreamManagementEnabled.h> -#include <Swiften/Elements/StreamManagementFailed.h> -#include <Swiften/Elements/StartSession.h> -#include <Swiften/Elements/StanzaAck.h> -#include <Swiften/Elements/StanzaAckRequest.h> #include <Swiften/Elements/IQ.h> +#include <Swiften/Elements/ProtocolHeader.h> #include <Swiften/Elements/ResourceBind.h> -#include <Swiften/SASL/PLAINClientAuthenticator.h> +#include <Swiften/Elements/StanzaAck.h> +#include <Swiften/Elements/StanzaAckRequest.h> +#include <Swiften/Elements/StartSession.h> +#include <Swiften/Elements/StartTLSFailure.h> +#include <Swiften/Elements/StartTLSRequest.h> +#include <Swiften/Elements/StreamError.h> +#include <Swiften/Elements/StreamFeatures.h> +#include <Swiften/Elements/StreamManagementEnabled.h> +#include <Swiften/Elements/StreamManagementFailed.h> +#include <Swiften/Elements/TLSProceed.h> +#include <Swiften/Network/Timer.h> +#include <Swiften/Network/TimerFactory.h> +#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h> #include <Swiften/SASL/EXTERNALClientAuthenticator.h> +#include <Swiften/SASL/PLAINClientAuthenticator.h> #include <Swiften/SASL/SCRAMSHA1ClientAuthenticator.h> -#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h> -#include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Session/SessionStream.h> +#include <Swiften/StreamManagement/StanzaAckRequester.h> +#include <Swiften/StreamManagement/StanzaAckResponder.h> #include <Swiften/TLS/CertificateTrustChecker.h> #include <Swiften/TLS/ServerIdentityVerifier.h> #ifdef SWIFTEN_PLATFORM_WIN32 #include <Swiften/Base/WindowsRegistry.h> #include <Swiften/SASL/WindowsGSSAPIClientAuthenticator.h> #endif #define CHECK_STATE_OR_RETURN(a) \ if (!checkState(a)) { return; } namespace Swift { ClientSession::ClientSession( const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, - CryptoProvider* crypto) : + CryptoProvider* crypto, + TimerFactory* timerFactory) : localJID(jid), - state(Initial), + state(State::Initial), stream(stream), idnConverter(idnConverter), crypto(crypto), + timerFactory(timerFactory), allowPLAINOverNonTLS(false), useStreamCompression(true), useTLS(UseTLSWhenAvailable), useAcks(true), needSessionStart(false), needResourceBind(false), needAcking(false), rosterVersioningSupported(false), authenticator(nullptr), certificateTrustChecker(nullptr), singleSignOn(false), authenticationPort(-1) { #ifdef SWIFTEN_PLATFORM_WIN32 if (WindowsRegistry::isFIPSEnabled()) { SWIFT_LOG(info) << "Windows is running in FIPS-140 mode. Some authentication methods will be unavailable." << std::endl; } #endif } ClientSession::~ClientSession() { } void ClientSession::start() { stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onStreamEndReceived.connect(boost::bind(&ClientSession::handleStreamEnd, shared_from_this())); stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); stream->onClosed.connect(boost::bind(&ClientSession::handleStreamClosed, shared_from_this(), _1)); stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); - assert(state == Initial); - state = WaitingForStreamStart; + assert(state == State::Initial); + state = State::WaitingForStreamStart; sendStreamHeader(); } void ClientSession::sendStreamHeader() { ProtocolHeader header; header.setTo(getRemoteJID()); stream->writeHeader(header); } void ClientSession::sendStanza(std::shared_ptr<Stanza> stanza) { stream->writeElement(stanza); if (stanzaAckRequester_) { stanzaAckRequester_->handleStanzaSent(stanza); } } void ClientSession::handleStreamStart(const ProtocolHeader&) { - CHECK_STATE_OR_RETURN(WaitingForStreamStart); - state = Negotiating; + CHECK_STATE_OR_RETURN(State::WaitingForStreamStart); + state = State::Negotiating; +} + +void ClientSession::handleStreamEnd() { + if (state == State::Finishing) { + // We are already in finishing state if we iniated the close of the session. + stream->close(); + } + else { + state = State::Finishing; + initiateShutdown(true); + } } void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { if (std::shared_ptr<Stanza> stanza = std::dynamic_pointer_cast<Stanza>(element)) { if (stanzaAckResponder_) { stanzaAckResponder_->handleStanzaReceived(); } - if (getState() == Initialized) { + if (getState() == State::Initialized) { onStanzaReceived(stanza); } else if (std::shared_ptr<IQ> iq = std::dynamic_pointer_cast<IQ>(element)) { - if (state == BindingResource) { + if (state == State::BindingResource) { std::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { finishSession(Error::ResourceBindError); } else if (!resourceBind) { finishSession(Error::UnexpectedElementError); } else if (iq->getType() == IQ::Result) { localJID = resourceBind->getJID(); if (!localJID.isValid()) { finishSession(Error::ResourceBindError); } needResourceBind = false; continueSessionInitialization(); } else { finishSession(Error::UnexpectedElementError); } } - else if (state == StartingSession) { + else if (state == State::StartingSession) { if (iq->getType() == IQ::Result) { needSessionStart = false; continueSessionInitialization(); } else if (iq->getType() == IQ::Error) { finishSession(Error::SessionStartError); } else { finishSession(Error::UnexpectedElementError); } } else { finishSession(Error::UnexpectedElementError); } } } else if (std::dynamic_pointer_cast<StanzaAckRequest>(element)) { if (stanzaAckResponder_) { stanzaAckResponder_->handleAckRequestReceived(); } } else if (std::shared_ptr<StanzaAck> ack = std::dynamic_pointer_cast<StanzaAck>(element)) { if (stanzaAckRequester_) { if (ack->isValid()) { stanzaAckRequester_->handleAckReceived(ack->getHandledStanzasCount()); } else { SWIFT_LOG(warning) << "Got invalid ack from server"; } } else { SWIFT_LOG(warning) << "Ignoring ack"; } } else if (StreamError::ref streamError = std::dynamic_pointer_cast<StreamError>(element)) { finishSession(Error::StreamError); } - else if (getState() == Initialized) { + else if (getState() == State::Initialized) { std::shared_ptr<Stanza> stanza = std::dynamic_pointer_cast<Stanza>(element); if (stanza) { if (stanzaAckResponder_) { stanzaAckResponder_->handleStanzaReceived(); } onStanzaReceived(stanza); } } else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { - CHECK_STATE_OR_RETURN(Negotiating); + CHECK_STATE_OR_RETURN(State::Negotiating); if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption() && useTLS != NeverUseTLS) { - state = WaitingForEncrypt; + state = State::WaitingForEncrypt; stream->writeElement(std::make_shared<StartTLSRequest>()); } else if (useTLS == RequireTLS && !stream->isTLSEncrypted()) { finishSession(Error::NoSupportedAuthMechanismsError); } else if (useStreamCompression && stream->supportsZLibCompression() && streamFeatures->hasCompressionMethod("zlib")) { - state = Compressing; + state = State::Compressing; stream->writeElement(std::make_shared<CompressRequest>("zlib")); } else if (streamFeatures->hasAuthenticationMechanisms()) { #ifdef SWIFTEN_PLATFORM_WIN32 if (singleSignOn) { const boost::optional<std::string> authenticationHostname = streamFeatures->getAuthenticationHostname(); bool gssapiSupported = streamFeatures->hasAuthenticationMechanism("GSSAPI") && authenticationHostname && !authenticationHostname->empty(); if (!gssapiSupported) { finishSession(Error::NoSupportedAuthMechanismsError); } else { WindowsGSSAPIClientAuthenticator* gssapiAuthenticator = new WindowsGSSAPIClientAuthenticator(*authenticationHostname, localJID.getDomain(), authenticationPort); std::shared_ptr<Error> error = std::make_shared<Error>(Error::AuthenticationFailedError); authenticator = gssapiAuthenticator; if (!gssapiAuthenticator->isError()) { - state = Authenticating; + state = State::Authenticating; stream->writeElement(std::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse())); } else { error->errorCode = gssapiAuthenticator->getErrorCode(); finishSession(error); } } } else #endif if (stream->hasTLSCertificate()) { if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { authenticator = new EXTERNALClientAuthenticator(); - state = Authenticating; + state = State::Authenticating; stream->writeElement(std::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray(""))); } else { finishSession(Error::TLSClientCertificateError); } } else if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { authenticator = new EXTERNALClientAuthenticator(); - state = Authenticating; + state = State::Authenticating; stream->writeElement(std::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray(""))); } else if (streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1") || streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1-PLUS")) { std::ostringstream s; ByteArray finishMessage; bool plus = streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1-PLUS"); if (stream->isTLSEncrypted()) { finishMessage = stream->getTLSFinishMessage(); plus &= !finishMessage.empty(); } s << boost::uuids::random_generator()(); SCRAMSHA1ClientAuthenticator* scramAuthenticator = new SCRAMSHA1ClientAuthenticator(s.str(), plus, idnConverter, crypto); if (!finishMessage.empty()) { scramAuthenticator->setTLSChannelBindingData(finishMessage); } authenticator = scramAuthenticator; - state = WaitingForCredentials; + state = State::WaitingForCredentials; onNeedCredentials(); } else if ((stream->isTLSEncrypted() || allowPLAINOverNonTLS) && streamFeatures->hasAuthenticationMechanism("PLAIN")) { authenticator = new PLAINClientAuthenticator(); - state = WaitingForCredentials; + state = State::WaitingForCredentials; onNeedCredentials(); } else if (streamFeatures->hasAuthenticationMechanism("DIGEST-MD5") && crypto->isMD5AllowedForCrypto()) { std::ostringstream s; s << boost::uuids::random_generator()(); // FIXME: Host should probably be the actual host authenticator = new DIGESTMD5ClientAuthenticator(localJID.getDomain(), s.str(), crypto); - state = WaitingForCredentials; + state = State::WaitingForCredentials; onNeedCredentials(); } else { finishSession(Error::NoSupportedAuthMechanismsError); } } else { // Start the session rosterVersioningSupported = streamFeatures->hasRosterVersioning(); stream->setWhitespacePingEnabled(true); needSessionStart = streamFeatures->hasSession(); needResourceBind = streamFeatures->hasResourceBind(); needAcking = streamFeatures->hasStreamManagement() && useAcks; if (!needResourceBind) { // Resource binding is a MUST finishSession(Error::ResourceBindError); } else { continueSessionInitialization(); } } } else if (std::dynamic_pointer_cast<Compressed>(element)) { - CHECK_STATE_OR_RETURN(Compressing); - state = WaitingForStreamStart; + CHECK_STATE_OR_RETURN(State::Compressing); + state = State::WaitingForStreamStart; stream->addZLibCompression(); stream->resetXMPPParser(); sendStreamHeader(); } else if (std::dynamic_pointer_cast<CompressFailure>(element)) { finishSession(Error::CompressionFailedError); } else if (std::dynamic_pointer_cast<StreamManagementEnabled>(element)) { stanzaAckRequester_ = std::make_shared<StanzaAckRequester>(); stanzaAckRequester_->onRequestAck.connect(boost::bind(&ClientSession::requestAck, shared_from_this())); stanzaAckRequester_->onStanzaAcked.connect(boost::bind(&ClientSession::handleStanzaAcked, shared_from_this(), _1)); stanzaAckResponder_ = std::make_shared<StanzaAckResponder>(); stanzaAckResponder_->onAck.connect(boost::bind(&ClientSession::ack, shared_from_this(), _1)); needAcking = false; continueSessionInitialization(); } else if (std::dynamic_pointer_cast<StreamManagementFailed>(element)) { needAcking = false; continueSessionInitialization(); } else if (AuthChallenge* challenge = dynamic_cast<AuthChallenge*>(element.get())) { - CHECK_STATE_OR_RETURN(Authenticating); + CHECK_STATE_OR_RETURN(State::Authenticating); assert(authenticator); if (authenticator->setChallenge(challenge->getValue())) { stream->writeElement(std::make_shared<AuthResponse>(authenticator->getResponse())); } #ifdef SWIFTEN_PLATFORM_WIN32 else if (WindowsGSSAPIClientAuthenticator* gssapiAuthenticator = dynamic_cast<WindowsGSSAPIClientAuthenticator*>(authenticator)) { std::shared_ptr<Error> error = std::make_shared<Error>(Error::AuthenticationFailedError); error->errorCode = gssapiAuthenticator->getErrorCode(); finishSession(error); } #endif else { finishSession(Error::AuthenticationFailedError); } } else if (AuthSuccess* authSuccess = dynamic_cast<AuthSuccess*>(element.get())) { - CHECK_STATE_OR_RETURN(Authenticating); + CHECK_STATE_OR_RETURN(State::Authenticating); assert(authenticator); if (!authenticator->setChallenge(authSuccess->getValue())) { finishSession(Error::ServerVerificationFailedError); } else { - state = WaitingForStreamStart; + state = State::WaitingForStreamStart; delete authenticator; authenticator = nullptr; stream->resetXMPPParser(); sendStreamHeader(); } } else if (dynamic_cast<AuthFailure*>(element.get())) { finishSession(Error::AuthenticationFailedError); } else if (dynamic_cast<TLSProceed*>(element.get())) { - CHECK_STATE_OR_RETURN(WaitingForEncrypt); - state = Encrypting; + CHECK_STATE_OR_RETURN(State::WaitingForEncrypt); + state = State::Encrypting; stream->addTLSEncryption(); } else if (dynamic_cast<StartTLSFailure*>(element.get())) { finishSession(Error::TLSError); } else { // FIXME Not correct? - state = Initialized; + state = State::Initialized; onInitialized(); } } void ClientSession::continueSessionInitialization() { if (needResourceBind) { - state = BindingResource; + state = State::BindingResource; std::shared_ptr<ResourceBind> resourceBind(std::make_shared<ResourceBind>()); if (!localJID.getResource().empty()) { resourceBind->setResource(localJID.getResource()); } sendStanza(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); } else if (needAcking) { - state = EnablingSessionManagement; + state = State::EnablingSessionManagement; stream->writeElement(std::make_shared<EnableStreamManagement>()); } else if (needSessionStart) { - state = StartingSession; + state = State::StartingSession; sendStanza(IQ::createRequest(IQ::Set, JID(), "session-start", std::make_shared<StartSession>())); } else { - state = Initialized; + state = State::Initialized; onInitialized(); } } bool ClientSession::checkState(State state) { if (this->state != state) { finishSession(Error::UnexpectedElementError); return false; } return true; } void ClientSession::sendCredentials(const SafeByteArray& password) { - assert(WaitingForCredentials); + assert(state == State::WaitingForCredentials); assert(authenticator); - state = Authenticating; + state = State::Authenticating; authenticator->setCredentials(localJID.getNode(), password); stream->writeElement(std::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse())); } void ClientSession::handleTLSEncrypted() { - CHECK_STATE_OR_RETURN(Encrypting); + CHECK_STATE_OR_RETURN(State::Encrypting); std::vector<Certificate::ref> certificateChain = stream->getPeerCertificateChain(); std::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError(); if (verificationError) { checkTrustOrFinish(certificateChain, verificationError); } else { ServerIdentityVerifier identityVerifier(localJID, idnConverter); if (!certificateChain.empty() && identityVerifier.certificateVerifies(certificateChain[0])) { continueAfterTLSEncrypted(); } else { checkTrustOrFinish(certificateChain, std::make_shared<CertificateVerificationError>(CertificateVerificationError::InvalidServerIdentity)); } } } void ClientSession::checkTrustOrFinish(const std::vector<Certificate::ref>& certificateChain, std::shared_ptr<CertificateVerificationError> error) { if (certificateTrustChecker && certificateTrustChecker->isCertificateTrusted(certificateChain)) { continueAfterTLSEncrypted(); } else { finishSession(error); } } +void ClientSession::initiateShutdown(bool sendFooter) { + if (!streamShutdownTimeout) { + streamShutdownTimeout = timerFactory->createTimer(sessionShutdownTimeoutInMilliseconds); + streamShutdownTimeout->onTick.connect(boost::bind(&ClientSession::handleStreamShutdownTimeout, shared_from_this())); + streamShutdownTimeout->start(); + } + if (sendFooter) { + stream->writeFooter(); + } + if (state == State::Finishing) { + // The other side already send </stream>; we can close the socket. + stream->close(); + } + else { + state = State::Finishing; + } +} + void ClientSession::continueAfterTLSEncrypted() { - state = WaitingForStreamStart; + state = State::WaitingForStreamStart; stream->resetXMPPParser(); sendStreamHeader(); } void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError) { State previousState = state; - state = Finished; + state = State::Finished; + + if (streamShutdownTimeout) { + streamShutdownTimeout->stop(); + streamShutdownTimeout.reset(); + } if (stanzaAckRequester_) { stanzaAckRequester_->onRequestAck.disconnect(boost::bind(&ClientSession::requestAck, shared_from_this())); stanzaAckRequester_->onStanzaAcked.disconnect(boost::bind(&ClientSession::handleStanzaAcked, shared_from_this(), _1)); stanzaAckRequester_.reset(); } if (stanzaAckResponder_) { stanzaAckResponder_->onAck.disconnect(boost::bind(&ClientSession::ack, shared_from_this(), _1)); stanzaAckResponder_.reset(); } stream->setWhitespacePingEnabled(false); stream->onStreamStartReceived.disconnect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onStreamEndReceived.disconnect(boost::bind(&ClientSession::handleStreamEnd, shared_from_this())); stream->onElementReceived.disconnect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); stream->onClosed.disconnect(boost::bind(&ClientSession::handleStreamClosed, shared_from_this(), _1)); stream->onTLSEncrypted.disconnect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); - if (previousState == Finishing) { + if (previousState == State::Finishing) { onFinished(error_); } else { onFinished(streamError); } } +void ClientSession::handleStreamShutdownTimeout() { + handleStreamClosed(std::shared_ptr<Swift::Error>()); +} + void ClientSession::finish() { - finishSession(std::shared_ptr<Error>()); + if (state != State::Finishing && state != State::Finished) { + finishSession(std::shared_ptr<Error>()); + } + else { + SWIFT_LOG(warning) << "Session already finished or finishing." << std::endl; + } } void ClientSession::finishSession(Error::Type error) { finishSession(std::make_shared<Swift::ClientSession::Error>(error)); } void ClientSession::finishSession(std::shared_ptr<Swift::Error> error) { - state = Finishing; if (!error_) { error_ = error; } else { SWIFT_LOG(warning) << "Session finished twice"; } assert(stream->isOpen()); if (stanzaAckResponder_) { stanzaAckResponder_->handleAckRequestReceived(); } if (authenticator) { delete authenticator; authenticator = nullptr; } - stream->writeFooter(); - stream->close(); + // Immidiately close TCP connection without stream closure. + if (std::dynamic_pointer_cast<CertificateVerificationError>(error)) { + state = State::Finishing; + initiateShutdown(false); + } + else { + if (state == State::Finishing) { + initiateShutdown(true); + } + else if (state != State::Finished) { + initiateShutdown(true); + } + } } void ClientSession::requestAck() { stream->writeElement(std::make_shared<StanzaAckRequest>()); } void ClientSession::handleStanzaAcked(std::shared_ptr<Stanza> stanza) { onStanzaAcked(stanza); } void ClientSession::ack(unsigned int handledStanzasCount) { stream->writeElement(std::make_shared<StanzaAck>(handledStanzasCount)); } } diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index ad1b78c..c7b3658 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -1,201 +1,215 @@ /* * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once #include <memory> #include <string> #include <boost/signals2.hpp> #include <Swiften/Base/API.h> #include <Swiften/Base/Error.h> #include <Swiften/Elements/ToplevelElement.h> #include <Swiften/JID/JID.h> #include <Swiften/Session/SessionStream.h> -#include <Swiften/StreamManagement/StanzaAckRequester.h> -#include <Swiften/StreamManagement/StanzaAckResponder.h> namespace Swift { - class ClientAuthenticator; class CertificateTrustChecker; - class IDNConverter; + class ClientAuthenticator; class CryptoProvider; + class IDNConverter; + class Stanza; + class StanzaAckRequester; + class StanzaAckResponder; + class TimerFactory; + class Timer; class SWIFTEN_API ClientSession : public std::enable_shared_from_this<ClientSession> { public: - enum State { + enum class State { Initial, WaitingForStreamStart, Negotiating, Compressing, WaitingForEncrypt, Encrypting, WaitingForCredentials, Authenticating, EnablingSessionManagement, BindingResource, StartingSession, Initialized, Finishing, Finished }; struct Error : public Swift::Error { enum Type { AuthenticationFailedError, CompressionFailedError, ServerVerificationFailedError, NoSupportedAuthMechanismsError, UnexpectedElementError, ResourceBindError, SessionStartError, TLSClientCertificateError, TLSError, - StreamError + StreamError, + StreamEndError, // The server send a closing stream tag. } type; std::shared_ptr<boost::system::error_code> errorCode; Error(Type type) : type(type) {} }; enum UseTLS { NeverUseTLS, UseTLSWhenAvailable, RequireTLS }; ~ClientSession(); - static std::shared_ptr<ClientSession> create(const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, CryptoProvider* crypto) { - return std::shared_ptr<ClientSession>(new ClientSession(jid, stream, idnConverter, crypto)); + static std::shared_ptr<ClientSession> create(const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, CryptoProvider* crypto, TimerFactory* timerFactory) { + return std::shared_ptr<ClientSession>(new ClientSession(jid, stream, idnConverter, crypto, timerFactory)); } State getState() const { return state; } void setAllowPLAINOverNonTLS(bool b) { allowPLAINOverNonTLS = b; } void setUseStreamCompression(bool b) { useStreamCompression = b; } void setUseTLS(UseTLS b) { useTLS = b; } void setUseAcks(bool b) { useAcks = b; } - bool getStreamManagementEnabled() const { // Explicitly convert to bool. In C++11, it would be cleaner to // compare to nullptr. return static_cast<bool>(stanzaAckRequester_); } bool getRosterVersioningSupported() const { return rosterVersioningSupported; } std::vector<Certificate::ref> getPeerCertificateChain() const { return stream->getPeerCertificateChain(); } const JID& getLocalJID() const { return localJID; } void start(); void finish(); bool isFinished() const { - return getState() == Finished; + return getState() == State::Finished; } void sendCredentials(const SafeByteArray& password); void sendStanza(std::shared_ptr<Stanza>); void setCertificateTrustChecker(CertificateTrustChecker* checker) { certificateTrustChecker = checker; } void setSingleSignOn(bool b) { singleSignOn = b; } /** * Sets the port number used in Kerberos authentication * Does not affect network connectivity. */ void setAuthenticationPort(int i) { authenticationPort = i; } + void setSessionShutdownTimeout(int timeoutInMilliseconds) { + sessionShutdownTimeoutInMilliseconds = timeoutInMilliseconds; + } + public: boost::signals2::signal<void ()> onNeedCredentials; boost::signals2::signal<void ()> onInitialized; boost::signals2::signal<void (std::shared_ptr<Swift::Error>)> onFinished; boost::signals2::signal<void (std::shared_ptr<Stanza>)> onStanzaReceived; boost::signals2::signal<void (std::shared_ptr<Stanza>)> onStanzaAcked; private: ClientSession( const JID& jid, std::shared_ptr<SessionStream>, IDNConverter* idnConverter, - CryptoProvider* crypto); + CryptoProvider* crypto, + TimerFactory* timerFactory); void finishSession(Error::Type error); void finishSession(std::shared_ptr<Swift::Error> error); JID getRemoteJID() const { return JID("", localJID.getDomain()); } void sendStreamHeader(); void handleElement(std::shared_ptr<ToplevelElement>); void handleStreamStart(const ProtocolHeader&); + void handleStreamEnd(); void handleStreamClosed(std::shared_ptr<Swift::Error>); + void handleStreamShutdownTimeout(); void handleTLSEncrypted(); bool checkState(State); void continueSessionInitialization(); void requestAck(); void handleStanzaAcked(std::shared_ptr<Stanza> stanza); void ack(unsigned int handledStanzasCount); void continueAfterTLSEncrypted(); void checkTrustOrFinish(const std::vector<Certificate::ref>& certificateChain, std::shared_ptr<CertificateVerificationError> error); + void initiateShutdown(bool sendFooter); private: JID localJID; State state; std::shared_ptr<SessionStream> stream; - IDNConverter* idnConverter; - CryptoProvider* crypto; + IDNConverter* idnConverter = nullptr; + CryptoProvider* crypto = nullptr; + TimerFactory* timerFactory = nullptr; + std::shared_ptr<Timer> streamShutdownTimeout; + int sessionShutdownTimeoutInMilliseconds = 10000; bool allowPLAINOverNonTLS; bool useStreamCompression; UseTLS useTLS; bool useAcks; bool needSessionStart; bool needResourceBind; bool needAcking; bool rosterVersioningSupported; ClientAuthenticator* authenticator; std::shared_ptr<StanzaAckRequester> stanzaAckRequester_; std::shared_ptr<StanzaAckResponder> stanzaAckResponder_; std::shared_ptr<Swift::Error> error_; CertificateTrustChecker* certificateTrustChecker; bool singleSignOn; int authenticationPort; }; } diff --git a/Swiften/Client/ClientSessionStanzaChannel.h b/Swiften/Client/ClientSessionStanzaChannel.h index 0527a5c..c4ee393 100644 --- a/Swiften/Client/ClientSessionStanzaChannel.h +++ b/Swiften/Client/ClientSessionStanzaChannel.h @@ -6,47 +6,47 @@ #pragma once #include <memory> #include <Swiften/Base/API.h> #include <Swiften/Base/IDGenerator.h> #include <Swiften/Client/ClientSession.h> #include <Swiften/Client/StanzaChannel.h> #include <Swiften/Elements/IQ.h> #include <Swiften/Elements/Message.h> #include <Swiften/Elements/Presence.h> namespace Swift { /** * StanzaChannel implementation around a ClientSession. */ class SWIFTEN_API ClientSessionStanzaChannel : public StanzaChannel { public: virtual ~ClientSessionStanzaChannel(); void setSession(std::shared_ptr<ClientSession> session); void sendIQ(std::shared_ptr<IQ> iq); void sendMessage(std::shared_ptr<Message> message); void sendPresence(std::shared_ptr<Presence> presence); bool getStreamManagementEnabled() const; virtual std::vector<Certificate::ref> getPeerCertificateChain() const; bool isAvailable() const { - return session && session->getState() == ClientSession::Initialized; + return session && session->getState() == ClientSession::State::Initialized; } private: std::string getNewIQID(); void send(std::shared_ptr<Stanza> stanza); void handleSessionFinished(std::shared_ptr<Error> error); void handleStanza(std::shared_ptr<Stanza> stanza); void handleStanzaAcked(std::shared_ptr<Stanza> stanza); void handleSessionInitialized(); private: IDGenerator idGenerator; std::shared_ptr<ClientSession> session; }; } diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index 3d75d8b..3c7902e 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -127,66 +127,67 @@ void CoreClient::connect(const ClientOptions& o) { std::shared_ptr<BOSHSessionStream> boshSessionStream_ = std::shared_ptr<BOSHSessionStream>(new BOSHSessionStream( options.boshURL, getPayloadParserFactories(), getPayloadSerializers(), networkFactories->getConnectionFactory(), networkFactories->getTLSContextFactory(), networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory(), networkFactories->getEventLoop(), networkFactories->getDomainNameResolver(), host, options.boshHTTPConnectProxyURL, options.boshHTTPConnectProxyAuthID, options.boshHTTPConnectProxyAuthPassword, options.tlsOptions, options.httpTrafficFilter)); sessionStream_ = boshSessionStream_; sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1)); sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1)); if (certificate_ && !certificate_->isNull()) { SWIFT_LOG(debug) << "set certificate" << std::endl; sessionStream_->setTLSCertificate(certificate_); } boshSessionStream_->open(); bindSessionToStream(); } } void CoreClient::bindSessionToStream() { - session_ = ClientSession::create(jid_, sessionStream_, networkFactories->getIDNConverter(), networkFactories->getCryptoProvider()); + session_ = ClientSession::create(jid_, sessionStream_, networkFactories->getIDNConverter(), networkFactories->getCryptoProvider(), networkFactories->getTimerFactory()); session_->setCertificateTrustChecker(certificateTrustChecker); session_->setUseStreamCompression(options.useStreamCompression); session_->setAllowPLAINOverNonTLS(options.allowPLAINWithoutTLS); session_->setSingleSignOn(options.singleSignOn); session_->setAuthenticationPort(options.manualPort); + session_->setSessionShutdownTimeout(options.sessionShutdownTimeoutInMilliseconds); switch(options.useTLS) { case ClientOptions::UseTLSWhenAvailable: session_->setUseTLS(ClientSession::UseTLSWhenAvailable); break; case ClientOptions::NeverUseTLS: session_->setUseTLS(ClientSession::NeverUseTLS); break; case ClientOptions::RequireTLS: session_->setUseTLS(ClientSession::RequireTLS); break; } session_->setUseAcks(options.useAcks); stanzaChannel_->setSession(session_); session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1)); session_->onNeedCredentials.connect(boost::bind(&CoreClient::handleNeedCredentials, this)); session_->start(); } /** * Only called for TCP sessions. BOSH is handled inside the BOSHSessionStream. */ void CoreClient::handleConnectorFinished(std::shared_ptr<Connection> connection, std::shared_ptr<Error> error) { resetConnector(); if (!connection) { if (options.forgetPassword) { purgePassword(); } boost::optional<ClientError> clientError; if (!disconnectRequested_) { clientError = std::dynamic_pointer_cast<DomainNameResolveError>(error) ? boost::optional<ClientError>(ClientError::DomainNameResolveError) : boost::optional<ClientError>(ClientError::ConnectionError); @@ -246,60 +247,63 @@ void CoreClient::handleSessionFinished(std::shared_ptr<Error> error) { case ClientSession::Error::AuthenticationFailedError: clientError = ClientError(ClientError::AuthenticationFailedError); break; case ClientSession::Error::CompressionFailedError: clientError = ClientError(ClientError::CompressionFailedError); break; case ClientSession::Error::ServerVerificationFailedError: clientError = ClientError(ClientError::ServerVerificationFailedError); break; case ClientSession::Error::NoSupportedAuthMechanismsError: clientError = ClientError(ClientError::NoSupportedAuthMechanismsError); break; case ClientSession::Error::UnexpectedElementError: clientError = ClientError(ClientError::UnexpectedElementError); break; case ClientSession::Error::ResourceBindError: clientError = ClientError(ClientError::ResourceBindError); break; case ClientSession::Error::SessionStartError: clientError = ClientError(ClientError::SessionStartError); break; case ClientSession::Error::TLSError: clientError = ClientError(ClientError::TLSError); break; case ClientSession::Error::TLSClientCertificateError: clientError = ClientError(ClientError::ClientCertificateError); break; case ClientSession::Error::StreamError: clientError = ClientError(ClientError::StreamError); break; + case ClientSession::Error::StreamEndError: + clientError = ClientError(ClientError::StreamError); + break; } clientError.setErrorCode(actualError->errorCode); } else if (std::shared_ptr<TLSError> actualError = std::dynamic_pointer_cast<TLSError>(error)) { switch(actualError->getType()) { case TLSError::CertificateCardRemoved: clientError = ClientError(ClientError::CertificateCardRemoved); break; case TLSError::UnknownError: clientError = ClientError(ClientError::TLSError); break; } } else if (std::shared_ptr<SessionStream::SessionStreamError> actualError = std::dynamic_pointer_cast<SessionStream::SessionStreamError>(error)) { switch(actualError->type) { case SessionStream::SessionStreamError::ParseError: clientError = ClientError(ClientError::XMLError); break; case SessionStream::SessionStreamError::TLSError: clientError = ClientError(ClientError::TLSError); break; case SessionStream::SessionStreamError::InvalidTLSCertificateError: clientError = ClientError(ClientError::ClientCertificateLoadError); break; case SessionStream::SessionStreamError::ConnectionReadError: clientError = ClientError(ClientError::ConnectionReadError); break; case SessionStream::SessionStreamError::ConnectionWriteError: clientError = ClientError(ClientError::ConnectionWriteError); break; diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index bd93f4b..44df961 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -1,380 +1,458 @@ /* * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #include <deque> #include <memory> #include <boost/bind.hpp> #include <boost/optional.hpp> +#include <Swiften/Base/Debug.h> + #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> #include <Swiften/Client/ClientSession.h> #include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Crypto/PlatformCryptoProvider.h> #include <Swiften/Elements/AuthChallenge.h> #include <Swiften/Elements/AuthFailure.h> #include <Swiften/Elements/AuthRequest.h> #include <Swiften/Elements/AuthSuccess.h> #include <Swiften/Elements/EnableStreamManagement.h> #include <Swiften/Elements/IQ.h> #include <Swiften/Elements/Message.h> #include <Swiften/Elements/ResourceBind.h> #include <Swiften/Elements/StanzaAck.h> #include <Swiften/Elements/StartTLSFailure.h> #include <Swiften/Elements/StartTLSRequest.h> #include <Swiften/Elements/StreamError.h> #include <Swiften/Elements/StreamFeatures.h> #include <Swiften/Elements/StreamManagementEnabled.h> #include <Swiften/Elements/StreamManagementFailed.h> #include <Swiften/Elements/TLSProceed.h> #include <Swiften/IDN/IDNConverter.h> #include <Swiften/IDN/PlatformIDNConverter.h> +#include <Swiften/Network/DummyTimerFactory.h> #include <Swiften/Session/SessionStream.h> #include <Swiften/TLS/BlindCertificateTrustChecker.h> #include <Swiften/TLS/SimpleCertificate.h> using namespace Swift; class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(ClientSessionTest); CPPUNIT_TEST(testStart_Error); CPPUNIT_TEST(testStart_StreamError); CPPUNIT_TEST(testStartTLS); CPPUNIT_TEST(testStartTLS_ServerError); CPPUNIT_TEST(testStartTLS_ConnectError); CPPUNIT_TEST(testStartTLS_InvalidIdentity); CPPUNIT_TEST(testStart_StreamFeaturesWithoutResourceBindingFails); CPPUNIT_TEST(testAuthenticate); CPPUNIT_TEST(testAuthenticate_Unauthorized); CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms); CPPUNIT_TEST(testAuthenticate_PLAINOverNonTLS); CPPUNIT_TEST(testAuthenticate_RequireTLS); CPPUNIT_TEST(testAuthenticate_EXTERNAL); CPPUNIT_TEST(testStreamManagement); CPPUNIT_TEST(testStreamManagement_Failed); CPPUNIT_TEST(testUnexpectedChallenge); CPPUNIT_TEST(testFinishAcksStanzas); + + CPPUNIT_TEST(testServerInitiatedSessionClose); + CPPUNIT_TEST(testClientInitiatedSessionClose); + CPPUNIT_TEST(testTimeoutOnShutdown); + /* CPPUNIT_TEST(testResourceBind); CPPUNIT_TEST(testResourceBind_ChangeResource); CPPUNIT_TEST(testResourceBind_EmptyResource); CPPUNIT_TEST(testResourceBind_Error); CPPUNIT_TEST(testSessionStart); CPPUNIT_TEST(testSessionStart_Error); CPPUNIT_TEST(testSessionStart_AfterResourceBind); CPPUNIT_TEST(testWhitespacePing); CPPUNIT_TEST(testReceiveElementAfterSessionStarted); CPPUNIT_TEST(testSendElement); */ CPPUNIT_TEST_SUITE_END(); public: void setUp() { crypto = std::shared_ptr<CryptoProvider>(PlatformCryptoProvider::create()); idnConverter = std::shared_ptr<IDNConverter>(PlatformIDNConverter::create()); + timerFactory = std::make_shared<DummyTimerFactory>(); server = std::make_shared<MockSessionStream>(); sessionFinishedReceived = false; needCredentials = false; blindCertificateTrustChecker = new BlindCertificateTrustChecker(); } void tearDown() { delete blindCertificateTrustChecker; } void testStart_Error() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->breakConnection(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStart_StreamError() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->sendStreamStart(); server->sendStreamError(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStartTLS() { std::shared_ptr<ClientSession> session(createSession()); session->setCertificateTrustChecker(blindCertificateTrustChecker); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithStartTLS(); server->receiveStartTLS(); CPPUNIT_ASSERT(!server->tlsEncrypted); server->sendTLSProceed(); CPPUNIT_ASSERT(server->tlsEncrypted); server->onTLSEncrypted(); server->receiveStreamStart(); server->sendStreamStart(); session->finish(); } void testStartTLS_ServerError() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithStartTLS(); server->receiveStartTLS(); server->sendTLSFailure(); CPPUNIT_ASSERT(!server->tlsEncrypted); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStartTLS_ConnectError() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithStartTLS(); server->receiveStartTLS(); server->sendTLSProceed(); server->breakTLS(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStartTLS_InvalidIdentity() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithStartTLS(); server->receiveStartTLS(); CPPUNIT_ASSERT(!server->tlsEncrypted); server->sendTLSProceed(); CPPUNIT_ASSERT(server->tlsEncrypted); server->onTLSEncrypted(); + server->close(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); + CPPUNIT_ASSERT(std::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError)); CPPUNIT_ASSERT_EQUAL(CertificateVerificationError::InvalidServerIdentity, std::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError)->getType()); } void testStart_StreamFeaturesWithoutResourceBindingFails() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendEmptyStreamFeatures(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); CPPUNIT_ASSERT(needCredentials); - CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::WaitingForCredentials, session->getState()); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); server->receiveStreamStart(); session->finish(); } void testAuthenticate_Unauthorized() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); CPPUNIT_ASSERT(needCredentials); - CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::WaitingForCredentials, session->getState()); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthFailure(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate_PLAINOverNonTLS() { std::shared_ptr<ClientSession> session(createSession()); session->setAllowPLAINOverNonTLS(false); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate_RequireTLS() { std::shared_ptr<ClientSession> session(createSession()); session->setUseTLS(ClientSession::RequireTLS); session->setAllowPLAINOverNonTLS(true); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithMultipleAuthentication(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate_NoValidAuthMechanisms() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithUnknownAuthentication(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + server->close(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate_EXTERNAL() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithEXTERNALAuthentication(); server->receiveAuthRequest("EXTERNAL"); server->sendAuthSuccess(); server->receiveStreamStart(); session->finish(); } void testUnexpectedChallenge() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithEXTERNALAuthentication(); server->receiveAuthRequest("EXTERNAL"); server->sendChallenge(); server->sendChallenge(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStreamManagement() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithBindAndStreamManagement(); server->receiveBind(); server->sendBindResult(); server->receiveStreamManagementEnable(); server->sendStreamManagementEnabled(); CPPUNIT_ASSERT(session->getStreamManagementEnabled()); // TODO: Test if the requesters & responders do their work - CPPUNIT_ASSERT_EQUAL(ClientSession::Initialized, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Initialized, session->getState()); session->finish(); } void testStreamManagement_Failed() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithBindAndStreamManagement(); server->receiveBind(); server->sendBindResult(); server->receiveStreamManagementEnable(); server->sendStreamManagementFailed(); CPPUNIT_ASSERT(!session->getStreamManagementEnabled()); - CPPUNIT_ASSERT_EQUAL(ClientSession::Initialized, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Initialized, session->getState()); session->finish(); } void testFinishAcksStanzas() { std::shared_ptr<ClientSession> session(createSession()); initializeSession(session); server->sendMessage(); server->sendMessage(); server->sendMessage(); session->finish(); server->receiveAck(3); } + void testServerInitiatedSessionClose() { + std::shared_ptr<ClientSession> session(createSession()); + initializeSession(session); + + server->onStreamEndReceived(); + server->close(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + } + + void testClientInitiatedSessionClose() { + std::shared_ptr<ClientSession> session(createSession()); + initializeSession(session); + + session->finish(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); + } + + void testTimeoutOnShutdown() { + std::shared_ptr<ClientSession> session(createSession()); + initializeSession(session); + + session->finish(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + timerFactory->setTime(60000); + + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + } + private: std::shared_ptr<ClientSession> createSession() { - std::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server, idnConverter.get(), crypto.get()); + std::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server, idnConverter.get(), crypto.get(), timerFactory.get()); session->onFinished.connect(boost::bind(&ClientSessionTest::handleSessionFinished, this, _1)); session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::handleSessionNeedCredentials, this)); session->setAllowPLAINOverNonTLS(true); return session; } void initializeSession(std::shared_ptr<ClientSession> session) { session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithBindAndStreamManagement(); server->receiveBind(); server->sendBindResult(); server->receiveStreamManagementEnable(); server->sendStreamManagementEnabled(); } void handleSessionFinished(std::shared_ptr<Error> error) { sessionFinishedReceived = true; sessionFinishedError = error; } void handleSessionNeedCredentials() { needCredentials = true; @@ -604,60 +682,61 @@ class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_ASSERT(event.element); std::shared_ptr<StanzaAck> ack = std::dynamic_pointer_cast<StanzaAck>(event.element); CPPUNIT_ASSERT(ack); CPPUNIT_ASSERT_EQUAL(n, ack->getHandledStanzasCount()); } Event popEvent() { CPPUNIT_ASSERT(!receivedEvents.empty()); Event event = receivedEvents.front(); receivedEvents.pop_front(); return event; } bool available; bool canTLSEncrypt; bool tlsEncrypted; bool compressed; bool whitespacePingEnabled; std::string bindID; int resetCount; std::deque<Event> receivedEvents; }; std::shared_ptr<IDNConverter> idnConverter; std::shared_ptr<MockSessionStream> server; bool sessionFinishedReceived; bool needCredentials; std::shared_ptr<Error> sessionFinishedError; BlindCertificateTrustChecker* blindCertificateTrustChecker; std::shared_ptr<CryptoProvider> crypto; + std::shared_ptr<DummyTimerFactory> timerFactory; }; CPPUNIT_TEST_SUITE_REGISTRATION(ClientSessionTest); #if 0 void testAuthenticate() { std::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::setNeedCredentials, this)); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithAuthentication(); session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); CPPUNIT_ASSERT(needCredentials_); getMockServer()->expectAuth("me", "mypass"); getMockServer()->sendAuthSuccess(); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); session->sendCredentials("mypass"); CPPUNIT_ASSERT_EQUAL(ClientSession::Authenticating, session->getState()); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::Negotiating, session->getState()); } void testAuthenticate_Unauthorized() { std::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp index 402f642..10c6ad0 100644 --- a/Swiften/Session/BasicSessionStream.cpp +++ b/Swiften/Session/BasicSessionStream.cpp @@ -13,88 +13,90 @@ #include <Swiften/StreamStack/CompressionLayer.h> #include <Swiften/StreamStack/ConnectionLayer.h> #include <Swiften/StreamStack/StreamStack.h> #include <Swiften/StreamStack/TLSLayer.h> #include <Swiften/StreamStack/WhitespacePingLayer.h> #include <Swiften/StreamStack/XMPPLayer.h> #include <Swiften/TLS/TLSContext.h> #include <Swiften/TLS/TLSContextFactory.h> namespace Swift { BasicSessionStream::BasicSessionStream( StreamType streamType, std::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSContextFactory* tlsContextFactory, TimerFactory* timerFactory, XMLParserFactory* xmlParserFactory, const TLSOptions& tlsOptions) : available(false), connection(connection), tlsContextFactory(tlsContextFactory), timerFactory(timerFactory), compressionLayer(nullptr), tlsLayer(nullptr), whitespacePingLayer(nullptr), tlsOptions_(tlsOptions) { xmppLayer = new XMPPLayer(payloadParserFactories, payloadSerializers, xmlParserFactory, streamType); xmppLayer->onStreamStart.connect(boost::bind(&BasicSessionStream::handleStreamStartReceived, this, _1)); + xmppLayer->onStreamEnd.connect(boost::bind(&BasicSessionStream::handleStreamEndReceived, this)); xmppLayer->onElement.connect(boost::bind(&BasicSessionStream::handleElementReceived, this, _1)); xmppLayer->onError.connect(boost::bind(&BasicSessionStream::handleXMPPError, this)); xmppLayer->onDataRead.connect(boost::bind(&BasicSessionStream::handleDataRead, this, _1)); xmppLayer->onWriteData.connect(boost::bind(&BasicSessionStream::handleDataWritten, this, _1)); connection->onDisconnected.connect(boost::bind(&BasicSessionStream::handleConnectionFinished, this, _1)); connectionLayer = new ConnectionLayer(connection); streamStack = new StreamStack(xmppLayer, connectionLayer); available = true; } BasicSessionStream::~BasicSessionStream() { delete compressionLayer; if (tlsLayer) { tlsLayer->onError.disconnect(boost::bind(&BasicSessionStream::handleTLSError, this, _1)); tlsLayer->onConnected.disconnect(boost::bind(&BasicSessionStream::handleTLSConnected, this)); delete tlsLayer; } delete whitespacePingLayer; delete streamStack; connection->onDisconnected.disconnect(boost::bind(&BasicSessionStream::handleConnectionFinished, this, _1)); delete connectionLayer; xmppLayer->onStreamStart.disconnect(boost::bind(&BasicSessionStream::handleStreamStartReceived, this, _1)); + xmppLayer->onStreamEnd.disconnect(boost::bind(&BasicSessionStream::handleStreamEndReceived, this)); xmppLayer->onElement.disconnect(boost::bind(&BasicSessionStream::handleElementReceived, this, _1)); xmppLayer->onError.disconnect(boost::bind(&BasicSessionStream::handleXMPPError, this)); xmppLayer->onDataRead.disconnect(boost::bind(&BasicSessionStream::handleDataRead, this, _1)); xmppLayer->onWriteData.disconnect(boost::bind(&BasicSessionStream::handleDataWritten, this, _1)); delete xmppLayer; } void BasicSessionStream::writeHeader(const ProtocolHeader& header) { assert(available); xmppLayer->writeHeader(header); } void BasicSessionStream::writeElement(std::shared_ptr<ToplevelElement> element) { assert(available); xmppLayer->writeElement(element); } void BasicSessionStream::writeFooter() { assert(available); xmppLayer->writeFooter(); } void BasicSessionStream::writeData(const std::string& data) { assert(available); xmppLayer->writeData(data); } void BasicSessionStream::close() { connection->disconnect(); } @@ -144,60 +146,64 @@ ByteArray BasicSessionStream::getTLSFinishMessage() const { bool BasicSessionStream::supportsZLibCompression() { return true; } void BasicSessionStream::addZLibCompression() { compressionLayer = new CompressionLayer(); streamStack->addLayer(compressionLayer); } void BasicSessionStream::setWhitespacePingEnabled(bool enabled) { if (enabled) { if (!whitespacePingLayer) { whitespacePingLayer = new WhitespacePingLayer(timerFactory); streamStack->addLayer(whitespacePingLayer); } whitespacePingLayer->setActive(); } else if (whitespacePingLayer) { whitespacePingLayer->setInactive(); } } void BasicSessionStream::resetXMPPParser() { xmppLayer->resetParser(); } void BasicSessionStream::handleStreamStartReceived(const ProtocolHeader& header) { onStreamStartReceived(header); } +void BasicSessionStream::handleStreamEndReceived() { + onStreamEndReceived(); +} + void BasicSessionStream::handleElementReceived(std::shared_ptr<ToplevelElement> element) { onElementReceived(element); } void BasicSessionStream::handleXMPPError() { available = false; onClosed(std::make_shared<SessionStreamError>(SessionStreamError::ParseError)); } void BasicSessionStream::handleTLSConnected() { onTLSEncrypted(); } void BasicSessionStream::handleTLSError(std::shared_ptr<TLSError> error) { available = false; onClosed(error); } void BasicSessionStream::handleConnectionFinished(const boost::optional<Connection::Error>& error) { available = false; if (error == Connection::ReadError) { onClosed(std::make_shared<SessionStreamError>(SessionStreamError::ConnectionReadError)); } else if (error) { onClosed(std::make_shared<SessionStreamError>(SessionStreamError::ConnectionWriteError)); } else { onClosed(std::shared_ptr<SessionStreamError>()); } } diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h index 1806cef..48b3d63 100644 --- a/Swiften/Session/BasicSessionStream.h +++ b/Swiften/Session/BasicSessionStream.h @@ -46,49 +46,50 @@ namespace Swift { virtual void close(); virtual bool isOpen(); virtual void writeHeader(const ProtocolHeader& header); virtual void writeElement(std::shared_ptr<ToplevelElement>); virtual void writeFooter(); virtual void writeData(const std::string& data); virtual bool supportsZLibCompression(); virtual void addZLibCompression(); virtual bool supportsTLSEncryption(); virtual void addTLSEncryption(); virtual bool isTLSEncrypted(); virtual Certificate::ref getPeerCertificate() const; virtual std::vector<Certificate::ref> getPeerCertificateChain() const; virtual std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const; virtual ByteArray getTLSFinishMessage() const; virtual void setWhitespacePingEnabled(bool); virtual void resetXMPPParser(); private: void handleConnectionFinished(const boost::optional<Connection::Error>& error); void handleXMPPError(); void handleTLSConnected(); void handleTLSError(std::shared_ptr<TLSError>); void handleStreamStartReceived(const ProtocolHeader&); + void handleStreamEndReceived(); void handleElementReceived(std::shared_ptr<ToplevelElement>); void handleDataRead(const SafeByteArray& data); void handleDataWritten(const SafeByteArray& data); private: bool available; std::shared_ptr<Connection> connection; TLSContextFactory* tlsContextFactory; TimerFactory* timerFactory; XMPPLayer* xmppLayer; ConnectionLayer* connectionLayer; CompressionLayer* compressionLayer; TLSLayer* tlsLayer; WhitespacePingLayer* whitespacePingLayer; StreamStack* streamStack; TLSOptions tlsOptions_; }; } diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h index f56f495..c5ec42a 100644 --- a/Swiften/Session/SessionStream.h +++ b/Swiften/Session/SessionStream.h @@ -48,45 +48,46 @@ namespace Swift { virtual void writeHeader(const ProtocolHeader& header) = 0; virtual void writeFooter() = 0; virtual void writeElement(std::shared_ptr<ToplevelElement>) = 0; virtual void writeData(const std::string& data) = 0; virtual bool supportsZLibCompression() = 0; virtual void addZLibCompression() = 0; virtual bool supportsTLSEncryption() = 0; virtual void addTLSEncryption() = 0; virtual bool isTLSEncrypted() = 0; virtual void setWhitespacePingEnabled(bool enabled) = 0; virtual void resetXMPPParser() = 0; void setTLSCertificate(CertificateWithKey::ref cert) { certificate = cert; } virtual bool hasTLSCertificate() { return certificate && !certificate->isNull(); } virtual Certificate::ref getPeerCertificate() const = 0; virtual std::vector<Certificate::ref> getPeerCertificateChain() const = 0; virtual std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const = 0; virtual ByteArray getTLSFinishMessage() const = 0; boost::signals2::signal<void (const ProtocolHeader&)> onStreamStartReceived; + boost::signals2::signal<void ()> onStreamEndReceived; boost::signals2::signal<void (std::shared_ptr<ToplevelElement>)> onElementReceived; boost::signals2::signal<void (std::shared_ptr<Error>)> onClosed; boost::signals2::signal<void ()> onTLSEncrypted; boost::signals2::signal<void (const SafeByteArray&)> onDataRead; boost::signals2::signal<void (const SafeByteArray&)> onDataWritten; protected: CertificateWithKey::ref getTLSCertificate() const { return certificate; } private: CertificateWithKey::ref certificate; }; } diff --git a/Swiften/StreamStack/XMPPLayer.cpp b/Swiften/StreamStack/XMPPLayer.cpp index 2eed906..982d13f 100644 --- a/Swiften/StreamStack/XMPPLayer.cpp +++ b/Swiften/StreamStack/XMPPLayer.cpp @@ -59,42 +59,43 @@ void XMPPLayer::handleDataRead(const SafeByteArray& data) { inParser_ = true; // FIXME: Converting to unsafe string. Should be ok, since we don't take passwords // from the stream in clients. If servers start using this, and require safe storage, // we need to fix this. if (!xmppParser_->parse(byteArrayToString(ByteArray(data.begin(), data.end())))) { inParser_ = false; onError(); return; } inParser_ = false; if (resetParserAfterParse_) { doResetParser(); } } void XMPPLayer::doResetParser() { delete xmppParser_; xmppParser_ = new XMPPParser(this, payloadParserFactories_, xmlParserFactory_); resetParserAfterParse_ = false; } void XMPPLayer::handleStreamStart(const ProtocolHeader& header) { onStreamStart(header); } void XMPPLayer::handleElement(std::shared_ptr<ToplevelElement> stanza) { onElement(stanza); } void XMPPLayer::handleStreamEnd() { + onStreamEnd(); } void XMPPLayer::resetParser() { if (inParser_) { resetParserAfterParse_ = true; } else { doResetParser(); } } } diff --git a/Swiften/StreamStack/XMPPLayer.h b/Swiften/StreamStack/XMPPLayer.h index 1d4abf8..f0b5afb 100644 --- a/Swiften/StreamStack/XMPPLayer.h +++ b/Swiften/StreamStack/XMPPLayer.h @@ -24,53 +24,54 @@ namespace Swift { class PayloadParserFactoryCollection; class XMPPSerializer; class PayloadSerializerCollection; class XMLParserFactory; class BOSHSessionStream; class SWIFTEN_API XMPPLayer : public XMPPParserClient, public HighLayer, boost::noncopyable { friend class BOSHSessionStream; public: XMPPLayer( PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, XMLParserFactory* xmlParserFactory, StreamType streamType, bool setExplictNSonTopLevelElements = false); virtual ~XMPPLayer(); void writeHeader(const ProtocolHeader& header); void writeFooter(); void writeElement(std::shared_ptr<ToplevelElement>); void writeData(const std::string& data); void resetParser(); protected: void handleDataRead(const SafeByteArray& data); void writeDataInternal(const SafeByteArray& data); public: boost::signals2::signal<void (const ProtocolHeader&)> onStreamStart; + boost::signals2::signal<void ()> onStreamEnd; boost::signals2::signal<void (std::shared_ptr<ToplevelElement>)> onElement; boost::signals2::signal<void (const SafeByteArray&)> onWriteData; boost::signals2::signal<void (const SafeByteArray&)> onDataRead; boost::signals2::signal<void ()> onError; private: void handleStreamStart(const ProtocolHeader&); void handleElement(std::shared_ptr<ToplevelElement>); void handleStreamEnd(); void doResetParser(); private: PayloadParserFactoryCollection* payloadParserFactories_; XMPPParser* xmppParser_; PayloadSerializerCollection* payloadSerializers_; XMLParserFactory* xmlParserFactory_; XMPPSerializer* xmppSerializer_; bool setExplictNSonTopLevelElements_; bool resetParserAfterParse_; bool inParser_; }; } |