From a6fcd9e7aa12c5e00c61ff809e81fba14babd70c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Sun, 19 Jul 2009 15:21:38 +0200
Subject: Factor out common session stuff into Session class.


diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp
index a38416a..95f6c0f 100644
--- a/Swiften/Client/Client.cpp
+++ b/Swiften/Client/Client.cpp
@@ -12,21 +12,17 @@
 namespace Swift {
 
 Client::Client(const JID& jid, const String& password) :
-		IQRouter(this), jid_(jid), password_(password), session_(0) {
+		IQRouter(this), jid_(jid), password_(password) {
 	connectionFactory_ = new BoostConnectionFactory(&boostIOServiceThread_.getIOService());
 	tlsLayerFactory_ = new PlatformTLSLayerFactory();
 }
 
 Client::~Client() {
-	delete session_;
 	delete tlsLayerFactory_;
 	delete connectionFactory_;
 }
 
 void Client::connect() {
-	delete session_;
-	session_ = 0;
-
 	DomainNameResolver resolver;
 	try {
 		HostAddressPort remote = resolver.resolve(jid_.getDomain().getUTF8String());
@@ -44,23 +40,23 @@ void Client::handleConnectionConnectFinished(bool error) {
 		onError(ClientError::ConnectionError);
 	}
 	else {
-		session_ = new ClientSession(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_);
+		session_ = boost::shared_ptr<ClientSession>(new ClientSession(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_));
 		if (!certificate_.isEmpty()) {
 			session_->setCertificate(PKCS12Certificate(certificate_, password_));
 		}
 		session_->onSessionStarted.connect(boost::bind(boost::ref(onConnected)));
-		session_->onError.connect(boost::bind(&Client::handleSessionError, this, _1));
+		session_->onSessionFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1));
 		session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this));
 		session_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1));
 		session_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1));
 		session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1));
-		session_->start();
+		session_->startSession();
 	}
 }
 
 void Client::disconnect() {
 	if (session_) {
-		session_->stop();
+		session_->finishSession();
 	}
 }
 
@@ -108,47 +104,46 @@ void Client::setCertificate(const String& certificate) {
 	certificate_ = certificate;
 }
 
