summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'Swiften/TLS')
-rw-r--r--Swiften/TLS/CertificateFactory.cpp4
-rw-r--r--Swiften/TLS/CertificateFactory.h2
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp8
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLCertificate.h2
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.cpp8
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h2
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLContext.cpp9
-rw-r--r--Swiften/TLS/OpenSSL/OpenSSLContext.h2
-rw-r--r--Swiften/TLS/TLSContext.cpp2
-rw-r--r--Swiften/TLS/TLSContext.h2
10 files changed, 28 insertions, 13 deletions
diff --git a/Swiften/TLS/CertificateFactory.cpp b/Swiften/TLS/CertificateFactory.cpp
index aaf27d9..d4db3f4 100644
--- a/Swiften/TLS/CertificateFactory.cpp
+++ b/Swiften/TLS/CertificateFactory.cpp
@@ -23,9 +23,9 @@ namespace Swift {
CertificateFactory::~CertificateFactory() {
}
-std::vector<std::unique_ptr<Certificate>> CertificateFactory::createCertificateChain(const ByteArray& /* data */) {
+std::vector<std::shared_ptr<Certificate>> CertificateFactory::createCertificateChain(const ByteArray& /* data */) {
assert(false);
- return std::vector<std::unique_ptr<Certificate>>();
+ return std::vector<std::shared_ptr<Certificate>>();
}
PrivateKey::ref CertificateFactory::createPrivateKey(const SafeByteArray& data, boost::optional<SafeByteArray> password) {
diff --git a/Swiften/TLS/CertificateFactory.h b/Swiften/TLS/CertificateFactory.h
index 619031c..873c36b 100644
--- a/Swiften/TLS/CertificateFactory.h
+++ b/Swiften/TLS/CertificateFactory.h
@@ -19,7 +19,7 @@ namespace Swift {
virtual ~CertificateFactory();
virtual Certificate* createCertificateFromDER(const ByteArray& der) = 0;
- virtual std::vector<std::unique_ptr<Certificate>> createCertificateChain(const ByteArray& data);
+ virtual std::vector<std::shared_ptr<Certificate>> createCertificateChain(const ByteArray& data);
PrivateKey::ref createPrivateKey(const SafeByteArray& data, boost::optional<SafeByteArray> password = boost::optional<SafeByteArray>());
};
}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
index 8d2d965..bb51428 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLCertificate.cpp
@@ -37,6 +37,14 @@ OpenSSLCertificate::OpenSSLCertificate(const ByteArray& der) {
parse();
}
+void OpenSSLCertificate::incrementReferenceCount() const {
+#if OPENSSL_VERSION_NUMBER >= 0x10100000L
+ X509_up_ref(cert.get());
+#else
+ CRYPTO_add(&(cert.get()->references), 1, CRYPTO_LOCK_EVP_PKEY);
+#endif
+}
+
ByteArray OpenSSLCertificate::toDER() const {
ByteArray result;
if (!cert) {
diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificate.h b/Swiften/TLS/OpenSSL/OpenSSLCertificate.h
index 186caea..64da82a 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLCertificate.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLCertificate.h
@@ -45,6 +45,8 @@ namespace Swift {
return cert;
}
+ void incrementReferenceCount() const;
+
private:
void parse();
diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.cpp b/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.cpp
index 5eb626b..fd94ec8 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.cpp
@@ -20,8 +20,8 @@ Certificate* OpenSSLCertificateFactory::createCertificateFromDER(const ByteArray
return new OpenSSLCertificate(der);
}
-std::vector<std::unique_ptr<Certificate>> OpenSSLCertificateFactory::createCertificateChain(const ByteArray& data) {
- std::vector<std::unique_ptr<Certificate>> certificateChain;
+std::vector<std::shared_ptr<Certificate>> OpenSSLCertificateFactory::createCertificateChain(const ByteArray& data) {
+ std::vector<std::shared_ptr<Certificate>> certificateChain;
if (data.size() > std::numeric_limits<int>::max()) {
return certificateChain;
@@ -35,11 +35,11 @@ std::vector<std::unique_ptr<Certificate>> OpenSSLCertificateFactory::createCerti
auto x509certFromPEM = PEM_read_bio_X509(bio.get(), &openSSLCert, nullptr, nullptr);
if (x509certFromPEM && openSSLCert) {
std::shared_ptr<X509> x509Cert(openSSLCert, X509_free);
- certificateChain.emplace_back(std::make_unique<OpenSSLCertificate>(x509Cert));
+ certificateChain.emplace_back(std::make_shared<OpenSSLCertificate>(x509Cert));
openSSLCert = nullptr;
while ((x509certFromPEM = PEM_read_bio_X509(bio.get(), &openSSLCert, nullptr, nullptr)) != nullptr) {
std::shared_ptr<X509> x509Cert(openSSLCert, X509_free);
- certificateChain.emplace_back(std::make_unique<OpenSSLCertificate>(x509Cert));
+ certificateChain.emplace_back(std::make_shared<OpenSSLCertificate>(x509Cert));
openSSLCert = nullptr;
}
}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h b/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h
index 48e9b2c..a6974c8 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLCertificateFactory.h
@@ -16,6 +16,6 @@ namespace Swift {
virtual ~OpenSSLCertificateFactory() override final;
virtual Certificate* createCertificateFromDER(const ByteArray& der) override final;
- virtual std::vector<std::unique_ptr<Certificate>> createCertificateChain(const ByteArray& data) override final;
+ virtual std::vector<std::shared_ptr<Certificate>> createCertificateChain(const ByteArray& data) override final;
};
}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
index 5c80976..32d6470 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.cpp
@@ -567,7 +567,7 @@ void OpenSSLContext::sendPendingDataToApplication() {
}
}
-bool OpenSSLContext::setCertificateChain(std::vector<std::unique_ptr<Certificate>>&& certificateChain) {
+bool OpenSSLContext::setCertificateChain(const std::vector<std::shared_ptr<Certificate>>& certificateChain) {
if (certificateChain.size() == 0) {
SWIFT_LOG(warning) << "Trying to load empty certificate chain." << std::endl;
return false;
@@ -583,17 +583,22 @@ bool OpenSSLContext::setCertificateChain(std::vector<std::unique_ptr<Certificate
return false;
}
+ // Increment reference count on certificate so that it does not get freed when the SSL context is destroyed
+ openSSLCert->incrementReferenceCount();
+
if (certificateChain.size() > 1) {
for (auto certificate = certificateChain.begin() + 1; certificate != certificateChain.end(); ++certificate) {
auto openSSLCert = dynamic_cast<OpenSSLCertificate*>(certificate->get());
if (!openSSLCert) {
return false;
}
+
if (SSL_CTX_add_extra_chain_cert(context_.get(), openSSLCert->getInternalX509().get()) != 1) {
SWIFT_LOG(warning) << "Trying to load empty certificate chain." << std::endl;
return false;
}
- certificate->release();
+
+ openSSLCert->incrementReferenceCount();
}
}
diff --git a/Swiften/TLS/OpenSSL/OpenSSLContext.h b/Swiften/TLS/OpenSSL/OpenSSLContext.h
index 885b1fe..8eb5758 100644
--- a/Swiften/TLS/OpenSSL/OpenSSLContext.h
+++ b/Swiften/TLS/OpenSSL/OpenSSLContext.h
@@ -46,7 +46,7 @@ namespace Swift {
void connect() override final;
void connect(const std::string& requestHostname) override final;
- bool setCertificateChain(std::vector<std::unique_ptr<Certificate>>&& certificateChain) override final;
+ bool setCertificateChain(const std::vector<std::shared_ptr<Certificate>>& certificateChain) override final;
bool setPrivateKey(const PrivateKey::ref& privateKey) override final;
bool setClientCertificate(CertificateWithKey::ref cert) override final;
void setAbortTLSHandshake(bool abort) override final;
diff --git a/Swiften/TLS/TLSContext.cpp b/Swiften/TLS/TLSContext.cpp
index 666ea7f..fd31c2d 100644
--- a/Swiften/TLS/TLSContext.cpp
+++ b/Swiften/TLS/TLSContext.cpp
@@ -21,7 +21,7 @@ void TLSContext::connect(const std::string& /* serverName */) {
assert(false);
}
-bool TLSContext::setCertificateChain(std::vector<std::unique_ptr<Certificate>>&& /* certificateChain */) {
+bool TLSContext::setCertificateChain(const std::vector<std::shared_ptr<Certificate>>& /* certificateChain */) {
assert(false);
return false;
}
diff --git a/Swiften/TLS/TLSContext.h b/Swiften/TLS/TLSContext.h
index 85776d8..f2dbdce 100644
--- a/Swiften/TLS/TLSContext.h
+++ b/Swiften/TLS/TLSContext.h
@@ -28,7 +28,7 @@ namespace Swift {
virtual void connect() = 0;
virtual void connect(const std::string& serverName);
- virtual bool setCertificateChain(std::vector<std::unique_ptr<Certificate>>&& /* certificateChain */);
+ virtual bool setCertificateChain(const std::vector<std::shared_ptr<Certificate>>& /* certificateChain */);
virtual bool setPrivateKey(const PrivateKey::ref& /* privateKey */);
virtual bool setClientCertificate(CertificateWithKey::ref cert) = 0;