diff --git a/examples/common.cpp b/examples/common.cpp index 476d565..0990195 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -260,12 +260,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.cfg_scale = std::stof(argv[i]); - } else if (arg == "--cfg-smooth-factor") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.cfg_smooth_factor = std::stof(argv[i]); } else if (arg == "-b" || arg == "--batch-size") { if (++i >= argc) { invalid_param = true; @@ -509,7 +503,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --cfg-negative-prompt PROMPT \n"); fprintf(stderr, " negative prompt to use for guidance. (default: empty)\n"); fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); - fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); diff --git a/examples/common.h b/examples/common.h index 037a4ee..69170df 100644 --- a/examples/common.h +++ b/examples/common.h @@ -55,7 +55,6 @@ struct gpt_params { // https://arxiv.org/abs/2306.17806 std::string cfg_negative_prompt; // string to help guidance float cfg_scale = 1.f; // How strong is guidance - float cfg_smooth_factor = 1.f; // Smooth factor between old and new logits std::string model = "models/7B/ggml-model.bin"; // model path std::string model_alias = "unknown"; // model alias diff --git a/examples/main/main.cpp b/examples/main/main.cpp index bcbcf12..656382f 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -557,7 +557,7 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; if (ctx_guidance) { - llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor); + llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale); } // Apply penalties diff --git a/llama.cpp b/llama.cpp index 23e746d..3b0024e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2218,8 +2218,7 @@ void llama_sample_classifier_free_guidance( struct llama_context * ctx, llama_token_data_array * candidates, struct llama_context * guidance_ctx, - float scale, - float smooth_factor) { + float scale) { int64_t t_start_sample_us = ggml_time_us(); assert(ctx); @@ -2240,16 +2239,7 @@ void llama_sample_classifier_free_guidance( for (int i = 0; i < n_vocab; ++i) { float logit_guidance = logits_guidance[i]; float logit_base = logits_base[i]; - logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance; - } - - llama_log_softmax(logits_guidance, n_vocab); - - for (int i = 0; i < n_vocab; ++i) { - float logit_base = logits_base[i]; - float logit_guidance = logits_guidance[i]; - - candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base; + candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance; } if (ctx) { diff --git a/llama.h b/llama.h index c565f6a..bbf28e6 100644 --- a/llama.h +++ b/llama.h @@ -344,13 +344,11 @@ extern "C" { /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. - /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits. LLAMA_API void llama_sample_classifier_free_guidance( struct llama_context * ctx, llama_token_data_array * candidates, struct llama_context * guidance_ctx, - float scale, - float smooth_factor); + float scale); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);