diff options
Diffstat (limited to 'Swiften/Server')
-rw-r--r-- | Swiften/Server/Makefile.inc | 4 | ||||
-rw-r--r-- | Swiften/Server/ServerFromClientSession.cpp | 64 | ||||
-rw-r--r-- | Swiften/Server/ServerFromClientSession.h | 13 | ||||
-rw-r--r-- | Swiften/Server/SimpleUserRegistry.cpp | 17 | ||||
-rw-r--r-- | Swiften/Server/SimpleUserRegistry.h | 22 | ||||
-rw-r--r-- | Swiften/Server/UserRegistry.cpp | 8 | ||||
-rw-r--r-- | Swiften/Server/UserRegistry.h | 13 |
7 files changed, 135 insertions, 6 deletions
diff --git a/Swiften/Server/Makefile.inc b/Swiften/Server/Makefile.inc index 1ab98d7..ae10bd2 100644 --- a/Swiften/Server/Makefile.inc +++ b/Swiften/Server/Makefile.inc @@ -2,6 +2,8 @@ SWIFTEN_SOURCES += \ Swiften/Server/ServerSession.cpp \ Swiften/Server/ServerFromClientSession.cpp \ Swiften/Server/ServerStanzaRouter.cpp \ - Swiften/Server/IncomingConnection.cpp + Swiften/Server/IncomingConnection.cpp \ + Swiften/Server/UserRegistry.cpp \ + Swiften/Server/SimpleUserRegistry.cpp include Swiften/Server/UnitTest/Makefile.inc diff --git a/Swiften/Server/ServerFromClientSession.cpp b/Swiften/Server/ServerFromClientSession.cpp index be8b601..4fc517f 100644 --- a/Swiften/Server/ServerFromClientSession.cpp +++ b/Swiften/Server/ServerFromClientSession.cpp @@ -2,10 +2,19 @@ #include <boost/bind.hpp> +#include "Swiften/Server/UserRegistry.h" #include "Swiften/Network/IncomingConnection.h" #include "Swiften/StreamStack/StreamStack.h" #include "Swiften/StreamStack/IncomingConnectionLayer.h" #include "Swiften/StreamStack/XMPPLayer.h" +#include "Swiften/Elements/StreamFeatures.h" +#include "Swiften/Elements/ResourceBind.h" +#include "Swiften/Elements/StartSession.h" +#include "Swiften/Elements/IQ.h" +#include "Swiften/Elements/AuthSuccess.h" +#include "Swiften/Elements/AuthFailure.h" +#include "Swiften/Elements/AuthRequest.h" +#include "Swiften/SASL/PLAINMessage.h" namespace Swift { @@ -13,11 +22,15 @@ ServerFromClientSession::ServerFromClientSession( const String& id, boost::shared_ptr<IncomingConnection> connection, PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers) : + PayloadSerializerCollection* payloadSerializers, + UserRegistry* userRegistry) : id_(id), connection_(connection), payloadParserFactories_(payloadParserFactories), - payloadSerializers_(payloadSerializers) { + payloadSerializers_(payloadSerializers), + userRegistry_(userRegistry), + authenticated_(false), + initialized_(false) { xmppLayer_ = new XMPPLayer(payloadParserFactories_, payloadSerializers_); xmppLayer_->onStreamStart.connect( boost::bind(&ServerFromClientSession::handleStreamStart, this, _2)); @@ -39,12 +52,57 @@ ServerFromClientSession::~ServerFromClientSession() { delete xmppLayer_; } -void ServerFromClientSession::handleElement(boost::shared_ptr<Element>) { +void ServerFromClientSession::handleElement(boost::shared_ptr<Element> element) { + if (initialized_) { + onElementReceived(element); + } + else { + if (AuthRequest* authRequest = dynamic_cast<AuthRequest*>(element.get())) { + if (authRequest->getMechanism() != "PLAIN") { + xmppLayer_->writeElement(boost::shared_ptr<AuthFailure>(new AuthFailure)); + onSessionFinished(); + } + else { + PLAINMessage plainMessage(authRequest->getMessage()); + if (userRegistry_->isValidUserPassword(JID(plainMessage.getAuthenticationID(), domain_), plainMessage.getPassword())) { + xmppLayer_->writeElement(boost::shared_ptr<AuthSuccess>(new AuthSuccess())); + authenticated_ = true; + xmppLayer_->resetParser(); + } + else { + xmppLayer_->writeElement(boost::shared_ptr<AuthFailure>(new AuthFailure)); + onSessionFinished(); + } + } + } + else if (IQ* iq = dynamic_cast<IQ*>(element.get())) { + if (boost::shared_ptr<ResourceBind> resourceBind = iq->getPayload<ResourceBind>()) { + jid_ = JID(user_, domain_, resourceBind->getResource()); + boost::shared_ptr<ResourceBind> resultResourceBind(new ResourceBind()); + resultResourceBind->setJID(jid_); + xmppLayer_->writeElement(IQ::createResult(JID(), iq->getID(), resultResourceBind)); + } + else if (iq->getPayload<StartSession>()) { + initialized_ = true; + xmppLayer_->writeElement(IQ::createResult(jid_, iq->getID())); + } + } + } } void ServerFromClientSession::handleStreamStart(const String& domain) { domain_ = domain; xmppLayer_->writeHeader(domain_, id_); + + boost::shared_ptr<StreamFeatures> features(new StreamFeatures()); + if (!authenticated_) { + features->addAuthenticationMechanism("PLAIN"); + } + else { + features->setHasResourceBind(); + features->setHasSession(); + } + xmppLayer_->writeElement(features); } } diff --git a/Swiften/Server/ServerFromClientSession.h b/Swiften/Server/ServerFromClientSession.h index 3413a03..9b340bc 100644 --- a/Swiften/Server/ServerFromClientSession.h +++ b/Swiften/Server/ServerFromClientSession.h @@ -4,12 +4,14 @@ #include <boost/signal.hpp> #include "Swiften/Base/String.h" +#include "Swiften/JID/JID.h" namespace Swift { class Element; class PayloadParserFactoryCollection; class PayloadSerializerCollection; class StreamStack; + class UserRegistry; class XMPPLayer; class IncomingConnectionLayer; class IncomingConnection; @@ -21,10 +23,12 @@ namespace Swift { const String& id, boost::shared_ptr<IncomingConnection> connection, PayloadParserFactoryCollection* payloadParserFactories, - PayloadSerializerCollection* payloadSerializers); + PayloadSerializerCollection* payloadSerializers, + UserRegistry* userRegistry); ~ServerFromClientSession(); - boost::signal<void()> onSessionFinished; + boost::signal<void (boost::shared_ptr<Element>)> onElementReceived; + boost::signal<void ()> onSessionFinished; boost::signal<void (const ByteArray&)> onDataWritten; boost::signal<void (const ByteArray&)> onDataRead; @@ -37,9 +41,14 @@ namespace Swift { boost::shared_ptr<IncomingConnection> connection_; PayloadParserFactoryCollection* payloadParserFactories_; PayloadSerializerCollection* payloadSerializers_; + UserRegistry* userRegistry_; + bool authenticated_; + bool initialized_; IncomingConnectionLayer* connectionLayer_; StreamStack* streamStack_; XMPPLayer* xmppLayer_; String domain_; + String user_; + JID jid_; }; } diff --git a/Swiften/Server/SimpleUserRegistry.cpp b/Swiften/Server/SimpleUserRegistry.cpp new file mode 100644 index 0000000..1a6743a --- /dev/null +++ b/Swiften/Server/SimpleUserRegistry.cpp @@ -0,0 +1,17 @@ +#include "Swiften/Server/SimpleUserRegistry.h" + +namespace Swift { + +SimpleUserRegistry::SimpleUserRegistry() { +} + +bool SimpleUserRegistry::isValidUserPassword(const JID& user, const String& password) const { + std::map<JID,String>::const_iterator i = users.find(user); + return i != users.end() ? i->second == password : false; +} + +void SimpleUserRegistry::addUser(const JID& user, const String& password) { + users.insert(std::make_pair(user, password)); +} + +} diff --git a/Swiften/Server/SimpleUserRegistry.h b/Swiften/Server/SimpleUserRegistry.h new file mode 100644 index 0000000..253025d --- /dev/null +++ b/Swiften/Server/SimpleUserRegistry.h @@ -0,0 +1,22 @@ +#pragma once + +#include <map> + +#include "Swiften/JID/JID.h" +#include "Swiften/Base/String.h" +#include "Swiften/Server/UserRegistry.h" + +namespace Swift { + class String; + + class SimpleUserRegistry : public UserRegistry { + public: + SimpleUserRegistry(); + + virtual bool isValidUserPassword(const JID& user, const String& password) const; + void addUser(const JID& user, const String& password); + + private: + std::map<JID, String> users; + }; +} diff --git a/Swiften/Server/UserRegistry.cpp b/Swiften/Server/UserRegistry.cpp new file mode 100644 index 0000000..d1a509f --- /dev/null +++ b/Swiften/Server/UserRegistry.cpp @@ -0,0 +1,8 @@ +#include "Swiften/Server/UserRegistry.h" + +namespace Swift { + +UserRegistry::~UserRegistry() { +} + +} diff --git a/Swiften/Server/UserRegistry.h b/Swiften/Server/UserRegistry.h new file mode 100644 index 0000000..5ced019 --- /dev/null +++ b/Swiften/Server/UserRegistry.h @@ -0,0 +1,13 @@ +#pragma once + +namespace Swift { + class String; + class JID; + + class UserRegistry { + public: + virtual ~UserRegistry(); + + virtual bool isValidUserPassword(const JID& user, const String& password) const = 0; + }; +} |