diff --git a/flyteadmin/auth/cookie.go b/flyteadmin/auth/cookie.go index 2470220d24..456eeb8580 100644 --- a/flyteadmin/auth/cookie.go +++ b/flyteadmin/auth/cookie.go @@ -12,6 +12,7 @@ import ( "github.com/gorilla/securecookie" "github.com/flyteorg/flyte/flyteadmin/auth/interfaces" + "github.com/flyteorg/flyte/flyteadmin/pkg/config" "github.com/flyteorg/flyte/flytestdlib/errors" "github.com/flyteorg/flyte/flytestdlib/logger" ) @@ -68,6 +69,8 @@ func NewSecureCookie(cookieName, value string, hashKey, blockKey []byte, domain Value: encoded, Domain: domain, SameSite: sameSiteMode, + HttpOnly: true, + Secure: !config.GetConfig().Security.InsecureCookieHeader, }, nil } @@ -126,6 +129,7 @@ func NewCsrfCookie() http.Cookie { Value: csrfStateToken, SameSite: http.SameSiteLaxMode, HttpOnly: true, + Secure: !config.GetConfig().Security.InsecureCookieHeader, } } @@ -164,6 +168,7 @@ func NewRedirectCookie(ctx context.Context, redirectURL string) *http.Cookie { Value: urlObj.String(), SameSite: http.SameSiteLaxMode, HttpOnly: true, + Secure: !config.GetConfig().Security.InsecureCookieHeader, } } diff --git a/flyteadmin/auth/cookie_manager.go b/flyteadmin/auth/cookie_manager.go index ce360c9d3a..8a23272d01 100644 --- a/flyteadmin/auth/cookie_manager.go +++ b/flyteadmin/auth/cookie_manager.go @@ -11,6 +11,7 @@ import ( "golang.org/x/oauth2" "github.com/flyteorg/flyte/flyteadmin/auth/config" + serverConfig "github.com/flyteorg/flyte/flyteadmin/pkg/config" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flytestdlib/errors" "github.com/flyteorg/flyte/flytestdlib/logger" @@ -218,6 +219,7 @@ func (c *CookieManager) getLogoutCookie(name string) *http.Cookie { Domain: c.domain, MaxAge: 0, HttpOnly: true, + Secure: !serverConfig.GetConfig().Security.InsecureCookieHeader, Expires: time.Now().Add(-1 * time.Hour), } } diff --git a/flyteadmin/auth/cookie_manager_test.go b/flyteadmin/auth/cookie_manager_test.go index 09d8468e83..444056ba8c 100644 --- a/flyteadmin/auth/cookie_manager_test.go +++ b/flyteadmin/auth/cookie_manager_test.go @@ -16,6 +16,7 @@ import ( "golang.org/x/oauth2" "github.com/flyteorg/flyte/flyteadmin/auth/config" + serverConfig "github.com/flyteorg/flyte/flyteadmin/pkg/config" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" ) @@ -199,34 +200,53 @@ func TestCookieManager(t *testing.T) { assert.EqualError(t, err, "[EMPTY_OAUTH_TOKEN] Error reading existing secure cookie [flyte_idt]. Error: [SECURE_COOKIE_ERROR] Error reading secure cookie flyte_idt, caused by: securecookie: error - caused by: crypto/aes: invalid key size 75") }) - t.Run("delete_cookies", func(t *testing.T) { - w := httptest.NewRecorder() - - manager.DeleteCookies(ctx, w) - - cookies := w.Result().Cookies() - require.Equal(t, 5, len(cookies)) - - assert.True(t, time.Now().After(cookies[0].Expires)) - assert.Equal(t, cookieSetting.Domain, cookies[0].Domain) - assert.Equal(t, accessTokenCookieName, cookies[0].Name) - - assert.True(t, time.Now().After(cookies[1].Expires)) - assert.Equal(t, cookieSetting.Domain, cookies[1].Domain) - assert.Equal(t, accessTokenCookieNameSplitFirst, cookies[1].Name) - - assert.True(t, time.Now().After(cookies[2].Expires)) - assert.Equal(t, cookieSetting.Domain, cookies[2].Domain) - assert.Equal(t, accessTokenCookieNameSplitSecond, cookies[2].Name) - - assert.True(t, time.Now().After(cookies[3].Expires)) - assert.Equal(t, cookieSetting.Domain, cookies[3].Domain) - assert.Equal(t, refreshTokenCookieName, cookies[3].Name) + tests := []struct { + name string + insecureCookieHeader bool + expectedSecure bool + }{ + { + name: "secure_cookies", + insecureCookieHeader: false, + expectedSecure: true, + }, + { + name: "insecure_cookies", + insecureCookieHeader: true, + expectedSecure: false, + }, + } - assert.True(t, time.Now().After(cookies[4].Expires)) - assert.Equal(t, cookieSetting.Domain, cookies[4].Domain) - assert.Equal(t, idTokenCookieName, cookies[4].Name) - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + + serverConfig.SetConfig(&serverConfig.ServerConfig{ + Security: serverConfig.ServerSecurityOptions{ + InsecureCookieHeader: tt.insecureCookieHeader, + }, + }) + + manager.DeleteCookies(ctx, w) + + cookies := w.Result().Cookies() + require.Equal(t, 5, len(cookies)) + + // Check secure flag for each cookie + for _, cookie := range cookies { + assert.Equal(t, tt.expectedSecure, cookie.Secure) + assert.True(t, time.Now().After(cookie.Expires)) + assert.Equal(t, cookieSetting.Domain, cookie.Domain) + } + + // Check cookie names + assert.Equal(t, accessTokenCookieName, cookies[0].Name) + assert.Equal(t, accessTokenCookieNameSplitFirst, cookies[1].Name) + assert.Equal(t, accessTokenCookieNameSplitSecond, cookies[2].Name) + assert.Equal(t, refreshTokenCookieName, cookies[3].Name) + assert.Equal(t, idTokenCookieName, cookies[4].Name) + }) + } t.Run("get_http_same_site_policy", func(t *testing.T) { manager.sameSitePolicy = config.SameSiteLaxMode diff --git a/flyteadmin/auth/cookie_test.go b/flyteadmin/auth/cookie_test.go index a5c58ad2ff..1134e957dc 100644 --- a/flyteadmin/auth/cookie_test.go +++ b/flyteadmin/auth/cookie_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/base64" - "fmt" "net/http" "net/url" "testing" @@ -14,6 +13,7 @@ import ( "github.com/flyteorg/flyte/flyteadmin/auth/config" "github.com/flyteorg/flyte/flyteadmin/auth/interfaces/mocks" + serverConfig "github.com/flyteorg/flyte/flyteadmin/pkg/config" stdConfig "github.com/flyteorg/flyte/flytestdlib/config" ) @@ -26,22 +26,53 @@ func mustParseURL(t testing.TB, u string) url.URL { return *res } -// This function can also be called locally to generate new keys func TestSecureCookieLifecycle(t *testing.T) { - hashKey := securecookie.GenerateRandomKey(64) - assert.True(t, base64.RawStdEncoding.EncodeToString(hashKey) != "") - - blockKey := securecookie.GenerateRandomKey(32) - assert.True(t, base64.RawStdEncoding.EncodeToString(blockKey) != "") - fmt.Printf("Hash key: |%s| Block key: |%s|\n", - base64.RawStdEncoding.EncodeToString(hashKey), base64.RawStdEncoding.EncodeToString(blockKey)) - - cookie, err := NewSecureCookie("choc", "chip", hashKey, blockKey, "localhost", http.SameSiteLaxMode) - assert.NoError(t, err) + tests := []struct { + name string + insecureCookieHeader bool + expectedSecure bool + }{ + { + name: "secure_cookie", + insecureCookieHeader: false, + expectedSecure: true, + }, + { + name: "insecure_cookie", + insecureCookieHeader: true, + expectedSecure: false, + }, + } - value, err := ReadSecureCookie(context.Background(), cookie, hashKey, blockKey) - assert.NoError(t, err) - assert.Equal(t, "chip", value) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Generate hash and block keys for secure cookie + hashKey := securecookie.GenerateRandomKey(64) + assert.True(t, base64.RawStdEncoding.EncodeToString(hashKey) != "") + + blockKey := securecookie.GenerateRandomKey(32) + assert.True(t, base64.RawStdEncoding.EncodeToString(blockKey) != "") + + // Set up server configuration with insecureCookieHeader option + serverConfig.SetConfig(&serverConfig.ServerConfig{ + Security: serverConfig.ServerSecurityOptions{ + InsecureCookieHeader: tt.insecureCookieHeader, + }, + }) + + // Create a secure cookie + cookie, err := NewSecureCookie("choc", "chip", hashKey, blockKey, "localhost", http.SameSiteLaxMode) + assert.NoError(t, err) + + // Validate the Secure attribute of the cookie + assert.Equal(t, tt.expectedSecure, cookie.Secure) + + // Read and validate the secure cookie value + value, err := ReadSecureCookie(context.Background(), cookie, hashKey, blockKey) + assert.NoError(t, err) + assert.Equal(t, "chip", value) + }) + } } func TestNewCsrfToken(t *testing.T) { @@ -50,9 +81,41 @@ func TestNewCsrfToken(t *testing.T) { } func TestNewCsrfCookie(t *testing.T) { - cookie := NewCsrfCookie() - assert.Equal(t, "flyte_csrf_state", cookie.Name) - assert.True(t, cookie.HttpOnly) + tests := []struct { + name string + insecureCookieHeader bool + expectedSecure bool + }{ + { + name: "secure_csrf_cookie", + insecureCookieHeader: false, + expectedSecure: true, + }, + { + name: "insecure_csrf_cookie", + insecureCookieHeader: true, + expectedSecure: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up server configuration with insecureCookieHeader option + serverConfig.SetConfig(&serverConfig.ServerConfig{ + Security: serverConfig.ServerSecurityOptions{ + InsecureCookieHeader: tt.insecureCookieHeader, + }, + }) + + // Generate CSRF cookie + cookie := NewCsrfCookie() + + // Validate CSRF cookie properties + assert.Equal(t, "flyte_csrf_state", cookie.Name) + assert.True(t, cookie.HttpOnly) + assert.Equal(t, tt.expectedSecure, cookie.Secure) + }) + } } func TestHashCsrfState(t *testing.T) { @@ -121,6 +184,36 @@ func TestNewRedirectCookie(t *testing.T) { assert.NotNil(t, cookie) assert.Equal(t, http.SameSiteLaxMode, cookie.SameSite) }) + + tests := []struct { + name string + insecureCookieHeader bool + expectedSecure bool + }{ + { + name: "secure_cookies", + insecureCookieHeader: false, + expectedSecure: true, + }, + { + name: "insecure_cookies", + insecureCookieHeader: true, + expectedSecure: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + serverConfig.SetConfig(&serverConfig.ServerConfig{ + Security: serverConfig.ServerSecurityOptions{ + InsecureCookieHeader: tt.insecureCookieHeader, + }, + }) + ctx := context.Background() + cookie := NewRedirectCookie(ctx, "http://www.example.com/postLogin") + assert.NotNil(t, cookie) + assert.Equal(t, cookie.Secure, tt.expectedSecure) + }) + } } func TestGetAuthFlowEndRedirect(t *testing.T) { diff --git a/flyteadmin/pkg/config/config.go b/flyteadmin/pkg/config/config.go index f6bdd27141..0e63eccb45 100644 --- a/flyteadmin/pkg/config/config.go +++ b/flyteadmin/pkg/config/config.go @@ -66,10 +66,13 @@ type KubeClientConfig struct { } type ServerSecurityOptions struct { - Secure bool `json:"secure"` - Ssl SslOptions `json:"ssl"` - UseAuth bool `json:"useAuth"` - AuditAccess bool `json:"auditAccess"` + Secure bool `json:"secure"` + Ssl SslOptions `json:"ssl"` + UseAuth bool `json:"useAuth"` + // InsecureCookieHeader should only be set in the case where we want to serve cookies with the header "Secure" set to false. + // This is useful for local development and *never* in production. + InsecureCookieHeader bool `json:"insecureCookieHeader"` + AuditAccess bool `json:"auditAccess"` // These options are here to allow deployments where the Flyte UI (Console) is served from a different domain/port. // Note that CORS only applies to Admin's API endpoints. The health check endpoint for instance is unaffected. diff --git a/flyteadmin/pkg/config/serverconfig_flags.go b/flyteadmin/pkg/config/serverconfig_flags.go index 10229a458a..09a5d70a26 100755 --- a/flyteadmin/pkg/config/serverconfig_flags.go +++ b/flyteadmin/pkg/config/serverconfig_flags.go @@ -59,6 +59,7 @@ func (cfg ServerConfig) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "security.ssl.certificateFile"), defaultServerConfig.Security.Ssl.CertificateFile, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "security.ssl.keyFile"), defaultServerConfig.Security.Ssl.KeyFile, "") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "security.useAuth"), defaultServerConfig.Security.UseAuth, "") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "security.insecureCookieHeader"), defaultServerConfig.Security.InsecureCookieHeader, "") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "security.auditAccess"), defaultServerConfig.Security.AuditAccess, "") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "security.allowCors"), defaultServerConfig.Security.AllowCors, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "security.allowedOrigins"), defaultServerConfig.Security.AllowedOrigins, "") diff --git a/flyteadmin/pkg/config/serverconfig_flags_test.go b/flyteadmin/pkg/config/serverconfig_flags_test.go index 6a95336f40..a18b56156e 100755 --- a/flyteadmin/pkg/config/serverconfig_flags_test.go +++ b/flyteadmin/pkg/config/serverconfig_flags_test.go @@ -225,6 +225,20 @@ func TestServerConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_security.insecureCookieHeader", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("security.insecureCookieHeader", testValue) + if vBool, err := cmdFlags.GetBool("security.insecureCookieHeader"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vBool), &actual.Security.InsecureCookieHeader) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_security.auditAccess", func(t *testing.T) { t.Run("Override", func(t *testing.T) {