-void Client::handleSessionError(ClientSession::SessionError error) {
-	ClientError clientError;
-	switch (error) {
-		case ClientSession::NoError: 
-			assert(false);
-			break;
-		case ClientSession::ConnectionReadError:
-			clientError = ClientError(ClientError::ConnectionReadError);
-			break;
-		case ClientSession::ConnectionWriteError:
-			clientError = ClientError(ClientError::ConnectionWriteError);
-			break;
-		case ClientSession::XMLError:
-			clientError = ClientError(ClientError::XMLError);
-			break;
-		case ClientSession::AuthenticationFailedError:
-			clientError = ClientError(ClientError::AuthenticationFailedError);
-			break;
-		case ClientSession::NoSupportedAuthMechanismsError:
-			clientError = ClientError(ClientError::NoSupportedAuthMechanismsError);
-			break;
-		case ClientSession::UnexpectedElementError:
-			clientError = ClientError(ClientError::UnexpectedElementError);
-			break;
-		case ClientSession::ResourceBindError:
-			clientError = ClientError(ClientError::ResourceBindError);
-			break;
-		case ClientSession::SessionStartError:
-			clientError = ClientError(ClientError::SessionStartError);
-			break;
-		case ClientSession::TLSError:
-			clientError = ClientError(ClientError::TLSError);
-			break;
-		case ClientSession::ClientCertificateLoadError:
-			clientError = ClientError(ClientError::ClientCertificateLoadError);
-			break;
-		case ClientSession::ClientCertificateError:
-			clientError = ClientError(ClientError::ClientCertificateError);
-			break;
+void Client::handleSessionFinished(const boost::optional<Session::SessionError>& error) {
+	if (error) {
+		ClientError clientError;
+		switch (*error) {
+			case Session::ConnectionReadError:
+				clientError = ClientError(ClientError::ConnectionReadError);
+				break;
+			case Session::ConnectionWriteError:
+				clientError = ClientError(ClientError::ConnectionWriteError);
+				break;
+			case Session::XMLError:
+				clientError = ClientError(ClientError::XMLError);
+				break;
+			case Session::AuthenticationFailedError:
+				clientError = ClientError(ClientError::AuthenticationFailedError);
+				break;
+			case Session::NoSupportedAuthMechanismsError:
+				clientError = ClientError(ClientError::NoSupportedAuthMechanismsError);
+				break;
+			case Session::UnexpectedElementError:
+				clientError = ClientError(ClientError::UnexpectedElementError);
+				break;
+			case Session::ResourceBindError:
+				clientError = ClientError(ClientError::ResourceBindError);
+				break;
+			case Session::SessionStartError:
+				clientError = ClientError(ClientError::SessionStartError);
+				break;
+			case Session::TLSError:
+				clientError = ClientError(ClientError::TLSError);
+				break;
+			case Session::ClientCertificateLoadError:
+				clientError = ClientError(ClientError::ClientCertificateLoadError);
+				break;
+			case Session::ClientCertificateError:
+				clientError = ClientError(ClientError::ClientCertificateError);
+				break;
+		}
+		onError(clientError);
 	}
-	onError(clientError);
 }
 
 void Client::handleNeedCredentials() {
diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h
index 48b76d9..1561c75 100644
--- a/Swiften/Client/Client.h
+++ b/Swiften/Client/Client.h
@@ -47,7 +47,7 @@ namespace Swift {
 			void send(boost::shared_ptr<Stanza>);
 			virtual String getNewIQID();
 			void handleElement(boost::shared_ptr<Element>);
-			void handleSessionError(ClientSession::SessionError error);
+			void handleSessionFinished(const boost::optional<Session::SessionError>& error);
 			void handleNeedCredentials();
 			void handleDataRead(const ByteArray&);
 			void handleDataWritten(const ByteArray&);
@@ -61,7 +61,7 @@ namespace Swift {
 			TLSLayerFactory* tlsLayerFactory_;
 			FullPayloadParserFactoryCollection payloadParserFactories_;
 			FullPayloadSerializerCollection payloadSerializers_;
-			ClientSession* session_;
+			boost::shared_ptr<ClientSession> session_;
 			boost::shared_ptr<Connection> connection_;
 			String certificate_;
 	};
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp
index 11317e8..4fcf1f8 100644
--- a/Swiften/Client/ClientSession.cpp
+++ b/Swiften/Client/ClientSession.cpp
@@ -30,70 +30,30 @@ ClientSession::ClientSession(
 		TLSLayerFactory* tlsLayerFactory, 
 		PayloadParserFactoryCollection* payloadParserFactories, 
 		PayloadSerializerCollection* payloadSerializers) : 
+			Session(connection, payloadParserFactories, payloadSerializers),
 			jid_(jid), 
 			tlsLayerFactory_(tlsLayerFactory),
-			payloadParserFactories_(payloadParserFactories),
-			payloadSerializers_(payloadSerializers),
 			state_(Initial), 
-			error_(NoError),
-			connection_(connection),
-			streamStack_(0),
 			needSessionStart_(false) {
 }
 
-ClientSession::~ClientSession() {
-	delete streamStack_;
-}
-
-void ClientSession::start() {
+void ClientSession::handleSessionStarted() {
 	assert(state_ == Initial);
-
-	connection_->onDisconnected.connect(boost::bind(&ClientSession::handleDisconnected, this, _1));
-	initializeStreamStack();
 	state_ = WaitingForStreamStart;
 	sendStreamHeader();
 }
 
-void ClientSession::stop() {
-	// TODO: Send end stream header if applicable
-	connection_->disconnect();
-}
-
 void ClientSession::sendStreamHeader() {
 	ProtocolHeader header;
 	header.setTo(jid_.getDomain());
-	xmppLayer_->writeHeader(header);
-}
-
-void ClientSession::initializeStreamStack() {
-	xmppLayer_ = boost::shared_ptr<XMPPLayer>(new XMPPLayer(payloadParserFactories_, payloadSerializers_));
-	xmppLayer_->onStreamStart.connect(boost::bind(&ClientSession::handleStreamStart, this));
-	xmppLayer_->onElement.connect(boost::bind(&ClientSession::handleElement, this, _1));
-	xmppLayer_->onError.connect(boost::bind(&ClientSession::setError, this, XMLError));
-	xmppLayer_->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1));
-	xmppLayer_->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1));
-	connectionLayer_ = boost::shared_ptr<ConnectionLayer>(new ConnectionLayer(connection_));
-	streamStack_ = new StreamStack(xmppLayer_, connectionLayer_);
-}
-
-void ClientSession::handleDisconnected(const boost::optional<Connection::Error>& error) {
-	if (error) {
-		switch (*error) {
-			case Connection::ReadError:
-				setError(ConnectionReadError);
-				break;
-			case Connection::WriteError:
-				setError(ConnectionWriteError);
-				break;
-		}
-	}
+	getXMPPLayer()->writeHeader(header);
 }
 
 void ClientSession::setCertificate(const PKCS12Certificate& certificate) {
 	certificate_ = certificate;
 }
 
-void ClientSession::handleStreamStart() {
+void ClientSession::handleStreamStart(const ProtocolHeader&) {
 	checkState(WaitingForStreamStart);
 	state_ = Negotiating;
 }
@@ -109,16 +69,16 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 
 		if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) {
 			state_ = Encrypting;
-			xmppLayer_->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest()));
+			getXMPPLayer()->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest()));
 		}
 		else if (streamFeatures->hasAuthenticationMechanisms()) {
 			if (!certificate_.isNull()) {
 				if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
 					state_ = Authenticating;
-					xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", "")));
+					getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", "")));
 				}
 				else {
-					setError(ClientCertificateError);
+					finishSession(ClientCertificateError);
 				}
 			}
 			else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) {
@@ -126,7 +86,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 				onNeedCredentials();
 			}
 			else {
-				setError(NoSupportedAuthMechanismsError);
+				finishSession(NoSupportedAuthMechanismsError);
 			}
 		}
 		else {
@@ -134,7 +94,7 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 
 			// Add a whitespace ping layer
 			whitespacePingLayer_ = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer());
