From 538ffdc8929b72ea3c236ce7b23d53e49aad6ad7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Fri, 22 Oct 2021 17:15:10 +0200
Subject: [PATCH] Add an AVX512 kernel

---
 mex/build/local_state_space_iterations.am     |   2 +-
 .../local_state_space_iteration_2.cc          |  25 +-
 .../local_state_space_iterations/lssi2.hh     |   5 +
 .../local_state_space_iterations/lssi2_avx2.S |  17 +-
 .../lssi2_avx512.S                            | 327 ++++++++++++++++++
 .../lssi2_common.S                            |  15 +
 .../local_state_space_iteration_k_test.mod    |  12 +-
 7 files changed, 381 insertions(+), 22 deletions(-)
 create mode 100644 mex/sources/local_state_space_iterations/lssi2_avx512.S
 create mode 100644 mex/sources/local_state_space_iterations/lssi2_common.S

diff --git a/mex/build/local_state_space_iterations.am b/mex/build/local_state_space_iterations.am
index df711fe19..380a3c542 100644
--- a/mex/build/local_state_space_iterations.am
+++ b/mex/build/local_state_space_iterations.am
@@ -1,6 +1,6 @@
 mex_PROGRAMS = local_state_space_iteration_2 local_state_space_iteration_k
 
-nodist_local_state_space_iteration_2_SOURCES = local_state_space_iteration_2.cc lssi2_avx2.S
+nodist_local_state_space_iteration_2_SOURCES = local_state_space_iteration_2.cc lssi2_avx2.S lssi2_avx512.S
 nodist_local_state_space_iteration_k_SOURCES = local_state_space_iteration_k.cc
 
 local_state_space_iteration_2_CPPFLAGS = $(AM_CPPFLAGS) -I$(top_srcdir)/../../sources/local_state_space_iterations
