65.9K
CodeProject 正在变化。 阅读更多。
Home

C++11 的编译时循环 - 创建通用的 static_for 实现

starIconstarIconstarIconstarIcon
emptyStarIcon
starIcon

4.80/5 (16投票s)

2014 年 12 月 28 日

CPOL

11分钟阅读

viewsIcon

49986

downloadIcon

229

你是否曾经在使用模板或 constexpr 时想要执行循环?或者你想展开一个循环,看看是否能提高程序速度?欢迎来到 static_for。

引言
要求
问题
相关主题
入门 - 尝试使用递归
递归问题的细节和修复
精髓 - 编写通用的 static_for 实现
代码
一个测试程序
代码注释
我该如何选择编译时循环和运行时循环?
问题和评论
历史

引言

我最近遇到一个有趣的问题。我该如何轻松地在编译时进行循环?

过去我曾几次考虑编写一个 static_for 方法来简化一些 template 代码,但直到现在还没有迫切的需要。

我最近找到了编写它的理由。我需要测试一些编译时代码,因此需要循环遍历它。内置工具还不够,所以……我来了,这是我的 static_for

通过在编译时循环,你可以与编译时代码(如模板)进行交互。

static_for 还可以展开循环。

要求

支持 C++11 或更高版本的编译器。下载内容包含一个变通方法,用于解决 Visual Studio 当前缺乏 constexpr 支持的问题,方法是使用 enum 和更多的 template

问题

编译时循环。运行时循环很简单

for (int i = 0; i < upper_bound; ++i)
{
    do_stuff(i);
}

由于我们在编译时工作,所以解决方案并不那么简单。我想这样做

template<int n> void do_stuff()
{
    // stuff
}
for (int i = 0; i < 1000; i++)
{
    do_stuff<i>();
}

但是……这行不通。i 是一个运行时计数器,而 do_stuff<int>() 需要一个编译时计数器。因为我不想写

do_stuff<0>
do_stuff<1>
do_stuff<2>
do_stuff<3>
do_stuff<4>
...
do_stuff<997>
do_stuff<998>
do_stuff<999>

当然,这仍然需要我硬编码迭代次数。我需要循环到一个通过编译时表达式计算的数字,同时仍具有与上面硬编码相同的运行时性能。

今天,我将带你了解我创建此问题通用解决方案过程的一些细节,并向你展示一些潜在的陷阱以及我需要创建如此复杂的东西的原因。

或者,你也可以直接下载代码,然后自己玩玩 :D

展开循环会提高性能吗?

函数签名

重载解析

完美转发

标签分派

尾递归

模板参数推导

模板实例化

模板元编程

泛型编程

n元/k元树

入门 - 尝试使用递归

假设我想使用从 0 到 99 的值来实例化这个 template,并调用函数

#include <iostream>
template<int index> void do_stuff()
{
    std::cout << index << std::endl;
}

这是我想在这里解决的问题类型,因为这不是我们可以用正常的运行时循环来处理的问题。最初解决这个问题的简单方法将涉及递归。也许这样好?

template<int max_index, int index = 0> void stuff_helper()
{
    if (index <= max_index)
    {
        do_stuff<index>();
        stuff_helper<max_index, index + 1>();
    }
}
int main()
{
    stuff_helper<100>();
    return 0;
}

取决于你对 C++ 中 template 的熟悉程度,你可能认为这看起来不错,或者你已经知道为什么这行不通。或者你可能认为这看起来不错并且知道为什么它行不通。

表面上看,if 语句似乎负责终止递归,就像它在“正常”基于运行时的递归算法中工作一样。但这就是问题所在。运行时有效的东西在编译时无效。

这是一个无限循环,只有在编译器将自身限制到一定的递归深度时才会停止。在 clang 中,我会收到错误 fatal error: recursive template instantiation exceeded maximum depth of 256。你可以在你选择的编译器中得到类似的错误。

幸运的是,有一个合理的方法来解决这个问题。

递归问题的细节和修复

所以第一个尝试的问题在于,template 并不总是像你可能直观认为的那样工作。我发现学习它有点令人沮丧,这似乎是一个常见的困难。然而,一旦理解并接受了这些怪癖,它们就不再是问题了。

你值得信赖的 C++ 编译器将尝试实例化 template,而不考虑基于运行时的“正常”逻辑。

这意味着,当我们这样做时

