diff options
Diffstat (limited to 'Swiften/Client/ClientSession.cpp')
-rw-r--r-- | Swiften/Client/ClientSession.cpp | 184 |
1 files changed, 123 insertions, 61 deletions
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index c301881..bcfb004 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -6,42 +6,47 @@ #include <Swiften/Client/ClientSession.h> +#include <memory> + #include <boost/bind.hpp> #include <boost/uuid/uuid.hpp> #include <boost/uuid/uuid_io.hpp> #include <boost/uuid/uuid_generators.hpp> -#include <memory> -#include <Swiften/Base/Platform.h> #include <Swiften/Base/Log.h> -#include <Swiften/Elements/ProtocolHeader.h> -#include <Swiften/Elements/StreamFeatures.h> -#include <Swiften/Elements/StreamError.h> -#include <Swiften/Elements/StartTLSRequest.h> -#include <Swiften/Elements/StartTLSFailure.h> -#include <Swiften/Elements/TLSProceed.h> -#include <Swiften/Elements/AuthRequest.h> -#include <Swiften/Elements/AuthSuccess.h> -#include <Swiften/Elements/AuthFailure.h> +#include <Swiften/Base/Platform.h> +#include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Elements/AuthChallenge.h> +#include <Swiften/Elements/AuthFailure.h> +#include <Swiften/Elements/AuthRequest.h> #include <Swiften/Elements/AuthResponse.h> -#include <Swiften/Elements/Compressed.h> +#include <Swiften/Elements/AuthSuccess.h> #include <Swiften/Elements/CompressFailure.h> #include <Swiften/Elements/CompressRequest.h> +#include <Swiften/Elements/Compressed.h> #include <Swiften/Elements/EnableStreamManagement.h> -#include <Swiften/Elements/StreamManagementEnabled.h> -#include <Swiften/Elements/StreamManagementFailed.h> -#include <Swiften/Elements/StartSession.h> -#include <Swiften/Elements/StanzaAck.h> -#include <Swiften/Elements/StanzaAckRequest.h> #include <Swiften/Elements/IQ.h> +#include <Swiften/Elements/ProtocolHeader.h> #include <Swiften/Elements/ResourceBind.h> -#include <Swiften/SASL/PLAINClientAuthenticator.h> +#include <Swiften/Elements/StanzaAck.h> +#include <Swiften/Elements/StanzaAckRequest.h> +#include <Swiften/Elements/StartSession.h> +#include <Swiften/Elements/StartTLSFailure.h> +#include <Swiften/Elements/StartTLSRequest.h> +#include <Swiften/Elements/StreamError.h> +#include <Swiften/Elements/StreamFeatures.h> +#include <Swiften/Elements/StreamManagementEnabled.h> +#include <Swiften/Elements/StreamManagementFailed.h> +#include <Swiften/Elements/TLSProceed.h> +#include <Swiften/Network/Timer.h> +#include <Swiften/Network/TimerFactory.h> +#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h> #include <Swiften/SASL/EXTERNALClientAuthenticator.h> +#include <Swiften/SASL/PLAINClientAuthenticator.h> #include <Swiften/SASL/SCRAMSHA1ClientAuthenticator.h> -#include <Swiften/SASL/DIGESTMD5ClientAuthenticator.h> -#include <Swiften/Crypto/CryptoProvider.h> #include <Swiften/Session/SessionStream.h> +#include <Swiften/StreamManagement/StanzaAckRequester.h> +#include <Swiften/StreamManagement/StanzaAckResponder.h> #include <Swiften/TLS/CertificateTrustChecker.h> #include <Swiften/TLS/ServerIdentityVerifier.h> @@ -59,12 +64,14 @@ ClientSession::ClientSession( const JID& jid, std::shared_ptr<SessionStream> stream, IDNConverter* idnConverter, - CryptoProvider* crypto) : + CryptoProvider* crypto, + TimerFactory* timerFactory) : localJID(jid), - state(Initial), + state(State::Initial), stream(stream), idnConverter(idnConverter), crypto(crypto), + timerFactory(timerFactory), allowPLAINOverNonTLS(false), useStreamCompression(true), useTLS(UseTLSWhenAvailable), @@ -89,12 +96,13 @@ ClientSession::~ClientSession() { void ClientSession::start() { stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onStreamEndReceived.connect(boost::bind(&ClientSession::handleStreamEnd, shared_from_this())); stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); stream->onClosed.connect(boost::bind(&ClientSession::handleStreamClosed, shared_from_this(), _1)); stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); - assert(state == Initial); - state = WaitingForStreamStart; + assert(state == State::Initial); + state = State::WaitingForStreamStart; sendStreamHeader(); } @@ -112,8 +120,19 @@ void ClientSession::sendStanza(std::shared_ptr<Stanza> stanza) { } void ClientSession::handleStreamStart(const ProtocolHeader&) { - CHECK_STATE_OR_RETURN(WaitingForStreamStart); - state = Negotiating; + CHECK_STATE_OR_RETURN(State::WaitingForStreamStart); + state = State::Negotiating; +} + +void ClientSession::handleStreamEnd() { + if (state == State::Finishing) { + // We are already in finishing state if we iniated the close of the session. + stream->close(); + } + else { + state = State::Finishing; + initiateShutdown(true); + } } void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { @@ -121,11 +140,11 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { if (stanzaAckResponder_) { stanzaAckResponder_->handleStanzaReceived(); } - if (getState() == Initialized) { + if (getState() == State::Initialized) { onStanzaReceived(stanza); } else if (std::shared_ptr<IQ> iq = std::dynamic_pointer_cast<IQ>(element)) { - if (state == BindingResource) { + if (state == State::BindingResource) { std::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { finishSession(Error::ResourceBindError); @@ -145,7 +164,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { finishSession(Error::UnexpectedElementError); } } - else if (state == StartingSession) { + else if (state == State::StartingSession) { if (iq->getType() == IQ::Result) { needSessionStart = false; continueSessionInitialization(); @@ -183,7 +202,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { else if (StreamError::ref streamError = std::dynamic_pointer_cast<StreamError>(element)) { finishSession(Error::StreamError); } - else if (getState() == Initialized) { + else if (getState() == State::Initialized) { std::shared_ptr<Stanza> stanza = std::dynamic_pointer_cast<Stanza>(element); if (stanza) { if (stanzaAckResponder_) { @@ -193,17 +212,17 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { } } else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { - CHECK_STATE_OR_RETURN(Negotiating); + CHECK_STATE_OR_RETURN(State::Negotiating); if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption() && useTLS != NeverUseTLS) { - state = WaitingForEncrypt; + state = State::WaitingForEncrypt; stream->writeElement(std::make_shared<StartTLSRequest>()); } else if (useTLS == RequireTLS && !stream->isTLSEncrypted()) { finishSession(Error::NoSupportedAuthMechanismsError); } else if (useStreamCompression && stream->supportsZLibCompression() && streamFeatures->hasCompressionMethod("zlib")) { - state = Compressing; + state = State::Compressing; stream->writeElement(std::make_shared<CompressRequest>("zlib")); } else if (streamFeatures->hasAuthenticationMechanisms()) { @@ -222,7 +241,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { authenticator = gssapiAuthenticator; if (!gssapiAuthenticator->isError()) { - state = Authenticating; + state = State::Authenticating; stream->writeElement(std::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse())); } else { @@ -236,7 +255,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { if (stream->hasTLSCertificate()) { if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { authenticator = new EXTERNALClientAuthenticator(); - state = Authenticating; + state = State::Authenticating; stream->writeElement(std::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray(""))); } else { @@ -245,7 +264,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { } else if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { authenticator = new EXTERNALClientAuthenticator(); - state = Authenticating; + state = State::Authenticating; stream->writeElement(std::make_shared<AuthRequest>("EXTERNAL", createSafeByteArray(""))); } else if (streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1") || streamFeatures->hasAuthenticationMechanism("SCRAM-SHA-1-PLUS")) { @@ -262,12 +281,12 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { scramAuthenticator->setTLSChannelBindingData(finishMessage); } authenticator = scramAuthenticator; - state = WaitingForCredentials; + state = State::WaitingForCredentials; onNeedCredentials(); } else if ((stream->isTLSEncrypted() || allowPLAINOverNonTLS) && streamFeatures->hasAuthenticationMechanism("PLAIN")) { authenticator = new PLAINClientAuthenticator(); - state = WaitingForCredentials; + state = State::WaitingForCredentials; onNeedCredentials(); } else if (streamFeatures->hasAuthenticationMechanism("DIGEST-MD5") && crypto->isMD5AllowedForCrypto()) { @@ -275,7 +294,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { s << boost::uuids::random_generator()(); // FIXME: Host should probably be the actual host authenticator = new DIGESTMD5ClientAuthenticator(localJID.getDomain(), s.str(), crypto); - state = WaitingForCredentials; + state = State::WaitingForCredentials; onNeedCredentials(); } else { @@ -299,8 +318,8 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { } } else if (std::dynamic_pointer_cast<Compressed>(element)) { - CHECK_STATE_OR_RETURN(Compressing); - state = WaitingForStreamStart; + CHECK_STATE_OR_RETURN(State::Compressing); + state = State::WaitingForStreamStart; stream->addZLibCompression(); stream->resetXMPPParser(); sendStreamHeader(); @@ -322,7 +341,7 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { continueSessionInitialization(); } else if (AuthChallenge* challenge = dynamic_cast<AuthChallenge*>(element.get())) { - CHECK_STATE_OR_RETURN(Authenticating); + CHECK_STATE_OR_RETURN(State::Authenticating); assert(authenticator); if (authenticator->setChallenge(challenge->getValue())) { stream->writeElement(std::make_shared<AuthResponse>(authenticator->getResponse())); @@ -340,13 +359,13 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { } } else if (AuthSuccess* authSuccess = dynamic_cast<AuthSuccess*>(element.get())) { - CHECK_STATE_OR_RETURN(Authenticating); + CHECK_STATE_OR_RETURN(State::Authenticating); assert(authenticator); if (!authenticator->setChallenge(authSuccess->getValue())) { finishSession(Error::ServerVerificationFailedError); } else { - state = WaitingForStreamStart; + state = State::WaitingForStreamStart; delete authenticator; authenticator = nullptr; stream->resetXMPPParser(); @@ -357,8 +376,8 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { finishSession(Error::AuthenticationFailedError); } else if (dynamic_cast<TLSProceed*>(element.get())) { - CHECK_STATE_OR_RETURN(WaitingForEncrypt); - state = Encrypting; + CHECK_STATE_OR_RETURN(State::WaitingForEncrypt); + state = State::Encrypting; stream->addTLSEncryption(); } else if (dynamic_cast<StartTLSFailure*>(element.get())) { @@ -366,14 +385,14 @@ void ClientSession::handleElement(std::shared_ptr<ToplevelElement> element) { } else { // FIXME Not correct? - state = Initialized; + state = State::Initialized; onInitialized(); } } void ClientSession::continueSessionInitialization() { if (needResourceBind) { - state = BindingResource; + state = State::BindingResource; std::shared_ptr<ResourceBind> resourceBind(std::make_shared<ResourceBind>()); if (!localJID.getResource().empty()) { resourceBind->setResource(localJID.getResource()); @@ -381,15 +400,15 @@ void ClientSession::continueSessionInitialization() { sendStanza(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); } else if (needAcking) { - state = EnablingSessionManagement; + state = State::EnablingSessionManagement; stream->writeElement(std::make_shared<EnableStreamManagement>()); } else if (needSessionStart) { - state = StartingSession; + state = State::StartingSession; sendStanza(IQ::createRequest(IQ::Set, JID(), "session-start", std::make_shared<StartSession>())); } else { - state = Initialized; + state = State::Initialized; onInitialized(); } } @@ -403,15 +422,15 @@ bool ClientSession::checkState(State state) { } void ClientSession::sendCredentials(const SafeByteArray& password) { - assert(WaitingForCredentials); + assert(state == State::WaitingForCredentials); assert(authenticator); - state = Authenticating; + state = State::Authenticating; authenticator->setCredentials(localJID.getNode(), password); stream->writeElement(std::make_shared<AuthRequest>(authenticator->getName(), authenticator->getResponse())); } void ClientSession::handleTLSEncrypted() { - CHECK_STATE_OR_RETURN(Encrypting); + CHECK_STATE_OR_RETURN(State::Encrypting); std::vector<Certificate::ref> certificateChain = stream->getPeerCertificateChain(); std::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError(); @@ -438,15 +457,38 @@ void ClientSession::checkTrustOrFinish(const std::vector<Certificate::ref>& cert } } +void ClientSession::initiateShutdown(bool sendFooter) { + if (!streamShutdownTimeout) { + streamShutdownTimeout = timerFactory->createTimer(sessionShutdownTimeoutInMilliseconds); + streamShutdownTimeout->onTick.connect(boost::bind(&ClientSession::handleStreamShutdownTimeout, shared_from_this())); + streamShutdownTimeout->start(); + } + if (sendFooter) { + stream->writeFooter(); + } + if (state == State::Finishing) { + // The other side already send </stream>; we can close the socket. + stream->close(); + } + else { + state = State::Finishing; + } +} + void ClientSession::continueAfterTLSEncrypted() { - state = WaitingForStreamStart; + state = State::WaitingForStreamStart; stream->resetXMPPParser(); sendStreamHeader(); } void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError) { State previousState = state; - state = Finished; + state = State::Finished; + + if (streamShutdownTimeout) { + streamShutdownTimeout->stop(); + streamShutdownTimeout.reset(); + } if (stanzaAckRequester_) { stanzaAckRequester_->onRequestAck.disconnect(boost::bind(&ClientSession::requestAck, shared_from_this())); @@ -459,11 +501,12 @@ void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError } stream->setWhitespacePingEnabled(false); stream->onStreamStartReceived.disconnect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onStreamEndReceived.disconnect(boost::bind(&ClientSession::handleStreamEnd, shared_from_this())); stream->onElementReceived.disconnect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); stream->onClosed.disconnect(boost::bind(&ClientSession::handleStreamClosed, shared_from_this(), _1)); stream->onTLSEncrypted.disconnect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); - if (previousState == Finishing) { + if (previousState == State::Finishing) { onFinished(error_); } else { @@ -471,8 +514,17 @@ void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError } } +void ClientSession::handleStreamShutdownTimeout() { + handleStreamClosed(std::shared_ptr<Swift::Error>()); +} + void ClientSession::finish() { - finishSession(std::shared_ptr<Error>()); + if (state != State::Finishing && state != State::Finished) { + finishSession(std::shared_ptr<Error>()); + } + else { + SWIFT_LOG(warning) << "Session already finished or finishing." << std::endl; + } } void ClientSession::finishSession(Error::Type error) { @@ -480,7 +532,6 @@ void ClientSession::finishSession(Error::Type error) { } void ClientSession::finishSession(std::shared_ptr<Swift::Error> error) { - state = Finishing; if (!error_) { error_ = error; } @@ -495,8 +546,19 @@ void ClientSession::finishSession(std::shared_ptr<Swift::Error> error) { delete authenticator; authenticator = nullptr; } - stream->writeFooter(); - stream->close(); + // Immidiately close TCP connection without stream closure. + if (std::dynamic_pointer_cast<CertificateVerificationError>(error)) { + state = State::Finishing; + initiateShutdown(false); + } + else { + if (state == State::Finishing) { + initiateShutdown(true); + } + else if (state != State::Finished) { + initiateShutdown(true); + } + } } void ClientSession::requestAck() { |