diff --git a/api/healthcheck/healthcheck_controller.go b/api/healthcheck/healthcheck_controller.go index 9e55f611e..3e3a09a23 100644 --- a/api/healthcheck/healthcheck_controller.go +++ b/api/healthcheck/healthcheck_controller.go @@ -17,23 +17,23 @@ package healthcheck import ( - h "github.com/InVisionApp/go-health" + "context" + gohealth "github.com/InVisionApp/go-health" "github.com/Peripli/service-manager/pkg/health" - "github.com/Peripli/service-manager/pkg/util" - "net/http" - "github.com/Peripli/service-manager/pkg/log" + "github.com/Peripli/service-manager/pkg/util" "github.com/Peripli/service-manager/pkg/web" + "net/http" ) // controller platform controller type controller struct { - health h.IHealth + health gohealth.IHealth thresholds map[string]int64 } // NewController returns a new healthcheck controller with the given health and thresholds -func NewController(health h.IHealth, thresholds map[string]int64) web.Controller { +func NewController(health gohealth.IHealth, thresholds map[string]int64) web.Controller { return &controller{ health: health, thresholds: thresholds, @@ -46,7 +46,7 @@ func (c *controller) healthCheck(r *web.Request) (*web.Response, error) { logger := log.C(ctx) logger.Debugf("Performing health check...") healthState, _, _ := c.health.State() - healthResult := c.aggregate(healthState) + healthResult := c.aggregate(ctx, healthState) var status int if healthResult.Status == health.StatusUp { status = http.StatusOK @@ -56,22 +56,25 @@ func (c *controller) healthCheck(r *web.Request) (*web.Response, error) { return util.NewJSONResponse(status, healthResult) } -func (c *controller) aggregate(overallState map[string]h.State) *health.Health { - if len(overallState) == 0 { +func (c *controller) aggregate(ctx context.Context, healthState map[string]gohealth.State) *health.Health { + if len(healthState) == 0 { return health.New().WithStatus(health.StatusUp) } + + details := make(map[string]interface{}) overallStatus := health.StatusUp - for name, state := range overallState { + for name, state := range healthState { if state.Fatal && state.ContiguousFailures >= c.thresholds[name] { overallStatus = health.StatusDown - break } - } - details := make(map[string]interface{}) - for name, state := range overallState { state.Status = convertStatus(state.Status) + if !web.IsAuthorized(ctx) { + state.Details = nil + state.Err = "" + } details[name] = state } + return health.New().WithStatus(overallStatus).WithDetails(details) } diff --git a/api/healthcheck/healthcheck_controller_test.go b/api/healthcheck/healthcheck_controller_test.go index fc353d341..826c95934 100644 --- a/api/healthcheck/healthcheck_controller_test.go +++ b/api/healthcheck/healthcheck_controller_test.go @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package healthcheck import ( + "context" "fmt" h "github.com/InVisionApp/go-health" "github.com/Peripli/service-manager/pkg/health" @@ -60,14 +62,17 @@ var _ = Describe("Healthcheck controller", func() { }) Describe("aggregation", func() { + var ctx context.Context var c *controller var healths map[string]h.State var thresholds map[string]int64 BeforeEach(func() { + ctx = context.TODO() healths = map[string]h.State{ - "test1": {Status: "ok"}, - "test2": {Status: "ok"}, + "test1": {Status: "ok", Details: "details"}, + "test2": {Status: "ok", Details: "details"}, + "failedState": {Status: "failed", Details: "details", Err: "err"}, } thresholds = map[string]int64{ "test1": 3, @@ -81,7 +86,7 @@ var _ = Describe("Healthcheck controller", func() { When("No healths are provided", func() { It("Returns UP", func() { - aggregatedHealth := c.aggregate(nil) + aggregatedHealth := c.aggregate(ctx, nil) Expect(aggregatedHealth.Status).To(Equal(health.StatusUp)) }) }) @@ -90,7 +95,7 @@ var _ = Describe("Healthcheck controller", func() { It("Returns DOWN", func() { healths["test3"] = h.State{Status: "failed", Fatal: true, ContiguousFailures: 4} c.thresholds["test3"] = 3 - aggregatedHealth := c.aggregate(healths) + aggregatedHealth := c.aggregate(ctx, healths) Expect(aggregatedHealth.Status).To(Equal(health.StatusDown)) }) }) @@ -98,7 +103,7 @@ var _ = Describe("Healthcheck controller", func() { When("At least one health is DOWN and is not Fatal", func() { It("Returns UP", func() { healths["test3"] = h.State{Status: "failed", Fatal: false, ContiguousFailures: 4} - aggregatedHealth := c.aggregate(healths) + aggregatedHealth := c.aggregate(ctx, healths) Expect(aggregatedHealth.Status).To(Equal(health.StatusUp)) }) }) @@ -107,21 +112,34 @@ var _ = Describe("Healthcheck controller", func() { It("Returns UP", func() { healths["test3"] = h.State{Status: "failed"} c.thresholds["test3"] = 3 - aggregatedHealth := c.aggregate(healths) + aggregatedHealth := c.aggregate(ctx, healths) Expect(aggregatedHealth.Status).To(Equal(health.StatusUp)) }) }) When("All healths are UP", func() { It("Returns UP", func() { - aggregatedHealth := c.aggregate(healths) + aggregatedHealth := c.aggregate(ctx, healths) Expect(aggregatedHealth.Status).To(Equal(health.StatusUp)) }) }) - When("Aggregating healths", func() { - It("Includes them as overall details", func() { - aggregatedHealth := c.aggregate(healths) + When("Aggregating health as unauthorized user", func() { + It("should strip details and error", func() { + aggregatedHealth := c.aggregate(ctx, healths) + for name, h := range healths { + h.Status = convertStatus(h.Status) + h.Details = nil + h.Err = "" + Expect(aggregatedHealth.Details[name]).To(Equal(h)) + } + }) + }) + + When("Aggregating health as authorized user", func() { + It("should include all details and errors", func() { + ctx = web.ContextWithAuthorization(ctx) + aggregatedHealth := c.aggregate(ctx, healths) for name, h := range healths { h.Status = convertStatus(h.Status) Expect(aggregatedHealth.Details[name]).To(Equal(h)) diff --git a/api/healthcheck/platform_indicator.go b/api/healthcheck/platform_indicator.go new file mode 100644 index 000000000..6688090b8 --- /dev/null +++ b/api/healthcheck/platform_indicator.go @@ -0,0 +1,83 @@ +/* + * Copyright 2018 The Service Manager Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package healthcheck + +import ( + "context" + "fmt" + "github.com/Peripli/service-manager/pkg/health" + "github.com/Peripli/service-manager/pkg/types" + "github.com/Peripli/service-manager/storage" +) + +// NewPlatformIndicator returns new health indicator for platforms of given type +func NewPlatformIndicator(ctx context.Context, repository storage.Repository, fatal func(*types.Platform) bool) health.Indicator { + if fatal == nil { + fatal = func(platform *types.Platform) bool { + return true + } + } + return &platformIndicator{ + ctx: ctx, + repository: repository, + fatal: fatal, + } +} + +type platformIndicator struct { + repository storage.Repository + ctx context.Context + fatal func(*types.Platform) bool +} + +// Name returns the name of the indicator +func (pi *platformIndicator) Name() string { + return health.PlatformsIndicatorName +} + +// Status returns status of the health check +func (pi *platformIndicator) Status() (interface{}, error) { + objList, err := pi.repository.List(pi.ctx, types.PlatformType) + if err != nil { + return nil, fmt.Errorf("could not fetch platforms health from storage: %v", err) + } + platforms := objList.(*types.Platforms).Platforms + + details := make(map[string]*health.Health) + inactivePlatforms := 0 + fatalInactivePlatforms := 0 + for _, platform := range platforms { + if platform.Active { + details[platform.Name] = health.New().WithStatus(health.StatusUp). + WithDetail("type", platform.Type) + } else { + details[platform.Name] = health.New().WithStatus(health.StatusDown). + WithDetail("since", platform.LastActive). + WithDetail("type", platform.Type) + inactivePlatforms++ + if pi.fatal(platform) { + fatalInactivePlatforms++ + } + } + } + + if fatalInactivePlatforms > 0 { + err = fmt.Errorf("there are %d inactive platforms %d of them are fatal", inactivePlatforms, fatalInactivePlatforms) + } + + return details, err +} diff --git a/api/healthcheck/platform_indicator_test.go b/api/healthcheck/platform_indicator_test.go new file mode 100644 index 000000000..94ff4d4ea --- /dev/null +++ b/api/healthcheck/platform_indicator_test.go @@ -0,0 +1,91 @@ +/* + * Copyright 2018 The Service Manager Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package healthcheck + +import ( + "context" + "errors" + "github.com/Peripli/service-manager/pkg/health" + "github.com/Peripli/service-manager/pkg/types" + storagefakes2 "github.com/Peripli/service-manager/storage/storagefakes" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "time" +) + +var _ = Describe("Platforms Indicator", func() { + var indicator health.Indicator + var repository *storagefakes2.FakeStorage + var ctx context.Context + var platform *types.Platform + + BeforeEach(func() { + ctx = context.TODO() + repository = &storagefakes2.FakeStorage{} + platform = &types.Platform{ + Name: "test-platform", + Type: "kubernetes", + Active: false, + LastActive: time.Now(), + } + indicator = NewPlatformIndicator(ctx, repository, nil) + }) + + Context("Name", func() { + It("should not be empty", func() { + Expect(indicator.Name()).Should(Equal(health.PlatformsIndicatorName)) + }) + }) + + Context("There are inactive platforms", func() { + BeforeEach(func() { + objectList := &types.Platforms{[]*types.Platform{platform}} + repository.ListReturns(objectList, nil) + }) + It("should return error", func() { + details, err := indicator.Status() + health := details.(map[string]*health.Health)[platform.Name] + Expect(err).Should(HaveOccurred()) + Expect(health.Details["since"]).ShouldNot(BeNil()) + }) + }) + + Context("Storage returns error", func() { + var expectedErr error + BeforeEach(func() { + expectedErr = errors.New("storage err") + repository.ListReturns(nil, expectedErr) + }) + It("should return error", func() { + _, err := indicator.Status() + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring(expectedErr.Error())) + }) + }) + + Context("All platforms are active", func() { + BeforeEach(func() { + platform.Active = true + objectList := &types.Platforms{[]*types.Platform{platform}} + repository.ListReturns(objectList, nil) + }) + It("should not return error", func() { + _, err := indicator.Status() + Expect(err).ShouldNot(HaveOccurred()) + }) + }) +}) diff --git a/pkg/env/env_test.go b/pkg/env/env_test.go index f38eb03ef..bdbee1634 100644 --- a/pkg/env/env_test.go +++ b/pkg/env/env_test.go @@ -19,12 +19,11 @@ package env_test import ( "context" "fmt" - "github.com/Peripli/service-manager/pkg/log" + "github.com/fatih/structs" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/sirupsen/logrus" - "io/ioutil" "os" "testing" @@ -35,7 +34,6 @@ import ( "github.com/Peripli/service-manager/config" "github.com/Peripli/service-manager/pkg/env" - "github.com/fatih/structs" "github.com/spf13/cast" "github.com/spf13/pflag" "gopkg.in/yaml.v2" @@ -48,6 +46,7 @@ func TestEnv(t *testing.T) { var _ = Describe("Env", func() { const ( + mapKey = "mapkey" key = "key" description = "desc" flagDefaultValue = "pflagDefaultValue" @@ -67,6 +66,18 @@ var _ = Describe("Env", func() { keyNslice = "nest.nslice" keyNmappedVal = "nest.n_mapped_val" + keySquashNbool = "nbool" + keySquashNint = "nint" + keySquashNstring = "nstring" + keySquashNslice = "nslice" + keySquashNmappedVal = "n_mapped_val" + + keyMapNbool = "wmapnest" + "." + mapKey + "." + "nbool" + keyMapNint = "wmapnest" + "." + mapKey + "." + "nint" + keyMapNstring = "wmapnest" + "." + mapKey + "." + "nstring" + keyMapNslice = "wmapnest" + "." + mapKey + "." + "nslice" + keyMapNmappedVal = "wmapnest" + "." + mapKey + "." + "n_mapped_val" + keyLogFormat = "log.format" keyLogLevel = "log.level" ) @@ -83,17 +94,38 @@ var _ = Describe("Env", func() { WInt int WString string WMappedVal string `mapstructure:"w_mapped_val" structs:"w_mapped_val" yaml:"w_mapped_val"` + WMapNest map[string]Nest Nest Nest + Squash Nest `mapstructure:",squash"` Log log.Settings } + type FlatOuter struct { + WBool bool + WInt int + WString string + WMappedVal string `mapstructure:"w_mapped_val" structs:"w_mapped_val" yaml:"w_mapped_val"` + WMapNest map[string]Nest + Nest Nest + + // Flattened Nest fields due to squash tag + NBool bool + NInt int + NString string + NSlice []string + NMappedVal string `mapstructure:"n_mapped_val" structs:"n_mapped_val" yaml:"n_mapped_val"` + + Log log.Settings + } + type testFile struct { env.File content interface{} } var ( - structure Outer + outer Outer + flatOuter FlatOuter cfgFile testFile testFlags *pflag.FlagSet @@ -117,12 +149,24 @@ var _ = Describe("Env", func() { set.String(keyWstring, s.WString, description) set.String(keyWmappedVal, s.WMappedVal, description) + set.Bool(keySquashNbool, s.Squash.NBool, description) + set.Int(keySquashNint, s.Squash.NInt, description) + set.String(keySquashNstring, s.Squash.NString, description) + set.StringSlice(keySquashNslice, s.Squash.NSlice, description) + set.String(keySquashNmappedVal, s.Squash.NMappedVal, description) + set.Bool(keyNbool, s.Nest.NBool, description) set.Int(keyNint, s.Nest.NInt, description) set.String(keyNstring, s.Nest.NString, description) set.StringSlice(keyNslice, s.Nest.NSlice, description) set.String(keyNmappedVal, s.Nest.NMappedVal, description) + set.Bool(keyMapNbool, s.WMapNest[mapKey].NBool, description) + set.Int(keyMapNint, s.WMapNest[mapKey].NInt, description) + set.String(keyMapNstring, s.WMapNest[mapKey].NString, description) + set.StringSlice(keyMapNslice, s.WMapNest[mapKey].NSlice, description) + set.String(keyMapNmappedVal, s.WMapNest[mapKey].NMappedVal, description) + set.String(keyLogLevel, s.Log.Level, description) set.String(keyLogFormat, s.Log.Format, description) @@ -142,29 +186,52 @@ var _ = Describe("Env", func() { Expect(testFlags.Set(keyWstring, o.WString)).ShouldNot(HaveOccurred()) Expect(testFlags.Set(keyWmappedVal, o.WMappedVal)).ShouldNot(HaveOccurred()) + Expect(testFlags.Set(keySquashNbool, cast.ToString(o.Squash.NBool))).ShouldNot(HaveOccurred()) + Expect(testFlags.Set(keySquashNint, cast.ToString(o.Squash.NInt))).ShouldNot(HaveOccurred()) + Expect(testFlags.Set(keySquashNstring, o.Squash.NString)).ShouldNot(HaveOccurred()) + Expect(testFlags.Set(keySquashNslice, strings.Join(o.Squash.NSlice, ","))).ShouldNot(HaveOccurred()) + Expect(testFlags.Set(keySquashNmappedVal, o.Squash.NMappedVal)).ShouldNot(HaveOccurred()) + Expect(testFlags.Set(keyNbool, cast.ToString(o.Nest.NBool))).ShouldNot(HaveOccurred()) Expect(testFlags.Set(keyNint, cast.ToString(o.Nest.NInt))).ShouldNot(HaveOccurred()) Expect(testFlags.Set(keyNstring, o.Nest.NString)).ShouldNot(HaveOccurred()) Expect(testFlags.Set(keyNmappedVal, o.Nest.NMappedVal)).ShouldNot(HaveOccurred()) + Expect(testFlags.Set(keyMapNbool, cast.ToString(o.WMapNest[mapKey].NBool))).ShouldNot(HaveOccurred()) + Expect(testFlags.Set(keyMapNint, cast.ToString(o.WMapNest[mapKey].NInt))).ShouldNot(HaveOccurred()) + Expect(testFlags.Set(keyMapNstring, o.WMapNest[mapKey].NString)).ShouldNot(HaveOccurred()) + Expect(testFlags.Set(keyMapNmappedVal, o.WMapNest[mapKey].NMappedVal)).ShouldNot(HaveOccurred()) + Expect(testFlags.Set(keyLogFormat, o.Log.Format)).ShouldNot(HaveOccurred()) Expect(testFlags.Set(keyLogLevel, o.Log.Level)).ShouldNot(HaveOccurred()) } setEnvVars := func() { - Expect(os.Setenv(strings.ToTitle(keyWbool), cast.ToString(structure.WBool))).ShouldNot(HaveOccurred()) - Expect(os.Setenv(strings.ToTitle(keyWint), cast.ToString(structure.WInt))).ShouldNot(HaveOccurred()) - Expect(os.Setenv(strings.ToTitle(keyWstring), structure.WString)).ShouldNot(HaveOccurred()) - Expect(os.Setenv(strings.ToTitle(keyWmappedVal), structure.WMappedVal)).ShouldNot(HaveOccurred()) - - Expect(os.Setenv(strings.Replace(strings.ToTitle(keyNbool), ".", "_", 1), cast.ToString(structure.Nest.NBool))).ShouldNot(HaveOccurred()) - Expect(os.Setenv(strings.Replace(strings.ToTitle(keyNint), ".", "_", 1), cast.ToString(structure.Nest.NInt))).ShouldNot(HaveOccurred()) - Expect(os.Setenv(strings.Replace(strings.ToTitle(keyNstring), ".", "_", 1), structure.Nest.NString)).ShouldNot(HaveOccurred()) - Expect(os.Setenv(strings.Replace(strings.ToTitle(keyNslice), ".", "_", 1), strings.Join(structure.Nest.NSlice, ","))).ShouldNot(HaveOccurred()) - Expect(os.Setenv(strings.Replace(strings.ToTitle(keyNmappedVal), ".", "_", 1), structure.Nest.NMappedVal)).ShouldNot(HaveOccurred()) - - Expect(os.Setenv(strings.Replace(strings.ToTitle(keyLogFormat), ".", "_", 1), structure.Log.Format)).ShouldNot(HaveOccurred()) - Expect(os.Setenv(strings.Replace(strings.ToTitle(keyLogLevel), ".", "_", 1), structure.Log.Level)).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.ToTitle(keyWbool), cast.ToString(outer.WBool))).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.ToTitle(keyWint), cast.ToString(outer.WInt))).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.ToTitle(keyWstring), outer.WString)).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.ToTitle(keyWmappedVal), outer.WMappedVal)).ShouldNot(HaveOccurred()) + + Expect(os.Setenv(strings.ToTitle(keySquashNbool), cast.ToString(outer.Squash.NBool))).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.ToTitle(keySquashNint), cast.ToString(outer.Squash.NInt))).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.ToTitle(keySquashNstring), outer.Squash.NString)).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.ToTitle(keySquashNslice), strings.Join(outer.Squash.NSlice, ","))).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.ToTitle(keySquashNmappedVal), outer.Squash.NMappedVal)).ShouldNot(HaveOccurred()) + + Expect(os.Setenv(strings.Replace(strings.ToTitle(keyNbool), ".", "_", -1), cast.ToString(outer.Nest.NBool))).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.Replace(strings.ToTitle(keyNint), ".", "_", -1), cast.ToString(outer.Nest.NInt))).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.Replace(strings.ToTitle(keyNstring), ".", "_", -1), outer.Nest.NString)).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.Replace(strings.ToTitle(keyNslice), ".", "_", -1), strings.Join(outer.Nest.NSlice, ","))).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.Replace(strings.ToTitle(keyNmappedVal), ".", "_", -1), outer.Nest.NMappedVal)).ShouldNot(HaveOccurred()) + + Expect(os.Setenv(strings.Replace(strings.ToTitle(keyMapNbool), ".", "_", -1), cast.ToString(outer.WMapNest[mapKey].NBool))).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.Replace(strings.ToTitle(keyMapNint), ".", "_", -1), cast.ToString(outer.WMapNest[mapKey].NInt))).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.Replace(strings.ToTitle(keyMapNstring), ".", "_", -1), outer.WMapNest[mapKey].NString)).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.Replace(strings.ToTitle(keyMapNslice), ".", "_", -1), strings.Join(outer.WMapNest[mapKey].NSlice, ","))).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.Replace(strings.ToTitle(keyMapNmappedVal), ".", "_", -1), outer.WMapNest[mapKey].NMappedVal)).ShouldNot(HaveOccurred()) + + Expect(os.Setenv(strings.Replace(strings.ToTitle(keyLogFormat), ".", "_", -1), outer.Log.Format)).ShouldNot(HaveOccurred()) + Expect(os.Setenv(strings.Replace(strings.ToTitle(keyLogLevel), ".", "_", -1), outer.Log.Level)).ShouldNot(HaveOccurred()) } cleanUpEnvVars := func() { @@ -173,14 +240,26 @@ var _ = Describe("Env", func() { Expect(os.Unsetenv(strings.ToTitle(keyWstring))).ShouldNot(HaveOccurred()) Expect(os.Unsetenv(strings.ToTitle(keyWmappedVal))).ShouldNot(HaveOccurred()) - Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyNbool), ".", "_", 1))).ShouldNot(HaveOccurred()) - Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyNint), ".", "_", 1))).ShouldNot(HaveOccurred()) - Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyNstring), ".", "_", 1))).ShouldNot(HaveOccurred()) - Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyNslice), ".", "_", 1))).ShouldNot(HaveOccurred()) - Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyNmappedVal), ".", "_", 1))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.ToTitle(keySquashNbool))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.ToTitle(keySquashNint))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.ToTitle(keySquashNstring))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.ToTitle(keySquashNslice))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.ToTitle(keySquashNmappedVal))).ShouldNot(HaveOccurred()) - Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyLogFormat), ".", "_", 1))).ShouldNot(HaveOccurred()) - Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyLogLevel), ".", "_", 1))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyNbool), ".", "_", -1))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyNint), ".", "_", -1))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyNstring), ".", "_", -1))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyNslice), ".", "_", -1))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyNmappedVal), ".", "_", -1))).ShouldNot(HaveOccurred()) + + Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyMapNbool), ".", "_", -1))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyMapNint), ".", "_", -1))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyMapNstring), ".", "_", -1))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyMapNslice), ".", "_", -1))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyMapNmappedVal), ".", "_", -1))).ShouldNot(HaveOccurred()) + + Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyLogFormat), ".", "_", -1))).ShouldNot(HaveOccurred()) + Expect(os.Unsetenv(strings.Replace(strings.ToTitle(keyLogLevel), ".", "_", -1))).ShouldNot(HaveOccurred()) Expect(os.Unsetenv(strings.ToTitle(key))).ShouldNot(HaveOccurred()) } @@ -214,55 +293,77 @@ var _ = Describe("Env", func() { Expect(createEnv()).ShouldNot(HaveOccurred()) } - verifyEnvContainsValues := func(expected interface{}) { - props := structs.Map(expected) - for key, expectedValue := range props { - switch v := expectedValue.(type) { - case map[string]interface{}: - for nestedKey, nestedExpectedValue := range v { - expectedValue, ok := nestedExpectedValue.([]string) - if ok { - nestedExpectedValue = strings.Join(expectedValue, ",") - } - - envValue := environment.Get(key + "." + nestedKey) - switch actualValue := envValue.(type) { - case []string: - envValue = strings.Join(actualValue, ",") - case []interface{}: - temp := make([]string, len(actualValue)) - for i, v := range actualValue { - temp[i] = fmt.Sprint(v) - } - envValue = strings.Join(temp, ",") - } + var verifyValues func(expected map[string]interface{}, prefix string) - Expect(cast.ToString(envValue)).Should(Equal(cast.ToString(nestedExpectedValue))) + verifyValues = func(fields map[string]interface{}, prefix string) { + for name, value := range fields { + switch v := value.(type) { + case map[string]interface{}: + verifyValues(v, prefix+name+".") + case []string: + switch envVar := environment.Get(prefix + name).(type) { + case string: + Expect(envVar).To(Equal(strings.Join(v, ","))) + case []string, []interface{}: + Expect(fmt.Sprint(envVar)).To(Equal(fmt.Sprint(v))) + default: + Fail(fmt.Sprintf("Expected env value of type []string but got: %T", envVar)) } default: - Expect(cast.ToString(environment.Get(key))).To(Equal(cast.ToString(expectedValue))) + Expect(cast.ToString(environment.Get(prefix+name))).To(Equal(cast.ToString(v)), prefix+name) } } } + verifyEnvContainsValues := func(expected interface{}) { + fields := structs.Map(expected) + verifyValues(fields, "") + } + BeforeEach(func() { testFlags = env.EmptyFlagSet() - structure = Outer{ + nest := Nest{ + NBool: true, + NInt: 4321, + NString: "nstringval", + NSlice: []string{"nval1", "nval2", "nval3"}, + NMappedVal: "nmappedval", + } + + outer = Outer{ WBool: true, WInt: 1234, WString: "wstringval", WMappedVal: "wmappedval", + Squash: nest, Log: log.Settings{ Level: "error", Format: "text", }, - Nest: Nest{ - NBool: true, - NInt: 4321, - NString: "nstringval", - NSlice: []string{"nval1", "nval2", "nval3"}, - NMappedVal: "nmappedval", + Nest: nest, + WMapNest: map[string]Nest{ + mapKey: nest, + }, + } + + flatOuter = FlatOuter{ + WBool: true, + WInt: 1234, + WString: "wstringval", + WMappedVal: "wmappedval", + NBool: true, + NInt: 4321, + NString: "nstringval", + NSlice: []string{"nval1", "nval2", "nval3"}, + NMappedVal: "nmappedval", + Log: log.Settings{ + Level: "error", + Format: "text", + }, + Nest: nest, + WMapNest: map[string]Nest{ + mapKey: nest, }, } }) @@ -280,19 +381,19 @@ var _ = Describe("Env", func() { ) It("adds viper bindings for the provided flags", func() { - testFlags.AddFlagSet(standardPFlagsSet(structure)) + testFlags.AddFlagSet(standardPFlagsSet(outer)) cfgFile.content = nil verifyEnvCreated() - verifyEnvContainsValues(structure) + verifyEnvContainsValues(flatOuter) }) Context("when SM config file exists", func() { BeforeEach(func() { cfgFile = testFile{ File: env.DefaultConfigFile(), - content: structure, + content: flatOuter, } }) @@ -348,7 +449,7 @@ var _ = Describe("Env", func() { It("reads the file in the environment", func() { verifyEnvCreated() - verifyEnvContainsValues(structure) + verifyEnvContainsValues(flatOuter) }) It("returns an err if config file loading fails", func() { @@ -370,7 +471,7 @@ var _ = Describe("Env", func() { Expect(log.D().Logger.Out.(*os.File).Name()).ToNot(Equal(newOutput)) f := cfgFile.Location + string(filepath.Separator) + cfgFile.Name + "." + cfgFile.Format - fileContent := cfgFile.content.(Outer) + fileContent := cfgFile.content.(FlatOuter) fileContent.Log.Level = logrus.DebugLevel.String() fileContent.Log.Output = newOutput cfgFile.content = fileContent @@ -425,20 +526,37 @@ var _ = Describe("Env", func() { Describe("Get", func() { var overrideStructure Outer + var overrideStructureOutput FlatOuter BeforeEach(func() { + nest := Nest{ + NBool: false, + NInt: 9999, + NString: "overrideval", + NSlice: []string{"nval1", "nval2", "nval3"}, + NMappedVal: "overrideval", + } + overrideStructure = Outer{ WBool: false, WInt: 8888, WString: "overrideval", WMappedVal: "overrideval", - Nest: Nest{ - NBool: false, - NInt: 9999, - NString: "overrideval", - NSlice: []string{"nval1", "nval2", "nval3"}, - NMappedVal: "overrideval", - }, + Nest: nest, + Squash: nest, + } + + overrideStructureOutput = FlatOuter{ + WBool: false, + WInt: 8888, + WString: "overrideval", + WMappedVal: "overrideval", + Nest: nest, + NBool: false, + NInt: 9999, + NString: "overrideval", + NSlice: []string{"nval1", "nval2", "nval3"}, + NMappedVal: "overrideval", } }) @@ -452,34 +570,34 @@ var _ = Describe("Env", func() { Context("when properties are loaded via standard pflags", func() { BeforeEach(func() { - testFlags.AddFlagSet(standardPFlagsSet(structure)) + testFlags.AddFlagSet(standardPFlagsSet(outer)) }) It("returns the default flag value if the flag is not set", func() { - verifyEnvContainsValues(structure) + verifyEnvContainsValues(flatOuter) }) It("returns the flags values if the flags are set", func() { setPFlags(overrideStructure) - verifyEnvContainsValues(overrideStructure) + verifyEnvContainsValues(overrideStructureOutput) }) }) Context("when properties are loaded via generated pflags", func() { BeforeEach(func() { - testFlags.AddFlagSet(generatedPFlagsSet(structure)) + testFlags.AddFlagSet(generatedPFlagsSet(outer)) }) It("returns the default flag value if the flag is not set", func() { - verifyEnvContainsValues(structure) + verifyEnvContainsValues(flatOuter) }) It("returns the flags values if the flags are set", func() { setPFlags(overrideStructure) - verifyEnvContainsValues(overrideStructure) + verifyEnvContainsValues(overrideStructureOutput) }) }) @@ -487,14 +605,14 @@ var _ = Describe("Env", func() { BeforeEach(func() { cfgFile = testFile{ File: env.DefaultConfigFile(), - content: structure, + content: flatOuter, } config.AddPFlags(testFlags) verifyEnvCreated() }) It("returns values from the config file", func() { - verifyEnvContainsValues(structure) + verifyEnvContainsValues(flatOuter) }) }) @@ -504,7 +622,7 @@ var _ = Describe("Env", func() { }) It("returns values from environment", func() { - verifyEnvContainsValues(structure) + verifyEnvContainsValues(flatOuter) }) }) @@ -552,9 +670,9 @@ var _ = Describe("Env", func() { It("has highest priority", func() { testFlags.AddFlagSet(singlePFlagSet(key, flagDefaultValue, description)) - os.Setenv(key, envValue) + Expect(os.Setenv(key, envValue)).ToNot(HaveOccurred()) verifyEnvCreated() - testFlags.Set(key, flagValue) + Expect(testFlags.Set(key, flagValue)).ToNot(HaveOccurred()) environment.Set(key, overrideValue) @@ -566,7 +684,11 @@ var _ = Describe("Env", func() { var actual Outer BeforeEach(func() { - actual = Outer{} + actual = Outer{ + WMapNest: map[string]Nest{ + mapKey: {}, + }, + } }) JustBeforeEach(func() { @@ -596,21 +718,21 @@ var _ = Describe("Env", func() { Context("when properties are loaded via standard pflags", func() { BeforeEach(func() { - testFlags.AddFlagSet(standardPFlagsSet(structure)) + testFlags.AddFlagSet(standardPFlagsSet(outer)) }) It("unmarshals correctly", func() { - verifyUnmarshallingIsCorrect(&actual, &structure) + verifyUnmarshallingIsCorrect(&actual, &outer) }) }) Context("when properties are loaded via generated pflags", func() { BeforeEach(func() { - testFlags.AddFlagSet(generatedPFlagsSet(structure)) + testFlags.AddFlagSet(generatedPFlagsSet(outer)) }) It("unmarshals correctly", func() { - verifyUnmarshallingIsCorrect(&actual, &structure) + verifyUnmarshallingIsCorrect(&actual, &outer) }) }) @@ -618,13 +740,13 @@ var _ = Describe("Env", func() { BeforeEach(func() { cfgFile = testFile{ File: env.DefaultConfigFile(), - content: structure, + content: flatOuter, } config.AddPFlags(testFlags) }) It("unmarshals correctly", func() { - verifyUnmarshallingIsCorrect(&actual, &structure) + verifyUnmarshallingIsCorrect(&actual, &outer) }) }) @@ -634,7 +756,7 @@ var _ = Describe("Env", func() { }) It("unmarshals correctly", func() { - verifyUnmarshallingIsCorrect(&actual, &structure) + verifyUnmarshallingIsCorrect(&actual, &outer) }) }) diff --git a/pkg/env/helpers.go b/pkg/env/helpers.go index 3e6930fb3..61322aba7 100644 --- a/pkg/env/helpers.go +++ b/pkg/env/helpers.go @@ -89,47 +89,63 @@ func buildDescriptionPaths(root *descriptionTree, path []*descriptionTree) []str } func buildDescriptionTreeWithParameters(value interface{}, tree *descriptionTree, buffer string, result *[]configurationParameter) { - if !structs.IsStruct(value) { - index := strings.LastIndex(buffer, ".") - if index == -1 { - index = 0 - } - key := strings.ToLower(buffer[0:index]) - *result = append(*result, configurationParameter{Name: key, DefaultValue: value}) - tree.Children = nil - return - } - s := structs.New(value) - k := 0 - for _, field := range s.Fields() { - if isValidField(field) { - var name string - if field.Tag("mapstructure") != "" { - name = field.Tag("mapstructure") - } else { - name = field.Name() + v := reflect.ValueOf(value) + switch v.Kind() { + case reflect.Map: + for _, key := range v.MapKeys() { + field := v.MapIndex(key).Interface() + if isValidField(field) { + name := key.String() + buffer += name + "." + buildDescriptionTreeWithParameters(field, tree, buffer, result) + buffer = buffer[0:strings.LastIndex(buffer, name)] } - - if name == "-" || name == ",squash" { - continue + } + default: + if !structs.IsStruct(value) { + index := strings.LastIndex(buffer, ".") + if index == -1 { + index = 0 } - - description := "" - if field.Tag("description") != "" { - description = field.Tag("description") + key := strings.ToLower(buffer[0:index]) + *result = append(*result, configurationParameter{Name: key, DefaultValue: value}) + tree.Children = nil + return + } + s := structs.New(value) + for _, field := range s.Fields() { + if isValidField(field) { + var name string + if field.Tag("mapstructure") != "" { + name = field.Tag("mapstructure") + } else { + name = field.Name() + } + if name == "-" { + continue + } + if name == ",squash" { + buildDescriptionTreeWithParameters(field.Value(), tree, buffer, result) + continue + } + description := "" + if field.Tag("description") != "" { + description = field.Tag("description") + } + + baseTree := newDescriptionTree(description) + tree.AddNode(baseTree) + buildDescriptionTreeWithParameters(field.Value(), baseTree, buffer+name+".", result) } - - baseTree := newDescriptionTree(description) - tree.AddNode(baseTree) - buffer += name + "." - buildDescriptionTreeWithParameters(field.Value(), tree.Children[k], buffer, result) - k++ - buffer = buffer[0:strings.LastIndex(buffer, name)] } } } -func isValidField(field *structs.Field) bool { - kind := field.Kind() - return field.IsExported() && kind != reflect.Interface && kind != reflect.Func +func isValidField(field interface{}) bool { + if fieldStruct, ok := field.(*structs.Field); ok { + kind := fieldStruct.Kind() + return fieldStruct.IsExported() && kind != reflect.Interface && kind != reflect.Func + } + kind := reflect.ValueOf(field).Kind() + return kind != reflect.Interface && kind != reflect.Func } diff --git a/pkg/health/types.go b/pkg/health/types.go index dac0ec81d..dbf5ac851 100644 --- a/pkg/health/types.go +++ b/pkg/health/types.go @@ -26,16 +26,36 @@ import ( "time" ) +// StorageIndicatorName is the name of storage indicator +const StorageIndicatorName = "storage" + +// PlatformsIndicatorName is the name of platforms indicator +const PlatformsIndicatorName = "platforms" + +// indicatorNames is a list of names of indicators which will be registered with default settings +// as part of default health settings, this will allow binding them as part of environment. +// If an indicator is registered but not specified in this list, it will be configured with +// default settings again, but this defaults could be overridden only via application.yml, +// env variables and pflags won't have any effect. If an indicator is specified in this list +// but later not registered nothing will happen. +var indicatorNames = [...]string{ + StorageIndicatorName, + PlatformsIndicatorName, +} + // Settings type to be loaded from the environment type Settings struct { - Indicators map[string]*IndicatorSettings `mapstructure:"indicators,omitempty"` + Indicators map[string]*IndicatorSettings `mapstructure:"indicators"` } // DefaultSettings returns default values for health settings func DefaultSettings() *Settings { - emptySettings := make(map[string]*IndicatorSettings) + defaultIndicatorSettings := make(map[string]*IndicatorSettings) + for _, name := range indicatorNames { + defaultIndicatorSettings[name] = DefaultIndicatorSettings() + } return &Settings{ - Indicators: emptySettings, + Indicators: defaultIndicatorSettings, } } @@ -68,7 +88,7 @@ func DefaultIndicatorSettings() *IndicatorSettings { // Validate validates indicator settings func (is *IndicatorSettings) Validate() error { if !is.Fatal && is.FailuresThreshold != 0 { - return fmt.Errorf("validate Settings: FailuresThreshold not applicable for non-fatal indicators") + return fmt.Errorf("validate Settings: FailuresThreshold must be 0 for non-fatal indicators") } if is.Fatal && is.FailuresThreshold <= 0 { return fmt.Errorf("validate Settings: FailuresThreshold must be > 0 for fatal indicators") @@ -94,7 +114,12 @@ const ( type StatusListener struct{} func (sl *StatusListener) HealthCheckFailed(state *health.State) { - log.D().Errorf("Health check for %v failed with: %v", state.Name, state.Err) + msg := fmt.Sprintf("Health check for %v failed with: %v", state.Name, state.Err) + if state.Fatal { + log.D().Error(msg) + } else { + log.D().Warn(msg) + } } func (sl *StatusListener) HealthCheckRecovered(state *health.State, numberOfFailures int64, unavailableDuration float64) { @@ -164,6 +189,17 @@ type Registry struct { HealthIndicators []Indicator } +// SetIndicator adds or replaces existing indicator with same name in registry +func (r *Registry) SetIndicator(healthIndicator Indicator) { + for i, indicator := range r.HealthIndicators { + if indicator.Name() == healthIndicator.Name() { + r.HealthIndicators[i] = healthIndicator + return + } + } + r.HealthIndicators = append(r.HealthIndicators, healthIndicator) +} + // Configure creates new health using provided settings. func Configure(ctx context.Context, indicators []Indicator, settings *Settings) (*h.Health, map[string]int64, error) { healthz := h.New() diff --git a/pkg/sm/sm.go b/pkg/sm/sm.go index 97808b447..f13496209 100644 --- a/pkg/sm/sm.go +++ b/pkg/sm/sm.go @@ -112,7 +112,7 @@ func New(ctx context.Context, cancel context.CancelFunc, cfg *config.Settings) ( // Setup core API log.C(ctx).Info("Setting up Service Manager core API...") - pgNotificator, err := postgres.NewNotificator(smStorage, cfg.Storage) + pgNotificator, err := postgres.NewNotificator(smStorage, interceptableRepository, cfg.Storage) if err != nil { return nil, fmt.Errorf("could not create notificator: %v", err) } @@ -133,7 +133,8 @@ func New(ctx context.Context, cancel context.CancelFunc, cfg *config.Settings) ( return nil, fmt.Errorf("error creating storage health indicator: %s", err) } - API.HealthIndicators = append(API.HealthIndicators, storageHealthIndicator) + API.SetIndicator(storageHealthIndicator) + API.SetIndicator(healthcheck.NewPlatformIndicator(ctx, interceptableRepository, nil)) notificationCleaner := &storage.NotificationCleaner{ Storage: interceptableRepository, diff --git a/pkg/types/platform.go b/pkg/types/platform.go index 0653b6e69..4372aeda7 100644 --- a/pkg/types/platform.go +++ b/pkg/types/platform.go @@ -19,6 +19,7 @@ package types import ( "errors" "fmt" + "time" "github.com/Peripli/service-manager/pkg/util" ) @@ -32,6 +33,8 @@ type Platform struct { Name string `json:"name"` Description string `json:"description"` Credentials *Credentials `json:"credentials,omitempty"` + Active bool `json:"-"` + LastActive time.Time `json:"-"` } func (e *Platform) SetCredentials(credentials *Credentials) { diff --git a/storage/healthcheck.go b/storage/healthcheck.go index d74d07396..9ae7b53fe 100644 --- a/storage/healthcheck.go +++ b/storage/healthcheck.go @@ -45,5 +45,5 @@ type SQLHealthIndicator struct { // Name returns the name of the storage component func (i *SQLHealthIndicator) Name() string { - return "storage" + return health.StorageIndicatorName } diff --git a/storage/postgres/keystore_test.go b/storage/postgres/keystore_test.go index 93f69fe7f..8a7a13f33 100644 --- a/storage/postgres/keystore_test.go +++ b/storage/postgres/keystore_test.go @@ -59,7 +59,7 @@ var _ = Describe("Secured Storage", func() { mock.ExpectQuery(`SELECT CURRENT_DATABASE()`).WillReturnRows(sqlmock.NewRows([]string{"mock"}).FromCSVString("mock")) mock.ExpectQuery(`SELECT COUNT(1)*`).WillReturnRows(sqlmock.NewRows([]string{"mock"}).FromCSVString("1")) mock.ExpectExec("SELECT pg_advisory_lock*").WithArgs(sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectQuery(`SELECT version, dirty FROM "schema_migrations" LIMIT 1`).WillReturnRows(sqlmock.NewRows([]string{"version", "dirty"}).FromCSVString("20190816162000,false")) + mock.ExpectQuery(`SELECT version, dirty FROM "schema_migrations" LIMIT 1`).WillReturnRows(sqlmock.NewRows([]string{"version", "dirty"}).FromCSVString("20190829101500,false")) mock.ExpectExec("SELECT pg_advisory_unlock*").WithArgs(sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) options := storage.DefaultSettings() options.EncryptionKey = string(envEncryptionKey) diff --git a/storage/postgres/migrations/20190829101500_add_platforms_status.down.sql b/storage/postgres/migrations/20190829101500_add_platforms_status.down.sql new file mode 100644 index 000000000..0627176db --- /dev/null +++ b/storage/postgres/migrations/20190829101500_add_platforms_status.down.sql @@ -0,0 +1,6 @@ +BEGIN; + +ALTER TABLE platforms DROP COLUMN active; +ALTER TABLE platforms DROP COLUMN last_active; + +COMMIT; \ No newline at end of file diff --git a/storage/postgres/migrations/20190829101500_add_platforms_status.up.sql b/storage/postgres/migrations/20190829101500_add_platforms_status.up.sql new file mode 100644 index 000000000..aa3ccbcb4 --- /dev/null +++ b/storage/postgres/migrations/20190829101500_add_platforms_status.up.sql @@ -0,0 +1,6 @@ +BEGIN; + +ALTER TABLE platforms ADD COLUMN active boolean NOT NULL DEFAULT '0'; +ALTER TABLE platforms ADD COLUMN last_active TIMESTAMP NOT NULL DEFAULT '0001-01-01 00:00:00+00'; + +COMMIT; \ No newline at end of file diff --git a/storage/postgres/notificator.go b/storage/postgres/notificator.go index 166aac41c..3477d3901 100644 --- a/storage/postgres/notificator.go +++ b/storage/postgres/notificator.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/Peripli/service-manager/pkg/query" "sync" "sync/atomic" "time" @@ -66,7 +67,7 @@ type Notificator struct { } // NewNotificator returns new Notificator based on a given NotificatorStorage and desired queue size -func NewNotificator(st storage.Storage, settings *storage.Settings) (*Notificator, error) { +func NewNotificator(st storage.Storage, repository storage.TransactionalRepository, settings *storage.Settings) (*Notificator, error) { ns, err := NewNotificationStorage(st) connectionCreator := ¬ificationConnectionCreatorImpl{ storageURI: settings.URI, @@ -82,8 +83,9 @@ func NewNotificator(st storage.Storage, settings *storage.Settings) (*Notificato connectionMutex: &sync.Mutex{}, consumersMutex: &sync.Mutex{}, consumers: &consumers{ - queues: make(map[string][]storage.NotificationQueue), - platforms: make([]*types.Platform, 0), + repository: repository, + queues: make(map[string][]storage.NotificationQueue), + platforms: make([]*types.Platform, 0), }, storage: ns, connectionCreator: connectionCreator, @@ -99,6 +101,7 @@ func (n *Notificator) Start(ctx context.Context, group *sync.WaitGroup) error { return errors.New("notificator already started") } n.ctx = ctx + n.consumers.ctx = ctx n.setConnection(n.connectionCreator.NewConnection(func(isConnected bool, err error) { if isConnected { atomic.StoreInt32(&n.isConnected, aTrue) @@ -145,7 +148,9 @@ func (n *Notificator) addConsumer(platform *types.Platform, queue storage.Notifi } n.consumersMutex.Lock() defer n.consumersMutex.Unlock() - n.consumers.Add(platform, queue) + if err := n.consumers.Add(platform, queue); err != nil { + return types.InvalidRevision, err + } return atomic.LoadInt64(&n.lastKnownRevision), nil } @@ -259,7 +264,9 @@ func (n *Notificator) UnregisterConsumer(queue storage.NotificationQueue) error if n.consumers.Len() == 0 { return nil // Consumer already unregistered } - n.consumers.Delete(queue) + if err := n.consumers.Delete(queue); err != nil { + return err + } if n.consumers.Len() == 0 { log.C(n.ctx).Debugf("No notification consumers left. Stop listening to channel %s", postgresChannel) n.stopProcessing() // stop processing notifications as there are no consumers @@ -419,8 +426,10 @@ func (n *Notificator) stopConnection() { } type consumers struct { - queues map[string][]storage.NotificationQueue - platforms []*types.Platform + repository storage.TransactionalRepository + ctx context.Context + queues map[string][]storage.NotificationQueue + platforms []*types.Platform } func (c *consumers) find(queueID string) (string, int) { @@ -443,10 +452,10 @@ func (c *consumers) ReplaceQueue(queueID string, newQueue storage.NotificationQu return nil } -func (c *consumers) Delete(queue storage.NotificationQueue) { +func (c *consumers) Delete(queue storage.NotificationQueue) error { platformIDToDelete, queueIndex := c.find(queue.ID()) if queueIndex == -1 { - return + return nil } platformConsumers := c.queues[platformIDToDelete] c.queues[platformIDToDelete] = append(platformConsumers[:queueIndex], platformConsumers[queueIndex+1:]...) @@ -456,17 +465,32 @@ func (c *consumers) Delete(queue storage.NotificationQueue) { for index, platform := range c.platforms { if platform.ID == platformIDToDelete { c.platforms = append(c.platforms[:index], c.platforms[index+1:]...) + err := c.updatePlatform(platform.ID, func(p *types.Platform) { + p.Active = false + p.LastActive = time.Now() + }) + if err != nil { + return err + } break } } } + return nil } -func (c *consumers) Add(platform *types.Platform, queue storage.NotificationQueue) { +func (c *consumers) Add(platform *types.Platform, queue storage.NotificationQueue) error { if len(c.queues[platform.ID]) == 0 { c.platforms = append(c.platforms, platform) + err := c.updatePlatform(platform.ID, func(p *types.Platform) { + p.Active = true + }) + if err != nil { + return err + } } c.queues[platform.ID] = append(c.queues[platform.ID], queue) + return nil } func (c *consumers) Clear() map[string][]storage.NotificationQueue { @@ -492,3 +516,29 @@ func (c *consumers) GetPlatform(platformID string) *types.Platform { func (c *consumers) GetQueuesForPlatform(platformID string) []storage.NotificationQueue { return c.queues[platformID] } + +func (c *consumers) updatePlatform(platformID string, updatePlatformFunc func(p *types.Platform)) error { + if err := c.repository.InTransaction(c.ctx, func(ctx context.Context, storage storage.Repository) error { + idCriteria := query.Criterion{ + LeftOp: "id", + Operator: query.EqualsOperator, + RightOp: []string{platformID}, + Type: query.FieldQuery, + } + obj, err := storage.Get(ctx, types.PlatformType, idCriteria) + if err != nil { + return err + } + + platform := obj.(*types.Platform) + updatePlatformFunc(platform) + + if _, err := storage.Update(ctx, platform, nil); err != nil { + return err + } + return nil + }); err != nil { + return err + } + return nil +} diff --git a/storage/postgres/notificator_test.go b/storage/postgres/notificator_test.go index c3b3a3072..da9e99b85 100644 --- a/storage/postgres/notificator_test.go +++ b/storage/postgres/notificator_test.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "errors" + "github.com/Peripli/service-manager/storage/storagefakes" "sync" "time" @@ -52,7 +53,8 @@ var _ = Describe("Notificator", func() { ctx context.Context cancel context.CancelFunc wg *sync.WaitGroup - fakeStorage *postgresfakes.FakeNotificationStorage + fakeStorage *storagefakes.FakeStorage + fakeNotificationStorage *postgresfakes.FakeNotificationStorage fakeConnectionCreator *postgresfakes.FakeNotificationConnectionCreator testNotificator storage.Notificator fakeNotificationConnection *notificationConnectionFakes.FakeNotificationConnection @@ -95,10 +97,11 @@ var _ = Describe("Notificator", func() { connectionMutex: &sync.Mutex{}, consumersMutex: &sync.Mutex{}, consumers: &consumers{ - queues: make(map[string][]storage.NotificationQueue), - platforms: make([]*types.Platform, 0), + repository: fakeStorage, + queues: make(map[string][]storage.NotificationQueue), + platforms: make([]*types.Platform, 0), }, - storage: fakeStorage, + storage: fakeNotificationStorage, connectionCreator: fakeConnectionCreator, stopProcessing: func() {}, lastKnownRevision: types.InvalidRevision, @@ -149,8 +152,9 @@ var _ = Describe("Notificator", func() { ID: "platformID", }, } - fakeStorage = &postgresfakes.FakeNotificationStorage{} - fakeStorage.GetLastRevisionReturns(defaultLastRevision, nil) + fakeStorage = &storagefakes.FakeStorage{} + fakeNotificationStorage = &postgresfakes.FakeNotificationStorage{} + fakeNotificationStorage.GetLastRevisionReturns(defaultLastRevision, nil) fakeNotificationConnection = ¬ificationConnectionFakes.FakeNotificationConnection{} fakeNotificationConnection.ListenReturns(nil) fakeNotificationConnection.UnlistenReturns(nil) @@ -215,9 +219,9 @@ var _ = Describe("Notificator", func() { } }) notification = createNotification("") - fakeStorage.GetNotificationByRevisionReturns(notification, nil) - fakeStorage.ListNotificationsReturns([]*types.Notification{notification}, nil) - fakeStorage.GetNotificationReturns(notification, nil) + fakeNotificationStorage.GetNotificationByRevisionReturns(notification, nil) + fakeNotificationStorage.ListNotificationsReturns([]*types.Notification{notification}, nil) + fakeNotificationStorage.GetNotificationReturns(notification, nil) Expect(testNotificator.Start(ctx, wg)).ToNot(HaveOccurred()) runningFunc(true, nil) @@ -333,7 +337,7 @@ var _ = Describe("Notificator", func() { Context("When storage GetLastRevision fails", func() { BeforeEach(func() { - fakeStorage.GetLastRevisionReturns(types.InvalidRevision, expectedError) + fakeNotificationStorage.GetLastRevisionReturns(types.InvalidRevision, expectedError) }) It("Should return error", func() { @@ -373,7 +377,7 @@ var _ = Describe("Notificator", func() { unlistenCalled <- struct{}{} return nil } - fakeStorage.GetNotificationByRevisionReturns(nil, expectedError) + fakeNotificationStorage.GetNotificationByRevisionReturns(nil, expectedError) expectRegisterConsumerFail(expectedError.Error(), defaultLastRevision-1) expectUnlistenCalled(unlistenCalled) }) @@ -387,7 +391,7 @@ var _ = Describe("Notificator", func() { unlistenCalled <- struct{}{} return errors.New("unlisten error") } - fakeStorage.GetNotificationByRevisionReturns(nil, expectedError) + fakeNotificationStorage.GetNotificationByRevisionReturns(nil, expectedError) expectRegisterConsumerFail(expectedError.Error(), defaultLastRevision-1) expectUnlistenCalled(unlistenCalled) }) @@ -395,14 +399,14 @@ var _ = Describe("Notificator", func() { Context("When storage returns \"not found\" error when getting notification with revision", func() { It("Should return ErrInvalidNotificationRevision", func() { - fakeStorage.GetNotificationByRevisionReturns(nil, util.ErrNotFoundInStorage) + fakeNotificationStorage.GetNotificationByRevisionReturns(nil, util.ErrNotFoundInStorage) expectRegisterConsumerFail(util.ErrInvalidNotificationRevision.Error(), defaultLastRevision-1) }) }) Context("When storage returns error on notification list", func() { It("Should return the error", func() { - fakeStorage.ListNotificationsReturns(nil, expectedError) + fakeNotificationStorage.ListNotificationsReturns(nil, expectedError) expectRegisterConsumerFail(expectedError.Error(), defaultLastRevision-1) }) }) @@ -413,7 +417,7 @@ var _ = Describe("Notificator", func() { for i := 0; i < defaultQueueSize+1; i++ { notificationsToReturn = append(notificationsToReturn, createNotification("")) } - fakeStorage.ListNotificationsReturns(notificationsToReturn, nil) + fakeNotificationStorage.ListNotificationsReturns(notificationsToReturn, nil) expectRegisterConsumerFail(util.ErrInvalidNotificationRevision.Error(), defaultLastRevision-1) }) }) @@ -422,9 +426,9 @@ var _ = Describe("Notificator", func() { It("Should be in the returned queue", func() { n1 := createNotification("") n2 := createNotification("") - fakeStorage.GetNotificationByRevisionReturns(n1, nil) - fakeStorage.ListNotificationsReturns([]*types.Notification{n1}, nil) - fakeStorage.GetNotificationReturns(n2, nil) + fakeNotificationStorage.GetNotificationByRevisionReturns(n1, nil) + fakeNotificationStorage.ListNotificationsReturns([]*types.Notification{n1}, nil) + fakeNotificationStorage.GetNotificationReturns(n2, nil) queue = expectRegisterConsumerSuccess(defaultPlatform, defaultLastRevision-1) queueChannel := queue.Channel() Expect(<-queueChannel).To(Equal(n1)) @@ -477,7 +481,7 @@ var _ = Describe("Notificator", func() { Context("When notification is sent", func() { It("Should be received in the queue", func() { notification := createNotification(defaultPlatform.ID) - fakeStorage.GetNotificationReturns(notification, nil) + fakeNotificationStorage.GetNotificationReturns(notification, nil) notificationChannel <- &pq.Notification{ Extra: createNotificationPayload(defaultPlatform.ID, notification.ID), } @@ -487,7 +491,7 @@ var _ = Describe("Notificator", func() { Context("When notification cannot be fetched from db", func() { fetchNotificationFromDBFail := func(platformID string) { - fakeStorage.GetNotificationReturns(nil, expectedError) + fakeNotificationStorage.GetNotificationReturns(nil, expectedError) ch := queue.Channel() notificationChannel <- &pq.Notification{ Extra: createNotificationPayload(platformID, "some_id"), @@ -514,7 +518,7 @@ var _ = Describe("Notificator", func() { q := registerDefaultPlatform() notification := createNotification("") - fakeStorage.GetNotificationReturns(notification, nil) + fakeNotificationStorage.GetNotificationReturns(notification, nil) notificationChannel <- &pq.Notification{ Extra: createNotificationPayload("", notification.ID), } @@ -526,7 +530,7 @@ var _ = Describe("Notificator", func() { Context("When notification is sent with unregistered platform ID", func() { It("Should call storage once", func() { notification := createNotification(defaultPlatform.ID) - fakeStorage.GetNotificationReturns(notification, nil) + fakeNotificationStorage.GetNotificationReturns(notification, nil) notificationChannel <- &pq.Notification{ Extra: createNotificationPayload("not_registered", "some_id"), } @@ -534,7 +538,7 @@ var _ = Describe("Notificator", func() { Extra: createNotificationPayload(defaultPlatform.ID, notification.ID), } expectReceivedNotification(notification, queue) - Expect(fakeStorage.GetNotificationCallCount()).To(Equal(1)) + Expect(fakeNotificationStorage.GetNotificationCallCount()).To(Equal(1)) }) }) @@ -552,7 +556,7 @@ var _ = Describe("Notificator", func() { Context("When notification is null", func() { It("Should not send notification", func() { notification := createNotification(defaultPlatform.ID) - fakeStorage.GetNotificationReturns(notification, nil) + fakeNotificationStorage.GetNotificationReturns(notification, nil) notificationChannel <- nil notificationChannel <- &pq.Notification{ Extra: createNotificationPayload(defaultPlatform.ID, notification.ID), @@ -576,7 +580,7 @@ var _ = Describe("Notificator", func() { It("Should close notification queue", func() { q := registerDefaultPlatform() notification := createNotification(defaultPlatform.ID) - fakeStorage.GetNotificationReturns(notification, nil) + fakeNotificationStorage.GetNotificationReturns(notification, nil) ch := q.Channel() notificationChannel <- &pq.Notification{ Extra: createNotificationPayload(defaultPlatform.ID, notification.ID), diff --git a/storage/postgres/platform.go b/storage/postgres/platform.go index 2f43e19b1..4921c02f6 100644 --- a/storage/postgres/platform.go +++ b/storage/postgres/platform.go @@ -18,6 +18,7 @@ package postgres import ( "database/sql" + "time" "github.com/Peripli/service-manager/storage" @@ -33,6 +34,8 @@ type Platform struct { Description sql.NullString `db:"description"` Username string `db:"username"` Password string `db:"password"` + Active bool `db:"active"` + LastActive time.Time `db:"last_active"` } func (p *Platform) FromObject(object types.Object) (storage.Entity, bool) { @@ -49,6 +52,8 @@ func (p *Platform) FromObject(object types.Object) (storage.Entity, bool) { Type: platform.Type, Name: platform.Name, Description: toNullString(platform.Description), + Active: platform.Active, + LastActive: platform.LastActive, } if platform.Description != "" { @@ -77,5 +82,7 @@ func (p *Platform) ToObject() types.Object { Password: p.Password, }, }, + Active: p.Active, + LastActive: p.LastActive, } } diff --git a/test/ws_notification_test/ws_notification_test.go b/test/ws_notification_test/ws_notification_test.go index 20bd9fe11..e0b548b81 100644 --- a/test/ws_notification_test/ws_notification_test.go +++ b/test/ws_notification_test/ws_notification_test.go @@ -235,6 +235,55 @@ var _ = Describe("WS", func() { }) }) + Context("when platform is connected", func() { + It("should switch platform's active status to true", func() { + Expect(platform.Active).To(BeFalse()) + _, _, err := ctx.ConnectWebSocket(platform, queryParams) + Expect(err).ShouldNot(HaveOccurred()) + + idCriteria := query.Criterion{ + LeftOp: "id", + Operator: query.EqualsOperator, + RightOp: []string{platform.ID}, + Type: query.FieldQuery, + } + obj, err := repository.Get(context.TODO(), types.PlatformType, idCriteria) + Expect(err).ShouldNot(HaveOccurred()) + Expect(obj.(*types.Platform).Active).To(BeTrue()) + }) + }) + + Context("when platform disconnects", func() { + It("should switch platform's active status to false", func() { + conn, _, err := ctx.ConnectWebSocket(platform, queryParams) + Expect(err).ShouldNot(HaveOccurred()) + + ctx.CloseWebSocket(conn) + + idCriteria := query.Criterion{ + LeftOp: "id", + Operator: query.EqualsOperator, + RightOp: []string{platform.ID}, + Type: query.FieldQuery, + } + ctx, _ := context.WithTimeout(context.TODO(), 5*time.Second) + ticker := time.NewTicker(500 * time.Millisecond) + for { + select { + case <-ticker.C: + obj, err := repository.Get(context.TODO(), types.PlatformType, idCriteria) + Expect(err).ShouldNot(HaveOccurred()) + p := obj.(*types.Platform) + if p.Active == false && !p.LastActive.IsZero() { + return + } + case <-ctx.Done(): + Fail("Timeout: platform active status not set to false") + } + } + }) + }) + Context("when same platform is connected twice", func() { It("should send same notifications to both", func() { conn, _, err := ctx.ConnectWebSocket(platform, queryParams)