From c682941cd230ad8caed3f3d457de3dc0cd7172d4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Mon, 9 Nov 2009 21:34:44 +0100
Subject: Refactoring Client.


diff --git a/Swiften/Client/Client.cpp b/Swiften/Client/Client.cpp
index b3bea0d..adda6af 100644
--- a/Swiften/Client/Client.cpp
+++ b/Swiften/Client/Client.cpp
@@ -50,25 +50,26 @@ void Client::handleConnectionConnectFinished(bool error) {
 		assert(!sessionStream_);
 		sessionStream_ = boost::shared_ptr<BasicSessionStream>(new BasicSessionStream(connection_, &payloadParserFactories_, &payloadSerializers_, tlsLayerFactory_));
 		sessionStream_->initialize();
-
-		session_ = boost::shared_ptr<ClientSession>(new ClientSession(jid_, connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_));
 		if (!certificate_.isEmpty()) {
-			session_->setCertificate(PKCS12Certificate(certificate_, password_));
+			sessionStream_->setTLSCertificate(PKCS12Certificate(certificate_, password_));
 		}
-		session_->onSessionStarted.connect(boost::bind(boost::ref(onConnected)));
-		session_->onSessionFinished.connect(boost::bind(&Client::handleSessionFinished, this, _1));
-		session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, this));
-		session_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1));
-		session_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1));
-		session_->onElementReceived.connect(boost::bind(&Client::handleElement, this, _1));
-		session_->startSession();
+		//sessionStream_->onDataRead.connect(boost::bind(&Client::handleDataRead, this, _1));
+		//sessionStream_->onDataWritten.connect(boost::bind(&Client::handleDataWritten, this, _1));
+
+		session_ = boost::shared_ptr<ClientSession>(new ClientSession(jid_, sessionStream_));
+		session_->onInitialized.connect(boost::bind(boost::ref(onConnected)));
+		session_->onFinished.connect(boost::bind(&Client::handleSessionFinished, shared_from_this(), _1));
+		session_->onNeedCredentials.connect(boost::bind(&Client::handleNeedCredentials, shared_from_this()));
+		session_->onElementReceived.connect(boost::bind(&Client::handleElement, shared_from_this(), _1));
+		session_->start();
 	}
 }
 
 void Client::disconnect() {
-	if (session_) {
-		session_->finishSession();
-	}
+	// TODO
+	//if (session_) {
+	//	session_->finishSession();
+	//}
 }
 
 void Client::send(boost::shared_ptr<Stanza> stanza) {
@@ -115,9 +116,10 @@ void Client::setCertificate(const String& certificate) {
 	certificate_ = certificate;
 }
 
-void Client::handleSessionFinished(const boost::optional<Session::SessionError>& error) {
+void Client::handleSessionFinished(boost::shared_ptr<Error> error) {
 	if (error) {
 		ClientError clientError;
+		/*
 		switch (*error) {
 			case Session::ConnectionReadError:
 				clientError = ClientError(ClientError::ConnectionReadError);
@@ -153,6 +155,7 @@ void Client::handleSessionFinished(const boost::optional<Session::SessionError>&
 				clientError = ClientError(ClientError::ClientCertificateError);
 				break;
 		}
+		*/
 		onError(clientError);
 	}
 }
diff --git a/Swiften/Client/Client.h b/Swiften/Client/Client.h
index 0e68f55..27c2458 100644
--- a/Swiften/Client/Client.h
+++ b/Swiften/Client/Client.h
@@ -3,7 +3,9 @@
 
 #include <boost/signals.hpp>
 #include <boost/shared_ptr.hpp>
+#include <boost/enable_shared_from_this.hpp>
 
+#include "Swiften/Base/Error.h"
 #include "Swiften/Client/ClientSession.h"
 #include "Swiften/Client/ClientError.h"
 #include "Swiften/Elements/Presence.h"
@@ -22,7 +24,7 @@ namespace Swift {
 	class ClientSession;
 	class BasicSessionStream;
 
-	class Client : public StanzaChannel, public IQRouter, public boost::bsignals::trackable {
+	class Client : public StanzaChannel, public IQRouter, public boost::enable_shared_from_this<Client> {
 		public:
 			Client(const JID& jid, const String& password);
 			~Client();
@@ -39,7 +41,7 @@ namespace Swift {
 			virtual void sendPresence(boost::shared_ptr<Presence>);
 
 		public:
-			boost::signal<void (ClientError)> onError;
+			boost::signal<void (const ClientError&)> onError;
 			boost::signal<void ()> onConnected;
 			boost::signal<void (const String&)> onDataRead;
 			boost::signal<void (const String&)> onDataWritten;
@@ -49,7 +51,7 @@ namespace Swift {
 			void send(boost::shared_ptr<Stanza>);
 			virtual String getNewIQID();
 			void handleElement(boost::shared_ptr<Element>);
-			void handleSessionFinished(const boost::optional<Session::SessionError>& error);
+			void handleSessionFinished(boost::shared_ptr<Error>);
 			void handleNeedCredentials();
 			void handleDataRead(const ByteArray&);
 			void handleDataWritten(const ByteArray&);
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp
index a0e1289..ed5d27d 100644
--- a/Swiften/Client/ClientSession.cpp
+++ b/Swiften/Client/ClientSession.cpp
@@ -2,13 +2,7 @@
 
 #include <boost/bind.hpp>
 
-#include "Swiften/Network/ConnectionFactory.h"
 #include "Swiften/Elements/ProtocolHeader.h"
-#include "Swiften/StreamStack/StreamStack.h"
-#include "Swiften/StreamStack/ConnectionLayer.h"
-#include "Swiften/StreamStack/XMPPLayer.h"
-#include "Swiften/StreamStack/TLSLayer.h"
-#include "Swiften/StreamStack/TLSLayerFactory.h"
 #include "Swiften/Elements/StreamFeatures.h"
 #include "Swiften/Elements/StartTLSRequest.h"
 #include "Swiften/Elements/StartTLSFailure.h"
@@ -20,47 +14,46 @@
 #include "Swiften/Elements/IQ.h"
 #include "Swiften/Elements/ResourceBind.h"
 #include "Swiften/SASL/PLAINMessage.h"
-#include "Swiften/StreamStack/WhitespacePingLayer.h"
+#include "Swiften/Session/SessionStream.h"
 
 namespace Swift {
 
 ClientSession::ClientSession(
 		const JID& jid, 
-		boost::shared_ptr<Connection> connection,
-		TLSLayerFactory* tlsLayerFactory, 
-		PayloadParserFactoryCollection* payloadParserFactories, 
-		PayloadSerializerCollection* payloadSerializers) : 
-			Session(connection, payloadParserFactories, payloadSerializers),
-			tlsLayerFactory_(tlsLayerFactory),
-			state_(Initial), 
-			needSessionStart_(false) {
-	setLocalJID(jid);
-	setRemoteJID(JID("", jid.getDomain()));
+		boost::shared_ptr<SessionStream> stream) :
+			localJID(jid),	
+			state(Initial), 
+			stream(stream),
+			needSessionStart(false) {
+	stream->onStreamStartReceived.connect(boost::bind(&ClientSession::handleStreamStart, shared_from_this(), _1));
+	stream->onElementReceived.connect(boost::bind(&ClientSession::handleElement, shared_from_this(), _1));
+	stream->onError.connect(boost::bind(&ClientSession::handleStreamError, shared_from_this(), _1));
+	stream->onTLSEncrypted.connect(boost::bind(&ClientSession::handleTLSEncrypted, shared_from_this()));
 }
 
-void ClientSession::handleSessionStarted() {
-	assert(state_ == Initial);
-	state_ = WaitingForStreamStart;
+void ClientSession::start() {
+	assert(state == Initial);
+	state = WaitingForStreamStart;
 	sendStreamHeader();
 }
 
 void ClientSession::sendStreamHeader() {
 	ProtocolHeader header;
 	header.setTo(getRemoteJID());
-	getXMPPLayer()->writeHeader(header);
+	stream->writeHeader(header);
 }
 
-void ClientSession::setCertificate(const PKCS12Certificate& certificate) {
-	certificate_ = certificate;
+void ClientSession::sendElement(boost::shared_ptr<Element> element) {
+	stream->writeElement(element);
 }
 
 void ClientSession::handleStreamStart(const ProtocolHeader&) {
 	checkState(WaitingForStreamStart);
-	state_ = Negotiating;
+	state = Negotiating;
 }
 
 void ClientSession::handleElement(boost::shared_ptr<Element> element) {
-	if (getState() == SessionStarted) {
+	if (getState() == Initialized) {
 		onElementReceived(element);
 	}
 	else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) {
@@ -68,152 +61,121 @@ void ClientSession::handleElement(boost::shared_ptr<Element> element) {
 			return;
 		}
 
-		if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) {
-			state_ = Encrypting;
-			getXMPPLayer()->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest()));
+		if (streamFeatures->hasStartTLS() && stream->supportsTLSEncryption()) {
+			state = WaitingForEncrypt;
+			stream->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest()));
 		}
 		else if (streamFeatures->hasAuthenticationMechanisms()) {
-			if (!certificate_.isNull()) {
-				if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
-					state_ = Authenticating;
-					getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", "")));
-				}
-				else {
-					finishSession(ClientCertificateError);
-				}
+			if (stream->hasTLSCertificate() && streamFeatures->hasAuthenticationMechanism("EXTERNAL")) {
+					state = Authenticating;
+					stream->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", "")));
 			}
 			else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) {
-				state_ = WaitingForCredentials;
+				state = WaitingForCredentials;
 				onNeedCredentials();
 			}
 			else {
-				finishSession(NoSupportedAuthMechanismsError);
+				finishSession(Error::NoSupportedAuthMechanismsError);
 			}
 		}
 		else {
 			// Start the session
-
-			// Add a whitespace ping layer
-			whitespacePingLayer_ = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer());
-			getStreamStack()->addLayer(whitespacePingLayer_);
-			whitespacePingLayer_->setActive();
+			stream->setWhitespacePingEnabled(true);
 
 			if (streamFeatures->hasSession()) {
-				needSessionStart_ = true;
+				needSessionStart = true;
 			}
 
 			if (streamFeatures->hasResourceBind()) {
-				state_ = BindingResource;
+				state = BindingResource;
 				boost::shared_ptr<ResourceBind> resourceBind(new ResourceBind());
-				if (!getLocalJID().getResource().isEmpty()) {
-					resourceBind->setResource(getLocalJID().getResource());
+				if (!localJID.getResource().isEmpty()) {
+					resourceBind->setResource(localJID.getResource());
 				}
-				getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind));
+				stream->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind));
 			}
-			else if (needSessionStart_) {
+			else if (needSessionStart) {
 				sendSessionStart();
 			}
 			else {
-				state_ = SessionStarted;
-				onSessionStarted();
+				state = Initialized;
+				onInitialized();
 			}
 		}
 	}
 	else if (dynamic_cast<AuthSuccess*>(element.get())) {
 		checkState(Authenticating);
-		state_ = WaitingForStreamStart;
-		getXMPPLayer()->resetParser();
+		state = WaitingForStreamStart;
+		stream->resetXMPPParser();
 		sendStreamHeader();
 	}
 	else if (dynamic_cast<AuthFailure*>(element.get())) {
-		finishSession(AuthenticationFailedError);
+		finishSession(Error::AuthenticationFailedError);
 	}
 	else if (dynamic_cast<TLSProceed*>(element.get())) {
-		tlsLayer_ = tlsLayerFactory_->createTLSLayer();
-		getStreamStack()->addLayer(tlsLayer_);
-		if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) {
-			finishSession(ClientCertificateLoadError);
-		}
-		else {
-			tlsLayer_->onConnected.connect(boost::bind(&ClientSession::handleTLSConnected, this));
-			tlsLayer_->onError.connect(boost::bind(&ClientSession::handleTLSError, this));
-			tlsLayer_->connect();
-		}
+		checkState(WaitingForEncrypt);
+		state = Encrypting;
+		stream->addTLSEncryption();
 	}
 	else if (dynamic_cast<StartTLSFailure*>(element.get())) {
-		finishSession(TLSError);
+		finishSession(Error::TLSError);
 	}
 	else if (IQ* iq = dynamic_cast<IQ*>(element.get())) {
-		if (state_ == BindingResource) {
+		if (state == BindingResource) {
 			boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>());
 			if (iq->getType() == IQ::Error && iq->getID() == "session-bind") {
-				finishSession(ResourceBindError);
+				finishSession(Error::ResourceBindError);
 			}
 			else if (!resourceBind) {
-				finishSession(UnexpectedElementError);
+				finishSession(Error::UnexpectedElementError);
 			}
 			else if (iq->getType() == IQ::Result) {
-				setLocalJID(resourceBind->getJID());
-				if (!getLocalJID().isValid()) {
-					finishSession(ResourceBindError);
+				localJID = resourceBind->getJID();
+				if (!localJID.isValid()) {
+					finishSession(Error::ResourceBindError);
 				}
-				if (needSessionStart_) {
+				if (needSessionStart) {
 					sendSessionStart();
 				}
 				else {
-					state_ = SessionStarted;
+					state = Initialized;
 				}
 			}
 			else {
-				finishSession(UnexpectedElementError);
+				finishSession(Error::UnexpectedElementError);
 			}
 		}
-		else if (state_ == StartingSession) {
+		else if (state == StartingSession) {
 			if (iq->getType() == IQ::Result) {
-				state_ = SessionStarted;
-				onSessionStarted();
+				state = Initialized;
+				onInitialized();
 			}
 			else if (iq->getType() == IQ::Error) {
-				finishSession(SessionStartError);
+				finishSession(Error::SessionStartError);
 			}
 			else {
-				finishSession(UnexpectedElementError);
+				finishSession(Error::UnexpectedElementError);
 			}
 		}
 		else {
-			finishSession(UnexpectedElementError);
+			finishSession(Error::UnexpectedElementError);
 		}
 	}
 	else {
 		// FIXME Not correct?
-		state_ = SessionStarted;
-		onSessionStarted();
+		state = Initialized;
+		onInitialized();
 	}
 }
 
 void ClientSession::sendSessionStart() {
-	state_ = StartingSession;
-	getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession())));
-}
-
-void ClientSession::handleSessionFinished(const boost::optional<SessionError>& error) {
-	if (whitespacePingLayer_) {
-		whitespacePingLayer_->setInactive();
-	}
-	
-	if (error) {
-		//assert(!error_);
-		state_ = Error;
-		error_ = error;
-	}
-	else {
-		state_ = Finished;
-	}
+	state = StartingSession;
+	stream->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession())));
 }
 
 bool ClientSession::checkState(State state) {
-	if (state_ != state) {
-		finishSession(UnexpectedElementError);
+	if (state != state) {
+		finishSession(Error::UnexpectedElementError);
 		return false;
 	}
 	return true;
@@ -221,18 +183,36 @@ bool ClientSession::checkState(State state) {
 
 void ClientSession::sendCredentials(const String& password) {
 	assert(WaitingForCredentials);
-	state_ = Authenticating;
-	getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(getLocalJID().getNode(), password).getValue())));
+	state = Authenticating;
+	stream->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(localJID.getNode(), password).getValue())));
 }
 
