From a3d9785418229f85f7ff0a66e737f3655b08261b Mon Sep 17 00:00:00 2001 From: Aditya Thebe Date: Tue, 15 Oct 2024 15:50:05 +0545 Subject: [PATCH] feat: set leader label and remove from other replicas on lead --- leader/election.go | 115 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 94 insertions(+), 21 deletions(-) diff --git a/leader/election.go b/leader/election.go index e0d1f0b7..87e42edb 100644 --- a/leader/election.go +++ b/leader/election.go @@ -9,14 +9,16 @@ import ( "strings" "time" - "github.com/samber/lo" "github.com/sethvargo/go-retry" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" types "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/leaderelection" "k8s.io/client-go/tools/leaderelection/resourcelock" + "github.com/flanksource/commons/logger" "github.com/flanksource/duty/context" + "github.com/flanksource/duty/shutdown" ) var ( @@ -53,10 +55,17 @@ func init() { log.Fatalf("failed to get hostname: %v", err) } + // To test locally + if v, ok := os.LookupEnv("MC_HOSTNAME_OVERRIDE"); ok { + logger.Infof("hostname overriden by MC_HOSTNAME_OVERRIDE: %s", v) + hostname = v + } + if n, err := getPodNamespace(); err == nil { podNamespace = n } + // Not sure if this is a very reliable way to get the service name service = strings.Split(hostname, "-")[0] } @@ -66,7 +75,7 @@ func Register( onLead func(ctx gocontext.Context), onStoppedLead func(), onNewLeader func(identity string), -) { +) error { if namespace == "" { namespace = podNamespace } @@ -84,7 +93,7 @@ func Register( }, } - leaderelection.RunOrDie(ctx, leaderelection.LeaderElectionConfig{ + electionConfig := leaderelection.LeaderElectionConfig{ Lock: lock, ReleaseOnCancel: true, LeaseDuration: ctx.Properties().Duration("leader.lease.duration", 30*time.Second), @@ -92,13 +101,10 @@ func Register( RetryPeriod: 5 * time.Second, Callbacks: leaderelection.LeaderCallbacks{ OnStartedLeading: func(leadCtx gocontext.Context) { - updateLeaderLabel(ctx, true) + updateLeaderLabel(ctx) onLead(leadCtx) }, - OnStoppedLeading: func() { - updateLeaderLabel(ctx, false) - onStoppedLead() - }, + OnStoppedLeading: onStoppedLead, OnNewLeader: func(identity string) { if identity == hostname { return @@ -107,25 +113,92 @@ func Register( onNewLeader(identity) }, }, - }) -} + } -func updateLeaderLabel(ctx context.Context, set bool) { - payload := `{"metadata":{"labels":{"leader":"true"}}}` - if !set { - payload = `{"metadata":{"labels":{"leader": null}}}` + elector, err := leaderelection.NewLeaderElector(electionConfig) + if err != nil { + return err } + leaderContext, cancel := gocontext.WithCancel(ctx) + shutdown.AddHook(func() { + cancel() + + // give the elector some time to release the lease + time.Sleep(time.Second * 2) + }) + + go elector.Run(leaderContext) + <-ctx.Done() + + return nil +} + +// updateLeaderLabel sets leader:true label on the current pod +// and also removes that label from all other replicas. +func updateLeaderLabel(ctx context.Context) { backoff := retry.WithMaxRetries(3, retry.NewExponential(time.Second)) err := retry.Do(ctx, backoff, func(_ctx gocontext.Context) error { - _, err := ctx.Kubernetes().CoreV1().Pods(ctx.GetNamespace()).Patch(ctx, - hostname, - types.MergePatchType, - []byte(payload), - metav1.PatchOptions{}) - return retry.RetryableError(err) + pods, err := getAllReplicas(ctx, hostname) + if err != nil { + return retry.RetryableError(fmt.Errorf("failed to get replicas: %w", err)) + } + + for _, pod := range pods.Items { + var payload string + if pod.Name == hostname { + ctx.Infof("adding leader metadata from pod: %s", pod.Name) + payload = `{"metadata":{"labels":{"leader":"true"}}}` + } else { + ctx.Infof("removing leader metadata from pod: %s", pod.Name) + payload = `{"metadata":{"labels":{"leader": null}}}` + } + + _, err := ctx.Kubernetes().CoreV1().Pods(ctx.GetNamespace()).Patch(ctx, + pod.Name, + types.MergePatchType, + []byte(payload), + metav1.PatchOptions{}) + if err != nil { + return retry.RetryableError(err) + } + } + + return nil }) if err != nil { - ctx.Errorf("failed to %sset label", lo.Ternary(set, "", "un")) + ctx.Errorf("failed to set label: %v", err) } } + +// getAllReplicas returns all the pods from its parent ReplicaSet +func getAllReplicas(ctx context.Context, thisPod string) (*corev1.PodList, error) { + pod, err := ctx.Kubernetes().CoreV1().Pods(ctx.GetNamespace()).Get(ctx, thisPod, metav1.GetOptions{}) + if err != nil { + return nil, err + } + + // Get the ReplicaSet owner reference + var replicaSetName string + for _, ownerRef := range pod.OwnerReferences { + if ownerRef.Kind == "ReplicaSet" { + replicaSetName = ownerRef.Name + break + } + } + + if replicaSetName == "" { + return nil, errors.New("this pod is not managed by a ReplicaSet") + } + + // List all pods with the same ReplicaSet label + labelSelector := fmt.Sprintf("pod-template-hash=%s", pod.Labels["pod-template-hash"]) + podList, err := ctx.Kubernetes().CoreV1().Pods(ctx.GetNamespace()).List(ctx, metav1.ListOptions{ + LabelSelector: labelSelector, + }) + if err != nil { + return nil, fmt.Errorf("failed to list pods with labelSelector(%s): %w", labelSelector, err) + } + + return podList, nil +}