65.9K
CodeProject 正在变化。 阅读更多。
Home

SSE 加速的忽略大小写的子字符串搜索

starIconstarIconstarIconstarIcon
emptyStarIcon
starIcon

4.89/5 (14投票s)

2012 年 5 月 13 日

CPOL

7分钟阅读

viewsIcon

30281

downloadIcon

301

经典问题的另一种解法,现在使用SSE指令

引言

我一直很好奇,为什么C/C++(由CRT库提供)中有一个`strstr`函数,却没有预期中的不区分大小写`stristr`函数(或者在宽字符集模式下:`wcsistr`)。

尽管在网上搜索能找到足够的结果,但大多数都写得很差,通常速度非常慢。最基本的实现,能给出正确结果的,首先将两个字符串都转换为小写(或大写),然后执行常规的子字符串搜索。不幸的是,内存分配和大小写转换例程使它们慢得令人痛苦。一些更快的替代方案使用汇编代码,例如Ralph Walden的这个。不幸的是,由于内联汇编代码,它不能用于64位构建,并且不能考虑不断变化的区域设置。其他方法则遵循完全不同的路径,尝试通过使用高级字符串搜索算法来加速搜索,例如Knuth–Morris–Pratt(如今在GNU C库中使用)或Boyer–Moore。然而,这种算法容易出错,并且在搜索小字符串时可能会导致显著的开销。

背景  

一个基本实现

让我们首先编写一个在不优化的情况下表现良好,并且能正确处理Unicode字符和选定区域设置的函数。

作为起点,我们将查看Microsoft CRT库中的常规子字符串搜索函数。您计算机上要查找的文件可能位于`C:\Program Files\Microsoft Visual Studio XX\VC\crt\src\intel\wcsstr.c`。该函数分为三个主要部分:第一个分支在检测到SSE4.2时执行,第二个分支用于SSE2,最后一个部分在这些都不存在时执行。

查看函数的最后一部分,它是纯C语言中常规子字符串搜索的一个非常基本的实现。请记住:`wcs1`是要搜索的字符串,`wcs2`是要查找的子字符串。对代码稍作修改,我们就可以获得一个不区分大小写的版本。

子字符串的第一个字符被转换为小写和大写,其值存储在`l`和`u`中。外层for循环遍历字符串,并将字符与`l`和`u`进行比较。如果字符匹配,则通过在内层while循环中即时将字符串和子字符串的字符转换为小写来比较子字符串的其余部分。

const wchar_t* wcsistr(const wchar_t *wcs1, const wchar_t *wcs2)
{
    const wchar_t *s1, *s2;
    const wchar_t l = towlower(*wcs2);
    const wchar_t u = towupper(*wcs2);
    
    if (!*wcs2)
        return wcs1; // an empty substring matches everything
    
    for (; *wcs1; ++wcs1)
    {
        if (*wcs1 == l || *wcs1 == u)
        {
            s1 = wcs1 + 1;
            s2 = wcs2 + 1;
            
            while (*s1 && *s2 && towlower(*s1) == towlower(*s2))
                ++s1, ++s2;
            
            if (!*s2)
                return wcs1;
        }
    }
 
    return NULL;
} 

所以我们有了它,一个已经相当好用的不区分大小写的子字符串搜索函数。尽管该函数表现良好,但它仍然比其常规对应项`wcsstr`慢约8~10倍。很明显,需要进行强力优化。

重要提示:字符的大小写转换是使用函数towlowertowupper完成的。为了使这些函数(以及其他涉及大小写转换的CRT函数,例如_wcsicmp)能够正确处理Unicode字符,您应该使用_wsetlocale设置区域设置。您可以使用以下方式设置默认系统区域设置:

_wsetlocale(LC_ALL, L"");     

缓存大小写转换

仔细检查函数后,发现`towlower`和`towupper`函数对性能至关重要,不幸的是它们相当慢。幸运的是,可以通过使用缓存机制来解决这个问题。一个合理的假设是,搜索函数处理的大多数文本很可能是西方语言,这基本上意味着只有0-255范围内的字符代码被广泛使用。

进一步阐述这个想法导致引入了三个新的`map`函数,它们应该取代其非`map`对应项。您现在需要在程序开始时调用`_wsetlocale_map`函数,因为除了设置区域设置外,它还将初始化缓存。

wchar_t g_mapToLower[256];
wchar_t g_mapToUpper[256];
 
wchar_t* _wsetlocale_map(int category, const wchar_t *locale)
{
	wchar_t *ret = _wsetlocale(category, locale);
	wchar_t c;
	
	for (c = 0; c < 256; ++c)
	{
		g_mapToLower[c] = towlower(c);
		g_mapToUpper[c] = towupper(c);
	}
 
	return ret;
}
 
