/*
 * Copyright (c) 2011 Soren Dreijer
 * Licensed under the simplified BSD license.
 * See Documentation/Licenses/BSD-simplified.txt for more information.
 */

/*
 * Copyright (c) 2016 Isode Limited.
 * All rights reserved.
 * See the COPYING file for more information.
 */

#pragma once

#define SECURITY_WIN32
#include <Windows.h>
#include <Schannel.h>
#include <security.h>
#include <schnlsp.h>

#include <boost/noncopyable.hpp>

namespace Swift
{
    //
    // Convenience wrapper around the Schannel CredHandle struct.
    //
    class ScopedCredHandle
    {
    private:
        struct HandleContext
        {
            HandleContext()
            {
                ZeroMemory(&m_h, sizeof(m_h));
            }

            HandleContext(const CredHandle& h)
            {
                memcpy(&m_h, &h, sizeof(m_h));
            }

            ~HandleContext()
            {
                ::FreeCredentialsHandle(&m_h);
            }

            CredHandle m_h;
        };

    public:
        ScopedCredHandle()
        : m_pHandle( new HandleContext )
        {
        }

        explicit ScopedCredHandle(const CredHandle& h)
        : m_pHandle( new HandleContext(h) )
        {
        }

        // Copy constructor
        explicit ScopedCredHandle(const ScopedCredHandle& rhs)
        {
            m_pHandle = rhs.m_pHandle;
        }

        ~ScopedCredHandle()
        {
            m_pHandle.reset();
        }

        PCredHandle Reset()
        {
            CloseHandle();
            return &m_pHandle->m_h;
        }

        operator PCredHandle() const
        {
            return &m_pHandle->m_h;
        }

        ScopedCredHandle& operator=(const ScopedCredHandle& sh)
        {
            // Only update the internal handle if it's different
            if (&m_pHandle->m_h != &sh.m_pHandle->m_h)
            {
                m_pHandle = sh.m_pHandle;
            }

            return *this;
        }

        void CloseHandle()
        {
            m_pHandle.reset( new HandleContext );
        }

    private:
        std::shared_ptr<HandleContext> m_pHandle;
    };

    //------------------------------------------------------------------------

    //
    // Convenience wrapper around the Schannel CtxtHandle struct.
    //
    class ScopedCtxtHandle
    {
    private:
        struct HandleContext
        {
            HandleContext()
            {
                ZeroMemory(&m_h, sizeof(m_h));
            }

            ~HandleContext()
            {
                ::DeleteSecurityContext(&m_h);
            }

            CtxtHandle m_h;
        };

    public:
        ScopedCtxtHandle()
        : m_pHandle( new HandleContext )
        {
        }

        explicit ScopedCtxtHandle(CredHandle h)
        : m_pHandle( new HandleContext )
        {
        }

        // Copy constructor
        explicit ScopedCtxtHandle(const ScopedCtxtHandle& rhs)
        {
            m_pHandle = rhs.m_pHandle;
        }

        ~ScopedCtxtHandle()
        {
            m_pHandle.reset();
        }

        PCredHandle Reset()
        {
            CloseHandle();
            return &m_pHandle->m_h;
        }

        operator PCredHandle() const
        {
            return &m_pHandle->m_h;
        }

        ScopedCtxtHandle& operator=(const ScopedCtxtHandle& sh)
        {
            // Only update the internal handle if it's different
            if (&m_pHandle->m_h != &sh.m_pHandle->m_h)
            {
                m_pHandle = sh.m_pHandle;
            }

            return *this;
        }

        void CloseHandle()
        {
            m_pHandle.reset( new HandleContext );
        }

    private:
        std::shared_ptr<HandleContext> m_pHandle;
    };

    //------------------------------------------------------------------------

    //
    // Convenience wrapper around the Schannel ScopedSecBuffer struct.
    //
    class ScopedSecBuffer : boost::noncopyable
    {
    public:
        ScopedSecBuffer(PSecBuffer pSecBuffer)
        : m_pSecBuffer(pSecBuffer)
        {
        }

        ~ScopedSecBuffer()
        {
            // Loop through all the output buffers and make sure we free them
            if (m_pSecBuffer->pvBuffer)
                FreeContextBuffer(m_pSecBuffer->pvBuffer);
        }

        PSecBuffer AsPtr()
        {
            return m_pSecBuffer;
        }

        PSecBuffer operator->()
        {
            return m_pSecBuffer;
        }

    private:
        PSecBuffer m_pSecBuffer;
    };

    //------------------------------------------------------------------------

    //
    // Convenience wrapper around the Schannel PCCERT_CONTEXT.
    //
    class ScopedCertContext
    {
    private:
        struct HandleContext
        {
            HandleContext()
            : m_pCertCtxt(NULL)
            {
            }

            HandleContext(PCCERT_CONTEXT pCert)
            : m_pCertCtxt(pCert)
            {
            }

