diff --git a/internal/controllers/provisioning/provisioning_controller.go b/internal/controllers/provisioning/provisioning_controller.go index 1364f18..55b2326 100644 --- a/internal/controllers/provisioning/provisioning_controller.go +++ b/internal/controllers/provisioning/provisioning_controller.go @@ -257,7 +257,7 @@ func (c *ProvisioningController) syncHandler(key string) error { return err } azureDbs = selectItemsInPlatformAndDomain(platformKey, domainKey, azureDbs) - err = applyTenantOverrides(azureDbs, tenant.Name) + azureDbs, err = applyTenantOverrides(azureDbs, tenant.Name) if err != nil { return err } @@ -267,7 +267,7 @@ func (c *ProvisioningController) syncHandler(key string) error { return err } azureManagedDbs = selectItemsInPlatformAndDomain(platformKey, domainKey, azureManagedDbs) - err = applyTenantOverrides(azureManagedDbs, tenant.Name) + azureManagedDbs, err = applyTenantOverrides(azureManagedDbs, tenant.Name) if err != nil { return err } @@ -277,7 +277,7 @@ func (c *ProvisioningController) syncHandler(key string) error { return err } helmReleases = selectItemsInPlatformAndDomain(platformKey, domainKey, helmReleases) - err = applyTenantOverrides(helmReleases, tenant.Name) + helmReleases, err = applyTenantOverrides(helmReleases, tenant.Name) if err != nil { return err } @@ -287,7 +287,7 @@ func (c *ProvisioningController) syncHandler(key string) error { return err } azureVirtualMachines = selectItemsInPlatformAndDomain(platformKey, domainKey, azureVirtualMachines) - err = applyTenantOverrides(azureVirtualMachines, tenant.Name) + azureVirtualMachines, err = applyTenantOverrides(azureVirtualMachines, tenant.Name) if err != nil { return err } @@ -297,7 +297,7 @@ func (c *ProvisioningController) syncHandler(key string) error { return err } azureVirtualDesktops = selectItemsInPlatformAndDomain(platformKey, domainKey, azureVirtualDesktops) - err = applyTenantOverrides(azureVirtualDesktops, tenant.Name) + azureVirtualDesktops, err = applyTenantOverrides(azureVirtualDesktops, tenant.Name) if err != nil { return err } diff --git a/internal/controllers/provisioning/provisioning_controller_test.go b/internal/controllers/provisioning/provisioning_controller_test.go index 506ef13..5b35d68 100644 --- a/internal/controllers/provisioning/provisioning_controller_test.go +++ b/internal/controllers/provisioning/provisioning_controller_test.go @@ -304,12 +304,13 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) { }, } - err := applyTenantOverrides([]*provisioningv1.AzureManagedDatabase{&db}, tenantName) + result, err := applyTenantOverrides([]*provisioningv1.AzureManagedDatabase{&db}, tenantName) if err != nil { t.Error(err) } - assert.Equal(t, "afterBackupFileName", db.Spec.RestoreFrom.BackupFileName) + assert.Len(t, result, 1) + assert.Equal(t, "afterBackupFileName", result[0].Spec.RestoreFrom.BackupFileName) }) t.Run("override helmRelease version", func(t *testing.T) { @@ -342,12 +343,13 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) { }, } - err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName) + result, err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName) if err != nil { t.Error(err) } - assert.Equal(t, hr.Spec.Release.Chart.Spec.Version, "1.1.1") + assert.Len(t, result, 1) + assert.Equal(t, "1.1.1", result[0].Spec.Release.Chart.Spec.Version) }) t.Run("override avd params", func(t *testing.T) { @@ -384,16 +386,17 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) { }, } - err := applyTenantOverrides([]*provisioningv1.AzureVirtualDesktop{&avd}, tenantName) + result, err := applyTenantOverrides([]*provisioningv1.AzureVirtualDesktop{&avd}, tenantName) if err != nil { t.Error(err) } + assert.Len(t, result, 1) assert.Equal(t, []provisioningv1.InitScriptArgs{{Name: "arg1NameAfter", Value: "arg1ValueAfter"}}, - avd.Spec.InitScriptArguments) + result[0].Spec.InitScriptArguments) - assert.Equal(t, []string{"user1After", "user2After"}, avd.Spec.Users.ApplicationUsers) - assert.Equal(t, 2, avd.Spec.VmNumberOfInstances) + assert.Equal(t, []string{"user1After", "user2After"}, result[0].Spec.Users.ApplicationUsers) + assert.Equal(t, 2, result[0].Spec.VmNumberOfInstances) }) t.Run("override contents of JSON field", func(t *testing.T) { @@ -431,13 +434,15 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) { }, } - err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName) + result, err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName) if err != nil { t.Error(err) } + assert.Len(t, result, 1) + var valuesMap map[string]any - if err := json.Unmarshal(hr.Spec.Release.Values.Raw, &valuesMap); err != nil { + if err := json.Unmarshal(result[0].Spec.Release.Values.Raw, &valuesMap); err != nil { t.Error(err) } @@ -476,14 +481,59 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) { }, } - err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName) + result, err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName) + if err != nil { + t.Error(err) + } + + assert.Len(t, result, 1) + assert.Equal(t, "", result[0].Spec.Release.TargetNamespace) + assert.Equal(t, "storageNamespaceBefore", result[0].Spec.Release.StorageNamespace) + assert.Nil(t, result[0].Spec.Release.MaxHistory) + }) + + t.Run("overrides don't mutate source", func(t *testing.T) { + tenantName := "tenant1" + overrides := map[string]any{ + "release": map[string]any{ + "releaseName": "releaseNameAfter0", + }, + } + overridesBytes, _ := json.Marshal(overrides) + + hrs := []*provisioningv1.HelmRelease{ + { + Spec: provisioningv1.HelmReleaseSpec{ + ProvisioningMeta: provisioningv1.ProvisioningMeta{ + TenantOverrides: map[string]*v1.JSON{ + tenantName: {Raw: overridesBytes}, + }, + }, + Release: v2beta1.HelmReleaseSpec{ + ReleaseName: "releaseNameBefore0", + }, + }, + }, + { + Spec: provisioningv1.HelmReleaseSpec{ + Release: v2beta1.HelmReleaseSpec{ + ReleaseName: "releaseNameBefore1", + }, + }, + }, + } + + result, err := applyTenantOverrides(hrs, tenantName) if err != nil { t.Error(err) } - assert.Equal(t, "", hr.Spec.Release.TargetNamespace) - assert.Equal(t, "storageNamespaceBefore", hr.Spec.Release.StorageNamespace) - assert.Nil(t, hr.Spec.Release.MaxHistory) + assert.Len(t, result, 2) + assert.NotSame(t, hrs[0], result[0]) + assert.Equal(t, "releaseNameBefore0", hrs[0].Spec.Release.ReleaseName) + assert.Equal(t, "releaseNameAfter0", result[0].Spec.Release.ReleaseName) + assert.Same(t, hrs[1], result[1]) + assert.Equal(t, "releaseNameBefore1", result[1].Spec.Release.ReleaseName) }) } diff --git a/internal/controllers/provisioning/provisioning_types.go b/internal/controllers/provisioning/provisioning_types.go index 80bac1a..9b1a9fa 100644 --- a/internal/controllers/provisioning/provisioning_types.go +++ b/internal/controllers/provisioning/provisioning_types.go @@ -16,7 +16,10 @@ type ProvisioningResource interface { GetSpec() any GetName() string GetNamespace() string - Clone() any +} + +type Cloner[C any] interface { + DeepCopy() C } func getPlatformAndDomain[R ProvisioningResource](res R) (platform, domain string, ok bool) { @@ -34,61 +37,70 @@ func selectItemsInPlatformAndDomain[R ProvisioningResource](platform, domain str for _, res := range source { if res.GetProvisioningMeta().PlatformRef == platform && res.GetProvisioningMeta().DomainRef == domain { - result = append(result, res.Clone().(R)) + result = append(result, res) } } return result } -func applyTenantOverrides[T ProvisioningResource](targets []T, tenantName string) error { - if targets == nil { - return nil +func applyTenantOverrides[R interface { + ProvisioningResource + Cloner[R] +}](source []R, tenantName string) ([]R, error) { + if source == nil { + return source, nil } - for _, target := range targets { - overrides := target.GetProvisioningMeta().TenantOverrides + result := []R{} + + for _, res := range source { + overrides := res.GetProvisioningMeta().TenantOverrides if overrides == nil { + result = append(result, res) continue } tenantOverridesJson, exists := overrides[tenantName] if !exists { + result = append(result, res) continue } var tenantOverridesMap map[string]any if err := json.Unmarshal(tenantOverridesJson.Raw, &tenantOverridesMap); err != nil { - return err + return nil, err } - targetSpec := target.GetSpec() - - targetSpecJsonBytes, err := json.Marshal(targetSpec) + resSpecJsonBytes, err := json.Marshal(res.GetSpec()) if err != nil { - return err + return nil, err } var targetSpecMap map[string]any - if err := json.Unmarshal(targetSpecJsonBytes, &targetSpecMap); err != nil { - return err + if err := json.Unmarshal(resSpecJsonBytes, &targetSpecMap); err != nil { + return nil, err } if err := mergo.Merge(&targetSpecMap, tenantOverridesMap, mergo.WithOverride, mergo.WithTransformers(jsonTransformer{})); err != nil { - return err + return nil, err } - targetSpecJsonBytes, err = json.Marshal(targetSpecMap) + resSpecJsonBytes, err = json.Marshal(targetSpecMap) if err != nil { - return err + return nil, err } - if err := json.Unmarshal(targetSpecJsonBytes, targetSpec); err != nil { - return err + resClone := res.DeepCopy() + + if err := json.Unmarshal(resSpecJsonBytes, resClone.GetSpec()); err != nil { + return nil, err } + + result = append(result, resClone) } - return nil + return result, nil } type jsonTransformer struct { diff --git a/pkg/apis/provisioning/v1alpha1/azureDatabaseTypes.go b/pkg/apis/provisioning/v1alpha1/azureDatabaseTypes.go index abf95b6..efbb674 100644 --- a/pkg/apis/provisioning/v1alpha1/azureDatabaseTypes.go +++ b/pkg/apis/provisioning/v1alpha1/azureDatabaseTypes.go @@ -74,7 +74,3 @@ func (db *AzureDatabase) GetProvisioningMeta() *ProvisioningMeta { func (db *AzureDatabase) GetSpec() any { return &db.Spec } - -func (db *AzureDatabase) Clone() any { - return db.DeepCopy() -} diff --git a/pkg/apis/provisioning/v1alpha1/azureManagedDatabaseTypes.go b/pkg/apis/provisioning/v1alpha1/azureManagedDatabaseTypes.go index fd9dce1..8986550 100644 --- a/pkg/apis/provisioning/v1alpha1/azureManagedDatabaseTypes.go +++ b/pkg/apis/provisioning/v1alpha1/azureManagedDatabaseTypes.go @@ -78,7 +78,3 @@ func (db *AzureManagedDatabase) GetProvisioningMeta() *ProvisioningMeta { func (db *AzureManagedDatabase) GetSpec() any { return &db.Spec } - -func (db *AzureManagedDatabase) Clone() any { - return db.DeepCopy() -} diff --git a/pkg/apis/provisioning/v1alpha1/azureVirtualDesktopTypes.go b/pkg/apis/provisioning/v1alpha1/azureVirtualDesktopTypes.go index f5346d6..8b835ec 100644 --- a/pkg/apis/provisioning/v1alpha1/azureVirtualDesktopTypes.go +++ b/pkg/apis/provisioning/v1alpha1/azureVirtualDesktopTypes.go @@ -112,7 +112,3 @@ func (db *AzureVirtualDesktop) GetProvisioningMeta() *ProvisioningMeta { func (db *AzureVirtualDesktop) GetSpec() any { return &db.Spec } - -func (db *AzureVirtualDesktop) Clone() any { - return db.DeepCopy() -} diff --git a/pkg/apis/provisioning/v1alpha1/azureVirtualMachineTypes.go b/pkg/apis/provisioning/v1alpha1/azureVirtualMachineTypes.go index 0f1a73b..be3a928 100644 --- a/pkg/apis/provisioning/v1alpha1/azureVirtualMachineTypes.go +++ b/pkg/apis/provisioning/v1alpha1/azureVirtualMachineTypes.go @@ -77,7 +77,3 @@ func (db *AzureVirtualMachine) GetProvisioningMeta() *ProvisioningMeta { func (db *AzureVirtualMachine) GetSpec() any { return &db.Spec } - -func (db *AzureVirtualMachine) Clone() any { - return db.DeepCopy() -} diff --git a/pkg/apis/provisioning/v1alpha1/helmReleaseTypes.go b/pkg/apis/provisioning/v1alpha1/helmReleaseTypes.go index cc6c15e..ecc9dcd 100644 --- a/pkg/apis/provisioning/v1alpha1/helmReleaseTypes.go +++ b/pkg/apis/provisioning/v1alpha1/helmReleaseTypes.go @@ -46,7 +46,3 @@ func (db *HelmRelease) GetProvisioningMeta() *ProvisioningMeta { func (db *HelmRelease) GetSpec() any { return &db.Spec } - -func (db *HelmRelease) Clone() any { - return db.DeepCopy() -}