summaryrefslogtreecommitdiffstats
blob: 7726c41517676f168930ce455d69d8120bcad6f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
/*
 * Copyright (c) 2011 Soren Dreijer
 * Licensed under the simplified BSD license.
 * See Documentation/Licenses/BSD-simplified.txt for more information.
 */

#pragma once

#include "Swiften/Base/boost_bsignals.h"

#include "Swiften/TLS/TLSContext.h"
#include "Swiften/TLS/Schannel/SchannelUtil.h"
#include <Swiften/TLS/CertificateWithKey.h>
#include "Swiften/Base/ByteArray.h"

#define SECURITY_WIN32
#include <Windows.h>
#include <Schannel.h>
#include <security.h>
#include <schnlsp.h>

#include <boost/noncopyable.hpp>

namespace Swift 
{	
	class SchannelContext : public TLSContext, boost::noncopyable 
	{
	public:
		typedef boost::shared_ptr<SchannelContext> sp_t;

	public:
		SchannelContext();

		~SchannelContext();

		//
		// TLSContext
		//
		virtual void	connect();
		virtual bool	setClientCertificate(CertificateWithKey::ref cert);

		virtual void	handleDataFromNetwork(const SafeByteArray& data);
		virtual void	handleDataFromApplication(const SafeByteArray& data);

		virtual Certificate::ref getPeerCertificate() const;
		virtual CertificateVerificationError::ref getPeerCertificateVerificationError() const;

		virtual ByteArray getFinishMessage() const;

	private:
		void			determineStreamSizes();
		void			continueHandshake(const SafeByteArray& data);
		void			indicateError();

		void			sendDataOnNetwork(const void* pData, size_t dataSize);
		void			forwardDataToApplication(const void* pData, size_t dataSize);

		void			decryptAndProcessData(const SafeByteArray& data);
		void			encryptAndSendData(const SafeByteArray& data);

		void			appendNewData(const SafeByteArray& data);

	private:
		enum SchannelState
		{
			Start,
			Connecting,
			Connected,
			Error

		};

		SchannelState		m_state;
		CertificateVerificationError m_verificationError;

		ULONG				m_secContext;
		ScopedCredHandle	m_credHandle;
		ScopedCtxtHandle	m_ctxtHandle;
		DWORD				m_ctxtFlags;
		SecPkgContext_StreamSizes m_streamSizes;

		std::vector<char>	m_receivedData;

		HCERTSTORE		m_my_cert_store;
		std::string		m_cert_store_name;
		std::string		m_cert_name;
	};
}