diff --git a/web/auth.go b/web/auth.go index 8ba50d4a..0b3f111d 100644 --- a/web/auth.go +++ b/web/auth.go @@ -580,12 +580,7 @@ func (h *loginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.FormValue("backto") != "" { backto := r.FormValue("backto") - - // to prevent redirecting to an external URL we only set the session data when: - // 1. we fail to parse backto - // 2. backto does not include a hostname - parsed, err := url.ParseRequestURI(backto) - if err != nil || parsed.Hostname() == "" { + if IsValidBacktoURL(backto) { session.Data["backto"] = backto } } diff --git a/web/backto.go b/web/backto.go new file mode 100644 index 00000000..c69f3e62 --- /dev/null +++ b/web/backto.go @@ -0,0 +1,70 @@ +package web + +import ( + "net/url" + "regexp" + "strings" +) + +var hostnameWhitelist = map[string]bool{ + "localhost": true, + "app.viam.dev": true, + "app.viam.com": true, +} + +var prTempEnvPattern = "pr-(\\d+)-appmain-bplesliplq-uc.a.run.app" + +// isWhitelisted returns true if the passed hostname is whitelisted or a temporary PR environment. +func isWhitelisted(hostname string) bool { + isPRTempEnv, err := regexp.MatchString(prTempEnvPattern, hostname) + if err != nil { + return false + } + + if isPRTempEnv { + return true + } + + return hostnameWhitelist[hostname] +} + +// isAllowedURLScheme returns true if the passed URL is using a "https" schema, or "http" for "localhost" URLs. +func isAllowedURLScheme(url *url.URL) bool { + if url.Scheme == "https" { + return true + } + + if url.Hostname() == "localhost" && url.Scheme == "http" { + return true + } + + return false +} + +// IsValidBacktoURL returns true if the passed string is a secure URL to a whitelisted +// hostname. The whitelisted hostnames are: "localhost", "app.viam.dev", and "app.viam.com". +// +// - https://example.com -> false +// - http://app.viam.com/path/name -> false +// - https://app.viam.com/path/name -> true +// - http://localhost/path/name -> true +func IsValidBacktoURL(path string) bool { + normalized := strings.ReplaceAll(path, "\\", "/") + url, err := url.Parse(normalized) + if err != nil { + // ignore invalid URLs/URL components + return false + } + + if !isAllowedURLScheme(url) { + // ignore non-secure URLs + return false + } + + if isWhitelisted(url.Hostname()) { + // ignore non-whitelisted hosts + return true + } + + return false +} diff --git a/web/backto_test.go b/web/backto_test.go new file mode 100644 index 00000000..407f0d50 --- /dev/null +++ b/web/backto_test.go @@ -0,0 +1,82 @@ +package web + +import ( + "testing" + + "go.viam.com/test" +) + +func TestIsValidBacktoURL(t *testing.T) { + t.Run("rejects external URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("https://example.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("http://example.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("ftp://example.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("://example.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//example.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("example.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("www.example.com"), test.ShouldBeFalse) + }) + + t.Run("rejects invalid production URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("http://app.viam.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("ftp://app.viam.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("://app.viam.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//app.viam.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//app.viam.com/some/path"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("app.viam.com"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("app.viam.com/some/path"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("www.app.viam.com"), test.ShouldBeFalse) + }) + + t.Run("accepts valid production URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("https://app.viam.com"), test.ShouldBeTrue) + test.That(t, IsValidBacktoURL("https://app.viam.com/some/path"), test.ShouldBeTrue) + }) + + t.Run("rejects invalid staging URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("http://app.viam.dev"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("ftp://app.viam.dev"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("://app.viam.dev"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//app.viam.dev"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//app.viam.dev/some/path"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("app.viam.dev"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("app.viam.dev/some/path"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("www.app.viam.dev"), test.ShouldBeFalse) + }) + + t.Run("accepts valid staging URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("https://app.viam.dev"), test.ShouldBeTrue) + test.That(t, IsValidBacktoURL("https://app.viam.dev/some/path"), test.ShouldBeTrue) + }) + + t.Run("rejects invalid local URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("ftp://localhost"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("://localhost"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//localhost"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//localhost/some/path"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("localhost"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("localhost/some/path"), test.ShouldBeFalse) + }) + + t.Run("accepts valid local URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("https://localhost"), test.ShouldBeTrue) + test.That(t, IsValidBacktoURL("https://localhost/some/path"), test.ShouldBeTrue) + test.That(t, IsValidBacktoURL("http://localhost"), test.ShouldBeTrue) + test.That(t, IsValidBacktoURL("http://localhost/some/path"), test.ShouldBeTrue) + }) + + t.Run("rejects invalid temp PR env URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("http://pr-1-appmain-bplesliplq-uc.a.run.app"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("ftp://pr-12-appmain-bplesliplq-uc.a.run.app"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("://pr-123-appmain-bplesliplq-uc.a.run.app"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//pr-1234-appmain-bplesliplq-uc.a.run.app"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("//pr-12345-appmain-bplesliplq-uc.a.run.app/some/path"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("pr-1234-appmain-bplesliplq-uc.a.run.app"), test.ShouldBeFalse) + test.That(t, IsValidBacktoURL("pr-123-appmain-bplesliplq-uc.a.run.app/some/path"), test.ShouldBeFalse) + }) + + t.Run("accepts valid temp PR env URLs", func(t *testing.T) { + test.That(t, IsValidBacktoURL("https://pr-12345-appmain-bplesliplq-uc.a.run.app"), test.ShouldBeTrue) + test.That(t, IsValidBacktoURL("https://pr-6789-appmain-bplesliplq-uc.a.run.app/some/path"), test.ShouldBeTrue) + }) +}