Skip to content

Commit

Permalink
server : fix smart selection of available slot (ggerganov#10120)
Browse files Browse the repository at this point in the history
* Fix smart selection of available slot

* minor fix

* replace vectors of tokens with shorthands
  • Loading branch information
sasha0552 authored Nov 1, 2024
1 parent 1804adb commit d865d14
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 28 deletions.
35 changes: 12 additions & 23 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,12 +725,12 @@ struct server_context {
return nullptr;
}

server_slot * get_available_slot(const std::string & prompt) {
server_slot * get_available_slot(const server_task & task) {
server_slot * ret = nullptr;

// find the slot that has at least n% prompt similarity
if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
int max_lcp_len = 0;
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
int max_lcs_len = 0;
float similarity = 0;

for (server_slot & slot : slots) {
Expand All @@ -740,25 +740,25 @@ struct server_context {
}

// skip the slot if it does not contains cached tokens
if (slot.prompt_tokens.empty()) {
if (slot.cache_tokens.empty()) {
continue;
}

// length of the Longest Common Prefix between the current slot's prompt and the input prompt
int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
int lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);

// fraction of the common substring length compared to the current slot's prompt length
similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
// fraction of the common subsequence length compared to the current slot's prompt length
similarity = static_cast<float>(lcs_len) / static_cast<int>(slot.cache_tokens.size());

// select the current slot if the criteria match
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
max_lcp_len = lcp_len;
if (lcs_len > max_lcs_len && similarity > slot_prompt_similarity) {
max_lcs_len = lcs_len;
ret = &slot;
}
}

if (ret != nullptr) {
SLT_DBG(*ret, "selected slot by lcp similarity, max_lcp_len = %d, similarity = %f\n", max_lcp_len, similarity);
SLT_DBG(*ret, "selected slot by lcs similarity, max_lcs_len = %d, similarity = %f\n", max_lcs_len, similarity);
}
}

Expand Down Expand Up @@ -1514,18 +1514,7 @@ struct server_context {
{
const int id_slot = json_value(task.data, "id_slot", -1);

server_slot * slot;

if (id_slot != -1) {
slot = get_slot_by_id(id_slot);
} else {
std::string prompt;
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
prompt = json_value(task.data, "prompt", std::string());
}

slot = get_available_slot(prompt);
}
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);

if (slot == nullptr) {
// if no slot is available, we defer this task for processing later
Expand Down
52 changes: 47 additions & 5 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,18 +439,60 @@ static std::string gen_chatcmplid() {
// other common utils
//

static size_t longest_common_prefix(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}

return i;
}

static size_t longest_common_prefix(const std::string & a, const std::string & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}

return i;
// get the lengths of the input sequences
int a_len = a.size();
int b_len = b.size();

// initialize the maximum length of the longest common subsequence (LCS)
int max_length = 0;

// use two rows instead of a 2D matrix to optimize space
std::vector<int> prev_row(b_len + 1, 0);
std::vector<int> curr_row(b_len + 1, 0);

// iterate through the elements of a
for (int i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (int j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}

// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}

// update the previous row for the next iteration
prev_row = curr_row;
}

// return the maximum length of the LCS
return max_length;
}

static bool ends_with(const std::string & str, const std::string & suffix) {
Expand Down

0 comments on commit d865d14

Please sign in to comment.