Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Honor Discovery URL for flows other than M2M + fallback for discovery URL #388

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
import java.lang.reflect.Field;
import java.util.*;
import org.apache.http.HttpMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DatabricksConfig {
private static final Logger LOG = LoggerFactory.getLogger(DatabricksConfig.class);
private CredentialsProvider credentialsProvider = new DefaultCredentialsProvider();

@ConfigAttribute(env = "DATABRICKS_HOST")
Expand Down Expand Up @@ -545,7 +548,19 @@ public OpenIDConnectEndpoints getOidcEndpoints() throws IOException {
if (discoveryUrl == null) {
return fetchDefaultOidcEndpoints();
}
return fetchOidcEndpointsFromDiscovery();
try {
OpenIDConnectEndpoints oidcEndpoints = fetchOidcEndpointsFromDiscovery();
if (oidcEndpoints != null) {
return oidcEndpoints;
}
} catch (Exception e) {
LOG.warn(
"Failed to fetch OIDC Endpoints using discovery URL: {}. Error: {}. \nDefaulting to fetch OIDC using default endpoint.",
discoveryUrl,
e.getMessage(),
e);
}
return fetchDefaultOidcEndpoints();
}

private OpenIDConnectEndpoints fetchOidcEndpointsFromDiscovery() {
Expand Down Expand Up @@ -632,6 +647,7 @@ public DatabricksEnvironment getDatabricksEnvironment() {
}

private DatabricksConfig clone(Set<String> fieldsToSkip) {
fieldsToSkip.add("LOG");
DatabricksConfig newConfig = new DatabricksConfig();
for (Field f : DatabricksConfig.class.getDeclaredFields()) {
if (fieldsToSkip.contains(f.getName())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class OAuthClient {
public static class Builder {
private String host;
private String clientId;
private String discoveryUrl;
private String redirectUrl;
private List<String> scopes;
private String clientSecret;
Expand All @@ -53,6 +54,11 @@ public Builder withClientId(String clientId) {
return this;
}

public Builder withDiscoveryUrl(String discoveryUrl) {
this.discoveryUrl = discoveryUrl;
return this;
}

public Builder withClientSecret(String clientSecret) {
this.clientSecret = clientSecret;
return this;
Expand Down Expand Up @@ -91,6 +97,7 @@ public OAuthClient(DatabricksConfig config) throws IOException {
.withHttpClient(config.getHttpClient())
.withClientId(config.getClientId())
.withClientSecret(config.getClientSecret())
.withDiscoveryUrl(config.getDiscoveryUrl())
.withHost(config.getHost())
.withRedirectUrl(
config.getOAuthRedirectUrl() != null
Expand All @@ -106,7 +113,8 @@ private OAuthClient(Builder b) throws IOException {
this.host = b.host;
this.hc = b.hc;

DatabricksConfig config = new DatabricksConfig().setHost(b.host).resolve();
DatabricksConfig config =
new DatabricksConfig().setHost(b.host).setDiscoveryUrl(b.discoveryUrl).resolve();
OpenIDConnectEndpoints oidc = config.getOidcEndpoints();
if (oidc == null) {
throw new DatabricksException(b.host + " does not support OAuth");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,35 @@ public void testDiscoveryEndpoint() throws IOException {
}
}

@Test
public void testDiscoveryEndpointFetchThrowsError() throws IOException {
String discoveryUrlSuffix = "/test.discovery.url";
String OIDCResponse =
"{\n"
+ " \"authorization_endpoint\": \"https://test.auth.endpoint/oidc/v1/authorize\",\n"
+ " \"token_endpoint\": \"https://test.auth.endpoint/oidc/v1/token\"\n"
+ "}";

try (FixtureServer server =
new FixtureServer()
.with("GET", discoveryUrlSuffix, "", 400)
.with("GET", "/oidc/.well-known/oauth-authorization-server", OIDCResponse, 200)) {

String discoveryUrl = server.getUrl() + discoveryUrlSuffix;

OpenIDConnectEndpoints oidcEndpoints =
new DatabricksConfig()
.setHost(server.getUrl())
.setDiscoveryUrl(discoveryUrl)
.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build())
.getOidcEndpoints();

assertEquals(
oidcEndpoints.getAuthorizationEndpoint(), "https://test.auth.endpoint/oidc/v1/authorize");
assertEquals(oidcEndpoints.getTokenEndpoint(), "https://test.auth.endpoint/oidc/v1/token");
}
}

@Test
public void testNewWithWorkspaceHost() {
DatabricksConfig config =
Expand Down
Loading