65.9K
CodeProject 正在变化。 阅读更多。
Home

颠覆 3D 数学:第一部分 - 矩阵乘法

starIconstarIconstarIconstarIconstarIcon

5.00/5 (25投票s)

2017年4月18日

CPOL

9分钟阅读

viewsIcon

24368

downloadIcon

235

跳出固有思维模式可以带来关键的性能提升

敷衍了事还是精益求精?

软件开发界(尤其是学术界)一直以来都崇尚懒惰和逃避工作的原则。使用别人的库!不要重复造轮子!让编译器去做!教条主义有一个糟糕的习惯,那就是不断地自我延续,导致产出质量越来越差,直到最终产品变得如此糟糕以至于几乎无法使用。

曾经,技能、知识和工艺是一项值得尊敬的传统。对自己工作的自豪感是懒惰和敷衍了事的彻底对立面。如果这些概念被应用于今天的 3D 数学领域,会发生什么呢?

如果基本的矩阵乘法是衡量标准,那么可以带来巨大的性能提升。我创建了一个自己的 XMMatrixMultiply 函数版本,速度提升了 657%。在 3D 场景中,XMMatrixMultiply 被调用非常频繁。将其速度提高 6.57 倍(根据我的测试机器上的测量结果)是会带来显著效果的。而这仅仅是数学库中一个函数,该库单独用于矩阵处理的函数就有 42 个。

包含伪代码

本文提供的代码是 ASM(汇编语言),主要原因是它比过度装饰和复杂的内在函数更容易阅读和理解。为了阐述一个概念,使用 ASM 比试图在冗长的内在函数列表中摸索要清晰得多。本文中的代码可以轻松地适应任何支持 SSE 的语言,特别是 AVX。

本文随附一个完整的 .DLL、.LIB、C++ 头文件和说明,可供下载。本文中提供的 PCC_XMMatrixMultiply 函数包含在该下载文件中。它接受与 DirectXMath 函数 XMMatrixMultiply 相同的输入,并提供相同的输出。

跳出矩阵思维

SSE(Streaming SIMD [Single Instruction Multiple Data] Extensions)是一把双刃剑。它能在极短的时间内处理大量数据。理论上,处理 SSE 并不难,但真正的挑战在于在大规模操作中精确跟踪所有数据的变化。

随着 SSE 发展到 AVX(Advanced Vector Extensions)、AVX2,再到 AVX-512,添加了越来越多的功能,这些功能非常适合向量数学。然而,对这些过程的支持取决于 CPU 版本。本文假设支持 AVX,并使用 256 位 YMM 寄存器,其中包含四个“槽”的双精度浮点值。

少即是多是误区

一种常见的误解,尤其是在使用高级语言时,是认为减少指令数量自然等同于更好、更快的代码。事实远非如此。十条逻辑位运算指令的执行时间,还不到一次内存访问完成所需时间的一小部分。指令数量无关紧要;指令数量少可以慢,也可以快,这完全取决于所编码的指令。归根结底,唯一重要的是函数整体的运行速度。这就是本文概述函数的依据——指令数量可能看起来令人望而生畏,但当每条指令都以五个或六个 CPU 周期执行时,整个函数会以闪电般的速度完成,其执行时间只是 DirectXMath 等效函数所需时间的一小部分。

远离内存

最大化性能的关键在于避免内存访问。在 CPU 寄存器之间移动数据速度极快,但利用内存辅助处理会严重拖慢矩阵乘法(或其他使用 SSE 的操作)的速度。

在一个理想的世界里,矩阵乘法只涉及十二次内存访问操作,假设使用双精度浮点值并在 YMM 寄存器上进行操作(单精度浮点数使用 XMM 会产生相同的结果,但精度较低)。会有 8 次读取操作,将乘法矩阵 M1 的四行和被乘数矩阵 M2 的四行加载到目标寄存器中。最后四次内存访问是写入操作,用于输出最终结果矩阵。所有其他操作都将在 SSE 寄存器内完成。

这可以做到吗?

翻转 M2

矩阵乘法的本质,从表面上看,由于需要处理列而非行(针对第二个矩阵 M2,即被乘数),因此阻碍了 SSE 指令的高效使用。然而,如果 M2 能绕着一条从左上角到右下角的假想对角线旋转 180 度,那么该矩阵的四列就变成了行,然后就可以被 SSE 寄存器完全访问。

Figure 1: http://www.starjourneygames.com/images/part_1_figure_1.png

在翻转过程中,M2[r0,c0]M2[r1,c1]M2[r2,c2]M2[r3,c3] 不会移动。对于剩余的值,可以通过直接交换寄存器内容来完成翻转。