-void ClientSession::handleTLSConnected() {
-	state_ = WaitingForStreamStart;
-	getXMPPLayer()->resetParser();
+void ClientSession::handleTLSEncrypted() {
+	checkState(WaitingForEncrypt);
+	state = WaitingForStreamStart;
+	stream->resetXMPPParser();
 	sendStreamHeader();
 }
 
-void ClientSession::handleTLSError() {
-	finishSession(TLSError);
+void ClientSession::handleStreamError(boost::shared_ptr<Swift::Error> error) {
+	finishSession(error);
+}
+
+void ClientSession::finish() {
+	if (stream->isAvailable()) {
+		stream->writeFooter();
+	}
+	finishSession(boost::shared_ptr<Error>());
+}
+
+void ClientSession::finishSession(Error::Type error) {
+	finishSession(boost::shared_ptr<Swift::ClientSession::Error>(new Swift::ClientSession::Error(error)));
 }
 
+void ClientSession::finishSession(boost::shared_ptr<Swift::Error> error) {
+	stream->setWhitespacePingEnabled(false);
+	onFinished(error);
+}
+
+
 }
diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h
index fead182..1b01a66 100644
--- a/Swiften/Client/ClientSession.h
+++ b/Swiften/Client/ClientSession.h
@@ -2,87 +2,88 @@
 
 #include <boost/signal.hpp>
 #include <boost/shared_ptr.hpp>
+#include <boost/enable_shared_from_this.hpp>
 
-#include "Swiften/Session/Session.h"
+#include "Swiften/Base/Error.h"
+#include "Swiften/Session/SessionStream.h"
+#include "Swiften/Session/BasicSessionStream.h"
 #include "Swiften/Base/String.h"
 #include "Swiften/JID/JID.h"
 #include "Swiften/Elements/Element.h"
-#include "Swiften/Network/Connection.h"
-#include "Swiften/TLS/PKCS12Certificate.h"
 
 namespace Swift {
-	class PayloadParserFactoryCollection;
-	class PayloadSerializerCollection;
-	class ConnectionFactory;
-	class Connection;
-	class StreamStack;
-	class XMPPLayer;
-	class ConnectionLayer;
-	class TLSLayerFactory;
-	class TLSLayer;
-	class WhitespacePingLayer;
-
-	class ClientSession : public Session {
+	class ClientSession : public boost::enable_shared_from_this<ClientSession> {
 		public:
 			enum State {
 				Initial,
 				WaitingForStreamStart,
 				Negotiating,
 				Compressing,
+				WaitingForEncrypt,
 				Encrypting,
 				WaitingForCredentials,
 				Authenticating,
 				BindingResource,
 				StartingSession,
-				SessionStarted,
-				Error,
+				Initialized,
 				Finished
 			};
 
+			struct Error : public Swift::Error {
+				enum Type {
+					AuthenticationFailedError,
+					NoSupportedAuthMechanismsError,
+					UnexpectedElementError,
+					ResourceBindError,
+					SessionStartError,
+					TLSError,
+				} type;
+				Error(Type type) : type(type) {}
+			};
+
 			ClientSession(
 					const JID& jid, 
-					boost::shared_ptr<Connection>, 
-					TLSLayerFactory*, 
-					PayloadParserFactoryCollection*, 
-					PayloadSerializerCollection*);
+					boost::shared_ptr<SessionStream>);
 
 			State getState() const {
-				return state_;
+				return state;
 			}
 
-			boost::optional<SessionError> getError() const {
-				return error_;
-			}
+			void start();
+			void finish();
 
 			void sendCredentials(const String& password);
-			void setCertificate(const PKCS12Certificate& certificate);
+			void sendElement(boost::shared_ptr<Element> element);
 
 		private:
+			void finishSession(Error::Type error);
+			void finishSession(boost::shared_ptr<Swift::Error> error);
+
+			JID getRemoteJID() const {
+				return JID("", localJID.getDomain());
+			}
+
 			void sendStreamHeader();
 			void sendSessionStart();
 
-			virtual void handleSessionStarted();
-			virtual void handleSessionFinished(const boost::optional<SessionError>& error);
 			virtual void handleElement(boost::shared_ptr<Element>);
 			virtual void handleStreamStart(const ProtocolHeader&);
+			virtual void handleStreamError(boost::shared_ptr<Swift::Error>);
 
-			void handleTLSConnected();
-			void handleTLSError();
+			void handleTLSEncrypted();
 
-			void setError(SessionError);
 			bool checkState(State);
 
 		public:
 			boost::signal<void ()> onNeedCredentials;
-			boost::signal<void ()> onSessionStarted;
+			boost::signal<void ()> onInitialized;
+			boost::signal<void (boost::shared_ptr<Swift::Error>)> onFinished;
+			boost::signal<void (boost::shared_ptr<Element>)> onElementReceived;
 		
 		private:
-			TLSLayerFactory* tlsLayerFactory_;
-			State state_;
-			boost::optional<SessionError> error_;
-			boost::shared_ptr<TLSLayer> tlsLayer_;
-			boost::shared_ptr<WhitespacePingLayer> whitespacePingLayer_;
-			bool needSessionStart_;
-			PKCS12Certificate certificate_;
+			JID localJID;
+			State state;
+			boost::shared_ptr<SessionStream> stream;
+			bool needSessionStart;
 	};
 }
diff --git a/Swiften/Session/BasicSessionStream.cpp b/Swiften/Session/BasicSessionStream.cpp
index 73eaf5b..115dc7c 100644
--- a/Swiften/Session/BasicSessionStream.cpp
+++ b/Swiften/Session/BasicSessionStream.cpp
@@ -1,4 +1,4 @@
-// TODO: whitespacePingLayer_->setInactive();
+// TODO: Send out better errors
 
 #include "Swiften/Session/BasicSessionStream.h"
 
@@ -13,7 +13,7 @@
 
 namespace Swift {
 
-BasicSessionStream::BasicSessionStream(boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSLayerFactory* tlsLayerFactory) : connection(connection), payloadParserFactories(payloadParserFactories), payloadSerializers(payloadSerializers), tlsLayerFactory(tlsLayerFactory) {
+BasicSessionStream::BasicSessionStream(boost::shared_ptr<Connection> connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, TLSLayerFactory* tlsLayerFactory) : available(false), connection(connection), payloadParserFactories(payloadParserFactories), payloadSerializers(payloadSerializers), tlsLayerFactory(tlsLayerFactory) {
 }
 
 void BasicSessionStream::initialize() {
@@ -24,10 +24,13 @@ void BasicSessionStream::initialize() {
 	xmppLayer->onError.connect(boost::bind(
       &BasicSessionStream::handleXMPPError, shared_from_this()));
 
+	connection->onDisconnected.connect(boost::bind(&BasicSessionStream::handleConnectionError, shared_from_this(), _1));
 	connectionLayer = boost::shared_ptr<ConnectionLayer>(
       new ConnectionLayer(connection));
 
 	streamStack = new StreamStack(xmppLayer, connectionLayer);
+
+	available = true;
 }
 
 BasicSessionStream::~BasicSessionStream() {
@@ -35,29 +38,53 @@ BasicSessionStream::~BasicSessionStream() {
 }
 
 void BasicSessionStream::writeHeader(const ProtocolHeader& header) {
+	assert(available);
 	xmppLayer->writeHeader(header);
 }
 
 void BasicSessionStream::writeElement(boost::shared_ptr<Element> element) {
+	assert(available);
 	xmppLayer->writeElement(element);
 }
 
+void BasicSessionStream::writeFooter() {
+	assert(available);
+	xmppLayer->writeFooter();
+}
+
+bool BasicSessionStream::isAvailable() {
+	return available;
+}
+
 bool BasicSessionStream::supportsTLSEncryption() {
   return tlsLayerFactory && tlsLayerFactory->canCreate();
 }
 
 void BasicSessionStream::addTLSEncryption() {
+	assert(available);
 	tlsLayer = tlsLayerFactory->createTLSLayer();
-  streamStack->addLayer(tlsLayer);
-  // TODO: Add tls layer certificate if needed
-  tlsLayer->onError.connect(boost::bind(&BasicSessionStream::handleTLSError, shared_from_this()));
-  tlsLayer->connect();
+	if (hasTLSCertificate() && !tlsLayer->setClientCertificate(getTLSCertificate())) {
+		onError(boost::shared_ptr<Error>(new Error()));
+	}
+	else {
+		streamStack->addLayer(tlsLayer);
+		tlsLayer->onError.connect(boost::bind(&BasicSessionStream::handleTLSError, shared_from_this()));
+		tlsLayer->onConnected.connect(boost::bind(&BasicSessionStream::handleTLSConnected, shared_from_this()));
+		tlsLayer->connect();
+	}
 }
 
-void BasicSessionStream::addWhitespacePing() {
-  whitespacePingLayer = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer());
-  streamStack->addLayer(whitespacePingLayer);
-  whitespacePingLayer->setActive();
+void BasicSessionStream::setWhitespacePingEnabled(bool enabled) {
+	if (enabled && !whitespacePingLayer) {
+		whitespacePingLayer = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer());
+		streamStack->addLayer(whitespacePingLayer);
+	}
+	if (enabled) {
+		whitespacePingLayer->setActive();
+	}
+	else {
+		whitespacePingLayer->setInactive();
+	}
 }
 
 void BasicSessionStream::resetXMPPParser() {
@@ -73,10 +100,21 @@ void BasicSessionStream::handleElementReceived(boost::shared_ptr<Element> elemen
 }
 
 void BasicSessionStream::handleXMPPError() {
+	available = false;
 	onError(boost::shared_ptr<Error>(new Error()));
 }
 
+void BasicSessionStream::handleTLSConnected() {
+	onTLSEncrypted();
+}
+
 void BasicSessionStream::handleTLSError() {
+	available = false;
+	onError(boost::shared_ptr<Error>(new Error()));
+}
+
+void BasicSessionStream::handleConnectionError(const boost::optional<Connection::Error>&) {
+	available = false;
 	onError(boost::shared_ptr<Error>(new Error()));
 }
 
diff --git a/Swiften/Session/BasicSessionStream.h b/Swiften/Session/BasicSessionStream.h
index d248ebc..5fe0b4c 100644
--- a/Swiften/Session/BasicSessionStream.h
+++ b/Swiften/Session/BasicSessionStream.h
@@ -30,23 +30,29 @@ namespace Swift {
 
 			void initialize();
 
+			virtual bool isAvailable();
+
 			virtual void writeHeader(const ProtocolHeader& header);
 			virtual void writeElement(boost::shared_ptr<Element>);
+			virtual void writeFooter();
 
 			virtual bool supportsTLSEncryption();
 			virtual void addTLSEncryption();
 
-			virtual void addWhitespacePing();
+			virtual void setWhitespacePingEnabled(bool);
 
 			virtual void resetXMPPParser();
 
     private:
+			void handleConnectionError(const boost::optional<Connection::Error>& error);
       void handleXMPPError();
+			void handleTLSConnected();
       void handleTLSError();
 			void handleStreamStartReceived(const ProtocolHeader&);
 			void handleElementReceived(boost::shared_ptr<Element>);
 
     private:
+			bool available;
 			boost::shared_ptr<Connection> connection;
 			PayloadParserFactoryCollection* payloadParserFactories;
 			PayloadSerializerCollection* payloadSerializers;
diff --git a/Swiften/Session/SessionStream.h b/Swiften/Session/SessionStream.h
index 44a1980..b2444f5 100644
--- a/Swiften/Session/SessionStream.h
+++ b/Swiften/Session/SessionStream.h
@@ -6,6 +6,7 @@
 #include "Swiften/Elements/ProtocolHeader.h"
 #include "Swiften/Elements/Element.h"
 #include "Swiften/Base/Error.h"
+#include "Swiften/TLS/PKCS12Certificate.h"
 
 namespace Swift {
 	class SessionStream {
@@ -17,18 +18,38 @@ namespace Swift {
 
 			virtual ~SessionStream();
 
+			virtual bool isAvailable() = 0;
+
 			virtual void writeHeader(const ProtocolHeader& header) = 0;
+			virtual void writeFooter() = 0;
 			virtual void writeElement(boost::shared_ptr<Element>) = 0;
 
 			virtual bool supportsTLSEncryption() = 0;
 			virtual void addTLSEncryption() = 0;
-
-			virtual void addWhitespacePing() = 0;
+			virtual void setWhitespacePingEnabled(bool enabled) = 0;
 
 			virtual void resetXMPPParser() = 0;
 
+			void setTLSCertificate(const PKCS12Certificate& cert) {
+				certificate = cert;
+			}
+
+			virtual bool hasTLSCertificate() {
+				return !certificate.isNull();
+			}
+
+
 			boost::signal<void (const ProtocolHeader&)> onStreamStartReceived;
 			boost::signal<void (boost::shared_ptr<Element>)> onElementReceived;
 			boost::signal<void (boost::shared_ptr<Error>)> onError;
+			boost::signal<void ()> onTLSEncrypted;
+
+		protected:
+			const PKCS12Certificate& getTLSCertificate() const {
+				return certificate;
+			}
+
+		private:
+			PKCS12Certificate certificate;
 	};
 }
-- 
cgit v0.10.2-6-g49f6