mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-12 16:29:44 +00:00
Only one CUDA stream per device for async compute (#1898)
This commit is contained in:
parent
051e1b0e6a
commit
2c9380dd2f
3 changed files with 20 additions and 38 deletions
|
@ -336,7 +336,6 @@ Building the program with BLAS support may lead to some performance improvements
|
||||||
cmake .. -DLLAMA_CUBLAS=ON
|
cmake .. -DLLAMA_CUBLAS=ON
|
||||||
cmake --build . --config Release
|
cmake --build . --config Release
|
||||||
```
|
```
|
||||||
Note: Because llama.cpp uses multiple CUDA streams for matrix multiplication results [are not guaranteed to be reproducible](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility). If you need reproducibility, set `GGML_CUDA_MAX_STREAMS` in the file `ggml-cuda.cu` to 1.
|
|
||||||
|
|
||||||
The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used.
|
The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used.
|
||||||
|
|
||||||
|
|
|
@ -106,9 +106,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (arg == "-s" || arg == "--seed") {
|
if (arg == "-s" || arg == "--seed") {
|
||||||
#if defined(GGML_USE_CUBLAS)
|
|
||||||
fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n");
|
|
||||||
#endif
|
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
|
|
54
ggml-cuda.cu
54
ggml-cuda.cu
|
@ -1467,19 +1467,13 @@ static void * g_scratch_buffer = nullptr;
|
||||||
static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
|
static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
|
||||||
static size_t g_scratch_offset = 0;
|
static size_t g_scratch_offset = 0;
|
||||||
|
|
||||||
#define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
|
|
||||||
#define GGML_CUDA_MAX_EVENTS 64
|
|
||||||
|
|
||||||
static int g_device_count = -1;
|
static int g_device_count = -1;
|
||||||
static int g_main_device = 0;
|
static int g_main_device = 0;
|
||||||
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
||||||
|
|
||||||
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||||
|
|
||||||
static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
|
static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
|
||||||
|
|
||||||
static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
|
|
||||||
static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_EVENTS] = { nullptr };
|
|
||||||
|
|
||||||
void ggml_init_cublas() {
|
void ggml_init_cublas() {
|
||||||
static bool initialized = false;
|
static bool initialized = false;
|
||||||
|
@ -1503,15 +1497,8 @@ void ggml_init_cublas() {
|
||||||
for (int id = 0; id < g_device_count; ++id) {
|
for (int id = 0; id < g_device_count; ++id) {
|
||||||
CUDA_CHECK(cudaSetDevice(id));
|
CUDA_CHECK(cudaSetDevice(id));
|
||||||
|
|
||||||
// create streams
|
// create main stream
|
||||||
for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
|
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
|
||||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id][i], cudaStreamNonBlocking));
|
|
||||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[id][i], cudaStreamNonBlocking));
|
|
||||||
}
|
|
||||||
// create events
|
|
||||||
for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
|
|
||||||
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming));
|
|
||||||
}
|
|
||||||
|
|
||||||
// create cublas handle
|
// create cublas handle
|
||||||
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
|
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
|
||||||
|
@ -1978,6 +1965,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
||||||
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
||||||
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
||||||
|
|
||||||
|
// if multiple GPUs are used they need to wait for the main GPU to finish
|
||||||
|
if (split && g_device_count > 1) {
|
||||||
|
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
}
|
||||||
|
|
||||||
for (int id = 0; id < g_device_count; ++id) {
|
for (int id = 0; id < g_device_count; ++id) {
|
||||||
if (!split && id != g_main_device) {
|
if (!split && id != g_main_device) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -2076,9 +2069,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
||||||
}
|
}
|
||||||
const int64_t i11 = i13*ne12 + i12;
|
const int64_t i11 = i13*ne12 + i12;
|
||||||
|
|
||||||
cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
|
cudaStream_t cudaStream_main = g_cudaStreams_main[id];
|
||||||
cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
|
|
||||||
cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
|
|
||||||
|
|
||||||
// for split tensors the data begins at i0 == i0_offset_low
|
// for split tensors the data begins at i0 == i0_offset_low
|
||||||
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
|
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
|
||||||
|
@ -2106,14 +2097,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
||||||
if (src1->backend == GGML_BACKEND_CPU) {
|
if (src1->backend == GGML_BACKEND_CPU) {
|
||||||
GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
|
GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
|
||||||
int64_t nrows1 = flatten_rows ? nrows0 : ne11;
|
int64_t nrows1 = flatten_rows ? nrows0 : ne11;
|
||||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
|
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
|
||||||
} else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
|
} else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
|
||||||
if (id != g_main_device) {
|
if (id != g_main_device) {
|
||||||
GGML_ASSERT(!flatten_rows);
|
GGML_ASSERT(!flatten_rows);
|
||||||
float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
|
float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
|
||||||
src1_ddf_i_source += i11*src1_stride;
|
src1_ddf_i_source += i11*src1_stride;
|
||||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
|
CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
|
||||||
cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1));
|
cudaMemcpyDeviceToDevice, cudaStream_main));
|
||||||
}
|
}
|
||||||
} else if (src1_on_device && !src1_is_contiguous) {
|
} else if (src1_on_device && !src1_is_contiguous) {
|
||||||
GGML_ASSERT(!split);
|
GGML_ASSERT(!split);
|
||||||
|
@ -2122,7 +2113,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
|
|
||||||
|
|
||||||
if (!src0_on_device || !src0_is_contiguous) {
|
if (!src0_on_device || !src0_is_contiguous) {
|
||||||
if (src0_is_f32) {
|
if (src0_is_f32) {
|
||||||
|
@ -2138,9 +2128,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait with main stream until src1 memcpy is done
|
|
||||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
|
|
||||||
|
|
||||||
// do the computation
|
// do the computation
|
||||||
op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
|
op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
|
||||||
|
|
||||||
|
@ -2178,8 +2165,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
||||||
|
|
||||||
// wait until each device is finished, then free their buffers
|
// wait until each device is finished, then free their buffers
|
||||||
for (int id = 0; id < g_device_count; ++id) {
|
for (int id = 0; id < g_device_count; ++id) {
|
||||||
|
if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
CUDA_CHECK(cudaSetDevice(id));
|
CUDA_CHECK(cudaSetDevice(id));
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
if (src0_asq[id] > 0) {
|
if (src0_asq[id] > 0) {
|
||||||
ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
|
ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
|
||||||
}
|
}
|
||||||
|
@ -2245,7 +2237,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
const int64_t ne02 = src0->ne[2];
|
const int64_t ne02 = src0->ne[2];
|
||||||
|
|
||||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
|
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
|
||||||
|
|
||||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||||
void * src0_ddq = src0_extra->data_device[g_main_device];
|
void * src0_ddq = src0_extra->data_device[g_main_device];
|
||||||
|
@ -2257,8 +2249,6 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||||
|
|
||||||
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
|
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
|
||||||
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||||
|
@ -2276,7 +2266,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
|
||||||
const int64_t nb02 = src0->nb[2];
|
const int64_t nb02 = src0->nb[2];
|
||||||
|
|
||||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
|
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
|
||||||
|
|
||||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||||
void * src0_ddq = src0_extra->data_device[g_main_device];
|
void * src0_ddq = src0_extra->data_device[g_main_device];
|
||||||
|
@ -2291,8 +2281,6 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
|
||||||
const int channel_stride_x = nb02 / sizeof(half);
|
const int channel_stride_x = nb02 / sizeof(half);
|
||||||
|
|
||||||
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
|
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
|
||||||
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
@ -2348,7 +2336,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
|
||||||
const int64_t nb12 = src1->nb[2];
|
const int64_t nb12 = src1->nb[2];
|
||||||
|
|
||||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
|
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
|
||||||
|
|
||||||
const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||||
const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||||
|
@ -2366,8 +2354,6 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
|
|
||||||
(void) dst;
|
(void) dst;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue