diff --git a/cmd/local/start.go b/cmd/local/start.go index 831cf22..9cf4bea 100644 --- a/cmd/local/start.go +++ b/cmd/local/start.go @@ -18,8 +18,10 @@ package localcmd import ( "flag" + "fmt" "net/http" + crs "github.com/graphql-editor/stucco/pkg/cors" "github.com/graphql-editor/stucco/pkg/handlers" "github.com/graphql-editor/stucco/pkg/server" "github.com/graphql-editor/stucco/pkg/utils" @@ -63,21 +65,16 @@ func NewStartCommand() *cobra.Command { if err != nil { return err } + corsOptions := crs.NewCors() + fmt.Println(corsOptions.AllowedOrigins) middleware := func(next http.Handler) http.Handler { return handlers.RecoveryHandler( httplog.WithLogging( cors.New(cors.Options{ - AllowedOrigins: []string{"*"}, - AllowedMethods: []string{ - http.MethodHead, - http.MethodGet, - http.MethodPost, - http.MethodPut, - http.MethodPatch, - http.MethodDelete, - }, - AllowedHeaders: []string{"*"}, - AllowCredentials: true, + AllowedOrigins: corsOptions.AllowedOrigins, + AllowedMethods: corsOptions.AllowedMethods, + AllowedHeaders: corsOptions.AllowedHeaders, + AllowCredentials: corsOptions.AllowedCredentials, }).Handler(next), httplog.DefaultStacktracePred, ), diff --git a/pkg/cors/cors.go b/pkg/cors/cors.go new file mode 100644 index 0000000..877ef3a --- /dev/null +++ b/pkg/cors/cors.go @@ -0,0 +1,53 @@ +package cors + +import ( + "net/http" + "os" + "strconv" + "strings" +) + +type CorsOptions struct { + AllowedMethods, AllowedHeaders, AllowedOrigins []string + AllowedCredentials bool +} + +func retriveOriginEnv(name string) []string { + return strings.Split(os.Getenv(name), " ") +} + +func NewCors() CorsOptions { + allowedOrigins := []string{"*"} + if envOrigin := retriveOriginEnv("ALLOWED_ORIGINS"); envOrigin[0] != "" { + allowedOrigins = envOrigin + } + allowedMethods := []string{http.MethodHead, + http.MethodGet, + http.MethodPost, + http.MethodPut, + http.MethodPatch, + http.MethodDelete, + } + if envMethod := retriveOriginEnv("ALLOWED_METHODS"); envMethod[0] == "" { + allowedMethods = []string{"POST", "GET", "OPTIONS"} + } + allowedHeaders := []string{"*"} + if envHeaders := retriveOriginEnv("ALLOWED_HEADERS"); envHeaders[0] == "" { + allowedHeaders = []string{"Accept", "Authorization", "Origin", "Content-Type"} + } + allowedCredentials := true + var err error + if envCredentials := os.Getenv("ALLOWED_CREDENTIALS"); envCredentials != "" { + allowedCredentials, err = strconv.ParseBool(envCredentials) + if err != nil { + panic("cannot parse ALLOWED_CREDENTIALS env to boolean") + } + } + c := CorsOptions{ + AllowedMethods: allowedMethods, + AllowedHeaders: allowedHeaders, + AllowedOrigins: allowedOrigins, + AllowedCredentials: allowedCredentials, + } + return c +}