mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 23:29:44 +00:00
ggml : alternative fix for race condition bug in non-inplace ggml_compute_forward_diag_mask_f32 (#1454)
* fix race condition bug in non-inplace ggml_compute_forward_diag_mask_f32 memcpy needs to be synchronized across threads to avoid race conditions. => do it in INIT phase * remove trailing whitespace * Update ggml.c --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
13c351ad72
commit
79b2d5b69d
1 changed files with 17 additions and 23 deletions
40
ggml.c
40
ggml.c
|
@ -10501,34 +10501,28 @@ static void ggml_compute_forward_diag_mask_f32(
|
|||
assert(src1->type == GGML_TYPE_I32);
|
||||
assert(ggml_nelements(src1) == 2);
|
||||
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
const bool inplace = (bool)((int32_t *) src1->data)[1];
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
// TODO: this hack is not good, need a better way to handle this
|
||||
if (!inplace) {
|
||||
// use the init task to copy src -> dst
|
||||
struct ggml_compute_params params_cpy = *params;
|
||||
|
||||
params_cpy.ith = 0;
|
||||
params_cpy.nth = 1;
|
||||
params_cpy.type = GGML_TASK_COMPUTE;
|
||||
|
||||
ggml_compute_forward_dup_same_cont(¶ms_cpy, src0, dst);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
const bool inplace = (bool)((int32_t *) src1->data)[1];
|
||||
assert(n_past >= 0);
|
||||
|
||||
if (!inplace && (params->type == GGML_TASK_INIT)) {
|
||||
// memcpy needs to be synchronized across threads to avoid race conditions.
|
||||
// => do it in INIT phase
|
||||
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
||||
memcpy(
|
||||
((char *) dst->data),
|
||||
((char *) src0->data),
|
||||
ggml_nbytes(dst));
|
||||
}
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: handle transposed/permuted matrices
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
|
|
Loading…
Reference in a new issue