38 #include "../../include/linalg/_gpu_mtx.h"
51 #define CL_USE_DEPRECATED_OPENCL_1_2_APIS
52 #define CL_TARGET_OPENCL_VERSION 300
56 #include <OpenCL/opencl.h>
61 #define MEM_SIZE (128)
62 #define MAX_SOURCE_SIZE (0x100000)
63 #define PRINT_LINE(title) printf("\n========== %s ==========\n", title);
65 #define MAX_SOURCE_SIZE (0x100000)
69 if (status != CL_SUCCESS) {
70 std::cout << errorMsg << std::endl;
76 const std::vector<std::vector<int>> &
B,
77 std::vector<std::vector<int>> &
C) {
79 const int N =
A.size();
80 const int M =
A[0].size();
83 std::vector<int> flat_A(
N * M);
84 std::vector<int> flat_B(
N * M);
85 std::vector<int> flat_C(
N * M, 0);
88 for (
int i = 0; i <
N; i++) {
89 for (
int j = 0; j < M; j++) {
90 flat_A[i * M + j] =
A[i][j];
91 flat_B[i * M + j] =
B[i][j];
94 std::chrono::steady_clock::time_point start_time_u =
95 std::chrono::steady_clock::now();
97 FILE *file = fopen(
"_gpu_mtx_kernel.c",
"r");
99 std::cout <<
"Failed to load kernel." << std::endl;
107 cl_platform_id platform_id = NULL;
108 cl_device_id device_id = NULL;
109 cl_uint ret_num_devices;
110 cl_uint ret_num_platforms;
112 cl_int status = clGetPlatformIDs(1, &platform_id, &ret_num_platforms);
113 status |= clGetDeviceIDs(platform_id,
118 checkError(status,
"Error getting platform and device information.");
122 clCreateContext(NULL, 1, &device_id, NULL, NULL, &status);
123 checkError(status,
"Error creating context.");
126 cl_command_queue command_queue =
127 clCreateCommandQueue(context, device_id, 0, &status);
128 checkError(status,
"Error creating command queue.");
131 cl_mem mem_obj_A = clCreateBuffer(context,
136 cl_mem mem_obj_B = clCreateBuffer(context,
141 cl_mem mem_obj_C = clCreateBuffer(context,
148 status = clEnqueueWriteBuffer(command_queue,
157 status |= clEnqueueWriteBuffer(command_queue,
166 checkError(status,
"Error writing matrices to device memory.");
169 cl_program program = clCreateProgramWithSource(context,
171 (
const char **)&source_str,
172 (
const size_t *)&source_size,
174 checkError(status,
"Error creating program.");
177 status = clBuildProgram(program, 1, &device_id, NULL, NULL, NULL);
178 checkError(status,
"Error building program.");
181 cl_kernel kernel = clCreateKernel(program,
"gpu_mtx_add", &status);
185 status = clSetKernelArg(kernel, 0,
sizeof(cl_mem), (
void *)&mem_obj_A);
186 status |= clSetKernelArg(kernel, 1,
sizeof(cl_mem), (
void *)&mem_obj_B);
187 status |= clSetKernelArg(kernel, 2,
sizeof(cl_mem), (
void *)&mem_obj_C);
188 checkError(status,
"Error setting kernel arguments.");
191 size_t global_work_size[2] = {
N, M};
192 size_t local_work_size[2] = {1, 1};
195 status = clEnqueueNDRangeKernel(command_queue,
204 checkError(status,
"Error enqueueing kernel.");
207 status = clEnqueueReadBuffer(command_queue,
216 checkError(status,
"Error reading result from device memory.");
217 std::chrono::steady_clock::time_point end_time_u =
218 std::chrono::steady_clock::now();
219 std::cout <<
"Time elapsed: "
220 << std::chrono::duration_cast<std::chrono::milliseconds>(
221 end_time_u - start_time_u)
223 <<
" ms" << std::endl;
226 for (
int i = 0; i <
N; i++) {
227 for (
int j = 0; j < M; j++) {
228 C[i][j] = flat_C[i * M + j];
233 clReleaseKernel(kernel);
234 clReleaseProgram(program);
235 clReleaseMemObject(mem_obj_A);
236 clReleaseMemObject(mem_obj_B);
237 clReleaseMemObject(mem_obj_C);
238 clReleaseCommandQueue(command_queue);
239 clReleaseContext(context);
253 std::random_device rd;
254 std::mt19937 gen(rd());
255 std::uniform_int_distribution<int> distribution(1, 100);
260 A[i][j] = distribution(gen);
261 B[i][j] = distribution(gen);
265 std::chrono::steady_clock::time_point start_time_u =
266 std::chrono::steady_clock::now();
269 std::chrono::steady_clock::time_point end_time_u =
270 std::chrono::steady_clock::now();
298 std::cout <<
"Time elapsed: "
299 << std::chrono::duration_cast<std::chrono::milliseconds>(
300 end_time_u - start_time_u)
302 <<
" ms" << std::endl;
void gpu_mtx_add(const std::vector< std::vector< int >> &A, const std::vector< std::vector< int >> &B, std::vector< std::vector< int >> &C)
void checkError(cl_int status, const char *errorMsg)