Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor cloning for tenant overrides #59

Merged
merged 1 commit into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions internal/controllers/provisioning/provisioning_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (c *ProvisioningController) syncHandler(key string) error {
return err
}
azureDbs = selectItemsInPlatformAndDomain(platformKey, domainKey, azureDbs)
err = applyTenantOverrides(azureDbs, tenant.Name)
azureDbs, err = applyTenantOverrides(azureDbs, tenant.Name)
if err != nil {
return err
}
Expand All @@ -267,7 +267,7 @@ func (c *ProvisioningController) syncHandler(key string) error {
return err
}
azureManagedDbs = selectItemsInPlatformAndDomain(platformKey, domainKey, azureManagedDbs)
err = applyTenantOverrides(azureManagedDbs, tenant.Name)
azureManagedDbs, err = applyTenantOverrides(azureManagedDbs, tenant.Name)
if err != nil {
return err
}
Expand All @@ -277,7 +277,7 @@ func (c *ProvisioningController) syncHandler(key string) error {
return err
}
helmReleases = selectItemsInPlatformAndDomain(platformKey, domainKey, helmReleases)
err = applyTenantOverrides(helmReleases, tenant.Name)
helmReleases, err = applyTenantOverrides(helmReleases, tenant.Name)
if err != nil {
return err
}
Expand All @@ -287,7 +287,7 @@ func (c *ProvisioningController) syncHandler(key string) error {
return err
}
azureVirtualMachines = selectItemsInPlatformAndDomain(platformKey, domainKey, azureVirtualMachines)
err = applyTenantOverrides(azureVirtualMachines, tenant.Name)
azureVirtualMachines, err = applyTenantOverrides(azureVirtualMachines, tenant.Name)
if err != nil {
return err
}
Expand All @@ -297,7 +297,7 @@ func (c *ProvisioningController) syncHandler(key string) error {
return err
}
azureVirtualDesktops = selectItemsInPlatformAndDomain(platformKey, domainKey, azureVirtualDesktops)
err = applyTenantOverrides(azureVirtualDesktops, tenant.Name)
azureVirtualDesktops, err = applyTenantOverrides(azureVirtualDesktops, tenant.Name)
if err != nil {
return err
}
Expand Down
78 changes: 64 additions & 14 deletions internal/controllers/provisioning/provisioning_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,13 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) {
},
}

err := applyTenantOverrides([]*provisioningv1.AzureManagedDatabase{&db}, tenantName)
result, err := applyTenantOverrides([]*provisioningv1.AzureManagedDatabase{&db}, tenantName)
if err != nil {
t.Error(err)
}

assert.Equal(t, "afterBackupFileName", db.Spec.RestoreFrom.BackupFileName)
assert.Len(t, result, 1)
assert.Equal(t, "afterBackupFileName", result[0].Spec.RestoreFrom.BackupFileName)
})

t.Run("override helmRelease version", func(t *testing.T) {
Expand Down Expand Up @@ -342,12 +343,13 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) {
},
}

err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName)
result, err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName)
if err != nil {
t.Error(err)
}

assert.Equal(t, hr.Spec.Release.Chart.Spec.Version, "1.1.1")
assert.Len(t, result, 1)
assert.Equal(t, "1.1.1", result[0].Spec.Release.Chart.Spec.Version)
})

t.Run("override avd params", func(t *testing.T) {
Expand Down Expand Up @@ -384,16 +386,17 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) {
},
}

err := applyTenantOverrides([]*provisioningv1.AzureVirtualDesktop{&avd}, tenantName)
result, err := applyTenantOverrides([]*provisioningv1.AzureVirtualDesktop{&avd}, tenantName)
if err != nil {
t.Error(err)
}

assert.Len(t, result, 1)
assert.Equal(t, []provisioningv1.InitScriptArgs{{Name: "arg1NameAfter", Value: "arg1ValueAfter"}},
avd.Spec.InitScriptArguments)
result[0].Spec.InitScriptArguments)

