diff options
-rw-r--r-- | Swiften/Client/ClientSession.h | 1 | ||||
-rw-r--r-- | Swiften/Component/ComponentSession.cpp | 108 | ||||
-rw-r--r-- | Swiften/Component/ComponentSession.h | 81 | ||||
-rw-r--r-- | Swiften/Component/SConscript | 3 | ||||
-rw-r--r-- | Swiften/Component/UnitTest/ComponentSessionTest.cpp | 190 | ||||
-rw-r--r-- | Swiften/Elements/AuthFailure.h | 8 | ||||
-rw-r--r-- | Swiften/Elements/ComponentHandshake.h | 3 | ||||
-rw-r--r-- | Swiften/SConscript | 1 |
8 files changed, 387 insertions, 8 deletions
diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index 2c1bda8..83744e0 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -12,7 +12,6 @@ #include "Swiften/Base/Error.h" #include "Swiften/Session/SessionStream.h" -#include "Swiften/Session/BasicSessionStream.h" #include "Swiften/Base/String.h" #include "Swiften/JID/JID.h" #include "Swiften/Elements/Element.h" diff --git a/Swiften/Component/ComponentSession.cpp b/Swiften/Component/ComponentSession.cpp new file mode 100644 index 0000000..75ee467 --- /dev/null +++ b/Swiften/Component/ComponentSession.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include "Swiften/Component/ComponentSession.h" + +#include <boost/bind.hpp> + +#include "Swiften/Elements/ProtocolHeader.h" +#include "Swiften/Elements/ComponentHandshake.h" +#include "Swiften/Session/SessionStream.h" +#include "Swiften/Component/ComponentHandshakeGenerator.h" + +namespace Swift { + +ComponentSession::ComponentSession(const JID& jid, const String& secret, boost::shared_ptr<SessionStream> stream) : jid(jid), secret(secret), stream(stream), state(Initial) { +} + +ComponentSession::~ComponentSession() { +} + +void ComponentSession::start() { + stream->onStreamStartReceived.connect(boost::bind(&ComponentSession::handleStreamStart, shared_from_this(), _1)); + stream->onElementReceived.connect(boost::bind(&ComponentSession::handleElement, shared_from_this(), _1)); + stream->onError.connect(boost::bind(&ComponentSession::handleStreamError, shared_from_this(), _1)); + + assert(state == Initial); + state = WaitingForStreamStart; + sendStreamHeader(); +} + +void ComponentSession::sendStreamHeader() { + ProtocolHeader header; + header.setTo(jid); + stream->writeHeader(header); +} + +void ComponentSession::sendStanza(boost::shared_ptr<Stanza> stanza) { + stream->writeElement(stanza); +} + +void ComponentSession::handleStreamStart(const ProtocolHeader& header) { + checkState(WaitingForStreamStart); + state = Authenticating; + stream->writeElement(ComponentHandshake::ref(new ComponentHandshake(ComponentHandshakeGenerator::getHandshake(header.getID(), secret)))); +} + +void ComponentSession::handleElement(boost::shared_ptr<Element> element) { + if (boost::shared_ptr<Stanza> stanza = boost::dynamic_pointer_cast<Stanza>(element)) { + if (getState() == Initialized) { + onStanzaReceived(stanza); + } + else { + finishSession(Error::UnexpectedElementError); + } + } + else if (ComponentHandshake::cast(element)) { + if (!checkState(Authenticating)) { + return; + } + stream->setWhitespacePingEnabled(true); + state = Initialized; + onInitialized(); + } + else if (getState() == Authenticating) { + // FIXME: We should actually check the element received + finishSession(Error::AuthenticationFailedError); + } + else { + finishSession(Error::UnexpectedElementError); + } +} + +bool ComponentSession::checkState(State state) { + if (this->state != state) { + finishSession(Error::UnexpectedElementError); + return false; + } + return true; +} + +void ComponentSession::handleStreamError(boost::shared_ptr<Swift::Error> error) { + finishSession(error); +} + +void ComponentSession::finish() { + finishSession(boost::shared_ptr<Error>()); +} + +void ComponentSession::finishSession(Error::Type error) { + finishSession(boost::shared_ptr<Swift::ComponentSession::Error>(new Swift::ComponentSession::Error(error))); +} + +void ComponentSession::finishSession(boost::shared_ptr<Swift::Error> error) { + state = Finished; + stream->setWhitespacePingEnabled(false); + stream->onStreamStartReceived.disconnect(boost::bind(&ComponentSession::handleStreamStart, shared_from_this(), _1)); + stream->onElementReceived.disconnect(boost::bind(&ComponentSession::handleElement, shared_from_this(), _1)); + stream->onError.disconnect(boost::bind(&ComponentSession::handleStreamError, shared_from_this(), _1)); + if (stream->isAvailable()) { + stream->writeFooter(); + } + onFinished(error); +} + +} diff --git a/Swiften/Component/ComponentSession.h b/Swiften/Component/ComponentSession.h new file mode 100644 index 0000000..cbfa227 --- /dev/null +++ b/Swiften/Component/ComponentSession.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#pragma once + +#include <boost/shared_ptr.hpp> +#include <boost/enable_shared_from_this.hpp> + +#include "Swiften/JID/JID.h" +#include "Swiften/Base/boost_bsignals.h" +#include "Swiften/Base/Error.h" +#include "Swiften/Base/String.h" +#include "Swiften/Elements/Element.h" +#include "Swiften/Elements/Stanza.h" +#include "Swiften/Session/SessionStream.h" + +namespace Swift { + class ComponentAuthenticator; + + class ComponentSession : public boost::enable_shared_from_this<ComponentSession> { + public: + enum State { + Initial, + WaitingForStreamStart, + Authenticating, + Initialized, + Finished + }; + + struct Error : public Swift::Error { + enum Type { + AuthenticationFailedError, + UnexpectedElementError, + } type; + Error(Type type) : type(type) {} + }; + + ~ComponentSession(); + + static boost::shared_ptr<ComponentSession> create(const JID& jid, const String& secret, boost::shared_ptr<SessionStream> stream) { + return boost::shared_ptr<ComponentSession>(new ComponentSession(jid, secret, stream)); + } + + State getState() const { + return state; + } + + void start(); + void finish(); + + void sendStanza(boost::shared_ptr<Stanza>); + + public: + boost::signal<void ()> onInitialized; + boost::signal<void (boost::shared_ptr<Swift::Error>)> onFinished; + boost::signal<void (boost::shared_ptr<Stanza>)> onStanzaReceived; + + private: + ComponentSession(const JID& jid, const String& secret, boost::shared_ptr<SessionStream>); + + void finishSession(Error::Type error); + void finishSession(boost::shared_ptr<Swift::Error> error); + + void sendStreamHeader(); + + void handleElement(boost::shared_ptr<Element>); + void handleStreamStart(const ProtocolHeader&); + void handleStreamError(boost::shared_ptr<Swift::Error>); + + bool checkState(State); + + private: + JID jid; + String secret; + boost::shared_ptr<SessionStream> stream; + State state; + }; +} diff --git a/Swiften/Component/SConscript b/Swiften/Component/SConscript index 6d86575..1f08301 100644 --- a/Swiften/Component/SConscript +++ b/Swiften/Component/SConscript @@ -3,6 +3,7 @@ Import("swiften_env") sources = [ "ComponentHandshakeGenerator.cpp", "ComponentConnector.cpp", + "ComponentSession.cpp", ] -swiften_env.Append(SWIFTEN_OBJECTS = swiften_env.StaticObject(sources))
\ No newline at end of file +swiften_env.Append(SWIFTEN_OBJECTS = swiften_env.StaticObject(sources)) diff --git a/Swiften/Component/UnitTest/ComponentSessionTest.cpp b/Swiften/Component/UnitTest/ComponentSessionTest.cpp new file mode 100644 index 0000000..1722ad4 --- /dev/null +++ b/Swiften/Component/UnitTest/ComponentSessionTest.cpp @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2010 Remko Tronçon + * Licensed under the GNU General Public License v3. + * See Documentation/Licenses/GPLv3.txt for more information. + */ + +#include <cppunit/extensions/HelperMacros.h> +#include <cppunit/extensions/TestFactoryRegistry.h> +#include <deque> +#include <boost/bind.hpp> +#include <boost/optional.hpp> + +#include "Swiften/Session/SessionStream.h" +#include "Swiften/Component/ComponentSession.h" +#include "Swiften/Elements/ComponentHandshake.h" +#include "Swiften/Elements/AuthFailure.h" + +using namespace Swift; + +class ComponentSessionTest : public CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(ComponentSessionTest); + CPPUNIT_TEST(testStart); + CPPUNIT_TEST(testStart_Error); + CPPUNIT_TEST(testStart_Unauthorized); + CPPUNIT_TEST_SUITE_END(); + + public: + void setUp() { + server = boost::shared_ptr<MockSessionStream>(new MockSessionStream()); + sessionFinishedReceived = false; + } + + void testStart() { + boost::shared_ptr<ComponentSession> session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->receiveHandshake(); + server->sendHandshakeResponse(); + + CPPUNIT_ASSERT(server->whitespacePingEnabled); + + session->finish(); + CPPUNIT_ASSERT(!server->whitespacePingEnabled); + + } + + void testStart_Error() { + boost::shared_ptr<ComponentSession> session(createSession()); + session->start(); + server->breakConnection(); + + CPPUNIT_ASSERT_EQUAL(ComponentSession::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + CPPUNIT_ASSERT(sessionFinishedError); + } + + void testStart_Unauthorized() { + boost::shared_ptr<ComponentSession> session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->receiveHandshake(); + server->sendHandshakeError(); + + CPPUNIT_ASSERT_EQUAL(ComponentSession::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + CPPUNIT_ASSERT(sessionFinishedError); + } + + private: + boost::shared_ptr<ComponentSession> createSession() { + boost::shared_ptr<ComponentSession> session = ComponentSession::create(JID("service.foo.com"), "servicesecret", server); + session->onFinished.connect(boost::bind(&ComponentSessionTest::handleSessionFinished, this, _1)); + return session; + } + + void handleSessionFinished(boost::shared_ptr<Error> error) { + sessionFinishedReceived = true; + sessionFinishedError = error; + } + + class MockSessionStream : public SessionStream { + public: + struct Event { + Event(boost::shared_ptr<Element> element) : element(element), footer(false) {} + Event(const ProtocolHeader& header) : header(header), footer(false) {} + Event() : footer(true) {} + + boost::shared_ptr<Element> element; + boost::optional<ProtocolHeader> header; + bool footer; + }; + + MockSessionStream() : available(true), whitespacePingEnabled(false), resetCount(0) { + } + + virtual bool isAvailable() { + return available; + } + + virtual void writeHeader(const ProtocolHeader& header) { + receivedEvents.push_back(Event(header)); + } + + virtual void writeFooter() { + receivedEvents.push_back(Event()); + } + + virtual void writeElement(boost::shared_ptr<Element> element) { + receivedEvents.push_back(Event(element)); + } + + virtual bool supportsTLSEncryption() { + return false; + } + + virtual void addTLSEncryption() { + assert(false); + } + + virtual bool isTLSEncrypted() { + return false; + } + + virtual void addZLibCompression() { + assert(false); + } + + virtual void setWhitespacePingEnabled(bool enabled) { + whitespacePingEnabled = enabled; + } + + virtual void resetXMPPParser() { + resetCount++; + } + + void breakConnection() { + onError(boost::shared_ptr<SessionStream::Error>(new SessionStream::Error(SessionStream::Error::ConnectionReadError))); + } + + void sendStreamStart() { + ProtocolHeader header; + header.setFrom("service.foo.com"); + return onStreamStartReceived(header); + } + + void sendHandshakeResponse() { + onElementReceived(ComponentHandshake::ref(new ComponentHandshake())); + } + + void sendHandshakeError() { + // FIXME: This isn't the correct element + onElementReceived(AuthFailure::ref(new AuthFailure())); + } + + void receiveStreamStart() { + Event event = popEvent(); + CPPUNIT_ASSERT(event.header); + } + + void receiveHandshake() { + Event event = popEvent(); + CPPUNIT_ASSERT(event.element); + ComponentHandshake::ref handshake(ComponentHandshake::cast(event.element)); + CPPUNIT_ASSERT(handshake); + CPPUNIT_ASSERT_EQUAL(String("4c4f8a41141722c8bbfbdd92d827f7b2fc0a542b"), handshake->getData()); + } + + Event popEvent() { + CPPUNIT_ASSERT(receivedEvents.size() > 0); + Event event = receivedEvents.front(); + receivedEvents.pop_front(); + return event; + } + + bool available; + bool whitespacePingEnabled; + String bindID; + int resetCount; + std::deque<Event> receivedEvents; + }; + + boost::shared_ptr<MockSessionStream> server; + bool sessionFinishedReceived; + bool needCredentials; + boost::shared_ptr<Error> sessionFinishedError; +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(ComponentSessionTest); diff --git a/Swiften/Elements/AuthFailure.h b/Swiften/Elements/AuthFailure.h index ff1468e..348a19b 100644 --- a/Swiften/Elements/AuthFailure.h +++ b/Swiften/Elements/AuthFailure.h @@ -4,16 +4,14 @@ * See Documentation/Licenses/GPLv3.txt for more information. */ -#ifndef SWIFTEN_AuthFailure_H -#define SWIFTEN_AuthFailure_H +#pragma once #include "Swiften/Elements/Element.h" +#include "Swiften/Base/Shared.h" namespace Swift { - class AuthFailure : public Element { + class AuthFailure : public Element, public Shared<AuthFailure> { public: AuthFailure() {} }; } - -#endif diff --git a/Swiften/Elements/ComponentHandshake.h b/Swiften/Elements/ComponentHandshake.h index 1067310..d9088e0 100644 --- a/Swiften/Elements/ComponentHandshake.h +++ b/Swiften/Elements/ComponentHandshake.h @@ -8,9 +8,10 @@ #include "Swiften/Elements/Element.h" #include "Swiften/Base/String.h" +#include "Swiften/Base/Shared.h" namespace Swift { - class ComponentHandshake : public Element { + class ComponentHandshake : public Element, public Shared<ComponentHandshake> { public: ComponentHandshake(const String& data = "") : data(data) { } diff --git a/Swiften/SConscript b/Swiften/SConscript index 84d9a67..ca4f91a 100644 --- a/Swiften/SConscript +++ b/Swiften/SConscript @@ -159,6 +159,7 @@ if env["SCONS_STAGE"] == "build" : File("Compress/UnitTest/ZLibDecompressorTest.cpp"), File("Component/UnitTest/ComponentHandshakeGeneratorTest.cpp"), File("Component/UnitTest/ComponentConnectorTest.cpp"), + File("Component/UnitTest/ComponentSessionTest.cpp"), File("Disco/UnitTest/CapsInfoGeneratorTest.cpp"), File("Disco/UnitTest/CapsManagerTest.cpp"), File("Disco/UnitTest/EntityCapsManagerTest.cpp"), |