65.9K
CodeProject 正在变化。 阅读更多。
Home

通过 C++ 并行化优化 Win32 SID 转换

starIconstarIconstarIconstarIconstarIcon

5.00/5 (3投票s)

2022 年 10 月 6 日

MIT

6分钟阅读

viewsIcon

5570

downloadIcon

185

当需要执行多个相似操作时,如何通过并行化来改进耗时的操作

引言

Microsoft Visual Studio 提供了并行编程库,其中包含许多函数,可以轻松地并行执行代码。其中一个函数是 parallel_for,我们将在此对其进行探讨。

背景

在我之前的文章中,我提供了一种从用户访问令牌中提取所有特权和组的方法。为了在更复杂的环境中测试我的代码,我在我的工作笔记本电脑上使用我的公司用户帐户运行了它。正如大型企业环境中的常见情况一样,我的帐户是许多组的成员。具体来说是 256 个。实际上,我怀疑我的帐户会是更多组的成员,但在令牌中有一个 256 个组 SID 的硬性限制。

当我的计算机连接到公司网络时,此扫描耗时 0.8 秒,几乎难以察觉。但当我使用 VPN 从家连接时,耗时 4.5 秒。如果这涉及到 GUI 更新,这是不可接受的。

大部分时间都花在了 Win32 API 调用上。由于每次查找本质上都与其他查找无关,因此这是一个我们可以节省时间的好例子。

实现 parallel_for 循环

这是我们开始的代码,它是一个单线程的 for 循环。

ULONG w32_GetTokenGroups(HANDLE hToken, vector<w32_CUserGroup>& groups) {
    DWORD bufLength = 0;
    DWORD retVal = NO_ERROR;
    PTOKEN_GROUPS pGroups = NULL;

    retVal = w32_GetTokenInformation(hToken, TokenGroups, (PVOID&)pGroups, bufLength);

    if (retVal == NO_ERROR) {
        groups.resize(pGroups->GroupCount);

        TCHAR name[w32_MAX_GROUPNAME_LENGTH];
        DWORD nameSize = sizeof(name);
        TCHAR domain[w32_MAX_DOMAINNAME_LENGTH];
        DWORD domainSize = sizeof(domain);
        SID_NAME_USE sidType;

        for (DWORD i = 0; i < pGroups->GroupCount; i++) {
            //Allocate buffers

            DWORD nameSize = sizeof(name);
            DWORD domainSize = sizeof(domain);

            PSID pSid = pGroups->Groups[i].Sid;
            LPTSTR sidString = NULL;

            //get the name of the group and the domain it belongs to
            if (!LookupAccountSid(
                NULL, pSid, name, &nameSize, domain, &domainSize, &sidType)) {
                retVal = GetLastError();
                break;
            }

            //Get a human readable version of the SID
            if (!ConvertSidToStringSid(pSid, &sidString)) {
                retVal = GetLastError();
                break;
            }

            groups[i].Attributes = pGroups->Groups[i].Attributes;
            groups[i].Domain = domain;
            groups[i].Sid = sidString;
            groups[i].Name = name;
            groups[i].SidType = sidType;

            LocalFree(sidString);
        }
    }

    if (pGroups)
        HeapFree(GetProcessHeap(), 0, pGroups);

    if (retVal != NO_ERROR)
        groups.clear();

    return retVal;
}

有几点需要注意:

  1. 我们的操作结果存储在一个 vector 中。虽然 vector 在并发修改时不是线程安全的,但只要 1 个线程只访问 1 个项,修改 vector 中引用的项就是安全的。因为 vector 在 for 循环之前被调整了大小,所以这不是问题。
  2. ConvertSidToStringSid 的调用也是安全的,因为它的参数是局部于循环的。
  3. LookupAccountSid 的调用是不安全的,因为它引用了循环外部的变量。对于单线程循环来说,这是一个合理的决定,因为它减少了栈分配。但总体而言,与实际操作相比,效率低下的问题并不重要。

parallel_for 辅助函数具有此签名:

template <typename _Index_type, typename _Function>
void parallel_for(
    _Index_type first,
    _Index_type last,
    const _Function& _Func,
    const auto_partitioner& _Part = auto_partitioner());

就像普通的 for 循环一样,您需要提供起始索引和大小。它有点令人困惑,因为它说 'last',但文档规定它是最后一个元素的下一个索引,所以如果您从 0 开始,'last' 实际上是元素的数量。最后,它接受要并行执行的函数。分区器会划分工作负载。除非您有特定理由需要不同的处理方式,否则请使用默认设置。

