From a6fcd9e7aa12c5e00c61ff809e81fba14babd70c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be> Date: Sun, 19 Jul 2009 15:21:38 +0200 Subject: Factor out common session stuff into Session class. diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp index a38416a..95f6c0f 100644 --- a/Swiften/Client/Client.cpp +++ b/Swiften/Client/Client.cpp @@ -12,21 +12,17 @@ namespace Swift { Client::Client(const JID& jid, const String& password) : - IQRouter(this), jid_(jid), password_(password), session_(0) { + IQRouter(this), jid_(jid), password_(password) { connectionFactory_ = new BoostConnectionFactory(&boostIOServiceThread_.getIOService()); tlsLayerFactory_ = new PlatformTLSLayerFactory(); } Client::~Client() { - delete session_; delete tlsLayerFactory_; delete connectionFactory_; } void Client::connect() { - delete session_; - session_ = 0; - DomainNameResolver resolver; try { HostAddressPort remote = resolver.resolve(jid_.getDomain().getUTF8String()); @@ -44,23 +40,23 @@ void Client::handleConnectionConnectFinished(bool error) { onError(ClientError::ConnectionError); } else { - session_ = new ClientSession(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_); + session_ = boost::shared_ptr<ClientSession>(new ClientSession(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_)); if (!certificate_.isEmpty()) { session_->setCertificate(PKCS12Certificate(certificate_, password_)); } session_->onSessionStarted.connect(boost::bind(boost::ref(onConnected))); - session_->onError.connect(boost::bind(&Client::handleSessionError, this, _1)); + session_->onSessionFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1)); session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this)); session_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1)); session_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1)); session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1)); - session_->start(); + session_->startSession(); } } void Client::disconnect() { if (session_) { - session_->stop(); + session_->finishSession(); } } @@ -108,47 +104,46 @@ void Client::setCertificate(const String& certificate) { certificate_ = certificate; } -void Client::handleSessionError(ClientSession::SessionError error) { - ClientError clientError; - switch (error) { - case ClientSession::NoError: - assert(false); - break; - case ClientSession::ConnectionReadError: - clientError = ClientError(ClientError::ConnectionReadError); - break; - case ClientSession::ConnectionWriteError: - clientError = ClientError(ClientError::ConnectionWriteError); - break; - case ClientSession::XMLError: - clientError = ClientError(ClientError::XMLError); - break; - case ClientSession::AuthenticationFailedError: - clientError = ClientError(ClientError::AuthenticationFailedError); - break; - case ClientSession::NoSupportedAuthMechanismsError: - clientError = ClientError(ClientError::NoSupportedAuthMechanismsError); - break; - case ClientSession::UnexpectedElementError: - clientError = ClientError(ClientError::UnexpectedElementError); - break; - case ClientSession::ResourceBindError: - clientError = ClientError(ClientError::ResourceBindError); - break; - case ClientSession::SessionStartError: - clientError = ClientError(ClientError::SessionStartError); - break; - case ClientSession::TLSError: - clientError = ClientError(ClientError::TLSError); - break; - case ClientSession::ClientCertificateLoadError: - clientError = ClientError(ClientError::ClientCertificateLoadError); - break; - case ClientSession::ClientCertificateError: - clientError = ClientError(ClientError::ClientCertificateError); - break; +void Client::handleSessionFinished(const boost::optional<Session::SessionError>& error) { + if (error) { + ClientError clientError; + switch (*error) { + case Session::ConnectionReadError: + clientError = ClientError(ClientError::ConnectionReadError); + break; + case Session::ConnectionWriteError: + clientError = ClientError(ClientError::ConnectionWriteError); + break; + case Session::XMLError: + clientError = ClientError(ClientError::XMLError); + break; + case Session::AuthenticationFailedError: + clientError = ClientError(ClientError::AuthenticationFailedError); + break; + case Session::NoSupportedAuthMechanismsError: + clientError = ClientError(ClientError::NoSupportedAuthMechanismsError); + break; + case Session::UnexpectedElementError: + clientError = ClientError(ClientError::UnexpectedElementError); + break; + case Session::ResourceBindError: + clientError = ClientError(ClientError::ResourceBindError); + break; + case Session::SessionStartError: + clientError = ClientError(ClientError::SessionStartError); + break; + case Session::TLSError: + clientError = ClientError(ClientError::TLSError); + break; + case Session::ClientCertificateLoadError: + clientError = ClientError(ClientError::ClientCertificateLoadError); + break; + case Session::ClientCertificateError: + clientError = ClientError(ClientError::ClientCertificateError); + break; + } + onError(clientError); } - onError(clientError); } void Client::handleNeedCredentials() { diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h index 48b76d9..1561c75 100644 --- a/Swiften/Client/Client.h +++ b/Swiften/Client/Client.h @@ -47,7 +47,7 @@ namespace Swift { void send(boost::shared_ptr<Stanza>); virtual String getNewIQID(); void handleElement(boost::shared_ptr<Element>); - void handleSessionError(ClientSession::SessionError error); + void handleSessionFinished(const boost::optional<Session::SessionError>& error); void handleNeedCredentials(); void handleDataRead(const ByteArray&); void handleDataWritten(const ByteArray&); @@ -61,7 +61,7 @@ namespace Swift { TLSLayerFactory* tlsLayerFactory_; FullPayloadParserFactoryCollection payloadParserFactories_; FullPayloadSerializerCollection payloadSerializers_; - ClientSession* session_; + boost::shared_ptr<ClientSession> session_; boost::shared_ptr<Connection> connection_; String certificate_; }; diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index 11317e8..4fcf1f8 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -30,70 +30,30 @@ ClientSession::ClientSession( TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) : + Session(connection, payloadParserFactories, payloadSerializers), jid_(jid), tlsLayerFactory_(tlsLayerFactory), - payloadParserFactories_(payloadParserFactories), - payloadSerializers_(payloadSerializers), state_(Initial), - error_(NoError), - connection_(connection), - streamStack_(0), needSessionStart_(false) { } -ClientSession::~ClientSession() { - delete streamStack_; -} - -void ClientSession::start() { +void ClientSession::handleSessionStarted() { assert(state_ == Initial); - - connection_->onDisconnected.connect(boost::bind(&ClientSession::handleDisconnected, this, _1)); - initializeStreamStack(); state_ = WaitingForStreamStart; sendStreamHeader(); } -void ClientSession::stop() { - // TODO: Send end stream header if applicable - connection_->disconnect(); -} - void ClientSession::sendStreamHeader() { ProtocolHeader header; header.setTo(jid_.getDomain()); - xmppLayer_->writeHeader(header); -} - -void ClientSession::initializeStreamStack() { - xmppLayer_ = boost::shared_ptr<XMPPLayer>(new XMPPLayer(payloadParserFactories_, payloadSerializers_)); - xmppLayer_->onStreamStart.connect(boost::bind(&ClientSession::handleStreamStart, this)); - xmppLayer_->onElement.connect(boost::bind(&ClientSession::handleElement, this, _1)); - xmppLayer_->onError.connect(boost::bind(&ClientSession::setError, this, XMLError)); - xmppLayer_->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1)); - xmppLayer_->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1)); - connectionLayer_ = boost::shared_ptr<ConnectionLayer>(new ConnectionLayer(connection_)); - streamStack_ = new StreamStack(xmppLayer_, connectionLayer_); -} - -void ClientSession::handleDisconnected(const boost::optional<Connection::Error>& error) { - if (error) { - switch (*error) { - case Connection::ReadError: - setError(ConnectionReadError); - break; - case Connection::WriteError: - setError(ConnectionWriteError); - break; - } - } + getXMPPLayer()->writeHeader(header); } void ClientSession::setCertificate(const PKCS12Certificate& certificate) { certificate_ = certificate; } -void ClientSession::handleStreamStart() { +void ClientSession::handleStreamStart(const ProtocolHeader&) { checkState(WaitingForStreamStart); state_ = Negotiating; } @@ -109,16 +69,16 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) { state_ = Encrypting; - xmppLayer_->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); + getXMPPLayer()->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); } else if (streamFeatures->hasAuthenticationMechanisms()) { if (!certificate_.isNull()) { if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { state_ = Authenticating; - xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); + getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); } else { - setError(ClientCertificateError); + finishSession(ClientCertificateError); } } else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) { @@ -126,7 +86,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { onNeedCredentials(); } else { - setError(NoSupportedAuthMechanismsError); + finishSession(NoSupportedAuthMechanismsError); } } else { @@ -134,7 +94,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { // Add a whitespace ping layer whitespacePingLayer_ = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer()); - streamStack_->addLayer(whitespacePingLayer_); + getStreamStack()->addLayer(whitespacePingLayer_); if (streamFeatures->hasSession()) { needSessionStart_ = true; @@ -146,31 +106,31 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { if (!jid_.getResource().isEmpty()) { resourceBind->setResource(jid_.getResource()); } - xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); + getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); } else if (needSessionStart_) { sendSessionStart(); } else { state_ = SessionStarted; - onSessionStarted(); + setInitialized(); } } } else if (dynamic_cast<AuthSuccess*>(element.get())) { checkState(Authenticating); state_ = WaitingForStreamStart; - xmppLayer_->resetParser(); + getXMPPLayer()->resetParser(); sendStreamHeader(); } else if (dynamic_cast<AuthFailure*>(element.get())) { - setError(AuthenticationFailedError); + finishSession(AuthenticationFailedError); } else if (dynamic_cast<TLSProceed*>(element.get())) { tlsLayer_ = tlsLayerFactory_->createTLSLayer(); - streamStack_->addLayer(tlsLayer_); + getStreamStack()->addLayer(tlsLayer_); if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) { - setError(ClientCertificateLoadError); + finishSession(ClientCertificateLoadError); } else { tlsLayer_->onConnected.connect(boost::bind(&ClientSession::handleTLSConnected, this)); @@ -179,21 +139,21 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { } } else if (dynamic_cast<StartTLSFailure*>(element.get())) { - setError(TLSError); + finishSession(TLSError); } else if (IQ* iq = dynamic_cast<IQ*>(element.get())) { if (state_ == BindingResource) { boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { - setError(ResourceBindError); + finishSession(ResourceBindError); } else if (!resourceBind) { - setError(UnexpectedElementError); + finishSession(UnexpectedElementError); } else if (iq->getType() == IQ::Result) { jid_ = resourceBind->getJID(); if (!jid_.isValid()) { - setError(ResourceBindError); + finishSession(ResourceBindError); } if (needSessionStart_) { sendSessionStart(); @@ -203,47 +163,51 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) { } } else { - setError(UnexpectedElementError); + finishSession(UnexpectedElementError); } } else if (state_ == StartingSession) { if (iq->getType() == IQ::Result) { state_ = SessionStarted; - onSessionStarted(); + setInitialized(); } else if (iq->getType() == IQ::Error) { - setError(SessionStartError); + finishSession(SessionStartError); } else { - setError(UnexpectedElementError); + finishSession(UnexpectedElementError); } } else { - setError(UnexpectedElementError); + finishSession(UnexpectedElementError); } } else { // FIXME Not correct? state_ = SessionStarted; - onSessionStarted(); + setInitialized(); } } void ClientSession::sendSessionStart() { state_ = StartingSession; - xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); + getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); } -void ClientSession::setError(SessionError error) { - assert(error != NoError); - state_ = Error; - error_ = error; - onError(error); +void ClientSession::handleSessionFinished(const boost::optional<SessionError>& error) { + if (error) { + assert(!error_); + state_ = Error; + error_ = error; + } + else { + state_ = Finished; + } } bool ClientSession::checkState(State state) { if (state_ != state) { - setError(UnexpectedElementError); + finishSession(UnexpectedElementError); return false; } return true; @@ -252,22 +216,17 @@ bool ClientSession::checkState(State state) { void ClientSession::sendCredentials(const String& password) { assert(WaitingForCredentials); state_ = Authenticating; - xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(jid_.getNode(), password).getValue()))); -} - -void ClientSession::sendElement(boost::shared_ptr<Element> element) { - assert(SessionStarted); - xmppLayer_->writeElement(element); + getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(jid_.getNode(), password).getValue()))); } void ClientSession::handleTLSConnected() { state_ = WaitingForStreamStart; - xmppLayer_->resetParser(); + getXMPPLayer()->resetParser(); sendStreamHeader(); } void ClientSession::handleTLSError() { - setError(TLSError); + finishSession(TLSError); } } diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index 50dae24..22e4a88 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -3,6 +3,7 @@ #include <boost/signal.hpp> #include <boost/shared_ptr.hpp> +#include "Swiften/Session/Session.h" #include "Swiften/Base/String.h" #include "Swiften/JID/JID.h" #include "Swiften/Elements/Element.h" @@ -21,7 +22,7 @@ namespace Swift { class TLSLayer; class WhitespacePingLayer; - class ClientSession { + class ClientSession : public Session { public: enum State { Initial, @@ -34,21 +35,8 @@ namespace Swift { BindingResource, StartingSession, SessionStarted, - Error - }; - enum SessionError { - NoError, - ConnectionReadError, - ConnectionWriteError, - XMLError, - AuthenticationFailedError, - NoSupportedAuthMechanismsError, - UnexpectedElementError, - ResourceBindError, - SessionStartError, - TLSError, - ClientCertificateLoadError, - ClientCertificateError + Error, + Finished }; ClientSession( @@ -57,13 +45,12 @@ namespace Swift { TLSLayerFactory*, PayloadParserFactoryCollection*, PayloadSerializerCollection*); - ~ClientSession(); State getState() const { return state_; } - SessionError getError() const { + boost::optional<SessionError> getError() const { return error_; } @@ -71,26 +58,18 @@ namespace Swift { return jid_; } - void start(); - void stop(); - void sendCredentials(const String& password); - void sendElement(boost::shared_ptr<Element>); void setCertificate(const PKCS12Certificate& certificate); - protected: - StreamStack* getStreamStack() const { - return streamStack_; - } - private: - void initializeStreamStack(); void sendStreamHeader(); void sendSessionStart(); - void handleDisconnected(const boost::optional<Connection::Error>&); - void handleElement(boost::shared_ptr<Element>); - void handleStreamStart(); + virtual void handleSessionStarted(); + virtual void handleSessionFinished(const boost::optional<SessionError>& error); + virtual void handleElement(boost::shared_ptr<Element>); + virtual void handleStreamStart(const ProtocolHeader&); + void handleTLSConnected(); void handleTLSError(); @@ -98,26 +77,15 @@ namespace Swift { bool checkState(State); public: - boost::signal<void ()> onSessionStarted; - boost::signal<void (SessionError)> onError; boost::signal<void ()> onNeedCredentials; - boost::signal<void (boost::shared_ptr<Element>) > onElementReceived; - boost::signal<void (const ByteArray&)> onDataWritten; - boost::signal<void (const ByteArray&)> onDataRead; private: JID jid_; TLSLayerFactory* tlsLayerFactory_; - PayloadParserFactoryCollection* payloadParserFactories_; - PayloadSerializerCollection* payloadSerializers_; State state_; - SessionError error_; - boost::shared_ptr<Connection> connection_; - boost::shared_ptr<XMPPLayer> xmppLayer_; + boost::optional<SessionError> error_; boost::shared_ptr<TLSLayer> tlsLayer_; - boost::shared_ptr<ConnectionLayer> connectionLayer_; boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer_; - StreamStack* streamStack_; bool needSessionStart_; PKCS12Certificate certificate_; }; diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index 1e66019..c86442d 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -78,15 +78,15 @@ class ClientSessionTest : public CppUnit::TestFixture { } void testConstructor() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); CPPUNIT_ASSERT_EQUAL(ClientSession::Initial, session->getState()); } void testStart_Error() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForStreamStart, session->getState()); @@ -94,14 +94,14 @@ class ClientSessionTest : public CppUnit::TestFixture { processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::ConnectionReadError, session->getError()); + CPPUNIT_ASSERT_EQUAL(ClientSession::ConnectionReadError, *session->getError()); } void testStart_XMLError() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForStreamStart, session->getState()); @@ -109,29 +109,29 @@ class ClientSessionTest : public CppUnit::TestFixture { processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::XMLError, session->getError()); + CPPUNIT_ASSERT_EQUAL(ClientSession::XMLError, *session->getError()); } void testStartTLS_NoTLSSupport() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); tlsLayerFactory_->setTLSSupported(false); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithStartTLS(); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); } void testStartTLS() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithStartTLS(); getMockServer()->expectStartTLS(); // FIXME: Test 'encrypting' state getMockServer()->sendTLSProceed(); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::Encrypting, session->getState()); CPPUNIT_ASSERT(session->getTLSLayer()); @@ -147,42 +147,42 @@ class ClientSessionTest : public CppUnit::TestFixture { } void testStartTLS_ServerError() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithStartTLS(); getMockServer()->expectStartTLS(); getMockServer()->sendTLSFailure(); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, session->getError()); + CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, *session->getError()); } void testStartTLS_ConnectError() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithStartTLS(); getMockServer()->expectStartTLS(); getMockServer()->sendTLSProceed(); - session->start(); + session->startSession(); processEvents(); session->getTLSLayer()->setError(); CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, session->getError()); + CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, *session->getError()); } void testStartTLS_ErrorAfterConnect() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithStartTLS(); getMockServer()->expectStartTLS(); getMockServer()->sendTLSProceed(); - session->start(); + session->startSession(); processEvents(); getMockServer()->resetParser(); getMockServer()->expectStreamStart(); @@ -193,16 +193,16 @@ class ClientSessionTest : public CppUnit::TestFixture { session->getTLSLayer()->setError(); CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, session->getError()); + CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, *session->getError()); } void testAuthenticate() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::setNeedCredentials, this)); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithAuthentication(); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState()); CPPUNIT_ASSERT(needCredentials_); @@ -218,11 +218,11 @@ class ClientSessionTest : public CppUnit::TestFixture { } void testAuthenticate_Unauthorized() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithAuthentication(); - session->start(); + session->startSession(); processEvents(); getMockServer()->expectAuth("me", "mypass"); @@ -231,30 +231,30 @@ class ClientSessionTest : public CppUnit::TestFixture { processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::AuthenticationFailedError, session->getError()); + CPPUNIT_ASSERT_EQUAL(ClientSession::AuthenticationFailedError, *session->getError()); } void testAuthenticate_NoValidAuthMechanisms() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithUnsupportedAuthentication(); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::NoSupportedAuthMechanismsError, session->getError()); + CPPUNIT_ASSERT_EQUAL(ClientSession::NoSupportedAuthMechanismsError, *session->getError()); } void testResourceBind() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithResourceBind(); getMockServer()->expectResourceBind("Bar", "session-bind"); // FIXME: Check CPPUNIT_ASSERT_EQUAL(ClientSession::BindingResource, session->getState()); getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind"); - session->start(); + session->startSession(); processEvents(); @@ -263,13 +263,13 @@ class ClientSessionTest : public CppUnit::TestFixture { } void testResourceBind_ChangeResource() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithResourceBind(); getMockServer()->expectResourceBind("Bar", "session-bind"); getMockServer()->sendResourceBindResponse("me@foo.com/Bar123", "session-bind"); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); @@ -277,13 +277,13 @@ class ClientSessionTest : public CppUnit::TestFixture { } void testResourceBind_EmptyResource() { - std::auto_ptr<MockSession> session(createSession("me@foo.com")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithResourceBind(); getMockServer()->expectResourceBind("", "session-bind"); getMockServer()->sendResourceBindResponse("me@foo.com/NewResource", "session-bind"); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); @@ -291,21 +291,21 @@ class ClientSessionTest : public CppUnit::TestFixture { } void testResourceBind_Error() { - std::auto_ptr<MockSession> session(createSession("me@foo.com")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithResourceBind(); getMockServer()->expectResourceBind("", "session-bind"); getMockServer()->sendError("session-bind"); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::ResourceBindError, session->getError()); + CPPUNIT_ASSERT_EQUAL(ClientSession::ResourceBindError, *session->getError()); } void testSessionStart() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); session->onSessionStarted.connect(boost::bind(&ClientSessionTest::setSessionStarted, this)); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); @@ -313,7 +313,7 @@ class ClientSessionTest : public CppUnit::TestFixture { getMockServer()->expectSessionStart("session-start"); // FIXME: Check CPPUNIT_ASSERT_EQUAL(ClientSession::StartingSession, session->getState()); getMockServer()->sendSessionStartResponse("session-start"); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); @@ -321,21 +321,21 @@ class ClientSessionTest : public CppUnit::TestFixture { } void testSessionStart_Error() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeaturesWithSession(); getMockServer()->expectSessionStart("session-start"); getMockServer()->sendError("session-start"); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState()); - CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStartError, session->getError()); + CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStartError, *session->getError()); } void testSessionStart_AfterResourceBind() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); session->onSessionStarted.connect(boost::bind(&ClientSessionTest::setSessionStarted, this)); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); @@ -344,7 +344,7 @@ class ClientSessionTest : public CppUnit::TestFixture { getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind"); getMockServer()->expectSessionStart("session-start"); getMockServer()->sendSessionStartResponse("session-start"); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState()); @@ -352,21 +352,21 @@ class ClientSessionTest : public CppUnit::TestFixture { } void testWhitespacePing() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeatures(); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT(session->getWhitespacePingLayer()); } void testReceiveElementAfterSessionStarted() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeatures(); - session->start(); + session->startSession(); processEvents(); getMockServer()->expectMessage(); @@ -374,13 +374,13 @@ class ClientSessionTest : public CppUnit::TestFixture { } void testSendElement() { - std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar")); + boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar")); session->onElementReceived.connect(boost::bind(&ClientSessionTest::addReceivedElement, this, _1)); getMockServer()->expectStreamStart(); getMockServer()->sendStreamStart(); getMockServer()->sendStreamFeatures(); getMockServer()->sendMessage(); - session->start(); + session->startSession(); processEvents(); CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(receivedElements_.size())); @@ -684,8 +684,8 @@ class ClientSessionTest : public CppUnit::TestFixture { } }; - MockSession* createSession(const String& jid) { - return new MockSession(JID(jid), connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_); + boost::shared_ptr<MockSession> createSession(const String& jid) { + return boost::shared_ptr<MockSession>(new MockSession(JID(jid), connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_)); } diff --git a/Swiften/Makefile.inc b/Swiften/Makefile.inc index 6fa2df8..d66a2b9 100644 --- a/Swiften/Makefile.inc +++ b/Swiften/Makefile.inc @@ -10,6 +10,7 @@ include Swiften/Serializer/Makefile.inc include Swiften/Parser/Makefile.inc include Swiften/MUC/Makefile.inc include Swiften/Network/Makefile.inc +include Swiften/Session/Makefile.inc include Swiften/Client/Makefile.inc include Swiften/TLS/Makefile.inc include Swiften/SASL/Makefile.inc diff --git a/Swiften/Session/Session.cpp b/Swiften/Session/Session.cpp index 9ab8e4d..84354e5 100644 --- a/Swiften/Session/Session.cpp +++ b/Swiften/Session/Session.cpp @@ -28,12 +28,14 @@ void Session::startSession() { void Session::finishSession() { connection->disconnect(); - onSessionFinished(boost::optional<Error>()); + handleSessionFinished(boost::optional<SessionError>()); + onSessionFinished(boost::optional<SessionError>()); } -void Session::finishSession(const Error& error) { +void Session::finishSession(const SessionError& error) { connection->disconnect(); - onSessionFinished(boost::optional<Error>(error)); + handleSessionFinished(boost::optional<SessionError>(error)); + onSessionFinished(boost::optional<SessionError>(error)); } void Session::initializeStreamStack() { @@ -41,23 +43,31 @@ void Session::initializeStreamStack() { new XMPPLayer(payloadParserFactories, payloadSerializers)); xmppLayer->onStreamStart.connect( boost::bind(&Session::handleStreamStart, this, _1)); - xmppLayer->onElement.connect( - boost::bind(&Session::handleElement, this, _1)); + xmppLayer->onElement.connect(boost::bind(&Session::handleElement, this, _1)); xmppLayer->onError.connect( boost::bind(&Session::finishSession, this, XMLError)); + xmppLayer->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1)); + xmppLayer->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1)); connection->onDisconnected.connect( boost::bind(&Session::handleDisconnected, shared_from_this(), _1)); connectionLayer = boost::shared_ptr<ConnectionLayer>(new ConnectionLayer(connection)); streamStack = new StreamStack(xmppLayer, connectionLayer); } -void Session::sendStanza(boost::shared_ptr<Stanza> stanza) { +void Session::sendElement(boost::shared_ptr<Element> stanza) { xmppLayer->writeElement(stanza); } void Session::handleDisconnected(const boost::optional<Connection::Error>& connectionError) { if (connectionError) { - finishSession(ConnectionError); + switch (*connectionError) { + case Connection::ReadError: + finishSession(ConnectionReadError); + break; + case Connection::WriteError: + finishSession(ConnectionWriteError); + break; + } } else { finishSession(); diff --git a/Swiften/Session/Session.h b/Swiften/Session/Session.h index bf8049a..b35179c 100644 --- a/Swiften/Session/Session.h +++ b/Swiften/Session/Session.h @@ -14,7 +14,7 @@ namespace Swift { class ProtocolHeader; class StreamStack; class JID; - class Stanza; + class Element; class ByteArray; class PayloadParserFactoryCollection; class PayloadSerializerCollection; @@ -22,9 +22,18 @@ namespace Swift { class Session : public boost::enable_shared_from_this<Session> { public: - enum Error { - ConnectionError, - XMLError + enum SessionError { + ConnectionReadError, + ConnectionWriteError, + XMLError, + AuthenticationFailedError, + NoSupportedAuthMechanismsError, + UnexpectedElementError, + ResourceBindError, + SessionStartError, + TLSError, + ClientCertificateLoadError, + ClientCertificateError }; Session( @@ -35,18 +44,19 @@ namespace Swift { void startSession(); void finishSession(); - void sendStanza(boost::shared_ptr<Stanza>); + void sendElement(boost::shared_ptr<Element>); - boost::signal<void (boost::shared_ptr<Stanza>)> onStanzaReceived; + boost::signal<void (boost::shared_ptr<Element>)> onElementReceived; boost::signal<void ()> onSessionStarted; - boost::signal<void (const boost::optional<Error>&)> onSessionFinished; + boost::signal<void (const boost::optional<SessionError>&)> onSessionFinished; boost::signal<void (const ByteArray&)> onDataWritten; boost::signal<void (const ByteArray&)> onDataRead; protected: - void finishSession(const Error&); + void finishSession(const SessionError&); virtual void handleSessionStarted() {} + virtual void handleSessionFinished(const boost::optional<SessionError>&) {} virtual void handleElement(boost::shared_ptr<Element>) = 0; virtual void handleStreamStart(const ProtocolHeader&) = 0; @@ -56,11 +66,17 @@ namespace Swift { return xmppLayer; } + StreamStack* getStreamStack() const { + return streamStack; + } + void setInitialized(); bool isInitialized() const { return initialized; } + void setFinished(); + private: void handleDisconnected(const boost::optional<Connection::Error>& error); -- cgit v0.10.2-6-g49f6