wchar_t towlower_map(wchar_t c)
	{return (c < 256 ? g_mapToLower[c] : towlower(c));}
 
wchar_t towupper_map(wchar_t c)
	{return (c < 256 ? g_mapToUpper[c] : towupper(c));} 

通过使用这种缓存机制,速度得到了极大的提升,不区分大小写的搜索现在比`wcsstr`慢约3.1~3.4倍(在西方语言的情况下)。进一步的优化应该在C/C++语言的范围之外寻找。

检测SSE可用性

在使用SSE2指令之前,我们必须确保它在系统上可用,这可以使用__cpuid内在函数轻松完成。这应该只做一次,并且由于区域设置通常不会经常更改,我们修改`_wsetlocale_map`函数以避免额外的初始化函数。

#define __ISA_AVAILABLE_X86     0
#define __ISA_AVAILABLE_SSE2    1
#define __ISA_AVAILABLE_SSE42   2

int __isa_available = __ISA_AVAILABLE_X86; 
 
wchar_t* _wsetlocale_map(int category, const wchar_t *locale)
{
    wchar_t *ret = _wsetlocale(category, locale);
    wchar_t c;
    int CPUInfo[4];
    
    for (c = 0; c < 256; ++c)
    {
        g_mapToLower[c] = towlower(c);
        g_mapToUpper[c] = towupper(c);
    }
 
    __cpuid(CPUInfo, 1);
    if (CPUInfo[2] & (1 << 20))
        __isa_available = __ISA_AVAILABLE_SSE42;
    else if (CPUInfo[3] & (1 << 26))
        __isa_available = __ISA_AVAILABLE_SSE2;
    
    return ret;
}  

使用SSE2指令优化外层循环

为了优化外层for循环中字符的测试,我们首先加载三个XMM寄存器。每个寄存器有8个可用位置用于字符(128位寄存器/16位字符=8个位置)。第一个寄存器加载`l`,第二个加载`u`,第三个加载零(字符串终止空字符)。

字符串现在以8个字符的块遍历,在每个块中,我们可以通过生成掩码快速检测是否存在感兴趣的字符。以下函数中提供的注释应概述该过程。

//
// These macros come from the Microsoft CRT library (look up wcsstr.c)
//

#define XMM_SIZE sizeof(__m128i)
#define XMM_ALIGNED(p) (0 == ((XMM_SIZE-1) & (intptr_t)(p)))
#define XMM_CHARS (XMM_SIZE / sizeof(wchar_t))
 
#define PAGE_SIZE ((intptr_t)0x1000)
#define PAGE_OFFSET(p) ((PAGE_SIZE-1) & (intptr_t)(p))
#define XMM_PAGE_SAFE(p) (PAGE_OFFSET(p) <= (PAGE_SIZE - XMM_SIZE))
 
//
// The case-insensitive substring search function
//

const wchar_t* wcsistr(const wchar_t *wcs1, const wchar_t *wcs2)
{
    const wchar_t *s1, *s2;
    const wchar_t l = towlower_map(*wcs2);
    const wchar_t u = towupper_map(*wcs2);

    if (!*wcs2)
        return wcs1; // an empty substring matches everything

    if (__isa_available >= __ISA_AVAILABLE_SSE2)
    {
        __m128i patl, patu, zero, tmp1, tmp2, tmp3;
        int mask;
        unsigned long offset;

        // fill one XMM register completely with l, one with u and another with zeros

        patl = _mm_cvtsi32_si128(l);
        patl = _mm_shufflelo_epi16(patl, 0);
        patl = _mm_shuffle_epi32(patl, 0);
        patu = _mm_cvtsi32_si128(u);
        patu = _mm_shufflelo_epi16(patu, 0);
        patu = _mm_shuffle_epi32(patu, 0);
        zero = _mm_setzero_si128();

        // loop through string and try to match the first character

        for (;;)
        {
            // if page safe compare using the XMM registers

            if (XMM_PAGE_SAFE(wcs1))
            {
                // perform the comparison (also checks for end of string)

                tmp1 = _mm_loadu_si128((__m128i*)wcs1);  // load chunk of 8 characters at this position
                tmp2 = _mm_cmpeq_epi16(tmp1, patl);      // compare against lower case pattern, store mask
                tmp3 = _mm_cmpeq_epi16(tmp1, patu);      // compare against upper case pattern, store mask
                tmp2 = _mm_or_si128(tmp2, tmp3);         // combine both masks into tmp2
                tmp3 = _mm_cmpeq_epi16(tmp1, zero);      // compare against null character, store mask
                tmp2 = _mm_or_si128(tmp2, tmp3);         // combine both masks into tmp2
                mask = _mm_movemask_epi8(tmp2);          // convert to 32-bit mask

                // if no match found, continue with next chunk

                if (mask == 0)
                {
                    wcs1 += XMM_CHARS;
                    continue;
                }

                // advance string pointer to position of l, u or null

                _BitScanForward(&offset, mask);
                wcs1 += offset / sizeof(wchar_t);
            }

            // if at the end of string, no match found

            if (!*wcs1)
                return NULL;

            // if first character matches, compare against the full substring
            
            if (*wcs1 == l || *wcs1 == u)
            {
                s1 = wcs1 + 1;
                s2 = wcs2 + 1;

                while (*s1 && *s2 && towlower_map(*s1) == towlower_map(*s2))
                    ++s1, ++s2;

                if (!*s2)
                    return wcs1;
            }

            ++wcs1;
        }
    }
    else
    {
        for (; *wcs1; ++wcs1)
        {
            if (*wcs1 == l || *wcs1 == u)
            {
                s1 = wcs1 + 1;
                s2 = wcs2 + 1;
                
                while (*s1 && *s2 && towlower_map(*s1) == towlower_map(*s2))
                    ++s1, ++s2;
            
                if (!*s2)
                    return wcs1;
            }
        }
    }

    return NULL;
} 

