40 #if defined(__x86_64__) || defined(__amd64__) || defined(__amd64)
50 #include <immintrin.h>
59 void gpmp::linalg::Mtx::mtx_add(
const int *
A,
66 for (
int i = 0; i <
rows; ++i) {
69 for (; j <
cols - 7; j += 8) {
71 __m256i a = _mm256_loadu_si256(
72 reinterpret_cast<const __m256i *
>(&
A[i *
cols + j]));
73 __m256i b = _mm256_loadu_si256(
74 reinterpret_cast<const __m256i *
>(&
B[i *
cols + j]));
75 __m256i c = _mm256_loadu_si256(
76 reinterpret_cast<const __m256i *
>(&
C[i *
cols + j]));
79 c = _mm256_add_epi32(c, _mm256_add_epi32(a, b));
83 reinterpret_cast<__m256i *
>(&
C[i *
cols + j]),
88 for (; j <
cols; ++j) {
100 void gpmp::linalg::Mtx::mtx_sub(
const int *
A,
105 for (
int i = 0; i <
rows; ++i) {
107 for (; j <
cols - 7; j += 8) {
108 __m256i a = _mm256_loadu_si256(
109 reinterpret_cast<const __m256i *
>(&
A[i *
cols + j]));
110 __m256i b = _mm256_loadu_si256(
111 reinterpret_cast<const __m256i *
>(&
B[i *
cols + j]));
112 __m256i c = _mm256_loadu_si256(
113 reinterpret_cast<const __m256i *
>(&
C[i *
cols + j]));
116 c = _mm256_sub_epi32(a, b);
119 _mm256_storeu_si256(
reinterpret_cast<__m256i *
>(&
C[i *
cols + j]),
123 for (; j <
cols; ++j) {
135 for (
int i = 0; i < rows_a; ++i) {
136 for (
int j = 0; j < cols_b; j += 8) {
137 __m256i c = _mm256_setzero_si256();
139 for (
int k = 0; k < cols_a; ++k) {
140 __m256i a = _mm256_set1_epi32(
A[i * cols_a + k]);
141 __m256i b = _mm256_loadu_si256(
142 reinterpret_cast<const __m256i *
>(&
B[k * cols_b + j]));
144 __m256i prod = _mm256_mullo_epi32(a, b);
145 c = _mm256_add_epi32(c, prod);
148 _mm256_storeu_si256(
reinterpret_cast<__m256i *
>(&
C[i * cols_b + j]),
153 for (
int j = cols_b - cols_b % 8; j < cols_b; ++j) {
156 for (
int k = 0; k < cols_a; ++k) {
157 sum +=
A[i * cols_a + k] *
B[k * cols_b + j];
160 C[i * cols_b + j] = sum;
172 for (
int i = 0; i < rows_a; ++i) {
173 for (
int j = 0; j < cols_b; j += 4) {
174 __m256i c_lo = _mm256_setzero_si256();
175 __m256i c_hi = _mm256_setzero_si256();
177 for (
int k = 0; k < cols_a; ++k) {
178 __m256i a = _mm256_set1_epi32(
A[i * cols_a + k]);
179 __m256i b = _mm256_loadu_si256(
180 reinterpret_cast<const __m256i *
>(&
B[k * cols_b + j]));
183 __m256i prod = _mm256_mullo_epi32(a, b);
187 _mm256_cvtepi32_epi64(_mm256_extractf128_si256(prod, 0));
189 _mm256_cvtepi32_epi64(_mm256_extractf128_si256(prod, 1));
192 c_lo = _mm256_add_epi64(c_lo, prod_lo);
193 c_hi = _mm256_add_epi64(c_hi, prod_hi);
197 _mm256_storeu_si256(
reinterpret_cast<__m256i *
>(&
C[i * cols_b + j]),
200 reinterpret_cast<__m256i *
>(&
C[i * cols_b + j + 4]),
205 for (
int j = cols_b - cols_b % 4; j < cols_b; ++j) {
208 for (
int k = 0; k < cols_a; ++k) {
210 static_cast<int64_t
>(
A[i * cols_a + k]) *
B[k * cols_b + j];
213 C[i * cols_b + j] = sum;
218 void gpmp::linalg::Mtx::mtx_tpose(
const int *
A,
int *
C,
int rows,
int cols) {
220 for (
int i = 0; i <
rows; i += 8) {
221 for (
int j = 0; j <
cols; j += 8) {
223 __m256i a0 = _mm256_loadu_si256(
224 (__m256i *)(
const_cast<int *
>(
A) + i *
cols + j));
225 __m256i a1 = _mm256_loadu_si256(
226 (__m256i *)(
const_cast<int *
>(
A) + (i + 1) *
cols + j));
227 __m256i a2 = _mm256_loadu_si256(
228 (__m256i *)(
const_cast<int *
>(
A) + (i + 2) *
cols + j));
229 __m256i a3 = _mm256_loadu_si256(
230 (__m256i *)(
const_cast<int *
>(
A) + (i + 3) *
cols + j));
231 __m256i a4 = _mm256_loadu_si256(
232 (__m256i *)(
const_cast<int *
>(
A) + (i + 4) *
cols + j));
233 __m256i a5 = _mm256_loadu_si256(
234 (__m256i *)(
const_cast<int *
>(
A) + (i + 5) *
cols + j));
235 __m256i a6 = _mm256_loadu_si256(
236 (__m256i *)(
const_cast<int *
>(
A) + (i + 6) *
cols + j));
237 __m256i a7 = _mm256_loadu_si256(
238 (__m256i *)(
const_cast<int *
>(
A) + (i + 7) *
cols + j));
241 __m256i t0 = _mm256_unpacklo_epi32(a0, a1);
242 __m256i t1 = _mm256_unpacklo_epi32(a2, a3);
243 __m256i t2 = _mm256_unpacklo_epi32(a4, a5);
244 __m256i t3 = _mm256_unpacklo_epi32(a6, a7);
245 __m256i t4 = _mm256_unpackhi_epi32(a0, a1);
246 __m256i t5 = _mm256_unpackhi_epi32(a2, a3);
247 __m256i t6 = _mm256_unpackhi_epi32(a4, a5);
248 __m256i t7 = _mm256_unpackhi_epi32(a6, a7);
250 __m256i tt0 = _mm256_unpacklo_epi64(t0, t1);
251 __m256i tt1 = _mm256_unpackhi_epi64(t0, t1);
252 __m256i tt2 = _mm256_unpacklo_epi64(t2, t3);
253 __m256i tt3 = _mm256_unpackhi_epi64(t2, t3);
254 __m256i tt4 = _mm256_unpacklo_epi64(t4, t5);
255 __m256i tt5 = _mm256_unpackhi_epi64(t4, t5);
256 __m256i tt6 = _mm256_unpacklo_epi64(t6, t7);
257 __m256i tt7 = _mm256_unpackhi_epi64(t6, t7);
259 __m256i ttt0 = _mm256_permute2x128_si256(tt0, tt2, 0x20);
260 __m256i ttt1 = _mm256_permute2x128_si256(tt1, tt3, 0x20);
261 __m256i ttt2 = _mm256_permute2x128_si256(tt4, tt6, 0x20);
262 __m256i ttt3 = _mm256_permute2x128_si256(tt5, tt7, 0x20);
263 __m256i ttt4 = _mm256_permute2x128_si256(tt0, tt2, 0x31);
264 __m256i ttt5 = _mm256_permute2x128_si256(tt1, tt3, 0x31);
265 __m256i ttt6 = _mm256_permute2x128_si256(tt4, tt6, 0x31);
266 __m256i ttt7 = _mm256_permute2x128_si256(tt5, tt7, 0x31);
269 _mm256_storeu_si256((__m256i *)(
C + j *
rows + i), ttt0);
270 _mm256_storeu_si256((__m256i *)(
C + (j + 1) *
rows + i), ttt1);
271 _mm256_storeu_si256((__m256i *)(
C + (j + 2) *
rows + i), ttt2);
272 _mm256_storeu_si256((__m256i *)(
C + (j + 3) *
rows + i), ttt3);
273 _mm256_storeu_si256((__m256i *)(
C + (j + 4) *
rows + i), ttt4);
274 _mm256_storeu_si256((__m256i *)(
C + (j + 5) *
rows + i), ttt5);
275 _mm256_storeu_si256((__m256i *)(
C + (j + 6) *
rows + i), ttt6);
276 _mm256_storeu_si256((__m256i *)(
C + (j + 7) *
rows + i), ttt7);
void std_mtx_add(const T *A, const T *B, T *C, int rows, int cols)
Perform matrix addition on two matrices as flat arrays.
void mtx_mult(std::vector< double > A, std::vector< double > B, std::vector< double > C)