From 958fe81b045e54ed6dadfe1fa9b14ac317811abf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Sun, 19 Jul 2009 15:48:19 +0200
Subject: Factor out common code from ServerFromClientSession.


diff --git a/Nim/main.cpp b/Nim/main.cpp
index 7ff4954..127878c 100644
--- a/Nim/main.cpp
+++ b/Nim/main.cpp
@@ -15,7 +15,7 @@
 #include "Swiften/EventLoop/MainEventLoop.h"
 #include "Swiften/EventLoop/SimpleEventLoop.h"
 #include "Swiften/EventLoop/EventOwner.h"
-#include "Swiften/Elements/Stanza.h"
+#include "Swiften/Elements/Element.h"
 #include "Swiften/LinkLocal/LinkLocalServiceInfo.h"
 #include "Swiften/LinkLocal/LinkLocalRoster.h"
 #include "Swiften/LinkLocal/LinkLocalSession.h"
@@ -69,9 +69,9 @@ class Server {
 				c->disconnect();
 			}
 			serverFromClientSession_ = boost::shared_ptr<ServerFromClientSession>(new ServerFromClientSession(idGenerator_.generateID(), c, &payloadParserFactories_, &payloadSerializers_, &userRegistry_));
-			serverFromClientSession_->onStanzaReceived.connect(boost::bind(&Server::handleStanzaReceived, this, _1, serverFromClientSession_));
+			serverFromClientSession_->onElementReceived.connect(boost::bind(&Server::handleElementReceived, this, _1, serverFromClientSession_));
 			serverFromClientSession_->onSessionFinished.connect(boost::bind(&Server::handleSessionFinished, this, serverFromClientSession_));
-			serverFromClientSession_->start();
+			serverFromClientSession_->startSession();
 		}
 
 		void handleNewLinkLocalConnection(boost::shared_ptr<Connection> connection) {
@@ -99,13 +99,15 @@ class Server {
 			linkLocalSessions_.erase(std::remove(linkLocalSessions_.begin(), linkLocalSessions_.end(), session), linkLocalSessions_.end());
 		}
 
-		void handleLinkLocalStanzaReceived(boost::shared_ptr<Stanza> stanza, boost::shared_ptr<LinkLocalSession> session) {
-			JID fromJID = session->getRemoteJID();
-			if (!linkLocalRoster_->hasItem(fromJID)) {
-				return; // TODO: Queue
+		void handleLinkLocalElementReceived(boost::shared_ptr<Element> element, boost::shared_ptr<LinkLocalSession> session) {
+			if (boost::shared_ptr<Stanza> stanza = boost::dynamic_pointer_cast<Stanza>(element)) {
+				JID fromJID = session->getRemoteJID();
+				if (!linkLocalRoster_->hasItem(fromJID)) {
+					return; // TODO: Queue
+				}
+				stanza->setFrom(fromJID);
+				serverFromClientSession_->sendElement(stanza);
 			}
-			stanza->setFrom(fromJID);
-			serverFromClientSession_->sendStanza(stanza);
 		}
 
 		void unregisterService() {
@@ -115,7 +117,12 @@ class Server {
 			}
 		}
 
-		void handleStanzaReceived(boost::shared_ptr<Stanza> stanza, boost::shared_ptr<ServerFromClientSession> session) {
+		void handleElementReceived(boost::shared_ptr<Element> element, boost::shared_ptr<ServerFromClientSession> session) {
+			boost::shared_ptr<Stanza> stanza = boost::dynamic_pointer_cast<Stanza>(element);
+			if (!stanza) {
+				return;
+			}
+
 			stanza->setFrom(session->getJID());
 			if (!stanza->getTo().isValid()) {
 				stanza->setTo(JID(session->getDomain()));
@@ -139,28 +146,28 @@ class Server {
 				if (boost::shared_ptr<IQ> iq = boost::dynamic_pointer_cast<IQ>(stanza)) {
 					if (iq->getPayload<RosterPayload>()) {
 						if (iq->getType() == IQ::Get) {
-							session->sendStanza(IQ::createResult(iq->getFrom(), iq->getID(), linkLocalRoster_->getRoster()));
+							session->sendElement(IQ::createResult(iq->getFrom(), iq->getID(), linkLocalRoster_->getRoster()));
 							rosterRequested_ = true;
 							foreach(const boost::shared_ptr<Presence> presence, linkLocalRoster_->getAllPresence()) {
-								session->sendStanza(presence);
+								session->sendElement(presence);
 							}
 						}
 						else {
-							session->sendStanza(IQ::createError(iq->getFrom(), iq->getID(), Error::Forbidden, Error::Cancel));
+							session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), Error::Forbidden, Error::Cancel));
 						}
 					}
 					if (iq->getPayload<VCard>()) {
 						if (iq->getType() == IQ::Get) {
 							boost::shared_ptr<VCard> vcard(new VCard());
 							vcard->setNickname(iq->getFrom().getNode());
-							session->sendStanza(IQ::createResult(iq->getFrom(), iq->getID(), vcard));
+							session->sendElement(IQ::createResult(iq->getFrom(), iq->getID(), vcard));
 						}
 						else {
-							session->sendStanza(IQ::createError(iq->getFrom(), iq->getID(), Error::Forbidden, Error::Cancel));
+							session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), Error::Forbidden, Error::Cancel));
 						}
 					}
 					else {
-						session->sendStanza(IQ::createError(iq->getFrom(), iq->getID(), Error::FeatureNotImplemented, Error::Cancel));
+						session->sendElement(IQ::createError(iq->getFrom(), iq->getID(), Error::FeatureNotImplemented, Error::Cancel));
 					}
 				}
 			}
