45 #if defined(__ARM_ARCH_ISA_A64) || defined(__ARM_NEON) || \
46 defined(__ARM_ARCH) || defined(__aarch64__)
56 void gpmp::linalg::Mtx::mtx_add(
const std::vector<std::vector<int>> &
A,
57 const std::vector<std::vector<int>> &
B,
58 std::vector<std::vector<int>> &
C) {
59 const int rows =
A.size();
60 const int cols =
A[0].size();
62 for (
int i = 0; i <
rows; ++i) {
66 for (; j <
cols - 15; j += 16) {
68 int32x4x4_t a = vld1q_s32_x4(&
A[i][j]);
69 int32x4x4_t b = vld1q_s32_x4(&
B[i][j]);
73 c.val[0] = vaddq_s32(a.val[0], b.val[0]);
74 c.val[1] = vaddq_s32(a.val[1], b.val[1]);
75 c.val[2] = vaddq_s32(a.val[2], b.val[2]);
76 c.val[3] = vaddq_s32(a.val[3], b.val[3]);
79 vst1q_s32_x4(&
C[i][j], c);
83 for (; j <
cols; ++j) {
84 C[i][j] =
A[i][j] +
B[i][j];
90 void gpmp::linalg::Mtx::mtx_sub(
const std::vector<std::vector<int>> &
A,
91 const std::vector<std::vector<int>> &
B,
92 std::vector<std::vector<int>> &
C) {
93 const int rows =
A.size();
94 const int cols =
A[0].size();
96 for (
int i = 0; i <
rows; ++i) {
99 for (; j <
cols - 7; j += 8) {
101 int32x4_t a_low = vld1q_s32(&
A[i][j]);
102 int32x4_t a_high = vld1q_s32(&
A[i][j + 4]);
103 int32x4_t b_low = vld1q_s32(&
B[i][j]);
104 int32x4_t b_high = vld1q_s32(&
B[i][j + 4]);
107 int32x4_t c_low = vsubq_s32(a_low, b_low);
108 int32x4_t c_high = vsubq_s32(a_high, b_high);
111 vst1q_s32(&
C[i][j], c_low);
112 vst1q_s32(&
C[i][j + 4], c_high);
116 for (; j <
cols; ++j) {
117 C[i][j] =
A[i][j] -
B[i][j];
123 void gpmp::linalg::Mtx::mtx_tpose(std::vector<std::vector<int>> &matrix) {
124 const int rows = matrix.size();
125 const int cols = matrix[0].size();
127 for (
int i = 0; i <
rows; i += 8) {
128 for (
int j = i; j <
cols; j += 8) {
129 int32x4x2_t row1 = vld2q_s32(&matrix[i][j]);
130 int32x4x2_t row2 = vld2q_s32(&matrix[i + 1][j]);
131 int32x4x2_t row3 = vld2q_s32(&matrix[i + 2][j]);
132 int32x4x2_t row4 = vld2q_s32(&matrix[i + 3][j]);
133 int32x4x2_t row5 = vld2q_s32(&matrix[i + 4][j]);
134 int32x4x2_t row6 = vld2q_s32(&matrix[i + 5][j]);
135 int32x4x2_t row7 = vld2q_s32(&matrix[i + 6][j]);
136 int32x4x2_t row8 = vld2q_s32(&matrix[i + 7][j]);
139 std::swap(row1.val[1], row2.val[0]);
140 std::swap(row3.val[1], row4.val[0]);
141 std::swap(row5.val[1], row6.val[0]);
142 std::swap(row7.val[1], row8.val[0]);
143 std::swap(row1.val[2], row3.val[0]);
144 std::swap(row5.val[2], row7.val[0]);
145 std::swap(row2.val[2], row4.val[0]);
146 std::swap(row6.val[2], row8.val[0]);
147 std::swap(row1.val[3], row5.val[0]);
148 std::swap(row2.val[3], row6.val[0]);
149 std::swap(row3.val[3], row7.val[0]);
150 std::swap(row4.val[3], row8.val[0]);
151 std::swap(row5.val[3], row7.val[2]);
152 std::swap(row6.val[3], row8.val[2]);
155 vst2q_s32(&matrix[i][j], row1);
156 vst2q_s32(&matrix[i + 1][j], row2);
157 vst2q_s32(&matrix[i + 2][j], row3);
158 vst2q_s32(&matrix[i + 3][j], row4);
159 vst2q_s32(&matrix[i + 4][j], row5);
160 vst2q_s32(&matrix[i + 5][j], row6);
161 vst2q_s32(&matrix[i + 6][j], row7);
162 vst2q_s32(&matrix[i + 7][j], row8);