diff options
author | Remko Tronçon <git@el-tramo.be> | 2009-11-10 21:24:03 (GMT) |
---|---|---|
committer | Remko Tronçon <git@el-tramo.be> | 2009-11-10 21:24:03 (GMT) |
commit | 54781ce12f7654f8136e645d4ebc5934d90c6bea (patch) | |
tree | 90bad869f9f64d57a3c0af209b83a538a47c7762 /Swiften/Client | |
parent | fcfac59db5cb4503554f2b30854b2e91928296f6 (diff) | |
parent | 66ced3654ad295478b33d3e4f1716f66ab4048b5 (diff) | |
download | swift-contrib-54781ce12f7654f8136e645d4ebc5934d90c6bea.zip swift-contrib-54781ce12f7654f8136e645d4ebc5934d90c6bea.tar.bz2 |
Refactored session management.
Diffstat (limited to 'Swiften/Client')
-rw-r--r-- | Swiften/Client/Client.cpp | 41 | ||||
-rw-r--r-- | Swiften/Client/Client.h | 15 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.cpp | 197 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.h | 83 | ||||
-rw-r--r-- | Swiften/Client/UnitTest/ClientSessionTest.cpp | 2 |
5 files changed, 170 insertions, 168 deletions
diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp index 60dfade..9e38626 100644 --- a/Swiften/Client/Client.cpp +++ b/Swiften/Client/Client.cpp @@ -10,6 +10,7 @@ #include "Swiften/Network/BoostConnectionFactory.h" #include "Swiften/Network/DomainNameResolveException.h" #include "Swiften/TLS/PKCS12Certificate.h" +#include "Swiften/Session/BasicSessionStream.h" namespace Swift { @@ -20,6 +21,9 @@ Client::Client(const JID& jid, const String& password) : } Client::~Client() { + if (session_ || connection_) { + std::cerr << "Warning: Client not disconnected properly" << std::endl; + } delete tlsLayerFactory_; delete connectionFactory_; } @@ -46,23 +50,32 @@ void Client::handleConnectionConnectFinished(bool error) { onError(ClientError::ConnectionError); } else { - session_ = boost::shared_ptr<ClientSession>(new ClientSession(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_)); + assert(!sessionStream_); + sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(connection_, &payloadParserFactories_, &payloadSerializers_, tlsLayerFactory_)); if (!certificate_.isEmpty()) { - session_->setCertificate(PKCS12Certificate(certificate_, password_)); + sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_)); } - session_->onSessionStarted.connect(boost::bind(boost::ref(onConnected))); - session_->onSessionFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1)); + sessionStream_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1)); + sessionStream_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1)); + sessionStream_->initialize(); + + session_ = boost::shared_ptr<ClientSession>(new ClientSession(jid_, sessionStream_)); + session_->onInitialized.connect(boost::bind(boost::ref(onConnected))); + session_->onFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1)); session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this)); - session_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1)); - session_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1)); session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1)); - session_->startSession(); + session_->start(); } } void Client::disconnect() { if (session_) { - session_->finishSession(); + session_->finish(); + session_.reset(); + } + if (connection_) { + connection_->disconnect(); + connection_.reset(); } } @@ -110,9 +123,10 @@ void Client::setCertificate(const String& certificate) { certificate_ = certificate; } -void Client::handleSessionFinished(const boost::optional<Session::SessionError>& error) { +void Client::handleSessionFinished(boost::shared_ptr<Error> error) { if (error) { ClientError clientError; + /* switch (*error) { case Session::ConnectionReadError: clientError = ClientError(ClientError::ConnectionReadError); @@ -148,6 +162,7 @@ void Client::handleSessionFinished(const boost::optional<Session::SessionError>& clientError = ClientError(ClientError::ClientCertificateError); break; } + */ onError(clientError); } } @@ -156,12 +171,12 @@ void Client::handleNeedCredentials() { session_->sendCredentials(password_); } -void Client::handleDataRead(const ByteArray& data) { - onDataRead(String(data.getData(), data.getSize())); +void Client::handleDataRead(const String& data) { + onDataRead(data); } -void Client::handleDataWritten(const ByteArray& data) { - onDataWritten(String(data.getData(), data.getSize())); +void Client::handleDataWritten(const String& data) { + onDataWritten(data); } } diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h index 59e1c05..5188789 100644 --- a/Swiften/Client/Client.h +++ b/Swiften/Client/Client.h @@ -4,6 +4,7 @@ #include <boost/signals.hpp> #include <boost/shared_ptr.hpp> +#include "Swiften/Base/Error.h" #include "Swiften/Client/ClientSession.h" #include "Swiften/Client/ClientError.h" #include "Swiften/Elements/Presence.h" @@ -20,6 +21,7 @@ namespace Swift { class TLSLayerFactory; class ConnectionFactory; class ClientSession; + class BasicSessionStream; class Client : public StanzaChannel, public IQRouter, public boost::bsignals::trackable { public: @@ -38,7 +40,7 @@ namespace Swift { virtual void sendPresence(boost::shared_ptr<Presence>); public: - boost::signal<void (ClientError)> onError; + boost::signal<void (const ClientError&)> onError; boost::signal<void ()> onConnected; boost::signal<void (const String&)> onDataRead; boost::signal<void (const String&)> onDataWritten; @@ -48,10 +50,12 @@ namespace Swift { void send(boost::shared_ptr<Stanza>); virtual String getNewIQID(); void handleElement(boost::shared_ptr<Element>); - void handleSessionFinished(const boost::optional<Session::SessionError>& error); + void handleSessionFinished(boost::shared_ptr<Error>); void handleNeedCredentials(); - void handleDataRead(const ByteArray&); - void handleDataWritten(const ByteArray&); + void handleDataRead(const String&); + void handleDataWritten(const String&); + + void reset(); private: JID jid_; @@ -61,8 +65,9 @@ namespace Swift { TLSLayerFactory* tlsLayerFactory_; FullPayloadParserFactoryCollection payloadParserFactories_; FullPayloadSerializerCollection payloadSerializers_; - boost::shared_ptr<ClientSession> session_; boost::shared_ptr<Connection> connection_; + boost::shared_ptr<BasicSessionStream> sessionStream_; + boost::shared_ptr<ClientSession> session_; String certificate_; }; } diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index a0e1289..a185ea0 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -2,13 +2,7 @@ #include <boost/bind.hpp> -#include "Swiften/Network/ConnectionFactory.h" #include "Swiften/Elements/ProtocolHeader.h" -#include "Swiften/StreamStack/StreamStack.h" -#include "Swiften/StreamStack/ConnectionLayer.h" -#include "Swiften/StreamStack/XMPPLayer.h" -#include "Swiften/StreamStack/TLSLayer.h" -#include "Swiften/StreamStack/TLSLayerFactory.h" #include "Swiften/Elements/StreamFeatures.h" #include "Swiften/Elements/StartTLSRequest.h" #include "Swiften/Elements/StartTLSFailure.h" @@ -20,47 +14,47 @@ #include "Swiften/Elements/IQ.h" #include "Swiften/Elements/ResourceBind.h" #include "Swiften/SASL/PLAINMessage.h" -#include "Swiften/StreamStack/WhitespacePingLayer.h" +#include "Swiften/Session/SessionStream.h" namespace Swift { ClientSession::ClientSession( const JID& jid, - boost::shared_ptr<Connection> connection, - TLSLayerFactory* tlsLayerFactory, - PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers) : - Session(connection, payloadParserFactories, payloadSerializers), - tlsLayerFactory_(tlsLayerFactory), - state_(Initial), - needSessionStart_(false) { - setLocalJID(jid); - setRemoteJID(JID("", jid.getDomain())); + boost::shared_ptr<SessionStream> stream) : + localJID(jid), + state(Initial), + stream(stream), + needSessionStart(false) { } -void ClientSession::handleSessionStarted() { - assert(state_ == Initial); - state_ = WaitingForStreamStart; +void ClientSession::start() { + stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1)); + stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1)); + stream->onError.connect(boost::bind(&ClientSession::handleStreamError, shared_from_this(), _1)); + stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this())); + + assert(state == Initial); + state = WaitingForStreamStart; sendStreamHeader(); } void ClientSession::sendStreamHeader() { ProtocolHeader header; header.setTo(getRemoteJID()); - getXMPPLayer()->writeHeader(header); + stream->writeHeader(header); } -void ClientSession::setCertificate(const PKCS12Certificate& certificate) { - certificate_ = certificate; +void ClientSession::sendElement(boost::shared_ptr<Element> element) { + stream->writeElement(element); } void ClientSession::handleStreamStart(const ProtocolHeader&) { checkState(WaitingForStreamStart); - state_ = Negotiating; + state = Negotiating; } void ClientSession::handleElement(boost::shared_ptr<Element> element) { - if (getState() == SessionStarted) { + if (getState() == Initialized) { onElementReceived(element); } else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { @@ -68,152 +62,121 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { return; } - if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) { - state_ = Encrypting; - getXMPPLayer()->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); + if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption()) { + state = WaitingForEncrypt; + stream->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); } else if (streamFeatures->hasAuthenticationMechanisms()) { - if (!certificate_.isNull()) { - if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { - state_ = Authenticating; - getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); - } - else { - finishSession(ClientCertificateError); - } + if (stream->hasTLSCertificate() && streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { + state = Authenticating; + stream->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); } else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) { - state_ = WaitingForCredentials; + state = WaitingForCredentials; onNeedCredentials(); } else { - finishSession(NoSupportedAuthMechanismsError); + finishSession(Error::NoSupportedAuthMechanismsError); } } else { // Start the session - - // Add a whitespace ping layer - whitespacePingLayer_ = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer()); - getStreamStack()->addLayer(whitespacePingLayer_); - whitespacePingLayer_->setActive(); + stream->setWhitespacePingEnabled(true); if (streamFeatures->hasSession()) { - needSessionStart_ = true; + needSessionStart = true; } if (streamFeatures->hasResourceBind()) { - state_ = BindingResource; + state = BindingResource; boost::shared_ptr<ResourceBind> resourceBind(new ResourceBind()); - if (!getLocalJID().getResource().isEmpty()) { - resourceBind->setResource(getLocalJID().getResource()); + if (!localJID.getResource().isEmpty()) { + resourceBind->setResource(localJID.getResource()); } - getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); + stream->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); } - else if (needSessionStart_) { + else if (needSessionStart) { sendSessionStart(); } else { - state_ = SessionStarted; - onSessionStarted(); + state = Initialized; + onInitialized(); } } } else if (dynamic_cast<AuthSuccess*>(element.get())) { checkState(Authenticating); - state_ = WaitingForStreamStart; - getXMPPLayer()->resetParser(); + state = WaitingForStreamStart; + stream->resetXMPPParser(); sendStreamHeader(); } else if (dynamic_cast<AuthFailure*>(element.get())) { - finishSession(AuthenticationFailedError); + finishSession(Error::AuthenticationFailedError); } else if (dynamic_cast<TLSProceed*>(element.get())) { - tlsLayer_ = tlsLayerFactory_->createTLSLayer(); - getStreamStack()->addLayer(tlsLayer_); - if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) { - finishSession(ClientCertificateLoadError); - } - else { - tlsLayer_->onConnected.connect(boost::bind(&ClientSession::handleTLSConnected, this)); - tlsLayer_->onError.connect(boost::bind(&ClientSession::handleTLSError, this)); - tlsLayer_->connect(); - } + checkState(WaitingForEncrypt); + state = Encrypting; + stream->addTLSEncryption(); } else if (dynamic_cast<StartTLSFailure*>(element.get())) { - finishSession(TLSError); + finishSession(Error::TLSError); } else if (IQ* iq = dynamic_cast<IQ*>(element.get())) { - if (state_ == BindingResource) { + if (state == BindingResource) { boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { - finishSession(ResourceBindError); + finishSession(Error::ResourceBindError); } else if (!resourceBind) { - finishSession(UnexpectedElementError); + finishSession(Error::UnexpectedElementError); } else if (iq->getType() == IQ::Result) { - setLocalJID(resourceBind->getJID()); - if (!getLocalJID().isValid()) { - finishSession(ResourceBindError); + localJID = resourceBind->getJID(); + if (!localJID.isValid()) { + finishSession(Error::ResourceBindError); } - if (needSessionStart_) { + if (needSessionStart) { sendSessionStart(); } else { - state_ = SessionStarted; + state = Initialized; } } else { - finishSession(UnexpectedElementError); + finishSession(Error::UnexpectedElementError); } } - else if (state_ == StartingSession) { + else if (state == StartingSession) { if (iq->getType() == IQ::Result) { - state_ = SessionStarted; - onSessionStarted(); + state = Initialized; + onInitialized(); } else if (iq->getType() == IQ::Error) { - finishSession(SessionStartError); + finishSession(Error::SessionStartError); } else { - finishSession(UnexpectedElementError); + finishSession(Error::UnexpectedElementError); } } else { - finishSession(UnexpectedElementError); + finishSession(Error::UnexpectedElementError); } } else { // FIXME Not correct? - state_ = SessionStarted; - onSessionStarted(); + state = Initialized; + onInitialized(); } } void ClientSession::sendSessionStart() { - state_ = StartingSession; - getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); -} - -void ClientSession::handleSessionFinished(const boost::optional<SessionError>& error) { - if (whitespacePingLayer_) { - whitespacePingLayer_->setInactive(); - } - - if (error) { - //assert(!error_); - state_ = Error; - error_ = error; - } - else { - state_ = Finished; - } + state = StartingSession; + stream->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); } bool ClientSession::checkState(State state) { - if (state_ != state) { - finishSession(UnexpectedElementError); + if (state != state) { + finishSession(Error::UnexpectedElementError); return false; } return true; @@ -221,18 +184,36 @@ bool ClientSession::checkState(State state) { void ClientSession::sendCredentials(const String& password) { assert(WaitingForCredentials); - state_ = Authenticating; - getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(getLocalJID().getNode(), password).getValue()))); + state = Authenticating; + stream->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(localJID.getNode(), password).getValue()))); } -void ClientSession::handleTLSConnected() { - state_ = WaitingForStreamStart; - getXMPPLayer()->resetParser(); +void ClientSession::handleTLSEncrypted() { + checkState(WaitingForEncrypt); + state = WaitingForStreamStart; + stream->resetXMPPParser(); sendStreamHeader(); } -void ClientSession::handleTLSError() { - finishSession(TLSError); +void ClientSession::handleStreamError(boost::shared_ptr<Swift::Error> error) { + finishSession(error); +} + +void ClientSession::finish() { + if (stream->isAvailable()) { + stream->writeFooter(); + } + finishSession(boost::shared_ptr<Error>()); } +void ClientSession::finishSession(Error::Type error) { + finishSession(boost::shared_ptr<Swift::ClientSession::Error>(new Swift::ClientSession::Error(error))); +} + +void ClientSession::finishSession(boost::shared_ptr<Swift::Error> error) { + stream->setWhitespacePingEnabled(false); + onFinished(error); +} + + } diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index fead182..e09861b 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -2,87 +2,88 @@ #include <boost/signal.hpp> #include <boost/shared_ptr.hpp> +#include <boost/enable_shared_from_this.hpp> -#include "Swiften/Session/Session.h" +#include "Swiften/Base/Error.h" +#include "Swiften/Session/SessionStream.h" +#include "Swiften/Session/BasicSessionStream.h" #include "Swiften/Base/String.h" #include "Swiften/JID/JID.h" #include "Swiften/Elements/Element.h" -#include "Swiften/Network/Connection.h" -#include "Swiften/TLS/PKCS12Certificate.h" namespace Swift { - class PayloadParserFactoryCollection; - class PayloadSerializerCollection; - class ConnectionFactory; - class Connection; - class StreamStack; - class XMPPLayer; - class ConnectionLayer; - class TLSLayerFactory; - class TLSLayer; - class WhitespacePingLayer; - - class ClientSession : public Session { + class ClientSession : public boost::enable_shared_from_this<ClientSession> { public: enum State { Initial, WaitingForStreamStart, Negotiating, Compressing, + WaitingForEncrypt, Encrypting, WaitingForCredentials, Authenticating, BindingResource, StartingSession, - SessionStarted, - Error, + Initialized, Finished }; + struct Error : public Swift::Error { + enum Type { + AuthenticationFailedError, + NoSupportedAuthMechanismsError, + UnexpectedElementError, + ResourceBindError, + SessionStartError, + TLSError, + } type; + Error(Type type) : type(type) {} + }; + ClientSession( const JID& jid, - boost::shared_ptr<Connection>, - TLSLayerFactory*, - PayloadParserFactoryCollection*, - PayloadSerializerCollection*); + boost::shared_ptr<SessionStream>); State getState() const { - return state_; + return state; } - boost::optional<SessionError> getError() const { - return error_; - } + void start(); + void finish(); void sendCredentials(const String& password); - void setCertificate(const PKCS12Certificate& certificate); + void sendElement(boost::shared_ptr<Element> element); private: + void finishSession(Error::Type error); + void finishSession(boost::shared_ptr<Swift::Error> error); + + JID getRemoteJID() const { + return JID("", localJID.getDomain()); + } + void sendStreamHeader(); void sendSessionStart(); - virtual void handleSessionStarted(); - virtual void handleSessionFinished(const boost::optional<SessionError>& error); - virtual void handleElement(boost::shared_ptr<Element>); - virtual void handleStreamStart(const ProtocolHeader&); + void handleElement(boost::shared_ptr<Element>); + void handleStreamStart(const ProtocolHeader&); + void handleStreamError(boost::shared_ptr<Swift::Error>); - void handleTLSConnected(); - void handleTLSError(); + void handleTLSEncrypted(); - void setError(SessionError); bool checkState(State); public: boost::signal<void ()> onNeedCredentials; - boost::signal<void ()> onSessionStarted; + boost::signal<void ()> onInitialized; + boost::signal<void (boost::shared_ptr<Swift::Error>)> onFinished; + boost::signal<void (boost::shared_ptr<Element>)> onElementReceived; private: - TLSLayerFactory* tlsLayerFactory_; - State state_; - boost::optional<SessionError> error_; - boost::shared_ptr<TLSLayer> tlsLayer_; - boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer_; - bool needSessionStart_; - PKCS12Certificate certificate_; + JID localJID; + State state; + boost::shared_ptr<SessionStream> stream; + bool needSessionStart; }; } diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index cbf20d2..70d4ba9 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -14,7 +14,7 @@ #include "Swiften/Elements/ProtocolHeader.h" #include "Swiften/Elements/StreamFeatures.h" #include "Swiften/Elements/Element.h" -#include "Swiften/Elements/Error.h" +#include "Swiften/Elements/ErrorPayload.h" #include "Swiften/Elements/IQ.h" #include "Swiften/Elements/AuthRequest.h" #include "Swiften/Elements/AuthSuccess.h" |