Skip to content

Commit

Permalink
sampling : refactor init to use llama_sampling_params (ggerganov#3696)
Browse files Browse the repository at this point in the history
* sampling : refactor init to use llama_sampling_params

* llama : combine repetition, frequency and presence penalties in 1 call

* examples : remove embd-input and gptneox-wip

* sampling : rename penalty params + reduce size of "prev" vector

* sampling : add llama_sampling_print helper

* sampling : hide prev behind API and apply ggerganov#3661

ggml-ci
  • Loading branch information
ggerganov authored Oct 20, 2023
1 parent 8cf19d6 commit d1031cf
Show file tree
Hide file tree
Showing 30 changed files with 364 additions and 4,501 deletions.
9 changes: 1 addition & 8 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Define the default target now so that it is always the first target
BUILD_TARGETS = \
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server embd-input-test gguf llama-bench llava baby-llama beam-search \
simple batched batched-bench save-load-state server gguf llama-bench llava baby-llama beam-search \
speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o

# Binaries only useful for tests
Expand Down Expand Up @@ -608,13 +608,6 @@ save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)

$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS)


embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %$(DSO_EXT),$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput

gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,6 @@ docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /

- [main](./examples/main/README.md)
- [server](./examples/server/README.md)
- [embd-input](./examples/embd-input/README.md)
- [jeopardy](./examples/jeopardy/README.md)
- [BLIS](./docs/BLIS.md)
- [Performance troubleshooting](./docs/token_generation_performance_tips.md)
Expand Down
69 changes: 35 additions & 34 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::string arg;
gpt_params default_params;
const std::string arg_prefix = "--";
llama_sampling_params & sparams = params.sampling_params;
llama_sampling_params & sparams = params.sparams;

for (int i = 1; i < argc; i++) {
arg = argv[i];
Expand Down Expand Up @@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
sparams.repeat_last_n = std::stoi(argv[i]);
sparams.penalty_last_n = std::stoi(argv[i]);
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
} else if (arg == "--repeat-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.repeat_penalty = std::stof(argv[i]);
sparams.penalty_repeat = std::stof(argv[i]);
} else if (arg == "--frequency-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.frequency_penalty = std::stof(argv[i]);
sparams.penalty_freq = std::stof(argv[i]);
} else if (arg == "--presence-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.presence_penalty = std::stof(argv[i]);
sparams.penalty_present = std::stof(argv[i]);
} else if (arg == "--mirostat") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -572,7 +573,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
params.grammar = argv[i];
sparams.grammar = argv[i];
} else if (arg == "--grammar-file") {
if (++i >= argc) {
invalid_param = true;
Expand All @@ -587,7 +588,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(params.grammar)
std::back_inserter(sparams.grammar)
);
#ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters
Expand Down Expand Up @@ -640,7 +641,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
}

void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
const llama_sampling_params & sparams = params.sampling_params;
const llama_sampling_params & sparams = params.sparams;

printf("usage: %s [options]\n", argv[0]);
printf("\n");
Expand Down Expand Up @@ -678,10 +679,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n);
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
printf(" --mirostat N use Mirostat sampling.\n");
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
Expand Down Expand Up @@ -878,7 +879,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}

if (params.ignore_eos) {
params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY;
params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY;
}

{
Expand Down Expand Up @@ -1123,28 +1124,28 @@ std::string get_sortable_timestamp() {

void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
const llama_sampling_params & sparams = params.sampling_params;
const llama_sampling_params & sparams = params.sparams;

fprintf(stream, "build_commit: %s\n", BUILD_COMMIT);
fprintf(stream, "build_number: %d\n", BUILD_NUMBER);
fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false");
fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false");
fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false");
fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false");
fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false");
fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false");
fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false");
fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false");
fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false");
fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false");
fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false");
fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false");
fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false");
fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false");
fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false");
fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false");
fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false");
fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false");
fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false");
fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false");

#ifdef NDEBUG
fprintf(stream, "debug: false\n");
Expand Down Expand Up @@ -1178,8 +1179,8 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty);
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
Expand Down Expand Up @@ -1238,14 +1239,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty);
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty);
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);

fprintf(stream, "reverse_prompt:\n");
for (std::string ap : params.antiprompt) {
Expand Down
3 changes: 1 addition & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct gpt_params {
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor

// // sampling parameters
struct llama_sampling_params sampling_params;
struct llama_sampling_params sparams;

std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_draft = ""; // draft model for speculative decoding
Expand All @@ -66,7 +66,6 @@ struct gpt_params {
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files

Expand Down
73 changes: 51 additions & 22 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "sampling.h"

struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) {
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();

result->params = params.sampling_params;
result->params = params;
result->grammar = nullptr;

// if there is a grammar, parse it
Expand All @@ -23,7 +23,7 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_params & pa
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
}

result->prev.resize(params.n_ctx);
result->prev.resize(params.n_prev);

return result;
}
Expand Down Expand Up @@ -66,25 +66,56 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
dst->prev = src->prev;
}

llama_token llama_sampling_last(llama_sampling_context * ctx) {
return ctx->prev.back();
}

std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
const int size = ctx_sampling->prev.size();

n = std::min(n, size);

std::string result;

for (int i = size - n; i < size; i++) {
result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
}

return result;
}

std::string llama_sampling_print(const llama_sampling_params & params) {
char result[1024];

snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau);

return std::string(result);
}

llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
const int n_ctx = llama_n_ctx(ctx_main);
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

const llama_sampling_params & params = ctx_sampling->params;

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
Expand All @@ -97,7 +128,7 @@ llama_token llama_sampling_sample(

float * logits = llama_get_logits_ith(ctx_main, idx);

// Apply params.logit_bias map
// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
Expand All @@ -117,14 +148,10 @@ llama_token llama_sampling_sample(
// apply penalties
if (!prev.empty()) {
const float nl_logit = logits[llama_token_nl(ctx_main)];
const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);

llama_sample_repetition_penalty(ctx_main, &cur_p,
prev.data() + prev.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
llama_sample_repetition_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - penalty_last_n,
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);

if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
Expand All @@ -141,7 +168,7 @@ llama_token llama_sampling_sample(
}

if (temp <= 0) {
// Greedy sampling
// greedy sampling
id = llama_sample_token_greedy(ctx_main, &cur_p);
} else {
if (mirostat == 1) {
Expand All @@ -152,8 +179,9 @@ llama_token llama_sampling_sample(
llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else {
// Temperature sampling
// temperature sampling
size_t min_keep = std::max(1, params.n_probs);

llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
Expand Down Expand Up @@ -183,11 +211,12 @@ llama_token llama_sampling_sample(
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
llama_token id) {
llama_token id,
bool apply_grammar) {
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(id);

if (ctx_sampling->grammar != NULL) {
if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
}
}
Loading

0 comments on commit d1031cf

Please sign in to comment.