#include <cppunit/extensions/HelperMacros.h>
#include <cppunit/extensions/TestFactoryRegistry.h>
#include <boost/bind.hpp>
#include <boost/function.hpp>
#include <boost/optional.hpp>

#include "Swiften/Parser/XMPPParser.h"
#include "Swiften/Parser/XMPPParserClient.h"
#include "Swiften/Serializer/XMPPSerializer.h"
#include "Swiften/StreamStack/TLSLayerFactory.h"
#include "Swiften/StreamStack/TLSLayer.h"
#include "Swiften/StreamStack/StreamStack.h"
#include "Swiften/StreamStack/WhitespacePingLayer.h"
#include "Swiften/Elements/ProtocolHeader.h"
#include "Swiften/Elements/StreamFeatures.h"
#include "Swiften/Elements/Element.h"
#include "Swiften/Elements/Error.h"
#include "Swiften/Elements/IQ.h"
#include "Swiften/Elements/AuthRequest.h"
#include "Swiften/Elements/AuthSuccess.h"
#include "Swiften/Elements/AuthFailure.h"
#include "Swiften/Elements/ResourceBind.h"
#include "Swiften/Elements/StartSession.h"
#include "Swiften/Elements/StartTLSRequest.h"
#include "Swiften/Elements/StartTLSFailure.h"
#include "Swiften/Elements/TLSProceed.h"
#include "Swiften/Elements/Message.h"
#include "Swiften/EventLoop/MainEventLoop.h"
#include "Swiften/EventLoop/DummyEventLoop.h"
#include "Swiften/Network/Connection.h"
#include "Swiften/Network/ConnectionFactory.h"
#include "Swiften/Client/ClientSession.h"
#include "Swiften/TLS/PKCS12Certificate.h"
#include "Swiften/Parser/PayloadParsers/FullPayloadParserFactoryCollection.h"
#include "Swiften/Serializer/PayloadSerializers/FullPayloadSerializerCollection.h"

using namespace Swift;