assert.Equal(t, []string{"user1After", "user2After"}, avd.Spec.Users.ApplicationUsers)
assert.Equal(t, 2, avd.Spec.VmNumberOfInstances)
assert.Equal(t, []string{"user1After", "user2After"}, result[0].Spec.Users.ApplicationUsers)
assert.Equal(t, 2, result[0].Spec.VmNumberOfInstances)
})

t.Run("override contents of JSON field", func(t *testing.T) {
Expand Down Expand Up @@ -431,13 +434,15 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) {
},
}

err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName)
result, err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName)
if err != nil {
t.Error(err)
}

assert.Len(t, result, 1)

var valuesMap map[string]any
if err := json.Unmarshal(hr.Spec.Release.Values.Raw, &valuesMap); err != nil {
if err := json.Unmarshal(result[0].Spec.Release.Values.Raw, &valuesMap); err != nil {
t.Error(err)
}

Expand Down Expand Up @@ -476,14 +481,59 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) {
},
}

err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName)
result, err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName)
if err != nil {
t.Error(err)
}

assert.Len(t, result, 1)
assert.Equal(t, "", result[0].Spec.Release.TargetNamespace)
assert.Equal(t, "storageNamespaceBefore", result[0].Spec.Release.StorageNamespace)
assert.Nil(t, result[0].Spec.Release.MaxHistory)
})

t.Run("overrides don't mutate source", func(t *testing.T) {
tenantName := "tenant1"
overrides := map[string]any{
"release": map[string]any{
"releaseName": "releaseNameAfter0",
},
}
overridesBytes, _ := json.Marshal(overrides)

hrs := []*provisioningv1.HelmRelease{
{
Spec: provisioningv1.HelmReleaseSpec{
ProvisioningMeta: provisioningv1.ProvisioningMeta{
TenantOverrides: map[string]*v1.JSON{
tenantName: {Raw: overridesBytes},
},
},
Release: v2beta1.HelmReleaseSpec{
ReleaseName: "releaseNameBefore0",
},
},
},
{
Spec: provisioningv1.HelmReleaseSpec{
Release: v2beta1.HelmReleaseSpec{
ReleaseName: "releaseNameBefore1",
},
},
},
}

result, err := applyTenantOverrides(hrs, tenantName)
if err != nil {
t.Error(err)
}

assert.Equal(t, "", hr.Spec.Release.TargetNamespace)
assert.Equal(t, "storageNamespaceBefore", hr.Spec.Release.StorageNamespace)
assert.Nil(t, hr.Spec.Release.MaxHistory)
assert.Len(t, result, 2)
assert.NotSame(t, hrs[0], result[0])
assert.Equal(t, "releaseNameBefore0", hrs[0].Spec.Release.ReleaseName)
assert.Equal(t, "releaseNameAfter0", result[0].Spec.Release.ReleaseName)
assert.Same(t, hrs[1], result[1])
assert.Equal(t, "releaseNameBefore1", result[1].Spec.Release.ReleaseName)
})

}
Expand Down
52 changes: 32 additions & 20 deletions internal/controllers/provisioning/provisioning_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ type ProvisioningResource interface {
GetSpec() any
GetName() string
GetNamespace() string
Clone() any
}

type Cloner[C any] interface {
DeepCopy() C
}