-			streamStack_->addLayer(whitespacePingLayer_);
+			getStreamStack()->addLayer(whitespacePingLayer_);
 
 			if (streamFeatures->hasSession()) {
 				needSessionStart_ = true;
@@ -146,31 +106,31 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 				if (!jid_.getResource().isEmpty()) {
 					resourceBind->setResource(jid_.getResource());
 				}
-				xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind));
+				getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind));
 			}
 			else if (needSessionStart_) {
 				sendSessionStart();
 			}
 			else {
 				state_ = SessionStarted;
-				onSessionStarted();
+				setInitialized();
 			}
 		}
 	}
 	else if (dynamic_cast<AuthSuccess*>(element.get())) {
 		checkState(Authenticating);
 		state_ = WaitingForStreamStart;
-		xmppLayer_->resetParser();
+		getXMPPLayer()->resetParser();
 		sendStreamHeader();
 	}
 	else if (dynamic_cast<AuthFailure*>(element.get())) {
-		setError(AuthenticationFailedError);
+		finishSession(AuthenticationFailedError);
 	}
 	else if (dynamic_cast<TLSProceed*>(element.get())) {
 		tlsLayer_ = tlsLayerFactory_->createTLSLayer();
-		streamStack_->addLayer(tlsLayer_);
+		getStreamStack()->addLayer(tlsLayer_);
 		if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) {
-			setError(ClientCertificateLoadError);
+			finishSession(ClientCertificateLoadError);
 		}
 		else {
 			tlsLayer_->onConnected.connect(boost::bind(&ClientSession::handleTLSConnected, this));
@@ -179,21 +139,21 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 		}
 	}
 	else if (dynamic_cast<StartTLSFailure*>(element.get())) {
-		setError(TLSError);
+		finishSession(TLSError);
 	}
 	else if (IQ* iq = dynamic_cast<IQ*>(element.get())) {
 		if (state_ == BindingResource) {
 			boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>());
 			if (iq->getType() == IQ::Error && iq->getID() == "session-bind") {
-				setError(ResourceBindError);
+				finishSession(ResourceBindError);
 			}
 			else if (!resourceBind) {
-				setError(UnexpectedElementError);
+				finishSession(UnexpectedElementError);
 			}
 			else if (iq->getType() == IQ::Result) {
 				jid_ = resourceBind->getJID();
 				if (!jid_.isValid()) {
-					setError(ResourceBindError);
+					finishSession(ResourceBindError);
 				}
 				if (needSessionStart_) {
 					sendSessionStart();
@@ -203,47 +163,51 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 				}
 			}
 			else {
-				setError(UnexpectedElementError);
+				finishSession(UnexpectedElementError);
 			}
 		}
 		else if (state_ == StartingSession) {
 			if (iq->getType() == IQ::Result) {
 				state_ = SessionStarted;
-				onSessionStarted();
+				setInitialized();
 			}
 			else if (iq->getType() == IQ::Error) {
-				setError(SessionStartError);
+				finishSession(SessionStartError);
 			}
 			else {
-				setError(UnexpectedElementError);
+				finishSession(UnexpectedElementError);
 			}
 		}
 		else {
-			setError(UnexpectedElementError);
+			finishSession(UnexpectedElementError);
 		}
 	}
 	else {
 		// FIXME Not correct?
 		state_ = SessionStarted;
-		onSessionStarted();
+		setInitialized();
 	}
 }
 
 void ClientSession::sendSessionStart() {
 	state_ = StartingSession;
-	xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession())));
+	getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession())));
 }
 
