summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'Swiften/TLS')
-rw-r--r--Swiften/TLS/CertificateFactory.cpp4
-rw-r--r--Swiften/TLS/CertificateFactory.h2
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp8
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLCertificate.h2
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.cpp8
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h2
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLContext.cpp9
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLContext.h2
-rw-r--r--Swiften/TLS/TLSContext.cpp2
-rw-r--r--Swiften/TLS/TLSContext.h2
10 files changed, 28 insertions, 13 deletions
diff --git a/Swiften/TLS/CertificateFactory.cpp b/Swiften/TLS/CertificateFactory.cpp
index aaf27d9..d4db3f4 100644
--- a/Swiften/TLS/CertificateFactory.cpp
+++ b/Swiften/TLS/CertificateFactory.cpp
@@ -23,9 +23,9 @@ namespace Swift {
23CertificateFactory::~CertificateFactory() { 23CertificateFactory::~CertificateFactory() {
24} 24}
25 25
26std::vector<std::unique_ptr<Certificate>> CertificateFactory::createCertificateChain(const ByteArray& /* data */) { 26std::vector<std::shared_ptr<Certificate>> CertificateFactory::createCertificateChain(const ByteArray& /* data */) {
27 assert(false); 27 assert(false);
28 return std::vector<std::unique_ptr<Certificate>>(); 28 return std::vector<std::shared_ptr<Certificate>>();
29} 29}
30 30
31PrivateKey::ref CertificateFactory::createPrivateKey(const SafeByteArray& data, boost::optional<SafeByteArray> password) { 31PrivateKey::ref CertificateFactory::createPrivateKey(const SafeByteArray& data, boost::optional<SafeByteArray> password) {
diff --git a/Swiften/TLS/CertificateFactory.h b/Swiften/TLS/CertificateFactory.h
index 619031c..873c36b 100644
--- a/Swiften/TLS/CertificateFactory.h
+++ b/Swiften/TLS/CertificateFactory.h
@@ -19,7 +19,7 @@ namespace Swift {
19 virtual ~CertificateFactory(); 19 virtual ~CertificateFactory();
20 20
21 virtual Certificate* createCertificateFromDER(const ByteArray& der) = 0; 21 virtual Certificate* createCertificateFromDER(const ByteArray& der) = 0;
22 virtual std::vector<std::unique_ptr<Certificate>> createCertificateChain(const ByteArray& data); 22 virtual std::vector<std::shared_ptr<Certificate>> createCertificateChain(const ByteArray& data);
23 PrivateKey::ref createPrivateKey(const SafeByteArray& data, boost::optional<SafeByteArray> password = boost::optional<SafeByteArray>()); 23 PrivateKey::ref createPrivateKey(const SafeByteArray& data, boost::optional<SafeByteArray> password = boost::optional<SafeByteArray>());
24 }; 24 };
25} 25}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
index 8d2d965..bb51428 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
@@ -37,6 +37,14 @@ OpenSSLCertificate::OpenSSLCertificate(const ByteArray& der) {
37 parse(); 37 parse();
38} 38}
39 39
40void OpenSSLCertificate::incrementReferenceCount() const {
41#if OPENSSL_VERSION_NUMBER >= 0x10100000L
42 X509_up_ref(cert.get());
43#else
44 CRYPTO_add(&(cert.get()->references), 1, CRYPTO_LOCK_EVP_PKEY);
45#endif
46}
47
40ByteArray OpenSSLCertificate::toDER() const { 48ByteArray OpenSSLCertificate::toDER() const {
41 ByteArray result; 49 ByteArray result;
42 if (!cert) { 50 if (!cert) {
diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificate.h b/Swiften/TLS/OpenSSL/OpenSSLCertificate.h
index 186caea..64da82a 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLCertificate.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLCertificate.h
@@ -45,6 +45,8 @@ namespace Swift {
45 return cert; 45 return cert;
46 } 46 }
47 47
48 void incrementReferenceCount() const;
49
48 private: 50 private:
49 void parse(); 51 void parse();
50 52
diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.cpp b/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.cpp
index 5eb626b..fd94ec8 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.cpp
@@ -20,8 +20,8 @@ Certificate* OpenSSLCertificateFactory::createCertificateFromDER(const ByteArray
20 return new OpenSSLCertificate(der); 20 return new OpenSSLCertificate(der);
21} 21}
22 22
23std::vector<std::unique_ptr<Certificate>> OpenSSLCertificateFactory::createCertificateChain(const ByteArray& data) { 23std::vector<std::shared_ptr<Certificate>> OpenSSLCertificateFactory::createCertificateChain(const ByteArray& data) {
24 std::vector<std::unique_ptr<Certificate>> certificateChain; 24 std::vector<std::shared_ptr<Certificate>> certificateChain;
25 25
26 if (data.size() > std::numeric_limits<int>::max()) { 26 if (data.size() > std::numeric_limits<int>::max()) {
27 return certificateChain; 27 return certificateChain;
@@ -35,11 +35,11 @@ std::vector<std::unique_ptr<Certificate>> OpenSSLCertificateFactory::createCerti
35 auto x509certFromPEM = PEM_read_bio_X509(bio.get(), &openSSLCert, nullptr, nullptr); 35 auto x509certFromPEM = PEM_read_bio_X509(bio.get(), &openSSLCert, nullptr, nullptr);
36 if (x509certFromPEM && openSSLCert) { 36 if (x509certFromPEM && openSSLCert) {
37 std::shared_ptr<X509> x509Cert(openSSLCert, X509_free); 37 std::shared_ptr<X509> x509Cert(openSSLCert, X509_free);
38 certificateChain.emplace_back(std::make_unique<OpenSSLCertificate>(x509Cert)); 38 certificateChain.emplace_back(std::make_shared<OpenSSLCertificate>(x509Cert));
39 openSSLCert = nullptr; 39 openSSLCert = nullptr;
40 while ((x509certFromPEM = PEM_read_bio_X509(bio.get(), &openSSLCert, nullptr, nullptr)) != nullptr) { 40 while ((x509certFromPEM = PEM_read_bio_X509(bio.get(), &openSSLCert, nullptr, nullptr)) != nullptr) {
41 std::shared_ptr<X509> x509Cert(openSSLCert, X509_free); 41 std::shared_ptr<X509> x509Cert(openSSLCert, X509_free);
42 certificateChain.emplace_back(std::make_unique<OpenSSLCertificate>(x509Cert)); 42 certificateChain.emplace_back(std::make_shared<OpenSSLCertificate>(x509Cert));
43 openSSLCert = nullptr; 43 openSSLCert = nullptr;
44 } 44 }
45 } 45 }
diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h b/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h
index 48e9b2c..a6974c8 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h
@@ -16,6 +16,6 @@ namespace Swift {
16 virtual ~OpenSSLCertificateFactory() override final; 16 virtual ~OpenSSLCertificateFactory() override final;
17 17
18 virtual Certificate* createCertificateFromDER(const ByteArray& der) override final; 18 virtual Certificate* createCertificateFromDER(const ByteArray& der) override final;
19 virtual std::vector<std::unique_ptr<Certificate>> createCertificateChain(const ByteArray& data) override final; 19 virtual std::vector<std::shared_ptr<Certificate>> createCertificateChain(const ByteArray& data) override final;
20 }; 20 };
21} 21}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
index 5c80976..32d6470 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
@@ -567,7 +567,7 @@ void OpenSSLContext::sendPendingDataToApplication() {
567 } 567 }
568} 568}
569 569
570bool OpenSSLContext::setCertificateChain(std::vector<std::unique_ptr<Certificate>>&& certificateChain) { 570bool OpenSSLContext::setCertificateChain(const std::vector<std::shared_ptr<Certificate>>& certificateChain) {
571 if (certificateChain.size() == 0) { 571 if (certificateChain.size() == 0) {
572 SWIFT_LOG(warning) << "Trying to load empty certificate chain." << std::endl; 572 SWIFT_LOG(warning) << "Trying to load empty certificate chain." << std::endl;
573 return false; 573 return false;
@@ -583,17 +583,22 @@ bool OpenSSLContext::setCertificateChain(std::vector<std::unique_ptr<Certificate
583 return false; 583 return false;
584 } 584 }
585 585
586 // Increment reference count on certificate so that it does not get freed when the SSL context is destroyed
587 openSSLCert->incrementReferenceCount();
588
586 if (certificateChain.size() > 1) { 589 if (certificateChain.size() > 1) {
587 for (auto certificate = certificateChain.begin() + 1; certificate != certificateChain.end(); ++certificate) { 590 for (auto certificate = certificateChain.begin() + 1; certificate != certificateChain.end(); ++certificate) {
588 auto openSSLCert = dynamic_cast<OpenSSLCertificate*>(certificate->get()); 591 auto openSSLCert = dynamic_cast<OpenSSLCertificate*>(certificate->get());
589 if (!openSSLCert) { 592 if (!openSSLCert) {
590 return false; 593 return false;
591 } 594 }
595
592 if (SSL_CTX_add_extra_chain_cert(context_.get(), openSSLCert->getInternalX509().get()) != 1) { 596 if (SSL_CTX_add_extra_chain_cert(context_.get(), openSSLCert->getInternalX509().get()) != 1) {
593 SWIFT_LOG(warning) << "Trying to load empty certificate chain." << std::endl; 597 SWIFT_LOG(warning) << "Trying to load empty certificate chain." << std::endl;
594 return false; 598 return false;
595 } 599 }
596 certificate->release(); 600
601 openSSLCert->incrementReferenceCount();
597 } 602 }
598 } 603 }
599 604
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h
index 885b1fe..8eb5758 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h
@@ -46,7 +46,7 @@ namespace Swift {
46 void connect() override final; 46 void connect() override final;
47 void connect(const std::string& requestHostname) override final; 47 void connect(const std::string& requestHostname) override final;
48 48
49 bool setCertificateChain(std::vector<std::unique_ptr<Certificate>>&& certificateChain) override final; 49 bool setCertificateChain(const std::vector<std::shared_ptr<Certificate>>& certificateChain) override final;
50 bool setPrivateKey(const PrivateKey::ref& privateKey) override final; 50 bool setPrivateKey(const PrivateKey::ref& privateKey) override final;
51 bool setClientCertificate(CertificateWithKey::ref cert) override final; 51 bool setClientCertificate(CertificateWithKey::ref cert) override final;
52 void setAbortTLSHandshake(bool abort) override final; 52 void setAbortTLSHandshake(bool abort) override final;
diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp
index 666ea7f..fd31c2d 100644
--- a/Swiften/TLS/TLSContext.cpp
+++ b/Swiften/TLS/TLSContext.cpp
@@ -21,7 +21,7 @@ void TLSContext::connect(const std::string& /* serverName */) {
21 assert(false); 21 assert(false);
22} 22}
23 23
24bool TLSContext::setCertificateChain(std::vector<std::unique_ptr<Certificate>>&& /* certificateChain */) { 24bool TLSContext::setCertificateChain(const std::vector<std::shared_ptr<Certificate>>& /* certificateChain */) {
25 assert(false); 25 assert(false);
26 return false; 26 return false;
27} 27}
diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h
index 85776d8..f2dbdce 100644
--- a/Swiften/TLS/TLSContext.h
+++ b/Swiften/TLS/TLSContext.h
@@ -28,7 +28,7 @@ namespace Swift {
28 virtual void connect() = 0; 28 virtual void connect() = 0;
29 virtual void connect(const std::string& serverName); 29 virtual void connect(const std::string& serverName);
30 30
31 virtual bool setCertificateChain(std::vector<std::unique_ptr<Certificate>>&& /* certificateChain */); 31 virtual bool setCertificateChain(const std::vector<std::shared_ptr<Certificate>>& /* certificateChain */);
32 virtual bool setPrivateKey(const PrivateKey::ref& /* privateKey */); 32 virtual bool setPrivateKey(const PrivateKey::ref& /* privateKey */);
33 33
34 virtual bool setClientCertificate(CertificateWithKey::ref cert) = 0; 34 virtual bool setClientCertificate(CertificateWithKey::ref cert) = 0;