如果我们实现它,我们会得到这个:

    ULONG w32_GetTokenGroups(HANDLE hToken, vector<CGroupIdentity>& groups) {
        DWORD bufLength = 0;
        DWORD retVal = NO_ERROR;
        PTOKEN_GROUPS pGroups = NULL;

        retVal = w32_GetTokenInformation
                 (hToken, TokenGroups, (PVOID&)pGroups, bufLength);

        if (retVal == NO_ERROR) {
            groups.resize(pGroups->GroupCount);
            concurrency::parallel_for(size_t(0), 
                         (size_t)pGroups->GroupCount, [&](size_t i) {
                if (retVal == NO_ERROR) {
                    TCHAR name[MAX_GROUPNAME_LENGTH];
                    DWORD nameSize = sizeof(name);
                    TCHAR domain[MAX_DOMAINNAME_LENGTH];
                    DWORD domainSize = sizeof(domain);
                    SID_NAME_USE sidType;
 
                    PSID pSid = pGroups->Groups[i].Sid;
                    LPTSTR sidString = NULL;

                    //get the name of the group and the domain it belongs to
                    if (!LookupAccountSid(
                        NULL, pSid, name, &nameSize, domain, &domainSize, &sidType)) {
                        retVal = GetLastError();
                        return;
                    }

                    //Get a human readable version of the SID
                    if (!ConvertSidToStringSid(pSid, &sidString)) {
                        retVal = GetLastError();
                        return;
                    }

                    groups[i].Attributes = pGroups->Groups[i].Attributes;
                    groups[i].Domain = domain;
                    groups[i].Sid = sidString;
                    groups[i].Name = name;
                    groups[i].SidType = sidType;

                    LocalFree(sidString);
                }
            });
        }

        if (pGroups)
            HeapFree(GetProcessHeap(), 0, pGroups);

        if (retVal != NO_ERROR)
            groups.clear();

        return retVal;
    }

我们需要将缓冲区移到循环内部以使其安全,但仅此而已。通过 VPN 转换 256 个 SID 的时间已从 4500 毫秒减少到 1000 毫秒。这仍然很长,但对于屏幕更新中的一次性操作(记住令牌是不可变的),我可以接受 1000 毫秒的单次更新。

也许关于错误处理要说几句:看起来存在竞态条件:如果并行执行过程中,有 2 个线程遇到错误会怎样?正确答案是:写入 retVal 确实存在竞态。但那并不重要。唯一重要的是,如果发生“一个”错误,我们在循环结束后就会知道。哪个特定的迭代失败并不重要。而且由于每次迭代都在自己的线程中并行运行,所以它们不会相互影响。一旦 retVal 被设置为 NO_ERROR 以外的任何值,将不再执行任何操作。

Ppl 支持取消,但这需要循环检查一个取消令牌,而在我们简单的 lambda 函数等任务中,从中获得的收益不大。

使并行执行可选

通常我会在文章的这个时候结束,但如这里所述,大约每 7 次中就有 1 次,Symantec 认为在不同线程中执行并行查找是恶意活动的迹象,并在没有任何提示的情况下将可执行文件从系统中移除。

在企业环境中,病毒扫描器等设备由企业团队管理,他们不想听取抱怨,除非发生重大生产故障,否则他们不感兴趣。

尽管如此,我不想因为有一个环境导致并行化出现问题而限制我的应用程序,将其全部放在单线程中。事实证明,使并行行为可选并受一个简单的布尔值控制很容易。

