mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 15:29:43 +00:00
CUDA: mmq CLI option, fixed mmq build issues (#2453)
This commit is contained in:
parent
1215ed7d5c
commit
0728c5a8b9
10 changed files with 67 additions and 27 deletions
|
@ -68,7 +68,7 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework
|
||||||
option(LLAMA_BLAS "llama: use BLAS" OFF)
|
option(LLAMA_BLAS "llama: use BLAS" OFF)
|
||||||
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
|
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
|
||||||
option(LLAMA_CUBLAS "llama: use CUDA" OFF)
|
option(LLAMA_CUBLAS "llama: use CUDA" OFF)
|
||||||
option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF)
|
#option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF)
|
||||||
set(LLAMA_CUDA_MMQ_Y "64" CACHE STRING "llama: y tile size for mmq CUDA kernels")
|
set(LLAMA_CUDA_MMQ_Y "64" CACHE STRING "llama: y tile size for mmq CUDA kernels")
|
||||||
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
|
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
|
||||||
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
|
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
|
||||||
|
@ -253,9 +253,9 @@ if (LLAMA_CUBLAS)
|
||||||
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
|
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
|
||||||
|
|
||||||
add_compile_definitions(GGML_USE_CUBLAS)
|
add_compile_definitions(GGML_USE_CUBLAS)
|
||||||
if (LLAMA_CUDA_CUBLAS)
|
# if (LLAMA_CUDA_CUBLAS)
|
||||||
add_compile_definitions(GGML_CUDA_CUBLAS)
|
# add_compile_definitions(GGML_CUDA_CUBLAS)
|
||||||
endif()
|
# endif()
|
||||||
add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y})
|
add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y})
|
||||||
if (LLAMA_CUDA_FORCE_DMMV)
|
if (LLAMA_CUDA_FORCE_DMMV)
|
||||||
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
|
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
|
||||||
|
@ -277,10 +277,14 @@ if (LLAMA_CUBLAS)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||||
|
# 52 == lowest CUDA 12 standard
|
||||||
|
# 60 == f16 CUDA intrinsics
|
||||||
|
# 61 == integer CUDA intrinsics
|
||||||
|
# 70 == (assumed) compute capability at which unrolling a loop in mul_mat_q kernels is faster
|
||||||
if (LLAMA_CUDA_DMMV_F16)
|
if (LLAMA_CUDA_DMMV_F16)
|
||||||
set(CMAKE_CUDA_ARCHITECTURES "60;61") # needed for f16 CUDA intrinsics
|
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics
|
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||||
|
|
6
Makefile
6
Makefile
|
@ -236,9 +236,9 @@ ifdef LLAMA_CUDA_MMQ_Y
|
||||||
else
|
else
|
||||||
NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64
|
NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64
|
||||||
endif # LLAMA_CUDA_MMQ_Y
|
endif # LLAMA_CUDA_MMQ_Y
|
||||||
ifdef LLAMA_CUDA_CUBLAS
|
#ifdef LLAMA_CUDA_CUBLAS
|
||||||
NVCCFLAGS += -DGGML_CUDA_CUBLAS
|
# NVCCFLAGS += -DGGML_CUDA_CUBLAS
|
||||||
endif # LLAMA_CUDA_CUBLAS
|
#endif # LLAMA_CUDA_CUBLAS
|
||||||
ifdef LLAMA_CUDA_CCBIN
|
ifdef LLAMA_CUDA_CCBIN
|
||||||
NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
|
NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
|
||||||
endif
|
endif
|
||||||
|
|
|
@ -400,9 +400,11 @@ Building the program with BLAS support may lead to some performance improvements
|
||||||
|
|
||||||
The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. The following compilation options are also available to tweak performance:
|
The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. The following compilation options are also available to tweak performance:
|
||||||
|
|
||||||
|
<!---
|
||||||
|
| LLAMA_CUDA_CUBLAS | Boolean | false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). |
|
||||||
|
--->
|
||||||
| Option | Legal values | Default | Description |
|
| Option | Legal values | Default | Description |
|
||||||
|-------------------------|------------------------|---------|-------------|
|
|-------------------------|------------------------|---------|-------------|
|
||||||
| LLAMA_CUDA_CUBLAS | Boolean | false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). |
|
|
||||||
| LLAMA_CUDA_MMQ_Y | Positive integer >= 32 | 64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. |
|
| LLAMA_CUDA_MMQ_Y | Positive integer >= 32 | 64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. |
|
||||||
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
|
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
|
||||||
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
|
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
|
||||||
|
|
|
@ -377,6 +377,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
|
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
|
||||||
|
#endif // GGML_USE_CUBLAS
|
||||||
|
} else if (arg == "--mul-mat-q" || arg == "-mmq") {
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
params.mul_mat_q = true;
|
||||||
|
#else
|
||||||
|
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to use mul_mat_q kernels.\n");
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
} else if (arg == "--low-vram" || arg == "-lv") {
|
} else if (arg == "--low-vram" || arg == "-lv") {
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
@ -585,6 +591,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
||||||
fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" );
|
fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" );
|
||||||
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n" );
|
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n" );
|
||||||
|
fprintf(stdout, " -mmq, --mul-mat-q use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" );
|
||||||
|
fprintf(stdout, " Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" );
|
||||||
|
fprintf(stdout, " is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" );
|
||||||
#endif
|
#endif
|
||||||
fprintf(stdout, " --mtest compute maximum memory usage\n");
|
fprintf(stdout, " --mtest compute maximum memory usage\n");
|
||||||
fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n");
|
fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n");
|
||||||
|
@ -637,6 +646,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||||
lparams.main_gpu = params.main_gpu;
|
lparams.main_gpu = params.main_gpu;
|
||||||
lparams.tensor_split = params.tensor_split;
|
lparams.tensor_split = params.tensor_split;
|
||||||
lparams.low_vram = params.low_vram;
|
lparams.low_vram = params.low_vram;
|
||||||
|
lparams.mul_mat_q = params.mul_mat_q;
|
||||||
lparams.seed = params.seed;
|
lparams.seed = params.seed;
|
||||||
lparams.f16_kv = params.memory_f16;
|
lparams.f16_kv = params.memory_f16;
|
||||||
lparams.use_mmap = params.use_mmap;
|
lparams.use_mmap = params.use_mmap;
|
||||||
|
|
|
@ -74,6 +74,7 @@ struct gpt_params {
|
||||||
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
|
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
|
||||||
|
|
||||||
bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
|
bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
|
||||||
|
bool mul_mat_q = false; // if true, use experimental mul_mat_q kernels
|
||||||
bool memory_f16 = true; // use f16 instead of f32 for memory kv
|
bool memory_f16 = true; // use f16 instead of f32 for memory kv
|
||||||
bool random_prompt = false; // do not randomize prompt if none provided
|
bool random_prompt = false; // do not randomize prompt if none provided
|
||||||
bool use_color = false; // use color to distinguish generations and inputs
|
bool use_color = false; // use color to distinguish generations and inputs
|
||||||
|
|
|
@ -631,6 +631,9 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||||
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
||||||
fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
|
fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
|
||||||
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n");
|
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n");
|
||||||
|
fprintf(stdout, " -mmq, --mul-mat-q use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" );
|
||||||
|
fprintf(stdout, " Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" );
|
||||||
|
fprintf(stdout, " is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" );
|
||||||
#endif
|
#endif
|
||||||
fprintf(stdout, " -m FNAME, --model FNAME\n");
|
fprintf(stdout, " -m FNAME, --model FNAME\n");
|
||||||
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
|
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
|
||||||
|
@ -827,7 +830,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.", {});
|
LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {});
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
}
|
}
|
||||||
else if (arg == "--low-vram" || arg == "-lv")
|
else if (arg == "--low-vram" || arg == "-lv")
|
||||||
|
@ -835,7 +838,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
params.low_vram = true;
|
params.low_vram = true;
|
||||||
#else
|
#else
|
||||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
|
LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n", {});
|
||||||
|
#endif // GGML_USE_CUBLAS
|
||||||
|
}
|
||||||
|
else if (arg == "--mul-mat-q" || arg == "-mmq")
|
||||||
|
{
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
params.mul_mat_q = true;
|
||||||
|
#else
|
||||||
|
LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to use mul_mat_q kernels.\n", {});
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
}
|
}
|
||||||
else if (arg == "--main-gpu" || arg == "-mg")
|
else if (arg == "--main-gpu" || arg == "-mg")
|
||||||
|
|
24
ggml-cuda.cu
24
ggml-cuda.cu
|
@ -3898,10 +3898,9 @@ static size_t g_scratch_offset = 0;
|
||||||
|
|
||||||
static int g_device_count = -1;
|
static int g_device_count = -1;
|
||||||
static int g_main_device = 0;
|
static int g_main_device = 0;
|
||||||
#ifndef GGML_CUDA_FORCE_DMMV
|
|
||||||
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
|
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
|
||||||
#endif
|
|
||||||
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
||||||
|
static bool g_mul_mat_q = false;
|
||||||
|
|
||||||
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||||
|
|
||||||
|
@ -3923,9 +3922,7 @@ void ggml_init_cublas() {
|
||||||
g_tensor_split[id] = total_vram;
|
g_tensor_split[id] = total_vram;
|
||||||
total_vram += prop.totalGlobalMem;
|
total_vram += prop.totalGlobalMem;
|
||||||
|
|
||||||
#ifndef GGML_CUDA_FORCE_DMMV
|
|
||||||
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
|
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
for (int id = 0; id < g_device_count; ++id) {
|
for (int id = 0; id < g_device_count; ++id) {
|
||||||
g_tensor_split[id] /= total_vram;
|
g_tensor_split[id] /= total_vram;
|
||||||
|
@ -4278,6 +4275,7 @@ inline void ggml_cuda_op_mul_mat_vec(
|
||||||
|
|
||||||
#ifdef GGML_CUDA_FORCE_DMMV
|
#ifdef GGML_CUDA_FORCE_DMMV
|
||||||
const bool use_mul_mat_vec_q = false;
|
const bool use_mul_mat_vec_q = false;
|
||||||
|
(void) g_compute_capabilities[0];
|
||||||
#else
|
#else
|
||||||
int id;
|
int id;
|
||||||
CUDA_CHECK(cudaGetDevice(&id));
|
CUDA_CHECK(cudaGetDevice(&id));
|
||||||
|
@ -5021,12 +5019,14 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
|
||||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
|
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
|
||||||
} else {
|
} else {
|
||||||
#ifdef GGML_CUDA_CUBLAS
|
int min_compute_capability = INT_MAX;
|
||||||
const bool use_mul_mat_q = false;
|
for (int id = 0; id < g_device_count; ++id) {
|
||||||
#else
|
if (min_compute_capability > g_compute_capabilities[id]) {
|
||||||
const bool use_mul_mat_q = ggml_is_quantized(src0->type);
|
min_compute_capability = g_compute_capabilities[id];
|
||||||
#endif // GGML_CUDA_CUBLAS
|
}
|
||||||
if (use_mul_mat_q) {
|
}
|
||||||
|
|
||||||
|
if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
|
||||||
} else {
|
} else {
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
|
||||||
|
@ -5320,6 +5320,10 @@ void ggml_cuda_set_main_device(int main_device) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_set_mul_mat_q(bool mul_mat_q) {
|
||||||
|
g_mul_mat_q = mul_mat_q;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_set_scratch_size(size_t scratch_size) {
|
void ggml_cuda_set_scratch_size(size_t scratch_size) {
|
||||||
g_scratch_size = scratch_size;
|
g_scratch_size = scratch_size;
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,7 @@ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
||||||
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
||||||
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
||||||
void ggml_cuda_set_main_device(int main_device);
|
void ggml_cuda_set_main_device(int main_device);
|
||||||
|
void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
||||||
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
||||||
void ggml_cuda_free_scratch(void);
|
void ggml_cuda_free_scratch(void);
|
||||||
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
||||||
|
|
10
llama.cpp
10
llama.cpp
|
@ -901,6 +901,7 @@ struct llama_context_params llama_context_default_params() {
|
||||||
/*.progress_callback =*/ nullptr,
|
/*.progress_callback =*/ nullptr,
|
||||||
/*.progress_callback_user_data =*/ nullptr,
|
/*.progress_callback_user_data =*/ nullptr,
|
||||||
/*.low_vram =*/ false,
|
/*.low_vram =*/ false,
|
||||||
|
/*.mul_mat_q =*/ false,
|
||||||
/*.f16_kv =*/ true,
|
/*.f16_kv =*/ true,
|
||||||
/*.logits_all =*/ false,
|
/*.logits_all =*/ false,
|
||||||
/*.vocab_only =*/ false,
|
/*.vocab_only =*/ false,
|
||||||
|
@ -1028,6 +1029,7 @@ static void llama_model_load_internal(
|
||||||
int n_gpu_layers,
|
int n_gpu_layers,
|
||||||
int main_gpu,
|
int main_gpu,
|
||||||
const float * tensor_split,
|
const float * tensor_split,
|
||||||
|
const bool mul_mat_q,
|
||||||
float rope_freq_base,
|
float rope_freq_base,
|
||||||
float rope_freq_scale,
|
float rope_freq_scale,
|
||||||
bool low_vram,
|
bool low_vram,
|
||||||
|
@ -1156,9 +1158,11 @@ static void llama_model_load_internal(
|
||||||
}
|
}
|
||||||
|
|
||||||
(void) main_gpu;
|
(void) main_gpu;
|
||||||
|
(void) mul_mat_q;
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
fprintf(stderr, "%s: using CUDA for GPU acceleration\n", __func__);
|
fprintf(stderr, "%s: using CUDA for GPU acceleration\n", __func__);
|
||||||
ggml_cuda_set_main_device(main_gpu);
|
ggml_cuda_set_main_device(main_gpu);
|
||||||
|
ggml_cuda_set_mul_mat_q(mul_mat_q);
|
||||||
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
|
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
|
||||||
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
|
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
|
@ -1367,6 +1371,7 @@ static bool llama_model_load(
|
||||||
int n_gpu_layers,
|
int n_gpu_layers,
|
||||||
int main_gpu,
|
int main_gpu,
|
||||||
const float * tensor_split,
|
const float * tensor_split,
|
||||||
|
const bool mul_mat_q,
|
||||||
float rope_freq_base,
|
float rope_freq_base,
|
||||||
float rope_freq_scale,
|
float rope_freq_scale,
|
||||||
bool low_vram,
|
bool low_vram,
|
||||||
|
@ -1377,7 +1382,8 @@ static bool llama_model_load(
|
||||||
llama_progress_callback progress_callback,
|
llama_progress_callback progress_callback,
|
||||||
void *progress_callback_user_data) {
|
void *progress_callback_user_data) {
|
||||||
try {
|
try {
|
||||||
llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
|
llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers,
|
||||||
|
main_gpu, tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, low_vram, memory_type,
|
||||||
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
|
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
|
@ -3192,7 +3198,7 @@ struct llama_model * llama_load_model_from_file(
|
||||||
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||||
|
|
||||||
if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers,
|
if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers,
|
||||||
params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
|
params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
|
||||||
memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
|
memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
|
||||||
params.progress_callback_user_data)) {
|
params.progress_callback_user_data)) {
|
||||||
delete model;
|
delete model;
|
||||||
|
|
1
llama.h
1
llama.h
|
@ -108,6 +108,7 @@ extern "C" {
|
||||||
|
|
||||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||||
bool low_vram; // if true, reduce VRAM usage at the cost of performance
|
bool low_vram; // if true, reduce VRAM usage at the cost of performance
|
||||||
|
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
|
||||||
bool f16_kv; // use fp16 for KV cache
|
bool f16_kv; // use fp16 for KV cache
|
||||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||||
bool vocab_only; // only load the vocabulary, no weights
|
bool vocab_only; // only load the vocabulary, no weights
|
||||||
|
|
Loading…
Reference in a new issue