SSE 处理两种基本数据类型:packed(将不同的值放置到 YMM 寄存器的每个 64 位槽中),或 scalar(仅使用低 64 位)。翻转 M2 的过程同时使用了这两种类型的指令。

vpermq 指令用于旋转 YMM 寄存器的内容,使得正在处理的每个连续值都被移动到其寄存器的低 64 位。一旦到达那里,就可以使用标量指令交换寄存器内容。通过使用这种方法,不会丢失或覆盖数据,也不会进行内存访问,并且可以非常快速地完成翻转 M2 的任务。指令数量可能看起来很多,但总执行时间却非常短。

首先,此示例将 M1 的第 0-3 行加载到 YMM0YMM1YMM2YMM3 中。M2 的行被加载到 YMM4YMM5YMM6YMM7 中。

Figure 2: http://www.starjourneygames.com/images/part_1_figure_2.png

(此代码假设进入函数时,RCX 指向 M1,RDX 指向 M2。这是 M1 和 M2 作为参数传递时的标准调用约定。)

vmovapd ymm4, ymmword ptr [ rdx ]      ; Load M2 [ r0 ] into ymm4
vmovapd ymm5, ymmword ptr [ rdx + 32 ] ; Load M2 [ r1 ] into ymm5
vmovapd ymm6, ymmword ptr [ rdx + 64 ] ; Load M2 [ r2 ] into ymm6
vmovapd ymm7, ymmword ptr [ rdx + 96 ] ; Load M2 [ r3 ] into ymm7

M2 加载到内存后,翻转过程就可以开始了。必须进行以下交换

r1, c0 : r0, c1
r2, c0 : r0, c2
r3, c0 : r0, c3
r2, c1 : r1, c2
r3, c1 : r1, c3
r3, c2 : r2, c3

总共需要进行六次交换。

需要记住的一个重要事项是 CPU 的小端序操作。第 3 列的值将占据其相应 YMM 寄存器的高 64 位;第 0 列的值将占据低 64 位。这与矩阵在图表中显示的视图是相反的。实际的 YMM 内容如下

Table 1: http://www.starjourneygames.com/images/part_1_figure_3.png

需要交换的每个值都必须旋转到其 YMM 寄存器的低位位置——位 63-0。对于第一次交换(r1,c0r0,c1),r1,c0 的值已在 YMM5 的低 64 位中。YMM4 需要进行 64 位右移,将 r0,c1 的值移动到其低 64 位。

vpermq ymm4, ymm4, 39h ; Rotate ymm4 right 64 bits
movsd xmm8, xmm5       ; Move r1c0 to hold
movsd xmm5, xmm4       ; Move r0c1 to r1c0 in xmm5
movsd xmm4, xmm8       ; Move r1c0 to r0c1

SSE 指令仅使用低 128 位,即 XMM 寄存器部分,进行标量移动,因为仅处理位 63-0。由于 YMM 寄存器是 XMM 的超集,因此可以实现所需的效果:值在 YMM 寄存器的低 64 位之间移动,即使引用的是 XMM 寄存器。

旋转后,r0c1(在 ymm8 中)已向右旋转,使 c1 现在位于位 63-0。下一次向右旋转 64 位会将 r0c2 放入位 63-0,然后该过程可以重复以交换 r0,c2r2,c0

vpermq ymm4, ymm4, 39h ; Rotate ymm4 right 64 bits
movsd  xmm8, xmm6      ; Move r2c0 to hold
movsd  xmm6, xmm4      ; Move r0c2 to r2c0
movsd  xmm4, xmm8      ; Move hold to r0c2

此过程继续进行,使用 YMM7 交换 r3,c0r0,c3。其余的 M2 值交换也以同样的方式进行。本文末尾的最终代码包含整个过程。

执行乘法

一旦 M2 被翻转,就可以进行实际的乘法。这才是真正的节省时间之处,因为只需要执行四次乘法运算。由于 M2 的每一列现在都连续地存储在 YMM 寄存器中,因此可以直接进行乘法。

首先,M1 被加载到 YMM0YMM3

vmovapd ymm0, ymmword ptr [ rcx ]      ; Load row 0 into ymm0
vmovapd ymm1, ymmword ptr [ rcx + 32 ] ; Load row 1 into ymm1
vmovapd ymm2, ymmword ptr [ rcx + 64 ] ; Load row 2 into ymm2
vmovapd ymm3, ymmword ptr [ rcx + 96 ] ; Load row 3 into ymm3

首先,将 M1[r0](加载到 YMM0 中)乘以 M2[c0](加载到 YMM4 中)。

vmulpd  ymm8, ymm0, ymm4 ; Set M1r0 * M2c0 in ymm8

M1[r1] * M2[c0] 的结果现在存在于 YMM8 中,接下来的挑战是如何累加 YMM8 中的值,该寄存器保存乘法的结果。