我们再次提高了速度,这次比常规子字符串搜索慢约2.7~2.9倍。

使用SSE2优化内层循环

最后的改进涉及内层循环的优化,即子字符串剩余部分的匹配。尽管可以使用XMM寄存器通过创建掩码来比较整个子字符串与当前字符串位置,但这涉及到连续加载字符。这不会带来性能提升,因此我们保留原始的while循环。

然而,大部分时间都浪费在这个昂贵的while循环中:比较第二个字符,第三个字符甚至第四个字符,但匹配可能在第五个字符处失败。在大量文本中,子字符串的第一个字符出现很多次,前两个字符连续出现的也很多,但前三个字符出现的更少,前四个字符出现的则少得多。我想你可以明白这是怎么回事,在外部循环中匹配第一个字符后,我们将直接测试前8个字符是否与子字符串匹配。

为此,我们将缓存子字符串的前8个字符,一个XMM寄存器将保存小写版本,第二个寄存器将保存大写版本。当与目标字符串进行比较时,组合掩码应在每个位置设置一个位(1)(除非子字符串小于8个字符,这就是`cache_mask`变量的作用)。将这个想法付诸实践,得到了函数的最终版本。

const wchar_t* wcsistr(const wchar_t *wcs1, const wchar_t *wcs2)
{
    const wchar_t *s1, *s2;
    const wchar_t l = towlower_map(*wcs2);
    const wchar_t u = towupper_map(*wcs2);

    if (!*wcs2)
        return wcs1; // an empty substring matches everything

    if (__isa_available >= __ISA_AVAILABLE_SSE2)
    {
        __m128i patl, patu, zero, tmp1, tmp2, tmp3, cachel, cacheu;
        int mask, i, cache_mask;
        unsigned long offset;

        // fill one XMM register completely with l, one with u and another with zeros

        patl = _mm_cvtsi32_si128(l);
        patl = _mm_shufflelo_epi16(patl, 0);
        patl = _mm_shuffle_epi32(patl, 0);
        patu = _mm_cvtsi32_si128(u);
        patu = _mm_shufflelo_epi16(patu, 0);
        patu = _mm_shuffle_epi32(patu, 0);
        zero = _mm_setzero_si128();

        // convert first 8 characters of substring to lower and upper case while putting
        // them in XMM registers

        cachel = _mm_setzero_si128();
        cacheu = _mm_setzero_si128();
        cache_mask = 0;
        s2 = wcs2;

        for (i = 0; i < XMM_CHARS; ++i)
        {
            cachel = _mm_srli_si128(cachel, sizeof(wchar_t)); // shift by one character
            cacheu = _mm_srli_si128(cacheu, sizeof(wchar_t));
            
            cachel = _mm_insert_epi16(cachel, towlower_map(*s2), XMM_CHARS-1); // insert character
            cacheu = _mm_insert_epi16(cacheu, towupper_map(*s2), XMM_CHARS-1);

            if (*s2)
            {
                cache_mask |= 3 << (i << 1); // decimal 3 is binary 11, set two bits for each character
                ++s2;
            }
        }

        // loop through string and try to match the first character

        for (;;)
        {
            // if page safe compare using the XMM registers

            if (XMM_PAGE_SAFE(wcs1))
            {
                // perform the comparison (also checks for end of string)

                tmp1 = _mm_loadu_si128((__m128i*)wcs1);  // load chunk of 8 characters at this position
                tmp2 = _mm_cmpeq_epi16(tmp1, patl);      // compare against lower case pattern, store mask
                tmp3 = _mm_cmpeq_epi16(tmp1, patu);      // compare against upper case pattern, store mask
                tmp2 = _mm_or_si128(tmp2, tmp3);         // combine both masks into tmp2
                tmp3 = _mm_cmpeq_epi16(tmp1, zero);      // compare against null character, store mask
                tmp2 = _mm_or_si128(tmp2, tmp3);         // combine both masks into tmp2
                mask = _mm_movemask_epi8(tmp2);          // convert to 32-bit mask

                // if no match found, continue with next chunk

                if (mask == 0)
                {
                    wcs1 += XMM_CHARS;
                    continue;
                }

                // advance string pointer to position of l, u or null

                _BitScanForward(&offset, mask);
                wcs1 += offset / sizeof(wchar_t);

                // if not at end of string and page safe, quickly check whether the first chunk
                // matches the substring

                if (*wcs1 && XMM_PAGE_SAFE(wcs1))
                {
                    tmp1 = _mm_loadu_si128((__m128i*)wcs1);
                    tmp2 = _mm_cmpeq_epi16(tmp1, cachel);
                    tmp3 = _mm_cmpeq_epi16(tmp1, cacheu);
                    tmp2 = _mm_or_si128(tmp2, tmp3);
                    mask = _mm_movemask_epi8(tmp2);

                    if (cache_mask == 0xFFFF) // only first part of substring in cache
                    {
                        if (mask == cache_mask)
                        {
                            s1 = wcs1 + XMM_CHARS;
                            s2 = wcs2 + XMM_CHARS;

                            while (*s1 && *s2 && towlower_map(*s1) == towlower_map(*s2))
                                ++s1, ++s2;

                            if (!*s2)
                                return wcs1;
                        }
                    }
                    else // full substring is in cache
                    {
                        if ((mask & cache_mask) == cache_mask)
                            return wcs1;
                    }

                    // no full match found, try next character in string

                    ++wcs1;
                    continue; 
                }
            }

            // if at the end of string, no match found

            if (!*wcs1)
                return NULL;

            // if first character matches, compare against the full substring
            
            if (*wcs1 == l || *wcs1 == u)
            {
                s1 = wcs1 + 1;
                s2 = wcs2 + 1;

                while (*s1 && *s2 && towlower_map(*s1) == towlower_map(*s2))
                    ++s1, ++s2;

                if (!*s2)
                    return wcs1;
            }

            ++wcs1;
        }
    }
    else
    {
        for (; *wcs1; ++wcs1)
        {
            if (*wcs1 == l || *wcs1 == u)
            {
                s1 = wcs1 + 1;
                s2 = wcs2 + 1;
                
                while (*s1 && *s2 && towlower_map(*s1) == towlower_map(*s2))
                    ++s1, ++s2;
            
                if (!*s2)
                    return wcs1;
            }
        }
    }

    return NULL;
} 

