diff options
Diffstat (limited to 'Swiften/TLS/OpenSSL')
-rw-r--r-- | Swiften/TLS/OpenSSL/Makefile.inc | 2 | ||||
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.cpp | 159 | ||||
-rw-r--r-- | Swiften/TLS/OpenSSL/OpenSSLContext.h | 43 |
3 files changed, 204 insertions, 0 deletions
diff --git a/Swiften/TLS/OpenSSL/Makefile.inc b/Swiften/TLS/OpenSSL/Makefile.inc new file mode 100644 index 0000000..5bda771 --- /dev/null +++ b/Swiften/TLS/OpenSSL/Makefile.inc @@ -0,0 +1,2 @@ +SWIFTEN_SOURCES += \ + Swiften/TLS/OpenSSL/OpenSSLContext.cpp diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp new file mode 100644 index 0000000..3ce8d54 --- /dev/null +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp @@ -0,0 +1,159 @@ +#include <vector> +#include <openssl/err.h> +#include <openssl/pkcs12.h> + +#include "Swiften/TLS/OpenSSL/OpenSSLContext.h" +#include "Swiften/TLS/PKCS12Certificate.h" + +#pragma GCC diagnostic ignored "-Wold-style-cast" + +namespace Swift { + +static const int SSL_READ_BUFFERSIZE = 8192; + +void freeX509Stack(STACK_OF(X509)* stack) { + sk_X509_free(stack); +} + +OpenSSLContext::OpenSSLContext() : state_(Start), context_(0), handle_(0), readBIO_(0), writeBIO_(0) { + ensureLibraryInitialized(); + context_ = SSL_CTX_new(TLSv1_client_method()); +} + +OpenSSLContext::~OpenSSLContext() { + SSL_free(handle_); + SSL_CTX_free(context_); +} + +void OpenSSLContext::ensureLibraryInitialized() { + static bool isLibraryInitialized = false; + if (!isLibraryInitialized) { + SSL_load_error_strings(); + SSL_library_init(); + OpenSSL_add_all_algorithms(); + isLibraryInitialized = true; + } +} + +void OpenSSLContext::connect() { + handle_ = SSL_new(context_); + // Ownership of BIOs is ransferred + readBIO_ = BIO_new(BIO_s_mem()); + writeBIO_ = BIO_new(BIO_s_mem()); + SSL_set_bio(handle_, readBIO_, writeBIO_); + + state_ = Connecting; + doConnect(); +} + +void OpenSSLContext::doConnect() { + int connectResult = SSL_connect(handle_); + int error = SSL_get_error(handle_, connectResult); + switch (error) { + case SSL_ERROR_NONE: + state_ = Connected; + onConnected(); + //X509* x = SSL_get_peer_certificate(handle_); + //std::cout << x->name << std::endl; + //const char* comp = SSL_get_current_compression(handle_); + //std::cout << "Compression: " << SSL_COMP_get_name(comp) << std::endl; + break; + case SSL_ERROR_WANT_READ: + sendPendingDataToNetwork(); + break; + default: + state_ = Error; + onError(); + } +} + +void OpenSSLContext::sendPendingDataToNetwork() { + int size = BIO_pending(writeBIO_); + if (size > 0) { + ByteArray data; + data.resize(size); + BIO_read(writeBIO_, data.getData(), size); + onDataForNetwork(data); + } +} + +void OpenSSLContext::handleDataFromNetwork(const ByteArray& data) { + BIO_write(readBIO_, data.getData(), data.getSize()); + switch (state_) { + case Connecting: + doConnect(); + break; + case Connected: + sendPendingDataToApplication(); + break; + case Start: assert(false); break; + case Error: assert(false); break; + } +} + +void OpenSSLContext::handleDataFromApplication(const ByteArray& data) { + if (SSL_write(handle_, data.getData(), data.getSize()) >= 0) { + sendPendingDataToNetwork(); + } + else { + state_ = Error; + onError(); + } +} + +void OpenSSLContext::sendPendingDataToApplication() { + ByteArray data; + data.resize(SSL_READ_BUFFERSIZE); + int ret = SSL_read(handle_, data.getData(), data.getSize()); + while (ret > 0) { + data.resize(ret); + onDataForApplication(data); + data.resize(SSL_READ_BUFFERSIZE); + ret = SSL_read(handle_, data.getData(), data.getSize()); + } + if (ret < 0 && SSL_get_error(handle_, ret) != SSL_ERROR_WANT_READ) { + state_ = Error; + onError(); + } +} + +bool OpenSSLContext::setClientCertificate(const PKCS12Certificate& certificate) { + if (certificate.isNull()) { + return false; + } + + // Create a PKCS12 structure + BIO* bio = BIO_new(BIO_s_mem()); + BIO_write(bio, certificate.getData().getData(), certificate.getData().getSize()); + boost::shared_ptr<PKCS12> pkcs12(d2i_PKCS12_bio(bio, NULL), PKCS12_free); + BIO_free(bio); + if (!pkcs12) { + return false; + } + + // Parse PKCS12 + X509 *certPtr = 0; + EVP_PKEY* privateKeyPtr = 0; + STACK_OF(X509)* caCertsPtr = 0; + int result = PKCS12_parse(pkcs12.get(), certificate.getPassword().getUTF8Data(), &privateKeyPtr, &certPtr, &caCertsPtr); + if (result != 1) { + return false; + } + boost::shared_ptr<X509> cert(certPtr, X509_free); + boost::shared_ptr<EVP_PKEY> privateKey(privateKeyPtr, EVP_PKEY_free); + boost::shared_ptr<STACK_OF(X509)> caCerts(caCertsPtr, freeX509Stack); + + // Use the key & certificates + if (SSL_CTX_use_certificate(context_, cert.get()) != 1) { + return false; + } + if (SSL_CTX_use_PrivateKey(context_, privateKey.get()) != 1) { + return false; + } + for (int i = 0; i < sk_X509_num(caCerts.get()); ++i) { + SSL_CTX_add_extra_chain_cert(context_, sk_X509_value(caCerts.get(), i)); + } + return true; +} + +} diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h new file mode 100644 index 0000000..3e735df --- /dev/null +++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h @@ -0,0 +1,43 @@ +#include <openssl/ssl.h> +#include <boost/signal.hpp> +#include <boost/noncopyable.hpp> + +#include "Swiften/Base/ByteArray.h" + +namespace Swift { + class PKCS12Certificate; + + class OpenSSLContext : boost::noncopyable { + public: + OpenSSLContext(); + ~OpenSSLContext(); + + void connect(); + bool setClientCertificate(const PKCS12Certificate& cert); + + void handleDataFromNetwork(const ByteArray&); + void handleDataFromApplication(const ByteArray&); + + public: + boost::signal<void (const ByteArray&)> onDataForNetwork; + boost::signal<void (const ByteArray&)> onDataForApplication; + boost::signal<void ()> onError; + boost::signal<void ()> onConnected; + + private: + static void ensureLibraryInitialized(); + + void doConnect(); + void sendPendingDataToNetwork(); + void sendPendingDataToApplication(); + + private: + enum State { Start, Connecting, Connected, Error }; + + State state_; + SSL_CTX* context_; + SSL* handle_; + BIO* readBIO_; + BIO* writeBIO_; + }; +} |