summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRemko Tronçon <git@el-tramo.be>2009-07-19 13:21:38 (GMT)
committerRemko Tronçon <git@el-tramo.be>2009-07-19 13:27:00 (GMT)
commita6fcd9e7aa12c5e00c61ff809e81fba14babd70c (patch)
treef65419f4f9a78a1db574f8fa792e745f7fbdf76c
parent9ccf1973ec3e23e4ba061b774c3f3e3bde4f1040 (diff)
downloadswift-contrib-a6fcd9e7aa12c5e00c61ff809e81fba14babd70c.zip
swift-contrib-a6fcd9e7aa12c5e00c61ff809e81fba14babd70c.tar.bz2
Factor out common session stuff into Session class.
-rw-r--r--Swiften/Client/Client.cpp93
-rw-r--r--Swiften/Client/Client.h4
-rw-r--r--Swiften/Client/ClientSession.cpp119
-rw-r--r--Swiften/Client/ClientSession.h54
-rw-r--r--Swiften/Client/UnitTest/ClientSessionTest.cpp104
-rw-r--r--Swiften/Makefile.inc1
-rw-r--r--Swiften/Session/Session.cpp24
-rw-r--r--Swiften/Session/Session.h32
8 files changed, 190 insertions, 241 deletions
diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp
index a38416a..95f6c0f 100644
--- a/Swiften/Client/Client.cpp
+++ b/Swiften/Client/Client.cpp
@@ -12,21 +12,17 @@
namespace Swift {
Client::Client(const JID& jid, const String& password) :
- IQRouter(this), jid_(jid), password_(password), session_(0) {
+ IQRouter(this), jid_(jid), password_(password) {
connectionFactory_ = new BoostConnectionFactory(&boostIOServiceThread_.getIOService());
tlsLayerFactory_ = new PlatformTLSLayerFactory();
}
Client::~Client() {
- delete session_;
delete tlsLayerFactory_;
delete connectionFactory_;
}
void Client::connect() {
- delete session_;
- session_ = 0;
-
DomainNameResolver resolver;
try {
HostAddressPort remote = resolver.resolve(jid_.getDomain().getUTF8String());
@@ -44,23 +40,23 @@ void Client::handleConnectionConnectFinished(bool error) {
onError(ClientError::ConnectionError);
}
else {
- session_ = new ClientSession(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_);
+ session_ = boost::shared_ptr<ClientSession>(new ClientSession(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_));
if (!certificate_.isEmpty()) {
session_->setCertificate(PKCS12Certificate(certificate_, password_));
}
session_->onSessionStarted.connect(boost::bind(boost::ref(onConnected)));
- session_->onError.connect(boost::bind(&Client::handleSessionError, this, _1));
+ session_->onSessionFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1));
session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this));
session_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1));
session_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1));
session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1));
- session_->start();
+ session_->startSession();
}
}
void Client::disconnect() {
if (session_) {
- session_->stop();
+ session_->finishSession();
}
}
@@ -108,47 +104,46 @@ void Client::setCertificate(const String& certificate) {
certificate_ = certificate;
}
-void Client::handleSessionError(ClientSession::SessionError error) {
- ClientError clientError;
- switch (error) {
- case ClientSession::NoError:
- assert(false);
- break;
- case ClientSession::ConnectionReadError:
- clientError = ClientError(ClientError::ConnectionReadError);
- break;
- case ClientSession::ConnectionWriteError:
- clientError = ClientError(ClientError::ConnectionWriteError);
- break;
- case ClientSession::XMLError:
- clientError = ClientError(ClientError::XMLError);
- break;
- case ClientSession::AuthenticationFailedError:
- clientError = ClientError(ClientError::AuthenticationFailedError);
- break;
- case ClientSession::NoSupportedAuthMechanismsError:
- clientError = ClientError(ClientError::NoSupportedAuthMechanismsError);
- break;
- case ClientSession::UnexpectedElementError:
- clientError = ClientError(ClientError::UnexpectedElementError);
- break;
- case ClientSession::ResourceBindError:
- clientError = ClientError(ClientError::ResourceBindError);
- break;
- case ClientSession::SessionStartError:
- clientError = ClientError(ClientError::SessionStartError);
- break;
- case ClientSession::TLSError:
- clientError = ClientError(ClientError::TLSError);
- break;
- case ClientSession::ClientCertificateLoadError:
- clientError = ClientError(ClientError::ClientCertificateLoadError);
- break;
- case ClientSession::ClientCertificateError:
- clientError = ClientError(ClientError::ClientCertificateError);
- break;
+void Client::handleSessionFinished(const boost::optional<Session::SessionError>& error) {
+ if (error) {
+ ClientError clientError;
+ switch (*error) {
+ case Session::ConnectionReadError:
+ clientError = ClientError(ClientError::ConnectionReadError);
+ break;
+ case Session::ConnectionWriteError:
+ clientError = ClientError(ClientError::ConnectionWriteError);
+ break;
+ case Session::XMLError:
+ clientError = ClientError(ClientError::XMLError);
+ break;
+ case Session::AuthenticationFailedError:
+ clientError = ClientError(ClientError::AuthenticationFailedError);
+ break;
+ case Session::NoSupportedAuthMechanismsError:
+ clientError = ClientError(ClientError::NoSupportedAuthMechanismsError);
+ break;
+ case Session::UnexpectedElementError:
+ clientError = ClientError(ClientError::UnexpectedElementError);
+ break;
+ case Session::ResourceBindError:
+ clientError = ClientError(ClientError::ResourceBindError);
+ break;
+ case Session::SessionStartError:
+ clientError = ClientError(ClientError::SessionStartError);
+ break;
+ case Session::TLSError:
+ clientError = ClientError(ClientError::TLSError);
+ break;
+ case Session::ClientCertificateLoadError:
+ clientError = ClientError(ClientError::ClientCertificateLoadError);
+ break;
+ case Session::ClientCertificateError:
+ clientError = ClientError(ClientError::ClientCertificateError);
+ break;
+ }
+ onError(clientError);
}
- onError(clientError);
}
void Client::handleNeedCredentials() {
diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h
index 48b76d9..1561c75 100644
--- a/Swiften/Client/Client.h
+++ b/Swiften/Client/Client.h
@@ -47,7 +47,7 @@ namespace Swift {
void send(boost::shared_ptr<Stanza>);
virtual String getNewIQID();
void handleElement(boost::shared_ptr<Element>);
- void handleSessionError(ClientSession::SessionError error);
+ void handleSessionFinished(const boost::optional<Session::SessionError>& error);
void handleNeedCredentials();
void handleDataRead(const ByteArray&);
void handleDataWritten(const ByteArray&);
@@ -61,7 +61,7 @@ namespace Swift {
TLSLayerFactory* tlsLayerFactory_;
FullPayloadParserFactoryCollection payloadParserFactories_;
FullPayloadSerializerCollection payloadSerializers_;
- ClientSession* session_;
+ boost::shared_ptr<ClientSession> session_;
boost::shared_ptr<Connection> connection_;
String certificate_;
};
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp
index 11317e8..4fcf1f8 100644
--- a/Swiften/Client/ClientSession.cpp
+++ b/Swiften/Client/ClientSession.cpp
@@ -30,70 +30,30 @@ ClientSession::ClientSession(
TLSLayerFactory* tlsLayerFactory,
PayloadParserFactoryCollection* payloadParserFactories,
PayloadSerializerCollection* payloadSerializers) :
+ Session(connection, payloadParserFactories, payloadSerializers),
jid_(jid),
tlsLayerFactory_(tlsLayerFactory),
- payloadParserFactories_(payloadParserFactories),
- payloadSerializers_(payloadSerializers),
state_(Initial),
- error_(NoError),
- connection_(connection),
- streamStack_(0),
needSessionStart_(false) {
}
-ClientSession::~ClientSession() {
- delete streamStack_;
-}
-
-void ClientSession::start() {
+void ClientSession::handleSessionStarted() {
assert(state_ == Initial);
-
- connection_->onDisconnected.connect(boost::bind(&ClientSession::handleDisconnected, this, _1));
- initializeStreamStack();
state_ = WaitingForStreamStart;
sendStreamHeader();
}
-void ClientSession::stop() {
- // TODO: Send end stream header if applicable
- connection_->disconnect();
-}
-
void ClientSession::sendStreamHeader() {
ProtocolHeader header;
header.setTo(jid_.getDomain());
- xmppLayer_->writeHeader(header);
-}
-
-void ClientSession::initializeStreamStack() {
- xmppLayer_ = boost::shared_ptr<XMPPLayer>(new XMPPLayer(payloadParserFactories_, payloadSerializers_));
- xmppLayer_->onStreamStart.connect(boost::bind(&ClientSession::handleStreamStart, this));
- xmppLayer_->onElement.connect(boost::bind(&ClientSession::handleElement, this, _1));
- xmppLayer_->onError.connect(boost::bind(&ClientSession::setError, this, XMLError));
- xmppLayer_->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1));
- xmppLayer_->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1));
- connectionLayer_ = boost::shared_ptr<ConnectionLayer>(new ConnectionLayer(connection_));
- streamStack_ = new StreamStack(xmppLayer_, connectionLayer_);
-}
-
-void ClientSession::handleDisconnected(const boost::optional<Connection::Error>& error) {
- if (error) {
- switch (*error) {
- case Connection::ReadError:
- setError(ConnectionReadError);
- break;
- case Connection::WriteError:
- setError(ConnectionWriteError);
- break;
- }
- }
+ getXMPPLayer()->writeHeader(header);
}
void ClientSession::setCertificate(const PKCS12Certificate& certificate) {
certificate_ = certificate;
}
-void ClientSession::handleStreamStart() {
+void ClientSession::handleStreamStart(const ProtocolHeader&) {
checkState(WaitingForStreamStart);
state_ = Negotiating;
}
@@ -109,16 +69,16 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) {
state_ = Encrypting;
- xmppLayer_->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest()));
+ getXMPPLayer()->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest()));
}
else if (streamFeatures->hasAuthenticationMechanisms()) {
if (!certificate_.isNull()) {
if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
state_ = Authenticating;
- xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", "")));
+ getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", "")));
}
else {
- setError(ClientCertificateError);
+ finishSession(ClientCertificateError);
}
}
else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) {
@@ -126,7 +86,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
onNeedCredentials();
}
else {
- setError(NoSupportedAuthMechanismsError);
+ finishSession(NoSupportedAuthMechanismsError);
}
}
else {
@@ -134,7 +94,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
// Add a whitespace ping layer
whitespacePingLayer_ = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer());
- streamStack_->addLayer(whitespacePingLayer_);
+ getStreamStack()->addLayer(whitespacePingLayer_);
if (streamFeatures->hasSession()) {
needSessionStart_ = true;
@@ -146,31 +106,31 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
if (!jid_.getResource().isEmpty()) {
resourceBind->setResource(jid_.getResource());
}
- xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind));
+ getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind));
}
else if (needSessionStart_) {
sendSessionStart();
}
else {
state_ = SessionStarted;
- onSessionStarted();
+ setInitialized();
}
}
}
else if (dynamic_cast<AuthSuccess*>(element.get())) {
checkState(Authenticating);
state_ = WaitingForStreamStart;
- xmppLayer_->resetParser();
+ getXMPPLayer()->resetParser();
sendStreamHeader();
}
else if (dynamic_cast<AuthFailure*>(element.get())) {
- setError(AuthenticationFailedError);
+ finishSession(AuthenticationFailedError);
}
else if (dynamic_cast<TLSProceed*>(element.get())) {
tlsLayer_ = tlsLayerFactory_->createTLSLayer();
- streamStack_->addLayer(tlsLayer_);
+ getStreamStack()->addLayer(tlsLayer_);
if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) {
- setError(ClientCertificateLoadError);
+ finishSession(ClientCertificateLoadError);
}
else {
tlsLayer_->onConnected.connect(boost::bind(&ClientSession::handleTLSConnected, this));
@@ -179,21 +139,21 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
}
}
else if (dynamic_cast<StartTLSFailure*>(element.get())) {
- setError(TLSError);
+ finishSession(TLSError);
}
else if (IQ* iq = dynamic_cast<IQ*>(element.get())) {
if (state_ == BindingResource) {
boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>());
if (iq->getType() == IQ::Error && iq->getID() == "session-bind") {
- setError(ResourceBindError);
+ finishSession(ResourceBindError);
}
else if (!resourceBind) {
- setError(UnexpectedElementError);
+ finishSession(UnexpectedElementError);
}
else if (iq->getType() == IQ::Result) {
jid_ = resourceBind->getJID();
if (!jid_.isValid()) {
- setError(ResourceBindError);
+ finishSession(ResourceBindError);
}
if (needSessionStart_) {
sendSessionStart();
@@ -203,47 +163,51 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
}
}
else {
- setError(UnexpectedElementError);
+ finishSession(UnexpectedElementError);
}
}
else if (state_ == StartingSession) {
if (iq->getType() == IQ::Result) {
state_ = SessionStarted;
- onSessionStarted();
+ setInitialized();
}
else if (iq->getType() == IQ::Error) {
- setError(SessionStartError);
+ finishSession(SessionStartError);
}
else {
- setError(UnexpectedElementError);
+ finishSession(UnexpectedElementError);
}
}
else {
- setError(UnexpectedElementError);
+ finishSession(UnexpectedElementError);
}
}
else {
// FIXME Not correct?
state_ = SessionStarted;
- onSessionStarted();
+ setInitialized();
}
}
void ClientSession::sendSessionStart() {
state_ = StartingSession;
- xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession())));
+ getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession())));
}
-void ClientSession::setError(SessionError error) {
- assert(error != NoError);
- state_ = Error;
- error_ = error;
- onError(error);
+void ClientSession::handleSessionFinished(const boost::optional<SessionError>& error) {
+ if (error) {
+ assert(!error_);
+ state_ = Error;
+ error_ = error;
+ }
+ else {
+ state_ = Finished;
+ }
}
bool ClientSession::checkState(State state) {
if (state_ != state) {
- setError(UnexpectedElementError);
+ finishSession(UnexpectedElementError);
return false;
}
return true;
@@ -252,22 +216,17 @@ bool ClientSession::checkState(State state) {
void ClientSession::sendCredentials(const String& password) {
assert(WaitingForCredentials);
state_ = Authenticating;
- xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(jid_.getNode(), password).getValue())));
-}
-
-void ClientSession::sendElement(boost::shared_ptr<Element> element) {
- assert(SessionStarted);
- xmppLayer_->writeElement(element);
+ getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(jid_.getNode(), password).getValue())));
}
void ClientSession::handleTLSConnected() {
state_ = WaitingForStreamStart;
- xmppLayer_->resetParser();
+ getXMPPLayer()->resetParser();
sendStreamHeader();
}
void ClientSession::handleTLSError() {
- setError(TLSError);
+ finishSession(TLSError);
}
}
diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h
index 50dae24..22e4a88 100644
--- a/Swiften/Client/ClientSession.h
+++ b/Swiften/Client/ClientSession.h
@@ -3,6 +3,7 @@
#include <boost/signal.hpp>
#include <boost/shared_ptr.hpp>
+#include "Swiften/Session/Session.h"
#include "Swiften/Base/String.h"
#include "Swiften/JID/JID.h"
#include "Swiften/Elements/Element.h"
@@ -21,7 +22,7 @@ namespace Swift {
class TLSLayer;
class WhitespacePingLayer;
- class ClientSession {
+ class ClientSession : public Session {
public:
enum State {
Initial,
@@ -34,21 +35,8 @@ namespace Swift {
BindingResource,
StartingSession,
SessionStarted,
- Error
- };
- enum SessionError {
- NoError,
- ConnectionReadError,
- ConnectionWriteError,
- XMLError,
- AuthenticationFailedError,
- NoSupportedAuthMechanismsError,
- UnexpectedElementError,
- ResourceBindError,
- SessionStartError,
- TLSError,
- ClientCertificateLoadError,
- ClientCertificateError
+ Error,
+ Finished
};
ClientSession(
@@ -57,13 +45,12 @@ namespace Swift {
TLSLayerFactory*,
PayloadParserFactoryCollection*,
PayloadSerializerCollection*);
- ~ClientSession();
State getState() const {
return state_;
}
- SessionError getError() const {
+ boost::optional<SessionError> getError() const {
return error_;
}
@@ -71,26 +58,18 @@ namespace Swift {
return jid_;
}
- void start();
- void stop();
-
void sendCredentials(const String& password);
- void sendElement(boost::shared_ptr<Element>);
void setCertificate(const PKCS12Certificate& certificate);
- protected:
- StreamStack* getStreamStack() const {
- return streamStack_;
- }
-
private:
- void initializeStreamStack();
void sendStreamHeader();
void sendSessionStart();
- void handleDisconnected(const boost::optional<Connection::Error>&);
- void handleElement(boost::shared_ptr<Element>);
- void handleStreamStart();
+ virtual void handleSessionStarted();
+ virtual void handleSessionFinished(const boost::optional<SessionError>& error);
+ virtual void handleElement(boost::shared_ptr<Element>);
+ virtual void handleStreamStart(const ProtocolHeader&);
+
void handleTLSConnected();
void handleTLSError();
@@ -98,26 +77,15 @@ namespace Swift {
bool checkState(State);
public:
- boost::signal<void ()> onSessionStarted;
- boost::signal<void (SessionError)> onError;
boost::signal<void ()> onNeedCredentials;
- boost::signal<void (boost::shared_ptr<Element>) > onElementReceived;
- boost::signal<void (const ByteArray&)> onDataWritten;
- boost::signal<void (const ByteArray&)> onDataRead;
private:
JID jid_;
TLSLayerFactory* tlsLayerFactory_;
- PayloadParserFactoryCollection* payloadParserFactories_;
- PayloadSerializerCollection* payloadSerializers_;
State state_;
- SessionError error_;
- boost::shared_ptr<Connection> connection_;
- boost::shared_ptr<XMPPLayer> xmppLayer_;
+ boost::optional<SessionError> error_;
boost::shared_ptr<TLSLayer> tlsLayer_;
- boost::shared_ptr<ConnectionLayer> connectionLayer_;
boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer_;
- StreamStack* streamStack_;
bool needSessionStart_;
PKCS12Certificate certificate_;
};
diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp
index 1e66019..c86442d 100644
--- a/Swiften/Client/UnitTest/ClientSessionTest.cpp
+++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp
@@ -78,15 +78,15 @@ class ClientSessionTest : public CppUnit::TestFixture {
}
void testConstructor() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
CPPUNIT_ASSERT_EQUAL(ClientSession::Initial, session->getState());
}
void testStart_Error() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForStreamStart, session->getState());
@@ -94,14 +94,14 @@ class ClientSessionTest : public CppUnit::TestFixture {
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
- CPPUNIT_ASSERT_EQUAL(ClientSession::ConnectionReadError, session->getError());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::ConnectionReadError, *session->getError());
}
void testStart_XMLError() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForStreamStart, session->getState());
@@ -109,29 +109,29 @@ class ClientSessionTest : public CppUnit::TestFixture {
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
- CPPUNIT_ASSERT_EQUAL(ClientSession::XMLError, session->getError());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::XMLError, *session->getError());
}
void testStartTLS_NoTLSSupport() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
tlsLayerFactory_->setTLSSupported(false);
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithStartTLS();
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
}
void testStartTLS() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithStartTLS();
getMockServer()->expectStartTLS();
// FIXME: Test 'encrypting' state
getMockServer()->sendTLSProceed();
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::Encrypting, session->getState());
CPPUNIT_ASSERT(session->getTLSLayer());
@@ -147,42 +147,42 @@ class ClientSessionTest : public CppUnit::TestFixture {
}
void testStartTLS_ServerError() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithStartTLS();
getMockServer()->expectStartTLS();
getMockServer()->sendTLSFailure();
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
- CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, session->getError());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, *session->getError());
}
void testStartTLS_ConnectError() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithStartTLS();
getMockServer()->expectStartTLS();
getMockServer()->sendTLSProceed();
- session->start();
+ session->startSession();
processEvents();
session->getTLSLayer()->setError();
CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
- CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, session->getError());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, *session->getError());
}
void testStartTLS_ErrorAfterConnect() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithStartTLS();
getMockServer()->expectStartTLS();
getMockServer()->sendTLSProceed();
- session->start();
+ session->startSession();
processEvents();
getMockServer()->resetParser();
getMockServer()->expectStreamStart();
@@ -193,16 +193,16 @@ class ClientSessionTest : public CppUnit::TestFixture {
session->getTLSLayer()->setError();
CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
- CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, session->getError());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, *session->getError());
}
void testAuthenticate() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::setNeedCredentials, this));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithAuthentication();
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState());
CPPUNIT_ASSERT(needCredentials_);
@@ -218,11 +218,11 @@ class ClientSessionTest : public CppUnit::TestFixture {
}
void testAuthenticate_Unauthorized() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithAuthentication();
- session->start();
+ session->startSession();
processEvents();
getMockServer()->expectAuth("me", "mypass");
@@ -231,30 +231,30 @@ class ClientSessionTest : public CppUnit::TestFixture {
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
- CPPUNIT_ASSERT_EQUAL(ClientSession::AuthenticationFailedError, session->getError());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::AuthenticationFailedError, *session->getError());
}
void testAuthenticate_NoValidAuthMechanisms() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithUnsupportedAuthentication();
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
- CPPUNIT_ASSERT_EQUAL(ClientSession::NoSupportedAuthMechanismsError, session->getError());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::NoSupportedAuthMechanismsError, *session->getError());
}
void testResourceBind() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithResourceBind();
getMockServer()->expectResourceBind("Bar", "session-bind");
// FIXME: Check CPPUNIT_ASSERT_EQUAL(ClientSession::BindingResource, session->getState());
getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind");
- session->start();
+ session->startSession();
processEvents();
@@ -263,13 +263,13 @@ class ClientSessionTest : public CppUnit::TestFixture {
}
void testResourceBind_ChangeResource() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithResourceBind();
getMockServer()->expectResourceBind("Bar", "session-bind");
getMockServer()->sendResourceBindResponse("me@foo.com/Bar123", "session-bind");
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
@@ -277,13 +277,13 @@ class ClientSessionTest : public CppUnit::TestFixture {
}
void testResourceBind_EmptyResource() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithResourceBind();
getMockServer()->expectResourceBind("", "session-bind");
getMockServer()->sendResourceBindResponse("me@foo.com/NewResource", "session-bind");
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
@@ -291,21 +291,21 @@ class ClientSessionTest : public CppUnit::TestFixture {
}
void testResourceBind_Error() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithResourceBind();
getMockServer()->expectResourceBind("", "session-bind");
getMockServer()->sendError("session-bind");
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
- CPPUNIT_ASSERT_EQUAL(ClientSession::ResourceBindError, session->getError());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::ResourceBindError, *session->getError());
}
void testSessionStart() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
session->onSessionStarted.connect(boost::bind(&ClientSessionTest::setSessionStarted, this));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
@@ -313,7 +313,7 @@ class ClientSessionTest : public CppUnit::TestFixture {
getMockServer()->expectSessionStart("session-start");
// FIXME: Check CPPUNIT_ASSERT_EQUAL(ClientSession::StartingSession, session->getState());
getMockServer()->sendSessionStartResponse("session-start");
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
@@ -321,21 +321,21 @@ class ClientSessionTest : public CppUnit::TestFixture {
}
void testSessionStart_Error() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeaturesWithSession();
getMockServer()->expectSessionStart("session-start");
getMockServer()->sendError("session-start");
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
- CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStartError, session->getError());
+ CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStartError, *session->getError());
}
void testSessionStart_AfterResourceBind() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
session->onSessionStarted.connect(boost::bind(&ClientSessionTest::setSessionStarted, this));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
@@ -344,7 +344,7 @@ class ClientSessionTest : public CppUnit::TestFixture {
getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind");
getMockServer()->expectSessionStart("session-start");
getMockServer()->sendSessionStartResponse("session-start");
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
@@ -352,21 +352,21 @@ class ClientSessionTest : public CppUnit::TestFixture {
}
void testWhitespacePing() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeatures();
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT(session->getWhitespacePingLayer());
}
void testReceiveElementAfterSessionStarted() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeatures();
- session->start();
+ session->startSession();
processEvents();
getMockServer()->expectMessage();
@@ -374,13 +374,13 @@ class ClientSessionTest : public CppUnit::TestFixture {
}
void testSendElement() {
- std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+ boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
session->onElementReceived.connect(boost::bind(&ClientSessionTest::addReceivedElement, this, _1));
getMockServer()->expectStreamStart();
getMockServer()->sendStreamStart();
getMockServer()->sendStreamFeatures();
getMockServer()->sendMessage();
- session->start();
+ session->startSession();
processEvents();
CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(receivedElements_.size()));
@@ -684,8 +684,8 @@ class ClientSessionTest : public CppUnit::TestFixture {
}
};
- MockSession* createSession(const String& jid) {
- return new MockSession(JID(jid), connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_);
+ boost::shared_ptr<MockSession> createSession(const String& jid) {
+ return boost::shared_ptr<MockSession>(new MockSession(JID(jid), connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_));
}
diff --git a/Swiften/Makefile.inc b/Swiften/Makefile.inc
index 6fa2df8..d66a2b9 100644
--- a/Swiften/Makefile.inc
+++ b/Swiften/Makefile.inc
@@ -10,6 +10,7 @@ include Swiften/Serializer/Makefile.inc
include Swiften/Parser/Makefile.inc
include Swiften/MUC/Makefile.inc
include Swiften/Network/Makefile.inc
+include Swiften/Session/Makefile.inc
include Swiften/Client/Makefile.inc
include Swiften/TLS/Makefile.inc
include Swiften/SASL/Makefile.inc
diff --git a/Swiften/Session/Session.cpp b/Swiften/Session/Session.cpp
index 9ab8e4d..84354e5 100644
--- a/Swiften/Session/Session.cpp
+++ b/Swiften/Session/Session.cpp
@@ -28,12 +28,14 @@ void Session::startSession() {
void Session::finishSession() {
connection->disconnect();
- onSessionFinished(boost::optional<Error>());
+ handleSessionFinished(boost::optional<SessionError>());
+ onSessionFinished(boost::optional<SessionError>());
}
-void Session::finishSession(const Error& error) {
+void Session::finishSession(const SessionError& error) {
connection->disconnect();
- onSessionFinished(boost::optional<Error>(error));
+ handleSessionFinished(boost::optional<SessionError>(error));
+ onSessionFinished(boost::optional<SessionError>(error));
}
void Session::initializeStreamStack() {
@@ -41,23 +43,31 @@ void Session::initializeStreamStack() {
new XMPPLayer(payloadParserFactories, payloadSerializers));
xmppLayer->onStreamStart.connect(
boost::bind(&Session::handleStreamStart, this, _1));
- xmppLayer->onElement.connect(
- boost::bind(&Session::handleElement, this, _1));
+ xmppLayer->onElement.connect(boost::bind(&Session::handleElement, this, _1));
xmppLayer->onError.connect(
boost::bind(&Session::finishSession, this, XMLError));
+ xmppLayer->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1));
+ xmppLayer->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1));
connection->onDisconnected.connect(
boost::bind(&Session::handleDisconnected, shared_from_this(), _1));
connectionLayer = boost::shared_ptr<ConnectionLayer>(new ConnectionLayer(connection));
streamStack = new StreamStack(xmppLayer, connectionLayer);
}
-void Session::sendStanza(boost::shared_ptr<Stanza> stanza) {
+void Session::sendElement(boost::shared_ptr<Element> stanza) {
xmppLayer->writeElement(stanza);
}
void Session::handleDisconnected(const boost::optional<Connection::Error>& connectionError) {
if (connectionError) {
- finishSession(ConnectionError);
+ switch (*connectionError) {
+ case Connection::ReadError:
+ finishSession(ConnectionReadError);
+ break;
+ case Connection::WriteError:
+ finishSession(ConnectionWriteError);
+ break;
+ }
}
else {
finishSession();
diff --git a/Swiften/Session/Session.h b/Swiften/Session/Session.h
index bf8049a..b35179c 100644
--- a/Swiften/Session/Session.h
+++ b/Swiften/Session/Session.h
@@ -14,7 +14,7 @@ namespace Swift {
class ProtocolHeader;
class StreamStack;
class JID;
- class Stanza;
+ class Element;
class ByteArray;
class PayloadParserFactoryCollection;
class PayloadSerializerCollection;
@@ -22,9 +22,18 @@ namespace Swift {
class Session : public boost::enable_shared_from_this<Session> {
public:
- enum Error {
- ConnectionError,
- XMLError
+ enum SessionError {
+ ConnectionReadError,
+ ConnectionWriteError,
+ XMLError,
+ AuthenticationFailedError,
+ NoSupportedAuthMechanismsError,
+ UnexpectedElementError,
+ ResourceBindError,
+ SessionStartError,
+ TLSError,
+ ClientCertificateLoadError,
+ ClientCertificateError
};
Session(
@@ -35,18 +44,19 @@ namespace Swift {
void startSession();
void finishSession();
- void sendStanza(boost::shared_ptr<Stanza>);
+ void sendElement(boost::shared_ptr<Element>);
- boost::signal<void (boost::shared_ptr<Stanza>)> onStanzaReceived;
+ boost::signal<void (boost::shared_ptr<Element>)> onElementReceived;
boost::signal<void ()> onSessionStarted;
- boost::signal<void (const boost::optional<Error>&)> onSessionFinished;
+ boost::signal<void (const boost::optional<SessionError>&)> onSessionFinished;
boost::signal<void (const ByteArray&)> onDataWritten;
boost::signal<void (const ByteArray&)> onDataRead;
protected:
- void finishSession(const Error&);
+ void finishSession(const SessionError&);
virtual void handleSessionStarted() {}
+ virtual void handleSessionFinished(const boost::optional<SessionError>&) {}
virtual void handleElement(boost::shared_ptr<Element>) = 0;
virtual void handleStreamStart(const ProtocolHeader&) = 0;
@@ -56,11 +66,17 @@ namespace Swift {
return xmppLayer;
}
+ StreamStack* getStreamStack() const {
+ return streamStack;
+ }
+
void setInitialized();
bool isInitialized() const {
return initialized;
}
+ void setFinished();
+
private:
void handleDisconnected(const boost::optional<Connection::Error>& error);