diff options
Diffstat (limited to 'Swiften/Client')
-rw-r--r-- | Swiften/Client/ClientSession.cpp | 18 | ||||
-rw-r--r-- | Swiften/Client/ClientSession.h | 1 | ||||
-rw-r--r-- | Swiften/Client/UnitTest/ClientSessionTest.cpp | 30 |
3 files changed, 45 insertions, 4 deletions
diff --git a/Swiften/Client/ClientSession.cpp b/Swiften/Client/ClientSession.cpp index a199a84..9e6db5d 100644 --- a/Swiften/Client/ClientSession.cpp +++ b/Swiften/Client/ClientSession.cpp @@ -37,6 +37,7 @@ #include "Swiften/SASL/DIGESTMD5ClientAuthenticator.h" #include "Swiften/Session/SessionStream.h" #include "Swiften/TLS/CertificateTrustChecker.h" +#include "Swiften/TLS/ServerIdentityVerifier.h" namespace Swift { @@ -330,16 +331,27 @@ void ClientSession::handleTLSEncrypted() { Certificate::ref certificate = stream->getPeerCertificate(); boost::shared_ptr<CertificateVerificationError> verificationError = stream->getPeerCertificateVerificationError(); if (verificationError) { - if (certificateTrustChecker && certificateTrustChecker->isCertificateTrusted(certificate, localJID.getDomain())) { + checkTrustOrFinish(certificate, verificationError); + } + else { + ServerIdentityVerifier identityVerifier(localJID); + if (identityVerifier.certificateVerifies(certificate)) { continueAfterTLSEncrypted(); } else { - finishSession(verificationError); + boost::shared_ptr<CertificateVerificationError> identityError(new CertificateVerificationError(CertificateVerificationError::InvalidServerIdentity)); + checkTrustOrFinish(certificate, identityError); } } - else { +} + +void ClientSession::checkTrustOrFinish(Certificate::ref certificate, boost::shared_ptr<CertificateVerificationError> error) { + if (certificateTrustChecker && certificateTrustChecker->isCertificateTrusted(certificate, localJID.getDomain())) { continueAfterTLSEncrypted(); } + else { + finishSession(error); + } } void ClientSession::continueAfterTLSEncrypted() { diff --git a/Swiften/Client/ClientSession.h b/Swiften/Client/ClientSession.h index 6acd9a3..20573a0 100644 --- a/Swiften/Client/ClientSession.h +++ b/Swiften/Client/ClientSession.h @@ -121,6 +121,7 @@ namespace Swift { void handleStanzaAcked(boost::shared_ptr<Stanza> stanza); void ack(unsigned int handledStanzasCount); void continueAfterTLSEncrypted(); + void checkTrustOrFinish(Certificate::ref certificate, boost::shared_ptr<CertificateVerificationError> error); private: JID localJID; diff --git a/Swiften/Client/UnitTest/ClientSessionTest.cpp b/Swiften/Client/UnitTest/ClientSessionTest.cpp index 11e4992..74f3376 100644 --- a/Swiften/Client/UnitTest/ClientSessionTest.cpp +++ b/Swiften/Client/UnitTest/ClientSessionTest.cpp @@ -24,6 +24,8 @@ #include "Swiften/Elements/EnableStreamManagement.h" #include "Swiften/Elements/IQ.h" #include "Swiften/Elements/ResourceBind.h" +#include "Swiften/TLS/SimpleCertificate.h" +#include "Swiften/TLS/BlindCertificateTrustChecker.h" using namespace Swift; @@ -33,6 +35,7 @@ class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_TEST(testStartTLS); CPPUNIT_TEST(testStartTLS_ServerError); CPPUNIT_TEST(testStartTLS_ConnectError); + CPPUNIT_TEST(testStartTLS_InvalidIdentity); CPPUNIT_TEST(testAuthenticate); CPPUNIT_TEST(testAuthenticate_Unauthorized); CPPUNIT_TEST(testAuthenticate_NoValidAuthMechanisms); @@ -57,6 +60,11 @@ class ClientSessionTest : public CppUnit::TestFixture { server = boost::shared_ptr<MockSessionStream>(new MockSessionStream()); sessionFinishedReceived = false; needCredentials = false; + blindCertificateTrustChecker = new BlindCertificateTrustChecker(); + } + + void tearDown() { + delete blindCertificateTrustChecker; } void testStart_Error() { @@ -71,6 +79,7 @@ class ClientSessionTest : public CppUnit::TestFixture { void testStartTLS() { boost::shared_ptr<ClientSession> session(createSession()); + session->setCertificateTrustChecker(blindCertificateTrustChecker); session->start(); server->receiveStreamStart(); server->sendStreamStart(); @@ -116,6 +125,24 @@ class ClientSessionTest : public CppUnit::TestFixture { CPPUNIT_ASSERT(sessionFinishedError); } + void testStartTLS_InvalidIdentity() { + boost::shared_ptr<ClientSession> session(createSession()); + session->start(); + server->receiveStreamStart(); + server->sendStreamStart(); + server->sendStreamFeaturesWithStartTLS(); + server->receiveStartTLS(); + CPPUNIT_ASSERT(!server->tlsEncrypted); + server->sendTLSProceed(); + CPPUNIT_ASSERT(server->tlsEncrypted); + server->onTLSEncrypted(); + + CPPUNIT_ASSERT_EQUAL(ClientSession::Finished, session->getState()); + CPPUNIT_ASSERT(sessionFinishedReceived); + CPPUNIT_ASSERT(sessionFinishedError); + CPPUNIT_ASSERT_EQUAL(CertificateVerificationError::InvalidServerIdentity, boost::dynamic_pointer_cast<CertificateVerificationError>(sessionFinishedError)->getType()); + } + void testAuthenticate() { boost::shared_ptr<ClientSession> session(createSession()); session->start(); @@ -284,7 +311,7 @@ class ClientSessionTest : public CppUnit::TestFixture { } virtual Certificate::ref getPeerCertificate() const { - return Certificate::ref(); + return Certificate::ref(new SimpleCertificate()); } virtual boost::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const { @@ -429,6 +456,7 @@ class ClientSessionTest : public CppUnit::TestFixture { bool sessionFinishedReceived; bool needCredentials; boost::shared_ptr<Error> sessionFinishedError; + BlindCertificateTrustChecker* blindCertificateTrustChecker; }; CPPUNIT_TEST_SUITE_REGISTRATION(ClientSessionTest); |