Skip to content

Commit

Permalink
Merge pull request #234 from SokolovAnatoliy/updating-token-handling-…
Browse files Browse the repository at this point in the history
…Azure

updates token handling for Azure
  • Loading branch information
JamesHWade authored Dec 17, 2024
2 parents 8443f1e + 6488ac8 commit 1b28eb7
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 44 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Imports:
utils,
yaml
Suggests:
AzureRMR,
AzureGraph,
future,
grDevices,
knitr,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export(create_completion_huggingface)
export(get_available_endpoints)
export(get_available_models)
export(get_ide_theme_info)
export(gptstudio_cache_directory)
export(gptstudio_chat)
export(gptstudio_chat_in_source_addin)
export(gptstudio_comment_code)
Expand Down
5 changes: 5 additions & 0 deletions R/cache.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#' a function that determines the appropriate directory to cache a token
#' @export
gptstudio_cache_directory <- function() {
tools::R_user_dir(package = "gptstudio")
}
64 changes: 42 additions & 22 deletions R/service-azure_openai.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,33 +108,54 @@ query_api_azure_openai <-
}

retrieve_azure_token <- function() {
rlang::check_installed("AzureRMR")

token <- tryCatch(
{
AzureRMR::get_azure_login(
tenant = Sys.getenv("AZURE_OPENAI_TENANT_ID"),
app = Sys.getenv("AZURE_OPENAI_CLIENT_ID"),
scopes = ".default"
)
},
error = function(e) NULL
)

if (is.null(token)) {
token <- AzureRMR::create_azure_login(
tenant = Sys.getenv("AZURE_OPENAI_TENANT_ID"),
app = Sys.getenv("AZURE_OPENAI_CLIENT_ID"),
password = Sys.getenv("AZURE_OPENAI_CLIENT_SECRET"),
host = "https://cognitiveservices.azure.com/",
scopes = ".default"
)
token <- retrieve_azure_token_object() |> suppressMessages()

invisible(token$credentials$access_token)
}


retrieve_azure_token_object <- function() {
rlang::check_installed("AzureGraph")

## Set this so that get_graph_login properly caches
azure_data_env <- Sys.getenv("R_AZURE_DATA_DIR")

Sys.setenv("R_AZURE_DATA_DIR" = gptstudio_cache_directory())

login <- try(AzureGraph::get_graph_login(tenant = Sys.getenv("AZURE_OPENAI_TENANT_ID"),
app = Sys.getenv("AZURE_OPENAI_CLIENT_ID"),
scopes = NULL,
refresh = FALSE),
silent = TRUE) |>
suppressMessages()

if (inherits(login, "try-error")) {

if (!dir.exists(gptstudio_cache_directory())) {
dir.create(gptstudio_cache_directory()) |>
suppressWarnings()
}


login <- AzureGraph::create_graph_login(tenant = Sys.getenv("AZURE_OPENAI_TENANT_ID"),
app = Sys.getenv("AZURE_OPENAI_CLIENT_ID"),
host = Sys.getenv("AZURE_OPENAI_SCOPE"),
scopes = NULL,
auth_type = "client_credentials",
password = Sys.getenv("AZURE_OPENAI_CLIENT_SECRET")) |>
suppressMessages()
}

invisible(token$token$credentials$access_token)
## Set this so that get_graph_login properly caches
Sys.setenv("R_AZURE_DATA_DIR" = azure_data_env)

invisible(login$token)
}




stream_azure_openai <- function(messages = list(list(role = "user", content = "hi there")),
element_callback = cat) {
body <- list(
Expand All @@ -155,6 +176,5 @@ stream_azure_openai <- function(messages = list(list(role = "user", content = "h
},
round = "line"
)

invisible(response)
}
11 changes: 11 additions & 0 deletions man/gptstudio_cache_directory.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 34 additions & 21 deletions tests/testthat/test-service-azure_openai.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,68 +153,81 @@ test_that("query_api_azure_openai handles error response", {
# Test token retrieval --------------------------------------------------------

test_that("retrieve_azure_token successfully gets existing token", {
skip_on_ci()

local_mocked_bindings(
get_azure_login = function(...) {
list(token = list(credentials = list(access_token = "existing_token")))
get_graph_login = function(...) {
list(credentials = list(access_token = "existing_token"))
},
create_azure_login = function(...) stop("Should not be called"),
.package = "AzureRMR"
create_graph_login = function(...) stop("Should not be called"),
.package = "AzureGraph"
)

token <- retrieve_azure_token()

expect_equal(token, "existing_token")
})

test_that("retrieve_azure_token creates new token when get_azure_login fails", {
test_that("retrieve_azure_token creates new token when get_graph_login fails", {
skip_on_ci()

local_mocked_bindings(
get_azure_login = function(...) stop("Error"),
create_azure_login = function(...) {
list(token = list(credentials = list(access_token = "new_token")))
get_graph_login = function(...) stop("Error"),
create_graph_login = function(...) {
list(credentials = list(access_token = "new_token"))
},
.package = "AzureRMR"
.package = "AzureGraph"
)

token <- retrieve_azure_token()

expect_equal(token, "new_token")
})


test_that("retrieve_azure_token uses correct environment variables", {
mock_get_azure_login <- function(tenant, app, scopes) {
skip_on_ci()

mock_get_graph_login <- function(tenant, app, scopes, refresh) {
expect_equal(tenant, "test_tenant")
expect_equal(app, "test_client")
expect_equal(scopes, ".default")
expect_equal(scopes, NULL)
expect_equal(refresh, FALSE)
stop("Error")
}

mock_create_azure_login <- function(tenant, app, password, host, scopes) {
mock_create_graph_login <- function(tenant, app, host, scopes, auth_type, password) {
expect_equal(tenant, "test_tenant")
expect_equal(app, "test_client")
expect_equal(host, "https://cognitiveservices.azure.com/.default")
expect_equal(scopes, NULL)
expect_equal(auth_type, "client_credentials")
expect_equal(password, "test_secret")
expect_equal(host, "https://cognitiveservices.azure.com/")
expect_equal(scopes, ".default")
list(token = list(credentials = list(access_token = "new_token")))
list(credentials = list(access_token = "new_token"))
}

local_mocked_bindings(
get_azure_login = mock_get_azure_login,
create_azure_login = mock_create_azure_login,
.package = "AzureRMR"
get_graph_login = mock_get_graph_login,
create_graph_login = mock_create_graph_login,
.package = "AzureGraph"
)

withr::local_envvar(
AZURE_OPENAI_TENANT_ID = "test_tenant",
AZURE_OPENAI_CLIENT_ID = "test_client",
AZURE_OPENAI_CLIENT_SECRET = "test_secret"
AZURE_OPENAI_CLIENT_SECRET = "test_secret",
AZURE_OPENAI_SCOPE = "https://cognitiveservices.azure.com/.default"
)

expect_no_error(retrieve_azure_token())
})

test_that("retrieve_azure_token checks for AzureRMR installation", {



test_that("retrieve_azure_token checks for AzureGraph installation", {
mock_check_installed <- function(pkg) {
expect_equal(pkg, "AzureRMR")
expect_equal(pkg, "AzureGraph")
}

local_mocked_bindings(
Expand Down

0 comments on commit 1b28eb7

Please sign in to comment.