diff --git a/mex/sources/local_state_space_iterations/local_state_space_iteration_2.cc b/mex/sources/local_state_space_iterations/local_state_space_iteration_2.cc
index c4d3947f4..ac3347d12 100644
--- a/mex/sources/local_state_space_iterations/local_state_space_iteration_2.cc
+++ b/mex/sources/local_state_space_iterations/local_state_space_iteration_2.cc
@@ -217,8 +217,29 @@ ss2Iteration(double *y, const double *yhat, const double *epsilon,
 
   // Runtime selection of kernel
 #if defined(__x86_64__) && defined(__LP64__)
-  if ((kernel == "auto" &&__builtin_cpu_supports("avx2") && __builtin_cpu_supports("fma"))
-      || kernel == "avx2")
+  if ((kernel == "auto" &&__builtin_cpu_supports("avx512f"))
+      || kernel == "avx512")
+    {
+      int it_total = s/8;
+      int it_per_thread = it_total / number_of_threads;
+      std::vector<std::thread> threads(number_of_threads);
+      for (int i = 0; i < number_of_threads; i++)
+        threads[i] = std::thread{[&, i] { // i is captured by value, since it changes
+          int offset = i*it_per_thread*8;
+          int s2 = i == number_of_threads - 1 ? it_total*8 - offset : it_per_thread*8;
+          lssi2_avx512(y+offset*m, yhat+offset*n, epsilon+offset*q,
+                       ghx, ghu, constant, ghxx, ghuu, ghxu,
+                       static_cast<int>(m), static_cast<int>(n), static_cast<int>(q), s2);
+        }};
+      for (int i = 0; i < number_of_threads; i++)
+        threads[i].join();
+      if (int rem = s % 8; rem != 0)
+        // If s is not a multiple of 8, use the generic routine to finish the computation
+        lssi2_generic(y+(s-rem)*m, yhat+(s-rem)*n, epsilon+(s-rem)*q,
+                      ghx, ghu, constant, ghxx, ghuu, ghxu, m, n, q, rem, 1);
+    }
+  else if ((kernel == "auto" &&__builtin_cpu_supports("avx2") && __builtin_cpu_supports("fma"))
+           || kernel == "avx2")
     {
       int it_total = s/4;
       int it_per_thread = it_total / number_of_threads;
diff --git a/mex/sources/local_state_space_iterations/lssi2.hh b/mex/sources/local_state_space_iterations/lssi2.hh
index 36ba04f57..877930e00 100644
--- a/mex/sources/local_state_space_iterations/lssi2.hh
+++ b/mex/sources/local_state_space_iterations/lssi2.hh
@@ -40,4 +40,9 @@ extern "C" void lssi2_avx2(double *y, const double *yhat, const double *epsilon,
                            const double *ghxx, const double *ghuu, const double *ghxu,
                            int m, int n, int q, int s);
 
+extern "C" void lssi2_avx512(double *y, const double *yhat, const double *epsilon,
+                             const double *ghx, const double *ghu, const double *constant,
+                             const double *ghxx, const double *ghuu, const double *ghxu,
+                             int m, int n, int q, int s);
+
 #endif
diff --git a/mex/sources/local_state_space_iterations/lssi2_avx2.S b/mex/sources/local_state_space_iterations/lssi2_avx2.S
index 692d02d8e..1392d31bd 100644
--- a/mex/sources/local_state_space_iterations/lssi2_avx2.S
+++ b/mex/sources/local_state_space_iterations/lssi2_avx2.S
@@ -37,22 +37,7 @@
 
 
 ### Some useful macros
-
-	## Push a register to the stack and adjust CFI information accordingly
-        ## Also increment a counter
-	.macro push_cfi_reg reg
-	push \reg
-	.cfi_adjust_cfa_offset 8
-	.cfi_rel_offset \reg, 0
-        .set pushed_regs,pushed_regs+1
-	.endm
-
-	## Pop a register from the stack and adjust CFI information accordingly
-	.macro pop_cfi_reg reg
-	pop \reg
-	.cfi_adjust_cfa_offset -8
-	.cfi_restore \reg
-	.endm
+#include "lssi2_common.S"
 
 ### Some pre-defined vector constants
         .section .rodata
diff --git a/mex/sources/local_state_space_iterations/lssi2_avx512.S b/mex/sources/local_state_space_iterations/lssi2_avx512.S
new file mode 100644
index 000000000..df39c2bcc
--- /dev/null
+++ b/mex/sources/local_state_space_iterations/lssi2_avx512.S
@@ -0,0 +1,327 @@
+/*
+ * Local state-space iteration at order 2, without pruning.
+ * Specialized kernel in x86-64 assembly using AVX512F.
+ * See the function prototype in lssi2.hh.
+ *
+ * WARNING: the number of particles must be a multiple of 8.
+ */
+
+/*
+ * Copyright © 2021 Dynare Team
+ *
+ * This file is part of Dynare.
+ *
+ * Dynare is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * Dynare is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with Dynare.  If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#if defined(__x86_64__) && defined(__LP64__)
+
+        .intel_syntax noprefix
+
+	## Enforce required CPU features, so that GNU as errors out if we use
+        ## instructions not in that feature set
+        .arch generic64
+        .arch .avx512f
+
+
+### Some useful macros
+#include "lssi2_common.S"
+
+
+### Some pre-defined vector constants
+        .section .rodata
+
+        .align 64
+v8_double_0p5:
+        .rept 8
+        .double 0.5
+        .endr
+
+        .align 32
+v8_int_0_1_2_3:
+        .int 0, 1, 2, 3, 4, 5, 6, 7
+v8_int_1:
+        .rept 8
+        .int 1
+        .endr
+
+
+### Layout of the current stack frame and of the beginning of the previous
+### one (for arguments passed through the stack).
+### This defines offset symbols through which variables and arguments can
+### be conveniently referred.
+        .struct 0
+constant_offset:
+        .space 8
+ghu_offset:
+        .space 8
+ghx_offset:
+        .space 8
+saved_regs_offset:
+        .space 6*8
+return_address_offset:
+        .space 8
+        ## Beginning of the previous stack frame
+ghxx_offset:
+        .space 8
+ghuu_offset:
+        .space 8
+ghxu_offset:
+        .space 8
+m_offset:
+        .space 8
+n_offset:
+        .space 8
+q_offset:
+        .space 8
+s_offset:
+        .space 8
+
+
+### Function body
+	.text
+	.global lssi2_avx512
+lssi2_avx512:
+        .cfi_startproc
+	.set pushed_regs,0
+	push_cfi_reg rbp
+        push_cfi_reg rbx
+        push_cfi_reg r12
+        push_cfi_reg r13
+        push_cfi_reg r14
+        push_cfi_reg r15
+        .if return_address_offset-saved_regs_offset != pushed_regs*8
+        .error "Incorrect stack frame layout regarding saved registers"
+        .endif
+
+	## Allocate space for variables in the current stack frame
+        sub rsp,saved_regs_offset
+        .cfi_adjust_cfa_offset saved_regs_offset
+
+	## Define more explicit names for registers that will remain associated
+        ## to input arguments
+	.set y,rdi
+        .set yhat,rsi
+        .set epsilon,rdx
+        .set m,r12d
+        .set n,r13d
+        .set q,r14d
+        .set s,r15d
+
+	## Initialize those registers that need it
+        mov m,[rsp+m_offset]                    # number of states + observed vars
+        mov n,[rsp+n_offset]                    # number of states
+        mov q,[rsp+q_offset]                    # number of shocks
+	mov s,[rsp+s_offset]                    # number of particles
+
+	## Conversely save some arguments to the stack
+        mov [rsp+ghx_offset],rcx
+        mov [rsp+ghu_offset],r8
+        mov [rsp+constant_offset],r9
+
+	## Do nothing if s=0
+	test s,s
+        jz done
+
+	## Precompute 8m (used for moving pointers by one column inside y and gh* matrices)
+	mov r11d,m
+        shl r11,3                               # Use a 64-bit op, since m could theoretically be 2³¹-1
+
+	## Pre-compute some vectorized constants
+        vmovd xmm15,n
+        vpbroadcastd ymm15,xmm15                # ymm15 = [ int32: n (×8) ]
+	vmovd xmm14,q
+        vpbroadcastd ymm14,xmm14                # ymm14 = [ int32: q (×8) ]
+        vmovd xmm13,m
+        vpbroadcastd ymm13,xmm13                # ymm13 = [ int32: m (×8) ]
+
+	vmovdqa ymm12,[rip+v8_int_0_1_2_3]
+        vpmulld ymm12,ymm12,ymm15               # ymm12 = base VSIB for current particles in ŷ
+	vmovdqa ymm11,[rip+v8_int_0_1_2_3]
+        vpmulld ymm11,ymm11,ymm14               # ymm11 = base VSIB for current particles in ε
+	vmovdqa ymm10,[rip+v8_int_0_1_2_3]
+        vpmulld ymm10,ymm10,ymm13               # ymm10 = base VSIB for particles in y
+
+	vpslld ymm15,ymm15,3                    # ymm15 = [ int32: 8n (×8) ]
+	vpslld ymm14,ymm14,3                    # ymm14 = [ int32: 8q (×8) ]
+	vpslld ymm13,ymm13,3                    # ymm13 = [ int32: 8m (×8) ]
+
+	vmovdqa ymm8,[rip+v8_int_1]             # ymm8 = [ int32: 1 (×8) ]
+        vmovapd zmm16,[rip+v8_double_0p5]       # zmm16 = [ double: 0.5 (×8) ]
+
+	## Enter the main loop
+	xor ebx,ebx                             # ebx = particle counter
+
+next_8_particles:
+        xor ecx,ecx                             # ecx = variable counter
+	vmovdqa ymm9,ymm10                      # ymm9 = VSIB for current particles in y
+
+next_variable:
+	## Use ymm0 to store the 4 next-period particles for this specific variable
+        ## Initialize ymm0 to the constant
+	mov r8,[rsp+constant_offset]
+	vbroadcastsd zmm0,[r8+rcx*8]
+
+        ## Add ghx·ŷ
+	mov ebp,n                               # ebp = number of remaining state variables
+        vmovdqa ymm1,ymm12                      # ymm1 = VSIB for current particles in ŷ
+        mov r8,[rsp+ghx_offset]
+        lea r8,[r8+rcx*8]                       # r8 = pointer to ghxᵢⱼ
+	.align 16                               # Recommendation p. 537 of Kusswurm (2018)
+next_state:
+        kxnorw k1,k1,k1
+	vgatherdpd zmm2{k1},[yhat+ymm1*8]       # zmm2 = current particles for ŷⱼ
+	vfmadd231pd zmm0,zmm2,[r8]{1to8}        # Add ghxᵢⱼ·ŷⱼ to current particles
+        add r8,r11                              # Move to next column in ghx
+	vpaddd ymm1,ymm1,ymm8                   # Update VSIB for next state (j=j+1)
+	sub ebp,1
+        jnz next_state
+
+	## Add ghu·ε
+	mov ebp,q                               # ebp = number of remaining shocks
+        vmovdqa ymm1,ymm11                      # ymm1 = VSIB for current particles in ε
+        mov r8,[rsp+ghu_offset]
+        lea r8,[r8+rcx*8]                       # r8 = pointer to ghuᵢⱼ
+        .align 16
+next_shock:
+        kxnorw k1,k1,k1
+	vgatherdpd zmm2{k1},[epsilon+ymm1*8]    # zmm2 = current particles for εⱼ
+	vfmadd231pd zmm0,zmm2,[r8]{1to8}        # Add ghuᵢⱼ·εⱼ to current particles
+        add r8,r11                              # Move to next column in ghu
+	vpaddd ymm1,ymm1,ymm8                   # Update VSIB for next shock (j=j+1)
+        sub ebp,1
+        jnz next_shock
+
+        ## Add ½ghxx·ŷ⊗ŷ
+	xor ebp,ebp                             # Index of first state (j₁)
+        mov r8,[rsp+ghxx_offset]
+        lea r8,[r8+rcx*8]                       # r8 = pointer to ghxxᵢⱼ
+        vmovdqa ymm1,ymm12                      # ymm1 = VSIB for current particles in ŷⱼ₁
+        .align 16
+next_state_state_1:
+        mov r10d,ebp                            # Index of second state (j₂)
+        vmovdqa ymm2,ymm1                       # ymm2 = VSIB for current particles in ŷⱼ₂
+next_state_state_2:
+        kxnorw k1,k1,k1
+        vgatherdpd zmm4{k1},[yhat+ymm1*8]       # zmm4 = current particles for ŷⱼ₁
+        kxnorw k2,k2,k2
+        vgatherdpd zmm6{k2},[yhat+ymm2*8]       # zmm6 = current particles for ŷⱼ₂
+        vmulpd zmm3,zmm4,zmm6                   # zmm3 = particles for ŷⱼ₁·ŷⱼ₂
+        ## NB: We compare the zmm registers to avoid requiring AVX512VL
+	vpcmpeqd k3,zmm1,zmm2                   # k3[0:7]=1 if diagonal, 0 otherwise
+        vmulpd zmm3{k3},zmm3,zmm16              # zmm3 = ½ŷⱼ₁·ŷⱼ₂ if diagonal, ŷⱼ₁·ŷⱼ₂ otherwise
+	vfmadd231pd zmm0,zmm3,[r8]{1to8}        # Add (½?)ghxxᵢⱼ·ŷⱼ₁·ŷⱼ₂ to current particles
+	vpaddd ymm2,ymm2,ymm8                   # Update VSIB for next second state (j₂=j₂+1)
+        add r8,r11                              # Move to next column in ghxx
+	add r10d,1
+	cmp r10d,n
+        jl next_state_state_2
+	vpaddd ymm1,ymm1,ymm8                   # Update VSIB for next first state (j₁=j₁+1)
+	add ebp,1
+	mov eax,ebp
+        imul rax,r11
+        add r8,rax                              # Jump several columns in ghxx
+	cmp ebp,n
+        jl next_state_state_1
+
+        ## Add ½ghuu·ε⊗ε
+	xor ebp,ebp                             # Index of first shock (j₁)
+        mov r8,[rsp+ghuu_offset]
+	lea r8,[r8+rcx*8]                       # r8 = pointer to ghuuᵢⱼ
+        vmovdqa ymm1,ymm11                      # ymm1 = VSIB for current particles in εⱼ₁
+        .align 16
+next_shock_shock_1:
+        mov r10d,ebp                            # Index of second shock (j₂)
+        vmovdqa ymm2,ymm1                       # ymm2 = VSIB for current particles in εⱼ₂
+next_shock_shock_2:
+        kxnorw k1,k1,k1
+        vgatherdpd zmm4{k1},[epsilon+ymm1*8]    # zmm4 = current particles for εⱼ₁
+        kxnorw k2,k2,k2
+        vgatherdpd zmm6{k2},[epsilon+ymm2*8]    # zmm6 = current particles for εⱼ₂
+        vmulpd zmm3,zmm4,zmm6                   # zmm3 = particles for εⱼ₁·εⱼ₂
+        ## NB: We compare the zmm registers to avoid requiring AVX512VL
+	vpcmpeqd k3,zmm1,zmm2                   # k3[0:7]=1 if diagonal, 0 otherwise
+        vmulpd zmm3{k3},zmm3,zmm16              # zmm3 = ½εⱼ₁εⱼ₂ if diagonal, εⱼ₁εⱼ₂ otherwise
+	vfmadd231pd zmm0,zmm3,[r8]{1to8}        # Add ghuuᵢⱼ·εⱼ₁εⱼ₂ to current particles
+	vpaddd ymm2,ymm2,ymm8                   # Update VSIB for next second shock (j₂=j₂+1)
+        add r8,r11                              # Move to next column in ghuu
+	add r10d,1
+	cmp r10d,q
+        jl next_shock_shock_2
+	vpaddd ymm1,ymm1,ymm8                   # Update VSIB for next first shock (j₁=j₁+1)
+	add ebp,1
+	mov eax,ebp
+        imul rax,r11
+        add r8,rax                              # Jump several columns in ghuu
+	cmp ebp,q
+        jl next_shock_shock_1
+
+	## Add ghxu·ŷ⊗ε
+	mov ebp,n                               # ebp = number of remaining states
+        vmovdqa ymm1,ymm12                      # ymm1 = VSIB for current particles in ŷ
+        mov r8,[rsp+ghxu_offset]
+        lea r8,[r8+rcx*8]                       # r8 = pointer to ghxuᵢⱼ
+	.align 16
+next_state_2:
+        mov eax,q                               # eax = number of remaining shocks
+        vmovdqa ymm2,ymm11                      # ymm2 = VSIB for current particles in ε
+        kxnorw k1,k1,k1
+	vgatherdpd zmm4{k1},[yhat+ymm1*8]       # zmm4 = current particles for ŷⱼ₁
+	.align 16
+next_shock_2:
+        kxnorw k2,k2,k2
+	vgatherdpd zmm6{k2},[epsilon+ymm2*8]    # zmm6 = current particles for εⱼ₂
+        vmulpd zmm6,zmm4,zmm6                   # zmm6 = particles for ŷⱼ₁εⱼ₂
+        vfmadd231pd zmm0,zmm6,[r8]{1to8}        # Add ghxuᵢⱼ·ŷⱼ₁εⱼ₂ to current particles
+        add r8,r11                              # Move to next column in ghxu
+	vpaddd ymm2,ymm2,ymm8                   # Update VSIB for next shock (j₂=j₂+1)
+        sub eax,1
+        jnz next_shock_2
+	vpaddd ymm1,ymm1,ymm8                   # Update VSIB for next state (j₁=j₁+1)
+	sub ebp,1
+        jnz next_state_2
+
+	## Save updated particles to memory
+        kxnorw k1,k1,k1
+        vscatterdpd [y+ymm9*8]{k1},zmm0
+
+        ## Loop over variables
+        vpaddd ymm9,ymm9,ymm8                   # Update VSIB for next variable in y
+	add ecx,1
+        cmp ecx,m
+        jl next_variable
+
+	## Loop over particles
+	add ebx,8
+        vpaddd ymm12,ymm12,ymm15                # Update base VSIB for ŷ
+        vpaddd ymm11,ymm11,ymm14                # Update base VSIB for ε
+        vpaddd ymm10,ymm10,ymm13                # Update base VSIB for y
+        cmp ebx,s
+        jl next_8_particles
+
+done:
+	## Cleanup
+	add rsp,saved_regs_offset
+        .cfi_adjust_cfa_offset -saved_regs_offset
+	pop_cfi_reg r15
+        pop_cfi_reg r14
+        pop_cfi_reg r13
+        pop_cfi_reg r12
+        pop_cfi_reg rbx
+	pop_cfi_reg rbp
+        vzeroupper
+        ret
+        .cfi_endproc
+
+#endif
diff --git a/mex/sources/local_state_space_iterations/lssi2_common.S b/mex/sources/local_state_space_iterations/lssi2_common.S
new file mode 100644
index 000000000..9d7f7742c
--- /dev/null
+++ b/mex/sources/local_state_space_iterations/lssi2_common.S
@@ -0,0 +1,15 @@
+	## Push a register to the stack and adjust CFI information accordingly
+        ## Also increment a counter
+	.macro push_cfi_reg reg
+	push \reg
+	.cfi_adjust_cfa_offset 8
+	.cfi_rel_offset \reg, 0
+        .set pushed_regs,pushed_regs+1
+	.endm
+
+	## Pop a register from the stack and adjust CFI information accordingly
+	.macro pop_cfi_reg reg
+	pop \reg
+	.cfi_adjust_cfa_offset -8
+	.cfi_restore \reg
+	.endm
diff --git a/tests/particle/local_state_space_iteration_k_test.mod b/tests/particle/local_state_space_iteration_k_test.mod
index 446aedd13..7705f9ed5 100644
--- a/tests/particle/local_state_space_iteration_k_test.mod
+++ b/tests/particle/local_state_space_iteration_k_test.mod
@@ -47,19 +47,25 @@ rf_ghxx = dr.ghxx(dr.restrict_var_list, :);
 rf_ghuu = dr.ghuu(dr.restrict_var_list, :);
 rf_ghxu = dr.ghxu(dr.restrict_var_list, :);
 
-setenv("DYNARE_LSSI2_KERNEL", "avx2")
+setenv("DYNARE_LSSI2_KERNEL", "avx512")
 tic;
 ynext1 = local_state_space_iteration_2(yhat, epsilon, rf_ghx, rf_ghu, rf_constant, rf_ghxx, rf_ghuu, rf_ghxu, options_.threads.local_state_space_iteration_2);
 toc;
 
-setenv("DYNARE_LSSI2_KERNEL", "generic")
+setenv("DYNARE_LSSI2_KERNEL", "avx2")
 tic;
 ynext2 = local_state_space_iteration_2(yhat, epsilon, rf_ghx, rf_ghu, rf_constant, rf_ghxx, rf_ghuu, rf_ghxu, options_.threads.local_state_space_iteration_2);
 toc;
 
+setenv("DYNARE_LSSI2_KERNEL", "generic")
+tic;
+ynext3 = local_state_space_iteration_2(yhat, epsilon, rf_ghx, rf_ghu, rf_constant, rf_ghxx, rf_ghuu, rf_ghxu, options_.threads.local_state_space_iteration_2);
+toc;
+
 setenv("DYNARE_LSSI2_KERNEL", "auto")
 
-max(max(abs(ynext2-ynext1)))
+max(max(abs(ynext1-ynext3)))
+max(max(abs(ynext2-ynext3)))
 
 /*
 expected = rf_constant+rf_ghx*yhat+rf_ghu*epsilon;
-- 
GitLab