func getPlatformAndDomain[R ProvisioningResource](res R) (platform, domain string, ok bool) {
Expand All @@ -34,61 +37,70 @@ func selectItemsInPlatformAndDomain[R ProvisioningResource](platform, domain str
for _, res := range source {
if res.GetProvisioningMeta().PlatformRef == platform && res.GetProvisioningMeta().DomainRef == domain {

result = append(result, res.Clone().(R))
result = append(result, res)
}
}
return result
}

func applyTenantOverrides[T ProvisioningResource](targets []T, tenantName string) error {
if targets == nil {
return nil
func applyTenantOverrides[R interface {
ProvisioningResource
Cloner[R]
}](source []R, tenantName string) ([]R, error) {
if source == nil {
return source, nil
}

for _, target := range targets {
overrides := target.GetProvisioningMeta().TenantOverrides
result := []R{}

for _, res := range source {
overrides := res.GetProvisioningMeta().TenantOverrides

if overrides == nil {
result = append(result, res)
continue
}

tenantOverridesJson, exists := overrides[tenantName]
if !exists {
result = append(result, res)
continue
}

var tenantOverridesMap map[string]any
if err := json.Unmarshal(tenantOverridesJson.Raw, &tenantOverridesMap); err != nil {
return err
return nil, err
}

targetSpec := target.GetSpec()

targetSpecJsonBytes, err := json.Marshal(targetSpec)
resSpecJsonBytes, err := json.Marshal(res.GetSpec())
if err != nil {
return err
return nil, err
}

var targetSpecMap map[string]any
if err := json.Unmarshal(targetSpecJsonBytes, &targetSpecMap); err != nil {
return err
if err := json.Unmarshal(resSpecJsonBytes, &targetSpecMap); err != nil {
return nil, err
}

if err := mergo.Merge(&targetSpecMap, tenantOverridesMap, mergo.WithOverride, mergo.WithTransformers(jsonTransformer{})); err != nil {
return err
return nil, err
}

targetSpecJsonBytes, err = json.Marshal(targetSpecMap)
resSpecJsonBytes, err = json.Marshal(targetSpecMap)
if err != nil {
return err
return nil, err
}

if err := json.Unmarshal(targetSpecJsonBytes, targetSpec); err != nil {
return err
resClone := res.DeepCopy()

if err := json.Unmarshal(resSpecJsonBytes, resClone.GetSpec()); err != nil {
return nil, err
}

result = append(result, resClone)
}

return nil
return result, nil
}

type jsonTransformer struct {
Expand Down
4 changes: 0 additions & 4 deletions pkg/apis/provisioning/v1alpha1/azureDatabaseTypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,3 @@ func (db *AzureDatabase) GetProvisioningMeta() *ProvisioningMeta {
func (db *AzureDatabase) GetSpec() any {
return &db.Spec
}

func (db *AzureDatabase) Clone() any {
return db.DeepCopy()
}
4 changes: 0 additions & 4 deletions pkg/apis/provisioning/v1alpha1/azureManagedDatabaseTypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,3 @@ func (db *AzureManagedDatabase) GetProvisioningMeta() *ProvisioningMeta {
func (db *AzureManagedDatabase) GetSpec() any {
return &db.Spec
}

func (db *AzureManagedDatabase) Clone() any {
return db.DeepCopy()
}
4 changes: 0 additions & 4 deletions pkg/apis/provisioning/v1alpha1/azureVirtualDesktopTypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,3 @@ func (db *AzureVirtualDesktop) GetProvisioningMeta() *ProvisioningMeta {
func (db *AzureVirtualDesktop) GetSpec() any {
return &db.Spec
}

func (db *AzureVirtualDesktop) Clone() any {
return db.DeepCopy()
}
4 changes: 0 additions & 4 deletions pkg/apis/provisioning/v1alpha1/azureVirtualMachineTypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,3 @@ func (db *AzureVirtualMachine) GetProvisioningMeta() *ProvisioningMeta {
func (db *AzureVirtualMachine) GetSpec() any {
return &db.Spec
}

func (db *AzureVirtualMachine) Clone() any {
return db.DeepCopy()
}
4 changes: 0 additions & 4 deletions pkg/apis/provisioning/v1alpha1/helmReleaseTypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,3 @@ func (db *HelmRelease) GetProvisioningMeta() *ProvisioningMeta {
func (db *HelmRelease) GetSpec() any {
return &db.Spec
}

func (db *HelmRelease) Clone() any {
return db.DeepCopy()
}