diff options
-rwxr-xr-x | BuildTools/FixIncludes.py | 5 | ||||
-rw-r--r-- | Swift/ChangeLog.md | 1 | ||||
-rw-r--r-- | Swiften/Base/Debug.cpp | 50 | ||||
-rw-r--r-- | Swiften/Base/Debug.h | 8 | ||||
-rw-r--r-- | Swiften/ChangeLog.md | 4 | ||||
-rw-r--r-- | Swiften/Client/ClientOptions.h | 54 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.cpp | 184 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.h | 40 | ||||
-rw-r--r-- | Swiften/Client/ClientSessionStanzaChannel.h | 2 | ||||
-rw-r--r-- | Swiften/Client/CoreClient.cpp | 6 | ||||
-rw-r--r-- | Swiften/Client/UnitTest/ClientSessionTest.cpp | 111 | ||||
-rw-r--r-- | Swiften/Session/BasicSessionStream.cpp | 6 | ||||
-rw-r--r-- | Swiften/Session/BasicSessionStream.h | 1 | ||||
-rw-r--r-- | Swiften/Session/SessionStream.h | 1 | ||||
-rw-r--r-- | Swiften/StreamStack/XMPPLayer.cpp | 1 | ||||
-rw-r--r-- | Swiften/StreamStack/XMPPLayer.h | 1 |
16 files changed, 350 insertions, 125 deletions
diff --git a/BuildTools/FixIncludes.py b/BuildTools/FixIncludes.py index d1b8268..8984944 100755 --- a/BuildTools/FixIncludes.py +++ b/BuildTools/FixIncludes.py @@ -1,93 +1,96 @@ #!/usr/bin/env python import sys; import os; import re; from sets import Set filename = sys.argv[1] inPlace = False if "-i" in sys.argv: inPlace = True filename_base = os.path.basename(filename) (filename_name, filename_ext) = os.path.splitext(filename_base) c_stdlib_headers = Set(["assert.h", "limits.h", "signal.h", "stdlib.h", "ctype.h", "locale.h", "stdarg.h", "string.h", "errno.h", "math.h", "stddef.h", "time.h", "float.h", "setjmp.h", "stdio.h", "iso646.h", "wchar.h", "wctype.h", "complex.h", "inttypes.h", "stdint.h", "tgmath.h", "fenv.h", "stdbool.h"]) cpp_stdlib_headers = Set(["algorithm", "fstream", "list", "regex", "typeindex", "array", "functional", "locale", "set", "typeinfo", "atomic", "future", "map", "sstream", "type_traits", "bitset", "initializer_list", "memory", "stack", "unordered_map", "chrono", "iomanip", "mutex", "stdexcept", "unordered_set", "codecvt", "ios", "new", "streambuf", "utility", "complex", "iosfwd", "numeric", "string", "valarray", "condition_variable", "iostream", "ostream", "strstream", "vector", "deque", "istream", "queue", "system_error", "exception", "iterator", "random", "thread", "forward_list", "limits", "ratio", "tuple", "cassert", "ciso646", "csetjmp", "cstdio", "ctime", "cctype", "climits", "csignal", "cstdlib", "cwchar", "cerrno", "clocale", "cstdarg", "cstring", "cwctype", "cfloat", "cmath", "cstddef"]) class HeaderType: - PRAGMA_ONCE, CORRESPONDING_HEADER, C_STDLIB, CPP_STDLIB, BOOST, QT, OTHER, SWIFTEN, LIMBER, SLIMBER, SWIFT_CONTROLLERS, SLUIFT, SWIFTOOLS, SWIFT = range(14) + PRAGMA_ONCE, CORRESPONDING_HEADER, C_STDLIB, CPP_STDLIB, BOOST, QT, SWIFTEN_BASE_DEBUG, OTHER, SWIFTEN, LIMBER, SLIMBER, SWIFT_CONTROLLERS, SLUIFT, SWIFTOOLS, SWIFT = range(15) def findHeaderBlock(lines): start = False end = False lastLine = None for idx, line in enumerate(lines): if not start and line.startswith("#"): start = idx elif start and (not end) and (not line.startswith("#")) and line.strip(): end = idx-1 break if not end: end = len(lines) return (start, end) def lineToFileName(line): match = re.match( r'#include "(.*)"', line) if match: return match.group(1) match = re.match( r'#include <(.*)>', line) if match: return match.group(1) return False def fileNameToHeaderType(name): if name.endswith("/" + filename_name + ".h"): return HeaderType.CORRESPONDING_HEADER if name in c_stdlib_headers: return HeaderType.C_STDLIB if name in cpp_stdlib_headers: return HeaderType.CPP_STDLIB if name.startswith("boost"): return HeaderType.BOOST if name.startswith("Q"): return HeaderType.QT + if name.startswith("Swiften/Base/Debug.h"): + return HeaderType.SWIFTEN_BASE_DEBUG + if name.startswith("Swiften"): return HeaderType.SWIFTEN if name.startswith("Limber"): return HeaderType.LIMBER if name.startswith("Slimber"): return HeaderType.SLIMBER if name.startswith("Swift/Controllers"): return HeaderType.SWIFT_CONTROLLERS if name.startswith("Sluift"): return HeaderType.SLUIFT if name.startswith("SwifTools"): return HeaderType.SWIFTOOLS if name.startswith("Swift"): return HeaderType.SWIFT return HeaderType.OTHER def serializeHeaderGroups(groups): headerList = [] for group in range(0, HeaderType.SWIFT + 1): if group in groups: # sorted and without duplicates headers = sorted(list(set(groups[group]))) headerList.extend(headers) diff --git a/Swift/ChangeLog.md b/Swift/ChangeLog.md index 475062c..4c88d88 100644 --- a/Swift/ChangeLog.md +++ b/Swift/ChangeLog.md @@ -1,34 +1,35 @@ 4.0-in-progress --------------- - Fix UI layout issue for translations that require right-to-left (RTL) layout - macOS releases are now code-signed with a key from Apple, so they can be run without Gatekeeper trust warnings +- Handle sessions being closed by the server 4.0-beta2 ( 2016-07-20 ) ------------------------ - Fix Swift bug introduced in 4.0-beta1 that results in the UI sometimes getting stuck during login 4.0-beta1 ( 2016-07-15 ) ------------------------ - Support for message carbons (XEP-0280) - Improved spell checker support on Linux - Enabled trellis mode as a default feature, allowing several tiled chats windows to be shown at once - New chat theme including a new font - And assorted smaller features and usability enhancements 3.0 ( 2016-02-24 ) ------------------ - File transfer and Mac Notification Center issues fixed - Fix connection to servers with invalid or untrusted certificates on OS/X - Support for the Notification Center on OS X - Users can now authenticate using certificates (and smart cards on Windows) when using the 'BOSH' connection type. - Encryption on OS X now uses the platform's native 'Secure Transport' mechanisms. - Emoticons menu in chat dialogs - Bookmark for rooms can now be edited directly from the ‘Recent Chats’ list - Adds option to workaround servers that don’t interoperate well with Windows (schannel) encryption - Rooms entered while offline will now get entered on reconnect - Chats can now be seamlessly upgraded to multi-person chats by either inviting someone via the ‘cog’ menu, or dragging them from the roster. This relies on server-side support with an appropriate chatroom (MUC) service. - Highlighting of keywords and messages from particular users can now be configured (Keyword Highlighting Blog post). - Full profile vcards (contact information etc.) are now supported and can be configured for the user and queried for contacts. - Simple Communication Blocking is now supported (subject to server support) to allow the blocking of nuisance users. - Swift can now transfer files via the ‘Jingle File Transfer’ protocol. - The status setter will now remember previously set statuses and will allow quick access to these when the user types part of a recently used status. diff --git a/Swiften/Base/Debug.cpp b/Swiften/Base/Debug.cpp index b59de35..b4245c3 100644 --- a/Swiften/Base/Debug.cpp +++ b/Swiften/Base/Debug.cpp @@ -1,42 +1,43 @@ /* * Copyright (c) 2015-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #include <Swiften/Base/Debug.h> #include <iostream> #include <memory> #include <Swiften/Client/ClientError.h> +#include <Swiften/Client/ClientSession.h> #include <Swiften/Serializer/PayloadSerializer.h> #include <Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h> #include <Swiften/Serializer/XMPPSerializer.h> std::ostream& operator<<(std::ostream& os, const Swift::ClientError& error) { os << "ClientError("; switch(error.getType()) { case Swift::ClientError::UnknownError: os << "UnknownError"; break; case Swift::ClientError::DomainNameResolveError: os << "DomainNameResolveError"; break; case Swift::ClientError::ConnectionError: os << "ConnectionError"; break; case Swift::ClientError::ConnectionReadError: os << "ConnectionReadError"; break; case Swift::ClientError::ConnectionWriteError: os << "ConnectionWriteError"; break; case Swift::ClientError::XMLError: os << "XMLError"; break; case Swift::ClientError::AuthenticationFailedError: os << "AuthenticationFailedError"; break; case Swift::ClientError::CompressionFailedError: os << "CompressionFailedError"; @@ -111,30 +112,79 @@ std::ostream& operator<<(std::ostream& os, const Swift::ClientError& error) { os << "RevocationCheckFailedError"; break; } os << ")"; return os; } std::ostream& operator<<(std::ostream& os, Swift::Element* ele) { using namespace Swift; std::shared_ptr<Element> element = std::shared_ptr<Element>(ele); std::shared_ptr<Payload> payload = std::dynamic_pointer_cast<Payload>(element); if (payload) { FullPayloadSerializerCollection payloadSerializerCollection; PayloadSerializer *serializer = payloadSerializerCollection.getPayloadSerializer(payload); os << "Payload(" << serializer->serialize(payload) << ")"; return os; } std::shared_ptr<ToplevelElement> topLevelElement = std::dynamic_pointer_cast<ToplevelElement>(element); if (topLevelElement) { FullPayloadSerializerCollection payloadSerializerCollection; XMPPSerializer xmppSerializer(&payloadSerializerCollection, ClientStreamType, false); SafeByteArray serialized = xmppSerializer.serializeElement(topLevelElement); os << "TopLevelElement(" << safeByteArrayToString(serialized) << ")"; return os; } os << "Element(Unknown)"; return os; } + +std::ostream& operator<<(std::ostream& os, Swift::ClientSession::State state) { + using CS = Swift::ClientSession; + switch (state) { + case CS::State::Initial: + os << "ClientSession::State::Initial"; + break; + case CS::State::WaitingForStreamStart: + os << "ClientSession::State::WaitingForStreamStart"; + break; + case CS::State::Negotiating: + os << "ClientSession::State::Negotiating"; + break; + case CS::State::Compressing: + os << "ClientSession::State::Compressing"; + break; + case CS::State::WaitingForEncrypt: + os << "ClientSession::State::WaitingForEncrypt"; + break; + case CS::State::Encrypting: + os << "ClientSession::State::Encrypting"; + break; + case CS::State::WaitingForCredentials: + os << "ClientSession::State::WaitingForCredentials"; + break; + case CS::State::Authenticating: + os << "ClientSession::State::Authenticating"; + break; + case CS::State::EnablingSessionManagement: + os << "ClientSession::State::EnablingSessionManagement"; + break; + case CS::State::BindingResource: + os << "ClientSession::State::BindingResource"; + break; + case CS::State::StartingSession: + os << "ClientSession::State::StartingSession"; + break; + case CS::State::Initialized: + os << "ClientSession::State::Initialized"; + break; + case CS::State::Finishing: + os << "ClientSession::State::Finishing"; + break; + case CS::State::Finished: + os << "ClientSession::State::Finished"; + break; + } + return os; +} diff --git a/Swiften/Base/Debug.h b/Swiften/Base/Debug.h index 18e7fb4..9dde74c 100644 --- a/Swiften/Base/Debug.h +++ b/Swiften/Base/Debug.h @@ -1,20 +1,26 @@ /* - * Copyright (c) 2015 Isode Limited. + * Copyright (c) 2015-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ +#pragma once + #include <iosfwd> +#include <Swiften/Client/ClientSession.h> + namespace Swift { class ClientError; class Element; } namespace boost { template<class T> class shared_ptr; } std::ostream& operator<<(std::ostream& os, const Swift::ClientError& error); std::ostream& operator<<(std::ostream& os, Swift::Element* ele); + +std::ostream& operator<<(std::ostream& os, Swift::ClientSession::State state); diff --git a/Swiften/ChangeLog.md b/Swiften/ChangeLog.md index 4d52c68..62cea5e 100644 --- a/Swiften/ChangeLog.md +++ b/Swiften/ChangeLog.md @@ -1,13 +1,17 @@ +4.0-in-progress +--------------- +- Handle sessions being closed by the server + 4.0-beta1 ( 2016-07-15 ) ------------------------ - Moved code-base to C++11 - Use C++11 threading instead of Boost.Thread library - Use C++11 smart pointers instead of Boost's - Migrated from Boost.Signals to Boost.Signals2 - Build without warnings on our CI platforms - General cleanup like remove of superflous files and #include statements. This means header files that previously were included implictly need to be explicitly included now - Support IPv6 addresses in URLs - Changed source code style to use soft tabs (4 spaces wide) instead of hard tabs. Custom patches for Swiften will need to be reformatted accordingly - Require a TLS backend for building - Update 3rdParty/lcov to version 1.12 - Fix several possible race conditions and other small bugs
\ No newline at end of file diff --git a/Swiften/Client/ClientOptions.h b/Swiften/Client/ClientOptions.h index 3a93197..1a337b6 100644 --- a/Swiften/Client/ClientOptions.h +++ b/Swiften/Client/ClientOptions.h @@ -3,160 +3,152 @@ * All rights reserved. * See the COPYING file for more information. */ #pragma once #include <memory> #include <Swiften/Base/API.h> #include <Swiften/Base/SafeString.h> #include <Swiften/Base/URL.h> #include <Swiften/TLS/TLSOptions.h> namespace Swift { class HTTPTrafficFilter; struct SWIFTEN_API ClientOptions { enum UseTLS { NeverUseTLS, UseTLSWhenAvailable, RequireTLS }; enum ProxyType { NoProxy, SystemConfiguredProxy, SOCKS5Proxy, HTTPConnectProxy }; - ClientOptions() : - useStreamCompression(true), - useTLS(UseTLSWhenAvailable), - allowPLAINWithoutTLS(false), - useStreamResumption(false), - forgetPassword(false), - useAcks(true), - singleSignOn(false), - manualHostname(""), - manualPort(-1), - proxyType(SystemConfiguredProxy), - manualProxyHostname(""), - manualProxyPort(-1), - boshHTTPConnectProxyAuthID(""), - boshHTTPConnectProxyAuthPassword("") { + ClientOptions() { } /** * Whether ZLib stream compression should be used when available. * * Default: true */ - bool useStreamCompression; + bool useStreamCompression = true; /** * Sets whether TLS encryption should be used. * * Default: UseTLSWhenAvailable */ - UseTLS useTLS; + UseTLS useTLS = UseTLSWhenAvailable; /** * Sets whether plaintext authentication is * allowed over non-TLS-encrypted connections. * * Default: false */ - bool allowPLAINWithoutTLS; + bool allowPLAINWithoutTLS = false; /** * Use XEP-196 stream resumption when available. * * Default: false */ - bool useStreamResumption; + bool useStreamResumption = false; /** * Forget the password once it's used. * This makes the Client useless after the first login attempt. * * FIXME: This is a temporary workaround. * * Default: false */ - bool forgetPassword; + bool forgetPassword = false; /** * Use XEP-0198 acks in the stream when available. * Default: true */ - bool useAcks; + bool useAcks = true; /** * Use Single Sign On. * Default: false */ - bool singleSignOn; + bool singleSignOn = false; /** * The hostname to connect to. * Leave this empty for standard XMPP connection, based on the JID domain. */ - std::string manualHostname; + std::string manualHostname = ""; /** * The port to connect to. * Leave this to -1 to use the port discovered by SRV lookups, and 5222 as a * fallback. */ - int manualPort; + int manualPort = -1; /** * The type of proxy to use for connecting to the XMPP * server. */ - ProxyType proxyType; + ProxyType proxyType = SystemConfiguredProxy; /** * Override the system-configured proxy hostname. */ - std::string manualProxyHostname; + std::string manualProxyHostname = ""; /** * Override the system-configured proxy port. */ - int manualProxyPort; + int manualProxyPort = -1; /** * If non-empty, use BOSH instead of direct TCP, with the given URL. * Default: empty (no BOSH) */ - URL boshURL; + URL boshURL = URL(); /** * If non-empty, BOSH connections will try to connect over this HTTP CONNECT * proxy instead of directly. * Default: empty (no proxy) */ - URL boshHTTPConnectProxyURL; + URL boshHTTPConnectProxyURL = URL(); /** * If this and matching Password are non-empty, BOSH connections over * HTTP CONNECT proxies will use these credentials for proxy access. * Default: empty (no authentication needed by the proxy) */ - SafeString boshHTTPConnectProxyAuthID; - SafeString boshHTTPConnectProxyAuthPassword; + SafeString boshHTTPConnectProxyAuthID = SafeString(""); + SafeString boshHTTPConnectProxyAuthPassword = SafeString(""); /** * This can be initialized with a custom HTTPTrafficFilter, which allows HTTP CONNECT * proxy initialization to be customized. */ std::shared_ptr<HTTPTrafficFilter> httpTrafficFilter; /** * Options passed to the TLS stack */ TLSOptions tlsOptions; + + /** + * Session shutdown timeout in milliseconds. This is the maximum time Swiften + * waits from a session close to the socket close. + */ + int sessionShutdownTimeoutInMilliseconds = 10000; }; } diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index c301881..bcfb004 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -1,514 +1,576 @@ /* * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #include <Swiften/Client/ClientSession.h> +#include <memory> + #include <boost/bind.hpp> #include <boost/uuid/uuid.hpp> #include <boost/uuid/uuid_io.hpp> #include <boost/uuid/uuid_generators.hpp> -#include <memory> -#include <Swiften/Base/Platform.h> #include <Swiften/Base/Log.h> -#include <Swiften/Elements/ProtocolHeader.h> -#include <Swiften/Elements/StreamFeatures.h> -#include <Swiften/Elements/StreamError.h> -#include <Swiften/Elements/StartTLSRequest.h> -#include <Swiften/Elements/StartTLSFailure.h> -#include <Swiften/Elements/TLSProceed.h> -#include <Swiften/Elements/AuthRequest.h> -#include <Swiften/Elements/AuthSuccess.h> -#include <Swiften/Elements/AuthFailure.h> +#include <Swiften/Base/Platform.h> +#include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Elements/AuthChallenge.h> +#include <Swiften/Elements/AuthFailure.h> +#include <Swiften/Elements/AuthRequest.h> #include <Swiften/Elements/AuthResponse.h> -#include <Swiften/Elements/Compressed.h> +#include <Swiften/Elements/AuthSuccess.h> #include <Swiften/Elements/CompressFailure.h> #include <Swiften/Elements/CompressRequest.h> +#include <Swiften/Elements/Compressed.h> #include <Swiften/Elements/EnableStreamManagement.h> -#include <Swiften/Elements/StreamManagementEnabled.h> -#include <Swiften/Elements/StreamManagementFailed.h> -#include <Swiften/Elements/StartSession.h> -#include <Swiften/Elements/StanzaAck.h> -#include <Swiften/Elements/StanzaAckRequest.h> #include <Swiften/Elements/IQ.h> +#include <Swiften/Elements/ProtocolHeader.h> #include <Swiften/Elements/ResourceBind.h> -#include <Swiften/SASL/PLAINClientAuthenticator.h> +#include <Swiften/Elements/StanzaAck.h> +#include <Swiften/Elements/StanzaAckRequest.h> +#include <Swiften/Elements/StartSession.h> +#include <Swiften/Elements/StartTLSFailure.h> +#include <Swiften/Elements/StartTLSRequest.h> +#include <Swiften/Elements/StreamError.h> +#include <Swiften/Elements/StreamFeatures.h> +#include <Swiften/Elements/StreamManagementEnabled.h> +#include <Swiften/Elements/StreamManagementFailed.h> +#include <Swiften/Elements/TLSProceed.h> +#include <Swiften/Network/Timer.h> +#include <Swiften/Network/TimerFactory.h> +#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h> #include <Swiften/SASL/EXTERNALClientAuthenticator.h> +#include <Swiften/SASL/PLAINClientAuthenticator.h> #include <Swiften/SASL/SCRAMSHA1ClientAuthenticator.h> -#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h> -#include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Session/SessionStream.h> +#include <Swiften/StreamManagement/StanzaAckRequester.h> +#include <Swiften/StreamManagement/StanzaAckResponder.h> #include <Swiften/TLS/CertificateTrustChecker.h> #include <Swiften/TLS/ServerIdentityVerifier.h> #ifdef SWIFTEN_PLATFORM_WIN32 #include <Swiften/Base/WindowsRegistry.h> #include <Swiften/SASL/WindowsGSSAPIClientAuthenticator.h> #endif #define CHECK_STATE_OR_RETURN(a) \ if (!checkState(a)) { return; } namespace Swift { ClientSession::ClientSession( const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, - CryptoProvider* crypto) : + CryptoProvider* crypto, + TimerFactory* timerFactory) : localJID(jid), - state(Initial), + state(State::Initial), stream(stream), idnConverter(idnConverter), crypto(crypto), + timerFactory(timerFactory), allowPLAINOverNonTLS(false), useStreamCompression(true), useTLS(UseTLSWhenAvailable), useAcks(true), needSessionStart(false), needResourceBind(false), needAcking(false), rosterVersioningSupported(false), authenticator(nullptr), certificateTrustChecker(nullptr), singleSignOn(false), authenticationPort(-1) { #ifdef SWIFTEN_PLATFORM_WIN32 if (WindowsRegistry::isFIPSEnabled()) { SWIFT_LOG(info) << "Windows is running in FIPS-140 mode. Some authentication methods will be unavailable." << std::endl; } #endif } ClientSession::~ClientSession() { } void ClientSession::start() { stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onStreamEndReceived.connect(boost::bind(&ClientSession::handleStreamEnd, shared_from_this())); stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); stream->onClosed.connect(boost::bind(&ClientSession::handleStreamClosed, shared_from_this(), _1)); stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); - assert(state == Initial); - state = WaitingForStreamStart; + assert(state == State::Initial); + state = State::WaitingForStreamStart; sendStreamHeader(); } void ClientSession::sendStreamHeader() { ProtocolHeader header; header.setTo(getRemoteJID()); stream->writeHeader(header); } void ClientSession::sendStanza(std::shared_ptr<Stanza> stanza) { stream->writeElement(stanza); if (stanzaAckRequester_) { stanzaAckRequester_->handleStanzaSent(stanza); } } void ClientSession::handleStreamStart(const ProtocolHeader&) { - CHECK_STATE_OR_RETURN(WaitingForStreamStart); - state = Negotiating; + CHECK_STATE_OR_RETURN(State::WaitingForStreamStart); + state = State::Negotiating; +} + +void ClientSession::handleStreamEnd() { + if (state == State::Finishing) { + // We are already in finishing state if we iniated the close of the session. + stream->close(); + } + else { + state = State::Finishing; + initiateShutdown(true); + } } void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { if (std::shared_ptr<Stanza> stanza = std::dynamic_pointer_cast<Stanza>(element)) { if (stanzaAckResponder_) { stanzaAckResponder_->handleStanzaReceived(); } - if (getState() == Initialized) { + if (getState() == State::Initialized) { onStanzaReceived(stanza); } else if (std::shared_ptr<IQ> iq = std::dynamic_pointer_cast<IQ>(element)) { - if (state == BindingResource) { + if (state == State::BindingResource) { std::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { finishSession(Error::ResourceBindError); } else if (!resourceBind) { finishSession(Error::UnexpectedElementError); } else if (iq->getType() == IQ::Result) { localJID = resourceBind->getJID(); if (!localJID.isValid()) { finishSession(Error::ResourceBindError); } needResourceBind = false; continueSessionInitialization(); } else { finishSession(Error::UnexpectedElementError); } } - else if (state == StartingSession) { + else if (state == State::StartingSession) { if (iq->getType() == IQ::Result) { needSessionStart = false; continueSessionInitialization(); } else if (iq->getType() == IQ::Error) { finishSession(Error::SessionStartError); } else { finishSession(Error::UnexpectedElementError); } } else { finishSession(Error::UnexpectedElementError); } } } else if (std::dynamic_pointer_cast<StanzaAckRequest>(element)) { if (stanzaAckResponder_) { stanzaAckResponder_->handleAckRequestReceived(); } } else if (std::shared_ptr<StanzaAck> ack = std::dynamic_pointer_cast<StanzaAck>(element)) { if (stanzaAckRequester_) { if (ack->isValid()) { stanzaAckRequester_->handleAckReceived(ack->getHandledStanzasCount()); } else { SWIFT_LOG(warning) << "Got invalid ack from server"; } } else { SWIFT_LOG(warning) << "Ignoring ack"; } } else if (StreamError::ref streamError = std::dynamic_pointer_cast<StreamError>(element)) { finishSession(Error::StreamError); } - else if (getState() == Initialized) { + else if (getState() == State::Initialized) { std::shared_ptr<Stanza> stanza = std::dynamic_pointer_cast<Stanza>(element); if (stanza) { if (stanzaAckResponder_) { stanzaAckResponder_->handleStanzaReceived(); } onStanzaReceived(stanza); } } else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { - CHECK_STATE_OR_RETURN(Negotiating); + CHECK_STATE_OR_RETURN(State::Negotiating); if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption() && useTLS != NeverUseTLS) { - state = WaitingForEncrypt; + state = State::WaitingForEncrypt; stream->writeElement(std::make_shared<StartTLSRequest>()); } else if (useTLS == RequireTLS && !stream->isTLSEncrypted()) { finishSession(Error::NoSupportedAuthMechanismsError); } else if (useStreamCompression && stream->supportsZLibCompression() && streamFeatures->hasCompressionMethod("zlib")) { - state = Compressing; + state = State::Compressing; stream->writeElement(std::make_shared<CompressRequest>("zlib")); } else if (streamFeatures->hasAuthenticationMechanisms()) { #ifdef SWIFTEN_PLATFORM_WIN32 if (singleSignOn) { const boost::optional<std::string> authenticationHostname = streamFeatures->getAuthenticationHostname(); bool gssapiSupported = streamFeatures->hasAuthenticationMechanism("GSSAPI") && authenticationHostname && !authenticationHostname->empty(); if (!gssapiSupported) { finishSession(Error::NoSupportedAuthMechanismsError); } else { WindowsGSSAPIClientAuthenticator* gssapiAuthenticator = new WindowsGSSAPIClientAuthenticator(*authenticationHostname, localJID.getDomain(), authenticationPort); std::shared_ptr<Error> error = std::make_shared<Error>(Error::AuthenticationFailedError); authenticator = gssapiAuthenticator; if (!gssapiAuthenticator->isError()) { - state = Authenticating; + state = State::Authenticating; stream->writeElement(std::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse())); } else { error->errorCode = gssapiAuthenticator->getErrorCode(); finishSession(error); } } } else #endif if (stream->hasTLSCertificate()) { if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { authenticator = new EXTERNALClientAuthenticator(); - state = Authenticating; + state = State::Authenticating; stream->writeElement(std::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray(""))); } else { finishSession(Error::TLSClientCertificateError); } } else if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { authenticator = new EXTERNALClientAuthenticator(); - state = Authenticating; + state = State::Authenticating; stream->writeElement(std::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray(""))); } else if (streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1") || streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1-PLUS")) { std::ostringstream s; ByteArray finishMessage; bool plus = streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1-PLUS"); if (stream->isTLSEncrypted()) { finishMessage = stream->getTLSFinishMessage(); plus &= !finishMessage.empty(); } s << boost::uuids::random_generator()(); SCRAMSHA1ClientAuthenticator* scramAuthenticator = new SCRAMSHA1ClientAuthenticator(s.str(), plus, idnConverter, crypto); if (!finishMessage.empty()) { scramAuthenticator->setTLSChannelBindingData(finishMessage); } authenticator = scramAuthenticator; - state = WaitingForCredentials; + state = State::WaitingForCredentials; onNeedCredentials(); } else if ((stream->isTLSEncrypted() || allowPLAINOverNonTLS) && streamFeatures->hasAuthenticationMechanism("PLAIN")) { authenticator = new PLAINClientAuthenticator(); - state = WaitingForCredentials; + state = State::WaitingForCredentials; onNeedCredentials(); } else if (streamFeatures->hasAuthenticationMechanism("DIGEST-MD5") && crypto->isMD5AllowedForCrypto()) { std::ostringstream s; s << boost::uuids::random_generator()(); // FIXME: Host should probably be the actual host authenticator = new DIGESTMD5ClientAuthenticator(localJID.getDomain(), s.str(), crypto); - state = WaitingForCredentials; + state = State::WaitingForCredentials; onNeedCredentials(); } else { finishSession(Error::NoSupportedAuthMechanismsError); } } else { // Start the session rosterVersioningSupported = streamFeatures->hasRosterVersioning(); stream->setWhitespacePingEnabled(true); needSessionStart = streamFeatures->hasSession(); needResourceBind = streamFeatures->hasResourceBind(); needAcking = streamFeatures->hasStreamManagement() && useAcks; if (!needResourceBind) { // Resource binding is a MUST finishSession(Error::ResourceBindError); } else { continueSessionInitialization(); } } } else if (std::dynamic_pointer_cast<Compressed>(element)) { - CHECK_STATE_OR_RETURN(Compressing); - state = WaitingForStreamStart; + CHECK_STATE_OR_RETURN(State::Compressing); + state = State::WaitingForStreamStart; stream->addZLibCompression(); stream->resetXMPPParser(); sendStreamHeader(); } else if (std::dynamic_pointer_cast<CompressFailure>(element)) { finishSession(Error::CompressionFailedError); } else if (std::dynamic_pointer_cast<StreamManagementEnabled>(element)) { stanzaAckRequester_ = std::make_shared<StanzaAckRequester>(); stanzaAckRequester_->onRequestAck.connect(boost::bind(&ClientSession::requestAck, shared_from_this())); stanzaAckRequester_->onStanzaAcked.connect(boost::bind(&ClientSession::handleStanzaAcked, shared_from_this(), _1)); stanzaAckResponder_ = std::make_shared<StanzaAckResponder>(); stanzaAckResponder_->onAck.connect(boost::bind(&ClientSession::ack, shared_from_this(), _1)); needAcking = false; continueSessionInitialization(); } else if (std::dynamic_pointer_cast<StreamManagementFailed>(element)) { needAcking = false; continueSessionInitialization(); } else if (AuthChallenge* challenge = dynamic_cast<AuthChallenge*>(element.get())) { - CHECK_STATE_OR_RETURN(Authenticating); + CHECK_STATE_OR_RETURN(State::Authenticating); assert(authenticator); if (authenticator->setChallenge(challenge->getValue())) { stream->writeElement(std::make_shared<AuthResponse>(authenticator->getResponse())); } #ifdef SWIFTEN_PLATFORM_WIN32 else if (WindowsGSSAPIClientAuthenticator* gssapiAuthenticator = dynamic_cast<WindowsGSSAPIClientAuthenticator*>(authenticator)) { std::shared_ptr<Error> error = std::make_shared<Error>(Error::AuthenticationFailedError); error->errorCode = gssapiAuthenticator->getErrorCode(); finishSession(error); } #endif else { finishSession(Error::AuthenticationFailedError); } } else if (AuthSuccess* authSuccess = dynamic_cast<AuthSuccess*>(element.get())) { - CHECK_STATE_OR_RETURN(Authenticating); + CHECK_STATE_OR_RETURN(State::Authenticating); assert(authenticator); if (!authenticator->setChallenge(authSuccess->getValue())) { finishSession(Error::ServerVerificationFailedError); } else { - state = WaitingForStreamStart; + state = State::WaitingForStreamStart; delete authenticator; authenticator = nullptr; stream->resetXMPPParser(); sendStreamHeader(); } } else if (dynamic_cast<AuthFailure*>(element.get())) { finishSession(Error::AuthenticationFailedError); } else if (dynamic_cast<TLSProceed*>(element.get())) { - CHECK_STATE_OR_RETURN(WaitingForEncrypt); - state = Encrypting; + CHECK_STATE_OR_RETURN(State::WaitingForEncrypt); + state = State::Encrypting; stream->addTLSEncryption(); } else if (dynamic_cast<StartTLSFailure*>(element.get())) { finishSession(Error::TLSError); } else { // FIXME Not correct? - state = Initialized; + state = State::Initialized; onInitialized(); } } void ClientSession::continueSessionInitialization() { if (needResourceBind) { - state = BindingResource; + state = State::BindingResource; std::shared_ptr<ResourceBind> resourceBind(std::make_shared<ResourceBind>()); if (!localJID.getResource().empty()) { resourceBind->setResource(localJID.getResource()); } sendStanza(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); } else if (needAcking) { - state = EnablingSessionManagement; + state = State::EnablingSessionManagement; stream->writeElement(std::make_shared<EnableStreamManagement>()); } else if (needSessionStart) { - state = StartingSession; + state = State::StartingSession; sendStanza(IQ::createRequest(IQ::Set, JID(), "session-start", std::make_shared<StartSession>())); } else { - state = Initialized; + state = State::Initialized; onInitialized(); } } bool ClientSession::checkState(State state) { if (this->state != state) { finishSession(Error::UnexpectedElementError); return false; } return true; } void ClientSession::sendCredentials(const SafeByteArray& password) { - assert(WaitingForCredentials); + assert(state == State::WaitingForCredentials); assert(authenticator); - state = Authenticating; + state = State::Authenticating; authenticator->setCredentials(localJID.getNode(), password); stream->writeElement(std::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse())); } void ClientSession::handleTLSEncrypted() { - CHECK_STATE_OR_RETURN(Encrypting); + CHECK_STATE_OR_RETURN(State::Encrypting); std::vector<Certificate::ref> certificateChain = stream->getPeerCertificateChain(); std::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError(); if (verificationError) { checkTrustOrFinish(certificateChain, verificationError); } else { ServerIdentityVerifier identityVerifier(localJID, idnConverter); if (!certificateChain.empty() && identityVerifier.certificateVerifies(certificateChain[0])) { continueAfterTLSEncrypted(); } else { checkTrustOrFinish(certificateChain, std::make_shared<CertificateVerificationError>(CertificateVerificationError::InvalidServerIdentity)); } } } void ClientSession::checkTrustOrFinish(const std::vector<Certificate::ref>& certificateChain, std::shared_ptr<CertificateVerificationError> error) { if (certificateTrustChecker && certificateTrustChecker->isCertificateTrusted(certificateChain)) { continueAfterTLSEncrypted(); } else { finishSession(error); } } +void ClientSession::initiateShutdown(bool sendFooter) { + if (!streamShutdownTimeout) { + streamShutdownTimeout = timerFactory->createTimer(sessionShutdownTimeoutInMilliseconds); + streamShutdownTimeout->onTick.connect(boost::bind(&ClientSession::handleStreamShutdownTimeout, shared_from_this())); + streamShutdownTimeout->start(); + } + if (sendFooter) { + stream->writeFooter(); + } + if (state == State::Finishing) { + // The other side already send </stream>; we can close the socket. + stream->close(); + } + else { + state = State::Finishing; + } +} + void ClientSession::continueAfterTLSEncrypted() { - state = WaitingForStreamStart; + state = State::WaitingForStreamStart; stream->resetXMPPParser(); sendStreamHeader(); } void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError) { State previousState = state; - state = Finished; + state = State::Finished; + + if (streamShutdownTimeout) { + streamShutdownTimeout->stop(); + streamShutdownTimeout.reset(); + } if (stanzaAckRequester_) { stanzaAckRequester_->onRequestAck.disconnect(boost::bind(&ClientSession::requestAck, shared_from_this())); stanzaAckRequester_->onStanzaAcked.disconnect(boost::bind(&ClientSession::handleStanzaAcked, shared_from_this(), _1)); stanzaAckRequester_.reset(); } if (stanzaAckResponder_) { stanzaAckResponder_->onAck.disconnect(boost::bind(&ClientSession::ack, shared_from_this(), _1)); stanzaAckResponder_.reset(); } stream->setWhitespacePingEnabled(false); stream->onStreamStartReceived.disconnect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onStreamEndReceived.disconnect(boost::bind(&ClientSession::handleStreamEnd, shared_from_this())); stream->onElementReceived.disconnect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); stream->onClosed.disconnect(boost::bind(&ClientSession::handleStreamClosed, shared_from_this(), _1)); stream->onTLSEncrypted.disconnect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); - if (previousState == Finishing) { + if (previousState == State::Finishing) { onFinished(error_); } else { onFinished(streamError); } } +void ClientSession::handleStreamShutdownTimeout() { + handleStreamClosed(std::shared_ptr<Swift::Error>()); +} + void ClientSession::finish() { - finishSession(std::shared_ptr<Error>()); + if (state != State::Finishing && state != State::Finished) { + finishSession(std::shared_ptr<Error>()); + } + else { + SWIFT_LOG(warning) << "Session already finished or finishing." << std::endl; + } } void ClientSession::finishSession(Error::Type error) { finishSession(std::make_shared<Swift::ClientSession::Error>(error)); } void ClientSession::finishSession(std::shared_ptr<Swift::Error> error) { - state = Finishing; if (!error_) { error_ = error; } else { SWIFT_LOG(warning) << "Session finished twice"; } assert(stream->isOpen()); if (stanzaAckResponder_) { stanzaAckResponder_->handleAckRequestReceived(); } if (authenticator) { delete authenticator; authenticator = nullptr; } - stream->writeFooter(); - stream->close(); + // Immidiately close TCP connection without stream closure. + if (std::dynamic_pointer_cast<CertificateVerificationError>(error)) { + state = State::Finishing; + initiateShutdown(false); + } + else { + if (state == State::Finishing) { + initiateShutdown(true); + } + else if (state != State::Finished) { + initiateShutdown(true); + } + } } void ClientSession::requestAck() { stream->writeElement(std::make_shared<StanzaAckRequest>()); } void ClientSession::handleStanzaAcked(std::shared_ptr<Stanza> stanza) { onStanzaAcked(stanza); } void ClientSession::ack(unsigned int handledStanzasCount) { stream->writeElement(std::make_shared<StanzaAck>(handledStanzasCount)); } } diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index ad1b78c..c7b3658 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -1,201 +1,215 @@ /* * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #pragma once #include <memory> #include <string> #include <boost/signals2.hpp> #include <Swiften/Base/API.h> #include <Swiften/Base/Error.h> #include <Swiften/Elements/ToplevelElement.h> #include <Swiften/JID/JID.h> #include <Swiften/Session/SessionStream.h> -#include <Swiften/StreamManagement/StanzaAckRequester.h> -#include <Swiften/StreamManagement/StanzaAckResponder.h> namespace Swift { - class ClientAuthenticator; class CertificateTrustChecker; - class IDNConverter; + class ClientAuthenticator; class CryptoProvider; + class IDNConverter; + class Stanza; + class StanzaAckRequester; + class StanzaAckResponder; + class TimerFactory; + class Timer; class SWIFTEN_API ClientSession : public std::enable_shared_from_this<ClientSession> { public: - enum State { + enum class State { Initial, WaitingForStreamStart, Negotiating, Compressing, WaitingForEncrypt, Encrypting, WaitingForCredentials, Authenticating, EnablingSessionManagement, BindingResource, StartingSession, Initialized, Finishing, Finished }; struct Error : public Swift::Error { enum Type { AuthenticationFailedError, CompressionFailedError, ServerVerificationFailedError, NoSupportedAuthMechanismsError, UnexpectedElementError, ResourceBindError, SessionStartError, TLSClientCertificateError, TLSError, - StreamError + StreamError, + StreamEndError, // The server send a closing stream tag. } type; std::shared_ptr<boost::system::error_code> errorCode; Error(Type type) : type(type) {} }; enum UseTLS { NeverUseTLS, UseTLSWhenAvailable, RequireTLS }; ~ClientSession(); - static std::shared_ptr<ClientSession> create(const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, CryptoProvider* crypto) { - return std::shared_ptr<ClientSession>(new ClientSession(jid, stream, idnConverter, crypto)); + static std::shared_ptr<ClientSession> create(const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, CryptoProvider* crypto, TimerFactory* timerFactory) { + return std::shared_ptr<ClientSession>(new ClientSession(jid, stream, idnConverter, crypto, timerFactory)); } State getState() const { return state; } void setAllowPLAINOverNonTLS(bool b) { allowPLAINOverNonTLS = b; } void setUseStreamCompression(bool b) { useStreamCompression = b; } void setUseTLS(UseTLS b) { useTLS = b; } void setUseAcks(bool b) { useAcks = b; } - bool getStreamManagementEnabled() const { // Explicitly convert to bool. In C++11, it would be cleaner to // compare to nullptr. return static_cast<bool>(stanzaAckRequester_); } bool getRosterVersioningSupported() const { return rosterVersioningSupported; } std::vector<Certificate::ref> getPeerCertificateChain() const { return stream->getPeerCertificateChain(); } const JID& getLocalJID() const { return localJID; } void start(); void finish(); bool isFinished() const { - return getState() == Finished; + return getState() == State::Finished; } void sendCredentials(const SafeByteArray& password); void sendStanza(std::shared_ptr<Stanza>); void setCertificateTrustChecker(CertificateTrustChecker* checker) { certificateTrustChecker = checker; } void setSingleSignOn(bool b) { singleSignOn = b; } /** * Sets the port number used in Kerberos authentication * Does not affect network connectivity. */ void setAuthenticationPort(int i) { authenticationPort = i; } + void setSessionShutdownTimeout(int timeoutInMilliseconds) { + sessionShutdownTimeoutInMilliseconds = timeoutInMilliseconds; + } + public: boost::signals2::signal<void ()> onNeedCredentials; boost::signals2::signal<void ()> onInitialized; boost::signals2::signal<void (std::shared_ptr<Swift::Error>)> onFinished; boost::signals2::signal<void (std::shared_ptr<Stanza>)> onStanzaReceived; boost::signals2::signal<void (std::shared_ptr<Stanza>)> onStanzaAcked; private: ClientSession( const JID& jid, std::shared_ptr<SessionStream>, IDNConverter* idnConverter, - CryptoProvider* crypto); + CryptoProvider* crypto, + TimerFactory* timerFactory); void finishSession(Error::Type error); void finishSession(std::shared_ptr<Swift::Error> error); JID getRemoteJID() const { return JID("", localJID.getDomain()); } void sendStreamHeader(); void handleElement(std::shared_ptr<ToplevelElement>); void handleStreamStart(const ProtocolHeader&); + void handleStreamEnd(); void handleStreamClosed(std::shared_ptr<Swift::Error>); + void handleStreamShutdownTimeout(); void handleTLSEncrypted(); bool checkState(State); void continueSessionInitialization(); void requestAck(); void handleStanzaAcked(std::shared_ptr<Stanza> stanza); void ack(unsigned int handledStanzasCount); void continueAfterTLSEncrypted(); void checkTrustOrFinish(const std::vector<Certificate::ref>& certificateChain, std::shared_ptr<CertificateVerificationError> error); + void initiateShutdown(bool sendFooter); private: JID localJID; State state; std::shared_ptr<SessionStream> stream; - IDNConverter* idnConverter; - CryptoProvider* crypto; + IDNConverter* idnConverter = nullptr; + CryptoProvider* crypto = nullptr; + TimerFactory* timerFactory = nullptr; + std::shared_ptr<Timer> streamShutdownTimeout; + int sessionShutdownTimeoutInMilliseconds = 10000; bool allowPLAINOverNonTLS; bool useStreamCompression; UseTLS useTLS; bool useAcks; bool needSessionStart; bool needResourceBind; bool needAcking; bool rosterVersioningSupported; ClientAuthenticator* authenticator; std::shared_ptr<StanzaAckRequester> stanzaAckRequester_; std::shared_ptr<StanzaAckResponder> stanzaAckResponder_; std::shared_ptr<Swift::Error> error_; CertificateTrustChecker* certificateTrustChecker; bool singleSignOn; int authenticationPort; }; } diff --git a/Swiften/Client/ClientSessionStanzaChannel.h b/Swiften/Client/ClientSessionStanzaChannel.h index 0527a5c..c4ee393 100644 --- a/Swiften/Client/ClientSessionStanzaChannel.h +++ b/Swiften/Client/ClientSessionStanzaChannel.h @@ -6,47 +6,47 @@ #pragma once #include <memory> #include <Swiften/Base/API.h> #include <Swiften/Base/IDGenerator.h> #include <Swiften/Client/ClientSession.h> #include <Swiften/Client/StanzaChannel.h> #include <Swiften/Elements/IQ.h> #include <Swiften/Elements/Message.h> #include <Swiften/Elements/Presence.h> namespace Swift { /** * StanzaChannel implementation around a ClientSession. */ class SWIFTEN_API ClientSessionStanzaChannel : public StanzaChannel { public: virtual ~ClientSessionStanzaChannel(); void setSession(std::shared_ptr<ClientSession> session); void sendIQ(std::shared_ptr<IQ> iq); void sendMessage(std::shared_ptr<Message> message); void sendPresence(std::shared_ptr<Presence> presence); bool getStreamManagementEnabled() const; virtual std::vector<Certificate::ref> getPeerCertificateChain() const; bool isAvailable() const { - return session && session->getState() == ClientSession::Initialized; + return session && session->getState() == ClientSession::State::Initialized; } private: std::string getNewIQID(); void send(std::shared_ptr<Stanza> stanza); void handleSessionFinished(std::shared_ptr<Error> error); void handleStanza(std::shared_ptr<Stanza> stanza); void handleStanzaAcked(std::shared_ptr<Stanza> stanza); void handleSessionInitialized(); private: IDGenerator idGenerator; std::shared_ptr<ClientSession> session; }; } diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index 3d75d8b..3c7902e 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -127,66 +127,67 @@ void CoreClient::connect(const ClientOptions& o) { std::shared_ptr<BOSHSessionStream> boshSessionStream_ = std::shared_ptr<BOSHSessionStream>(new BOSHSessionStream( options.boshURL, getPayloadParserFactories(), getPayloadSerializers(), networkFactories->getConnectionFactory(), networkFactories->getTLSContextFactory(), networkFactories->getTimerFactory(), networkFactories->getXMLParserFactory(), networkFactories->getEventLoop(), networkFactories->getDomainNameResolver(), host, options.boshHTTPConnectProxyURL, options.boshHTTPConnectProxyAuthID, options.boshHTTPConnectProxyAuthPassword, options.tlsOptions, options.httpTrafficFilter)); sessionStream_ = boshSessionStream_; sessionStream_->onDataRead.connect(boost::bind(&CoreClient::handleDataRead, this, _1)); sessionStream_->onDataWritten.connect(boost::bind(&CoreClient::handleDataWritten, this, _1)); if (certificate_ && !certificate_->isNull()) { SWIFT_LOG(debug) << "set certificate" << std::endl; sessionStream_->setTLSCertificate(certificate_); } boshSessionStream_->open(); bindSessionToStream(); } } void CoreClient::bindSessionToStream() { - session_ = ClientSession::create(jid_, sessionStream_, networkFactories->getIDNConverter(), networkFactories->getCryptoProvider()); + session_ = ClientSession::create(jid_, sessionStream_, networkFactories->getIDNConverter(), networkFactories->getCryptoProvider(), networkFactories->getTimerFactory()); session_->setCertificateTrustChecker(certificateTrustChecker); session_->setUseStreamCompression(options.useStreamCompression); session_->setAllowPLAINOverNonTLS(options.allowPLAINWithoutTLS); session_->setSingleSignOn(options.singleSignOn); session_->setAuthenticationPort(options.manualPort); + session_->setSessionShutdownTimeout(options.sessionShutdownTimeoutInMilliseconds); switch(options.useTLS) { case ClientOptions::UseTLSWhenAvailable: session_->setUseTLS(ClientSession::UseTLSWhenAvailable); break; case ClientOptions::NeverUseTLS: session_->setUseTLS(ClientSession::NeverUseTLS); break; case ClientOptions::RequireTLS: session_->setUseTLS(ClientSession::RequireTLS); break; } session_->setUseAcks(options.useAcks); stanzaChannel_->setSession(session_); session_->onFinished.connect(boost::bind(&CoreClient::handleSessionFinished, this, _1)); session_->onNeedCredentials.connect(boost::bind(&CoreClient::handleNeedCredentials, this)); session_->start(); } /** * Only called for TCP sessions. BOSH is handled inside the BOSHSessionStream. */ void CoreClient::handleConnectorFinished(std::shared_ptr<Connection> connection, std::shared_ptr<Error> error) { resetConnector(); if (!connection) { if (options.forgetPassword) { purgePassword(); } boost::optional<ClientError> clientError; if (!disconnectRequested_) { clientError = std::dynamic_pointer_cast<DomainNameResolveError>(error) ? boost::optional<ClientError>(ClientError::DomainNameResolveError) : boost::optional<ClientError>(ClientError::ConnectionError); @@ -246,60 +247,63 @@ void CoreClient::handleSessionFinished(std::shared_ptr<Error> error) { case ClientSession::Error::AuthenticationFailedError: clientError = ClientError(ClientError::AuthenticationFailedError); break; case ClientSession::Error::CompressionFailedError: clientError = ClientError(ClientError::CompressionFailedError); break; case ClientSession::Error::ServerVerificationFailedError: clientError = ClientError(ClientError::ServerVerificationFailedError); break; case ClientSession::Error::NoSupportedAuthMechanismsError: clientError = ClientError(ClientError::NoSupportedAuthMechanismsError); break; case ClientSession::Error::UnexpectedElementError: clientError = ClientError(ClientError::UnexpectedElementError); break; case ClientSession::Error::ResourceBindError: clientError = ClientError(ClientError::ResourceBindError); break; case ClientSession::Error::SessionStartError: clientError = ClientError(ClientError::SessionStartError); break; case ClientSession::Error::TLSError: clientError = ClientError(ClientError::TLSError); break; case ClientSession::Error::TLSClientCertificateError: clientError = ClientError(ClientError::ClientCertificateError); break; case ClientSession::Error::StreamError: clientError = ClientError(ClientError::StreamError); break; + case ClientSession::Error::StreamEndError: + clientError = ClientError(ClientError::StreamError); + break; } clientError.setErrorCode(actualError->errorCode); } else if (std::shared_ptr<TLSError> actualError = std::dynamic_pointer_cast<TLSError>(error)) { switch(actualError->getType()) { case TLSError::CertificateCardRemoved: clientError = ClientError(ClientError::CertificateCardRemoved); break; case TLSError::UnknownError: clientError = ClientError(ClientError::TLSError); break; } } else if (std::shared_ptr<SessionStream::SessionStreamError> actualError = std::dynamic_pointer_cast<SessionStream::SessionStreamError>(error)) { switch(actualError->type) { case SessionStream::SessionStreamError::ParseError: clientError = ClientError(ClientError::XMLError); break; case SessionStream::SessionStreamError::TLSError: clientError = ClientError(ClientError::TLSError); break; case SessionStream::SessionStreamError::InvalidTLSCertificateError: clientError = ClientError(ClientError::ClientCertificateLoadError); break; case SessionStream::SessionStreamError::ConnectionReadError: clientError = ClientError(ClientError::ConnectionReadError); break; case SessionStream::SessionStreamError::ConnectionWriteError: clientError = ClientError(ClientError::ConnectionWriteError); break; diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index bd93f4b..44df961 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -1,380 +1,458 @@ /* * Copyright (c) 2010-2016 Isode Limited. * All rights reserved. * See the COPYING file for more information. */ #include <deque> #include <memory> #include <boost/bind.hpp> #include <boost/optional.hpp> +#include <Swiften/Base/Debug.h> + #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> #include <Swiften/Client/ClientSession.h> #include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Crypto/PlatformCryptoProvider.h> #include <Swiften/Elements/AuthChallenge.h> #include <Swiften/Elements/AuthFailure.h> #include <Swiften/Elements/AuthRequest.h> #include <Swiften/Elements/AuthSuccess.h> #include <Swiften/Elements/EnableStreamManagement.h> #include <Swiften/Elements/IQ.h> #include <Swiften/Elements/Message.h> #include <Swiften/Elements/ResourceBind.h> #include <Swiften/Elements/StanzaAck.h> #include <Swiften/Elements/StartTLSFailure.h> #include <Swiften/Elements/StartTLSRequest.h> #include <Swiften/Elements/StreamError.h> #include <Swiften/Elements/StreamFeatures.h> #include <Swiften/Elements/StreamManagementEnabled.h> #include <Swiften/Elements/StreamManagementFailed.h> #include <Swiften/Elements/TLSProceed.h> #include <Swiften/IDN/IDNConverter.h> #include <Swiften/IDN/PlatformIDNConverter.h> +#include <Swiften/Network/DummyTimerFactory.h> #include <Swiften/Session/SessionStream.h> #include <Swiften/TLS/BlindCertificateTrustChecker.h> #include <Swiften/TLS/SimpleCertificate.h> using namespace Swift; class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(ClientSessionTest); CPPUNIT_TEST(testStart_Error); CPPUNIT_TEST(testStart_StreamError); CPPUNIT_TEST(testStartTLS); CPPUNIT_TEST(testStartTLS_ServerError); CPPUNIT_TEST(testStartTLS_ConnectError); CPPUNIT_TEST(testStartTLS_InvalidIdentity); CPPUNIT_TEST(testStart_StreamFeaturesWithoutResourceBindingFails); CPPUNIT_TEST(testAuthenticate); CPPUNIT_TEST(testAuthenticate_Unauthorized); CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms); CPPUNIT_TEST(testAuthenticate_PLAINOverNonTLS); CPPUNIT_TEST(testAuthenticate_RequireTLS); CPPUNIT_TEST(testAuthenticate_EXTERNAL); CPPUNIT_TEST(testStreamManagement); CPPUNIT_TEST(testStreamManagement_Failed); CPPUNIT_TEST(testUnexpectedChallenge); CPPUNIT_TEST(testFinishAcksStanzas); + + CPPUNIT_TEST(testServerInitiatedSessionClose); + CPPUNIT_TEST(testClientInitiatedSessionClose); + CPPUNIT_TEST(testTimeoutOnShutdown); + /* CPPUNIT_TEST(testResourceBind); CPPUNIT_TEST(testResourceBind_ChangeResource); CPPUNIT_TEST(testResourceBind_EmptyResource); CPPUNIT_TEST(testResourceBind_Error); CPPUNIT_TEST(testSessionStart); CPPUNIT_TEST(testSessionStart_Error); CPPUNIT_TEST(testSessionStart_AfterResourceBind); CPPUNIT_TEST(testWhitespacePing); CPPUNIT_TEST(testReceiveElementAfterSessionStarted); CPPUNIT_TEST(testSendElement); */ CPPUNIT_TEST_SUITE_END(); public: void setUp() { crypto = std::shared_ptr<CryptoProvider>(PlatformCryptoProvider::create()); idnConverter = std::shared_ptr<IDNConverter>(PlatformIDNConverter::create()); + timerFactory = std::make_shared<DummyTimerFactory>(); server = std::make_shared<MockSessionStream>(); sessionFinishedReceived = false; needCredentials = false; blindCertificateTrustChecker = new BlindCertificateTrustChecker(); } void tearDown() { delete blindCertificateTrustChecker; } void testStart_Error() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->breakConnection(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStart_StreamError() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->sendStreamStart(); server->sendStreamError(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStartTLS() { std::shared_ptr<ClientSession> session(createSession()); session->setCertificateTrustChecker(blindCertificateTrustChecker); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithStartTLS(); server->receiveStartTLS(); CPPUNIT_ASSERT(!server->tlsEncrypted); server->sendTLSProceed(); CPPUNIT_ASSERT(server->tlsEncrypted); server->onTLSEncrypted(); server->receiveStreamStart(); server->sendStreamStart(); session->finish(); } void testStartTLS_ServerError() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithStartTLS(); server->receiveStartTLS(); server->sendTLSFailure(); CPPUNIT_ASSERT(!server->tlsEncrypted); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStartTLS_ConnectError() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithStartTLS(); server->receiveStartTLS(); server->sendTLSProceed(); server->breakTLS(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStartTLS_InvalidIdentity() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithStartTLS(); server->receiveStartTLS(); CPPUNIT_ASSERT(!server->tlsEncrypted); server->sendTLSProceed(); CPPUNIT_ASSERT(server->tlsEncrypted); server->onTLSEncrypted(); + server->close(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); + CPPUNIT_ASSERT(std::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError)); CPPUNIT_ASSERT_EQUAL(CertificateVerificationError::InvalidServerIdentity, std::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError)->getType()); } void testStart_StreamFeaturesWithoutResourceBindingFails() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendEmptyStreamFeatures(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); CPPUNIT_ASSERT(needCredentials); - CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::WaitingForCredentials, session->getState()); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); server->receiveStreamStart(); session->finish(); } void testAuthenticate_Unauthorized() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); CPPUNIT_ASSERT(needCredentials); - CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::WaitingForCredentials, session->getState()); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthFailure(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate_PLAINOverNonTLS() { std::shared_ptr<ClientSession> session(createSession()); session->setAllowPLAINOverNonTLS(false); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate_RequireTLS() { std::shared_ptr<ClientSession> session(createSession()); session->setUseTLS(ClientSession::RequireTLS); session->setAllowPLAINOverNonTLS(true); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithMultipleAuthentication(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate_NoValidAuthMechanisms() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithUnknownAuthentication(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + server->close(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testAuthenticate_EXTERNAL() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithEXTERNALAuthentication(); server->receiveAuthRequest("EXTERNAL"); server->sendAuthSuccess(); server->receiveStreamStart(); session->finish(); } void testUnexpectedChallenge() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithEXTERNALAuthentication(); server->receiveAuthRequest("EXTERNAL"); server->sendChallenge(); server->sendChallenge(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } void testStreamManagement() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithBindAndStreamManagement(); server->receiveBind(); server->sendBindResult(); server->receiveStreamManagementEnable(); server->sendStreamManagementEnabled(); CPPUNIT_ASSERT(session->getStreamManagementEnabled()); // TODO: Test if the requesters & responders do their work - CPPUNIT_ASSERT_EQUAL(ClientSession::Initialized, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Initialized, session->getState()); session->finish(); } void testStreamManagement_Failed() { std::shared_ptr<ClientSession> session(createSession()); session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithBindAndStreamManagement(); server->receiveBind(); server->sendBindResult(); server->receiveStreamManagementEnable(); server->sendStreamManagementFailed(); CPPUNIT_ASSERT(!session->getStreamManagementEnabled()); - CPPUNIT_ASSERT_EQUAL(ClientSession::Initialized, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Initialized, session->getState()); session->finish(); } void testFinishAcksStanzas() { std::shared_ptr<ClientSession> session(createSession()); initializeSession(session); server->sendMessage(); server->sendMessage(); server->sendMessage(); session->finish(); server->receiveAck(3); } + void testServerInitiatedSessionClose() { + std::shared_ptr<ClientSession> session(createSession()); + initializeSession(session); + + server->onStreamEndReceived(); + server->close(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + } + + void testClientInitiatedSessionClose() { + std::shared_ptr<ClientSession> session(createSession()); + initializeSession(session); + + session->finish(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); + } + + void testTimeoutOnShutdown() { + std::shared_ptr<ClientSession> session(createSession()); + initializeSession(session); + + session->finish(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + timerFactory->setTime(60000); + + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + } + private: std::shared_ptr<ClientSession> createSession() { - std::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server, idnConverter.get(), crypto.get()); + std::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server, idnConverter.get(), crypto.get(), timerFactory.get()); session->onFinished.connect(boost::bind(&ClientSessionTest::handleSessionFinished, this, _1)); session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::handleSessionNeedCredentials, this)); session->setAllowPLAINOverNonTLS(true); return session; } void initializeSession(std::shared_ptr<ClientSession> session) { session->start(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithBindAndStreamManagement(); server->receiveBind(); server->sendBindResult(); server->receiveStreamManagementEnable(); server->sendStreamManagementEnabled(); } void handleSessionFinished(std::shared_ptr<Error> error) { sessionFinishedReceived = true; sessionFinishedError = error; } void handleSessionNeedCredentials() { needCredentials = true; @@ -604,60 +682,61 @@ class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_ASSERT(event.element); std::shared_ptr<StanzaAck> ack = std::dynamic_pointer_cast<StanzaAck>(event.element); CPPUNIT_ASSERT(ack); CPPUNIT_ASSERT_EQUAL(n, ack->getHandledStanzasCount()); } Event popEvent() { CPPUNIT_ASSERT(!receivedEvents.empty()); Event event = receivedEvents.front(); receivedEvents.pop_front(); return event; } bool available; bool canTLSEncrypt; bool tlsEncrypted; bool compressed; bool whitespacePingEnabled; std::string bindID; int resetCount; std::deque<Event> receivedEvents; }; std::shared_ptr<IDNConverter> idnConverter; std::shared_ptr<MockSessionStream> server; bool sessionFinishedReceived; bool needCredentials; std::shared_ptr<Error> sessionFinishedError; BlindCertificateTrustChecker* blindCertificateTrustChecker; std::shared_ptr<CryptoProvider> crypto; + std::shared_ptr<DummyTimerFactory> timerFactory; }; CPPUNIT_TEST_SUITE_REGISTRATION(ClientSessionTest); #if 0 void testAuthenticate() { std::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::setNeedCredentials, this)); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithAuthentication(); session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); CPPUNIT_ASSERT(needCredentials_); getMockServer()->expectAuth("me", "mypass"); getMockServer()->sendAuthSuccess(); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); session->sendCredentials("mypass"); CPPUNIT_ASSERT_EQUAL(ClientSession::Authenticating, session->getState()); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::Negotiating, session->getState()); } void testAuthenticate_Unauthorized() { std::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp index 402f642..10c6ad0 100644 --- a/Swiften/Session/BasicSessionStream.cpp +++ b/Swiften/Session/BasicSessionStream.cpp @@ -13,88 +13,90 @@ #include <Swiften/StreamStack/CompressionLayer.h> #include <Swiften/StreamStack/ConnectionLayer.h> #include <Swiften/StreamStack/StreamStack.h> #include <Swiften/StreamStack/TLSLayer.h> #include <Swiften/StreamStack/WhitespacePingLayer.h> #include <Swiften/StreamStack/XMPPLayer.h> #include <Swiften/TLS/TLSContext.h> #include <Swiften/TLS/TLSContextFactory.h> namespace Swift { BasicSessionStream::BasicSessionStream( StreamType streamType, std::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSContextFactory* tlsContextFactory, TimerFactory* timerFactory, XMLParserFactory* xmlParserFactory, const TLSOptions& tlsOptions) : available(false), connection(connection), tlsContextFactory(tlsContextFactory), timerFactory(timerFactory), compressionLayer(nullptr), tlsLayer(nullptr), whitespacePingLayer(nullptr), tlsOptions_(tlsOptions) { xmppLayer = new XMPPLayer(payloadParserFactories, payloadSerializers, xmlParserFactory, streamType); xmppLayer->onStreamStart.connect(boost::bind(&BasicSessionStream::handleStreamStartReceived, this, _1)); + xmppLayer->onStreamEnd.connect(boost::bind(&BasicSessionStream::handleStreamEndReceived, this)); xmppLayer->onElement.connect(boost::bind(&BasicSessionStream::handleElementReceived, this, _1)); xmppLayer->onError.connect(boost::bind(&BasicSessionStream::handleXMPPError, this)); xmppLayer->onDataRead.connect(boost::bind(&BasicSessionStream::handleDataRead, this, _1)); xmppLayer->onWriteData.connect(boost::bind(&BasicSessionStream::handleDataWritten, this, _1)); connection->onDisconnected.connect(boost::bind(&BasicSessionStream::handleConnectionFinished, this, _1)); connectionLayer = new ConnectionLayer(connection); streamStack = new StreamStack(xmppLayer, connectionLayer); available = true; } BasicSessionStream::~BasicSessionStream() { delete compressionLayer; if (tlsLayer) { tlsLayer->onError.disconnect(boost::bind(&BasicSessionStream::handleTLSError, this, _1)); tlsLayer->onConnected.disconnect(boost::bind(&BasicSessionStream::handleTLSConnected, this)); delete tlsLayer; } delete whitespacePingLayer; delete streamStack; connection->onDisconnected.disconnect(boost::bind(&BasicSessionStream::handleConnectionFinished, this, _1)); delete connectionLayer; xmppLayer->onStreamStart.disconnect(boost::bind(&BasicSessionStream::handleStreamStartReceived, this, _1)); + xmppLayer->onStreamEnd.disconnect(boost::bind(&BasicSessionStream::handleStreamEndReceived, this)); xmppLayer->onElement.disconnect(boost::bind(&BasicSessionStream::handleElementReceived, this, _1)); xmppLayer->onError.disconnect(boost::bind(&BasicSessionStream::handleXMPPError, this)); xmppLayer->onDataRead.disconnect(boost::bind(&BasicSessionStream::handleDataRead, this, _1)); xmppLayer->onWriteData.disconnect(boost::bind(&BasicSessionStream::handleDataWritten, this, _1)); delete xmppLayer; } void BasicSessionStream::writeHeader(const ProtocolHeader& header) { assert(available); xmppLayer->writeHeader(header); } void BasicSessionStream::writeElement(std::shared_ptr<ToplevelElement> element) { assert(available); xmppLayer->writeElement(element); } void BasicSessionStream::writeFooter() { assert(available); xmppLayer->writeFooter(); } void BasicSessionStream::writeData(const std::string& data) { assert(available); xmppLayer->writeData(data); } void BasicSessionStream::close() { connection->disconnect(); } @@ -144,60 +146,64 @@ ByteArray BasicSessionStream::getTLSFinishMessage() const { bool BasicSessionStream::supportsZLibCompression() { return true; } void BasicSessionStream::addZLibCompression() { compressionLayer = new CompressionLayer(); streamStack->addLayer(compressionLayer); } void BasicSessionStream::setWhitespacePingEnabled(bool enabled) { if (enabled) { if (!whitespacePingLayer) { whitespacePingLayer = new WhitespacePingLayer(timerFactory); streamStack->addLayer(whitespacePingLayer); } whitespacePingLayer->setActive(); } else if (whitespacePingLayer) { whitespacePingLayer->setInactive(); } } void BasicSessionStream::resetXMPPParser() { xmppLayer->resetParser(); } void BasicSessionStream::handleStreamStartReceived(const ProtocolHeader& header) { onStreamStartReceived(header); } +void BasicSessionStream::handleStreamEndReceived() { + onStreamEndReceived(); +} + void BasicSessionStream::handleElementReceived(std::shared_ptr<ToplevelElement> element) { onElementReceived(element); } void BasicSessionStream::handleXMPPError() { available = false; onClosed(std::make_shared<SessionStreamError>(SessionStreamError::ParseError)); } void BasicSessionStream::handleTLSConnected() { onTLSEncrypted(); } void BasicSessionStream::handleTLSError(std::shared_ptr<TLSError> error) { available = false; onClosed(error); } void BasicSessionStream::handleConnectionFinished(const boost::optional<Connection::Error>& error) { available = false; if (error == Connection::ReadError) { onClosed(std::make_shared<SessionStreamError>(SessionStreamError::ConnectionReadError)); } else if (error) { onClosed(std::make_shared<SessionStreamError>(SessionStreamError::ConnectionWriteError)); } else { onClosed(std::shared_ptr<SessionStreamError>()); } } diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h index 1806cef..48b3d63 100644 --- a/Swiften/Session/BasicSessionStream.h +++ b/Swiften/Session/BasicSessionStream.h @@ -46,49 +46,50 @@ namespace Swift { virtual void close(); virtual bool isOpen(); virtual void writeHeader(const ProtocolHeader& header); virtual void writeElement(std::shared_ptr<ToplevelElement>); virtual void writeFooter(); virtual void writeData(const std::string& data); virtual bool supportsZLibCompression(); virtual void addZLibCompression(); virtual bool supportsTLSEncryption(); virtual void addTLSEncryption(); virtual bool isTLSEncrypted(); virtual Certificate::ref getPeerCertificate() const; virtual std::vector<Certificate::ref> getPeerCertificateChain() const; virtual std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const; virtual ByteArray getTLSFinishMessage() const; virtual void setWhitespacePingEnabled(bool); virtual void resetXMPPParser(); private: void handleConnectionFinished(const boost::optional<Connection::Error>& error); void handleXMPPError(); void handleTLSConnected(); void handleTLSError(std::shared_ptr<TLSError>); void handleStreamStartReceived(const ProtocolHeader&); + void handleStreamEndReceived(); void handleElementReceived(std::shared_ptr<ToplevelElement>); void handleDataRead(const SafeByteArray& data); void handleDataWritten(const SafeByteArray& data); private: bool available; std::shared_ptr<Connection> connection; TLSContextFactory* tlsContextFactory; TimerFactory* timerFactory; XMPPLayer* xmppLayer; ConnectionLayer* connectionLayer; CompressionLayer* compressionLayer; TLSLayer* tlsLayer; WhitespacePingLayer* whitespacePingLayer; StreamStack* streamStack; TLSOptions tlsOptions_; }; } diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h index f56f495..c5ec42a 100644 --- a/Swiften/Session/SessionStream.h +++ b/Swiften/Session/SessionStream.h @@ -48,45 +48,46 @@ namespace Swift { virtual void writeHeader(const ProtocolHeader& header) = 0; virtual void writeFooter() = 0; virtual void writeElement(std::shared_ptr<ToplevelElement>) = 0; virtual void writeData(const std::string& data) = 0; virtual bool supportsZLibCompression() = 0; virtual void addZLibCompression() = 0; virtual bool supportsTLSEncryption() = 0; virtual void addTLSEncryption() = 0; virtual bool isTLSEncrypted() = 0; virtual void setWhitespacePingEnabled(bool enabled) = 0; virtual void resetXMPPParser() = 0; void setTLSCertificate(CertificateWithKey::ref cert) { certificate = cert; } virtual bool hasTLSCertificate() { return certificate && !certificate->isNull(); } virtual Certificate::ref getPeerCertificate() const = 0; virtual std::vector<Certificate::ref> getPeerCertificateChain() const = 0; virtual std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const = 0; virtual ByteArray getTLSFinishMessage() const = 0; boost::signals2::signal<void (const ProtocolHeader&)> onStreamStartReceived; + boost::signals2::signal<void ()> onStreamEndReceived; boost::signals2::signal<void (std::shared_ptr<ToplevelElement>)> onElementReceived; boost::signals2::signal<void (std::shared_ptr<Error>)> onClosed; boost::signals2::signal<void ()> onTLSEncrypted; boost::signals2::signal<void (const SafeByteArray&)> onDataRead; boost::signals2::signal<void (const SafeByteArray&)> onDataWritten; protected: CertificateWithKey::ref getTLSCertificate() const { return certificate; } private: CertificateWithKey::ref certificate; }; } diff --git a/Swiften/StreamStack/XMPPLayer.cpp b/Swiften/StreamStack/XMPPLayer.cpp index 2eed906..982d13f 100644 --- a/Swiften/StreamStack/XMPPLayer.cpp +++ b/Swiften/StreamStack/XMPPLayer.cpp @@ -59,42 +59,43 @@ void XMPPLayer::handleDataRead(const SafeByteArray& data) { inParser_ = true; // FIXME: Converting to unsafe string. Should be ok, since we don't take passwords // from the stream in clients. If servers start using this, and require safe storage, // we need to fix this. if (!xmppParser_->parse(byteArrayToString(ByteArray(data.begin(), data.end())))) { inParser_ = false; onError(); return; } inParser_ = false; if (resetParserAfterParse_) { doResetParser(); } } void XMPPLayer::doResetParser() { delete xmppParser_; xmppParser_ = new XMPPParser(this, payloadParserFactories_, xmlParserFactory_); resetParserAfterParse_ = false; } void XMPPLayer::handleStreamStart(const ProtocolHeader& header) { onStreamStart(header); } void XMPPLayer::handleElement(std::shared_ptr<ToplevelElement> stanza) { onElement(stanza); } void XMPPLayer::handleStreamEnd() { + onStreamEnd(); } void XMPPLayer::resetParser() { if (inParser_) { resetParserAfterParse_ = true; } else { doResetParser(); } } } diff --git a/Swiften/StreamStack/XMPPLayer.h b/Swiften/StreamStack/XMPPLayer.h index 1d4abf8..f0b5afb 100644 --- a/Swiften/StreamStack/XMPPLayer.h +++ b/Swiften/StreamStack/XMPPLayer.h @@ -24,53 +24,54 @@ namespace Swift { class PayloadParserFactoryCollection; class XMPPSerializer; class PayloadSerializerCollection; class XMLParserFactory; class BOSHSessionStream; class SWIFTEN_API XMPPLayer : public XMPPParserClient, public HighLayer, boost::noncopyable { friend class BOSHSessionStream; public: XMPPLayer( PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, XMLParserFactory* xmlParserFactory, StreamType streamType, bool setExplictNSonTopLevelElements = false); virtual ~XMPPLayer(); void writeHeader(const ProtocolHeader& header); void writeFooter(); void writeElement(std::shared_ptr<ToplevelElement>); void writeData(const std::string& data); void resetParser(); protected: void handleDataRead(const SafeByteArray& data); void writeDataInternal(const SafeByteArray& data); public: boost::signals2::signal<void (const ProtocolHeader&)> onStreamStart; + boost::signals2::signal<void ()> onStreamEnd; boost::signals2::signal<void (std::shared_ptr<ToplevelElement>)> onElement; boost::signals2::signal<void (const SafeByteArray&)> onWriteData; boost::signals2::signal<void (const SafeByteArray&)> onDataRead; boost::signals2::signal<void ()> onError; private: void handleStreamStart(const ProtocolHeader&); void handleElement(std::shared_ptr<ToplevelElement>); void handleStreamEnd(); void doResetParser(); private: PayloadParserFactoryCollection* payloadParserFactories_; XMPPParser* xmppParser_; PayloadSerializerCollection* payloadSerializers_; XMLParserFactory* xmlParserFactory_; XMPPSerializer* xmppSerializer_; bool setExplictNSonTopLevelElements_; bool resetParserAfterParse_; bool inParser_; }; } |