最终结果:不区分大小写的子字符串搜索现在比常规子字符串搜索慢约2.2~2.4倍。

使用代码

本文附带了一个zip文件,其中包含C源文件及其附带的头文件。在您的项目中使用此代码非常简单:只需将这两个文件复制到您的项目目录中,并将它们添加到Visual Studio中的项目。要使用该函数,您应该通知编译器它的存在并通过以下方式包含头文件

#include "wcsistr.h"

不要忘记通过调用新引入的`_wsetlocale_map`函数来设置当前区域设置并初始化大小写转换缓存。该函数至少应该调用一次,最好在程序启动时调用。

_wsetlocale_map(LC_ALL, L""); 

现在执行典型的不区分大小写子字符串搜索如下所示

const wchar_t *p = wcsistr(L"HeLLo, wOrLD!", L"world");   

值得关注的点  

  • 涉及复杂字符的Unicode字符串比较需要额外注意。在Unicode中,字符通常表示为单个预组合字符,或表示为基本字母加上一个或多个非间距标记的分解序列。如果字符串和子字符串的表示不同,`wcsistr`函数显然会失败,因为它只是基于16位进行字符串比较。这意味着在使用`wcsistr`函数之前,您需要使用NormalizeString标准化两个字符串。更多信息可以在这里找到。但即便如此,我仍然高度怀疑这在100%的情况下都能奏效,Unicode确实很难正确处理。
  • 尽管SSE4.2有专门用于字符串比较的指令,但我使用这些指令获得的结果却令人失望。在最好的情况下,该函数仍然比本文中介绍的SSE2加速版本慢约10~20%。

希望您喜欢这篇文章。请随时留下关于错误、性能或其他方面的反馈。

历史  

  • 2012年5月14日 - 文章发布。
© . All rights reserved.