template<int max_index, int index = 0> void stuff_helper()
{
    if (index <= max_index)
    {
        do_stuff<index>();

if 逻辑在我们想要的时候并没有生效。if 与编译时决策无关,你的可爱编译器有义务尝试编译

        stuff_helper<max_index, index + 1>();

……即使你已经到达或超过了 max_index

所以我们需要的是一些不同的东西。我们在这里需要的是一种在编译时而不是在运行时进行逻辑分支的方法。

如何使用特化?

#include <type_traits>

constexpr int terminate_index{ 100 };

template<int index = 0> void stuff_helper();
template<> void stuff_helper<terminate_index>()
{
}    
template<int index> void stuff_helper<index>()
{
    do_stuff<index>();
    stuff_helper<index + 1>();
}

int main()
{
    stuff_helper();
    return 0;
}

不行,那行不通。C++ 不允许函数的局部特化。幸运的是,在这种情况下有一个变通方法可以模拟局部特化:我们可以使用标签。这里的标签是空类型,用于将一个函数签名与另一个函数签名分开。所以让我们使用 std::integral_constant,然后尝试这样做:

#include <type_traits>

constexpr int terminate_index{ 100 };

void stuff_helper(std::integral_constant<int, terminate_index>)
{
}
    
template<int index = 0>
void stuff_helper(std::integral_constant<int, index> = std::integral_constant<int, 0>())
{
    do_stuff<index>();
    stuff_helper(std::integral_constant<int, index + 1>());
}

int main()
{
    stuff_helper();
    return 0;
}

成功! 但它是如何工作的?

这依赖于编译器将函数调用匹配到正确函数的机制,并利用template 参数推导标签分派

第一次调用 stuff_helper 时,只有第二个templated)函数匹配调用,因此选择它。

直到 terminate_index 的递归调用也相当直接。我们传递 std::integral_constant,编译器会推导出 template 参数以匹配 integral_constant 中的 int

std::integral_constant 在这里工作是因为每个索引都会创建一个新类型,以及一个新的函数签名。本质上,std::integral_constant<int, 0>std::integral_constant<int, 1> 是完全不同的类型,尽管它们共享一个底层 template

最后一个调用是特殊的,因为那里有两个有效的函数签名可供选择。当不可避免地调用 stuff_helper(std::integral_constant<int, max_index>()); 时,编译器必须选择正确的。但那是哪个呢?

编译器会优先选择没有 template 的函数,而不是有 template 的函数。实际上,由于编译器可以找到一个无需实例化模板即可工作的函数签名,因此它会选择该函数而不查看 templated 的函数。这个过程被称为重载解析,属于“最佳可行函数”的范畴。

哦,别担心传递未使用的 std::integral_constant 的运行时性能成本。据我所知的所有编译器在发布模式下都会优化掉匿名类型

但是,这段代码有其局限性。当我们改变代码多循环几次时会发生什么?

constexpr int max_index{ 1000 };

砰! 我们又遇到了 template 递归错误。至少我在递归深度限制为 256 的编译器上遇到了。

有一个简单的方法和一个有趣的方法来解决这个问题。

简单的方法是告诉你的编译器增加其最大的 template 深度,使其超过 256,以适应你的具体情况。

有趣的方法是保持在最大 template 深度 256 的限制内,并允许我们循环任意多次。这就像解决一个棘手的谜题。也许你不喜欢谜题。*耸肩* 另外一个好处是,有趣的方法提供了一个在其他地方也有用的通用解决方案。

另外需要注意的是:我正在编写库代码,我无法控制用户的编译器设置。因此,对我来说,能够以一种无需修改编译器标志就能在现代 C++ 编译器上“正常工作”的方式解决这种情况非常重要。

精髓 - 编写一个通用的编译时 static_for 方法

这是我创建的基本接口

template<size_t count, typename functor_template, typename... functor_types>
inline void static_for(functor_types&&... functor_args);

template<size_t start, size_t end, typename functor_template, typename... functor_types>
inline void static_for(functor_types&&... functor_args);

static_for 方法接受一个 count 或一个 startend 点,一个包含一个名为 functemplated 函数的 typename,该函数将被 static_for 实例化,以及一组函数参数,使用 C++11 的完美转发(这些参数会自动推导!)来传递。哦,它还接受一个可选参数 sequence_width,用于内部使用。稍后将对此进行详细介绍!

虽然声明看起来很复杂,但使用它很简单。例如

struct do_stuff
{
    template<int index> static inline void func()
    {
        std::cout << index << std::endl;
    }
};

int main()
{
    static_for<5000, do_stuff>();
    return 0;
}

看起来不错吧?现在我们有了接口。现在只是编写代码来实现它……

不幸的是,我们需要将 func() 包装在一个辅助结构中,因为无法直接传递模板函数的名称。这是 C++ 语言在编译时能力方面又一个不足之处。

代码

我解决在编译时循环而不触发 template 深度错误的方法是将问题分解成更小的部分。

基本上,最初的简单递归方法是直线进行的。无论你想循环 20 次还是 20,000 次,基本方法都会一直进行直到完成。

我所做的不同之处在于将它分解成更小的部分。template 参数 sequence_width 可用于设置这些小块的大小。这种结构被称为n元树,其中 n 等于 sequence_width。我最初是用二元树编写的,并将它改成了n元树作为编译时优化。

虽然下面的代码适用于大型循环,但你几乎肯定会遇到编译器中的其他限制,或者基于编译程序所需时间的硬性限制。

换句话说:这不是万能药,虽然它允许你进行比以前更大的编译时循环,但你仍然会遇到其他问题。另外:编译 template 以众所周知地慢。

我将在此处分享头文件,以及一些关于此功能如何工作的注释。完整的可运行代码如下。下载中包含一个测试函数和项目文件。

公共接口

template<size_t count, typename functor, size_t sequence_width = 70,
    typename... functor_types>
inline void static_for(functor_types&&... functor_args)
{
    static_for_impl<0, count-1, functor, sequence_width, functor_types...>::
        loop(std::forward<functor_types>(functor_args)...);
}

template<size_t start, size_t end, typename functor, size_t sequence_width = 70,
    typename... functor_types>
inline void static_for(functor_types&&... functor_args)
{
    static_for_impl<start, end, functor, sequence_width, functor_types...>::
        loop(std::forward<functor_types>(functor_args)...);
}

带注释的内部实现

template<size_t for_start, size_t for_end, typename functor, size_t sequence_width,
    typename... functor_types>
struct static_for_impl
{
    static inline void loop(functor_types&&... functor_args)
    {
        // The main sequence point is created, and then we call "next" on each point inside
        using sequence = point<for_start, for_end>;
        next<sequence>
            (std::integral_constant<bool, sequence::is_end_point_>(), 
             std::forward<functor_types>(functor_args)...);
    }

private:
    
    // A point is a node of an n-ary tree
    template<size_t pt_start, size_t pt_end> struct point
    {
        static constexpr size_t start_        { pt_start };
        static constexpr size_t end_          { pt_end };
        static constexpr size_t count_        { end_ - start_ + 1 };
        static constexpr bool is_end_point_   { count_ <= sequence_width };

        static constexpr size_t sequence_count()
        {
            return
                    points_in_sequence(sequence_width) > sequence_width
                ?
                    sequence_width
                :
                    points_in_sequence(sequence_width);
        }

    private:
        // Calculates the start and end indexes for a child node
        static constexpr size_t child_start(size_t index)
        {
            return
                    index == 0
                ?
                    pt_start
                :
                    child_end(index - 1) + 1;
        }
        static constexpr size_t child_end(size_t index)
        {
            return
                    index == sequence_count() - 1
                ?
                    pt_end
                :
                    pt_start + points_in_sequence(sequence_count()) * (index + 1) -
                        (index < count_
                    ?
                         1
                    :
                         0);
        }
        static constexpr size_t points_in_sequence(size_t max)
        {
            return count_ / max + (
                    (count_ % max) > 0
                ?
                    1
                :
                    0);
        }
           
    public:
        // Generates child nodes when needed
        template<size_t index> using child_point = point<child_start(index), child_end(index)>;
    };

    // flat_for is used to instantiate a section of our our main static_for::loop
    // A point is used to specify which numbers this instance of flat_for will use
    template<size_t flat_start, size_t flat_end, class flat_functor> struct flat_for
    {
        // This is the entry point for flat_for
        static inline void flat_loop(functor_types&&... functor_args)
        {
            flat_next(std::integral_constant<size_t, flat_start>(), 
                std::forward<functor_types>(functor_args)...);
        }

    private:
        // Loop termination
        static inline void flat_next
            (std::integral_constant<size_t, flat_end + 1>, functor_types&&...)
        {
        }
       
        // Loop function that calls the function passed to it, as well as recurses
        template<size_t index>
        static inline void flat_next
            (std::integral_constant<size_t, index>, functor_types&&... functor_args)
        {
            flat_functor::template func<index>(std::forward<functor_types>(functor_args)...);
            flat_next(std::integral_constant<size_t, index + 1>(),
                std::forward<functor_types>(functor_args)...);
        }
    };

    // This is what gets called when we run flat_for on a point
    // It will recurse to more finer grained point until the points are no bigger than sequence_width
    template<typename sequence> struct flat_sequence
    {
        template<size_t index> static inline void func(functor_types&&... functor_args)
        {
            using pt = typename sequence::template child_point<index>;
            next<pt>
                (std::integral_constant<bool, pt::is_end_point_>(),
                 std::forward<functor_types>(functor_args)...);
        }
    };

    // The true_type function is called when our sequence is small enough to run out
    // and call the main functor that was provided to us
    template<typename sequence> static inline void next
        (std::true_type, functor_types&&... functor_args)
    {
        flat_for<sequence::start_, sequence::end_, functor>::
            flat_loop(std::forward<functor_types>(functor_args)...);
    }

    // The false_type function is called when our sequence is still too big, and we need to
    // run an internal flat_for loop on the child sequence_points 
    template<typename sequence> static inline void next
        (std::false_type, functor_types&&... functor_args)
    {
        flat_for<0, sequence::sequence_count() - 1, flat_sequence<sequence>>::
            flat_loop(std::forward<functor_types>(functor_args)...);
    }
};

Visual Studio 的当前版本不包含对 constexpr 的完整支持。源代码下载中包含一个不需要 constexprstatic_for 版本。

一个测试程序

// This is a sample struct that is instantiated at compile time with static_for
struct do_stuff
{
    template<size_t index> static inline void func(bool* b)
    {
        assert(!b[index]);
        b[index] = true;
        std::cout << index << std::endl;
    }
};

int main()
{
#define for_count 2000

    // To show that this works, I create an array of bools

    // The bools are all initialized to false

    // On each loop iteration, the bool at the index is tested to make sure that do_stuff hasn't been
    // called at that index. After the loop, each bool is tested to ensure that it was set to true.

    // In effect, this test validates that each index is called once and only once.

    bool b[for_count + 100];
    memset(b, false, for_count);

    // ... and that we don't overshoot the end point
    memset(&b[for_count], true, 100);

    static_for<for_count, do_stuff>(b);

    for (size_t i{ 0 }; i < for_count; ++i)
        assert(b[i]);

    // When the program completes without error, our test has passed successfully

    return 0;
}

代码注释

static_for_impl 使用你传递给它的信息进行初始化,然后它创建一个 point。将 point 视为树的一个节点。

如果 point 包含 sequence_width 个或更少的项,则该 point 被标记为不需要拆分(该 point 的值 is_end_point_ 将等于 true)。

如果 point 包含的项超过 sequence_width,则其 start_end_ 值将被拆分为 child_point

检查 child_point 以确定它们是否需要进一步拆分。

如果是,那么每个子节点都会获得数量等于父节点的 count_ 除以 sequence_widthpoint 项(类型)。

否则,如果不需要进一步拆分,那么子节点的数量 point最少为父节点的 count_ 除以 sequence_width

每个 point 都通过内部 flat_for 模板进行处理。在那里,point 用于调用带有子 point 的更多 flat_for 实例,或者用于在调用 static_for 时提供的 functor 在正确的 index 处进行实例化。

总的来说,代码通过将 for 循环的任务分解为 point(每个 point 包含 sequence_width 个或更少的项)来工作。point 在内部构建为n元树。然后,它通过 flat_for 处理每个 pointflat_for 用于在调用 static_for 时请求的每个索引处调用你的代码。

我该如何选择编译时循环和运行时循环?

在某些情况下,你别无选择。

我之所以写这个,是因为我需要与其他的编译时代码进行循环,在这种情况下我不能使用运行时循环。

通常,有许多与循环相关的逻辑直到运行时才能确定。例如,一个需要外部数据来确定何时终止的循环。

很多时候,这个问题归结为一个工程问题,因为它需要权衡这两种选择的利弊。

编译时/展开的循环将花费更长的编译时间。它们可能允许略高的运行时性能。在某些情况下,这可能是一个值得付出更长编译时间的收益。

在许多情况下,节省几纳秒的性能差异无关紧要。然而,在其他情况下,节省几纳秒至关重要。

至少,了解这些权衡很重要,并知道你为什么选择一个选项而不是另一个。

问题和评论

我很想听听你对此的任何问题和评论。告诉我你喜欢什么,不喜欢什么,以及我如何能做得更好。

告诉我你对哪些模板元编程主题感兴趣,以及你想让我写一篇后续文章。

历史

v 1.0 - 2014 年 12 月 28 日 - 初始发布

v 1.1 - 2014 年 12 月 29 日 - 附加信息和清理

v 1.2 - 2014 年 12 月 30 日 - 代码清理,附加信息,添加了测试程序和更多“相关主题”

© . All rights reserved.