From 9eaa75b907a515a65ccb2002632fbf2f30c5aee8 Mon Sep 17 00:00:00 2001
From: Tobias Markmann <tm@ayena.de>
Date: Fri, 5 Jan 2018 16:45:34 +0100
Subject: Modernize OpenSSL crypto backend

* use std::unique_ptr for memory management of dynamic OpenSSL
  objects
* use an initializer class and static instance of it to correctly
  initialize/finalize OpenSSL on first use
* use enum class instead of simple enum for state
* use nullptr instead of NULL

Test-Information:

Builds and tests pass on macOS 10.13.2 with clang-trunk and
ASAN.

Change-Id: I346f14e21c34871c1900a8e1ac000450770a0bbe

diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
index 17ac8cc..8d2d965 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
@@ -30,7 +30,7 @@ OpenSSLCertificate::OpenSSLCertificate(const ByteArray& der) {
 #else
     const unsigned char* p = vecptr(der);
 #endif
-    cert = std::shared_ptr<X509>(d2i_X509(NULL, &p, der.size()), X509_free);
+    cert = std::shared_ptr<X509>(d2i_X509(nullptr, &p, der.size()), X509_free);
     if (!cert) {
         SWIFT_LOG(warning) << "Error creating certificate from DER data" << std::endl;
     }
@@ -42,7 +42,7 @@ ByteArray OpenSSLCertificate::toDER() const {
     if (!cert) {
         return result;
     }
-    result.resize(i2d_X509(cert.get(), NULL));
+    result.resize(i2d_X509(cert.get(), nullptr));
     unsigned char* p = vecptr(result);
     i2d_X509(cert.get(), &p);
     return result;
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
index 0805917..6f15edf 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2010-2016 Isode Limited.
+ * Copyright (c) 2010-2018 Isode Limited.
  * All rights reserved.
  * See the COPYING file for more information.
  */
@@ -10,10 +10,12 @@
 #include <wincrypt.h>
 #endif
 
+#include <cassert>
+#include <memory>
 #include <vector>
+
 #include <openssl/err.h>
 #include <openssl/pkcs12.h>
-#include <memory>
 
 #if defined(SWIFTEN_PLATFORM_MACOSX)
 #include <Security/Security.h>
@@ -39,10 +41,32 @@ static void freeX509Stack(STACK_OF(X509)* stack) {
     sk_X509_free(stack);
 }
 
-OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readBIO_(0), writeBIO_(0) {
+namespace {
+    class OpenSSLInitializerFinalizer {
+        public:
+            OpenSSLInitializerFinalizer() {
+                SSL_load_error_strings();
+                SSL_library_init();
+                OpenSSL_add_all_algorithms();
+
+                // Disable compression
+                /*
+                STACK_OF(SSL_COMP)* compressionMethods = SSL_COMP_get_compression_methods();
+                sk_SSL_COMP_zero(compressionMethods);*/
+            }
+
+            ~OpenSSLInitializerFinalizer() {
+                EVP_cleanup();
+            }
+
+            OpenSSLInitializerFinalizer(const OpenSSLInitializerFinalizer &) = delete;
+    };
+}
+
+OpenSSLContext::OpenSSLContext() : state_(State::Start) {
     ensureLibraryInitialized();
-    context_ = SSL_CTX_new(SSLv23_client_method());
-    SSL_CTX_set_options(context_, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
+    context_ = std::unique_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_client_method()));
+    SSL_CTX_set_options(context_.get(), SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
 
     // TODO: implement CRL checking
     // TODO: download CRL (HTTP transport)
@@ -52,7 +76,7 @@ OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readB
     // TODO: handle OCSP stapling see https://www.rfc-editor.org/rfc/rfc4366.txt
     // Load system certs
 #if defined(SWIFTEN_PLATFORM_WINDOWS)
-    X509_STORE* store = SSL_CTX_get_cert_store(context_);
+    X509_STORE* store = SSL_CTX_get_cert_store(context_.get());
     HCERTSTORE systemStore = CertOpenSystemStore(0, "ROOT");
     if (systemStore) {
         PCCERT_CONTEXT certContext = NULL;
@@ -68,7 +92,7 @@ OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readB
         }
     }
 #elif !defined(SWIFTEN_PLATFORM_MACOSX)
-    SSL_CTX_set_default_verify_paths(context_);
+    SSL_CTX_set_default_verify_paths(context_.get());
 #elif defined(SWIFTEN_PLATFORM_MACOSX) && !defined(SWIFTEN_PLATFORM_IPHONE)
     // On Mac OS X 10.5 (OpenSSL < 0.9.8), OpenSSL does not automatically look in the system store.
     // On Mac OS X 10.6 (OpenSSL >= 0.9.8), OpenSSL *does* look in the system store to determine trust.
@@ -76,7 +100,7 @@ OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readB
     // the certificates first. See
     //        http://opensource.apple.com/source/OpenSSL098/OpenSSL098-27/src/crypto/x509/x509_vfy_apple.c
     // to understand why. We therefore add all certs from the system store ourselves.
-    X509_STORE* store = SSL_CTX_get_cert_store(context_);
+    X509_STORE* store = SSL_CTX_get_cert_store(context_.get());
     CFArrayRef anchorCertificates;
     if (SecTrustCopyAnchorCertificates(&anchorCertificates) == 0) {
         for (int i = 0; i < CFArrayGetCount(anchorCertificates); ++i) {
@@ -99,51 +123,37 @@ OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readB
 }
 
 OpenSSLContext::~OpenSSLContext() {
-    SSL_free(handle_);
-    SSL_CTX_free(context_);
 }
 
 void OpenSSLContext::ensureLibraryInitialized() {
-    static bool isLibraryInitialized = false;
-    if (!isLibraryInitialized) {
-        SSL_load_error_strings();
-        SSL_library_init();
-        OpenSSL_add_all_algorithms();
-
-        // Disable compression
-        /*
-        STACK_OF(SSL_COMP)* compressionMethods = SSL_COMP_get_compression_methods();
-        sk_SSL_COMP_zero(compressionMethods);*/
-
-        isLibraryInitialized = true;
-    }
+    static OpenSSLInitializerFinalizer openSSLInit;
 }
 
 void OpenSSLContext::connect() {
-    handle_ = SSL_new(context_);
-    if (handle_ == nullptr) {
-        state_ = Error;
+    handle_ = std::unique_ptr<SSL>(SSL_new(context_.get()));
+    if (!handle_) {
+        state_ = State::Error;
         onError(std::make_shared<TLSError>());
         return;
     }
 
-    // Ownership of BIOs is ransferred
+    // Ownership of BIOs is transferred
     readBIO_ = BIO_new(BIO_s_mem());
     writeBIO_ = BIO_new(BIO_s_mem());
-    SSL_set_bio(handle_, readBIO_, writeBIO_);
+    SSL_set_bio(handle_.get(), readBIO_, writeBIO_);
 
-    state_ = Connecting;
+    state_ = State::Connecting;
     doConnect();
 }
 
 void OpenSSLContext::doConnect() {
-    int connectResult = SSL_connect(handle_);
-    int error = SSL_get_error(handle_, connectResult);
+    int connectResult = SSL_connect(handle_.get());
+    int error = SSL_get_error(handle_.get(), connectResult);
     switch (error) {
         case SSL_ERROR_NONE: {
-            state_ = Connected;
+            state_ = State::Connected;
             //std::cout << x->name << std::endl;
-            //const char* comp = SSL_get_current_compression(handle_);
+            //const char* comp = SSL_get_current_compression(handle_.get());
             //std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl;
             onConnected();
             break;
@@ -152,7 +162,7 @@ void OpenSSLContext::doConnect() {
             sendPendingDataToNetwork();
             break;
         default:
-            state_ = Error;
+            state_ = State::Error;
             onError(std::make_shared<TLSError>());
     }
 }
@@ -170,23 +180,23 @@ void OpenSSLContext::sendPendingDataToNetwork() {
 void OpenSSLContext::handleDataFromNetwork(const SafeByteArray& data) {
     BIO_write(readBIO_, vecptr(data), data.size());
     switch (state_) {
-        case Connecting:
+        case State::Connecting:
             doConnect();
             break;
-        case Connected:
+        case State::Connected:
             sendPendingDataToApplication();
             break;
-        case Start: assert(false); break;
-        case Error: /*assert(false);*/ break;
+        case State::Start: assert(false); break;
+        case State::Error: /*assert(false);*/ break;
     }
 }
 
 void OpenSSLContext::handleDataFromApplication(const SafeByteArray& data) {
-    if (SSL_write(handle_, vecptr(data), data.size()) >= 0) {
+    if (SSL_write(handle_.get(), vecptr(data), data.size()) >= 0) {
         sendPendingDataToNetwork();
     }
     else {
-        state_ = Error;
+        state_ = State::Error;
         onError(std::make_shared<TLSError>());
     }
 }
@@ -194,15 +204,15 @@ void OpenSSLContext::handleDataFromApplication(const SafeByteArray& data) {
 void OpenSSLContext::sendPendingDataToApplication() {
     SafeByteArray data;
     data.resize(SSL_READ_BUFFERSIZE);
-    int ret = SSL_read(handle_, vecptr(data), data.size());
+    int ret = SSL_read(handle_.get(), vecptr(data), data.size());
     while (ret > 0) {
         data.resize(ret);
         onDataForApplication(data);
         data.resize(SSL_READ_BUFFERSIZE);
-        ret = SSL_read(handle_, vecptr(data), data.size());
+        ret = SSL_read(handle_.get(), vecptr(data), data.size());
     }
-    if (ret < 0 && SSL_get_error(handle_, ret) != SSL_ERROR_WANT_READ) {
-        state_ = Error;
+    if (ret < 0 && SSL_get_error(handle_.get(), ret) != SSL_ERROR_WANT_READ) {
+        state_ = State::Error;
         onError(std::make_shared<TLSError>());
     }
 }
@@ -216,16 +226,16 @@ bool OpenSSLContext::setClientCertificate(CertificateWithKey::ref certificate) {
     // Create a PKCS12 structure
     BIO* bio = BIO_new(BIO_s_mem());
     BIO_write(bio, vecptr(pkcs12Certificate->getData()), pkcs12Certificate->getData().size());
-    std::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, NULL), PKCS12_free);
+    std::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, nullptr), PKCS12_free);
     BIO_free(bio);
     if (!pkcs12) {
         return false;
     }
 
     // Parse PKCS12
-    X509 *certPtr = 0;
-    EVP_PKEY* privateKeyPtr = 0;
-    STACK_OF(X509)* caCertsPtr = 0;
+    X509 *certPtr = nullptr;
+    EVP_PKEY* privateKeyPtr = nullptr;
+    STACK_OF(X509)* caCertsPtr = nullptr;
     SafeByteArray password(pkcs12Certificate->getPassword());
     password.push_back(0);
     int result = PKCS12_parse(pkcs12.get(), reinterpret_cast<const char*>(vecptr(password)), &privateKeyPtr, &certPtr, &caCertsPtr);
@@ -237,21 +247,21 @@ bool OpenSSLContext::setClientCertificate(CertificateWithKey::ref certificate) {
     std::shared_ptr<STACK_OF(X509)> caCerts(caCertsPtr, freeX509Stack);
 
     // Use the key & certificates
-    if (SSL_CTX_use_certificate(context_, cert.get()) != 1) {
+    if (SSL_CTX_use_certificate(context_.get(), cert.get()) != 1) {
         return false;
     }
-    if (SSL_CTX_use_PrivateKey(context_, privateKey.get()) != 1) {
+    if (SSL_CTX_use_PrivateKey(context_.get(), privateKey.get()) != 1) {
         return false;
     }
     for (int i = 0;  i < sk_X509_num(caCerts.get()); ++i) {
-        SSL_CTX_add_extra_chain_cert(context_, sk_X509_value(caCerts.get(), i));
+        SSL_CTX_add_extra_chain_cert(context_.get(), sk_X509_value(caCerts.get(), i));
     }
     return true;
 }
 
 std::vector<Certificate::ref> OpenSSLContext::getPeerCertificateChain() const {
     std::vector<Certificate::ref> result;
-    STACK_OF(X509)* chain = SSL_get_peer_cert_chain(handle_);
+    STACK_OF(X509)* chain = SSL_get_peer_cert_chain(handle_.get());
     for (int i = 0; i < sk_X509_num(chain); ++i) {
         std::shared_ptr<X509> x509Cert(X509_dup(sk_X509_value(chain, i)), X509_free);
 
@@ -262,7 +272,7 @@ std::vector<Certificate::ref> OpenSSLContext::getPeerCertificateChain() const {
 }
 
 std::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertificateVerificationError() const {
-    int verifyResult = SSL_get_verify_result(handle_);
+    int verifyResult = SSL_get_verify_result(handle_.get());
     if (verifyResult != X509_V_OK) {
         return std::make_shared<CertificateVerificationError>(getVerificationErrorTypeForResult(verifyResult));
     }
@@ -274,7 +284,7 @@ std::shared_ptr<CertificateVerificationError> OpenSSLContext::getPeerCertificate
 ByteArray OpenSSLContext::getFinishMessage() const {
     ByteArray data;
     data.resize(MAX_FINISHED_SIZE);
-    size_t size = SSL_get_finished(handle_, vecptr(data), data.size());
+    size_t size = SSL_get_finished(handle_.get(), vecptr(data), data.size());
     data.resize(size);
     return data;
 }
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h
index e75b3c9..49ada51 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h
@@ -1,11 +1,13 @@
 /*
- * Copyright (c) 2010-2016 Isode Limited.
+ * Copyright (c) 2010-2018 Isode Limited.
  * All rights reserved.
  * See the COPYING file for more information.
  */
 
 #pragma once
 
+#include <memory>
+
 #include <boost/noncopyable.hpp>
 #include <boost/signals2.hpp>
 
@@ -15,23 +17,40 @@
 #include <Swiften/TLS/CertificateWithKey.h>
 #include <Swiften/TLS/TLSContext.h>
 
-namespace Swift {
+namespace std {
+    template<>
+    class default_delete<SSL_CTX> {
+    public:
+        void operator()(SSL_CTX *ptr) {
+            SSL_CTX_free(ptr);
+        }
+    };
 
+    template<>
+    class default_delete<SSL> {
+    public:
+        void operator()(SSL *ptr) {
+            SSL_free(ptr);
+        }
+    };
+}
+
+namespace Swift {
     class OpenSSLContext : public TLSContext, boost::noncopyable {
         public:
             OpenSSLContext();
-            virtual ~OpenSSLContext();
+            virtual ~OpenSSLContext() override final;
 
-            void connect();
-            bool setClientCertificate(CertificateWithKey::ref cert);
+            void connect() override final;
+            bool setClientCertificate(CertificateWithKey::ref cert) override final;
 
-            void handleDataFromNetwork(const SafeByteArray&);
-            void handleDataFromApplication(const SafeByteArray&);
+            void handleDataFromNetwork(const SafeByteArray&) override final;
+            void handleDataFromApplication(const SafeByteArray&) override final;
 
-            std::vector<Certificate::ref> getPeerCertificateChain() const;
-            std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const;
+            std::vector<Certificate::ref> getPeerCertificateChain() const override final;
+            std::shared_ptr<CertificateVerificationError> getPeerCertificateVerificationError() const override final;
 
-            virtual ByteArray getFinishMessage() const;
+            virtual ByteArray getFinishMessage() const override final;
 
         private:
             static void ensureLibraryInitialized();
@@ -43,12 +62,12 @@ namespace Swift {
             void sendPendingDataToApplication();
 
         private:
-            enum State { Start, Connecting, Connected, Error };
+            enum class State { Start, Connecting, Connected, Error };
 
             State state_;
-            SSL_CTX* context_;
-            SSL* handle_;
-            BIO* readBIO_;
-            BIO* writeBIO_;
+            std::unique_ptr<SSL_CTX> context_;
+            std::unique_ptr<SSL> handle_;
+            BIO* readBIO_ = nullptr;
+            BIO* writeBIO_ = nullptr;
     };
 }
-- 
cgit v0.10.2-6-g49f6