From 9173ea9c7d9e35a6b0fd87ee51a07f4e96b53fd6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Remko=20Tron=C3=A7on?= <git@el-tramo.be>
Date: Fri, 15 Oct 2010 21:42:01 +0200
Subject: Added ComponentSession.


diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h
index 2c1bda8..83744e0 100644
--- a/Swiften/Client/ClientSession.h
+++ b/Swiften/Client/ClientSession.h
@@ -12,7 +12,6 @@
 
 #include "Swiften/Base/Error.h"
 #include "Swiften/Session/SessionStream.h"
-#include "Swiften/Session/BasicSessionStream.h"
 #include "Swiften/Base/String.h"
 #include "Swiften/JID/JID.h"
 #include "Swiften/Elements/Element.h"
diff --git a/Swiften/Component/ComponentSession.cpp b/Swiften/Component/ComponentSession.cpp
new file mode 100644
index 0000000..75ee467
--- /dev/null
+++ b/Swiften/Component/ComponentSession.cpp
@@ -0,0 +1,108 @@
+/*
+ * Copyright (c) 2010 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#include "Swiften/Component/ComponentSession.h"
+
+#include <boost/bind.hpp>
+
+#include "Swiften/Elements/ProtocolHeader.h"
+#include "Swiften/Elements/ComponentHandshake.h"
+#include "Swiften/Session/SessionStream.h"
+#include "Swiften/Component/ComponentHandshakeGenerator.h"
+
+namespace Swift {
+
+ComponentSession::ComponentSession(const JID& jid, const String& secret, boost::shared_ptr<SessionStream> stream) : jid(jid), secret(secret), stream(stream), state(Initial) {
+}
+
+ComponentSession::~ComponentSession() {
+}
+
+void ComponentSession::start() {
+	stream->onStreamStartReceived.connect(boost::bind(&ComponentSession::handleStreamStart, shared_from_this(), _1));
+	stream->onElementReceived.connect(boost::bind(&ComponentSession::handleElement, shared_from_this(), _1));
+	stream->onError.connect(boost::bind(&ComponentSession::handleStreamError, shared_from_this(), _1));
+
+	assert(state == Initial);
+	state = WaitingForStreamStart;
+	sendStreamHeader();
+}
+
+void ComponentSession::sendStreamHeader() {
+	ProtocolHeader header;
+	header.setTo(jid);
+	stream->writeHeader(header);
+}
+
+void ComponentSession::sendStanza(boost::shared_ptr<Stanza> stanza) {
+	stream->writeElement(stanza);
+}
+
+void ComponentSession::handleStreamStart(const ProtocolHeader& header) {
+	checkState(WaitingForStreamStart);
+	state = Authenticating;
+	stream->writeElement(ComponentHandshake::ref(new ComponentHandshake(ComponentHandshakeGenerator::getHandshake(header.getID(), secret))));
+}
+
+void ComponentSession::handleElement(boost::shared_ptr<Element> element) {
+	if (boost::shared_ptr<Stanza> stanza = boost::dynamic_pointer_cast<Stanza>(element)) {
+		if (getState() == Initialized) {
+			onStanzaReceived(stanza);
+		}
+		else {
+			finishSession(Error::UnexpectedElementError);
+		}
+	}
+	else if (ComponentHandshake::cast(element)) {
+		if (!checkState(Authenticating)) {
+			return;
+		}
+		stream->setWhitespacePingEnabled(true);
+		state = Initialized;
+		onInitialized();
+	}
+	else if (getState() == Authenticating) {
+		// FIXME: We should actually check the element received
+		finishSession(Error::AuthenticationFailedError);
+	}
+	else {
+		finishSession(Error::UnexpectedElementError);
+	}
+}
+
+bool ComponentSession::checkState(State state) {
+	if (this->state != state) {
+		finishSession(Error::UnexpectedElementError);
+		return false;
+	}
+	return true;
+}
+
+void ComponentSession::handleStreamError(boost::shared_ptr<Swift::Error> error) {
+	finishSession(error);
+}
+
+void ComponentSession::finish() {
+	finishSession(boost::shared_ptr<Error>());
+}
+
+void ComponentSession::finishSession(Error::Type error) {
+	finishSession(boost::shared_ptr<Swift::ComponentSession::Error>(new Swift::ComponentSession::Error(error)));
+}
+
+void ComponentSession::finishSession(boost::shared_ptr<Swift::Error> error) {
+	state = Finished;
+	stream->setWhitespacePingEnabled(false);
+	stream->onStreamStartReceived.disconnect(boost::bind(&ComponentSession::handleStreamStart, shared_from_this(), _1));
+	stream->onElementReceived.disconnect(boost::bind(&ComponentSession::handleElement, shared_from_this(), _1));
+	stream->onError.disconnect(boost::bind(&ComponentSession::handleStreamError, shared_from_this(), _1));
+	if (stream->isAvailable()) {
+		stream->writeFooter();
+	}
+	onFinished(error);
+}
+
+}
diff --git a/Swiften/Component/ComponentSession.h b/Swiften/Component/ComponentSession.h
new file mode 100644
index 0000000..cbfa227
--- /dev/null
+++ b/Swiften/Component/ComponentSession.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2010 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#pragma once
+
+#include <boost/shared_ptr.hpp>
+#include <boost/enable_shared_from_this.hpp>
+
+#include "Swiften/JID/JID.h"
+#include "Swiften/Base/boost_bsignals.h"
+#include "Swiften/Base/Error.h"
+#include "Swiften/Base/String.h"
+#include "Swiften/Elements/Element.h"
+#include "Swiften/Elements/Stanza.h"
+#include "Swiften/Session/SessionStream.h"
+
+namespace Swift {
+	class ComponentAuthenticator;
+
+	class ComponentSession : public boost::enable_shared_from_this<ComponentSession> {
+		public:
+			enum State {
+				Initial,
+				WaitingForStreamStart,
+				Authenticating,
+				Initialized,
+				Finished
+			};
+
+			struct Error : public Swift::Error {
+				enum Type {
+					AuthenticationFailedError,
+					UnexpectedElementError,
+				} type;
+				Error(Type type) : type(type) {}
+			};
+
+			~ComponentSession();
+
+			static boost::shared_ptr<ComponentSession> create(const JID& jid, const String& secret, boost::shared_ptr<SessionStream> stream) {
+				return boost::shared_ptr<ComponentSession>(new ComponentSession(jid, secret, stream));
+			}
+
+			State getState() const {
+				return state;
+			}
+
+			void start();
+			void finish();
+
+			void sendStanza(boost::shared_ptr<Stanza>);
+
+		public:
+			boost::signal<void ()> onInitialized;
+			boost::signal<void (boost::shared_ptr<Swift::Error>)> onFinished;
+			boost::signal<void (boost::shared_ptr<Stanza>)> onStanzaReceived;
+		
+		private:
+			ComponentSession(const JID& jid, const String& secret, boost::shared_ptr<SessionStream>);
+
+			void finishSession(Error::Type error);
+			void finishSession(boost::shared_ptr<Swift::Error> error);
+
+			void sendStreamHeader();
+
+			void handleElement(boost::shared_ptr<Element>);
+			void handleStreamStart(const ProtocolHeader&);
+			void handleStreamError(boost::shared_ptr<Swift::Error>);
+
+			bool checkState(State);
+
+		private:
+			JID jid;
+			String secret;
+			boost::shared_ptr<SessionStream> stream;
+			State state;
+	};
+}
diff --git a/Swiften/Component/SConscript b/Swiften/Component/SConscript
index 6d86575..1f08301 100644
--- a/Swiften/Component/SConscript
+++ b/Swiften/Component/SConscript
@@ -3,6 +3,7 @@ Import("swiften_env")
 sources = [
 		"ComponentHandshakeGenerator.cpp",
 		"ComponentConnector.cpp",
+		"ComponentSession.cpp",
 	]
 
-swiften_env.Append(SWIFTEN_OBJECTS = swiften_env.StaticObject(sources))
\ No newline at end of file
+swiften_env.Append(SWIFTEN_OBJECTS = swiften_env.StaticObject(sources))
diff --git a/Swiften/Component/UnitTest/ComponentSessionTest.cpp b/Swiften/Component/UnitTest/ComponentSessionTest.cpp
new file mode 100644
index 0000000..1722ad4
--- /dev/null
+++ b/Swiften/Component/UnitTest/ComponentSessionTest.cpp
@@ -0,0 +1,190 @@
+/*
+ * Copyright (c) 2010 Remko Tronçon
+ * Licensed under the GNU General Public License v3.
+ * See Documentation/Licenses/GPLv3.txt for more information.
+ */
+
+#include <cppunit/extensions/HelperMacros.h>
+#include <cppunit/extensions/TestFactoryRegistry.h>
+#include <deque>
+#include <boost/bind.hpp>
+#include <boost/optional.hpp>
+
+#include "Swiften/Session/SessionStream.h"
+#include "Swiften/Component/ComponentSession.h"
+#include "Swiften/Elements/ComponentHandshake.h"
+#include "Swiften/Elements/AuthFailure.h"
+
+using namespace Swift;
+
+class ComponentSessionTest : public CppUnit::TestFixture {
+		CPPUNIT_TEST_SUITE(ComponentSessionTest);
+		CPPUNIT_TEST(testStart);
+		CPPUNIT_TEST(testStart_Error);
+		CPPUNIT_TEST(testStart_Unauthorized);
+		CPPUNIT_TEST_SUITE_END();
+
+	public:
+		void setUp() {
+			server = boost::shared_ptr<MockSessionStream>(new MockSessionStream());
+			sessionFinishedReceived = false;
+		}
+
+		void testStart() {
+			boost::shared_ptr<ComponentSession> session(createSession());
+			session->start();
+			server->receiveStreamStart();
+			server->sendStreamStart();
+			server->receiveHandshake();
+			server->sendHandshakeResponse();
+
+			CPPUNIT_ASSERT(server->whitespacePingEnabled);
+
+			session->finish();
+			CPPUNIT_ASSERT(!server->whitespacePingEnabled);
+
+		}
+
+		void testStart_Error() {
+			boost::shared_ptr<ComponentSession> session(createSession());
+			session->start();
+			server->breakConnection();
+
+			CPPUNIT_ASSERT_EQUAL(ComponentSession::Finished, session->getState());
+			CPPUNIT_ASSERT(sessionFinishedReceived);
+			CPPUNIT_ASSERT(sessionFinishedError);
+		}
+
+		void testStart_Unauthorized() {
+			boost::shared_ptr<ComponentSession> session(createSession());
+			session->start();
+			server->receiveStreamStart();
+			server->sendStreamStart();
+			server->receiveHandshake();
+			server->sendHandshakeError();
+
+			CPPUNIT_ASSERT_EQUAL(ComponentSession::Finished, session->getState());
+			CPPUNIT_ASSERT(sessionFinishedReceived);
+			CPPUNIT_ASSERT(sessionFinishedError);
+		}
+
+	private:
+		boost::shared_ptr<ComponentSession> createSession() {
+			boost::shared_ptr<ComponentSession> session = ComponentSession::create(JID("service.foo.com"), "servicesecret", server);
+			session->onFinished.connect(boost::bind(&ComponentSessionTest::handleSessionFinished, this, _1));
+			return session;
+		}
+
+		void handleSessionFinished(boost::shared_ptr<Error> error) {
+			sessionFinishedReceived = true;
+			sessionFinishedError = error;
+		}
+
+		class MockSessionStream : public SessionStream {
+			public:
+				struct Event {
+					Event(boost::shared_ptr<Element> element) : element(element), footer(false) {}
+					Event(const ProtocolHeader& header) : header(header), footer(false) {}
+					Event() : footer(true) {}
+					
+					boost::shared_ptr<Element> element;
+					boost::optional<ProtocolHeader> header;
+					bool footer;
+				};
+
+				MockSessionStream() : available(true), whitespacePingEnabled(false), resetCount(0) {
+				}
+
+				virtual bool isAvailable() {
+					return available;
+				}
+
+				virtual void writeHeader(const ProtocolHeader& header) {
+					receivedEvents.push_back(Event(header));
+				}
+
+				virtual void writeFooter() {
+					receivedEvents.push_back(Event());
+				}
+
+				virtual void writeElement(boost::shared_ptr<Element> element) {
+					receivedEvents.push_back(Event(element));
+				}
+
+				virtual bool supportsTLSEncryption() {
+					return false;
+				}
+
+				virtual void addTLSEncryption() {
+					assert(false);
+				}
+
+				virtual bool isTLSEncrypted() {
+					return false;
+				}
+
+				virtual void addZLibCompression() {
+					assert(false);
+				}
+
+				virtual void setWhitespacePingEnabled(bool enabled) {
+					whitespacePingEnabled = enabled;
+				}
+
+				virtual void resetXMPPParser() {
+					resetCount++;
+				}
+
+				void breakConnection() {
+					onError(boost::shared_ptr<SessionStream::Error>(new SessionStream::Error(SessionStream::Error::ConnectionReadError)));
+				}
+
+				void sendStreamStart() {
+					ProtocolHeader header;
+					header.setFrom("service.foo.com");
+					return onStreamStartReceived(header);
+				}
+
+				void sendHandshakeResponse() {
+					onElementReceived(ComponentHandshake::ref(new ComponentHandshake()));
+				}
+
+				void sendHandshakeError() {
+					// FIXME: This isn't the correct element
+					onElementReceived(AuthFailure::ref(new AuthFailure()));
+				}
+
+				void receiveStreamStart() {
+					Event event = popEvent();
+					CPPUNIT_ASSERT(event.header);
+				}
+
+				void receiveHandshake() {
+					Event event = popEvent();
+					CPPUNIT_ASSERT(event.element);
+					ComponentHandshake::ref handshake(ComponentHandshake::cast(event.element));
+					CPPUNIT_ASSERT(handshake);
+					CPPUNIT_ASSERT_EQUAL(String("4c4f8a41141722c8bbfbdd92d827f7b2fc0a542b"), handshake->getData());
+				}
+
+				Event popEvent() {
+					CPPUNIT_ASSERT(receivedEvents.size() > 0);
+					Event event = receivedEvents.front();
+					receivedEvents.pop_front();
+					return event;
+				}
+
+				bool available;
+				bool whitespacePingEnabled;
+				String bindID;
+				int resetCount;
+				std::deque<Event> receivedEvents;
+		};
+
+		boost::shared_ptr<MockSessionStream> server;
+		bool sessionFinishedReceived;
+		bool needCredentials;
+		boost::shared_ptr<Error> sessionFinishedError;
+};
+
+CPPUNIT_TEST_SUITE_REGISTRATION(ComponentSessionTest);
diff --git a/Swiften/Elements/AuthFailure.h b/Swiften/Elements/AuthFailure.h
index ff1468e..348a19b 100644
--- a/Swiften/Elements/AuthFailure.h
+++ b/Swiften/Elements/AuthFailure.h
@@ -4,16 +4,14 @@
  * See Documentation/Licenses/GPLv3.txt for more information.
  */
 