我见过很多混乱的处理方式;我个人认为没有一种是合适的。对于这个函数,我再次转向系统地旋转结果寄存器 YMM8 三次,每次旋转 64 位,使用 addsd 指令累加移入位 63-0 的每个值。由于原始 M1 行在 YMM0YMM3 中,翻转后的 M2 列在 YMM4YMM7 中,因此 YMM8 在累加完成后将被丢弃;它可以保留其最终状态,右移 192 位。

再次请记住,没有进行任何内存访问。所有操作都是寄存器到寄存器的,这就是速度提升的来源。

movsd  xmm9, xmm8      ; Get the full result (only care about 63 - 0)
vpermq ymm8, ymm8, 39h ; Rotate ymm8 right 64 bits
addsd  xmm9, xmm8      ; Add 127 – 64
vpermq ymm8, ymm8, 39h ; Rotate ymm8 right 64 bits
addsd  xmm9, xmm8      ; Add 191 – 128
vpermq ymm8, ymm8, 39h ; Rotate ymm8 right 64 bits
addsd  xmm9, xmm8      ; Add 255 - 192

重复相同的过程:将 M1[r0] * M2[c1] 相乘。在过程开始时,累加器 YMM9 向左移动 64 位,以便将 M1[r0] * M2[c1] 的结果放入位 63-0

vmulpd  ymm8, ymm0, ymm5 ; Set M1r0 * M2c1
vpermq  ymm9, ymm9, 39h  ; Rotate accumulator left 64 bits
movsd   xmm9, xmm8       ; Get the full result
vpermq  ymm8, ymm8, 39h  ; Rotate ymm0 right 32 bits
addsd   xmm9, xmm8       ; Add 63 – 32
vpermq  ymm8, ymm8, 39h  ; Rotate ymm0 right 32 bits
addsd   xmm9, xmm8       ; Add 95 – 64
vpermq  ymm8, ymm8, 39h  ; Rotate ymm0 right 32 bits
addsd   xmm9, xmm8       ; Add 127 – 96

此过程系统地重复两次,使用 M2[c2] (YMM6)M2[c3] (YMM7) 作为被乘数。当该过程重复四次(每次使用 YMM0 作为乘数)后,累加器 YMM9 已右移三次。最后一次旋转使其处于最终状态,准备存储到最终输出中。

vpermq  ymm9, ymm9, 39h     ; Rotate accumulator left 32 bits
vmovapd ymmword ptr M, ymm9 ; Store output row 0

该过程重复用于输出行一、二和三,乘数从 YMM0 变为 YMM1,然后变为 YMM2,再变为 YMM3,对应于每一行。

最终结果是输出矩阵 M 在 XMMatrixMultiply 执行所需时间的一小部分内被正确计算。

注意:对于希望进行自己的速度比较的人来说,64 位 Visual Studio 应用程序不允许使用 64 位汇编,而内在函数可能会(很可能)扭曲计时结果(它们可能执行得更慢)。此外,尽管我无法绝对证明这一点,但我的经验一直表明,在 Visual Studio 中编写的应用程序与 Windows 之间存在一种特殊的关系,它们运行速度比用其他任何工具编写的应用程序都要快得多。这是我的经验;我无法在法庭上证明这一点,因此我不能权威地声称它是真实的。这是一种我个人持有的坚定信念。

下面展示了 PCC_XMMatrixMultiply 函数的完整代码列表。初步看来,该函数有太多的指令,不可能在任何可观的时间范围内运行。但事实并非如此。这些指令速度极快,延迟极低,而且它们能像闪电一样被执行。如前所述,唯一可以衡量的是函数整体的最终性能。

align qword
PCC_XMMatrixMultiply proc

; 64-bit calling conventions have RCX > M1 and RDX > M2 on entry to this function.

;*****[Load M2 into ymm4 / ymm5 / ymm6 / ymm7]*****

vmovapd ymm4, ymmword ptr [ rdx ]
vmovapd ymm5, ymmword ptr [ rdx + 32 ]
vmovapd ymm6, ymmword ptr [ rdx + 64 ]
vmovapd ymm7, ymmword ptr [ rdx + 96 ]

;*****[Swap r0,c1 and r1,c0]*************************

vpermq  ymm4, ymm4, 39h
movsd   xmm9, xmm5
movsd   xmm5, xmm4
movsd   xmm4, xmm9

;*****[Swap r0,c2 and r2,c0]*************************

vpermq  ymm4, ymm4, 39h
movsd   xmm9, xmm6
movsd   xmm6, xmm4
movsd   xmm4, xmm9

;*****[Swap r3,c0 and r0,c3]*************************

vpermq  ymm4, ymm4, 39h
movsd   xmm9, xmm7
movsd   xmm7, xmm4
movsd   xmm4, xmm9
vpermq  ymm4, ymm4, 39h