class ClientSessionTest : public CppUnit::TestFixture {
		CPPUNIT_TEST_SUITE(ClientSessionTest);
		CPPUNIT_TEST(testConstructor);
		CPPUNIT_TEST(testStart_Error);
		CPPUNIT_TEST(testStart_XMLError);
		CPPUNIT_TEST(testStartTLS);
		CPPUNIT_TEST(testStartTLS_ServerError);
		CPPUNIT_TEST(testStartTLS_NoTLSSupport);
		CPPUNIT_TEST(testStartTLS_ConnectError);
		CPPUNIT_TEST(testStartTLS_ErrorAfterConnect);
		CPPUNIT_TEST(testAuthenticate);
		CPPUNIT_TEST(testAuthenticate_Unauthorized);
		CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms);
		CPPUNIT_TEST(testResourceBind);
		CPPUNIT_TEST(testResourceBind_ChangeResource);
		CPPUNIT_TEST(testResourceBind_EmptyResource);
		CPPUNIT_TEST(testResourceBind_Error);
		CPPUNIT_TEST(testSessionStart);
		CPPUNIT_TEST(testSessionStart_Error);
		CPPUNIT_TEST(testSessionStart_AfterResourceBind);
		CPPUNIT_TEST(testWhitespacePing);
		CPPUNIT_TEST(testReceiveElementAfterSessionStarted);
		CPPUNIT_TEST(testSendElement);
		CPPUNIT_TEST_SUITE_END();

	public:
		ClientSessionTest() {}

		void setUp() {
			eventLoop_ = new DummyEventLoop();
			connection_ = boost::shared_ptr<MockConnection>(new MockConnection());
			tlsLayerFactory_ = new MockTLSLayerFactory();
			sessionStarted_ = false;
			needCredentials_ = false;
		}

		void tearDown() {
			delete tlsLayerFactory_;
			delete eventLoop_;
		}

		void testConstructor() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			CPPUNIT_ASSERT_EQUAL(ClientSession::Initial, session->getState());
		}

		void testStart_Error() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));

			getMockServer()->expectStreamStart();
			session->startSession();
			processEvents();
			CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForStreamStart, session->getState());

			getMockServer()->setError();
			processEvents();

			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
			CPPUNIT_ASSERT_EQUAL(ClientSession::ConnectionReadError, *session->getError());
		}

		void testStart_XMLError() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));

			getMockServer()->expectStreamStart();
			session->startSession();
			processEvents();
			CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForStreamStart, session->getState());

			getMockServer()->sendInvalidXML();
			processEvents();

			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
			CPPUNIT_ASSERT_EQUAL(ClientSession::XMLError, *session->getError());
		}

		void testStartTLS_NoTLSSupport() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			tlsLayerFactory_->setTLSSupported(false);
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithStartTLS();
			session->startSession();
			processEvents();
			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
		}

		void testStartTLS() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithStartTLS();
			getMockServer()->expectStartTLS();
			// FIXME: Test 'encrypting' state
			getMockServer()->sendTLSProceed();
			session->startSession();
			processEvents();
			CPPUNIT_ASSERT_EQUAL(ClientSession::Encrypting, session->getState());
			CPPUNIT_ASSERT(session->getTLSLayer());
			CPPUNIT_ASSERT(session->getTLSLayer()->isConnecting());

			getMockServer()->resetParser();
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			session->getTLSLayer()->setConnected();
			// FIXME: Test 'WatingForStreamStart' state
			processEvents();
			CPPUNIT_ASSERT_EQUAL(ClientSession::Negotiating, session->getState());
		}

		void testStartTLS_ServerError() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithStartTLS();
			getMockServer()->expectStartTLS();
			getMockServer()->sendTLSFailure();
			session->startSession();
			processEvents();

			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
			CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, *session->getError());
		}

		void testStartTLS_ConnectError() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithStartTLS();
			getMockServer()->expectStartTLS();
			getMockServer()->sendTLSProceed();
			session->startSession();
			processEvents();
			session->getTLSLayer()->setError();

			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
			CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, *session->getError());
		}

		void testStartTLS_ErrorAfterConnect() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithStartTLS();
			getMockServer()->expectStartTLS();
			getMockServer()->sendTLSProceed();
			session->startSession();
			processEvents();
			getMockServer()->resetParser();
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			session->getTLSLayer()->setConnected();
			processEvents();

			session->getTLSLayer()->setError();

			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
			CPPUNIT_ASSERT_EQUAL(ClientSession::TLSError, *session->getError());
		}

		void testAuthenticate() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			session->onNeedCredentials.connect(boost::bind(&ClientSessionTest::setNeedCredentials, this));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithAuthentication();
			session->startSession();
			processEvents();
			CPPUNIT_ASSERT_EQUAL(ClientSession::WaitingForCredentials, session->getState());
			CPPUNIT_ASSERT(needCredentials_);

			getMockServer()->expectAuth("me", "mypass");
			getMockServer()->sendAuthSuccess();
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			session->sendCredentials("mypass");
			CPPUNIT_ASSERT_EQUAL(ClientSession::Authenticating, session->getState());
			processEvents();
			CPPUNIT_ASSERT_EQUAL(ClientSession::Negotiating, session->getState());
		}

		void testAuthenticate_Unauthorized() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithAuthentication();
			session->startSession();
			processEvents();

			getMockServer()->expectAuth("me", "mypass");
			getMockServer()->sendAuthFailure();
			session->sendCredentials("mypass");
			processEvents();

			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
			CPPUNIT_ASSERT_EQUAL(ClientSession::AuthenticationFailedError, *session->getError());
		}

		void testAuthenticate_NoValidAuthMechanisms() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithUnsupportedAuthentication();
			session->startSession();
			processEvents();

			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
			CPPUNIT_ASSERT_EQUAL(ClientSession::NoSupportedAuthMechanismsError, *session->getError());
		}

		void testResourceBind() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithResourceBind();
			getMockServer()->expectResourceBind("Bar", "session-bind");
			// FIXME: Check CPPUNIT_ASSERT_EQUAL(ClientSession::BindingResource, session->getState());
			getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind");
			session->startSession();

			processEvents();

			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
			CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/Bar"), session->getLocalJID());
		}

		void testResourceBind_ChangeResource() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithResourceBind();
			getMockServer()->expectResourceBind("Bar", "session-bind");
			getMockServer()->sendResourceBindResponse("me@foo.com/Bar123", "session-bind");
			session->startSession();
			processEvents();

			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
			CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/Bar123"), session->getLocalJID());
		}

		void testResourceBind_EmptyResource() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithResourceBind();
			getMockServer()->expectResourceBind("", "session-bind");
			getMockServer()->sendResourceBindResponse("me@foo.com/NewResource", "session-bind");
			session->startSession();
			processEvents();

			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
			CPPUNIT_ASSERT_EQUAL(JID("me@foo.com/NewResource"), session->getLocalJID());
		}

		void testResourceBind_Error() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithResourceBind();
			getMockServer()->expectResourceBind("", "session-bind");
			getMockServer()->sendError("session-bind");
			session->startSession();
			processEvents();

			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
			CPPUNIT_ASSERT_EQUAL(ClientSession::ResourceBindError, *session->getError());
		}

		void testSessionStart() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			session->onSessionStarted.connect(boost::bind(&ClientSessionTest::setSessionStarted, this));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithSession();
			getMockServer()->expectSessionStart("session-start");
			// FIXME: Check CPPUNIT_ASSERT_EQUAL(ClientSession::StartingSession, session->getState());
			getMockServer()->sendSessionStartResponse("session-start");
			session->startSession();
			processEvents();

			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
			CPPUNIT_ASSERT(sessionStarted_);
		}

		void testSessionStart_Error() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithSession();
			getMockServer()->expectSessionStart("session-start");
			getMockServer()->sendError("session-start");
			session->startSession();
			processEvents();

			CPPUNIT_ASSERT_EQUAL(ClientSession::Error, session->getState());
			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStartError, *session->getError());
		}

		void testSessionStart_AfterResourceBind() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			session->onSessionStarted.connect(boost::bind(&ClientSessionTest::setSessionStarted, this));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeaturesWithResourceBindAndSession();
			getMockServer()->expectResourceBind("Bar", "session-bind");
			getMockServer()->sendResourceBindResponse("me@foo.com/Bar", "session-bind");
			getMockServer()->expectSessionStart("session-start");
			getMockServer()->sendSessionStartResponse("session-start");
			session->startSession();
			processEvents();

			CPPUNIT_ASSERT_EQUAL(ClientSession::SessionStarted, session->getState());
			CPPUNIT_ASSERT(sessionStarted_);
		}

		void testWhitespacePing() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeatures();
			session->startSession();
			processEvents();
			CPPUNIT_ASSERT(session->getWhitespacePingLayer());
		}

		void testReceiveElementAfterSessionStarted() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeatures();
			session->startSession();
			processEvents();

			getMockServer()->expectMessage();
			session->sendElement(boost::shared_ptr<Message>(new Message()));
		}

		void testSendElement() {
			boost::shared_ptr<MockSession> session(createSession("me@foo.com/Bar"));
			session->onElementReceived.connect(boost::bind(&ClientSessionTest::addReceivedElement, this, _1));
			getMockServer()->expectStreamStart();
			getMockServer()->sendStreamStart();
			getMockServer()->sendStreamFeatures();
			getMockServer()->sendMessage();
			session->startSession();
			processEvents();

			CPPUNIT_ASSERT_EQUAL(1, static_cast<int>(receivedElements_.size()));
			CPPUNIT_ASSERT(boost::dynamic_pointer_cast<Message>(receivedElements_[0]));
		}

	private:
		struct MockConnection;

		boost::shared_ptr<MockConnection> getMockServer() const {
			return connection_;
		}

		void processEvents() {
			eventLoop_->processEvents();
			getMockServer()->assertNoMoreExpectations();
		}

		void setSessionStarted() {
			sessionStarted_ = true;
		}

		void setNeedCredentials() {
			needCredentials_ = true;
		}

		void addReceivedElement(boost::shared_ptr<Element> element) {
			receivedElements_.push_back(element);
		}
	
	private:
		struct MockConnection : public Connection, public XMPPParserClient {
			struct Event {
				enum Direction { In, Out };
				enum Type { StreamStartEvent, StreamEndEvent, ElementEvent };

				Event(
					Direction direction,
					Type type, 
					boost::shared_ptr<Element> element = boost::shared_ptr<Element>()) :
						direction(direction), type(type), element(element) {}

				Direction direction;
				Type	type;
				boost::shared_ptr<Element> element;
			};

			MockConnection() :
					resetParser_(false),
					domain_("foo.com"),
					parser_(0),
					serializer_(&payloadSerializers_) {
				parser_ = new XMPPParser(this, &payloadParserFactories_);
			}

			~MockConnection() {
				delete parser_;
			}

			void disconnect() { }

			void listen() {
				assert(false);
			}

			void connect(const HostAddressPort&) { assert(false); }
			void connect(const String&) { assert(false); }

			void setError() {
				MainEventLoop::postEvent(boost::bind(boost::ref(onDisconnected), Connection::ReadError));
			}

			void write(const ByteArray& data) {
				CPPUNIT_ASSERT(parser_->parse(data.toString()));
				if (resetParser_) {
					resetParser();
					resetParser_ = false;
				}
			}

			void resetParser() {
				delete parser_;
				parser_ = new XMPPParser(this, &payloadParserFactories_);
			}

			void handleStreamStart(const ProtocolHeader& header) {
				CPPUNIT_ASSERT_EQUAL(domain_, header.getTo());
				handleEvent(Event::StreamStartEvent);
			}

			void handleElement(boost::shared_ptr<Swift::Element> element) {
				handleEvent(Event::ElementEvent, element);
			}

			void handleStreamEnd() {
				handleEvent(Event::StreamEndEvent);
			}

			void handleEvent(Event::Type type, boost::shared_ptr<Element> element = boost::shared_ptr<Element>()) {
				CPPUNIT_ASSERT(!events_.empty());
				CPPUNIT_ASSERT_EQUAL(events_[0].direction, Event::In);
				CPPUNIT_ASSERT_EQUAL(events_[0].type, type);
				if (type == Event::ElementEvent) {
					CPPUNIT_ASSERT_EQUAL(serializer_.serializeElement(events_[0].element), serializer_.serializeElement(element));
				}
				events_.pop_front();

				while (!events_.empty() && events_[0].direction == Event::Out) {
					sendData(serializeEvent(events_[0]));
					events_.pop_front();
				}

				if (!events_.empty() && events_[0].type == Event::StreamStartEvent) {
					resetParser_ = true;
				}
			}

			String serializeEvent(const Event& event) {
				switch (event.type) {
					case Event::StreamStartEvent: 
						{
							ProtocolHeader header;
							header.setTo(domain_);
							return serializer_.serializeHeader(header);
						}
					case Event::ElementEvent:
						return serializer_.serializeElement(event.element);
					case Event::StreamEndEvent:
						return serializer_.serializeFooter();
				}
				assert(false);
				return "";
			}

			void assertNoMoreExpectations() {
				foreach (const Event& event, events_) {
					std::cout << "Unprocessed event: " << serializeEvent(event) << std::endl;
				}
				CPPUNIT_ASSERT(events_.empty());
			}

			void sendData(const ByteArray& data) {
				MainEventLoop::postEvent(boost::bind(boost::ref(onDataRead), data));
			}

			void expectStreamStart() {
				events_.push_back(Event(Event::In, Event::StreamStartEvent));
			}

			void expectStartTLS() {
				events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())));
			}

			void expectAuth(const String& user, const String& password) {
				String s	= String("") + '\0' + user + '\0' + password;
				events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<AuthRequest>(new AuthRequest("PLAIN", ByteArray(s.getUTF8Data(), s.getUTF8Size())))));
			}

			void expectResourceBind(const String& resource, const String& id) {
				boost::shared_ptr<ResourceBind> sessionStart(new ResourceBind());
				sessionStart->setResource(resource);
				events_.push_back(Event(Event::In, Event::ElementEvent, IQ::createRequest(IQ::Set, JID(), id, sessionStart)));
			}

			void expectSessionStart(const String& id) {
				events_.push_back(Event(Event::In, Event::ElementEvent, IQ::createRequest(IQ::Set, JID(), id, boost::shared_ptr<StartSession>(new StartSession()))));
			}
			
			void expectMessage() {
				events_.push_back(Event(Event::In, Event::ElementEvent, boost::shared_ptr<Message>(new Message())));
			}

			void sendInvalidXML() {
				sendData("<invalid xml/>");
			}

			void sendStreamStart() {
				events_.push_back(Event(Event::Out, Event::StreamStartEvent));
			}

			void sendStreamFeatures() {
				boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
				events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
			}

			void sendStreamFeaturesWithStartTLS() {
				boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
				streamFeatures->setHasStartTLS();
				events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
			}

			void sendStreamFeaturesWithAuthentication() {
				boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
				streamFeatures->addAuthenticationMechanism("PLAIN");
				events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
			}

			void sendStreamFeaturesWithUnsupportedAuthentication() {
				boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
				streamFeatures->addAuthenticationMechanism("MY-UNSUPPORTED-MECH");
				events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
			}

			void sendStreamFeaturesWithResourceBind() {
				boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
				streamFeatures->setHasResourceBind();
				events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
			}

			void sendStreamFeaturesWithSession() {
				boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
				streamFeatures->setHasSession();
				events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
			}

			void sendStreamFeaturesWithResourceBindAndSession() {
				boost::shared_ptr<StreamFeatures> streamFeatures(new StreamFeatures());
				streamFeatures->setHasResourceBind();
				streamFeatures->setHasSession();
				events_.push_back(Event(Event::Out, Event::ElementEvent, streamFeatures));
			}

			void sendMessage() {
				events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<Message>(new Message())));
			}

			void sendTLSProceed() {
				events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<TLSProceed>(new TLSProceed())));
			}

			void sendTLSFailure() {
				events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<StartTLSFailure>(new StartTLSFailure())));
			}

			void sendAuthSuccess() {
				events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<AuthSuccess>(new AuthSuccess())));
			}

			void sendAuthFailure() {
				events_.push_back(Event(Event::Out, Event::ElementEvent, boost::shared_ptr<AuthFailure>(new AuthFailure())));
			}

			void sendResourceBindResponse(const String& jid, const String& id) {
				boost::shared_ptr<ResourceBind> sessionStart(new ResourceBind());
				sessionStart->setJID(JID(jid));
				events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createResult(JID(), id, sessionStart)));
			}

			void sendError(const String& id) {
				events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createError(JID(), id, Swift::Error::NotAllowed, Swift::Error::Cancel)));
			}

			void sendSessionStartResponse(const String& id) {
				events_.push_back(Event(Event::Out, Event::ElementEvent, IQ::createResult(JID(), id, boost::shared_ptr<StartSession>(new StartSession()))));
			}

			bool resetParser_;
			String domain_;
			FullPayloadParserFactoryCollection payloadParserFactories_;
			FullPayloadSerializerCollection payloadSerializers_;
			XMPPParser* parser_;
			XMPPSerializer serializer_;
			std::deque<Event> events_;
		};

		struct MockTLSLayer : public TLSLayer {
			MockTLSLayer() : connecting_(false) {}
			bool setClientCertificate(const PKCS12Certificate&) { return true; }
			void writeData(const ByteArray& data) { onWriteData(data); }
			void handleDataRead(const ByteArray& data) { onDataRead(data); }
			void setConnected() { onConnected(); }
			void setError() { onError(); }
			void connect() { connecting_ = true; }
			bool isConnecting() { return connecting_; }

			bool connecting_;
		};

		struct MockTLSLayerFactory : public TLSLayerFactory {
			MockTLSLayerFactory() : haveTLS_(true) {}
			void setTLSSupported(bool b) { haveTLS_ = b; }
			virtual bool canCreate() const { return haveTLS_; }
			virtual boost::shared_ptr<TLSLayer> createTLSLayer() { 
				assert(haveTLS_);
				boost::shared_ptr<MockTLSLayer> result(new MockTLSLayer());
				layers_.push_back(result);
				return result;
			}
			std::vector< boost::shared_ptr<MockTLSLayer> > layers_;
			bool haveTLS_;
		};

		struct MockSession : public ClientSession {
			MockSession(const JID& jid, boost::shared_ptr<Connection> connection, TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) : ClientSession(jid, connection, tlsLayerFactory, payloadParserFactories, payloadSerializers) {}

			boost::shared_ptr<MockTLSLayer> getTLSLayer() const {
				return getStreamStack()->getLayer<MockTLSLayer>();
			}
			boost::shared_ptr<WhitespacePingLayer> getWhitespacePingLayer() const {
				return getStreamStack()->getLayer<WhitespacePingLayer>();
			}
		};

		boost::shared_ptr<MockSession> createSession(const String& jid) {
			return boost::shared_ptr<MockSession>(new MockSession(JID(jid), connection_, tlsLayerFactory_, &payloadParserFactories_, &payloadSerializers_));
		}


		DummyEventLoop* eventLoop_;
		boost::shared_ptr<MockConnection> connection_;
		MockTLSLayerFactory* tlsLayerFactory_;
		FullPayloadParserFactoryCollection payloadParserFactories_;
		FullPayloadSerializerCollection payloadSerializers_;
		bool sessionStarted_;
		bool needCredentials_;
		std::vector< boost::shared_ptr<Element> > receivedElements_;
		typedef std::vector< boost::function<void ()> > EventQueue;
		EventQueue events_;
};

CPPUNIT_TEST_SUITE_REGISTRATION(ClientSessionTest);