#include "Swiften/Client/Session.h" #include <boost/bind.hpp> #include "Swiften/Network/ConnectionFactory.h" #include "Swiften/Elements/ProtocolHeader.h" #include "Swiften/StreamStack/StreamStack.h" #include "Swiften/StreamStack/ConnectionLayer.h" #include "Swiften/StreamStack/XMPPLayer.h" #include "Swiften/StreamStack/TLSLayer.h" #include "Swiften/StreamStack/TLSLayerFactory.h" #include "Swiften/Elements/StreamFeatures.h" #include "Swiften/Elements/StartTLSRequest.h" #include "Swiften/Elements/StartTLSFailure.h" #include "Swiften/Elements/TLSProceed.h" #include "Swiften/Elements/AuthRequest.h" #include "Swiften/Elements/AuthSuccess.h" #include "Swiften/Elements/AuthFailure.h" #include "Swiften/Elements/StartSession.h" #include "Swiften/Elements/IQ.h" #include "Swiften/Elements/ResourceBind.h" #include "Swiften/SASL/PLAINMessage.h" #include "Swiften/StreamStack/WhitespacePingLayer.h" namespace Swift { Session::Session( const JID& jid, boost::shared_ptr<Connection> connection, TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) : jid_(jid), tlsLayerFactory_(tlsLayerFactory), payloadParserFactories_(payloadParserFactories), payloadSerializers_(payloadSerializers), state_(Initial), error_(NoError), connection_(connection), streamStack_(0), needSessionStart_(false) { } Session::~Session() { delete streamStack_; } void Session::start() { assert(state_ == Initial); connection_->onDisconnected.connect(boost::bind(&Session::handleDisconnected, this, _1)); initializeStreamStack(); state_ = WaitingForStreamStart; sendStreamHeader(); } void Session::stop() { // TODO: Send end stream header if applicable connection_->disconnect(); } void Session::sendStreamHeader() { ProtocolHeader header; header.setTo(jid_.getDomain()); xmppLayer_->writeHeader(header); } void Session::initializeStreamStack() { xmppLayer_ = boost::shared_ptr<XMPPLayer>(new XMPPLayer(payloadParserFactories_, payloadSerializers_)); xmppLayer_->onStreamStart.connect(boost::bind(&Session::handleStreamStart, this)); xmppLayer_->onElement.connect(boost::bind(&Session::handleElement, this, _1)); xmppLayer_->onError.connect(boost::bind(&Session::setError, this, XMLError)); xmppLayer_->onDataRead.connect(boost::bind(boost::ref(onDataRead), _1)); xmppLayer_->onWriteData.connect(boost::bind(boost::ref(onDataWritten), _1)); connectionLayer_ = boost::shared_ptr<ConnectionLayer>(new ConnectionLayer(connection_)); streamStack_ = new StreamStack(xmppLayer_, connectionLayer_); } void Session::handleDisconnected(const boost::optional<Connection::Error>& error) { if (error) { switch (*error) { case Connection::ReadError: setError(ConnectionReadError); break; case Connection::WriteError: setError(ConnectionWriteError); break; } } } void Session::setCertificate(const PKCS12Certificate& certificate) { certificate_ = certificate; } void Session::handleStreamStart() { checkState(WaitingForStreamStart); state_ = Negotiating; } void Session::handleElement(boost::shared_ptr<Element> element) { if (getState() == SessionStarted) { onElementReceived(element); } else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { if (!checkState(Negotiating)) { return; } if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) { state_ = Encrypting; xmppLayer_->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); } else if (streamFeatures->hasAuthenticationMechanisms()) { if (!certificate_.isNull()) { if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { state_ = Authenticating; xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); } else { setError(ClientCertificateError); } } else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) { state_ = WaitingForCredentials; onNeedCredentials(); } else { setError(NoSupportedAuthMechanismsError); } } else { // Start the session // Add a whitespace ping layer whitespacePingLayer_ = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer()); streamStack_->addLayer(whitespacePingLayer_); if (streamFeatures->hasSession()) { needSessionStart_ = true; } if (streamFeatures->hasResourceBind()) { state_ = BindingResource; boost::shared_ptr<ResourceBind> resourceBind(new ResourceBind()); if (!jid_.getResource().isEmpty()) { resourceBind->setResource(jid_.getResource()); } xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); } else if (needSessionStart_) { sendSessionStart(); } else { state_ = SessionStarted; onSessionStarted(); } } } else if (dynamic_cast<AuthSuccess*>(element.get())) { checkState(Authenticating); state_ = WaitingForStreamStart; xmppLayer_->resetParser(); sendStreamHeader(); } else if (dynamic_cast<AuthFailure*>(element.get())) { setError(AuthenticationFailedError); } else if (dynamic_cast<TLSProceed*>(element.get())) { tlsLayer_ = tlsLayerFactory_->createTLSLayer(); streamStack_->addLayer(tlsLayer_); if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) { setError(ClientCertificateLoadError); } else { tlsLayer_->onConnected.connect(boost::bind(&Session::handleTLSConnected, this)); tlsLayer_->onError.connect(boost::bind(&Session::handleTLSError, this)); tlsLayer_->connect(); } } else if (dynamic_cast<StartTLSFailure*>(element.get())) { setError(TLSError); } else if (IQ* iq = dynamic_cast<IQ*>(element.get())) { if (state_ == BindingResource) { boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { setError(ResourceBindError); } else if (!resourceBind) { setError(UnexpectedElementError); } else if (iq->getType() == IQ::Result) { jid_ = resourceBind->getJID(); if (!jid_.isValid()) { setError(ResourceBindError); } if (needSessionStart_) { sendSessionStart(); } else { state_ = SessionStarted; } } else { setError(UnexpectedElementError); } } else if (state_ == StartingSession) { if (iq->getType() == IQ::Result) { state_ = SessionStarted; onSessionStarted(); } else if (iq->getType() == IQ::Error) { setError(SessionStartError); } else { setError(UnexpectedElementError); } } else { setError(UnexpectedElementError); } } else { // FIXME Not correct? state_ = SessionStarted; onSessionStarted(); } } void Session::sendSessionStart() { state_ = StartingSession; xmppLayer_->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); } void Session::setError(SessionError error) { assert(error != NoError); state_ = Error; error_ = error; onError(error); } bool Session::checkState(State state) { if (state_ != state) { setError(UnexpectedElementError); return false; } return true; } void Session::sendCredentials(const String& password) { assert(WaitingForCredentials); state_ = Authenticating; xmppLayer_->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(jid_.getNode(), password).getValue()))); } void Session::sendElement(boost::shared_ptr<Element> element) { assert(SessionStarted); xmppLayer_->writeElement(element); } void Session::handleTLSConnected() { state_ = WaitingForStreamStart; xmppLayer_->resetParser(); sendStreamHeader(); } void Session::handleTLSError() { setError(TLSError); } }