
/*  ************************************************************************  *
 *                                profile.cpp                                 *
 *  ************************************************************************  */

#include    "stdinc.h"

#include    <stdio.h>

#include    "puterror.h"

#include    "profapi.h"
#include    "profile.h"
#include    "profseg.h"

/*  This source file's use of various undocumented NTDLL imports assumes
    they will be resolved from an import library, saving us the trouble of
    mucking around with GetProcAddress. The following direction to use a
    suitable import library, supplied with all driver kits, is ignored if
    building with the Window Driver Kit's BUILD tool, but may save trouble
    when building some other way.  */

#pragma comment (lib, "ntdll.lib")

/*  ************************************************************************  */
/*  Forward references  */

static VOID PutStatus (NTSTATUS, PCSTR, ...);

static NTSTATUS SetInterval (ULONG, KPROFILE_SOURCE);
static NTSTATUS RestoreInterval (ULONG, KPROFILE_SOURCE);

static DWORD EnableSinglePrivilege (PCWSTR);

/*  ************************************************************************  */
/*  Markers for the start and end of the profiled region  */

extern __declspec (allocate (PROFILE_START)) BYTE const ProfileStart = 0;
extern __declspec (allocate (PROFILE_END)) BYTE const ProfileEnd = 0;

/*  ************************************************************************  */
/*  CProfile implementation  */

CProfile :: CProfile (VOID)
{
    m_Buffer = NULL;
    m_OldInterval = 0;
    m_ProfileHandle = NULL;
}

CProfile :: ~CProfile (VOID)
{
    if (m_ProfileHandle != NULL) CloseHandle (m_ProfileHandle);
    if (m_OldInterval != 0) RestoreInterval (m_OldInterval, m_Source);
    if (m_Buffer != NULL) delete m_Buffer;
}

DWORD
CProfile :: Init (
    ULONG BucketSize,
    ULONG Interval,
    KPROFILE_SOURCE Source,
    ULONG *BucketCount)
{
    /*  For our purposes of inspecting how well some algorithm runs, we want
        the smallest bucket possible, which is 4 bytes. Still, anticipate
        some generalisation.  */

    if (BucketSize == 0 OR BucketSize & (BucketSize - 1)) {
        PutError ("Bucket size 0x%08X is not a power of 2", BucketSize);
        return ERROR_INVALID_PARAMETER;
    }

    ULONG bucketshift;
    BitScanReverse (&bucketshift, BucketSize);

    /*  Retrieve the start address and size of the region of our code that
        we're to profile. Though we represent these by macros for
        convenience, neither is known at compile-time. Let's take it as
        unthinkable that the profiled area is too big to fit 32 bits.  */

    PVOID profilebase = (PVOID) PROFILE_BASE;
    ULONG profilesize = (ULONG) PROFILE_SIZE;

    /*  Compute what size of buffer will be needed for one ULONG execution
        count per bucket that spans the profiled region.  */

    ULONG numbuckets = profilesize >> bucketshift;
    if (profilesize & (BucketSize - 1)) numbuckets ++;

    /*  Get that buffer (trusting the compiler's generation of operator new
        to catch an, unlikely, overflow).  */

    m_Buffer = new ULONG [numbuckets];
    if (m_Buffer == NULL) {
        PutMemoryError ();
        return ERROR_NOT_ENOUGH_MEMORY;
    }

    ULONG buffersize = numbuckets * sizeof (ULONG);

    /*  Verify that the given profile source is supported. The easiest way
        to do this also has obtains the source's current setting for the
        profile interval, which we may as well report.

        From here on, use of the "native function" API means that we, like
        many a low-level user-mode program by Microsoft, are using NTSTATUS
        values not Win32 error codes.  */

    ULONG interval;
    NTSTATUS status = NtQueryIntervalProfile (Source, &interval);
    if (NOT NT_SUCCESS (status)) {
        PutStatus (status, "querying interval for source %u", Source);
    }
    else if (interval == 0) {
        PutError ("Source %u is not supported", (ULONG) Source);
        status = STATUS_NOT_SUPPORTED;
    }
    else {

        PutInfo ("Interval for source %u is 0x%08X", Source, interval);

        /*  If a different interval is specified, try setting it.  */

        if (Interval != 0 AND Interval != interval) {

            m_OldInterval = interval;
            m_Source = Source;

            status = SetInterval (Interval, Source);
        }

        if (NT_SUCCESS (status)) {

            /*  Have Windows create a profile object from the given
                parameters and others. We profile our process only, whatever
                processors it gets to run on. We have a specified profile
                source and may already have set the corresponding interval
                for the recurrence of profile interrupts. Execution counts
                go into the buffer we've created.

                The use of a KAFFINITY that has all bits set to mean all
                processors, even ones that our processor isn't currently
                affinitised to, is only partly supported in version 6.1 and
                higher. In these later versions it means all processors in
                the current processor group - and it only works for a 32-bit
                program on 32-bit Windows or a 64-bit program on 64-bit
                Windows, not for 32-bit on 64-bit. (Demonstrating this, as a
                side-line, is some of the point to coding this way.)  */

            #if _WIN32_WINNT < 0x0601

            status = NtCreateProfile (
                        &m_ProfileHandle,
                        GetCurrentProcess (),
                        profilebase,
                        profilesize,
                        bucketshift,
                        m_Buffer,
                        buffersize,
                        Source,
                        (KAFFINITY) -1);
            #else

            /*  If the program does not need to run on earlier versions than
                Windows 7, then "profile all processors" can be arranged by
                supplying no processor specification.  */

            status = NtCreateProfileEx (
                        &m_ProfileHandle,
                        GetCurrentProcess (),
                        profilebase,
                        profilesize,
                        bucketshift,
                        m_Buffer,
                        buffersize,
                        Source,
                        0,
                        NULL);
            #endif

            if (NOT NT_SUCCESS (status)) {
                PutStatus (status, "creating profile");
            }
            else {

                m_BufferSize = buffersize;
                m_BucketCount = numbuckets;

                *BucketCount = numbuckets;

                return ERROR_SUCCESS;
            }
        }
    }
    return RtlNtStatusToDosError (status);
}

