Skip to content

Commit

Permalink
Refactor cloning for tenant overrides (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
fraliv13 committed Jul 5, 2023
1 parent 97eccd2 commit 89540c3
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 59 deletions.
10 changes: 5 additions & 5 deletions internal/controllers/provisioning/provisioning_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (c *ProvisioningController) syncHandler(key string) error {
return err
}
azureDbs = selectItemsInPlatformAndDomain(platformKey, domainKey, azureDbs)
err = applyTenantOverrides(azureDbs, tenant.Name)
azureDbs, err = applyTenantOverrides(azureDbs, tenant.Name)
if err != nil {
return err
}
Expand All @@ -267,7 +267,7 @@ func (c *ProvisioningController) syncHandler(key string) error {
return err
}
azureManagedDbs = selectItemsInPlatformAndDomain(platformKey, domainKey, azureManagedDbs)
err = applyTenantOverrides(azureManagedDbs, tenant.Name)
azureManagedDbs, err = applyTenantOverrides(azureManagedDbs, tenant.Name)
if err != nil {
return err
}
Expand All @@ -277,7 +277,7 @@ func (c *ProvisioningController) syncHandler(key string) error {
return err
}
helmReleases = selectItemsInPlatformAndDomain(platformKey, domainKey, helmReleases)
err = applyTenantOverrides(helmReleases, tenant.Name)
helmReleases, err = applyTenantOverrides(helmReleases, tenant.Name)
if err != nil {
return err
}
Expand All @@ -287,7 +287,7 @@ func (c *ProvisioningController) syncHandler(key string) error {
return err
}
azureVirtualMachines = selectItemsInPlatformAndDomain(platformKey, domainKey, azureVirtualMachines)
err = applyTenantOverrides(azureVirtualMachines, tenant.Name)
azureVirtualMachines, err = applyTenantOverrides(azureVirtualMachines, tenant.Name)
if err != nil {
return err
}
Expand All @@ -297,7 +297,7 @@ func (c *ProvisioningController) syncHandler(key string) error {
return err
}
azureVirtualDesktops = selectItemsInPlatformAndDomain(platformKey, domainKey, azureVirtualDesktops)
err = applyTenantOverrides(azureVirtualDesktops, tenant.Name)
azureVirtualDesktops, err = applyTenantOverrides(azureVirtualDesktops, tenant.Name)
if err != nil {
return err
}
Expand Down
78 changes: 64 additions & 14 deletions internal/controllers/provisioning/provisioning_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,13 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) {
},
}

err := applyTenantOverrides([]*provisioningv1.AzureManagedDatabase{&db}, tenantName)
result, err := applyTenantOverrides([]*provisioningv1.AzureManagedDatabase{&db}, tenantName)
if err != nil {
t.Error(err)
}

assert.Equal(t, "afterBackupFileName", db.Spec.RestoreFrom.BackupFileName)
assert.Len(t, result, 1)
assert.Equal(t, "afterBackupFileName", result[0].Spec.RestoreFrom.BackupFileName)
})

t.Run("override helmRelease version", func(t *testing.T) {
Expand Down Expand Up @@ -342,12 +343,13 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) {
},
}

err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName)
result, err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName)
if err != nil {
t.Error(err)
}

assert.Equal(t, hr.Spec.Release.Chart.Spec.Version, "1.1.1")
assert.Len(t, result, 1)
assert.Equal(t, "1.1.1", result[0].Spec.Release.Chart.Spec.Version)
})

t.Run("override avd params", func(t *testing.T) {
Expand Down Expand Up @@ -384,16 +386,17 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) {
},
}

err := applyTenantOverrides([]*provisioningv1.AzureVirtualDesktop{&avd}, tenantName)
result, err := applyTenantOverrides([]*provisioningv1.AzureVirtualDesktop{&avd}, tenantName)
if err != nil {
t.Error(err)
}