-void ClientSession::setError(SessionError error) {
-	assert(error != NoError);
-	state_ = Error;
-	error_ = error;
-	onError(error);
+void ClientSession::handleSessionFinished(const boost::optional<SessionError>& error) {
+	if (error) {
+		assert(!error_);
+		state_ = Error;
+		error_ = error;
+	}
+	else {
+		state_ = Finished;
+	}
 }
 
 bool ClientSession::checkState(State state) {
 	if (state_ != state) {
-		setError(UnexpectedElementError);
+		finishSession(UnexpectedElementError);
 		return false;
 	}
 	return true;
@@ -252,22 +216,17 @@ bool ClientSession::checkState(State state) {
 void ClientSession::sendCredentials(const String& password) {
 	assert(WaitingForCredentials);
 	state_ = Authenticating;
-	xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(jid_.getNode(), password).getValue())));
-}
-
-void ClientSession::sendElement(boost::shared_ptr<Element> element) {
-	assert(SessionStarted);
-	xmppLayer_->writeElement(element);
+	getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(jid_.getNode(), password).getValue())));
 }
 
 void ClientSession::handleTLSConnected() {
 	state_ = WaitingForStreamStart;
-	xmppLayer_->resetParser();
+	getXMPPLayer()->resetParser();
 	sendStreamHeader();
 }
 
 void ClientSession::handleTLSError() {
-	setError(TLSError);
+	finishSession(TLSError);
 }
 
 }
diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h
index 50dae24..22e4a88 100644
--- a/Swiften/Client/ClientSession.h
+++ b/Swiften/Client/ClientSession.h
@@ -3,6 +3,7 @@
 #include <boost/signal.hpp>
 #include <boost/shared_ptr.hpp>
 
+#include "Swiften/Session/Session.h"
 #include "Swiften/Base/String.h"
 #include "Swiften/JID/JID.h"
 #include "Swiften/Elements/Element.h"
