Line data Source code
1 : /** 2 : * helpers for matrix tests 3 : */ 4 : 5 : #ifndef T_MATRIX_HPP 6 : #define T_MATRIX_HPP 7 : 8 : #include <cmath> 9 : #include <cstdint> 10 : #include <gtest/gtest.h> 11 : #include <iostream> 12 : #include <limits.h> 13 : #include <string> 14 : #include <vector> 15 : 16 : // utility test helper function to compare two matrices. used for verifying 17 : // accelerated/non-standard implementations to the simple naive algorithm 18 : // for matrix arithmetic operations 19 : template <typename T> 20 7 : bool mtx_verif(const std::vector<std::vector<T>> &A, 21 : const std::vector<std::vector<T>> &B) { 22 7 : if (A.size() != B.size() || A[0].size() != B[0].size()) { 23 0 : return false; 24 : } 25 : 26 5255 : for (size_t i = 0; i < A.size(); ++i) { 27 5256320 : for (size_t j = 0; j < A[i].size(); ++j) { 28 10502144 : if (std::abs(A[i][j] - B[i][j]) > 29 5251072 : std::numeric_limits<T>::epsilon()) { 30 0 : return false; 31 : } 32 : } 33 : } 34 7 : return true; 35 : } 36 : 37 : template <typename T> 38 : bool mtx_verif(const std::vector<T> &A, const std::vector<T> &B) { 39 : int rows = A.size(); 40 : int cols = rows > 0 ? A.size() / rows : 0; 41 : 42 : if (A.size() != B.size()) { 43 : return false; 44 : } 45 : 46 : for (size_t i = 0; i < rows; ++i) { 47 : for (size_t j = 0; j < cols; ++j) { 48 : if (A[i * cols + j] != B[i * cols + j]) { 49 : return false; 50 : } 51 : } 52 : } 53 : return true; 54 : } 55 : 56 : template <typename T> 57 24 : bool mtx_verif(const T *A, const T *B, int rows, int cols) { 58 19808 : for (int i = 0; i < rows; ++i) { 59 19342120 : for (int j = 0; j < cols; ++j) { 60 31270768 : if (std::abs(A[i * cols + j] - B[i * cols + j]) > 61 19322336 : std::numeric_limits<T>::epsilon()) { 62 0 : return false; 63 : } 64 : } 65 : } 66 24 : return true; 67 : } 68 : 69 : template <typename T> void print_matrix(const T *matrix, int rows, int cols) { 70 : for (int i = 0; i < rows; ++i) { 71 : for (int j = 0; j < cols; ++j) { 72 : std::cout << matrix[i * cols + j] << " "; 73 : } 74 : std::cout << std::endl; 75 : } 76 : std::cout << std::endl; 77 : } 78 : 79 : #endif