Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Store additional claims in the QueryUserInfoFromAccessToken path #512

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
"strings"
"time"

_struct "github.com/golang/protobuf/ptypes/struct"
"google.golang.org/protobuf/encoding/protojson"

"github.com/flyteorg/flyteadmin/auth/interfaces"
"github.com/flyteorg/flyteadmin/pkg/common"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
Expand Down Expand Up @@ -318,6 +321,8 @@ func GetHTTPRequestCookieToMetadataHandler(authCtx interfaces.AuthenticationCont
logger.Infof(ctx, "Failed to retrieve user info cookie. Ignoring. Error: %v", err)
}

logger.Debugf(ctx, "Retrieved [%v] Additional Claims: [%+v]", len(userInfo.AdditionalClaims.AsMap()), userInfo.AdditionalClaims.AsMap())

raw, err := json.Marshal(userInfo)
if err != nil {
logger.Infof(ctx, "Failed to marshal user info. Ignoring. Error: %v", err)
Expand Down Expand Up @@ -373,6 +378,8 @@ func IdentityContextFromRequest(ctx context.Context, req *http.Request, authCtx
return nil, fmt.Errorf("unauthenticated request. Error: %w", err)
}

logger.Debugf(ctx, "Retrieved2 [%v] Additional Claims: [%+v]", len(userInfo.AdditionalClaims.AsMap()), userInfo.AdditionalClaims.AsMap())

return IdentityContextFromIDTokenToken(ctx, idToken, authCtx.Options().UserAuth.OpenID.ClientID,
authCtx.OidcProvider(), userInfo)
}
Expand Down Expand Up @@ -410,16 +417,45 @@ func QueryUserInfoUsingAccessToken(ctx context.Context, originalRequest *http.Re
userInfo, err := authCtx.OidcProvider().UserInfo(ctx, tokenSource)
if err != nil {
logger.Errorf(ctx, "Error getting user info from IDP %s", err)
return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP")
return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP. Error: %w", err)
}

resp := &service.UserInfoResponse{}
err = userInfo.Claims(&resp)
if err != nil {
logger.Errorf(ctx, "Error getting user info from IDP %s", err)
return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP")
return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP. Error: %w", err)
}

allClaims := make(map[string]any, 10)
err = userInfo.Claims(&allClaims)
if err != nil {
logger.Errorf(ctx, "Error unmarshalling raw claims %s", err)
return &service.UserInfoResponse{}, fmt.Errorf("error unmarshalling raw claims. Error: %w", err)
}

logger.Debugf(ctx, "Unmarshalled a total of [%v] claims: [%+v]", len(allClaims), allClaims)

alreadyRead := []string{"subject", "name", "preferred_username", "given_name", "family_name", "email", "picture"}
for _, existing := range alreadyRead {
delete(allClaims, existing)
}

logger.Debugf(ctx, "Remaining a total of [%v] additional claims: [%+v]", len(allClaims), allClaims)

var response _struct.Struct
b, err := json.Marshal(allClaims)
if err != nil {
return &service.UserInfoResponse{}, fmt.Errorf("failed to marshal additional claims to json. Error: %w", err)
}

err = protojson.Unmarshal(b, &response)
if err != nil {
return nil, fmt.Errorf("failed to unamarshal additional claims to proto.struct. Error: %w", err)
}

resp.AdditionalClaims = &response

return resp, err
}

Expand Down Expand Up @@ -449,6 +485,7 @@ func GetUserInfoForwardResponseHandler() UserInfoForwardResponseHandler {
return func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error {
info, ok := m.(*service.UserInfoResponse)
if ok {
logger.Debugf(ctx, "GetUserInfoForwardResponseHandler: Additional claims: [%+v]", info.AdditionalClaims.GetFields())
if info.AdditionalClaims != nil {
for k, v := range info.AdditionalClaims.GetFields() {
jsonBytes, err := v.MarshalJSON()
Expand All @@ -457,6 +494,7 @@ func GetUserInfoForwardResponseHandler() UserInfoForwardResponseHandler {
continue
}
header := fmt.Sprintf("X-User-Claim-%s", strings.ReplaceAll(k, "_", "-"))
logger.Debugf(ctx, "Setting header [%v: %v]", header, string(jsonBytes))
w.Header().Set(header, string(jsonBytes))
}
}
Expand Down