diff --git a/controller.go b/controller.go index 81d4ac1..45cbf23 100644 --- a/controller.go +++ b/controller.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "reflect" "time" log "github.com/sirupsen/logrus" @@ -9,7 +10,7 @@ import ( "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/client-go/kubernetes" - "k8s.io/client-go/pkg/api/v1" + v1 "k8s.io/client-go/pkg/api/v1" "k8s.io/client-go/pkg/apis/apps/v1beta1" pv1beta1 "k8s.io/client-go/pkg/apis/policy/v1beta1" ) @@ -17,6 +18,9 @@ import ( const ( heritageLabel = "heritage" pdbController = "pdb-controller" + + deploymentOwnershipLabel = "pod-template-hash" + statefulsetOwnershipLabel = "statefulset.kubernetes.io/pod-name" ) var ( @@ -99,7 +103,7 @@ func (n *PDBController) addPDBs(namespace *v1.Namespace) error { return err } - addPDB := make([]interface{}, 0, len(deployments.Items)+len(statefulSets.Items)) + addPDB := make([]metav1.Object, 0, len(deployments.Items)+len(statefulSets.Items)) removePDB := make([]pv1beta1.PodDisruptionBudget, 0, len(deployments.Items)+len(statefulSets.Items)) nonReadTTL := time.Now().UTC().Add(-n.nonReadyTTL) @@ -111,7 +115,8 @@ func (n *PDBController) addPDBs(namespace *v1.Namespace) error { // ready, add one if deployment.Status.ReadyReplicas == *deployment.Spec.Replicas { if len(matchedPDBs) == 0 && *deployment.Spec.Replicas > 1 { - addPDB = append(addPDB, deployment) + obj := &deployment + addPDB = append(addPDB, obj) } } @@ -168,7 +173,8 @@ func (n *PDBController) addPDBs(namespace *v1.Namespace) error { // ready, add one if statefulSet.Status.ReadyReplicas == *statefulSet.Spec.Replicas { if len(matchedPDBs) == 0 && *statefulSet.Spec.Replicas > 1 { - addPDB = append(addPDB, statefulSet) + obj := &statefulSet + addPDB = append(addPDB, obj) } } @@ -215,50 +221,15 @@ func (n *PDBController) addPDBs(namespace *v1.Namespace) error { // add missing PDBs for _, resource := range addPDB { - maxUnavailable := intstr.FromInt(1) - pdb := &pv1beta1.PodDisruptionBudget{ - Spec: pv1beta1.PodDisruptionBudgetSpec{ - MaxUnavailable: &maxUnavailable, - }, - } + var pdb *pv1beta1.PodDisruptionBudget switch r := resource.(type) { - case v1beta1.Deployment: - if r.Labels == nil { - r.Labels = make(map[string]string) - } - labels := r.Labels - labels[heritageLabel] = pdbController - pdb.Name = r.Name - pdb.Namespace = r.Namespace - pdb.OwnerReferences = []metav1.OwnerReference{ - { - APIVersion: "apps/v1", - Kind: "Deployment", - Name: r.Name, - UID: r.UID, - }, - } - pdb.Labels = labels - pdb.Spec.Selector = r.Spec.Selector - case v1beta1.StatefulSet: - if r.Labels == nil { - r.Labels = make(map[string]string) - } - labels := r.Labels - labels[heritageLabel] = pdbController - pdb.Name = r.Name - pdb.Namespace = r.Namespace - pdb.OwnerReferences = []metav1.OwnerReference{ - { - APIVersion: "apps/v1", - Kind: "StatefulSet", - Name: r.Name, - UID: r.UID, - }, - } - pdb.Labels = labels - pdb.Spec.Selector = r.Spec.Selector + case *v1beta1.Deployment: + pdb = generatePDB(r.APIVersion, r.Kind, r, r.Spec.Selector, deploymentOwnershipLabel) + case *v1beta1.StatefulSet: + pdb = generatePDB(r.APIVersion, r.Kind, r, r.Spec.Selector, statefulsetOwnershipLabel) + default: + return fmt.Errorf("unknown type for %s/%s: %s", resource.GetNamespace(), resource.GetName(), reflect.TypeOf(r)) } if n.pdbNameSuffix != "" { @@ -298,6 +269,41 @@ func (n *PDBController) addPDBs(namespace *v1.Namespace) error { return nil } +func generatePDB(apiVersion, kind string, object metav1.Object, selector *metav1.LabelSelector, ownershipLabel string) *pv1beta1.PodDisruptionBudget { + maxUnavailable := intstr.FromInt(1) + pdb := &pv1beta1.PodDisruptionBudget{ + Spec: pv1beta1.PodDisruptionBudgetSpec{ + MaxUnavailable: &maxUnavailable, + }, + } + + pdb.Labels = object.GetLabels() + if pdb.Labels == nil { + pdb.Labels = make(map[string]string) + } + pdb.Labels[heritageLabel] = pdbController + + pdb.Name = object.GetName() + pdb.Namespace = object.GetNamespace() + pdb.OwnerReferences = []metav1.OwnerReference{ + { + APIVersion: apiVersion, + Kind: kind, + Name: object.GetName(), + UID: object.GetUID(), + }, + } + pdb.Spec.Selector = selector + if !hasOwnershipMatchExpression(pdb.Spec.Selector) { + pdb.Spec.Selector.MatchExpressions = append(pdb.Spec.Selector.MatchExpressions, metav1.LabelSelectorRequirement{ + Key: ownershipLabel, + Operator: metav1.LabelSelectorOpExists, + }) + } + + return pdb +} + // getPodsLastTransitionTime returns the latest transition time for the pod not // ready condition of all pods matched by the selector. func (n *PDBController) getPodsLastTransitionTime(namespace string, selector map[string]string) (time.Time, error) { @@ -331,9 +337,21 @@ func (n *PDBController) getPodsLastTransitionTime(namespace string, selector map return lastTransitionTime, nil } +func hasOwnershipMatchExpression(selector *metav1.LabelSelector) bool { + for _, expr := range selector.MatchExpressions { + if expr.Operator == metav1.LabelSelectorOpExists { + switch expr.Key { + case deploymentOwnershipLabel, statefulsetOwnershipLabel: + return true + } + } + } + return false +} + // pdbSpecValid returns true if the PDB spec is up-to-date func pdbSpecValid(pdb pv1beta1.PodDisruptionBudget) bool { - return pdb.Spec.MinAvailable == nil + return pdb.Spec.MinAvailable == nil && hasOwnershipMatchExpression(pdb.Spec.Selector) } // getPDBs gets matching PodDisruptionBudgets. diff --git a/controller_test.go b/controller_test.go index 6f452c4..1043456 100644 --- a/controller_test.go +++ b/controller_test.go @@ -8,7 +8,7 @@ import ( "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/fake" - "k8s.io/client-go/pkg/api/v1" + v1 "k8s.io/client-go/pkg/api/v1" "k8s.io/client-go/pkg/apis/apps/v1beta1" pv1beta1 "k8s.io/client-go/pkg/apis/policy/v1beta1" )