Line data Source code
1 : /************************************************************************* 2 : * 3 : * Project 4 : * _____ _____ __ __ _____ 5 : * / ____| __ \| \/ | __ \ 6 : * ___ _ __ ___ _ __ | | __| |__) | \ / | |__) | 7 : * / _ \| '_ \ / _ \ '_ \| | |_ | ___/| |\/| | ___/ 8 : *| (_) | |_) | __/ | | | |__| | | | | | | | 9 : * \___/| .__/ \___|_| |_|\_____|_| |_| |_|_| 10 : * | | 11 : * |_| 12 : * 13 : * Copyright (C) Akiel Aries, <akiel@akiel.org>, et al. 14 : * 15 : * This software is licensed as described in the file LICENSE, which 16 : * you should have received as part of this distribution. The terms 17 : * among other details are referenced in the official documentation 18 : * seen here : https://akielaries.github.io/openGPMP/ along with 19 : * important files seen in this project. 20 : * 21 : * You may opt to use, copy, modify, merge, publish, distribute 22 : * and/or sell copies of the Software, and permit persons to whom 23 : * the Software is furnished to do so, under the terms of the 24 : * LICENSE file. As this is an Open Source effort, all implementations 25 : * must be of the same methodology. 26 : * 27 : * 28 : * 29 : * This software is distributed on an AS IS basis, WITHOUT 30 : * WARRANTY OF ANY KIND, either express or implied. 31 : * 32 : ************************************************************************/ 33 : #include <cassert> 34 : #include <cstddef> 35 : #include <cstdint> 36 : #include <cstring> 37 : #include <iostream> 38 : #include <openGPMP/linalg/mtx.hpp> 39 : #include <vector> 40 : 41 : #if defined(__x86_64__) || defined(__amd64__) || defined(__amd64) 42 : 43 : /************************************************************************ 44 : * 45 : * Matrix Operations for AVX ISA 46 : * 47 : ************************************************************************/ 48 : #if defined(__AVX2__) 49 : 50 : // AVX family intrinsics 51 : #include <immintrin.h> 52 : 53 : /************************************************************************ 54 : * 55 : * Matrix Operations on Arrays 56 : * 57 : ************************************************************************/ 58 : // matrix addition for 8-bit integers using 256-bit SIMD registers 59 3 : void gpmp::linalg::Mtx::mtx_add(const int8_t *A, 60 : const int8_t *B, 61 : int8_t *C, 62 : int rows, 63 : int cols) { 64 : // BUG FIXME 65 2451 : for (int i = 0; i < rows; ++i) { 66 2448 : int j = 0; 67 72784 : for (; j < cols - 31; j += 32) { 68 70336 : __m256i a = _mm256_loadu_si256( 69 70336 : reinterpret_cast<const __m256i *>(&A[i * cols + j])); 70 70336 : __m256i b = _mm256_loadu_si256( 71 70336 : reinterpret_cast<const __m256i *>(&B[i * cols + j])); 72 70336 : __m256i c = _mm256_loadu_si256( 73 70336 : reinterpret_cast<const __m256i *>(&C[i * cols + j])); 74 : 75 : // Perform vectorized addition and accumulate the result 76 70336 : c = _mm256_add_epi8(c, _mm256_add_epi8(a, b)); 77 : 78 : // Store the result back to the C matrix 79 70336 : _mm256_storeu_si256(reinterpret_cast<__m256i *>(&C[i * cols + j]), 80 : c); 81 : } 82 : 83 8848 : for (; j < cols; ++j) { 84 6400 : C[i * cols + j] = A[i * cols + j] + B[i * cols + j]; 85 : } 86 : } 87 3 : } 88 : 89 3 : void gpmp::linalg::Mtx::mtx_sub(const int8_t *A, 90 : const int8_t *B, 91 : int8_t *C, 92 : int rows, 93 : int cols) { 94 2451 : for (int i = 0; i < rows; ++i) { 95 2448 : int j = 0; 96 72784 : for (; j < cols - 31; j += 32) { 97 70336 : __m256i a = _mm256_loadu_si256( 98 70336 : reinterpret_cast<const __m256i *>(&A[i * cols + j])); 99 70336 : __m256i b = _mm256_loadu_si256( 100 70336 : reinterpret_cast<const __m256i *>(&B[i * cols + j])); 101 70336 : __m256i c = _mm256_loadu_si256( 102 70336 : reinterpret_cast<const __m256i *>(&C[i * cols + j])); 103 : 104 : // Perform vectorized subtraction and accumulate the result 105 70336 : c = _mm256_sub_epi8(a, b); 106 : 107 : // Store the result back to the C matrix 108 70336 : _mm256_storeu_si256(reinterpret_cast<__m256i *>(&C[i * cols + j]), 109 : c); 110 : } 111 : 112 8848 : for (; j < cols; ++j) { 113 6400 : C[i * cols + j] = A[i * cols + j] - B[i * cols + j]; 114 : } 115 : } 116 3 : } 117 : 118 0 : void gpmp::linalg::Mtx::mtx_mult(const int8_t *A, 119 : const int8_t *B, 120 : int8_t *C, 121 : int rows_a, 122 : int cols_a, 123 : int cols_b) { 124 : 125 0 : for (int i = 0; i < rows_a; ++i) { 126 0 : for (int j = 0; j < cols_b; j += 32) { 127 0 : __m256i c = _mm256_setzero_si256(); 128 : 129 0 : for (int k = 0; k < cols_a; ++k) { 130 0 : __m256i a = _mm256_set1_epi8(A[i * cols_a + k]); 131 0 : __m256i b = _mm256_loadu_si256( 132 0 : reinterpret_cast<const __m256i *>(&B[k * cols_b + j])); 133 : 134 0 : __m256i prod = _mm256_maddubs_epi16(a, b); 135 0 : c = _mm256_add_epi16(c, prod); 136 : } 137 : 138 0 : c = _mm256_srai_epi16(c, 8); 139 0 : c = _mm256_packs_epi16(c, _mm256_setzero_si256()); 140 : 141 0 : _mm256_storeu_si256(reinterpret_cast<__m256i *>(&C[i * cols_b + j]), 142 : c); 143 : } 144 : 145 : // Handle remaining elements 146 0 : for (int j = cols_b - cols_b % 32; j < cols_b; ++j) { 147 0 : int sum = 0; 148 : 149 0 : for (int k = 0; k < cols_a; ++k) { 150 0 : sum += A[i * cols_a + k] * B[k * cols_b + j]; 151 : } 152 : 153 0 : C[i * cols_b + j] = sum; 154 : } 155 : } 156 0 : } 157 : 158 : #endif 159 : 160 : // x86 161 : #endif