Skip to content

Commit

Permalink
use service lister instead of endpoints cache to get port from portName
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Wozniak <[email protected]>
  • Loading branch information
wozniakjan committed Oct 24, 2024
1 parent 36b9348 commit 17de409
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 38 deletions.
8 changes: 8 additions & 0 deletions config/interceptor/role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ rules:
- get
- list
- watch
- apiGroups:
- ""
resources:
- services
verbs:
- get
- list
- watch
- apiGroups:
- http.keda.sh
resources:
Expand Down
17 changes: 13 additions & 4 deletions interceptor/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/exp/maps"
"golang.org/x/sync/errgroup"
k8sinformers "k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes"
v1 "k8s.io/client-go/listers/core/v1"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/log/zap"

Expand All @@ -42,6 +44,7 @@ var (

// +kubebuilder:rbac:groups=http.keda.sh,resources=httpscaledobjects,verbs=get;list;watch
// +kubebuilder:rbac:groups="",resources=endpoints,verbs=get;list;watch
// +kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch

func main() {
timeoutCfg := config.MustParseTimeouts()
Expand Down Expand Up @@ -110,6 +113,8 @@ func main() {
setupLog.Error(err, "fetching routing table")
os.Exit(1)
}
k8sSharedInformerFactory := k8sinformers.NewSharedInformerFactory(cl, servingCfg.ConfigMapCacheRsyncPeriod)
svcLister := k8sSharedInformerFactory.Core().V1().Services().Lister()

setupLog.Info("Interceptor starting")

Expand All @@ -123,6 +128,7 @@ func main() {
setupLog.Info("starting the endpoints cache")

endpointsCache.Start(ctx)
k8sSharedInformerFactory.Start(ctx.Done())
return nil
})

Expand Down Expand Up @@ -173,10 +179,11 @@ func main() {
eg.Go(func() error {
proxyTLSConfig := map[string]string{"certificatePath": servingCfg.TLSCertPath, "keyPath": servingCfg.TLSKeyPath, "certstorePaths": servingCfg.TLSCertStorePaths}
proxyTLSPort := servingCfg.TLSPort
k8sSharedInformerFactory.WaitForCacheSync(ctx.Done())

setupLog.Info("starting the proxy server with TLS enabled", "port", proxyTLSPort)

if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, endpointsCache, timeoutCfg, proxyTLSPort, proxyTLSEnabled, proxyTLSConfig); !util.IsIgnoredErr(err) {
if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, svcLister, timeoutCfg, proxyTLSPort, proxyTLSEnabled, proxyTLSConfig); !util.IsIgnoredErr(err) {
setupLog.Error(err, "tls proxy server failed")
return err
}
Expand All @@ -186,9 +193,11 @@ func main() {

// start a proxy server without TLS.
eg.Go(func() error {
k8sSharedInformerFactory.WaitForCacheSync(ctx.Done())
setupLog.Info("starting the proxy server with TLS disabled", "port", proxyPort)

if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, endpointsCache, timeoutCfg, proxyPort, false, nil); !util.IsIgnoredErr(err) {
k8sSharedInformerFactory.WaitForCacheSync(ctx.Done())
if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, svcLister, timeoutCfg, proxyPort, false, nil); !util.IsIgnoredErr(err) {
setupLog.Error(err, "proxy server failed")
return err
}
Expand Down Expand Up @@ -369,7 +378,7 @@ func runProxyServer(
q queue.Counter,
waitFunc forwardWaitFunc,
routingTable routing.Table,
endpointsCache k8s.EndpointsCache,
svcLister v1.ServiceLister,
timeouts *config.Timeouts,
port int,
tlsEnabled bool,
Expand Down Expand Up @@ -417,7 +426,7 @@ func runProxyServer(
routingTable,
probeHandler,
upstreamHandler,
endpointsCache,
svcLister,
tlsEnabled,
)
rootHandler = middleware.NewLogging(
Expand Down
12 changes: 6 additions & 6 deletions interceptor/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestRunProxyServerCountMiddleware(t *testing.T) {
// server
routingTable := routingtest.NewTable()
routingTable.Memory[host] = httpso
endpointsCache := k8s.NewFakeEndpointsCache()
_, svcLister := k8s.NewFakeServiceLister()

timeouts := &config.Timeouts{}
waiterCh := make(chan struct{})
Expand All @@ -78,7 +78,7 @@ func TestRunProxyServerCountMiddleware(t *testing.T) {
q,
waitFunc,
routingTable,
endpointsCache,
svcLister,
timeouts,
port,
false,
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestRunProxyServerWithTLSCountMiddleware(t *testing.T) {
// server
routingTable := routingtest.NewTable()
routingTable.Memory[host] = httpso
endpointsCache := k8s.NewFakeEndpointsCache()
_, svcLister := k8s.NewFakeServiceLister()

timeouts := &config.Timeouts{}
waiterCh := make(chan struct{})
Expand All @@ -212,7 +212,7 @@ func TestRunProxyServerWithTLSCountMiddleware(t *testing.T) {
q,
waitFunc,
routingTable,
endpointsCache,
svcLister,
timeouts,
port,
true,
Expand Down Expand Up @@ -343,7 +343,7 @@ func TestRunProxyServerWithMultipleCertsTLSCountMiddleware(t *testing.T) {
// server
routingTable := routingtest.NewTable()
routingTable.Memory[host] = httpso
endpointsCache := k8s.NewFakeEndpointsCache()
_, svcLister := k8s.NewFakeServiceLister()

timeouts := &config.Timeouts{}
waiterCh := make(chan struct{})
Expand All @@ -359,7 +359,7 @@ func TestRunProxyServerWithMultipleCertsTLSCountMiddleware(t *testing.T) {
q,
waitFunc,
routingTable,
endpointsCache,
svcLister,
timeouts,
port,
true,
Expand Down
25 changes: 12 additions & 13 deletions interceptor/middleware/routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import (
"net/url"
"regexp"

v1 "k8s.io/client-go/listers/core/v1"

"github.com/kedacore/http-add-on/interceptor/handler"
httpv1alpha1 "github.com/kedacore/http-add-on/operator/apis/http/v1alpha1"
"github.com/kedacore/http-add-on/pkg/k8s"
"github.com/kedacore/http-add-on/pkg/routing"
"github.com/kedacore/http-add-on/pkg/util"
)
Expand All @@ -22,16 +23,16 @@ type Routing struct {
routingTable routing.Table
probeHandler http.Handler
upstreamHandler http.Handler
endpointsCache k8s.EndpointsCache
svcLister v1.ServiceLister
tlsEnabled bool
}

func NewRouting(routingTable routing.Table, probeHandler http.Handler, upstreamHandler http.Handler, endpointsCache k8s.EndpointsCache, tlsEnabled bool) *Routing {
func NewRouting(routingTable routing.Table, probeHandler http.Handler, upstreamHandler http.Handler, svcLister v1.ServiceLister, tlsEnabled bool) *Routing {
return &Routing{
routingTable: routingTable,
probeHandler: probeHandler,
upstreamHandler: upstreamHandler,
endpointsCache: endpointsCache,
svcLister: svcLister,
tlsEnabled: tlsEnabled,
}
}
Expand Down Expand Up @@ -72,20 +73,18 @@ func (rm *Routing) getPort(httpso *httpv1alpha1.HTTPScaledObject) (int32, error)
return httpso.Spec.ScaleTargetRef.Port, nil
}
if httpso.Spec.ScaleTargetRef.PortName == "" {
return 0, fmt.Errorf("must specify either port or portName")
return 0, fmt.Errorf(`must specify either "port" or "portName"`)
}
endpoints, err := rm.endpointsCache.Get(httpso.GetNamespace(), httpso.Spec.ScaleTargetRef.Service)
svc, err := rm.svcLister.Services(httpso.GetNamespace()).Get(httpso.Spec.ScaleTargetRef.Service)
if err != nil {
return 0, fmt.Errorf("failed to get Endpoints: %w", err)
return 0, fmt.Errorf("failed to get Service: %w", err)
}
for _, subset := range endpoints.Subsets {
for _, port := range subset.Ports {
if port.Name == httpso.Spec.ScaleTargetRef.PortName {
return port.Port, nil
}
for _, port := range svc.Spec.Ports {
if port.Name == httpso.Spec.ScaleTargetRef.PortName {
return port.Port, nil
}
}
return 0, fmt.Errorf("portName %s not found in Endpoints", httpso.Spec.ScaleTargetRef.PortName)
return 0, fmt.Errorf("portName %q not found in Service", httpso.Spec.ScaleTargetRef.PortName)
}

func (rm *Routing) streamFromHTTPSO(httpso *httpv1alpha1.HTTPScaledObject) (*url.URL, error) {
Expand Down
32 changes: 18 additions & 14 deletions interceptor/middleware/routing_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package middleware

import (
"context"
"net/http"
"net/http/httptest"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"
listersv1 "k8s.io/client-go/listers/core/v1"

httpv1alpha1 "github.com/kedacore/http-add-on/operator/apis/http/v1alpha1"
"github.com/kedacore/http-add-on/pkg/k8s"
Expand All @@ -25,9 +28,9 @@ var _ = Describe("RoutingMiddleware", func() {
emptyHandler := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
probeHandler.Handle("/probe", emptyHandler)
upstreamHandler.Handle("/upstream", emptyHandler)
endpointsCache := k8s.NewFakeEndpointsCache()
_, svcLister := k8s.NewFakeServiceLister()

rm := NewRouting(routingTable, probeHandler, upstreamHandler, endpointsCache, false)
rm := NewRouting(routingTable, probeHandler, upstreamHandler, svcLister, false)
Expect(rm).NotTo(BeNil())
Expect(rm.routingTable).To(Equal(routingTable))
Expect(rm.probeHandler).To(Equal(probeHandler))
Expand All @@ -44,7 +47,8 @@ var _ = Describe("RoutingMiddleware", func() {
var (
upstreamHandler *http.ServeMux
probeHandler *http.ServeMux
endpointsCache *k8s.FakeEndpointsCache
cl *fake.Clientset
svcLister listersv1.ServiceLister
routingTable *routingtest.Table
routingMiddleware *Routing
w *httptest.ResponseRecorder
Expand Down Expand Up @@ -76,18 +80,16 @@ var _ = Describe("RoutingMiddleware", func() {
},
},
}
endpoints = corev1.Endpoints{
svc = &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Name: "keda-svc",
Namespace: "default",
},
Subsets: []corev1.EndpointSubset{
{
Ports: []corev1.EndpointPort{
{
Name: "http",
Port: 80,
},
Spec: corev1.ServiceSpec{
Ports: []corev1.ServicePort{
{
Name: "http",
Port: 80,
},
},
},
Expand All @@ -98,8 +100,8 @@ var _ = Describe("RoutingMiddleware", func() {
upstreamHandler = http.NewServeMux()
probeHandler = http.NewServeMux()
routingTable = routingtest.NewTable()
endpointsCache = k8s.NewFakeEndpointsCache()
routingMiddleware = NewRouting(routingTable, probeHandler, upstreamHandler, endpointsCache, false)
cl, svcLister = k8s.NewFakeServiceLister()
routingMiddleware = NewRouting(routingTable, probeHandler, upstreamHandler, svcLister, false)

w = httptest.NewRecorder()

Expand Down Expand Up @@ -141,7 +143,9 @@ var _ = Describe("RoutingMiddleware", func() {

When("route is found with portName", func() {
It("routes to the upstream handler", func() {
endpointsCache.Set(endpoints)
_, err := cl.CoreV1().Services(svc.Namespace).Create(context.Background(), svc, metav1.CreateOptions{})
Expect(err).NotTo(HaveOccurred())
Eventually(func() error { _, err := svcLister.Services(svc.Namespace).Get(svc.Name); return err }).Should(Succeed())
var (
sc = http.StatusTeapot
st = http.StatusText(sc)
Expand Down
3 changes: 2 additions & 1 deletion interceptor/proxy_handlers_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ func newHarness(
},
)

_, svcLister := k8s.NewFakeServiceLister()
endpCache := k8s.NewFakeEndpointsCache()
waitFunc := newWorkloadReplicasForwardWaitFunc(
logr.Discard(),
Expand Down Expand Up @@ -308,7 +309,7 @@ func newHarness(
respHeaderTimeout: time.Second,
},
&tls.Config{}),
endpCache,
svcLister,
false,
)

Expand Down
14 changes: 14 additions & 0 deletions pkg/k8s/endpoints_cache_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes/fake"
listersv1 "k8s.io/client-go/listers/core/v1"
)

// FakeEndpointsCache is a fake implementation of
Expand Down Expand Up @@ -121,3 +124,14 @@ func (f *FakeEndpointsCache) SetSubsets(ns, name string, num int) error {
func key(ns, name string) string {
return fmt.Sprintf("%s/%s", ns, name)
}

// NewFakeServiceLister returns a fake implementation of a ServiceLister
func NewFakeServiceLister() (*fake.Clientset, listersv1.ServiceLister) {
client := fake.NewSimpleClientset()
factory := informers.NewSharedInformerFactory(client, 0)
lister := factory.Core().V1().Services().Lister()
ch := make(chan struct{})
factory.Start(ch)
factory.WaitForCacheSync(ch)
return client, lister
}

0 comments on commit 17de409

Please sign in to comment.