LCOV - code coverage report
Current view: top level - modules/linalg/avx - mtx_avx2_arr_f64.cpp (source / functions) Hit Total Coverage
Test: lcov.info Lines: 25 34 73.5 %
Date: 2024-05-13 05:06:06 Functions: 2 2 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*************************************************************************
       2             :  *
       3             :  *  Project
       4             :  *                         _____ _____  __  __ _____
       5             :  *                        / ____|  __ \|  \/  |  __ \
       6             :  *  ___  _ __   ___ _ __ | |  __| |__) | \  / | |__) |
       7             :  * / _ \| '_ \ / _ \ '_ \| | |_ |  ___/| |\/| |  ___/
       8             :  *| (_) | |_) |  __/ | | | |__| | |    | |  | | |
       9             :  * \___/| .__/ \___|_| |_|\_____|_|    |_|  |_|_|
      10             :  *      | |
      11             :  *      |_|
      12             :  *
      13             :  * Copyright (C) Akiel Aries, <akiel@akiel.org>, et al.
      14             :  *
      15             :  * This software is licensed as described in the file LICENSE, which
      16             :  * you should have received as part of this distribution. The terms
      17             :  * among other details are referenced in the official documentation
      18             :  * seen here : https://akielaries.github.io/openGPMP/ along with
      19             :  * important files seen in this project.
      20             :  *
      21             :  * You may opt to use, copy, modify, merge, publish, distribute
      22             :  * and/or sell copies of the Software, and permit persons to whom
      23             :  * the Software is furnished to do so, under the terms of the
      24             :  * LICENSE file. As this is an Open Source effort, all implementations
      25             :  * must be of the same methodology.
      26             :  *
      27             :  *
      28             :  *
      29             :  * This software is distributed on an AS IS basis, WITHOUT
      30             :  * WARRANTY OF ANY KIND, either express or implied.
      31             :  *
      32             :  ************************************************************************/
      33             : #include <cassert>
      34             : #include <cstddef>
      35             : #include <cstdint>
      36             : #include <iostream>
      37             : #include <openGPMP/linalg/mtx.hpp>
      38             : #include <vector>
      39             : 
      40             : #if defined(__x86_64__) || defined(__amd64__) || defined(__amd64)
      41             : 
      42             : /************************************************************************
      43             :  *
      44             :  * Matrix Operations for AVX ISA
      45             :  *
      46             :  ************************************************************************/
      47             : #if defined(__AVX2__)
      48             : 
      49             : // AVX family intrinsics
      50             : #include <immintrin.h>
      51             : 
      52             : /************************************************************************
      53             :  *
      54             :  * Matrix Operations on Arrays
      55             :  *
      56             :  ************************************************************************/
      57             : // matrix addition using Intel intrinsics, accepts double arrays as matrices
      58           3 : void gpmp::linalg::Mtx::mtx_add(const double *A,
      59             :                                 const double *B,
      60             :                                 double *C,
      61             :                                 int rows,
      62             :                                 int cols) {
      63           3 :     if (rows > 8) {
      64        2115 :         for (int i = 0; i < rows; ++i) {
      65        2112 :             int j = 0;
      66             :             // requires at least size 4x4 size matrices
      67      527424 :             for (; j < cols - 3; j += 4) {
      68             :                 // load 4 elements from A, B, and C matrices using SIMD
      69      525312 :                 __m256d a = _mm256_loadu_pd(&A[i * cols + j]);
      70      525312 :                 __m256d b = _mm256_loadu_pd(&B[i * cols + j]);
      71     1050624 :                 __m256d c = _mm256_loadu_pd(&C[i * cols + j]);
      72             :                 // perform vectorized addition and accumulate the result
      73      525312 :                 c = _mm256_add_pd(a, b);
      74             : 
      75             :                 // store the result back to the C matrix
      76      525312 :                 _mm256_storeu_pd(&C[i * cols + j], c);
      77             :             }
      78             : 
      79             :             // handle the remaining elements that are not multiples of 8
      80        2112 :             for (; j < cols; ++j) {
      81           0 :                 C[i * cols + j] = A[i * cols + j] + B[i * cols + j];
      82             :             }
      83             :         }
      84             :     } else {
      85             :         // use standard matrix addition
      86           0 :         std_mtx_add(A, B, C, rows, cols);
      87             :     }
      88           3 : }
      89             : 
      90           1 : void gpmp::linalg::Mtx::mtx_mult(const double *A,
      91             :                                  const double *B,
      92             :                                  double *C,
      93             :                                  int rows_a,
      94             :                                  int cols_a,
      95             :                                  int cols_b) {
      96           1 :     if (cols_a != rows_a) {
      97             :         // Matrix dimensions don't match for multiplication
      98           0 :         std::cerr << "Matching error";
      99           0 :         return;
     100             :     }
     101             : 
     102           1 :     if (rows_a > 8) {
     103             : 
     104        1025 :         for (int i = 0; i < rows_a; ++i) {
     105      263168 :             for (int j = 0; j < cols_b - 3; j += 4) {
     106             :                 // creat result vector of zeros
     107      262144 :                 __m256d sum_vec = _mm256_setzero_pd();
     108             : 
     109   268697600 :                 for (int k = 0; k < cols_a; ++k) {
     110   268435456 :                     __m256d a_vec = _mm256_set1_pd(A[i * cols_a + k]);
     111             : 
     112   536870912 :                     __m256d b_vec = _mm256_loadu_pd(&B[k * cols_b + j]);
     113             : 
     114   268435456 :                     __m256d prod = _mm256_mul_pd(a_vec, b_vec);
     115             : 
     116   268435456 :                     sum_vec = _mm256_add_pd(sum_vec, prod);
     117             :                 }
     118      262144 :                 _mm256_storeu_pd(&C[i * cols_b + j], sum_vec);
     119             :             }
     120             : 
     121             :             // handle remaining elements not multiples of 4
     122        1024 :             for (int j = cols_b - cols_b % 4; j < cols_b; ++j) {
     123           0 :                 double sum = 0.0;
     124             : 
     125           0 :                 for (int k = 0; k < cols_a; ++k) {
     126           0 :                     sum += A[i * cols_a + k] * B[k * cols_b + j];
     127             :                 }
     128           0 :                 C[i * cols_b + j] = sum;
     129             :             }
     130             :         }
     131             : 
     132             :     }
     133             : 
     134             :     else {
     135           0 :         std_mtx_mult(A, B, C, rows_a, cols_a, cols_b);
     136             :     }
     137             : }
     138             : 
     139             : #endif
     140             : 
     141             : // x86
     142             : #endif

Generated by: LCOV version 1.14