assert.Len(t, result, 1)
assert.Equal(t, []provisioningv1.InitScriptArgs{{Name: "arg1NameAfter", Value: "arg1ValueAfter"}},
avd.Spec.InitScriptArguments)
result[0].Spec.InitScriptArguments)

assert.Equal(t, []string{"user1After", "user2After"}, avd.Spec.Users.ApplicationUsers)
assert.Equal(t, 2, avd.Spec.VmNumberOfInstances)
assert.Equal(t, []string{"user1After", "user2After"}, result[0].Spec.Users.ApplicationUsers)
assert.Equal(t, 2, result[0].Spec.VmNumberOfInstances)
})

t.Run("override contents of JSON field", func(t *testing.T) {
Expand Down Expand Up @@ -431,13 +434,15 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) {
},
}

err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName)
result, err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName)
if err != nil {
t.Error(err)
}

assert.Len(t, result, 1)

var valuesMap map[string]any
if err := json.Unmarshal(hr.Spec.Release.Values.Raw, &valuesMap); err != nil {
if err := json.Unmarshal(result[0].Spec.Release.Values.Raw, &valuesMap); err != nil {
t.Error(err)
}

Expand Down Expand Up @@ -476,14 +481,59 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) {
},
}

err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName)
result, err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName)
if err != nil {
t.Error(err)
}

assert.Len(t, result, 1)
assert.Equal(t, "", result[0].Spec.Release.TargetNamespace)
assert.Equal(t, "storageNamespaceBefore", result[0].Spec.Release.StorageNamespace)
assert.Nil(t, result[0].Spec.Release.MaxHistory)
})

t.Run("overrides don't mutate source", func(t *testing.T) {
tenantName := "tenant1"
overrides := map[string]any{
"release": map[string]any{
"releaseName": "releaseNameAfter0",
},
}
overridesBytes, _ := json.Marshal(overrides)

hrs := []*provisioningv1.HelmRelease{
{
Spec: provisioningv1.HelmReleaseSpec{
ProvisioningMeta: provisioningv1.ProvisioningMeta{
TenantOverrides: map[string]*v1.JSON{
tenantName: {Raw: overridesBytes},
},
},
Release: v2beta1.HelmReleaseSpec{
ReleaseName: "releaseNameBefore0",
},
},
},
{
Spec: provisioningv1.HelmReleaseSpec{
Release: v2beta1.HelmReleaseSpec{
ReleaseName: "releaseNameBefore1",
},
},
},
}

result, err := applyTenantOverrides(hrs, tenantName)
if err != nil {
t.Error(err)
}

assert.Equal(t, "", hr.Spec.Release.TargetNamespace)
assert.Equal(t, "storageNamespaceBefore", hr.Spec.Release.StorageNamespace)
assert.Nil(t, hr.Spec.Release.MaxHistory)
assert.Len(t, result, 2)
assert.NotSame(t, hrs[0], result[0])
assert.Equal(t, "releaseNameBefore0", hrs[0].Spec.Release.ReleaseName)
assert.Equal(t, "releaseNameAfter0", result[0].Spec.Release.ReleaseName)
assert.Same(t, hrs[1], result[1])
assert.Equal(t, "releaseNameBefore1", result[1].Spec.Release.ReleaseName)
})

}
Expand Down
52 changes: 32 additions & 20 deletions internal/controllers/provisioning/provisioning_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ type ProvisioningResource interface {
GetSpec() any
GetName() string
GetNamespace() string
Clone() any
}

type Cloner[C any] interface {
DeepCopy() C
}