@@ -21,7 +22,7 @@ namespace Swift {
 	class TLSLayer;
 	class WhitespacePingLayer;
 
-	class ClientSession {
+	class ClientSession : public Session {
 		public:
 			enum State {
 				Initial,
@@ -34,21 +35,8 @@ namespace Swift {
 				BindingResource,
 				StartingSession,
 				SessionStarted,
-				Error
-			};
-			enum SessionError {
-				NoError,
-				ConnectionReadError,
-				ConnectionWriteError,
-				XMLError,
-				AuthenticationFailedError,
-				NoSupportedAuthMechanismsError,
-				UnexpectedElementError,
-				ResourceBindError,
-				SessionStartError,
-				TLSError,
-				ClientCertificateLoadError,
-				ClientCertificateError
+				Error,
+				Finished
 			};
 
 			ClientSession(
@@ -57,13 +45,12 @@ namespace Swift {
 					TLSLayerFactory*, 
 					PayloadParserFactoryCollection*, 
 					PayloadSerializerCollection*);
-			~ClientSession();
 
 			State getState() const {
 				return state_;
 			}
 
-			SessionError getError() const {
+			boost::optional<SessionError> getError() const {
 				return error_;
 			}
 
@@ -71,26 +58,18 @@ namespace Swift {
 				return jid_;
 			}
 
-			void start();
-			void stop();
-
 			void sendCredentials(const String& password);
-			void sendElement(boost::shared_ptr<Element>);
 			void setCertificate(const PKCS12Certificate& certificate);
 
-		protected:
-			StreamStack* getStreamStack() const {
-				return streamStack_;
-			}
-
 		private:
-			void initializeStreamStack();
 			void sendStreamHeader();
 			void sendSessionStart();
 
-			void handleDisconnected(const boost::optional<Connection::Error>&);
-			void handleElement(boost::shared_ptr<Element>);
-			void handleStreamStart();
+			virtual void handleSessionStarted();
+			virtual void handleSessionFinished(const boost::optional<SessionError>& error);
+			virtual void handleElement(boost::shared_ptr<Element>);
+			virtual void handleStreamStart(const ProtocolHeader&);
+
 			void handleTLSConnected();
 			void handleTLSError();
 
@@ -98,26 +77,15 @@ namespace Swift {
 			bool checkState(State);
 
 		public:
-			boost::signal<void ()> onSessionStarted;
-			boost::signal<void (SessionError)> onError;
 			boost::signal<void ()> onNeedCredentials;
-			boost::signal<void (boost::shared_ptr<Element>) > onElementReceived;
-			boost::signal<void (const ByteArray&)> onDataWritten;
-			boost::signal<void (const ByteArray&)> onDataRead;
 		
 		private:
 			JID jid_;
 			TLSLayerFactory* tlsLayerFactory_;
-			PayloadParserFactoryCollection* payloadParserFactories_;
-			PayloadSerializerCollection* payloadSerializers_;
 			State state_;
-			SessionError error_;
-			boost::shared_ptr<Connection> connection_;
-			boost::shared_ptr<XMPPLayer> xmppLayer_;
+			boost::optional<SessionError> error_;
 			boost::shared_ptr<TLSLayer> tlsLayer_;
-			boost::shared_ptr<ConnectionLayer> connectionLayer_;
 			boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer_;
-			StreamStack* streamStack_;
 			bool needSessionStart_;
 			PKCS12Certificate certificate_;
 	};
diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp
index 1e66019..c86442d 100644
--- a/Swiften/Client/UnitTest/ClientSessionTest.cpp
+++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp
@@ -78,15 +78,15 @@ class ClientSessionTest : public CppUnit::TestFixture {
 		}
 
 		void testConstructor() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			CPPUNIT_ASSERT_EQUAL(ClientSession::Initial, session->getState());
 		}
 
 		void testStart_Error() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 
 			getMockServer()->expectStreamStart();
-			session->start();
+			session->startSession();
 			processEvents();
 			CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForStreamStart, session->getState());
 
@@ -94,14 +94,14 @@ class ClientSessionTest : public CppUnit::TestFixture {
 			processEvents();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
-			CPPUNIT_ASSERT_EQUAL(ClientSession::ConnectionReadError, session->getError());
+			CPPUNIT_ASSERT_EQUAL(ClientSession::ConnectionReadError, *session->getError());
 		}
 
 		void testStart_XMLError() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 
 			getMockServer()->expectStreamStart();
-			session->start();
+			session->startSession();
 			processEvents();
 			CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForStreamStart, session->getState());
 
@@ -109,29 +109,29 @@ class ClientSessionTest : public CppUnit::TestFixture {
 			processEvents();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
-			CPPUNIT_ASSERT_EQUAL(ClientSession::XMLError, session->getError());
+			CPPUNIT_ASSERT_EQUAL(ClientSession::XMLError, *session->getError());
 		}
 
 		void testStartTLS_NoTLSSupport() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			tlsLayerFactory_->setTLSSupported(false);
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithStartTLS();
-			session->start();
+			session->startSession();
 			processEvents();
 			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
 		}
 
 		void testStartTLS() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithStartTLS();
 			getMockServer()->expectStartTLS();
 			// FIXME: Test 'encrypting' state
 			getMockServer()->sendTLSProceed();
-			session->start();
+			session->startSession();
 			processEvents();
 			CPPUNIT_ASSERT_EQUAL(ClientSession::Encrypting, session->getState());
 			CPPUNIT_ASSERT(session->getTLSLayer());
@@ -147,42 +147,42 @@ class ClientSessionTest : public CppUnit::TestFixture {
 		}
 
 		void testStartTLS_ServerError() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithStartTLS();
 			getMockServer()->expectStartTLS();
 			getMockServer()->sendTLSFailure();
-			session->start();
+			session->startSession();
 			processEvents();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
-			CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, session->getError());
+			CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, *session->getError());
 		}
 
 		void testStartTLS_ConnectError() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithStartTLS();
 			getMockServer()->expectStartTLS();
 			getMockServer()->sendTLSProceed();
