2023-03-25 18:26:40 +00:00
|
|
|
#include "common.h"
|
|
|
|
#include "llama.h"
|
|
|
|
|
2023-03-28 16:48:20 +00:00
|
|
|
#include <cmath>
|
|
|
|
|
|
|
|
std::vector<float> softmax(const std::vector<float>& logits) {
|
|
|
|
std::vector<float> probs(logits.size());
|
2023-03-25 18:26:40 +00:00
|
|
|
float max_logit = logits[0];
|
|
|
|
for (float v : logits) max_logit = std::max(max_logit, v);
|
|
|
|
double sum_exp = 0.0;
|
|
|
|
for (size_t i = 0; i < logits.size(); i++) {
|
|
|
|
// Subtract the maximum logit value from the current logit value for numerical stability
|
2023-03-28 16:48:20 +00:00
|
|
|
const float logit = logits[i] - max_logit;
|
|
|
|
const float exp_logit = expf(logit);
|
2023-03-25 18:26:40 +00:00
|
|
|
sum_exp += exp_logit;
|
|
|
|
probs[i] = exp_logit;
|
|
|
|
}
|
|
|
|
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
|
|
|
|
return probs;
|
|
|
|
}
|
|
|
|
|
|
|
|
void perplexity(llama_context * ctx, const gpt_params & params) {
|
|
|
|
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
2023-03-26 13:14:01 +00:00
|
|
|
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
2023-03-25 18:26:40 +00:00
|
|
|
// Output: `perplexity: 13.5106 [114/114]`
|
|
|
|
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
|
|
|
|
|
|
|
|
int count = 0;
|
|
|
|
int seq_count = tokens.size() / params.n_ctx;
|
2023-04-13 21:50:42 +00:00
|
|
|
int n_vocab = llama_n_vocab(ctx);
|
2023-03-25 18:26:40 +00:00
|
|
|
|
2023-03-28 16:48:20 +00:00
|
|
|
double nll = 0.0;
|
2023-04-13 21:50:42 +00:00
|
|
|
fprintf(stderr, "%s : calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch);
|
2023-03-25 18:26:40 +00:00
|
|
|
|
|
|
|
for (int i = 0; i < seq_count; ++i) {
|
|
|
|
int start = i * params.n_ctx;
|
2023-04-13 21:50:42 +00:00
|
|
|
int end = start + params.n_ctx;
|
|
|
|
|
|
|
|
std::vector<float> logits;
|
|
|
|
int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch;
|
2023-03-25 18:26:40 +00:00
|
|
|
auto start_t = std::chrono::high_resolution_clock::now();
|
2023-04-13 21:50:42 +00:00
|
|
|
for (int j = 0; j < num_batches; ++j) {
|
|
|
|
int batch_start = start + j * params.n_batch;
|
|
|
|
int batch_size = std::min(end - batch_start, params.n_batch);
|
|
|
|
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) {
|
|
|
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
auto batch_logits = llama_get_logits(ctx);
|
|
|
|
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
2023-03-25 18:26:40 +00:00
|
|
|
}
|
|
|
|
auto end_t = std::chrono::high_resolution_clock::now();
|
|
|
|
if (i == 0) {
|
2023-03-28 16:48:20 +00:00
|
|
|
const float seconds = std::chrono::duration<float>(end_t - start_t).count();
|
2023-03-25 18:26:40 +00:00
|
|
|
printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0));
|
|
|
|
}
|
|
|
|
// We get the logits for all the tokens in the context window (params.n_ctx)
|
|
|
|
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
|
|
|
|
// calculate the perplexity over the last half the window (so the model always has
|
|
|
|
// some context to predict the token).
|
|
|
|
//
|
|
|
|
// We rely on the fact that attention in the forward pass only looks at previous
|
|
|
|
// tokens here, so the logits returned for each token are an accurate representation
|
|
|
|
// of what the model would have predicted at that point.
|
|
|
|
//
|
|
|
|
// Example, we have a context window of 512, we will compute perplexity for each of the
|
|
|
|
// last 256 tokens. Then, we split the input up into context window size chunks to
|
|
|
|
// process the entire prompt.
|
2023-04-13 21:50:42 +00:00
|
|
|
for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) {
|
2023-03-25 18:26:40 +00:00
|
|
|
// Calculate probability of next token, given the previous ones.
|
|
|
|
std::vector<float> tok_logits(
|
2023-04-13 21:50:42 +00:00
|
|
|
logits.begin() + j * n_vocab,
|
|
|
|
logits.begin() + (j + 1) * n_vocab);
|
|
|
|
float prob = softmax(tok_logits)[tokens[start + j + 1]];
|
2023-03-25 18:26:40 +00:00
|
|
|
nll += -std::log(prob);
|
|
|
|
++count;
|
|
|
|
}
|
|
|
|
// perplexity is e^(average negative log-likelihood)
|
|
|
|
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
|
|
|
fflush(stdout);
|
|
|
|
}
|
|
|
|
printf("\n");
|
|
|
|
}
|
|
|
|
|
|
|
|
int main(int argc, char ** argv) {
|
|
|
|
gpt_params params;
|
|
|
|
params.model = "models/llama-7B/ggml-model.bin";
|
|
|
|
|
2023-04-13 21:50:42 +00:00
|
|
|
params.n_batch = 512;
|
2023-03-25 18:26:40 +00:00
|
|
|
if (gpt_params_parse(argc, argv, params) == false) {
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
params.perplexity = true;
|
2023-04-13 21:50:42 +00:00
|
|
|
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
2023-03-25 18:26:40 +00:00
|
|
|
|
|
|
|
if (params.n_ctx > 2048) {
|
|
|
|
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
|
|
|
|
"expect poor results\n", __func__, params.n_ctx);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (params.seed <= 0) {
|
|
|
|
params.seed = time(NULL);
|
|
|
|
}
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
|
|
|
|
|
|
|
|
std::mt19937 rng(params.seed);
|
|
|
|
if (params.random_prompt) {
|
|
|
|
params.prompt = gpt_random_prompt(rng);
|
|
|
|
}
|
|
|
|
|
|
|
|
llama_context * ctx;
|
|
|
|
|
|
|
|
// load the model
|
|
|
|
{
|
|
|
|
auto lparams = llama_context_default_params();
|
|
|
|
|
|
|
|
lparams.n_ctx = params.n_ctx;
|
|
|
|
lparams.n_parts = params.n_parts;
|
|
|
|
lparams.seed = params.seed;
|
|
|
|
lparams.f16_kv = params.memory_f16;
|
|
|
|
lparams.logits_all = params.perplexity;
|
Rewrite loading code to try to satisfy everyone:
- Support all three formats (ggml, ggmf, ggjt). (However, I didn't
include the hack needed to support GPT4All files without conversion.
Those can still be used after converting them with convert.py from my
other PR.)
- Support both mmap and read (mmap is used by default, but can be
disabled with `--no-mmap`, and is automatically disabled for pre-ggjt
files or on platforms where mmap is not supported).
- Support multi-file models like before, but automatically determine the
number of parts rather than requiring `--n_parts`.
- Improve validation and error checking.
- Stop using the per-file type field (f16) entirely in favor of just
relying on the per-tensor type/size fields. This has no immediate
benefit, but makes it easier to experiment with different formats, and
should make it easier to support the new GPTQ-for-LLaMa models in the
future (I have some work in progress on that front).
- Support VirtualLock on Windows (using the same `--mlock` option as on
Unix).
- Indicate loading progress when using mmap + mlock. (Which led me
to the interesting observation that on my Linux machine, with a
warm file cache, mlock actually takes some time, whereas mmap
without mlock starts almost instantly...)
- To help implement this, move mlock support from ggml to the
loading code.
- madvise/PrefetchVirtualMemory support (based on #740)
- Switch from ifstream to the `fopen` family of functions to avoid
unnecessary copying and, when mmap is enabled, allow reusing the same
file descriptor for both metadata reads and mmap (whereas the existing
implementation opens the file a second time to mmap).
- Quantization now produces a single-file output even with multi-file
inputs (not really a feature as much as 'it was easier this way').
Implementation notes:
I tried to factor the code into more discrete pieces than before.
Regarding code style: I tried to follow the code style, but I'm naughty
and used a few advanced C++ features repeatedly:
- Destructors to make it easier to ensure everything gets cleaned up.
- Exceptions. I don't even usually use exceptions when writing C++, and
I can remove them if desired... but here they make the loading code
much more succinct while still properly handling a variety of errors,
ranging from API calls failing to integer overflow and allocation
failure. The exceptions are converted to error codes at the
API boundary.)
Co-authored-by: Pavol Rusnak <pavol@rusnak.io> (for the bit I copied from #740)
2023-04-08 19:24:37 +00:00
|
|
|
lparams.use_mmap = params.use_mmap;
|
2023-03-25 18:26:40 +00:00
|
|
|
lparams.use_mlock = params.use_mlock;
|
|
|
|
lparams.embedding = params.embedding;
|
|
|
|
|
|
|
|
ctx = llama_init_from_file(params.model.c_str(), lparams);
|
|
|
|
|
|
|
|
if (ctx == NULL) {
|
|
|
|
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// print system information
|
|
|
|
{
|
|
|
|
fprintf(stderr, "\n");
|
|
|
|
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
|
|
|
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
|
|
|
}
|
|
|
|
|
|
|
|
perplexity(ctx, params);
|
|
|
|
|
|
|
|
llama_print_timings(ctx);
|
|
|
|
llama_free(ctx);
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
}
|