func getPlatformAndDomain[R ProvisioningResource](res R) (platform, domain string, ok bool) {
Expand All @@ -34,61 +37,70 @@ func selectItemsInPlatformAndDomain[R ProvisioningResource](platform, domain str
for _, res := range source {
if res.GetProvisioningMeta().PlatformRef == platform && res.GetProvisioningMeta().DomainRef == domain {

result = append(result, res.Clone().(R))
result = append(result, res)
}
}
return result
}

func applyTenantOverrides[T ProvisioningResource](targets []T, tenantName string) error {
if targets == nil {
return nil
func applyTenantOverrides[R interface {
ProvisioningResource
Cloner[R]
}](source []R, tenantName string) ([]R, error) {
if source == nil {
return source, nil
}

for _, target := range targets {
overrides := target.GetProvisioningMeta().TenantOverrides
result := []R{}

for _, res := range source {
overrides := res.GetProvisioningMeta().TenantOverrides

if overrides == nil {
result = append(result, res)
continue
}

tenantOverridesJson, exists := overrides[tenantName]
if !exists {
result = append(result, res)
continue
}

var tenantOverridesMap map[string]any
if err := json.Unmarshal(tenantOverridesJson.Raw, &tenantOverridesMap); err != nil {
return err
return nil, err
}

targetSpec := target.GetSpec()

targetSpecJsonBytes, err := json.Marshal(targetSpec)
resSpecJsonBytes, err := json.Marshal(res.GetSpec())
if err != nil {
return err
return nil, err
}

var targetSpecMap map[string]any
if err := json.Unmarshal(targetSpecJsonBytes, &targetSpecMap); err != nil {
return err
if err := json.Unmarshal(resSpecJsonBytes, &targetSpecMap); err != nil {
return nil, err
}

if err := mergo.Merge(&targetSpecMap, tenantOverridesMap, mergo.WithOverride, mergo.WithTransformers(jsonTransformer{})); err != nil {
return err
return nil, err
}

targetSpecJsonBytes, err = json.Marshal(targetSpecMap)
resSpecJsonBytes, err = json.Marshal(targetSpecMap)
if err != nil {
return err
return nil, err
}

if err := json.Unmarshal(targetSpecJsonBytes, targetSpec); err != nil {
return err
resClone := res.DeepCopy()

if err := json.Unmarshal(resSpecJsonBytes, resClone.GetSpec()); err != nil {
return nil, err
}

result = append(result, resClone)
}

return nil
return result, nil
}

type jsonTransformer struct {
Expand Down
4 changes: 0 additions & 4 deletions pkg/apis/provisioning/v1alpha1/azureDatabaseTypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,3 @@ func (db *AzureDatabase) GetProvisioningMeta() *ProvisioningMeta {
func (db *AzureDatabase) GetSpec() any {
return &db.Spec
}

func (db *AzureDatabase) Clone() any {
return db.DeepCopy()
}
4 changes: 0 additions & 4 deletions pkg/apis/provisioning/v1alpha1/azureManagedDatabaseTypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,3 @@ func (db *AzureManagedDatabase) GetProvisioningMeta() *ProvisioningMeta {
func (db *AzureManagedDatabase) GetSpec() any {
return &db.Spec
}

func (db *AzureManagedDatabase) Clone() any {
return db.DeepCopy()
}
4 changes: 0 additions & 4 deletions pkg/apis/provisioning/v1alpha1/azureVirtualDesktopTypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,3 @@ func (db *AzureVirtualDesktop) GetProvisioningMeta() *ProvisioningMeta {
func (db *AzureVirtualDesktop) GetSpec() any {
return &db.Spec
}

func (db *AzureVirtualDesktop) Clone() any {
return db.DeepCopy()
}
4 changes: 0 additions & 4 deletions pkg/apis/provisioning/v1alpha1/azureVirtualMachineTypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,3 @@ func (db *AzureVirtualMachine) GetProvisioningMeta() *ProvisioningMeta {
func (db *AzureVirtualMachine) GetSpec() any {
return &db.Spec
}

func (db *AzureVirtualMachine) Clone() any {
return db.DeepCopy()
}
4 changes: 0 additions & 4 deletions pkg/apis/provisioning/v1alpha1/helmReleaseTypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,3 @@ func (db *HelmRelease) GetProvisioningMeta() *ProvisioningMeta {
func (db *HelmRelease) GetSpec() any {
return &db.Spec
}

func (db *HelmRelease) Clone() any {
return db.DeepCopy()
}

0 comments on commit 89540c3

Please sign in to comment.