-#ifndef SWIFTEN_AuthFailure_H
-#define SWIFTEN_AuthFailure_H
+#pragma once
 
 #include "Swiften/Elements/Element.h"
+#include "Swiften/Base/Shared.h"
 
 namespace Swift {
-	class AuthFailure : public Element {
+	class AuthFailure : public Element, public Shared<AuthFailure> {
 		public:
 			AuthFailure() {}
 	};
 }
-
-#endif
diff --git a/Swiften/Elements/ComponentHandshake.h b/Swiften/Elements/ComponentHandshake.h
index 1067310..d9088e0 100644
--- a/Swiften/Elements/ComponentHandshake.h
+++ b/Swiften/Elements/ComponentHandshake.h
@@ -8,9 +8,10 @@
 
 #include "Swiften/Elements/Element.h"
 #include "Swiften/Base/String.h"
+#include "Swiften/Base/Shared.h"
 
 namespace Swift {
-	class ComponentHandshake : public Element {
+	class ComponentHandshake : public Element, public Shared<ComponentHandshake> {
 		public:
 			ComponentHandshake(const String& data = "") : data(data) {
 			}
diff --git a/Swiften/SConscript b/Swiften/SConscript
index 84d9a67..ca4f91a 100644
--- a/Swiften/SConscript
+++ b/Swiften/SConscript
@@ -159,6 +159,7 @@ if env["SCONS_STAGE"] == "build" :
 			File("Compress/UnitTest/ZLibDecompressorTest.cpp"),
 			File("Component/UnitTest/ComponentHandshakeGeneratorTest.cpp"),
 			File("Component/UnitTest/ComponentConnectorTest.cpp"),
+			File("Component/UnitTest/ComponentSessionTest.cpp"),
 			File("Disco/UnitTest/CapsInfoGeneratorTest.cpp"),
 			File("Disco/UnitTest/CapsManagerTest.cpp"),
 			File("Disco/UnitTest/EntityCapsManagerTest.cpp"),
-- 
cgit v0.10.2-6-g49f6