Skip to content

Commit

Permalink
Update rag.R
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexChristensen committed Feb 1, 2024
1 parent 346c9fe commit f40dbf6
Showing 1 changed file with 63 additions and 61 deletions.
124 changes: 63 additions & 61 deletions R/rag.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,57 +125,59 @@ rag <- function(
envir = 1, progress = TRUE
)
{

# Check that input of 'text' argument is in the appropriate format for the analysis
non_text_warning(text) # see utils-transforEmotion.R for function

if(!is.null(text)){
non_text_warning(text) # see utils-transforEmotion.R for function
}

# Check that 'text' or 'path' are set
if(is.null(text) & is.null(path)){
stop("Argument 'text' or 'path' must be provided.", call. = FALSE)
}

# Check for 'transformer'
if(missing(transformer)){
transformer <- "tinyllama"
}else{transformer <- tolower(match.arg(transformer))}

# Check for 'query'
if(missing(query)){
stop("A 'query' must be provided")
}

# Set default for 'response_mode'
if(missing(response_mode)){
response_mode <- "tree_summarize"
}else{response_mode <- match.arg(response_mode)}

# Set default for 'device'
if(missing(response_mode)){
device <- "auto"
}else{device <- match.arg(device)}

# Run setup for modules
setup_modules()

# Check for llama_index in environment
if(!exists("llama_index", envir = as.environment(envir))){

# Import 'llama-index'
message("Importing llama-index module...")
llama_index <- reticulate::import("llama_index")

}

# Check for service context
if(exists("service_context", envir = as.environment(envir))){

# Check for service context LLM
if(attr(service_context, which = "transformer") != transformer){
rm(service_context, envir = as.environment(envir)); gc(verbose = FALSE)
}

}

# Get service context
if(!exists("service_context", envir = as.environment(envir))){

Expand All @@ -193,89 +195,89 @@ rag <- function(
"phi-2" = setup_phi2(llama_index, prompt, device),
stop(paste0("'", transformer, "' not found"), call. = FALSE)
)

}

# Add transformer attribute to `service_context`
attr(service_context, which = "transformer") <- transformer

# Load into environment
if(isTRUE(keep_in_env)){

# Keep llama-index module in environment
assign(
x = "llama_index",
value = llama_index,
envir = as.environment(envir)
)

# Keep service_context in the environment
assign(
x = "service_context",
value = service_context,
envir = as.environment(envir)
)

}

# Depending on where documents are, load them
if(!is.null(path)){

# Set documents
documents <- llama_index$SimpleDirectoryReader(path)$load_data()

}else if(!is.null(text)){

# Set documents
documents <- lapply(
text, function(x){
llama_index$Document(text = x)
}
)

}

# Send message to user
message("Indexing documents...")

# Set indices
index <- llama_index$VectorStoreIndex(
documents, service_context = service_context,
show_progress = progress
)

# Set up query engine
engine <- index$as_query_engine(
similarity_top_k = similarity_top_k,
response_mode = response_mode
)

# Send message to user
message("Querying...", appendLF = FALSE)

# Start time
start <- Sys.time()

# Get query
extracted_query <- engine$query(query)

# Stop time
message(paste0(" elapsed: ", round(Sys.time() - start), "s"))

# Organize Python output
output <- list(
response = response_cleanup(
extracted_query$response, transformer = transformer
),
content = content_cleanup(extracted_query$source_nodes)
)

# Set class
class(output) <- "rag"

# Return response
return(output)

}

#' @exportS3Method
Expand All @@ -296,10 +298,10 @@ summary.rag <- function(object, ...){
#' @noRd
# Updated 29.01.2024
response_cleanup <- function(response, transformer){

# Trim whitespace first!
response <- trimws(response)

# Return on switch
return(
switch(
Expand All @@ -316,45 +318,45 @@ response_cleanup <- function(response, transformer){
"tinyllama" = response
)
)

}

#' Clean up content
#' @noRd
# Updated 28.01.2024
content_cleanup <- function(content){

# Get number of documents
n_documents <- length(content)

# Initialize data frame
content_df <- matrix(
data = NA, nrow = n_documents, ncol = 3,
dimnames = list(
NULL, c("document", "text", "score")
)
)


# Loop over content
for(i in seq_len(n_documents)){

# Populate matrix
content_df[i,] <- c(
content[[i]]$id_, content[[i]]$text, content[[i]]$score
)

}

# Make it a real data frame
content_df <- as.data.frame(content_df)

# Set proper modes
content_df$score <- as.numeric(content_df$score)

# Return data frame
return(content_df)

}

#' Set up for LLAMA-2
Expand All @@ -363,7 +365,7 @@ content_cleanup <- function(content){
# Updated 28.01.2023
setup_llama2 <- function(llama_index, prompt, device)
{

# Return model
return(
llama_index$ServiceContext$from_defaults(
Expand All @@ -383,7 +385,7 @@ setup_llama2 <- function(llama_index, prompt, device)
embed_model = "local:BAAI/bge-small-en-v1.5"
)
)

}

#' Set up for Mistral-7B
Expand All @@ -392,7 +394,7 @@ setup_llama2 <- function(llama_index, prompt, device)
# Updated 28.01.2023
setup_mistral <- function(llama_index, prompt, device)
{

# Return model
return(
llama_index$ServiceContext$from_defaults(
Expand All @@ -408,7 +410,7 @@ setup_mistral <- function(llama_index, prompt, device)
embed_model = "local:BAAI/bge-small-en-v1.5"
)
)

}

#' Set up for OpenChat-3.5
Expand All @@ -417,7 +419,7 @@ setup_mistral <- function(llama_index, prompt, device)
# Updated 28.01.2023
setup_openchat <- function(llama_index, prompt, device)
{

# Return model
return(
llama_index$ServiceContext$from_defaults(
Expand All @@ -432,7 +434,7 @@ setup_openchat <- function(llama_index, prompt, device)
embed_model = "local:BAAI/bge-small-en-v1.5"
)
)

}

#' Set up for Orca-2
Expand All @@ -441,7 +443,7 @@ setup_openchat <- function(llama_index, prompt, device)
# Updated 28.01.2023
setup_orca2 <- function(llama_index, prompt, device)
{

# Return model
return(
llama_index$ServiceContext$from_defaults(
Expand All @@ -456,7 +458,7 @@ setup_orca2 <- function(llama_index, prompt, device)
embed_model = "local:BAAI/bge-small-en-v1.5"
)
)

}

#' Set up for Phi-2
Expand All @@ -465,7 +467,7 @@ setup_orca2 <- function(llama_index, prompt, device)
# Updated 28.01.2023
setup_phi2 <- function(llama_index, prompt, device)
{

# Return model
return(
llama_index$ServiceContext$from_defaults(
Expand All @@ -481,7 +483,7 @@ setup_phi2 <- function(llama_index, prompt, device)
embed_model = "local:BAAI/bge-small-en-v1.5"
)
)

}

#' Set up for TinyLLAMA
Expand All @@ -490,7 +492,7 @@ setup_phi2 <- function(llama_index, prompt, device)
# Updated 28.01.2023
setup_tinyllama <- function(llama_index, prompt, device)
{

# Return model
return(
llama_index$ServiceContext$from_defaults(
Expand All @@ -507,5 +509,5 @@ setup_tinyllama <- function(llama_index, prompt, device)
embed_model = "local:BAAI/bge-small-en-v1.5"
)
)

}

0 comments on commit f40dbf6

Please sign in to comment.