-			session->start();
+			session->startSession();
 			processEvents();
 			session->getTLSLayer()->setError();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
-			CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, session->getError());
+			CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, *session->getError());
 		}
 
 		void testStartTLS_ErrorAfterConnect() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithStartTLS();
 			getMockServer()->expectStartTLS();
 			getMockServer()->sendTLSProceed();
-			session->start();
+			session->startSession();
 			processEvents();
 			getMockServer()->resetParser();
 			getMockServer()->expectStreamStart();
@@ -193,16 +193,16 @@ class ClientSessionTest : public CppUnit::TestFixture {
 			session->getTLSLayer()->setError();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
-			CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, session->getError());
+			CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, *session->getError());
 		}
 
 		void testAuthenticate() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::setNeedCredentials, this));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithAuthentication();
-			session->start();
+			session->startSession();
 			processEvents();
 			CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState());
 			CPPUNIT_ASSERT(needCredentials_);
@@ -218,11 +218,11 @@ class ClientSessionTest : public CppUnit::TestFixture {
 		}
 
 		void testAuthenticate_Unauthorized() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithAuthentication();
-			session->start();
+			session->startSession();
 			processEvents();
 
 			getMockServer()->expectAuth("me", "mypass");
@@ -231,30 +231,30 @@ class ClientSessionTest : public CppUnit::TestFixture {
 			processEvents();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
-			CPPUNIT_ASSERT_EQUAL(ClientSession::AuthenticationFailedError, session->getError());
+			CPPUNIT_ASSERT_EQUAL(ClientSession::AuthenticationFailedError, *session->getError());
 		}
 
 		void testAuthenticate_NoValidAuthMechanisms() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithUnsupportedAuthentication();
-			session->start();
+			session->startSession();
 			processEvents();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
-			CPPUNIT_ASSERT_EQUAL(ClientSession::NoSupportedAuthMechanismsError, session->getError());
+			CPPUNIT_ASSERT_EQUAL(ClientSession::NoSupportedAuthMechanismsError, *session->getError());
 		}
 
 		void testResourceBind() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithResourceBind();
 			getMockServer()->expectResourceBind("Bar", "session-bind");
 			// FIXME: Check CPPUNIT_ASSERT_EQUAL(ClientSession::BindingResource, session->getState());
 			getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind");
-			session->start();
+			session->startSession();
 
 			processEvents();
 
@@ -263,13 +263,13 @@ class ClientSessionTest : public CppUnit::TestFixture {
 		}
 
 		void testResourceBind_ChangeResource() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithResourceBind();
 			getMockServer()->expectResourceBind("Bar", "session-bind");
 			getMockServer()->sendResourceBindResponse("me@foo.com/Bar123", "session-bind");
-			session->start();
+			session->startSession();
 			processEvents();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
@@ -277,13 +277,13 @@ class ClientSessionTest : public CppUnit::TestFixture {
 		}
 
 		void testResourceBind_EmptyResource() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithResourceBind();
 			getMockServer()->expectResourceBind("", "session-bind");
 			getMockServer()->sendResourceBindResponse("me@foo.com/NewResource", "session-bind");
-			session->start();
+			session->startSession();
 			processEvents();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
@@ -291,21 +291,21 @@ class ClientSessionTest : public CppUnit::TestFixture {
 		}
 
 		void testResourceBind_Error() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithResourceBind();
 			getMockServer()->expectResourceBind("", "session-bind");
 			getMockServer()->sendError("session-bind");
-			session->start();
+			session->startSession();
 			processEvents();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
-			CPPUNIT_ASSERT_EQUAL(ClientSession::ResourceBindError, session->getError());
+			CPPUNIT_ASSERT_EQUAL(ClientSession::ResourceBindError, *session->getError());
 		}
 
 		void testSessionStart() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			session->onSessionStarted.connect(boost::bind(&ClientSessionTest::setSessionStarted, this));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
