diff options
-rw-r--r-- | Swiften/Client/ClientSession.cpp | 20 | ||||
-rw-r--r-- | Swiften/Network/BOSHConnectionPool.cpp | 1 | ||||
-rw-r--r-- | Swiften/Network/BOSHConnectionPool.h | 1 | ||||
-rw-r--r-- | Swiften/Session/BOSHSessionStream.cpp | 7 | ||||
-rw-r--r-- | Swiften/Session/BOSHSessionStream.h | 1 |
5 files changed, 23 insertions, 7 deletions
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index bcfb004..661a832 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -45,6 +45,8 @@ #include <Swiften/SASL/PLAINClientAuthenticator.h> #include <Swiften/SASL/SCRAMSHA1ClientAuthenticator.h> #include <Swiften/Session/SessionStream.h> +#include <Swiften/Session/BasicSessionStream.h> +#include <Swiften/Session/BOSHSessionStream.h> #include <Swiften/StreamManagement/StanzaAckRequester.h> #include <Swiften/StreamManagement/StanzaAckResponder.h> #include <Swiften/TLS/CertificateTrustChecker.h> @@ -430,7 +432,9 @@ void ClientSession::sendCredentials(const SafeByteArray& password) { } void ClientSession::handleTLSEncrypted() { - CHECK_STATE_OR_RETURN(State::Encrypting); + if (!std::dynamic_pointer_cast<BOSHSessionStream>(stream)) { + CHECK_STATE_OR_RETURN(State::Encrypting); + } std::vector<Certificate::ref> certificateChain = stream->getPeerCertificateChain(); std::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError(); @@ -450,7 +454,9 @@ void ClientSession::handleTLSEncrypted() { void ClientSession::checkTrustOrFinish(const std::vector<Certificate::ref>& certificateChain, std::shared_ptr<CertificateVerificationError> error) { if (certificateTrustChecker && certificateTrustChecker->isCertificateTrusted(certificateChain)) { - continueAfterTLSEncrypted(); + if (!std::dynamic_pointer_cast<BOSHSessionStream>(stream)) { + continueAfterTLSEncrypted(); + } } else { finishSession(error); @@ -476,9 +482,11 @@ void ClientSession::initiateShutdown(bool sendFooter) { } void ClientSession::continueAfterTLSEncrypted() { - state = State::WaitingForStreamStart; - stream->resetXMPPParser(); - sendStreamHeader(); + if (!std::dynamic_pointer_cast<BOSHSessionStream>(stream)) { + state = State::WaitingForStreamStart; + stream->resetXMPPParser(); + sendStreamHeader(); + } } void ClientSession::handleStreamClosed(std::shared_ptr<Swift::Error> streamError) { @@ -536,7 +544,7 @@ void ClientSession::finishSession(std::shared_ptr<Swift::Error> error) { error_ = error; } else { - SWIFT_LOG(warning) << "Session finished twice"; + SWIFT_LOG(warning) << "Session finished twice" << std::endl; } assert(stream->isOpen()); if (stanzaAckResponder_) { diff --git a/Swiften/Network/BOSHConnectionPool.cpp b/Swiften/Network/BOSHConnectionPool.cpp index e4ca471..8a75e81 100644 --- a/Swiften/Network/BOSHConnectionPool.cpp +++ b/Swiften/Network/BOSHConnectionPool.cpp @@ -142,6 +142,7 @@ void BOSHConnectionPool::handleConnectFinished(bool error, BOSHConnection::ref c } if (!pinnedCertificateChain_.empty()) { lastVerificationError_ = connection->getPeerCertificateVerificationError(); + onTLSConnectionEstablished(); } if (sid.empty()) { diff --git a/Swiften/Network/BOSHConnectionPool.h b/Swiften/Network/BOSHConnectionPool.h index 1a805de..c4d827c 100644 --- a/Swiften/Network/BOSHConnectionPool.h +++ b/Swiften/Network/BOSHConnectionPool.h @@ -41,6 +41,7 @@ namespace Swift { std::vector<Certificate::ref> getPeerCertificateChain() const; std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const; + boost::signals2::signal<void ()> onTLSConnectionEstablished; boost::signals2::signal<void (BOSHError::ref)> onSessionTerminated; boost::signals2::signal<void ()> onSessionStarted; boost::signals2::signal<void (const SafeByteArray&)> onXMPPDataRead; diff --git a/Swiften/Session/BOSHSessionStream.cpp b/Swiften/Session/BOSHSessionStream.cpp index 4c7bdee..a335b93 100644 --- a/Swiften/Session/BOSHSessionStream.cpp +++ b/Swiften/Session/BOSHSessionStream.cpp @@ -55,7 +55,7 @@ BOSHSessionStream::BOSHSessionStream(const URL& boshURL, connectionPool->onXMPPDataRead.connect(boost::bind(&BOSHSessionStream::handlePoolXMPPDataRead, this, _1)); connectionPool->onBOSHDataRead.connect(boost::bind(&BOSHSessionStream::handlePoolBOSHDataRead, this, _1)); connectionPool->onBOSHDataWritten.connect(boost::bind(&BOSHSessionStream::handlePoolBOSHDataWritten, this, _1)); - + connectionPool->onTLSConnectionEstablished.connect(boost::bind(&BOSHSessionStream::handlePoolTLSEstablished, this)); xmppLayer = new XMPPLayer(payloadParserFactories, payloadSerializers, xmlParserFactory, ClientStreamType, true); xmppLayer->onStreamStart.connect(boost::bind(&BOSHSessionStream::handleStreamStartReceived, this, _1)); xmppLayer->onElement.connect(boost::bind(&BOSHSessionStream::handleElementReceived, this, _1)); @@ -72,6 +72,7 @@ BOSHSessionStream::~BOSHSessionStream() { connectionPool->onXMPPDataRead.disconnect(boost::bind(&BOSHSessionStream::handlePoolXMPPDataRead, this, _1)); connectionPool->onBOSHDataRead.disconnect(boost::bind(&BOSHSessionStream::handlePoolBOSHDataRead, this, _1)); connectionPool->onBOSHDataWritten.disconnect(boost::bind(&BOSHSessionStream::handlePoolBOSHDataWritten, this, _1)); + connectionPool->onTLSConnectionEstablished.disconnect(boost::bind(&BOSHSessionStream::handlePoolTLSEstablished, this)); delete connectionPool; connectionPool = nullptr; xmppLayer->onStreamStart.disconnect(boost::bind(&BOSHSessionStream::handleStreamStartReceived, this, _1)); @@ -178,6 +179,10 @@ void BOSHSessionStream::handlePoolSessionTerminated(BOSHError::ref error) { eventLoop->postEvent(boost::bind(&BOSHSessionStream::fakeStreamFooterReceipt, this, error), shared_from_this()); } +void BOSHSessionStream::handlePoolTLSEstablished() { + onTLSEncrypted(); +} + void BOSHSessionStream::writeHeader(const ProtocolHeader& header) { streamHeader = header; /*First time we're told to do this, don't (the sending of the initial header is handled on connect) diff --git a/Swiften/Session/BOSHSessionStream.h b/Swiften/Session/BOSHSessionStream.h index 0c26848..719f1f0 100644 --- a/Swiften/Session/BOSHSessionStream.h +++ b/Swiften/Session/BOSHSessionStream.h @@ -85,6 +85,7 @@ namespace Swift { void handlePoolBOSHDataRead(const SafeByteArray& data); void handlePoolBOSHDataWritten(const SafeByteArray& data); void handlePoolSessionTerminated(BOSHError::ref condition); + void handlePoolTLSEstablished(); private: void fakeStreamHeaderReceipt(); |