;*****[Swap r2,c1 and r1,c2]*************************

vpermq  ymm5, ymm5, 04Eh
vpermq  ymm6, ymm6, 039h
movsd   xmm9, xmm6
movsd   xmm6, xmm5
movsd   xmm5, xmm9

;*****[Swap r3,c1 and r1,c3]*************************

vpermq  ymm5, ymm5, 039h
vpermq  ymm7, ymm7, 039h
movsd   xmm9, xmm7
movsd   xmm7, xmm5
movsd   xmm5, xmm9
vpermq  ymm5, ymm5, 039h

;*****[Swap r3,c2 and r2,c3]*************************

vpermq  ymm6, ymm6, 04Eh
vpermq  ymm7, ymm7, 039h
movsd   xmm9, xmm7
movsd   xmm7, xmm6
movsd   xmm6, xmm9
vpermq  ymm6, ymm6, 039h
vpermq  ymm7, ymm7, 04Eh

;*****[Load M1 values]*******************************

vmovapd ymm0, ymmword ptr [ rcx ]
vmovapd ymm1, ymmword ptr [ rcx + 32 ]
vmovapd ymm2, ymmword ptr [ rcx + 64 ]
vmovapd ymm3, ymmword ptr [ rcx + 96 ]

;*****[Set Mr0c0 = M1r0 * M2c0 ]*********************

vmulpd  ymm8, ymm0, ymm4
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Set Mr0c1 = M1r0 * M2c1 ]*********************

vmulpd  ymm8, ymm0, ymm5
vpermq  ymm9, ymm9, 39h
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Set Mr0c2 = M1r0 * M2c2 ]*********************

vmulpd  ymm8, ymm0, ymm6
vpermq  ymm9, ymm9, 39h
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Set Mr0c3 = M1r0 * M2c3 ]*********************

vmulpd  ymm8, ymm0, ymm7
vpermq  ymm9, ymm9, 39h
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Store output row 0]***************************

vpermq  ymm9, ymm9, 39h
lea     rdx, M
vmovapd ymmword ptr M, ymm9

;*****[Set Mr1c0 = M1r1 * M2c0 ]*********************

vmulpd  ymm8, ymm1, ymm4
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Set Mr1c1 = M1r1 * M2c1 ]*********************

vmulpd  ymm8, ymm1, ymm5
vpermq  ymm9, ymm9, 39h
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Set Mr1c2 = M1r1 * M2c2 ]*********************

vmulpd  ymm8, ymm1, ymm6
vpermq  ymm9, ymm9, 39h
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39haddsd   xmm9, xmm8

;*****[Set Mr1c3 = M1r1 * M2c3 ]*********************

vmulpd  ymm8, ymm1, ymm7
vpermq  ymm9, ymm9, 39h
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Store output row 1]***************************

vpermq  ymm9, ymm9, 39h
vmovapd ymmword ptr M [ 32 ], ymm9

;*****[Set Mr1c0 = M1r2 * M2c0 ]*********************

vmulpd  ymm8, ymm2, ymm4
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Set Mr2c1 = M1r2 * M2c1 ]*********************

vmulpd  ymm8, ymm2, ymm5
vpermq  ymm9, ymm9, 39h
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Set Mr2c2 = M1r2 * M2c2 ]*********************

vmulpd  ymm8, ymm2, ymm6
vpermq  ymm9, ymm9, 39h
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Set Mr2c3 = M1r2 * M2c3 ]*********************

vmulpd  ymm8, ymm2, ymm7
vpermq  ymm9, ymm9, 39h
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Store output row 2]***************************

vpermq  ymm9, ymm9, 39h
vmovapd ymmword ptr M [ 64 ], ymm9

;*****[Set Mr3c0 = M1r3 * M2c0 ]*********************

vmulpd  ymm8, ymm3, ymm4
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Set Mr3c1 = M1r3 * M2c1 ]*********************

vmulpd  ymm8, ymm3, ymm5
vpermq  ymm9, ymm9, 39h
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Set Mr3c2 = M1r3 * M2c2 ]*********************

vmulpd  ymm8, ymm3, ymm6
vpermq  ymm9, ymm9, 39h
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Set Mr3c3 = M1r3 * M2c3 ]*********************

vmulpd  ymm8, ymm3, ymm7
vpermq  ymm9, ymm9, 39h
movsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8
vpermq  ymm8, ymm8, 39h
addsd   xmm9, xmm8

;*****[Store output row 3]***************************

vpermq  ymm9, ymm9, 39h
vmovapd ymmword ptr M [ 96 ], ymm9

;*****[Set final return]*****************************

lea     rax, M
ret

PCC_XMMatrixMultiply endp
© . All rights reserved.