diff options
Diffstat (limited to 'Swiften/Client')
-rw-r--r-- | Swiften/Client/ClientOptions.h | 54 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.cpp | 184 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.h | 40 | ||||
-rw-r--r-- | Swiften/Client/ClientSessionStanzaChannel.h | 2 | ||||
-rw-r--r-- | Swiften/Client/CoreClient.cpp | 6 | ||||
-rw-r--r-- | Swiften/Client/UnitTest/ClientSessionTest.cpp | 111 |
6 files changed, 274 insertions, 123 deletions
diff --git a/Swiften/Client/ClientOptions.h b/Swiften/Client/ClientOptions.h index 3a93197..1a337b6 100644 --- a/Swiften/Client/ClientOptions.h +++ b/Swiften/Client/ClientOptions.h @@ -30,21 +30,7 @@ namespace Swift { HTTPConnectProxy }; - ClientOptions() : - useStreamCompression(true), - useTLS(UseTLSWhenAvailable), - allowPLAINWithoutTLS(false), - useStreamResumption(false), - forgetPassword(false), - useAcks(true), - singleSignOn(false), - manualHostname(""), - manualPort(-1), - proxyType(SystemConfiguredProxy), - manualProxyHostname(""), - manualProxyPort(-1), - boshHTTPConnectProxyAuthID(""), - boshHTTPConnectProxyAuthPassword("") { + ClientOptions() { } /** @@ -52,14 +38,14 @@ namespace Swift { * * Default: true */ - bool useStreamCompression; + bool useStreamCompression = true; /** * Sets whether TLS encryption should be used. * * Default: UseTLSWhenAvailable */ - UseTLS useTLS; + UseTLS useTLS = UseTLSWhenAvailable; /** * Sets whether plaintext authentication is @@ -67,14 +53,14 @@ namespace Swift { * * Default: false */ - bool allowPLAINWithoutTLS; + bool allowPLAINWithoutTLS = false; /** * Use XEP-196 stream resumption when available. * * Default: false */ - bool useStreamResumption; + bool useStreamResumption = false; /** * Forget the password once it's used. @@ -84,69 +70,69 @@ namespace Swift { * * Default: false */ - bool forgetPassword; + bool forgetPassword = false; /** * Use XEP-0198 acks in the stream when available. * Default: true */ - bool useAcks; + bool useAcks = true; /** * Use Single Sign On. * Default: false */ - bool singleSignOn; + bool singleSignOn = false; /** * The hostname to connect to. * Leave this empty for standard XMPP connection, based on the JID domain. */ - std::string manualHostname; + std::string manualHostname = ""; /** * The port to connect to. * Leave this to -1 to use the port discovered by SRV lookups, and 5222 as a * fallback. */ - int manualPort; + int manualPort = -1; /** * The type of proxy to use for connecting to the XMPP * server. */ - ProxyType proxyType; + ProxyType proxyType = SystemConfiguredProxy; /** * Override the system-configured proxy hostname. */ - std::string manualProxyHostname; + std::string manualProxyHostname = ""; /** * Override the system-configured proxy port. */ - int manualProxyPort; + int manualProxyPort = -1; /** * If non-empty, use BOSH instead of direct TCP, with the given URL. * Default: empty (no BOSH) */ - URL boshURL; + URL boshURL = URL(); /** * If non-empty, BOSH connections will try to connect over this HTTP CONNECT * proxy instead of directly. * Default: empty (no proxy) */ - URL boshHTTPConnectProxyURL; + URL boshHTTPConnectProxyURL = URL(); /** * If this and matching Password are non-empty, BOSH connections over * HTTP CONNECT proxies will use these credentials for proxy access. * Default: empty (no authentication needed by the proxy) */ - SafeString boshHTTPConnectProxyAuthID; - SafeString boshHTTPConnectProxyAuthPassword; + SafeString boshHTTPConnectProxyAuthID = SafeString(""); + SafeString boshHTTPConnectProxyAuthPassword = SafeString(""); /** * This can be initialized with a custom HTTPTrafficFilter, which allows HTTP CONNECT @@ -158,5 +144,11 @@ namespace Swift { * Options passed to the TLS stack */ TLSOptions tlsOptions; + + /** + * Session shutdown timeout in milliseconds. This is the maximum time Swiften + * waits from a session close to the socket close. + */ + int sessionShutdownTimeoutInMilliseconds = 10000; }; } diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index c301881..bcfb004 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -6,42 +6,47 @@ #include <Swiften/Client/ClientSession.h> +#include <memory> + #include <boost/bind.hpp> #include <boost/uuid/uuid.hpp> #include <boost/uuid/uuid_io.hpp> #include <boost/uuid/uuid_generators.hpp> -#include <memory> -#include <Swiften/Base/Platform.h> #include <Swiften/Base/Log.h> -#include <Swiften/Elements/ProtocolHeader.h> -#include <Swiften/Elements/StreamFeatures.h> -#include <Swiften/Elements/StreamError.h> -#include <Swiften/Elements/StartTLSRequest.h> -#include <Swiften/Elements/StartTLSFailure.h> -#include <Swiften/Elements/TLSProceed.h> -#include <Swiften/Elements/AuthRequest.h> -#include <Swiften/Elements/AuthSuccess.h> -#include <Swiften/Elements/AuthFailure.h> +#include <Swiften/Base/Platform.h> +#include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Elements/AuthChallenge.h> +#include <Swiften/Elements/AuthFailure.h> +#include <Swiften/Elements/AuthRequest.h> #include <Swiften/Elements/AuthResponse.h> -#include <Swiften/Elements/Compressed.h> +#include <Swiften/Elements/AuthSuccess.h> #include <Swiften/Elements/CompressFailure.h> #include <Swiften/Elements/CompressRequest.h> +#include <Swiften/Elements/Compressed.h> #include <Swiften/Elements/EnableStreamManagement.h> -#include <Swiften/Elements/StreamManagementEnabled.h> -#include <Swiften/Elements/StreamManagementFailed.h> -#include <Swiften/Elements/StartSession.h> -#include <Swiften/Elements/StanzaAck.h> -#include <Swiften/Elements/StanzaAckRequest.h> #include <Swiften/Elements/IQ.h> +#include <Swiften/Elements/ProtocolHeader.h> #include <Swiften/Elements/ResourceBind.h> -#include <Swiften/SASL/PLAINClientAuthenticator.h> +#include <Swiften/Elements/StanzaAck.h> +#include <Swiften/Elements/StanzaAckRequest.h> +#include <Swiften/Elements/StartSession.h> +#include <Swiften/Elements/StartTLSFailure.h> +#include <Swiften/Elements/StartTLSRequest.h> +#include <Swiften/Elements/StreamError.h> +#include <Swiften/Elements/StreamFeatures.h> +#include <Swiften/Elements/StreamManagementEnabled.h> +#include <Swiften/Elements/StreamManagementFailed.h> +#include <Swiften/Elements/TLSProceed.h> +#include <Swiften/Network/Timer.h> +#include <Swiften/Network/TimerFactory.h> +#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h> #include <Swiften/SASL/EXTERNALClientAuthenticator.h> +#include <Swiften/SASL/PLAINClientAuthenticator.h> #include <Swiften/SASL/SCRAMSHA1ClientAuthenticator.h> -#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h> -#include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Session/SessionStream.h> +#include <Swiften/StreamManagement/StanzaAckRequester.h> +#include <Swiften/StreamManagement/StanzaAckResponder.h> #include <Swiften/TLS/CertificateTrustChecker.h> #include <Swiften/TLS/ServerIdentityVerifier.h> @@ -59,12 +64,14 @@ ClientSession::ClientSession( const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, - CryptoProvider* crypto) : + CryptoProvider* crypto, + TimerFactory* timerFactory) : localJID(jid), - state(Initial), + state(State::Initial), stream(stream), idnConverter(idnConverter), crypto(crypto), + timerFactory(timerFactory), allowPLAINOverNonTLS(false), useStreamCompression(true), useTLS(UseTLSWhenAvailable), @@ -89,12 +96,13 @@ ClientSession::~ClientSession() { void ClientSession::start() { stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onStreamEndReceived.connect(boost::bind(&ClientSession::handleStreamEnd, shared_from_this())); stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); stream->onClosed.connect(boost::bind(&ClientSession::handleStreamClosed, shared_from_this(), _1)); stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); - assert(state == Initial); - state = WaitingForStreamStart; + assert(state == State::Initial); + state = State::WaitingForStreamStart; sendStreamHeader(); } @@ -112,8 +120,19 @@ void ClientSession::sendStanza(std::shared_ptr<Stanza> stanza) { } void ClientSession::handleStreamStart(const ProtocolHeader&) { - CHECK_STATE_OR_RETURN(WaitingForStreamStart); - state = Negotiating; + CHECK_STATE_OR_RETURN(State::WaitingForStreamStart); + state = State::Negotiating; +} + +void ClientSession::handleStreamEnd() { + if (state == State::Finishing) { + // We are already in finishing state if we iniated the close of the session. + stream->close(); + } + else { + state = State::Finishing; + initiateShutdown(true); + } } void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { @@ -121,11 +140,11 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { if (stanzaAckResponder_) { stanzaAckResponder_->handleStanzaReceived(); } - if (getState() == Initialized) { + if (getState() == State::Initialized) { onStanzaReceived(stanza); } else if (std::shared_ptr<IQ> iq = std::dynamic_pointer_cast<IQ>(element)) { - if (state == BindingResource) { + if (state == State::BindingResource) { std::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { finishSession(Error::ResourceBindError); @@ -145,7 +164,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { finishSession(Error::UnexpectedElementError); } } - else if (state == StartingSession) { + else if (state == State::StartingSession) { if (iq->getType() == IQ::Result) { needSessionStart = false; continueSessionInitialization(); @@ -183,7 +202,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { else if (StreamError::ref streamError = std::dynamic_pointer_cast<StreamError>(element)) { finishSession(Error::StreamError); } - else if (getState() == Initialized) { + else if (getState() == State::Initialized) { std::shared_ptr<Stanza> stanza = std::dynamic_pointer_cast<Stanza>(element); if (stanza) { if (stanzaAckResponder_) { @@ -193,17 +212,17 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { } } else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { - CHECK_STATE_OR_RETURN(Negotiating); + CHECK_STATE_OR_RETURN(State::Negotiating); if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption() && useTLS != NeverUseTLS) { - state = WaitingForEncrypt; + state = State::WaitingForEncrypt; stream->writeElement(std::make_shared<StartTLSRequest>()); } else if (useTLS == RequireTLS && !stream->isTLSEncrypted()) { finishSession(Error::NoSupportedAuthMechanismsError); } else if (useStreamCompression && stream->supportsZLibCompression() && streamFeatures->hasCompressionMethod("zlib")) { - state = Compressing; + state = State::Compressing; stream->writeElement(std::make_shared<CompressRequest>("zlib")); } else if (streamFeatures->hasAuthenticationMechanisms()) { @@ -222,7 +241,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { authenticator = gssapiAuthenticator; if (!gssapiAuthenticator->isError()) { - state = Authenticating; + state = State::Authenticating; stream->writeElement(std::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse())); } else { @@ -236,7 +255,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { if (stream->hasTLSCertificate()) { if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { authenticator = new EXTERNALClientAuthenticator(); - state = Authenticating; + state = State::Authenticating; stream->writeElement(std::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray(""))); } else { @@ -245,7 +264,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { } else if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { authenticator = new EXTERNALClientAuthenticator(); - state = Authenticating; + state = State::Authenticating; stream->writeElement(std::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray(""))); } else if (streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1") || streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1-PLUS")) { @@ -262,12 +281,12 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { scramAuthenticator->setTLSChannelBindingData(finishMessage); } authenticator = scramAuthenticator; - state = WaitingForCredentials; + state = State::WaitingForCredentials; onNeedCredentials(); } else if ((stream->isTLSEncrypted() || allowPLAINOverNonTLS) && streamFeatures->hasAuthenticationMechanism("PLAIN")) { authenticator = new PLAINClientAuthenticator(); - state = WaitingForCredentials; + state = State::WaitingForCredentials; onNeedCredentials(); } else if (streamFeatures->hasAuthenticationMechanism("DIGEST-MD5") && crypto->isMD5AllowedForCrypto()) { @@ -275,7 +294,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { s << boost::uuids::random_generator()(); // FIXME: Host should probably be the actual host authenticator = new DIGESTMD5ClientAuthenticator(localJID.getDomain(), s.str(), crypto); - state = WaitingForCredentials; + state = State::WaitingForCredentials; onNeedCredentials(); } else { @@ -299,8 +318,8 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { } } else if (std::dynamic_pointer_cast<Compressed>(element)) { - CHECK_STATE_OR_RETURN(Compressing); - state = WaitingForStreamStart; + CHECK_STATE_OR_RETURN(State::Compressing); + state = State::WaitingForStreamStart; stream->addZLibCompression(); stream->resetXMPPParser(); sendStreamHeader(); @@ -322,7 +341,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { continueSessionInitialization(); } else if (AuthChallenge* challenge = dynamic_cast<AuthChallenge*>(element.get())) { - CHECK_STATE_OR_RETURN(Authenticating); + CHECK_STATE_OR_RETURN(State::Authenticating); assert(authenticator); if (authenticator->setChallenge(challenge->getValue())) { stream->writeElement(std::make_shared<AuthResponse>(authenticator->getResponse())); @@ -340,13 +359,13 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { } } else if (AuthSuccess* authSuccess = dynamic_cast<AuthSuccess*>(element.get())) { - CHECK_STATE_OR_RETURN(Authenticating); + CHECK_STATE_OR_RETURN(State::Authenticating); assert(authenticator); if (!authenticator->setChallenge(authSuccess->getValue())) { finishSession(Error::ServerVerificationFailedError); } else { - state = WaitingForStreamStart; + state = State::WaitingForStreamStart; delete authenticator; authenticator = nullptr; stream->resetXMPPParser(); @@ -357,8 +376,8 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { finishSession(Error::AuthenticationFailedError); } else if (dynamic_cast<TLSProceed*>(element.get())) { - CHECK_STATE_OR_RETURN(WaitingForEncrypt); - state = Encrypting; + CHECK_STATE_OR_RETURN(State::WaitingForEncrypt); + state = State::Encrypting; stream->addTLSEncryption(); } else if (dynamic_cast<StartTLSFailure*>(element.get())) { @@ -366,14 +385,14 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { } else { // FIXME Not correct? - state = Initialized; + state = State::Initialized; onInitialized(); } } void ClientSession::continueSessionInitialization() { if (needResourceBind) { - state = BindingResource; + state = State::BindingResource; std::shared_ptr<ResourceBind> resourceBind(std::make_shared<ResourceBind>()); if (!localJID.getResource().empty()) { resourceBind->setResource(localJID.getResource()); @@ -381,15 +400,15 @@ void ClientSession::continueSessionInitialization() { sendStanza(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); } else if (needAcking) { - state = EnablingSessionManagement; + state = State::EnablingSessionManagement; stream->writeElement(std::make_shared<EnableStreamManagement>()); } else if (needSessionStart) { - state = StartingSession; + state = State::StartingSession; sendStanza(IQ::createRequest(IQ::Set, JID(), "session-start", std::make_shared<StartSession>())); } else { - state = Initialized; + state = State::Initialized; onInitialized(); } } @@ -403,15 +422,15 @@ bool ClientSession::checkState(State state) { } void ClientSession::sendCredentials(const SafeByteArray& password) { - assert(WaitingForCredentials); + assert(state == State::WaitingForCredentials); assert(authenticator); - state = Authenticating; + state = State::Authenticating; authenticator->setCredentials(localJID.getNode(), password); stream->writeElement(std::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse())); } void ClientSession::handleTLSEncrypted() { - CHECK_STATE_OR_RETURN(Encrypting); + CHECK_STATE_OR_RETURN(State::Encrypting); std::vector<Certificate::ref> certificateChain = stream->getPeerCertificateChain(); std::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError(); @@ -438,15 +457,38 @@ void ClientSession::checkTrustOrFinish(const std::vector<Certificate::ref>& cert } } +void ClientSession::initiateShutdown(bool sendFooter) { + if (!streamShutdownTimeout) { + streamShutdownTimeout = timerFactory->createTimer(sessionShutdownTimeoutInMilliseconds); + streamShutdownTimeout->onTick.connect(boost::bind(&ClientSession::handleStreamShutdownTimeout, shared_from_this())); + streamShutdownTimeout->start(); + } + if (sendFooter) { + stream->writeFooter(); + } + if (state == State::Finishing) { + // The other side already send </stream>; we can close the socket. + stream->close(); + } + else { + state = State::Finishing; + } +} + void ClientSession::continueAfterTLSEncrypted() { - state = WaitingForStreamStart; + state = State::WaitingForStreamStart; stream->resetXMPPParser(); sendStreamHeader(); } void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError) { State previousState = state; - state = Finished; + state = State::Finished; + + if (streamShutdownTimeout) { + streamShutdownTimeout->stop(); + streamShutdownTimeout.reset(); + } if (stanzaAckRequester_) { stanzaAckRequester_->onRequestAck.disconnect(boost::bind(&ClientSession::requestAck, shared_from_this())); @@ -459,11 +501,12 @@ void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError } stream->setWhitespacePingEnabled(false); stream->onStreamStartReceived.disconnect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onStreamEndReceived.disconnect(boost::bind(&ClientSession::handleStreamEnd, shared_from_this())); stream->onElementReceived.disconnect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); stream->onClosed.disconnect(boost::bind(&ClientSession::handleStreamClosed, shared_from_this(), _1)); stream->onTLSEncrypted.disconnect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); - if (previousState == Finishing) { + if (previousState == State::Finishing) { onFinished(error_); } else { @@ -471,8 +514,17 @@ void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError } } +void ClientSession::handleStreamShutdownTimeout() { + handleStreamClosed(std::shared_ptr<Swift::Error>()); +} + void ClientSession::finish() { - finishSession(std::shared_ptr<Error>()); + if (state != State::Finishing && state != State::Finished) { + finishSession(std::shared_ptr<Error>()); + } + else { + SWIFT_LOG(warning) << "Session already finished or finishing." << std::endl; + } } void ClientSession::finishSession(Error::Type error) { @@ -480,7 +532,6 @@ void ClientSession::finishSession(Error::Type error) { } void ClientSession::finishSession(std::shared_ptr<Swift::Error> error) { - state = Finishing; if (!error_) { error_ = error; } @@ -495,8 +546,19 @@ void ClientSession::finishSession(std::shared_ptr<Swift::Error> error) { delete authenticator; authenticator = nullptr; } - stream->writeFooter(); - stream->close(); + // Immidiately close TCP connection without stream closure. + if (std::dynamic_pointer_cast<CertificateVerificationError>(error)) { + state = State::Finishing; + initiateShutdown(false); + } + else { + if (state == State::Finishing) { + initiateShutdown(true); + } + else if (state != State::Finished) { + initiateShutdown(true); + } + } } void ClientSession::requestAck() { diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index ad1b78c..c7b3658 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -16,18 +16,21 @@ #include <Swiften/Elements/ToplevelElement.h> #include <Swiften/JID/JID.h> #include <Swiften/Session/SessionStream.h> -#include <Swiften/StreamManagement/StanzaAckRequester.h> -#include <Swiften/StreamManagement/StanzaAckResponder.h> namespace Swift { - class ClientAuthenticator; class CertificateTrustChecker; - class IDNConverter; + class ClientAuthenticator; class CryptoProvider; + class IDNConverter; + class Stanza; + class StanzaAckRequester; + class StanzaAckResponder; + class TimerFactory; + class Timer; class SWIFTEN_API ClientSession : public std::enable_shared_from_this<ClientSession> { public: - enum State { + enum class State { Initial, WaitingForStreamStart, Negotiating, @@ -55,7 +58,8 @@ namespace Swift { SessionStartError, TLSClientCertificateError, TLSError, - StreamError + StreamError, + StreamEndError, // The server send a closing stream tag. } type; std::shared_ptr<boost::system::error_code> errorCode; Error(Type type) : type(type) {} @@ -69,8 +73,8 @@ namespace Swift { ~ClientSession(); - static std::shared_ptr<ClientSession> create(const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, CryptoProvider* crypto) { - return std::shared_ptr<ClientSession>(new ClientSession(jid, stream, idnConverter, crypto)); + static std::shared_ptr<ClientSession> create(const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, CryptoProvider* crypto, TimerFactory* timerFactory) { + return std::shared_ptr<ClientSession>(new ClientSession(jid, stream, idnConverter, crypto, timerFactory)); } State getState() const { @@ -93,7 +97,6 @@ namespace Swift { useAcks = b; } - bool getStreamManagementEnabled() const { // Explicitly convert to bool. In C++11, it would be cleaner to // compare to nullptr. @@ -116,7 +119,7 @@ namespace Swift { void finish(); bool isFinished() const { - return getState() == Finished; + return getState() == State::Finished; } void sendCredentials(const SafeByteArray& password); @@ -138,6 +141,10 @@ namespace Swift { authenticationPort = i; } + void setSessionShutdownTimeout(int timeoutInMilliseconds) { + sessionShutdownTimeoutInMilliseconds = timeoutInMilliseconds; + } + public: boost::signals2::signal<void ()> onNeedCredentials; boost::signals2::signal<void ()> onInitialized; @@ -150,7 +157,8 @@ namespace Swift { const JID& jid, std::shared_ptr<SessionStream>, IDNConverter* idnConverter, - CryptoProvider* crypto); + CryptoProvider* crypto, + TimerFactory* timerFactory); void finishSession(Error::Type error); void finishSession(std::shared_ptr<Swift::Error> error); @@ -163,7 +171,9 @@ namespace Swift { void handleElement(std::shared_ptr<ToplevelElement>); void handleStreamStart(const ProtocolHeader&); + void handleStreamEnd(); void handleStreamClosed(std::shared_ptr<Swift::Error>); + void handleStreamShutdownTimeout(); void handleTLSEncrypted(); @@ -175,13 +185,17 @@ namespace Swift { void ack(unsigned int handledStanzasCount); void continueAfterTLSEncrypted(); void checkTrustOrFinish(const std::vector<Certificate::ref>& certificateChain, std::shared_ptr<CertificateVerificationError> error); + void initiateShutdown(bool sendFooter); private: JID localJID; State state; std::shared_ptr<SessionStream> stream; - IDNConverter* idnConverter; - CryptoProvider* crypto; + IDNConverter* idnConverter = nullptr; + CryptoProvider* crypto = nullptr; + TimerFactory* timerFactory = nullptr; + std::shared_ptr<Timer> streamShutdownTimeout; + int sessionShutdownTimeoutInMilliseconds = 10000; bool allowPLAINOverNonTLS; bool useStreamCompression; UseTLS useTLS; diff --git a/Swiften/Client/ClientSessionStanzaChannel.h b/Swiften/Client/ClientSessionStanzaChannel.h index 0527a5c..c4ee393 100644 --- a/Swiften/Client/ClientSessionStanzaChannel.h +++ b/Swiften/Client/ClientSessionStanzaChannel.h @@ -33,7 +33,7 @@ namespace Swift { virtual std::vector<Certificate::ref> getPeerCertificateChain() const; bool isAvailable() const { - return session && session->getState() == ClientSession::Initialized; + return session && session->getState() == ClientSession::State::Initialized; } private: diff --git a/Swiften/Client/CoreClient.cpp b/Swiften/Client/CoreClient.cpp index 3d75d8b..3c7902e 100644 --- a/Swiften/Client/CoreClient.cpp +++ b/Swiften/Client/CoreClient.cpp @@ -154,12 +154,13 @@ void CoreClient::connect(const ClientOptions& o) { } void CoreClient::bindSessionToStream() { - session_ = ClientSession::create(jid_, sessionStream_, networkFactories->getIDNConverter(), networkFactories->getCryptoProvider()); + session_ = ClientSession::create(jid_, sessionStream_, networkFactories->getIDNConverter(), networkFactories->getCryptoProvider(), networkFactories->getTimerFactory()); session_->setCertificateTrustChecker(certificateTrustChecker); session_->setUseStreamCompression(options.useStreamCompression); session_->setAllowPLAINOverNonTLS(options.allowPLAINWithoutTLS); session_->setSingleSignOn(options.singleSignOn); session_->setAuthenticationPort(options.manualPort); + session_->setSessionShutdownTimeout(options.sessionShutdownTimeoutInMilliseconds); switch(options.useTLS) { case ClientOptions::UseTLSWhenAvailable: session_->setUseTLS(ClientSession::UseTLSWhenAvailable); @@ -273,6 +274,9 @@ void CoreClient::handleSessionFinished(std::shared_ptr<Error> error) { case ClientSession::Error::StreamError: clientError = ClientError(ClientError::StreamError); break; + case ClientSession::Error::StreamEndError: + clientError = ClientError(ClientError::StreamError); + break; } clientError.setErrorCode(actualError->errorCode); } diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index bd93f4b..44df961 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -10,6 +10,8 @@ #include <boost/bind.hpp> #include <boost/optional.hpp> +#include <Swiften/Base/Debug.h> + #include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/TestFactoryRegistry.h> @@ -34,6 +36,7 @@ #include <Swiften/Elements/TLSProceed.h> #include <Swiften/IDN/IDNConverter.h> #include <Swiften/IDN/PlatformIDNConverter.h> +#include <Swiften/Network/DummyTimerFactory.h> #include <Swiften/Session/SessionStream.h> #include <Swiften/TLS/BlindCertificateTrustChecker.h> #include <Swiften/TLS/SimpleCertificate.h> @@ -59,6 +62,11 @@ class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_TEST(testStreamManagement_Failed); CPPUNIT_TEST(testUnexpectedChallenge); CPPUNIT_TEST(testFinishAcksStanzas); + + CPPUNIT_TEST(testServerInitiatedSessionClose); + CPPUNIT_TEST(testClientInitiatedSessionClose); + CPPUNIT_TEST(testTimeoutOnShutdown); + /* CPPUNIT_TEST(testResourceBind); CPPUNIT_TEST(testResourceBind_ChangeResource); @@ -77,6 +85,7 @@ class ClientSessionTest : public CppUnit::TestFixture { void setUp() { crypto = std::shared_ptr<CryptoProvider>(PlatformCryptoProvider::create()); idnConverter = std::shared_ptr<IDNConverter>(PlatformIDNConverter::create()); + timerFactory = std::make_shared<DummyTimerFactory>(); server = std::make_shared<MockSessionStream>(); sessionFinishedReceived = false; needCredentials = false; @@ -92,7 +101,7 @@ class ClientSessionTest : public CppUnit::TestFixture { session->start(); server->breakConnection(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } @@ -103,7 +112,11 @@ class ClientSessionTest : public CppUnit::TestFixture { server->sendStreamStart(); server->sendStreamError(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } @@ -136,7 +149,11 @@ class ClientSessionTest : public CppUnit::TestFixture { server->sendTLSFailure(); CPPUNIT_ASSERT(!server->tlsEncrypted); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } @@ -151,7 +168,7 @@ class ClientSessionTest : public CppUnit::TestFixture { server->sendTLSProceed(); server->breakTLS(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } @@ -167,10 +184,12 @@ class ClientSessionTest : public CppUnit::TestFixture { server->sendTLSProceed(); CPPUNIT_ASSERT(server->tlsEncrypted); server->onTLSEncrypted(); + server->close(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); + CPPUNIT_ASSERT(std::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError)); CPPUNIT_ASSERT_EQUAL(CertificateVerificationError::InvalidServerIdentity, std::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError)->getType()); } @@ -181,7 +200,11 @@ class ClientSessionTest : public CppUnit::TestFixture { server->sendStreamStart(); server->sendEmptyStreamFeatures(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } @@ -193,7 +216,7 @@ class ClientSessionTest : public CppUnit::TestFixture { server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); CPPUNIT_ASSERT(needCredentials); - CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::WaitingForCredentials, session->getState()); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthSuccess(); @@ -209,12 +232,16 @@ class ClientSessionTest : public CppUnit::TestFixture { server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); CPPUNIT_ASSERT(needCredentials); - CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::WaitingForCredentials, session->getState()); session->sendCredentials(createSafeByteArray("mypass")); server->receiveAuthRequest("PLAIN"); server->sendAuthFailure(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } @@ -227,7 +254,11 @@ class ClientSessionTest : public CppUnit::TestFixture { server->sendStreamStart(); server->sendStreamFeaturesWithPLAINAuthentication(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } @@ -240,8 +271,11 @@ class ClientSessionTest : public CppUnit::TestFixture { server->receiveStreamStart(); server->sendStreamStart(); server->sendStreamFeaturesWithMultipleAuthentication(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } @@ -253,7 +287,12 @@ class ClientSessionTest : public CppUnit::TestFixture { server->sendStreamStart(); server->sendStreamFeaturesWithUnknownAuthentication(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + server->close(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } @@ -281,7 +320,11 @@ class ClientSessionTest : public CppUnit::TestFixture { server->sendChallenge(); server->sendChallenge(); - CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); CPPUNIT_ASSERT(sessionFinishedReceived); CPPUNIT_ASSERT(sessionFinishedError); } @@ -305,7 +348,7 @@ class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_ASSERT(session->getStreamManagementEnabled()); // TODO: Test if the requesters & responders do their work - CPPUNIT_ASSERT_EQUAL(ClientSession::Initialized, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Initialized, session->getState()); session->finish(); } @@ -328,7 +371,7 @@ class ClientSessionTest : public CppUnit::TestFixture { server->sendStreamManagementFailed(); CPPUNIT_ASSERT(!session->getStreamManagementEnabled()); - CPPUNIT_ASSERT_EQUAL(ClientSession::Initialized, session->getState()); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Initialized, session->getState()); session->finish(); } @@ -345,9 +388,44 @@ class ClientSessionTest : public CppUnit::TestFixture { server->receiveAck(3); } + void testServerInitiatedSessionClose() { + std::shared_ptr<ClientSession> session(createSession()); + initializeSession(session); + + server->onStreamEndReceived(); + server->close(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + } + + void testClientInitiatedSessionClose() { + std::shared_ptr<ClientSession> session(createSession()); + initializeSession(session); + + session->finish(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + + server->onStreamEndReceived(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); + } + + void testTimeoutOnShutdown() { + std::shared_ptr<ClientSession> session(createSession()); + initializeSession(session); + + session->finish(); + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finishing, session->getState()); + CPPUNIT_ASSERT_EQUAL(true, server->receivedEvents.back().footer); + timerFactory->setTime(60000); + + CPPUNIT_ASSERT_EQUAL(ClientSession::State::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + } + private: std::shared_ptr<ClientSession> createSession() { - std::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server, idnConverter.get(), crypto.get()); + std::shared_ptr<ClientSession> session = ClientSession::create(JID("me@foo.com"), server, idnConverter.get(), crypto.get(), timerFactory.get()); session->onFinished.connect(boost::bind(&ClientSessionTest::handleSessionFinished, this, _1)); session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::handleSessionNeedCredentials, this)); session->setAllowPLAINOverNonTLS(true); @@ -631,6 +709,7 @@ class ClientSessionTest : public CppUnit::TestFixture { std::shared_ptr<Error> sessionFinishedError; BlindCertificateTrustChecker* blindCertificateTrustChecker; std::shared_ptr<CryptoProvider> crypto; + std::shared_ptr<DummyTimerFactory> timerFactory; }; CPPUNIT_TEST_SUITE_REGISTRATION(ClientSessionTest); |