@@ -169,7 +176,7 @@ class Server {
 				boost::shared_ptr<LinkLocalSession> outgoingSession = 
 						getLinkLocalSessionForJID(toJID);
 				if (outgoingSession) {
-					outgoingSession->sendStanza(stanza);
+					outgoingSession->sendElement(stanza);
 				}
 				else {
 					if (linkLocalRoster_->hasItem(toJID)) {
@@ -182,10 +189,10 @@ class Server {
 									&payloadParserFactories_, &payloadSerializers_,
 									&boostConnectionFactory_));
 						registerLinkLocalSession(outgoingSession);
-						outgoingSession->sendStanza(stanza);
+						outgoingSession->sendElement(stanza);
 					}
 					else {
-						session->sendStanza(IQ::createError(
+						session->sendElement(IQ::createError(
 								stanza->getFrom(), stanza->getID(), 
 								Error::RecipientUnavailable, Error::Wait));
 					}
@@ -195,7 +202,7 @@ class Server {
 
 		void registerLinkLocalSession(boost::shared_ptr<LinkLocalSession> session) {
 			session->onSessionFinished.connect(boost::bind(&Server::handleLinkLocalSessionFinished, this, session));
-			session->onStanzaReceived.connect(boost::bind(&Server::handleLinkLocalStanzaReceived, this, _1, session));
+			session->onElementReceived.connect(boost::bind(&Server::handleLinkLocalElementReceived, this, _1, session));
 			linkLocalSessions_.push_back(session);
 			session->start();
 		}
@@ -213,13 +220,13 @@ class Server {
 			if (rosterRequested_) {
 				boost::shared_ptr<IQ> iq = IQ::createRequest(IQ::Set, serverFromClientSession_->getJID(), idGenerator_.generateID(), roster);
 				iq->setFrom(serverFromClientSession_->getJID().toBare());
-				serverFromClientSession_->sendStanza(iq);
+				serverFromClientSession_->sendElement(iq);
 			}
 		}
 
 		void handlePresenceChanged(boost::shared_ptr<Presence> presence) {
 			if (rosterRequested_) {
-				serverFromClientSession_->sendStanza(presence);
+				serverFromClientSession_->sendElement(presence);
 			}
 		}
 
@@ -260,7 +267,6 @@ class Server {
 		boost::shared_ptr<ServerFromClientSession> serverFromClientSession_;
 		boost::shared_ptr<BoostConnectionServer> serverFromNetworkConnectionServer_;
 		std::vector< boost::shared_ptr<LinkLocalSession> > linkLocalSessions_;
-		std::vector< boost::shared_ptr<Stanza> > queuedOutgoingStanzas_;
 		FullPayloadParserFactoryCollection payloadParserFactories_;
 		FullPayloadSerializerCollection payloadSerializers_;
 		bool dnsSDServiceRegistered_;
diff --git a/Swiften/LinkLocal/IncomingLinkLocalSession.cpp b/Swiften/LinkLocal/IncomingLinkLocalSession.cpp
index db4b007..b73e979 100644
--- a/Swiften/LinkLocal/IncomingLinkLocalSession.cpp
+++ b/Swiften/LinkLocal/IncomingLinkLocalSession.cpp
@@ -57,7 +57,7 @@ void IncomingLinkLocalSession::handleElement(boost::shared_ptr<Element> element)
 	
 	if (isInitialized()) {
 		if (stanza) {
-			onStanzaReceived(stanza);
+			onElementReceived(stanza);
 		}
 		else {
 			std::cerr << "Received unexpected element" << std::endl;
diff --git a/Swiften/LinkLocal/LinkLocalSession.cpp b/Swiften/LinkLocal/LinkLocalSession.cpp
index 0f106ae..60227a7 100644
--- a/Swiften/LinkLocal/LinkLocalSession.cpp
+++ b/Swiften/LinkLocal/LinkLocalSession.cpp
@@ -43,7 +43,7 @@ void LinkLocalSession::finishSession() {
 	connection->disconnect();
 }
 
-void LinkLocalSession::sendStanza(boost::shared_ptr<Stanza> stanza) {
+void LinkLocalSession::sendElement(boost::shared_ptr<Element> stanza) {
 	xmppLayer->writeElement(stanza);
 }
 
diff --git a/Swiften/LinkLocal/LinkLocalSession.h b/Swiften/LinkLocal/LinkLocalSession.h
index 6629a2a..4bec14d 100644
--- a/Swiften/LinkLocal/LinkLocalSession.h
+++ b/Swiften/LinkLocal/LinkLocalSession.h
@@ -35,13 +35,13 @@ namespace Swift {
 			void finishSession();
 
 			// TODO: Make non-virtual when OutgoingSession is fixed
-			virtual void sendStanza(boost::shared_ptr<Stanza>);
+			virtual void sendElement(boost::shared_ptr<Element>);
 
 			virtual const JID& getRemoteJID() const = 0;
 
 			virtual void start() = 0;
 
-			boost::signal<void (boost::shared_ptr<Stanza>)> onStanzaReceived;
+			boost::signal<void (boost::shared_ptr<Element>)> onElementReceived;
 			boost::signal<void (boost::optional<Error>)> onSessionFinished;
 			boost::signal<void ()> onSessionStarted;
 			boost::signal<void (const ByteArray&)> onDataWritten;
diff --git a/Swiften/LinkLocal/OutgoingLinkLocalSession.cpp b/Swiften/LinkLocal/OutgoingLinkLocalSession.cpp
index f97520a..7415174 100644
--- a/Swiften/LinkLocal/OutgoingLinkLocalSession.cpp
+++ b/Swiften/LinkLocal/OutgoingLinkLocalSession.cpp
@@ -77,28 +77,28 @@ void OutgoingLinkLocalSession::handleConnected(bool error) {
 }
 
 void OutgoingLinkLocalSession::handleStreamStart(const ProtocolHeader&) {
-	foreach(const boost::shared_ptr<Stanza>& stanza, queuedStanzas_) {
-		LinkLocalSession::sendStanza(stanza);
+	foreach(const boost::shared_ptr<Element>& stanza, queuedElements_) {
+		LinkLocalSession::sendElement(stanza);
 	}
-	queuedStanzas_.clear();
+	queuedElements_.clear();
 	setInitialized();
 }
 
 void OutgoingLinkLocalSession::handleElement(boost::shared_ptr<Element> element) {
 	if (isInitialized()) {
-		boost::shared_ptr<Stanza> stanza = boost::dynamic_pointer_cast<Stanza>(element);
+		boost::shared_ptr<Element> stanza = boost::dynamic_pointer_cast<Element>(element);
 		if (stanza) {
-			onStanzaReceived(stanza);
+			onElementReceived(stanza);
 		}
 	}
 }
 
-void OutgoingLinkLocalSession::sendStanza(boost::shared_ptr<Stanza> stanza) {
+void OutgoingLinkLocalSession::sendElement(boost::shared_ptr<Element> stanza) {
 	if (isInitialized()) {
-		LinkLocalSession::sendStanza(stanza);
+		LinkLocalSession::sendElement(stanza);
 	}
 	else {
-		queuedStanzas_.push_back(stanza);
+		queuedElements_.push_back(stanza);
 	}
 }
 
diff --git a/Swiften/LinkLocal/OutgoingLinkLocalSession.h b/Swiften/LinkLocal/OutgoingLinkLocalSession.h
index 76ab803..d3fed0b 100644
--- a/Swiften/LinkLocal/OutgoingLinkLocalSession.h
+++ b/Swiften/LinkLocal/OutgoingLinkLocalSession.h
@@ -35,7 +35,7 @@ namespace Swift {
 
 			void start();
 
-			void sendStanza(boost::shared_ptr<Stanza> stanza);
+			void sendElement(boost::shared_ptr<Element> stanza);
 
 		private:
 			void handleElement(boost::shared_ptr<Element>);
@@ -49,7 +49,7 @@ namespace Swift {
 			String hostname_;
 			int port_;
 			boost::shared_ptr<DNSSDService> resolver_;
-			std::vector<boost::shared_ptr<Stanza> > queuedStanzas_;
+			std::vector<boost::shared_ptr<Element> > queuedElements_;
 			ConnectionFactory* connectionFactory_;
 	};
 }
diff --git a/Swiften/Server/ServerFromClientSession.cpp b/Swiften/Server/ServerFromClientSession.cpp
index 45df3be..4489654 100644
--- a/Swiften/Server/ServerFromClientSession.cpp
+++ b/Swiften/Server/ServerFromClientSession.cpp
@@ -5,8 +5,6 @@
 #include "Swiften/Elements/ProtocolHeader.h"
 #include "Swiften/Server/UserRegistry.h"
 #include "Swiften/Network/Connection.h"
-#include "Swiften/StreamStack/StreamStack.h"
-#include "Swiften/StreamStack/ConnectionLayer.h"
 #include "Swiften/StreamStack/XMPPLayer.h"
 #include "Swiften/Elements/StreamFeatures.h"
 #include "Swiften/Elements/ResourceBind.h"
@@ -25,62 +23,34 @@ ServerFromClientSession::ServerFromClientSession(
 		PayloadParserFactoryCollection* payloadParserFactories, 
 		PayloadSerializerCollection* payloadSerializers,
 		UserRegistry* userRegistry) : 
+			Session(connection, payloadParserFactories, payloadSerializers),
 			id_(id),
-			connection_(connection), 
-			payloadParserFactories_(payloadParserFactories), 
-			payloadSerializers_(payloadSerializers),
 			userRegistry_(userRegistry),
-			authenticated_(false),
-			initialized_(false) {
-	xmppLayer_ = boost::shared_ptr<XMPPLayer>(new XMPPLayer(payloadParserFactories_, payloadSerializers_));
-	connectionLayer_ = boost::shared_ptr<ConnectionLayer>(new ConnectionLayer(connection_));
-	streamStack_ = new StreamStack(xmppLayer_, connectionLayer_);
+			authenticated_(false) {
 }
 
-ServerFromClientSession::~ServerFromClientSession() {
-	delete streamStack_;
-}
-
-void ServerFromClientSession::start() {
-	xmppLayer_->onStreamStart.connect(
-			boost::bind(&ServerFromClientSession::handleStreamStart, this, _1));
-	xmppLayer_->onElement.connect(
-			boost::bind(&ServerFromClientSession::handleElement, this, _1));
-	//xmppLayer_->onError.connect(
-	//		boost::bind(&ServerFromClientSession::setError, this, XMLError));
-	xmppLayer_->onDataRead.connect(
-			boost::bind(boost::ref(onDataRead), _1));
-	xmppLayer_->onWriteData.connect(
-			boost::bind(boost::ref(onDataWritten), _1));
-	connection_->onDisconnected.connect(boost::bind(&ServerFromClientSession::handleDisconnected, shared_from_this(), _1));
-}
 
 void ServerFromClientSession::handleElement(boost::shared_ptr<Element> element) {
-	if (initialized_) {
-		if (boost::shared_ptr<Stanza> stanza = boost::dynamic_pointer_cast<Stanza>(element)) {
-			onStanzaReceived(stanza);
-		}
-		else {
-			std::cerr << "Received unexpected element" << std::endl;
-		}
+	if (isInitialized()) {
+		onElementReceived(element);
 	}
 	else {
 		if (AuthRequest* authRequest = dynamic_cast<AuthRequest*>(element.get())) {
 			if (authRequest->getMechanism() != "PLAIN") {
-				xmppLayer_->writeElement(boost::shared_ptr<AuthFailure>(new AuthFailure));
-				onSessionFinished();
+				getXMPPLayer()->writeElement(boost::shared_ptr<AuthFailure>(new AuthFailure));
+				finishSession(NoSupportedAuthMechanismsError);
 			}
 			else {
 				PLAINMessage plainMessage(authRequest->getMessage());
 				if (userRegistry_->isValidUserPassword(JID(plainMessage.getAuthenticationID(), domain_.getDomain()), plainMessage.getPassword())) {
-					xmppLayer_->writeElement(boost::shared_ptr<AuthSuccess>(new AuthSuccess()));
+					getXMPPLayer()->writeElement(boost::shared_ptr<AuthSuccess>(new AuthSuccess()));
 					user_ = plainMessage.getAuthenticationID();
 					authenticated_ = true;
-					xmppLayer_->resetParser();
+					getXMPPLayer()->resetParser();
 				}
 				else {
-					xmppLayer_->writeElement(boost::shared_ptr<AuthFailure>(new AuthFailure));
-					onSessionFinished();
+					getXMPPLayer()->writeElement(boost::shared_ptr<AuthFailure>(new AuthFailure));
+					finishSession(AuthenticationFailedError);
 				}
 			}
 		}
@@ -89,12 +59,11 @@ void ServerFromClientSession::handleElement(boost::shared_ptr<Element> element)
 				jid_ = JID(user_, domain_.getDomain(), resourceBind->getResource());
 				boost::shared_ptr<ResourceBind> resultResourceBind(new ResourceBind());
 				resultResourceBind->setJID(jid_);
-				xmppLayer_->writeElement(IQ::createResult(JID(), iq->getID(), resultResourceBind));
+				getXMPPLayer()->writeElement(IQ::createResult(JID(), iq->getID(), resultResourceBind));
 			}
 			else if (iq->getPayload<StartSession>()) {
-				initialized_ = true;
-				xmppLayer_->writeElement(IQ::createResult(jid_, iq->getID()));
-				onSessionStarted();
+				getXMPPLayer()->writeElement(IQ::createResult(jid_, iq->getID()));
+				setInitialized();
 			}
 		}
 	}
@@ -105,7 +74,7 @@ void ServerFromClientSession::handleStreamStart(const ProtocolHeader& incomingHe
 	ProtocolHeader header;
 	header.setFrom(incomingHeader.getTo());
 	header.setID(id_);
-	xmppLayer_->writeHeader(header);
+	getXMPPLayer()->writeHeader(header);
 
 	boost::shared_ptr<StreamFeatures> features(new StreamFeatures());
 	if (!authenticated_) {
@@ -115,16 +84,7 @@ void ServerFromClientSession::handleStreamStart(const ProtocolHeader& incomingHe
 		features->setHasResourceBind();
 		features->setHasSession();
 	}
-	xmppLayer_->writeElement(features);
+	getXMPPLayer()->writeElement(features);
 }
 
-void ServerFromClientSession::sendStanza(boost::shared_ptr<Stanza> stanza) {
-	xmppLayer_->writeElement(stanza);
-}
-
-void ServerFromClientSession::handleDisconnected(const boost::optional<Connection::Error>&) {
-	onSessionFinished();
-}
-
-
 }
diff --git a/Swiften/Server/ServerFromClientSession.h b/Swiften/Server/ServerFromClientSession.h
index 733c428..213f5c7 100644
--- a/Swiften/Server/ServerFromClientSession.h
+++ b/Swiften/Server/ServerFromClientSession.h
@@ -5,6 +5,7 @@
 #include <boost/enable_shared_from_this.hpp>
 
 #include "Swiften/Base/String.h"
+#include "Swiften/Session/Session.h"
 #include "Swiften/JID/JID.h"
 #include "Swiften/Network/Connection.h"
 
@@ -21,7 +22,7 @@ namespace Swift {
 	class Connection;
 	class ByteArray;
 
-	class ServerFromClientSession : public boost::enable_shared_from_this<ServerFromClientSession> {
+	class ServerFromClientSession : public Session {
 		public:
 			ServerFromClientSession(
 					const String& id,
@@ -29,11 +30,6 @@ namespace Swift {
 					PayloadParserFactoryCollection* payloadParserFactories, 
 					PayloadSerializerCollection* payloadSerializers,
 					UserRegistry* userRegistry);
-			~ServerFromClientSession();
-
-			void start();
-
-			void sendStanza(boost::shared_ptr<Stanza>);
 
 			const JID& getJID() const {
 				return jid_;
@@ -43,28 +39,14 @@ namespace Swift {
 				return domain_;
 			}
 
-			boost::signal<void (boost::shared_ptr<Stanza>)> onStanzaReceived;
-			boost::signal<void ()> onSessionFinished;
-			boost::signal<void ()> onSessionStarted;
-			boost::signal<void (const ByteArray&)> onDataWritten;
-			boost::signal<void (const ByteArray&)> onDataRead;
-
 		private:
-			void handleDisconnected(const boost::optional<Connection::Error>& error);
 			void handleElement(boost::shared_ptr<Element>);
 			void handleStreamStart(const ProtocolHeader& header);
 
 		private:
 			String id_;
-			boost::shared_ptr<Connection> connection_;
-			PayloadParserFactoryCollection* payloadParserFactories_;
-			PayloadSerializerCollection* payloadSerializers_;
 			UserRegistry* userRegistry_;
 			bool authenticated_;
-			bool initialized_;
-			boost::shared_ptr<XMPPLayer> xmppLayer_;
-			boost::shared_ptr<ConnectionLayer> connectionLayer_;
-			StreamStack* streamStack_;
 			JID domain_;
 			String user_;
 			JID jid_;
diff --git a/Swiften/Session/Session.cpp b/Swiften/Session/Session.cpp
index 84354e5..5ee98e7 100644
--- a/Swiften/Session/Session.cpp
+++ b/Swiften/Session/Session.cpp
@@ -42,10 +42,10 @@ void Session::initializeStreamStack() {
 	xmppLayer = boost::shared_ptr<XMPPLayer>(
 			new XMPPLayer(payloadParserFactories, payloadSerializers));
 	xmppLayer->onStreamStart.connect(
-			boost::bind(&Session::handleStreamStart, this, _1));
-	xmppLayer->onElement.connect(boost::bind(&Session::handleElement, this, _1));
+			boost::bind(&Session::handleStreamStart, shared_from_this(), _1));
+	xmppLayer->onElement.connect(boost::bind(&Session::handleElement, shared_from_this(), _1));
 	xmppLayer->onError.connect(
-			boost::bind(&Session::finishSession, this, XMLError));
+			boost::bind(&Session::finishSession, shared_from_this(), XMLError));
 	xmppLayer->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1));
 	xmppLayer->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1));
 	connection->onDisconnected.connect(
-- 
cgit v0.10.2-6-g49f6