            ~HandleContext()
            {
                if (m_pCertCtxt)
                    CertFreeCertificateContext(m_pCertCtxt);
            }

            PCCERT_CONTEXT m_pCertCtxt;
        };

    public:
        ScopedCertContext()
        : m_pHandle( new HandleContext )
        {
        }

        explicit ScopedCertContext(PCCERT_CONTEXT pCert)
        : m_pHandle( new HandleContext(pCert) )
        {
        }

        // Copy constructor
        ScopedCertContext(const ScopedCertContext& rhs)
        {
            m_pHandle = rhs.m_pHandle;
        }

        ~ScopedCertContext()
        {
            m_pHandle.reset();
        }

        PCCERT_CONTEXT* Reset()
        {
            FreeContext();
            return &m_pHandle->m_pCertCtxt;
        }

        operator PCCERT_CONTEXT() const
        {
            return m_pHandle->m_pCertCtxt;
        }

        PCCERT_CONTEXT* GetPointer() const
        {
            return &m_pHandle->m_pCertCtxt;
        }

        PCCERT_CONTEXT operator->() const
        {
            return m_pHandle->m_pCertCtxt;
        }

        ScopedCertContext& operator=(const ScopedCertContext& sh)
        {
            // Only update the internal handle if it's different
            if (&m_pHandle->m_pCertCtxt != &sh.m_pHandle->m_pCertCtxt)
            {
                m_pHandle = sh.m_pHandle;
            }

            return *this;
        }

        ScopedCertContext& operator=(PCCERT_CONTEXT pCertCtxt)
        {
            // Only update the internal handle if it's different
            if (m_pHandle && m_pHandle->m_pCertCtxt != pCertCtxt)
                m_pHandle.reset( new HandleContext(pCertCtxt) );

            return *this;
        }

        void FreeContext()
        {
            m_pHandle.reset( new HandleContext );
        }

    private:
        std::shared_ptr<HandleContext> m_pHandle;
    };

    //------------------------------------------------------------------------

    //
    // Convenience wrapper around the Schannel HCERTSTORE.
    //
    class ScopedCertStore : boost::noncopyable
    {
    public:
        ScopedCertStore(HCERTSTORE hCertStore)
        : m_hCertStore(hCertStore)
        {
        }

        ~ScopedCertStore()
        {
            // Forcefully free all memory related to the store, i.e. we assume all CertContext's that have been opened via this
            // cert store have been closed at this point.
            if (m_hCertStore)
                CertCloseStore(m_hCertStore, CERT_CLOSE_STORE_FORCE_FLAG);
        }

        operator HCERTSTORE() const
        {
            return m_hCertStore;
        }

    private:
        HCERTSTORE m_hCertStore;
    };

    //------------------------------------------------------------------------

    //
    // Convenience wrapper around the Schannel CERT_CHAIN_CONTEXT.
    //
    class ScopedCertChainContext
    {
    private:
        struct HandleContext
        {
            HandleContext()
            : m_pCertChainCtxt(NULL)
            {
            }

            HandleContext(PCCERT_CHAIN_CONTEXT pCert)
            : m_pCertChainCtxt(pCert)
            {
            }

            ~HandleContext()
            {
                if (m_pCertChainCtxt)
                    CertFreeCertificateChain(m_pCertChainCtxt);
            }

            PCCERT_CHAIN_CONTEXT m_pCertChainCtxt;
        };

    public:
        ScopedCertChainContext()
        : m_pHandle( new HandleContext )
        {
        }

        explicit ScopedCertChainContext(PCCERT_CHAIN_CONTEXT pCert)
        : m_pHandle( new HandleContext(pCert) )
        {
        }

        // Copy constructor
        ScopedCertChainContext(const ScopedCertChainContext& rhs)
        {
            m_pHandle = rhs.m_pHandle;
        }

        ~ScopedCertChainContext()
        {
            m_pHandle.reset();
        }

        PCCERT_CHAIN_CONTEXT* Reset()
        {
            FreeContext();
            return &m_pHandle->m_pCertChainCtxt;
        }

        operator PCCERT_CHAIN_CONTEXT() const
        {
            return m_pHandle->m_pCertChainCtxt;
        }

        PCCERT_CHAIN_CONTEXT operator->() const
        {
            return m_pHandle->m_pCertChainCtxt;
        }

        ScopedCertChainContext& operator=(const ScopedCertChainContext& sh)
        {
            // Only update the internal handle if it's different
            if (&m_pHandle->m_pCertChainCtxt != &sh.m_pHandle->m_pCertChainCtxt)
            {
                m_pHandle = sh.m_pHandle;
            }

            return *this;
        }

        void FreeContext()
        {
            m_pHandle.reset( new HandleContext );
        }

    private:
        std::shared_ptr<HandleContext> m_pHandle;
    };
}