examples : add llama_init_from_gpt_params() common function (#1290)

Signed-off-by: deadprogram <ron@hybridgroup.com>
This commit is contained in:
Ron Evans 2023-05-02 22:39:51 +02:00 committed by GitHub
parent 0e6cbff1b7
commit 67c77799e0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 76 deletions

View file

@ -405,6 +405,37 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}
struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
lparams.n_parts = params.n_parts;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return NULL;
}
if (!params.lora_adapter.empty()) {
int err = llama_apply_lora_from_file(lctx,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return NULL;
}
}
return lctx;
}
/* Keep track of current color of output, and emit ANSI code if it changes. */
void set_console_color(console_state & con_st, console_color_t color) {
if (con_st.use_color && con_st.color != color) {

View file

@ -77,6 +77,12 @@ std::string gpt_random_prompt(std::mt19937 & rng);
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);
//
// Model utils
//
struct llama_context * llama_init_from_gpt_params(const gpt_params & params);
//
// Console utils
//

View file

@ -35,25 +35,11 @@ int main(int argc, char ** argv) {
llama_context * ctx;
// load the model
{
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
lparams.n_parts = params.n_parts;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.logits_all = params.perplexity;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.embedding = params.embedding;
ctx = llama_init_from_file(params.model.c_str(), lparams);
ctx = llama_init_from_gpt_params(params);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
}
}
// print system information
{

View file

@ -101,35 +101,12 @@ int main(int argc, char ** argv) {
llama_context * ctx;
g_ctx = &ctx;
// load the model
{
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
lparams.n_parts = params.n_parts;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
ctx = llama_init_from_file(params.model.c_str(), lparams);
// load the model and apply lora adapter, if any
ctx = llama_init_from_gpt_params(params);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
}
}
if (!params.lora_adapter.empty()) {
int err = llama_apply_lora_from_file(ctx,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return 1;
}
}
// print system information
{

View file

@ -122,37 +122,12 @@ int main(int argc, char ** argv) {
llama_context * ctx;
// load the model
{
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
lparams.n_parts = params.n_parts;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.logits_all = params.perplexity;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.embedding = params.embedding;
ctx = llama_init_from_file(params.model.c_str(), lparams);
// load the model and apply lora adapter, if any
ctx = llama_init_from_gpt_params(params);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
}
}
if (!params.lora_adapter.empty()) {
int err = llama_apply_lora_from_file(ctx,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return 1;
}
}
// print system information
{