@@ -313,7 +313,7 @@ class ClientSessionTest : public CppUnit::TestFixture {
 			getMockServer()->expectSessionStart("session-start");
 			// FIXME: Check CPPUNIT_ASSERT_EQUAL(ClientSession::StartingSession, session->getState());
 			getMockServer()->sendSessionStartResponse("session-start");
-			session->start();
+			session->startSession();
 			processEvents();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
@@ -321,21 +321,21 @@ class ClientSessionTest : public CppUnit::TestFixture {
 		}
 
 		void testSessionStart_Error() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeaturesWithSession();
 			getMockServer()->expectSessionStart("session-start");
 			getMockServer()->sendError("session-start");
-			session->start();
+			session->startSession();
 			processEvents();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
-			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStartError, session->getError());
+			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStartError, *session->getError());
 		}
 
 		void testSessionStart_AfterResourceBind() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			session->onSessionStarted.connect(boost::bind(&ClientSessionTest::setSessionStarted, this));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
@@ -344,7 +344,7 @@ class ClientSessionTest : public CppUnit::TestFixture {
 			getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind");
 			getMockServer()->expectSessionStart("session-start");
 			getMockServer()->sendSessionStartResponse("session-start");
-			session->start();
+			session->startSession();
 			processEvents();
 
 			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
@@ -352,21 +352,21 @@ class ClientSessionTest : public CppUnit::TestFixture {
 		}
 
 		void testWhitespacePing() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeatures();
-			session->start();
+			session->startSession();
 			processEvents();
 			CPPUNIT_ASSERT(session->getWhitespacePingLayer());
 		}
 
 		void testReceiveElementAfterSessionStarted() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeatures();
-			session->start();
+			session->startSession();
 			processEvents();
 
 			getMockServer()->expectMessage();
@@ -374,13 +374,13 @@ class ClientSessionTest : public CppUnit::TestFixture {
 		}
 
 		void testSendElement() {
-			std::auto_ptr<MockSession> session(createSession("me@foo.com/Bar"));
+			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
 			session->onElementReceived.connect(boost::bind(&ClientSessionTest::addReceivedElement, this, _1));
 			getMockServer()->expectStreamStart();
 			getMockServer()->sendStreamStart();
 			getMockServer()->sendStreamFeatures();
 			getMockServer()->sendMessage();
-			session->start();
+			session->startSession();
 			processEvents();
 
 			CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(receivedElements_.size()));
@@ -684,8 +684,8 @@ class ClientSessionTest : public CppUnit::TestFixture {
 			}
 		};
 
-		MockSession* createSession(const String& jid) {
-			return new MockSession(JID(jid), connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_);
+		boost::shared_ptr<MockSession> createSession(const String& jid) {
+			return boost::shared_ptr<MockSession>(new MockSession(JID(jid), connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_));
 		}
 
 
diff --git a/Swiften/Makefile.inc b/Swiften/Makefile.inc
index 6fa2df8..d66a2b9 100644
--- a/Swiften/Makefile.inc
+++ b/Swiften/Makefile.inc
@@ -10,6 +10,7 @@ include Swiften/Serializer/Makefile.inc
 include Swiften/Parser/Makefile.inc
 include Swiften/MUC/Makefile.inc
 include Swiften/Network/Makefile.inc
+include Swiften/Session/Makefile.inc
 include Swiften/Client/Makefile.inc
 include Swiften/TLS/Makefile.inc
 include Swiften/SASL/Makefile.inc
diff --git a/Swiften/Session/Session.cpp b/Swiften/Session/Session.cpp
index 9ab8e4d..84354e5 100644
--- a/Swiften/Session/Session.cpp
+++ b/Swiften/Session/Session.cpp
@@ -28,12 +28,14 @@ void Session::startSession() {
 
 void Session::finishSession() {
 	connection->disconnect();
-	onSessionFinished(boost::optional<Error>());
+	handleSessionFinished(boost::optional<SessionError>());
+	onSessionFinished(boost::optional<SessionError>());
 }
 
-void Session::finishSession(const Error& error) {
+void Session::finishSession(const SessionError& error) {
 	connection->disconnect();
-	onSessionFinished(boost::optional<Error>(error));
+	handleSessionFinished(boost::optional<SessionError>(error));
+	onSessionFinished(boost::optional<SessionError>(error));
 }
 
 void Session::initializeStreamStack() {
@@ -41,23 +43,31 @@ void Session::initializeStreamStack() {
 			new XMPPLayer(payloadParserFactories, payloadSerializers));
 	xmppLayer->onStreamStart.connect(
 			boost::bind(&Session::handleStreamStart, this, _1));
-	xmppLayer->onElement.connect(
-			boost::bind(&Session::handleElement, this, _1));
+	xmppLayer->onElement.connect(boost::bind(&Session::handleElement, this, _1));
 	xmppLayer->onError.connect(
 			boost::bind(&Session::finishSession, this, XMLError));
+	xmppLayer->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1));
+	xmppLayer->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1));
 	connection->onDisconnected.connect(
 			boost::bind(&Session::handleDisconnected, shared_from_this(), _1));
 	connectionLayer = boost::shared_ptr<ConnectionLayer>(new ConnectionLayer(connection));
 	streamStack = new StreamStack(xmppLayer, connectionLayer);
 }
 
