mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 23:29:44 +00:00
llama : always sort logits before nucleus sampling (#812)
* Always sort logits before nucleus sampling * remove second normalization - fix windows build - remove normalization since std::discrete_distribution does not require it
This commit is contained in:
parent
cc9cee8e9e
commit
4953e9007f
1 changed files with 3 additions and 14 deletions
17
llama.cpp
17
llama.cpp
|
@ -1236,19 +1236,13 @@ static llama_vocab::id llama_sample_top_p_top_k(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (top_k > 0 && top_k < n_logits) {
|
sample_top_k(logits_id, top_k > 0 ? Min(top_k, n_logits) : n_logits);
|
||||||
sample_top_k(logits_id, top_k);
|
|
||||||
}
|
|
||||||
|
|
||||||
float maxl = -std::numeric_limits<float>::infinity();
|
|
||||||
for (const auto & kv : logits_id) {
|
|
||||||
maxl = Max(maxl, kv.first);
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute probs for the top k tokens
|
// compute probs for the top k tokens
|
||||||
std::vector<float> probs;
|
std::vector<float> probs;
|
||||||
probs.reserve(logits_id.size());
|
probs.reserve(logits_id.size());
|
||||||
|
|
||||||
|
float maxl = logits_id[0].first;
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
for (const auto & kv : logits_id) {
|
for (const auto & kv : logits_id) {
|
||||||
const float p = expf(kv.first - maxl);
|
const float p = expf(kv.first - maxl);
|
||||||
|
@ -1271,16 +1265,11 @@ static llama_vocab::id llama_sample_top_p_top_k(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cumsum = 1.0/cumsum;
|
|
||||||
for (int i = 0; i < (int) probs.size(); i++) {
|
|
||||||
probs[i] *= cumsum;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//printf("\n");
|
//printf("\n");
|
||||||
//for (int i = 0; i < (int) 10; i++) {
|
//for (int i = 0; i < (int) 10; i++) {
|
||||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
|
// printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]);
|
||||||
//}
|
//}
|
||||||
//printf("\n\n");
|
//printf("\n\n");
|
||||||
//exit(0);
|
//exit(0);
|
||||||
|
|
Loading…
Reference in a new issue