From d019f0aaba76687ab67dcbff3e7075f532089fbe Mon Sep 17 00:00:00 2001 From: Radu Popovici Date: Tue, 4 Jul 2023 15:05:49 +0300 Subject: [PATCH] domain stacks (#56) * tests passing * skip migrator on disabled tenants * resource group per domain --- .../provisioning/migration/migration.go | 20 +- .../provisioning/migration/migration_test.go | 22 +- .../pulumi/azure_resource_group.go | 4 +- .../provisioners/pulumi/pulumi.go | 11 +- .../provisioning/provisioners/types.go | 4 +- .../provisioning/provisioning_controller.go | 557 ++++-------------- .../provisioning_controller_test.go | 73 ++- .../provisioning/provisioning_types.go | 126 ++++ pkg/apis/provisioning/v1alpha1/commonTypes.go | 4 - 9 files changed, 315 insertions(+), 506 deletions(-) create mode 100644 internal/controllers/provisioning/provisioning_types.go diff --git a/internal/controllers/provisioning/migration/migration.go b/internal/controllers/provisioning/migration/migration.go index ff36b98..aee8abd 100644 --- a/internal/controllers/provisioning/migration/migration.go +++ b/internal/controllers/provisioning/migration/migration.go @@ -8,26 +8,36 @@ import ( v1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" "k8s.io/client-go/kubernetes" "k8s.io/klog/v2" platformv1 "totalsoft.ro/platform-controllers/pkg/apis/platform/v1alpha1" ) const ( - JobLabelSelectorKey = "provisioning.totalsoft.ro/migration-job-template" + jobTemplateLabel = "provisioning.totalsoft.ro/migration-job-template" + domainLabel = "platform.totalsoft.ro/domain" ) func KubeJobsMigrationForTenant(kubeClient kubernetes.Interface, - nsFilter func(string, string) bool) func(platform string, tenant *platformv1.Tenant) error { + nsFilter func(string, string) bool) func(platform string, tenant *platformv1.Tenant, domain string) error { namer := func(jName, tenant string) string { return fmt.Sprintf("%s-%s-%d", jName, tenant, time.Now().Unix()) } - return func(platform string, tenant *platformv1.Tenant) error { - klog.InfoS("Creating migrations jobs", "tenant", tenant.Name) + return func(platform string, tenant *platformv1.Tenant, domain string) error { + klog.InfoS("Creating migrations jobs", "tenant", tenant.Name, "domain", domain) + + labelSelector, err := labels.ValidatedSelectorFromSet(map[string]string{ + domainLabel: domain, + jobTemplateLabel: "true", + }) + if err != nil { + return err + } jobs, err := kubeClient.BatchV1().Jobs("").List(context.TODO(), metav1.ListOptions{ - LabelSelector: JobLabelSelectorKey + "=true", + LabelSelector: labelSelector.String(), }) if err != nil { return err diff --git a/internal/controllers/provisioning/migration/migration_test.go b/internal/controllers/provisioning/migration/migration_test.go index b1580a5..df78b9c 100644 --- a/internal/controllers/provisioning/migration/migration_test.go +++ b/internal/controllers/provisioning/migration/migration_test.go @@ -12,25 +12,28 @@ import ( ) func TestKubeJobsMigrationForTenant(t *testing.T) { + domain := "test-domain" objects := []runtime.Object{ - newJob("dev1", true), - newJob("dev2", true), - newJob("dev3", false), + newJob("dev1", domain, true), + newJob("dev2", domain, true), + newJob("dev3", domain, false), + newJob("dev4", "some-other-domain", true), } kubeClient := fake.NewSimpleClientset(objects...) migrator := KubeJobsMigrationForTenant(kubeClient, func(s string, s2 string) bool { return true }) t.Run("test job selection by label", func(t *testing.T) { - migrator("test", newTenant("qa", "qa")) + migrator("test", newTenant("qa", "qa"), domain) jobs, _ := kubeClient.BatchV1().Jobs(metav1.NamespaceDefault).List(context.TODO(), metav1.ListOptions{}) - if len(jobs.Items) != 5 { - t.Errorf("Error running migration, expected 5 jobs but found %d", len(jobs.Items)) + expectedNoOfJobs := 4 + 2 //4 existing + 2 new jobs + if len(jobs.Items) != expectedNoOfJobs { + t.Errorf("Error running migration, expected %d jobs but found %d", expectedNoOfJobs, len(jobs.Items)) } }) } -func newJob(name string, template bool) *v1.Job { +func newJob(name, domain string, template bool) *v1.Job { j := &v1.Job{ TypeMeta: metav1.TypeMeta{APIVersion: platformv1.SchemeGroupVersion.String()}, ObjectMeta: metav1.ObjectMeta{ @@ -40,7 +43,10 @@ func newJob(name string, template bool) *v1.Job { Spec: v1.JobSpec{}, } if template { - j.SetLabels(map[string]string{JobLabelSelectorKey: "true"}) + j.SetLabels(map[string]string{ + jobTemplateLabel: "true", + domainLabel: domain, + }) } return j } diff --git a/internal/controllers/provisioning/provisioners/pulumi/azure_resource_group.go b/internal/controllers/provisioning/provisioners/pulumi/azure_resource_group.go index 6c472a7..ec59dc9 100644 --- a/internal/controllers/provisioning/provisioners/pulumi/azure_resource_group.go +++ b/internal/controllers/provisioning/provisioners/pulumi/azure_resource_group.go @@ -8,9 +8,9 @@ import ( platformv1 "totalsoft.ro/platform-controllers/pkg/apis/platform/v1alpha1" ) -func azureRGDeployFunc(platform string, tenant *platformv1.Tenant) func(ctx *pulumi.Context) (pulumi.StringOutput, error) { +func azureRGDeployFunc(platform string, tenant *platformv1.Tenant, domain string) func(ctx *pulumi.Context) (pulumi.StringOutput, error) { return func(ctx *pulumi.Context) (pulumi.StringOutput, error) { - resourceGroupName := fmt.Sprintf("%s-%s", platform, tenant.Name) + resourceGroupName := fmt.Sprintf("%s-%s-%s", platform, tenant.Name, domain) resourceGroup, err := azureResources.NewResourceGroup(ctx, resourceGroupName, &azureResources.ResourceGroupArgs{ ResourceGroupName: pulumi.String(resourceGroupName), }, pulumi.RetainOnDelete(true)) diff --git a/internal/controllers/provisioning/provisioners/pulumi/pulumi.go b/internal/controllers/provisioning/provisioners/pulumi/pulumi.go index e1150bc..d5133ca 100644 --- a/internal/controllers/provisioning/provisioners/pulumi/pulumi.go +++ b/internal/controllers/provisioning/provisioners/pulumi/pulumi.go @@ -4,6 +4,7 @@ package pulumi import ( "context" + "fmt" "github.com/pulumi/pulumi/sdk/v3/go/common/apitype" @@ -22,7 +23,7 @@ const ( PulumiRetainOnDelete = true ) -func Create(platform string, tenant *platformv1.Tenant, infra *provisioners.InfrastructureManifests) provisioners.ProvisioningResult { +func Create(platform string, tenant *platformv1.Tenant, domain string, infra *provisioners.InfrastructureManifests) provisioners.ProvisioningResult { result := provisioners.ProvisioningResult{} upRes := auto.UpResult{} destroyRes := auto.DestroyResult{} @@ -37,9 +38,9 @@ func Create(platform string, tenant *platformv1.Tenant, infra *provisioners.Infr anyResource := anyAzureDb || anyManagedAzureDb || anyHelmRelease || anyVirtualMachine || anyVirtualDesktop needsResourceGroup := anyVirtualMachine || anyVirtualDesktop - stackName := tenant.Name + stackName := fmt.Sprintf("%s-%s", tenant.Name, domain) if anyResource { - upRes, result.Error = updateStack(stackName, platform, deployFunc(platform, tenant, infra, needsResourceGroup)) + upRes, result.Error = updateStack(stackName, platform, deployFunc(platform, tenant, domain, infra, needsResourceGroup)) if result.Error != nil { return result } @@ -188,7 +189,7 @@ func createOrSelectStack(ctx context.Context, stackName, projectName string, dep return s, nil } -func deployFunc(platform string, tenant *platformv1.Tenant, +func deployFunc(platform string, tenant *platformv1.Tenant, domain string, infra *provisioners.InfrastructureManifests, needsResourceGroup bool) pulumi.RunFunc { return func(ctx *pulumi.Context) error { @@ -208,7 +209,7 @@ func deployFunc(platform string, tenant *platformv1.Tenant, } if needsResourceGroup { - rgName, err := azureRGDeployFunc(platform, tenant)(ctx) + rgName, err := azureRGDeployFunc(platform, tenant, domain)(ctx) if err != nil { return err } diff --git a/internal/controllers/provisioning/provisioners/types.go b/internal/controllers/provisioning/provisioners/types.go index 14d5eba..774f095 100644 --- a/internal/controllers/provisioning/provisioners/types.go +++ b/internal/controllers/provisioning/provisioners/types.go @@ -5,8 +5,10 @@ import ( provisioningv1 "totalsoft.ro/platform-controllers/pkg/apis/provisioning/v1alpha1" ) -type CreateInfrastructureFunc func(platform string, +type CreateInfrastructureFunc func( + platform string, tenant *platformv1.Tenant, + domain string, infra *InfrastructureManifests) ProvisioningResult type InfrastructureManifests struct { diff --git a/internal/controllers/provisioning/provisioning_controller.go b/internal/controllers/provisioning/provisioning_controller.go index 4513510..1364f18 100644 --- a/internal/controllers/provisioning/provisioning_controller.go +++ b/internal/controllers/provisioning/provisioning_controller.go @@ -2,18 +2,13 @@ package provisioning import ( "context" - "encoding/json" "fmt" "reflect" "strings" "time" - "k8s.io/utils/strings/slices" - corev1 "k8s.io/api/core/v1" - apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" "k8s.io/apimachinery/pkg/api/errors" - apimeta "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" utilruntime "k8s.io/apimachinery/pkg/util/runtime" @@ -23,7 +18,7 @@ import ( "k8s.io/client-go/tools/record" "k8s.io/client-go/util/workqueue" "k8s.io/klog/v2" - controllers "totalsoft.ro/platform-controllers/internal/controllers" + "k8s.io/utils/strings/slices" provisioners "totalsoft.ro/platform-controllers/internal/controllers/provisioning/provisioners" messaging "totalsoft.ro/platform-controllers/internal/messaging" platformv1 "totalsoft.ro/platform-controllers/pkg/apis/platform/v1alpha1" @@ -33,8 +28,6 @@ import ( informers "totalsoft.ro/platform-controllers/pkg/generated/informers/externalversions" platformInformersv1 "totalsoft.ro/platform-controllers/pkg/generated/informers/externalversions/platform/v1alpha1" provisioningInformersv1 "totalsoft.ro/platform-controllers/pkg/generated/informers/externalversions/provisioning/v1alpha1" - - "dario.cat/mergo" ) const ( @@ -44,13 +37,16 @@ const ( tenantProvisionedSuccessfullyTopic = "PlatformControllers.ProvisioningController.TenantProvisionedSuccessfully" tenantProvisionningFailedTopic = "PlatformControllers.ProvisioningController.TenantProvisionningFailed" + + DomainProvisionedSuccessfullyFormat string = "%s domain provisioned successfully" + DomainProvisionningFailedFormat string = "%s domain provisionning failed" ) // ProvisioningController is the controller implementation for Tenant resources type ProvisioningController struct { factory informers.SharedInformerFactory clientset clientset.Interface - migrator func(platform string, tenant *platformv1.Tenant) error + migrator func(platform string, tenant *platformv1.Tenant, domain string) error // workqueue is a rate limited work queue. This is used to queue work to be // processed instead of performing it as soon as a change happens. This @@ -77,7 +73,7 @@ type ProvisioningController struct { func NewProvisioningController(clientSet clientset.Interface, provisioner provisioners.CreateInfrastructureFunc, - migrator func(platform string, tenant *platformv1.Tenant) error, + migrator func(platform string, tenant *platformv1.Tenant, domain string) error, eventBroadcaster record.EventBroadcaster, messagingPublisher messaging.MessagingPublisher) *ProvisioningController { @@ -114,11 +110,12 @@ func NewProvisioningController(clientSet clientset.Interface, addTenantHandlers(c.tenantInformer, c.enqueueTenant) addPlatformHandlers(c.platformInformer) - addAzureDbHandlers(c.azureDbInformer, c.enqueueAllTenant) - addAzureManagedDbHandlers(c.azureManagedDbInformer, c.enqueueAllTenant) - addHelmReleaseHandlers(c.helmReleaseInformer, c.enqueueAllTenant) - addAzureVirtualMachineHandlers(c.azureVirtualMachineInformer, c.enqueueAllTenant) - addAzureVirtualDesktopHandlers(c.azureVirtualDesktopInformer, c.enqueueAllTenant) + + addResourceHandlers[*provisioningv1.AzureDatabase]("Azure database", c.azureDbInformer.Informer(), c.enqueueDomain) + addResourceHandlers[*provisioningv1.AzureManagedDatabase]("Azure managed database", c.azureManagedDbInformer.Informer(), c.enqueueDomain) + addResourceHandlers[*provisioningv1.HelmRelease]("Helm release", c.helmReleaseInformer.Informer(), c.enqueueDomain) + addResourceHandlers[*provisioningv1.AzureVirtualMachine]("Azure virtual machine", c.azureVirtualMachineInformer.Informer(), c.enqueueDomain) + addResourceHandlers[*provisioningv1.AzureVirtualDesktop]("Azure virtual Desktop", c.azureVirtualDesktopInformer.Informer(), c.enqueueDomain) return c } @@ -212,8 +209,8 @@ func (c *ProvisioningController) processNextWorkItem(i int) bool { // with the current status of the resource. func (c *ProvisioningController) syncHandler(key string) error { // Convert the namespace/name string into a distinct namespace and name - platform, tenantKey, _ := decodeKey(key) - namespace, name, err := cache.SplitMetaNamespaceKey(tenantKey) + platformKey, tenantKey, domainKey, _ := decodeKey(key) + tenantNamespace, tenantName, err := cache.SplitMetaNamespaceKey(tenantKey) if err != nil { utilruntime.HandleError(fmt.Errorf("invalid tenant key: %s", tenantKey)) return nil @@ -221,11 +218,17 @@ func (c *ProvisioningController) syncHandler(key string) error { // Get the Tenant resource with this namespace/name // use the live query API, to get the latest version instead of listers which are cached - tenant, err := c.clientset.PlatformV1alpha1().Tenants(namespace).Get(context.TODO(), name, metav1.GetOptions{}) - if shouldCleanupTenantResources := (err != nil && errors.IsNotFound(err)) || (err == nil && tenant.Spec.PlatformRef != platform); shouldCleanupTenantResources { - cleanupResult := c.provisioner(platform, &platformv1.Tenant{ - ObjectMeta: metav1.ObjectMeta{Name: name}, - Spec: platformv1.TenantSpec{PlatformRef: platform}}, + tenant, err := c.clientset.PlatformV1alpha1().Tenants(tenantNamespace).Get(context.TODO(), tenantName, metav1.GetOptions{}) + shouldCleanupResources := + (err != nil && errors.IsNotFound(err)) || + (err == nil && (tenant.Spec.PlatformRef != platformKey || + !slices.Contains(tenant.Spec.DomainRefs, domainKey))) + + if shouldCleanupResources { + cleanupResult := c.provisioner(platformKey, &platformv1.Tenant{ + ObjectMeta: metav1.ObjectMeta{Name: tenantName}, + Spec: platformv1.TenantSpec{PlatformRef: platformKey}}, + domainKey, &provisioners.InfrastructureManifests{ AzureDbs: []*provisioningv1.AzureDatabase{}, AzureManagedDbs: []*provisioningv1.AzureManagedDatabase{}, @@ -253,98 +256,53 @@ func (c *ProvisioningController) syncHandler(key string) error { if err != nil { return err } - - n := 0 - for _, db := range azureDbs { - if db.Spec.PlatformRef == platform && slices.Contains(tenant.Spec.DomainRefs, db.Spec.DomainRef) { - err := applyTenantOverrides(db, tenant.Name) - if err != nil { - return err - } - - azureDbs[n] = db - n++ - } + azureDbs = selectItemsInPlatformAndDomain(platformKey, domainKey, azureDbs) + err = applyTenantOverrides(azureDbs, tenant.Name) + if err != nil { + return err } - azureDbs = azureDbs[:n] azureManagedDbs, err := c.azureManagedDbInformer.Lister().List(skipTenantLabelSelector) if err != nil { return err } - - n = 0 - for _, db := range azureManagedDbs { - if db.Spec.PlatformRef == platform && slices.Contains(tenant.Spec.DomainRefs, db.Spec.DomainRef) { - err := applyTenantOverrides(db, tenant.Name) - if err != nil { - return err - } - - azureManagedDbs[n] = db - n++ - } + azureManagedDbs = selectItemsInPlatformAndDomain(platformKey, domainKey, azureManagedDbs) + err = applyTenantOverrides(azureManagedDbs, tenant.Name) + if err != nil { + return err } - azureManagedDbs = azureManagedDbs[:n] helmReleases, err := c.helmReleaseInformer.Lister().List(skipTenantLabelSelector) if err != nil { return err } - - n = 0 - for _, hr := range helmReleases { - if hr.Spec.PlatformRef == platform && slices.Contains(tenant.Spec.DomainRefs, hr.Spec.DomainRef) { - err := applyTenantOverrides(hr, tenant.Name) - if err != nil { - return err - } - - helmReleases[n] = hr - n++ - } + helmReleases = selectItemsInPlatformAndDomain(platformKey, domainKey, helmReleases) + err = applyTenantOverrides(helmReleases, tenant.Name) + if err != nil { + return err } - helmReleases = helmReleases[:n] azureVirtualMachines, err := c.azureVirtualMachineInformer.Lister().List(skipTenantLabelSelector) if err != nil { return err } - - n = 0 - for _, vm := range azureVirtualMachines { - if vm.Spec.PlatformRef == platform && slices.Contains(tenant.Spec.DomainRefs, vm.Spec.DomainRef) { - err := applyTenantOverrides(vm, tenant.Name) - if err != nil { - return err - } - - azureVirtualMachines[n] = vm - n++ - } + azureVirtualMachines = selectItemsInPlatformAndDomain(platformKey, domainKey, azureVirtualMachines) + err = applyTenantOverrides(azureVirtualMachines, tenant.Name) + if err != nil { + return err } - azureVirtualMachines = azureVirtualMachines[:n] azureVirtualDesktops, err := c.azureVirtualDesktopInformer.Lister().List(skipTenantLabelSelector) if err != nil { return err } - - n = 0 - for _, avd := range azureVirtualDesktops { - if avd.Spec.PlatformRef == platform && slices.Contains(tenant.Spec.DomainRefs, avd.Spec.DomainRef) { - err := applyTenantOverrides(avd, tenant.Name) - if err != nil { - return err - } - - azureVirtualDesktops[n] = avd - n++ - } + azureVirtualDesktops = selectItemsInPlatformAndDomain(platformKey, domainKey, azureVirtualDesktops) + err = applyTenantOverrides(azureVirtualDesktops, tenant.Name) + if err != nil { + return err } - azureVirtualDesktops = azureVirtualDesktops[:n] - result := c.provisioner(platform, tenant, &provisioners.InfrastructureManifests{ + result := c.provisioner(platformKey, tenant, domainKey, &provisioners.InfrastructureManifests{ AzureDbs: azureDbs, AzureManagedDbs: azureManagedDbs, HelmReleases: helmReleases, @@ -353,10 +311,10 @@ func (c *ProvisioningController) syncHandler(key string) error { }) if result.Error == nil { - if c.migrator != nil && result.HasChanges { - p, err := c.clientset.PlatformV1alpha1().Platforms().Get(context.TODO(), platform, metav1.GetOptions{}) + if c.migrator != nil && result.HasChanges && tenant.Spec.Enabled { + platform, err := c.clientset.PlatformV1alpha1().Platforms().Get(context.TODO(), platformKey, metav1.GetOptions{}) if err == nil { - result.Error = c.migrator(p.Spec.TargetNamespace, tenant) + result.Error = c.migrator(platform.Spec.TargetNamespace, tenant, domainKey) } else { klog.ErrorS(err, "platform not found") } @@ -364,25 +322,27 @@ func (c *ProvisioningController) syncHandler(key string) error { } if result.Error == nil { - c.recorder.Event(tenant, corev1.EventTypeNormal, controllers.SuccessSynced, controllers.SuccessSynced) + c.recorder.Event(tenant, corev1.EventTypeNormal, fmt.Sprintf(DomainProvisionedSuccessfullyFormat, domainKey), fmt.Sprintf(DomainProvisionedSuccessfullyFormat, domainKey)) var ev = struct { TenantId string TenantName string TenantDescription string Platform string + Domain string }{ TenantId: tenant.Spec.Id, TenantName: tenant.Name, TenantDescription: tenant.Spec.Description, - Platform: platform, + Platform: platformKey, + Domain: domainKey, } - err = c.messagingPublisher(context.TODO(), tenantProvisionedSuccessfullyTopic, ev, platform) + err = c.messagingPublisher(context.TODO(), tenantProvisionedSuccessfullyTopic, ev, platformKey) if err != nil { klog.ErrorS(err, "message publisher error") } } else { - c.recorder.Event(tenant, corev1.EventTypeWarning, controllers.ErrorSynced, result.Error.Error()) + c.recorder.Event(tenant, corev1.EventTypeWarning, fmt.Sprintf(DomainProvisionningFailedFormat, domainKey), result.Error.Error()) var ev = struct { TenantId string @@ -390,58 +350,25 @@ func (c *ProvisioningController) syncHandler(key string) error { TenantDescription string Platform string Error string + Domain string }{ TenantId: tenant.Spec.Id, TenantName: tenant.Name, TenantDescription: tenant.Spec.Description, - Platform: platform, + Platform: platformKey, + Domain: domainKey, Error: result.Error.Error(), } - err = c.messagingPublisher(context.TODO(), tenantProvisionningFailedTopic, ev, platform) + err = c.messagingPublisher(context.TODO(), tenantProvisionningFailedTopic, ev, platformKey) if err != nil { klog.ErrorS(err, "message publisher error") } } - _, e := c.updateTenantStatus(tenant, result.Error) - if e != nil { - //just log this error, don't propagate - utilruntime.HandleError(e) - } return result.Error } -func (c *ProvisioningController) updateTenantStatus(tenant *platformv1.Tenant, err error) (*platformv1.Tenant, error) { - // NEVER modify objects from the store. It's a read-only, local cache. - // You can use DeepCopy() to make a deep copy of original object and modify this copy - // Or create a copy manually for better performance - tenantCopy := tenant.DeepCopy() - tenantCopy.Status.LastResyncTime = metav1.Now() - - if err != nil { - apimeta.SetStatusCondition(&tenantCopy.Status.Conditions, metav1.Condition{ - Type: "Ready", - Status: metav1.ConditionFalse, - Reason: controllers.FailedReason, - Message: err.Error(), - }) - } else { - apimeta.SetStatusCondition(&tenantCopy.Status.Conditions, metav1.Condition{ - Type: "Ready", - Status: metav1.ConditionTrue, - Reason: controllers.SucceededReason, - Message: controllers.SuccessSynced, - }) - } - - // If the CustomResourceSubresources feature gate is not enabled, - // we must use Update instead of UpdateStatus to update the Status block of the resource. - // UpdateStatus will not allow changes to the Spec of the resource, - // which is ideal for ensuring nothing other than resource status has been updated. - return c.clientset.PlatformV1alpha1().Tenants(tenant.Namespace).UpdateStatus(context.TODO(), tenantCopy, metav1.UpdateOptions{}) -} - -func (c *ProvisioningController) enqueueAllTenant(platform string) { +func (c *ProvisioningController) enqueueDomain(platform, domain string) { tenants, err := c.tenantInformer.Lister().List(labels.Everything()) if err != nil { utilruntime.HandleError(err) @@ -449,7 +376,7 @@ func (c *ProvisioningController) enqueueAllTenant(platform string) { } for _, tenant := range tenants { if tenant.Spec.PlatformRef == platform { - c.enqueueTenant(tenant) + c.enqueueTenantDomain(tenant, domain) } } } @@ -466,19 +393,40 @@ func (c *ProvisioningController) enqueueTenant(tenant *platformv1.Tenant) { utilruntime.HandleError(err) return } - c.workqueue.Add(encodeKey(tenant.Spec.PlatformRef, tenantKey)) + + for _, domain := range tenant.Spec.DomainRefs { + c.workqueue.Add(encodeKey(tenant.Spec.PlatformRef, tenantKey, domain)) + } +} + +func (c *ProvisioningController) enqueueTenantDomain(tenant *platformv1.Tenant, domain string) { + var tenantKey string + var err error + + if v, ok := tenant.Labels[SkipProvisioningLabel]; ok && v == "true" { + return + } + + if tenantKey, err = cache.MetaNamespaceKeyFunc(tenant); err != nil { + utilruntime.HandleError(err) + return + } + + if slices.Contains(tenant.Spec.DomainRefs, domain) { + c.workqueue.Add(encodeKey(tenant.Spec.PlatformRef, tenantKey, domain)) + } } -func encodeKey(platformKey, tenantKey string) (key string) { - return fmt.Sprintf("%s::%s", platformKey, tenantKey) +func encodeKey(platform, tenant, domain string) (key string) { + return fmt.Sprintf("%s::%s::%s", platform, tenant, domain) } -func decodeKey(key string) (platformKey, tenantKey string, err error) { +func decodeKey(key string) (platform, tenant, domain string, err error) { res := strings.Split(key, "::") - if len(res) == 2 { - return res[0], res[1], nil + if len(res) == 3 { + return res[0], res[1], res[2], nil } - return "", "", fmt.Errorf("cannot decode key: %v", key) + return "", "", "", fmt.Errorf("cannot decode key: %v", key) } func addPlatformHandlers(informer platformInformersv1.PlatformInformer) { @@ -534,181 +482,41 @@ func addTenantHandlers(informer platformInformersv1.TenantInformer, handler func }) } -func addAzureDbHandlers(informer provisioningInformersv1.AzureDatabaseInformer, handler func(platform string)) { - informer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{ - AddFunc: func(obj interface{}) { - comp := obj.(*provisioningv1.AzureDatabase) - if platform, ok := getAzureDbPlatform(comp); ok { - klog.V(4).InfoS("Azure database added", "name", comp.Name, "namespace", comp.Namespace, "platform", platform) - handler(platform) - } - }, - UpdateFunc: func(oldObj, newObj interface{}) { - oldComp := oldObj.(*provisioningv1.AzureDatabase) - newComp := newObj.(*provisioningv1.AzureDatabase) - oldPlatform, oldOk := getAzureDbPlatform(oldComp) - newPlatform, newOk := getAzureDbPlatform(newComp) - platformChanged := oldPlatform != newPlatform - - if oldOk && platformChanged { - klog.V(4).InfoS("Azure database invalidated", "name", oldComp.Name, "namespace", oldComp.Namespace, "platform", oldPlatform) - handler(oldPlatform) - } - - if newOk { - klog.V(4).InfoS("Azure database updated", "name", newComp.Name, "namespace", newComp.Namespace, "platform", newPlatform) - handler(newPlatform) - } - }, - DeleteFunc: func(obj interface{}) { - comp := obj.(*provisioningv1.AzureDatabase) - if platform, ok := getAzureDbPlatform(comp); ok { - klog.V(4).InfoS("Azure database deleted", "name", comp.Name, "namespace", comp.Namespace, "platform", platform) - handler(platform) - } - }, - }) -} - -func addAzureManagedDbHandlers(informer provisioningInformersv1.AzureManagedDatabaseInformer, handler func(platform string)) { - informer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{ +func addResourceHandlers[R ProvisioningResource](resType string, informer cache.SharedIndexInformer, handler func(platform, domain string)) { + informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ AddFunc: func(obj interface{}) { - comp := obj.(*provisioningv1.AzureManagedDatabase) - if platform, ok := getAzureManagedDbPlatform(comp); ok { - klog.V(4).InfoS("Azure managed database added", "name", comp.Name, "namespace", comp.Namespace, "platform", platform) - handler(platform) + comp := obj.(R) + if platform, domain, ok := getPlatformAndDomain(comp); ok { + msg := fmt.Sprintf("%s added", resType) + klog.V(4).InfoS(msg, "name", comp.GetName(), "namespace", comp.GetNamespace(), "platform", platform, "domain", domain) + handler(platform, domain) } }, UpdateFunc: func(oldObj, newObj interface{}) { - oldComp := oldObj.(*provisioningv1.AzureManagedDatabase) - newComp := newObj.(*provisioningv1.AzureManagedDatabase) - oldPlatform, oldOk := getAzureManagedDbPlatform(oldComp) - newPlatform, newOk := getAzureManagedDbPlatform(newComp) - platformChanged := oldPlatform != newPlatform - - if oldOk && platformChanged { - klog.V(4).InfoS("Azure managed database invalidated", "name", oldComp.Name, "namespace", oldComp.Namespace, "platform", oldPlatform) - handler(oldPlatform) + oldComp := oldObj.(R) + newComp := newObj.(R) + oldPlatform, oldDomain, oldOk := getPlatformAndDomain(oldComp) + newPlatform, newDomain, newOk := getPlatformAndDomain(newComp) + platformOrDomainChanged := oldPlatform != newPlatform || oldDomain != newDomain + + if oldOk && platformOrDomainChanged { + msg := fmt.Sprintf("%s invalidated", resType) + klog.V(4).InfoS(msg, "name", oldComp.GetName(), "namespace", oldComp.GetNamespace(), "platform", oldPlatform, "domain", oldDomain) + handler(oldPlatform, oldDomain) } if newOk { - klog.V(4).InfoS("Azure managed database updated", "name", newComp.Name, "namespace", newComp.Namespace, "platform", newPlatform) - handler(newPlatform) + msg := fmt.Sprintf("%s updated", resType) + klog.V(4).InfoS(msg, "name", newComp.GetName(), "namespace", newComp.GetNamespace(), "platform", newPlatform, "domain", newDomain) + handler(newPlatform, newDomain) } }, DeleteFunc: func(obj interface{}) { - comp := obj.(*provisioningv1.AzureManagedDatabase) - if platform, ok := getAzureManagedDbPlatform(comp); ok { - klog.V(4).InfoS("Azure managed database deleted", "name", comp.Name, "namespace", comp.Namespace, "platform", platform) - handler(platform) - } - }, - }) -} - -func addHelmReleaseHandlers(informer provisioningInformersv1.HelmReleaseInformer, handler func(platform string)) { - informer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{ - AddFunc: func(obj interface{}) { - comp := obj.(*provisioningv1.HelmRelease) - if platform, ok := getHelmReleasePlatform(comp); ok { - klog.V(4).InfoS("Helm release added", "name", comp.Name, "namespace", comp.Namespace, "platform", platform) - handler(platform) - } - }, - UpdateFunc: func(oldObj, newObj interface{}) { - oldComp := oldObj.(*provisioningv1.HelmRelease) - newComp := newObj.(*provisioningv1.HelmRelease) - oldPlatform, oldOk := getHelmReleasePlatform(oldComp) - newPlatform, newOk := getHelmReleasePlatform(newComp) - platformChanged := oldPlatform != newPlatform - - if oldOk && platformChanged { - klog.V(4).InfoS("Helm release invalidated", "name", oldComp.Name, "namespace", oldComp.Namespace, "platform", oldPlatform) - handler(oldPlatform) - } - - if newOk { - klog.V(4).InfoS("Helm release updated", "name", newComp.Name, "namespace", newComp.Namespace, "platform", newPlatform) - handler(newPlatform) - } - }, - DeleteFunc: func(obj interface{}) { - comp := obj.(*provisioningv1.HelmRelease) - if platform, ok := getHelmReleasePlatform(comp); ok { - klog.V(4).InfoS("Helm release deleted", "name", comp.Name, "namespace", comp.Namespace, "platform", platform) - handler(platform) - } - }, - }) -} - -func addAzureVirtualDesktopHandlers(informer provisioningInformersv1.AzureVirtualDesktopInformer, handler func(platform string)) { - informer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{ - AddFunc: func(obj interface{}) { - comp := obj.(*provisioningv1.AzureVirtualDesktop) - if platform, ok := getAzureVirtualDesktopPlatform(comp); ok { - klog.V(4).InfoS("Azure virtual Desktop added", "name", comp.Name, "namespace", comp.Namespace, "platform", platform) - handler(platform) - } - }, - UpdateFunc: func(oldObj, newObj interface{}) { - oldComp := oldObj.(*provisioningv1.AzureVirtualDesktop) - newComp := newObj.(*provisioningv1.AzureVirtualDesktop) - oldPlatform, oldOk := getAzureVirtualDesktopPlatform(oldComp) - newPlatform, newOk := getAzureVirtualDesktopPlatform(newComp) - platformChanged := oldPlatform != newPlatform - - if oldOk && platformChanged { - klog.V(4).InfoS("Azure virtual desktop invalidated", "name", oldComp.Name, "namespace", oldComp.Namespace, "platform", oldPlatform) - handler(oldPlatform) - } - - if newOk { - klog.V(4).InfoS("Azure virtual desktop updated", "name", newComp.Name, "namespace", newComp.Namespace, "platform", newPlatform) - handler(newPlatform) - } - }, - DeleteFunc: func(obj interface{}) { - comp := obj.(*provisioningv1.AzureVirtualDesktop) - if platform, ok := getAzureVirtualDesktopPlatform(comp); ok { - klog.V(4).InfoS("Azure virtual desktop deleted", "name", comp.Name, "namespace", comp.Namespace, "platform", platform) - handler(platform) - } - }, - }) -} - -func addAzureVirtualMachineHandlers(informer provisioningInformersv1.AzureVirtualMachineInformer, handler func(platform string)) { - informer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{ - AddFunc: func(obj interface{}) { - comp := obj.(*provisioningv1.AzureVirtualMachine) - if platform, ok := getAzureVirtualMachinePlatform(comp); ok { - klog.V(4).InfoS("Azure virtual machine added", "name", comp.Name, "namespace", comp.Namespace, "platform", platform) - handler(platform) - } - }, - UpdateFunc: func(oldObj, newObj interface{}) { - oldComp := oldObj.(*provisioningv1.AzureVirtualMachine) - newComp := newObj.(*provisioningv1.AzureVirtualMachine) - oldPlatform, oldOk := getAzureVirtualMachinePlatform(oldComp) - newPlatform, newOk := getAzureVirtualMachinePlatform(newComp) - platformChanged := oldPlatform != newPlatform - - if oldOk && platformChanged { - klog.V(4).InfoS("Azure virtual machine invalidated", "name", oldComp.Name, "namespace", oldComp.Namespace, "platform", oldPlatform) - handler(oldPlatform) - } - - if newOk { - klog.V(4).InfoS("Azure virtual machine updated", "name", newComp.Name, "namespace", newComp.Namespace, "platform", newPlatform) - handler(newPlatform) - } - }, - DeleteFunc: func(obj interface{}) { - comp := obj.(*provisioningv1.AzureVirtualMachine) - if platform, ok := getAzureVirtualMachinePlatform(comp); ok { - klog.V(4).InfoS("Azure virtual machine deleted", "name", comp.Name, "namespace", comp.Namespace, "platform", platform) - handler(platform) + comp := obj.(R) + if platform, domain, ok := getPlatformAndDomain(comp); ok { + msg := fmt.Sprintf("%s deleted", resType) + klog.V(4).InfoS(msg, "name", comp.GetName(), "namespace", comp.GetNamespace(), "platform", platform, "domain", domain) + handler(platform, domain) } }, }) @@ -722,140 +530,3 @@ func getTenantPlatform(tenant *platformv1.Tenant) (platform string, ok bool) { return platform, true } - -func getAzureDbPlatform(azureDb *provisioningv1.AzureDatabase) (platform string, ok bool) { - platform = azureDb.Spec.PlatformRef - if len(platform) == 0 { - return platform, false - } - - return platform, true -} - -func getAzureManagedDbPlatform(azureManagedDb *provisioningv1.AzureManagedDatabase) (platform string, ok bool) { - platform = azureManagedDb.Spec.PlatformRef - if len(platform) == 0 { - return platform, false - } - - return platform, true -} - -func getHelmReleasePlatform(helmRelease *provisioningv1.HelmRelease) (platform string, ok bool) { - platform = helmRelease.Spec.PlatformRef - if len(platform) == 0 { - return platform, false - } - - return platform, true -} - -func getAzureVirtualMachinePlatform(azureVirtualMachine *provisioningv1.AzureVirtualMachine) (platform string, ok bool) { - platform = azureVirtualMachine.Spec.PlatformRef - if len(platform) == 0 { - return platform, false - } - - return platform, true -} - -func getAzureVirtualDesktopPlatform(azureVirtualDesktop *provisioningv1.AzureVirtualDesktop) (platform string, ok bool) { - platform = azureVirtualDesktop.Spec.PlatformRef - if len(platform) == 0 { - return platform, false - } - - return platform, true -} - -type ProvisioningResource interface { - *provisioningv1.AzureDatabase | *provisioningv1.AzureManagedDatabase | *provisioningv1.HelmRelease | *provisioningv1.AzureVirtualMachine | *provisioningv1.AzureVirtualDesktop - - GetProvisioningMeta() *provisioningv1.ProvisioningMeta - GetSpec() any -} - -func applyTenantOverrides[T ProvisioningResource](target T, tenantName string) error { - if target == nil { - return nil - } - - overrides := target.GetProvisioningMeta().TenantOverrides - - if overrides == nil { - return nil - } - - tenantOverridesJson, exists := overrides[tenantName] - if !exists { - return nil - } - - var tenantOverridesMap map[string]any - if err := json.Unmarshal(tenantOverridesJson.Raw, &tenantOverridesMap); err != nil { - return err - } - - targetSpec := target.GetSpec() - - targetSpecJsonBytes, err := json.Marshal(targetSpec) - if err != nil { - return err - } - - var targetSpecMap map[string]any - if err := json.Unmarshal(targetSpecJsonBytes, &targetSpecMap); err != nil { - return err - } - - if err := mergo.Merge(&targetSpecMap, tenantOverridesMap, mergo.WithOverride, mergo.WithTransformers(jsonTransformer{})); err != nil { - return err - } - - targetSpecJsonBytes, err = json.Marshal(targetSpecMap) - if err != nil { - return err - } - - if err := json.Unmarshal(targetSpecJsonBytes, targetSpec); err != nil { - return err - } - - return nil -} - -type jsonTransformer struct { -} - -func (t jsonTransformer) Transformer(typ reflect.Type) func(dst, src reflect.Value) error { - if typ == reflect.TypeOf(apiextensionsv1.JSON{}) { - return func(dst, src reflect.Value) error { - if dst.CanSet() { - srcRaw := src.FieldByName("Raw").Bytes() - var srcMap map[string]interface{} - if err := json.Unmarshal(srcRaw, &srcMap); err != nil { - return err - } - - dstRaw := dst.FieldByName("Raw").Bytes() - var dstMap map[string]interface{} - if err := json.Unmarshal(dstRaw, &dstMap); err != nil { - return err - } - - if err := mergo.Merge(&dstMap, srcMap, mergo.WithOverride, mergo.WithTransformers(jsonTransformer{})); err != nil { - return err - } - - dstRaw, err := json.Marshal(dstMap) - if err != nil { - return err - } - - dst.FieldByName("Raw").SetBytes(dstRaw) - } - return nil - } - } - return nil -} diff --git a/internal/controllers/provisioning/provisioning_controller_test.go b/internal/controllers/provisioning/provisioning_controller_test.go index 0f2c03d..506ef13 100644 --- a/internal/controllers/provisioning/provisioning_controller_test.go +++ b/internal/controllers/provisioning/provisioning_controller_test.go @@ -23,10 +23,11 @@ import ( func TestProvisioningController_processNextWorkItem(t *testing.T) { t.Run("add three tenants", func(t *testing.T) { + domain := "my-domain" objects := []runtime.Object{ - newTenant("dev1", "dev"), - newTenant("dev2", "dev"), - newTenant("dev3", "qa"), + newTenant("dev1", "dev", domain), + newTenant("dev2", "dev", domain), + newTenant("dev3", "qa", domain), } c, outputs, msgChan := runControllerWithDefaultFakes(objects) @@ -85,17 +86,18 @@ func TestProvisioningController_processNextWorkItem(t *testing.T) { }) - t.Run("ignores same tenant updates while an update for same tenant in progress", func(t *testing.T) { + t.Run("ignores same tenant-domain updates while an update for same tenant-domain in progress", func(t *testing.T) { //Arrange - tenant := newTenant("dev1", "dev") + domain := "my-domain" + tenant := newTenant("dev1", "dev", domain) objects := []runtime.Object{tenant} wg := &sync.WaitGroup{} wg.Add(1) var outputs []provisionerResult - infraCreator := func(platform string, tenant *platformv1.Tenant, infra *provisioners.InfrastructureManifests) provisioners.ProvisioningResult { - outputs = append(outputs, provisionerResult{platform, tenant, infra}) + infraCreator := func(platform string, tenant *platformv1.Tenant, domain string, infra *provisioners.InfrastructureManifests) provisioners.ProvisioningResult { + outputs = append(outputs, provisionerResult{platform, tenant, domain, infra}) wg.Wait() //wait for other tenant updates return provisioners.ProvisioningResult{} } @@ -128,10 +130,11 @@ func TestProvisioningController_processNextWorkItem(t *testing.T) { }) t.Run("add one tenant and one database same platform", func(t *testing.T) { + domain := "my-domain" objects := []runtime.Object{ - newTenant("dev1", "dev"), - newAzureDb("db1", "dev"), - newAzureManagedDb("db1", "dev"), + newTenant("dev1", "dev", domain), + newAzureDb("db1", "dev", domain), + newAzureManagedDb("db1", "dev", domain), } c, outputs, msgChan := runControllerWithDefaultFakes(objects) @@ -168,10 +171,11 @@ func TestProvisioningController_processNextWorkItem(t *testing.T) { }) t.Run("add one tenant and one database different platforms", func(t *testing.T) { + domain := "my-domain" objects := []runtime.Object{ - newTenant("dev1", "dev"), - newAzureDb("db1", "dev2"), - newAzureManagedDb("db1", "dev2"), + newTenant("dev1", "dev", domain), + newAzureDb("db1", "dev2", domain), + newAzureManagedDb("db1", "dev2", domain), } c, outputs, msgChan := runControllerWithDefaultFakes(objects) @@ -208,9 +212,9 @@ func TestProvisioningController_processNextWorkItem(t *testing.T) { }) t.Run("skip tenant resource provisioning", func(t *testing.T) { - - tenant := newTenant("dev1", "dev") - azureDb := newAzureDb("db1", "dev") + domain := "my-domain" + tenant := newTenant("dev1", "dev", domain) + azureDb := newAzureDb("db1", "dev", domain) azureDb.ObjectMeta.Labels = map[string]string{ "provisioning.totalsoft.ro/skip-tenant-dev1": "true", } @@ -243,9 +247,8 @@ func TestProvisioningController_processNextWorkItem(t *testing.T) { }) t.Run("filter resource by Domain", func(t *testing.T) { - - tenant := newTenantWithService("dev1", "dev", "p1") - azureDb := newAzureDbWithService("db1", "dev", "p2") + tenant := newTenant("dev1", "dev", "p1") + azureDb := newAzureDb("db1", "dev", "p2") objects := []runtime.Object{ tenant, azureDb, @@ -301,7 +304,7 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) { }, } - err := applyTenantOverrides(&db, tenantName) + err := applyTenantOverrides([]*provisioningv1.AzureManagedDatabase{&db}, tenantName) if err != nil { t.Error(err) } @@ -339,7 +342,7 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) { }, } - err := applyTenantOverrides(&hr, tenantName) + err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName) if err != nil { t.Error(err) } @@ -381,7 +384,7 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) { }, } - err := applyTenantOverrides(&avd, tenantName) + err := applyTenantOverrides([]*provisioningv1.AzureVirtualDesktop{&avd}, tenantName) if err != nil { t.Error(err) } @@ -428,7 +431,7 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) { }, } - err := applyTenantOverrides(&hr, tenantName) + err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName) if err != nil { t.Error(err) } @@ -473,7 +476,7 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) { }, } - err := applyTenantOverrides(&hr, tenantName) + err := applyTenantOverrides([]*provisioningv1.HelmRelease{&hr}, tenantName) if err != nil { t.Error(err) } @@ -485,7 +488,7 @@ func TestProvisioningController_applyTenantOverrides(t *testing.T) { } -func newTenantWithService(name, platform, domain string) *platformv1.Tenant { +func newTenant(name, platform string, domains ...string) *platformv1.Tenant { return &platformv1.Tenant{ TypeMeta: metav1.TypeMeta{APIVersion: provisioningv1.SchemeGroupVersion.String()}, ObjectMeta: metav1.ObjectMeta{ @@ -495,16 +498,12 @@ func newTenantWithService(name, platform, domain string) *platformv1.Tenant { Spec: platformv1.TenantSpec{ PlatformRef: platform, Description: name + " description", - DomainRefs: []string{domain}, + DomainRefs: domains, }, } } -func newTenant(name, platform string) *platformv1.Tenant { - return newTenantWithService(name, platform, "") -} - -func newAzureDbWithService(name, platform, domain string) *provisioningv1.AzureDatabase { +func newAzureDb(name, platform, domain string) *provisioningv1.AzureDatabase { return &provisioningv1.AzureDatabase{ TypeMeta: metav1.TypeMeta{APIVersion: provisioningv1.SchemeGroupVersion.String()}, ObjectMeta: metav1.ObjectMeta{ @@ -521,11 +520,7 @@ func newAzureDbWithService(name, platform, domain string) *provisioningv1.AzureD } } -func newAzureDb(name, platform string) *provisioningv1.AzureDatabase { - return newAzureDbWithService(name, platform, "") -} - -func newAzureManagedDb(dbName, platform string) *provisioningv1.AzureManagedDatabase { +func newAzureManagedDb(dbName, platform string, domain string) *provisioningv1.AzureManagedDatabase { return &provisioningv1.AzureManagedDatabase{ TypeMeta: metav1.TypeMeta{APIVersion: provisioningv1.SchemeGroupVersion.String()}, ObjectMeta: metav1.ObjectMeta{ @@ -535,6 +530,7 @@ func newAzureManagedDb(dbName, platform string) *provisioningv1.AzureManagedData Spec: provisioningv1.AzureManagedDatabaseSpec{ ProvisioningMeta: provisioningv1.ProvisioningMeta{ PlatformRef: platform, + DomainRef: domain, }, DbName: dbName, }, @@ -554,8 +550,8 @@ func runController(objects []runtime.Object, provisioner provisioners.CreateInfr func runControllerWithDefaultFakes(objects []runtime.Object) (*ProvisioningController, *[]provisionerResult, chan messagingMock.RcvMsg) { var outputs []provisionerResult - infraCreator := func(platform string, tenant *platformv1.Tenant, infra *provisioners.InfrastructureManifests) provisioners.ProvisioningResult { - outputs = append(outputs, provisionerResult{platform, tenant, infra}) + infraCreator := func(platform string, tenant *platformv1.Tenant, domain string, infra *provisioners.InfrastructureManifests) provisioners.ProvisioningResult { + outputs = append(outputs, provisionerResult{platform, tenant, domain, infra}) return provisioners.ProvisioningResult{} } @@ -572,5 +568,6 @@ func runControllerWithDefaultFakes(objects []runtime.Object) (*ProvisioningContr type provisionerResult struct { platform string tenant *platformv1.Tenant + domain string infra *provisioners.InfrastructureManifests } diff --git a/internal/controllers/provisioning/provisioning_types.go b/internal/controllers/provisioning/provisioning_types.go new file mode 100644 index 0000000..35b96cf --- /dev/null +++ b/internal/controllers/provisioning/provisioning_types.go @@ -0,0 +1,126 @@ +package provisioning + +import ( + "encoding/json" + "reflect" + + "dario.cat/mergo" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" + provisioningv1 "totalsoft.ro/platform-controllers/pkg/apis/provisioning/v1alpha1" +) + +type ProvisioningResource interface { + *provisioningv1.AzureDatabase | *provisioningv1.AzureManagedDatabase | *provisioningv1.HelmRelease | *provisioningv1.AzureVirtualMachine | *provisioningv1.AzureVirtualDesktop + + GetProvisioningMeta() *provisioningv1.ProvisioningMeta + GetSpec() any + GetName() string + GetNamespace() string +} + +func getPlatformAndDomain[R ProvisioningResource](res R) (platform, domain string, ok bool) { + platform = res.GetProvisioningMeta().PlatformRef + domain = res.GetProvisioningMeta().DomainRef + if len(platform) == 0 || len(domain) == 0 { + return platform, domain, false + } + + return platform, domain, true +} + +func selectItemsInPlatformAndDomain[R ProvisioningResource](platform, domain string, source []R) []R { + result := []R{} + for _, res := range source { + if res.GetProvisioningMeta().PlatformRef == platform && res.GetProvisioningMeta().DomainRef == domain { + result = append(result, res) + } + } + return result +} + +func applyTenantOverrides[T ProvisioningResource](targets []T, tenantName string) error { + if targets == nil { + return nil + } + + for _, target := range targets { + overrides := target.GetProvisioningMeta().TenantOverrides + + if overrides == nil { + continue + } + + tenantOverridesJson, exists := overrides[tenantName] + if !exists { + continue + } + + var tenantOverridesMap map[string]any + if err := json.Unmarshal(tenantOverridesJson.Raw, &tenantOverridesMap); err != nil { + return err + } + + targetSpec := target.GetSpec() + + targetSpecJsonBytes, err := json.Marshal(targetSpec) + if err != nil { + return err + } + + var targetSpecMap map[string]any + if err := json.Unmarshal(targetSpecJsonBytes, &targetSpecMap); err != nil { + return err + } + + if err := mergo.Merge(&targetSpecMap, tenantOverridesMap, mergo.WithOverride, mergo.WithTransformers(jsonTransformer{})); err != nil { + return err + } + + targetSpecJsonBytes, err = json.Marshal(targetSpecMap) + if err != nil { + return err + } + + if err := json.Unmarshal(targetSpecJsonBytes, targetSpec); err != nil { + return err + } + } + + return nil +} + +type jsonTransformer struct { +} + +func (t jsonTransformer) Transformer(typ reflect.Type) func(dst, src reflect.Value) error { + if typ == reflect.TypeOf(apiextensionsv1.JSON{}) { + return func(dst, src reflect.Value) error { + if dst.CanSet() { + srcRaw := src.FieldByName("Raw").Bytes() + var srcMap map[string]interface{} + if err := json.Unmarshal(srcRaw, &srcMap); err != nil { + return err + } + + dstRaw := dst.FieldByName("Raw").Bytes() + var dstMap map[string]interface{} + if err := json.Unmarshal(dstRaw, &dstMap); err != nil { + return err + } + + if err := mergo.Merge(&dstMap, srcMap, mergo.WithOverride, mergo.WithTransformers(jsonTransformer{})); err != nil { + return err + } + + dstRaw, err := json.Marshal(dstMap) + if err != nil { + return err + } + + dst.FieldByName("Raw").SetBytes(dstRaw) + } + return nil + } + } + return nil +} diff --git a/pkg/apis/provisioning/v1alpha1/commonTypes.go b/pkg/apis/provisioning/v1alpha1/commonTypes.go index 71129e3..4308286 100644 --- a/pkg/apis/provisioning/v1alpha1/commonTypes.go +++ b/pkg/apis/provisioning/v1alpha1/commonTypes.go @@ -29,7 +29,3 @@ type ProvisioningMeta struct { // +optional TenantOverrides map[string]*apiextensionsv1.JSON `json:"tenantOverrides,omitempty"` } - -func (meta *ProvisioningMeta) GetTenantOverrides() map[string]*apiextensionsv1.JSON { - return meta.TenantOverrides -}