-void Session::sendStanza(boost::shared_ptr<Stanza> stanza) {
+void Session::sendElement(boost::shared_ptr<Element> stanza) {
 	xmppLayer->writeElement(stanza);
 }
 
 void Session::handleDisconnected(const boost::optional<Connection::Error>& connectionError) {
 	if (connectionError) {
-		finishSession(ConnectionError);
+		switch (*connectionError) {
+			case Connection::ReadError:
+				finishSession(ConnectionReadError);
+				break;
+			case Connection::WriteError:
+				finishSession(ConnectionWriteError);
+				break;
+		}
 	}
 	else {
 		finishSession();
diff --git a/Swiften/Session/Session.h b/Swiften/Session/Session.h
index bf8049a..b35179c 100644
--- a/Swiften/Session/Session.h
+++ b/Swiften/Session/Session.h
@@ -14,7 +14,7 @@ namespace Swift {
 	class ProtocolHeader;
 	class StreamStack;
 	class JID;
-	class Stanza;
+	class Element;
 	class ByteArray;
 	class PayloadParserFactoryCollection;
 	class PayloadSerializerCollection;
@@ -22,9 +22,18 @@ namespace Swift {
 
 	class Session : public boost::enable_shared_from_this<Session> {
 		public:
-			enum Error {
-				ConnectionError,
-				XMLError
+			enum SessionError {
+				ConnectionReadError,
+				ConnectionWriteError,
+				XMLError,
+				AuthenticationFailedError,
+				NoSupportedAuthMechanismsError,
+				UnexpectedElementError,
+				ResourceBindError,
+				SessionStartError,
+				TLSError,
+				ClientCertificateLoadError,
+				ClientCertificateError
 			};
 
 			Session(
@@ -35,18 +44,19 @@ namespace Swift {
 
 			void startSession();
 			void finishSession();
-			void sendStanza(boost::shared_ptr<Stanza>);
+			void sendElement(boost::shared_ptr<Element>);
 
-			boost::signal<void (boost::shared_ptr<Stanza>)> onStanzaReceived;
+			boost::signal<void (boost::shared_ptr<Element>)> onElementReceived;
 			boost::signal<void ()> onSessionStarted;
-			boost::signal<void (const boost::optional<Error>&)> onSessionFinished;
+			boost::signal<void (const boost::optional<SessionError>&)> onSessionFinished;
 			boost::signal<void (const ByteArray&)> onDataWritten;
 			boost::signal<void (const ByteArray&)> onDataRead;
 
 		protected:
-			void finishSession(const Error&);
+			void finishSession(const SessionError&);
 
 			virtual void handleSessionStarted() {}
+			virtual void handleSessionFinished(const boost::optional<SessionError>&) {}
 			virtual void handleElement(boost::shared_ptr<Element>) = 0;
 			virtual void handleStreamStart(const ProtocolHeader&) = 0;
 
@@ -56,11 +66,17 @@ namespace Swift {
 				return xmppLayer;
 			}
 
+			StreamStack* getStreamStack() const {
+				return streamStack;
+			}
+
 			void setInitialized();
 			bool isInitialized() const { 
 				return initialized; 
 			}
 
+			void setFinished();
+
 		private:
 			void handleDisconnected(const boost::optional<Connection::Error>& error);
 
-- 
cgit v0.10.2-6-g49f6