DWORD CProfile :: Start (VOID)
{
    /*  Once profiling starts, the execution counts in our buffer get
        incremented. What they first get incremented from is up to us:
        reset the buffer each time.  */

    RtlZeroMemory (m_Buffer, m_BufferSize);

    NTSTATUS status = NtStartProfile (m_ProfileHandle);
    if (NOT NT_SUCCESS (status)) {
        PutStatus (status, "starting profile");
    }

    return RtlNtStatusToDosError (status);
}

DWORD CProfile :: Stop (ULONG **Buckets)
{
    NTSTATUS status = NtStopProfile (m_ProfileHandle);
    if (NOT NT_SUCCESS (status)) {
        PutStatus (status, "stopping profile");
        return RtlNtStatusToDosError (status);
    }

    /*  Now that the execution counts are not being incremented on receipt
        of profile interrupts, give our caller a copy of the results.  */

    ULONG *buckets = new ULONG [m_BucketCount];
    if (buckets == NULL) {
        PutMemoryError ();
        return ERROR_NOT_ENOUGH_MEMORY;
    }

    memcpy (buckets, m_Buffer, m_BufferSize);

    *Buckets = buckets;
    return ERROR_SUCCESS;
}

/*  ************************************************************************  */
/*  Helpers  */

VOID PutStatus (NTSTATUS Status, PCSTR Text, ...)
{
    va_list argptr;
    va_start (argptr, Text);
    printf ("\nStatus 0x%08X ", Status);
    vprintf (Text, argptr);
    printf ("\n");
    va_end (argptr);
}

/*  ========================================================================  */

NTSTATUS SetInterval (ULONG Interval, KPROFILE_SOURCE Source)
{
    NTSTATUS status;

    /*  Setting the interval requires privilege. If we already have it,
        fine. But prepare a redo.  */

    for (;;) {

        status = NtSetIntervalProfile (Interval, Source);
        if (NT_SUCCESS (status)) break;

        /*  If the reason we can't set the interval is not that we don't
            have the privilege, there's nothing we can do about it.  */

        if (status != STATUS_PRIVILEGE_NOT_HELD) {

            PutStatus (
                status,
                "setting interval 0x%08X for source %u",
                Interval,
                Source);

            break;
        }

        /*  Given that we don't have the privilege, see if we can enable
            it. If so, we can try again to set the interval. Here, for
            demonstration, we leave the privilege enabled. A real-world
            program might better disable it afterwards, until it's needed
            again for restoring the interval.

            The ugly, but official, way requires a string to name the
            privilege. Kernel-mode programmers get a constant.  */

        DWORD ec = EnableSinglePrivilege (SE_SYSTEM_PROFILE_NAME);
        if (ec != ERROR_SUCCESS) {

            PutError (ec, "enabling SeSystemProfilePrivilege");
            break;
        }
    }

    /*  That we successfully set the interval does not mean that the
        interval is now exactly what we set. If only for informational
        purposes, report what the setting now seems to be (but don't let an
        error spoil what is otherwise our success).  */

    if (NT_SUCCESS (status)) {

        ULONG newinterval;
        NTSTATUS status = NtQueryIntervalProfile (Source, &newinterval);
        if (NOT NT_SUCCESS (status)) {

            PutStatus (status, "querying new interval for source %u", Source);
        }
        else {

            PutInfo (
                "Interval for source %u changed to 0x%08X",
                Source,
                newinterval);
        }
    }
    return status;
}

NTSTATUS RestoreInterval (ULONG OldInterval, KPROFILE_SOURCE Source)
{
    NTSTATUS status = NtSetIntervalProfile (OldInterval, Source);
    if (NOT NT_SUCCESS (status)) {

        PutStatus (
            status,
            "restoring interval 0x%08X for source %u",
            OldInterval,
            Source);
    }
    return status;
}

/*  ========================================================================  */

DWORD EnableSinglePrivilege (PCWSTR PrivilegeName)
{
    DWORD ec;

    HANDLE token;
    BOOL ok = OpenProcessToken (
                GetCurrentProcess (),
                TOKEN_ADJUST_PRIVILEGES,
                &token);
    if (NOT ok) {
        ec = GetLastError ();
    }
    else {

        TOKEN_PRIVILEGES tp;
        tp.PrivilegeCount = 1;
        ok = LookupPrivilegeValue (
                NULL,
                PrivilegeName,
                &tp.Privileges [0].Luid);
        if (NOT ok) {
            ec = GetLastError ();
        }
        else {
            tp.Privileges [0].Attributes = SE_PRIVILEGE_ENABLED;

            ok = AdjustTokenPrivileges (
                    token,
                    FALSE,
                    &tp,
                    sizeof (tp),
                    NULL,
                    NULL);
            if (NOT ok) {
                ec = GetLastError ();
            }
            else {
                ec = ERROR_SUCCESS;
            }
        }
        CloseHandle (token);
    }
    return ec;
}

/*  ************************************************************************  */

