Skip to content

Commit

Permalink
a few updates for rag
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexChristensen committed Feb 10, 2024
1 parent eab84a6 commit c723e44
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 42 deletions.
18 changes: 8 additions & 10 deletions R/auto_device.R
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
#' @noRd
# Automatically detect appropriate device
# Updated 29.01.2024
# Updated 06.02.2024
auto_device <- function(device, transformer)
{

# Set transformer memory (MB)
# Numbers derived from overall memory usage on Alex's 1x A6000
# Single run of `rag` with each model using 3800 tweets
transformer_memory <- round(
switch(
transformer,
"tinyllama" = 5504, "llama-2" = 5964,
"mistral-7b" = 30018, "openchat-3.5" = 29238,
"orca-2" = 29836, "phi-2" = 13594
), digits = -2
# Assume GPU and 16-bit unless otherwise noted
transformer_memory <- switch(
transformer,
"tinyllama" = 2640, "llama-2" = 4200, # supported by 4-bit
"mistral-7b" = 16800, "openchat-3.5" = 16800,
"orca-2" = 33600, # supported by 32-bit
"phi-2" = 6480
)

# First, check for "auto"
Expand Down
105 changes: 81 additions & 24 deletions R/rag.R
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,6 @@ get_embedding <- function(index, output)
# Loop across documents
embedding <- do.call(cbind, lapply(output$content$document, index$vector_store$get))







# Initialize data frame
content_df <- matrix(
data = NA, nrow = n_documents, ncol = 3,
Expand All @@ -397,7 +391,6 @@ get_embedding <- function(index, output)
)
)


# Loop over content
for(i in seq_len(n_documents)){

Expand All @@ -421,35 +414,99 @@ get_embedding <- function(index, output)

#' @noRd
# LLAMA-2 ----
# Updated 28.01.2023
# Updated 06.02.2024
setup_llama2 <- function(llama_index, prompt, device)
{

# Return model
return(
llama_index$ServiceContext$from_defaults(
# Check for device
if(grepl("cuda", device)){

# Try to setup GPU modules
output <- try(setup_gpu_modules(), silent = TRUE)

# If error, then switch to "cpu"
if(is(device, "try-error")){
device <- "cpu"
}

}

# If GPU possible, try different models
if(grepl("cuda", device)){

# Order of models to try
MODEL <- c("GPTQ", "AWQ")

# Loop over and try
for(model in MODEL){

# Set up model
model_name <- paste0("TheBloke/Llama-2-7B-Chat-", model)

# Try to get and load model
load_model <- try(
llama_index$ServiceContext$from_defaults(
llm = llama_index$llms$HuggingFaceLLM(
model_name = model_name,
tokenizer_name = model_name,
query_wrapper_prompt = llama_index$PromptTemplate(
paste0(
"<|system|>\n", prompt,
"</s>\n<|user|>\n{query_str}</s>\n<|assistant|>\n"
)
), device_map = device,
generate_kwargs = list(
temperature = as.double(0.1), do_sample = TRUE
)
), context_window = 8192L,
embed_model = "local:BAAI/bge-small-en-v1.5"
), silent = TRUE
)

# Check if load model failed
if(is(load_model, "try-error")){
delete_transformer(gsub("/", "--", model_name), TRUE)
}else{ # Successful load, break out of loop
break
}

}

# If by the end, still failing, switch to CPU
if(is(load_model, "try-error")){
device <- "cpu"
}

}

# Use CPU model
if(device == "cpu"){
load_model <- llama_index$ServiceContext$from_defaults(
llm = llama_index$llms$HuggingFaceLLM(
model_name = "TheBloke/Llama-2-7b-Chat-AWQ",
tokenizer_name = "TheBloke/Llama-2-7b-Chat-AWQ",
model_name = "TheBloke/Llama-2-7B-Chat-fp16",
tokenizer_name = "TheBloke/Llama-2-7B-Chat-fp16",
query_wrapper_prompt = llama_index$PromptTemplate(
paste0(
"<|system|>\n", prompt,
"</s>\n<|user|>\n{query_str}</s>\n<|assistant|>\n"
)
), device_map = device,
generate_kwargs = list(
"temperature" = as.double(0.1), do_sample = TRUE
temperature = as.double(0.1), do_sample = TRUE
)
), context_window = 8192L,
embed_model = "local:BAAI/bge-small-en-v1.5"
)
)
}

# Return model
return(load_model)

}

#' @noRd
# Mistral-7B ----
# Updated 28.01.2023
# Updated 28.01.2024
setup_mistral <- function(llama_index, prompt, device)
{

Expand All @@ -461,7 +518,7 @@ setup_mistral <- function(llama_index, prompt, device)
tokenizer_name = "mistralai/Mistral-7B-v0.1",
device_map = device,
generate_kwargs = list(
"temperature" = as.double(0.1), do_sample = TRUE,
temperature = as.double(0.1), do_sample = TRUE,
pad_token_id = 2L, eos_token_id = 2L
)
), context_window = 8192L,
Expand All @@ -473,7 +530,7 @@ setup_mistral <- function(llama_index, prompt, device)

#' @noRd
# OpenChat-3.5 ----
# Updated 28.01.2023
# Updated 28.01.2024
setup_openchat <- function(llama_index, prompt, device)
{

Expand All @@ -485,7 +542,7 @@ setup_openchat <- function(llama_index, prompt, device)
tokenizer_name = "openchat/openchat_3.5",
device_map = device,
generate_kwargs = list(
"temperature" = as.double(0.1), do_sample = TRUE
temperature = as.double(0.1), do_sample = TRUE
)
), context_window = 8192L,
embed_model = "local:BAAI/bge-small-en-v1.5"
Expand All @@ -496,7 +553,7 @@ setup_openchat <- function(llama_index, prompt, device)

#' @noRd
# Orca-2 ----
# Updated 28.01.2023
# Updated 28.01.2024
setup_orca2 <- function(llama_index, prompt, device)
{

Expand All @@ -508,7 +565,7 @@ setup_orca2 <- function(llama_index, prompt, device)
tokenizer_name = "microsoft/Orca-2-7b",
device_map = device,
generate_kwargs = list(
"temperature" = as.double(0.1), do_sample = TRUE
temperature = as.double(0.1), do_sample = TRUE
)
), context_window = 4096L,
embed_model = "local:BAAI/bge-small-en-v1.5"
Expand All @@ -519,7 +576,7 @@ setup_orca2 <- function(llama_index, prompt, device)

#' @noRd
# Phi-2 ----
# Updated 28.01.2023
# Updated 28.01.2024
setup_phi2 <- function(llama_index, prompt, device)
{

Expand All @@ -531,7 +588,7 @@ setup_phi2 <- function(llama_index, prompt, device)
tokenizer_name = "microsoft/phi-2",
device_map = device,
generate_kwargs = list(
"temperature" = as.double(0.1), do_sample = TRUE,
temperature = as.double(0.1), do_sample = TRUE,
pad_token_id = 2L, eos_token_id = 2L
)
), context_window = 2048L,
Expand All @@ -543,7 +600,7 @@ setup_phi2 <- function(llama_index, prompt, device)

#' @noRd
# TinyLLAMA ----
# Updated 28.01.2023
# Updated 28.01.2024
setup_tinyllama <- function(llama_index, prompt, device)
{

Expand Down
9 changes: 7 additions & 2 deletions R/setup_gpu_modules.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@
#' @export
#'
# Install GPU modules
# Updated 03.02.2024
# Updated 06.02.2024
setup_gpu_modules <- function()
{

# Set necessary modules
modules <- c(
"autoawq"
"autoawq", "auto-gptq", "optimum"
)

# Check for Linux
if(system.check()$OS == "linux"){
modules <- c(modules, "llama-cpp-python")
}

# Determine whether any modules need to be installed
installed_modules <- reticulate::py_list_packages(envname = "transforEmotion")

Expand Down
12 changes: 6 additions & 6 deletions R/setup_miniconda.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ conda_check <- function(){
#'
#' @author Alexander P. Christensen <[email protected]>
#' Aleksandar Tomašević <[email protected]>
#'
#'
#' @export
#'
# Install miniconda
# Updated 15.11.2023
setup_miniconda <- function()
{

# Install miniconda
path_to_miniconda <- try(
install_miniconda(),
silent = TRUE
)
if(any(class(path_to_miniconda) != "try-error")){
message("\nTo uninstall miniconda, use `reticulate::miniconda_uninstall()`")

if(any(class(path_to_miniconda) != "try-error")){
message("\nTo uninstall miniconda, use `reticulate::miniconda_uninstall()`")
}

# Create transformEmotion enviroment if it doesn't exist
Expand All @@ -52,7 +52,7 @@ setup_miniconda <- function()
# Activate the environment

reticulate::use_condaenv("transforEmotion", required = TRUE)

print("Installing missing Python libraries...")
setup_modules()
}
Expand Down

0 comments on commit c723e44

Please sign in to comment.