毕竟,parallel_for 已经在使用 lambda 表达式。我们所要做的就是将其提高一个级别。我们可以通过这种方式轻松地重构转换例程:

        if (retVal == NO_ERROR) {
            groups.resize(pGroups->GroupCount);

            //Lambda expression to perform translation of a SID and store the results
            //In a preallocated Vector item
            auto translator = [&](size_t i) {
                if (retVal == NO_ERROR) {
                    TCHAR name[MAX_GROUPNAME_LENGTH];
                    DWORD nameSize = sizeof(name);
                    TCHAR domain[MAX_DOMAINNAME_LENGTH];
                    DWORD domainSize = sizeof(domain);
                    SID_NAME_USE sidType;

                    PSID pSid = pGroups->Groups[i].Sid;
                    groups[i].Attributes = pGroups->Groups[i].Attributes;

                    //get the name of the group and the domain it belongs to
                    if (!LookupAccountSid(
                        NULL, pSid, name, &nameSize, domain, &domainSize, &sidType)) {
                        retVal = GetLastError();
                        return;
                    }
                    groups[i].SidType = sidType;
                    groups[i].Name = name;
                    groups[i].Domain = domain;

                    LPTSTR sidString = NULL;
                    //Get a human readable version of the SID
                    if (!ConvertSidToStringSid(pSid, &sidString)) {
                        retVal = GetLastError();
                        return;
                    }
                    groups[i].Sid = sidString;
                    LocalFree(sidString);
                }
            };

            if (parallelGroupLookup) {
                concurrency::parallel_for(size_t(0), 
                (size_t)pGroups->GroupCount, translator);
            }
            else {
                for (int i = 0; i < pGroups->GroupCount; i++) {
                     translator(i);
                }
            }

借助 lambda 表达式和 ppl 框架,实现强大且灵活的并行化非常简单。

测量执行时间

为了正确测量不同操作的时间,我需要一种方便的方法来以高精度测量执行时间。标准的 GetTickCount 太粗糙了。为此,Microsoft 实现了一个高分辨率计数器,该计数器计算 CPU 滴答。

我的秒表类很简单:

    class CStopWatch
    {
        LARGE_INTEGER m_StartTime;
        LARGE_INTEGER m_Frequency;

    public:
        void Start();
        LONGLONG ElapsedMilliSec();
        LONGLONG ElapsedMicroSec();
    };

功能非常基本。启动秒表后,我们可以请求自启动以来的毫秒数或微秒数。内部变量是启动时间和滴答频率,我们只需要读取一次。

    void CStopWatch::Start()
    {
        QueryPerformanceFrequency(&m_Frequency);
        QueryPerformanceCounter(&m_StartTime);
    }
        
    LONGLONG CStopWatch::ElapsedMilliSec()
    {
        LARGE_INTEGER endingTime;
        LARGE_INTEGER elapsedMilliseconds;
        QueryPerformanceCounter(&endingTime);

        elapsedMilliseconds.QuadPart = endingTime.QuadPart - m_StartTime.QuadPart;
        elapsedMilliseconds.QuadPart *= 1000;
        elapsedMilliseconds.QuadPart /= m_Frequency.QuadPart;

        return elapsedMilliseconds.QuadPart;
    }

    LONGLONG CStopWatch::ElapsedMicroSec()
    {
        LARGE_INTEGER endingTime;
        LARGE_INTEGER elapsedMicroseconds;
        QueryPerformanceCounter(&endingTime);

        elapsedMicroseconds.QuadPart = endingTime.QuadPart - m_StartTime.QuadPart;
        elapsedMicroseconds.QuadPart *= 1000000;
        elapsedMicroseconds.QuadPart /= m_Frequency.QuadPart;

        return elapsedMicroseconds.QuadPart;
    }

这里的繁重工作大部分由 QueryPerformanceFrequency(获取滴答频率)和 QueryPerformanceCounter(获取滴答数)完成。由于我们进行整数运算,因此在除以频率之前先将滴答数乘以频率,以提高精度。

使用此秒表,我们可以这样测量时间:

    vector<CGroupIdentity> groups;
    CStopWatch stopwatch;
    stopwatch.Start();
    w32_GetTokenGroups(GetCurrentThreadEffectiveToken(), groups, false);
    wcout << dec <<groups.size() << L" Groups for the current thread in " <<
        dec << stopwatch.ElapsedMilliSec() << L" milliseconds" << endl;
    for (auto group : groups) {
        wcout << group.Name << L"    " << group.Sid << endl;
    }

关注点

我测量了直接连接到网络时的执行时间,以及通过 VPN 连接时的执行时间。显然,通过 VPN 测量的会受到外部变化的影响。

直接在网络上,这是并行化的结果。

没有涉及人为延迟,而且我的笔记本电脑有 12 个 CPU 核心,因此 10 倍的速度提升大约是我们所能期望的。回想在拥有 2 个核心的系统在当时是“高大上”并且需要 2 个物理 CPU 和昂贵主板的时代学习多线程编程,我发现通过最少的重构就能获得这种改进,真是太棒了。

另一方面,通过 VPN 连接时,结果不太令人印象深刻。

速度提升“仅”为 5 倍。我重复了这个测试几次,每次测试都实现了 4 到 5 倍的提速。我不了解我们的 VPN 架构的具体实现细节,但我怀疑为了确保客户端不会导致 VPN 连接点过载,存在某种限速。无论如何,5 倍的提升仍然很可观。更重要的是,600 毫秒对于用户界面更新中的一次性操作来说是可接受的范围。

历史

  • 2022 年 9 月 7 日:第一个版本
© . All rights reserved.