#include "Swiften/Client/ClientSession.h" #include <boost/bind.hpp> #include "Swiften/Network/ConnectionFactory.h" #include "Swiften/Elements/ProtocolHeader.h" #include "Swiften/StreamStack/StreamStack.h" #include "Swiften/StreamStack/ConnectionLayer.h" #include "Swiften/StreamStack/XMPPLayer.h" #include "Swiften/StreamStack/TLSLayer.h" #include "Swiften/StreamStack/TLSLayerFactory.h" #include "Swiften/Elements/StreamFeatures.h" #include "Swiften/Elements/StartTLSRequest.h" #include "Swiften/Elements/StartTLSFailure.h" #include "Swiften/Elements/TLSProceed.h" #include "Swiften/Elements/AuthRequest.h" #include "Swiften/Elements/AuthSuccess.h" #include "Swiften/Elements/AuthFailure.h" #include "Swiften/Elements/StartSession.h" #include "Swiften/Elements/IQ.h" #include "Swiften/Elements/ResourceBind.h" #include "Swiften/SASL/PLAINMessage.h" #include "Swiften/StreamStack/WhitespacePingLayer.h" namespace Swift { ClientSession::ClientSession( const JID& jid, boost::shared_ptr<Connection> connection, TLSLayerFactory* tlsLayerFactory, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers) : Session(connection, payloadParserFactories, payloadSerializers), tlsLayerFactory_(tlsLayerFactory), state_(Initial), needSessionStart_(false) { setLocalJID(jid); setRemoteJID(JID("", jid.getDomain())); } void ClientSession::handleSessionStarted() { assert(state_ == Initial); state_ = WaitingForStreamStart; sendStreamHeader(); } void ClientSession::sendStreamHeader() { ProtocolHeader header; header.setTo(getRemoteJID()); getXMPPLayer()->writeHeader(header); } void ClientSession::setCertificate(const PKCS12Certificate& certificate) { certificate_ = certificate; } void ClientSession::handleStreamStart(const ProtocolHeader&) { checkState(WaitingForStreamStart); state_ = Negotiating; } void ClientSession::handleElement(boost::shared_ptr<Element> element) { if (getState() == SessionStarted) { onElementReceived(element); } else if (StreamFeatures* streamFeatures = dynamic_cast<StreamFeatures*>(element.get())) { if (!checkState(Negotiating)) { return; } if (streamFeatures->hasStartTLS() && tlsLayerFactory_->canCreate()) { state_ = Encrypting; getXMPPLayer()->writeElement(boost::shared_ptr<StartTLSRequest>(new StartTLSRequest())); } else if (streamFeatures->hasAuthenticationMechanisms()) { if (!certificate_.isNull()) { if (streamFeatures->hasAuthenticationMechanism("EXTERNAL")) { state_ = Authenticating; getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("EXTERNAL", ""))); } else { finishSession(ClientCertificateError); } } else if (streamFeatures->hasAuthenticationMechanism("PLAIN")) { state_ = WaitingForCredentials; onNeedCredentials(); } else { finishSession(NoSupportedAuthMechanismsError); } } else { // Start the session // Add a whitespace ping layer whitespacePingLayer_ = boost::shared_ptr<WhitespacePingLayer>(new WhitespacePingLayer()); getStreamStack()->addLayer(whitespacePingLayer_); if (streamFeatures->hasSession()) { needSessionStart_ = true; } if (streamFeatures->hasResourceBind()) { state_ = BindingResource; boost::shared_ptr<ResourceBind> resourceBind(new ResourceBind()); if (!getLocalJID().getResource().isEmpty()) { resourceBind->setResource(getLocalJID().getResource()); } getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-bind", resourceBind)); } else if (needSessionStart_) { sendSessionStart(); } else { state_ = SessionStarted; setInitialized(); } } } else if (dynamic_cast<AuthSuccess*>(element.get())) { checkState(Authenticating); state_ = WaitingForStreamStart; getXMPPLayer()->resetParser(); sendStreamHeader(); } else if (dynamic_cast<AuthFailure*>(element.get())) { finishSession(AuthenticationFailedError); } else if (dynamic_cast<TLSProceed*>(element.get())) { tlsLayer_ = tlsLayerFactory_->createTLSLayer(); getStreamStack()->addLayer(tlsLayer_); if (!certificate_.isNull() && !tlsLayer_->setClientCertificate(certificate_)) { finishSession(ClientCertificateLoadError); } else { tlsLayer_->onConnected.connect(boost::bind(&ClientSession::handleTLSConnected, this)); tlsLayer_->onError.connect(boost::bind(&ClientSession::handleTLSError, this)); tlsLayer_->connect(); } } else if (dynamic_cast<StartTLSFailure*>(element.get())) { finishSession(TLSError); } else if (IQ* iq = dynamic_cast<IQ*>(element.get())) { if (state_ == BindingResource) { boost::shared_ptr<ResourceBind> resourceBind(iq->getPayload<ResourceBind>()); if (iq->getType() == IQ::Error && iq->getID() == "session-bind") { finishSession(ResourceBindError); } else if (!resourceBind) { finishSession(UnexpectedElementError); } else if (iq->getType() == IQ::Result) { setLocalJID(resourceBind->getJID()); if (!getLocalJID().isValid()) { finishSession(ResourceBindError); } if (needSessionStart_) { sendSessionStart(); } else { state_ = SessionStarted; } } else { finishSession(UnexpectedElementError); } } else if (state_ == StartingSession) { if (iq->getType() == IQ::Result) { state_ = SessionStarted; setInitialized(); } else if (iq->getType() == IQ::Error) { finishSession(SessionStartError); } else { finishSession(UnexpectedElementError); } } else { finishSession(UnexpectedElementError); } } else { // FIXME Not correct? state_ = SessionStarted; setInitialized(); } } void ClientSession::sendSessionStart() { state_ = StartingSession; getXMPPLayer()->writeElement(IQ::createRequest(IQ::Set, JID(), "session-start", boost::shared_ptr<StartSession>(new StartSession()))); } void ClientSession::handleSessionFinished(const boost::optional<SessionError>& error) { if (error) { //assert(!error_); state_ = Error; error_ = error; } else { state_ = Finished; } } bool ClientSession::checkState(State state) { if (state_ != state) { finishSession(UnexpectedElementError); return false; } return true; } void ClientSession::sendCredentials(const String& password) { assert(WaitingForCredentials); state_ = Authenticating; getXMPPLayer()->writeElement(boost::shared_ptr<Element>(new AuthRequest("PLAIN", PLAINMessage(getLocalJID().getNode(), password).getValue()))); } void ClientSession::handleTLSConnected() { state_ = WaitingForStreamStart; getXMPPLayer()->resetParser(); sendStreamHeader(); } void ClientSession::handleTLSError() { finishSession(TLSError); } }