From 43eadfb5d884407c54ccd41cf46881ae374fdf15 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Mon, 13 Jul 2009 21:30:56 +0200
Subject: Implemented session initialization.


diff --git a/Limber/main.cpp b/Limber/main.cpp
index 969a390..c3db296 100644
--- a/Limber/main.cpp
+++ b/Limber/main.cpp
@@ -8,6 +8,7 @@
 #include <boost/enable_shared_from_this.hpp>
 #include <boost/thread.hpp>
 
+#include "Swiften/Server/SimpleUserRegistry.h"
 #include "Swiften/Base/ByteArray.h"
 #include "Swiften/Base/IDGenerator.h"
 #include "Swiften/EventLoop/MainEventLoop.h"
@@ -122,7 +123,7 @@ class BoostConnectionServer : public ConnectionServer {
 
 class Server {
 	public:
-		Server() {
+		Server(UserRegistry* userRegistry) : userRegistry_(userRegistry) {
 			serverFromClientConnectionServer_ = new BoostConnectionServer(5222, boostIOServiceThread_.getIOService());
 			serverFromClientConnectionServer_->onNewConnection.connect(boost::bind(&Server::handleNewConnection, this, _1));
 		}
@@ -133,7 +134,7 @@ class Server {
 
 	private:
 		void handleNewConnection(boost::shared_ptr<IncomingConnection> c) {
-			ServerFromClientSession* session = new ServerFromClientSession(idGenerator_.generateID(), c, &payloadParserFactories_, &payloadSerializers_);
+			ServerFromClientSession* session = new ServerFromClientSession(idGenerator_.generateID(), c, &payloadParserFactories_, &payloadSerializers_, userRegistry_);
 			serverFromClientSessions_.push_back(session);
 			session->onSessionFinished.connect(boost::bind(&Server::handleSessionFinished, this, session));
 		}
@@ -145,6 +146,7 @@ class Server {
 
 	private:
 		IDGenerator idGenerator_;
+		UserRegistry* userRegistry_;
 		BoostIOServiceThread boostIOServiceThread_;
 		BoostConnectionServer* serverFromClientConnectionServer_;
 		std::vector<ServerFromClientSession*> serverFromClientSessions_;
@@ -154,7 +156,9 @@ class Server {
 
 int main() {
 	SimpleEventLoop eventLoop;
-	Server server;
+	SimpleUserRegistry userRegistry;
+	userRegistry.addUser(JID("remko@limber"), "pass");
+	Server server(&userRegistry);
 	eventLoop.run();
   return 0;
 }
diff --git a/Swiften/Server/Makefile.inc b/Swiften/Server/Makefile.inc
index 1ab98d7..ae10bd2 100644
--- a/Swiften/Server/Makefile.inc
+++ b/Swiften/Server/Makefile.inc
@@ -2,6 +2,8 @@ SWIFTEN_SOURCES += \
 	Swiften/Server/ServerSession.cpp \
 	Swiften/Server/ServerFromClientSession.cpp \
 	Swiften/Server/ServerStanzaRouter.cpp \
-	Swiften/Server/IncomingConnection.cpp
+	Swiften/Server/IncomingConnection.cpp \
+	Swiften/Server/UserRegistry.cpp \
+	Swiften/Server/SimpleUserRegistry.cpp
 
 include Swiften/Server/UnitTest/Makefile.inc
diff --git a/Swiften/Server/ServerFromClientSession.cpp b/Swiften/Server/ServerFromClientSession.cpp
index be8b601..4fc517f 100644
--- a/Swiften/Server/ServerFromClientSession.cpp
+++ b/Swiften/Server/ServerFromClientSession.cpp
@@ -2,10 +2,19 @@
 
 #include <boost/bind.hpp>
 
+#include "Swiften/Server/UserRegistry.h"
 #include "Swiften/Network/IncomingConnection.h"
 #include "Swiften/StreamStack/StreamStack.h"
 #include "Swiften/StreamStack/IncomingConnectionLayer.h"
 #include "Swiften/StreamStack/XMPPLayer.h"
+#include "Swiften/Elements/StreamFeatures.h"
+#include "Swiften/Elements/ResourceBind.h"
+#include "Swiften/Elements/StartSession.h"
+#include "Swiften/Elements/IQ.h"
+#include "Swiften/Elements/AuthSuccess.h"
+#include "Swiften/Elements/AuthFailure.h"
+#include "Swiften/Elements/AuthRequest.h"
+#include "Swiften/SASL/PLAINMessage.h"
 
 namespace Swift {
 
@@ -13,11 +22,15 @@ ServerFromClientSession::ServerFromClientSession(
 		const String& id,
 		boost::shared_ptr<IncomingConnection> connection, 
 		PayloadParserFactoryCollection* payloadParserFactories, 
-		PayloadSerializerCollection* payloadSerializers) : 
+		PayloadSerializerCollection* payloadSerializers,
+		UserRegistry* userRegistry) : 
 			id_(id),
 			connection_(connection), 
 			payloadParserFactories_(payloadParserFactories), 
-			payloadSerializers_(payloadSerializers) {
+			payloadSerializers_(payloadSerializers),
+			userRegistry_(userRegistry),
+			authenticated_(false),
+			initialized_(false) {
 	xmppLayer_ = new XMPPLayer(payloadParserFactories_, payloadSerializers_);
 	xmppLayer_->onStreamStart.connect(
 			boost::bind(&ServerFromClientSession::handleStreamStart, this, _2));
@@ -39,12 +52,57 @@ ServerFromClientSession::~ServerFromClientSession() {
 	delete xmppLayer_;
 }
 
-void ServerFromClientSession::handleElement(boost::shared_ptr<Element>) {
+void ServerFromClientSession::handleElement(boost::shared_ptr<Element> element) {
+	if (initialized_) {
+		onElementReceived(element);
+	}
+	else {
+		if (AuthRequest* authRequest = dynamic_cast<AuthRequest*>(element.get())) {
+			if (authRequest->getMechanism() != "PLAIN") {
+				xmppLayer_->writeElement(boost::shared_ptr<AuthFailure>(new AuthFailure));
+				onSessionFinished();
+			}
+			else {
+				PLAINMessage plainMessage(authRequest->getMessage());
+				if (userRegistry_->isValidUserPassword(JID(plainMessage.getAuthenticationID(), domain_), plainMessage.getPassword())) {
+					xmppLayer_->writeElement(boost::shared_ptr<AuthSuccess>(new AuthSuccess()));
+					authenticated_ = true;
+					xmppLayer_->resetParser();
+				}
+				else {
+					xmppLayer_->writeElement(boost::shared_ptr<AuthFailure>(new AuthFailure));
+					onSessionFinished();
+				}
+			}
+		}
+		else if (IQ* iq = dynamic_cast<IQ*>(element.get())) {
+			if (boost::shared_ptr<ResourceBind> resourceBind = iq->getPayload<ResourceBind>()) {
+				jid_ = JID(user_, domain_, resourceBind->getResource());
+				boost::shared_ptr<ResourceBind> resultResourceBind(new ResourceBind());
+				resultResourceBind->setJID(jid_);
+				xmppLayer_->writeElement(IQ::createResult(JID(), iq->getID(), resultResourceBind));
+			}
+			else if (iq->getPayload<StartSession>()) {
+				initialized_ = true;
+				xmppLayer_->writeElement(IQ::createResult(jid_, iq->getID()));
+			}
+		}
+	}
 }
 
 void ServerFromClientSession::handleStreamStart(const String& domain) {
 	domain_ = domain;
 	xmppLayer_->writeHeader(domain_, id_);
+
+	boost::shared_ptr<StreamFeatures> features(new StreamFeatures());
+	if (!authenticated_) {
+		features->addAuthenticationMechanism("PLAIN");
+	}
+	else {
+		features->setHasResourceBind();
+		features->setHasSession();
+	}
+	xmppLayer_->writeElement(features);
 }
 
 }
diff --git a/Swiften/Server/ServerFromClientSession.h b/Swiften/Server/ServerFromClientSession.h
index 3413a03..9b340bc 100644
--- a/Swiften/Server/ServerFromClientSession.h
+++ b/Swiften/Server/ServerFromClientSession.h
@@ -4,12 +4,14 @@
 #include <boost/signal.hpp>
 
 #include "Swiften/Base/String.h"
+#include "Swiften/JID/JID.h"
 
 namespace Swift {
 	class Element;
 	class PayloadParserFactoryCollection;
 	class PayloadSerializerCollection;
 	class StreamStack;
+	class UserRegistry;
 	class XMPPLayer;
 	class IncomingConnectionLayer;
 	class IncomingConnection;
@@ -21,10 +23,12 @@ namespace Swift {
 					const String& id,
 					boost::shared_ptr<IncomingConnection> connection, 
 					PayloadParserFactoryCollection* payloadParserFactories, 
-					PayloadSerializerCollection* payloadSerializers);
+					PayloadSerializerCollection* payloadSerializers,
+					UserRegistry* userRegistry);
 			~ServerFromClientSession();
 
-			boost::signal<void()> onSessionFinished;
+			boost::signal<void (boost::shared_ptr<Element>)> onElementReceived;
+			boost::signal<void ()> onSessionFinished;
 			boost::signal<void (const ByteArray&)> onDataWritten;
 			boost::signal<void (const ByteArray&)> onDataRead;
 
@@ -37,9 +41,14 @@ namespace Swift {
 			boost::shared_ptr<IncomingConnection> connection_;
 			PayloadParserFactoryCollection* payloadParserFactories_;
 			PayloadSerializerCollection* payloadSerializers_;
+			UserRegistry* userRegistry_;
+			bool authenticated_;
+			bool initialized_;
 			IncomingConnectionLayer* connectionLayer_;
 			StreamStack* streamStack_;
 			XMPPLayer* xmppLayer_;
 			String domain_;
+			String user_;
+			JID jid_;
 	};
 }
diff --git a/Swiften/Server/SimpleUserRegistry.cpp b/Swiften/Server/SimpleUserRegistry.cpp
new file mode 100644
index 0000000..1a6743a
--- /dev/null
+++ b/Swiften/Server/SimpleUserRegistry.cpp
@@ -0,0 +1,17 @@
+#include "Swiften/Server/SimpleUserRegistry.h"
+
+namespace Swift {
+
+SimpleUserRegistry::SimpleUserRegistry() {
+}
+
+bool SimpleUserRegistry::isValidUserPassword(const JID& user, const String& password) const {
+	std::map<JID,String>::const_iterator i = users.find(user);
+	return i != users.end() ? i->second == password : false;
+}
+
+void SimpleUserRegistry::addUser(const JID& user, const String& password) {
+	users.insert(std::make_pair(user, password));
+}
+
+}
diff --git a/Swiften/Server/SimpleUserRegistry.h b/Swiften/Server/SimpleUserRegistry.h
new file mode 100644
index 0000000..253025d
--- /dev/null
+++ b/Swiften/Server/SimpleUserRegistry.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#include <map>
+
+#include "Swiften/JID/JID.h"
+#include "Swiften/Base/String.h"
+#include "Swiften/Server/UserRegistry.h"
+
+namespace Swift {
+	class String;
+
+	class SimpleUserRegistry : public UserRegistry {
+		public:
+			SimpleUserRegistry();
+
+			virtual bool isValidUserPassword(const JID& user, const String& password) const;
+			void addUser(const JID& user, const String& password);
+
+		private:
+			std::map<JID, String> users;
+	};
+}
diff --git a/Swiften/Server/UserRegistry.cpp b/Swiften/Server/UserRegistry.cpp
new file mode 100644
index 0000000..d1a509f
--- /dev/null
+++ b/Swiften/Server/UserRegistry.cpp
@@ -0,0 +1,8 @@
+#include "Swiften/Server/UserRegistry.h"
+
+namespace Swift {
+
+UserRegistry::~UserRegistry() {
+}
+
+}
diff --git a/Swiften/Server/UserRegistry.h b/Swiften/Server/UserRegistry.h
new file mode 100644
index 0000000..5ced019
--- /dev/null
+++ b/Swiften/Server/UserRegistry.h
@@ -0,0 +1,13 @@
+#pragma once
+
+namespace Swift {
+	class String;
+	class JID;
+
+	class UserRegistry {
+		public:
+			virtual ~UserRegistry();
+
+			virtual bool isValidUserPassword(const JID& user, const String& password) const = 0;
+	};
+}
-- 
cgit v0.10.2-6-g49f6