From b08c34f6bd7d4018990412845379b7f3c6535047 Mon Sep 17 00:00:00 2001 From: Matthew Anderson <42154938+matoszz@users.noreply.github.com> Date: Sat, 7 Dec 2024 19:36:16 -0700 Subject: [PATCH] add generic context utility --- contextx/contextx.go | 54 ++++++++++++++++++++ contextx/contextx_test.go | 102 ++++++++++++++++++++++++++++++++++++++ contextx/doc.go | 8 +++ 3 files changed, 164 insertions(+) create mode 100644 contextx/contextx.go create mode 100644 contextx/contextx_test.go create mode 100644 contextx/doc.go diff --git a/contextx/contextx.go b/contextx/contextx.go new file mode 100644 index 0000000..bce614b --- /dev/null +++ b/contextx/contextx.go @@ -0,0 +1,54 @@ +package contextx + +import ( + "context" +) + +// key is a unique type that we can use as a key in a context +type key[T any] struct{} + +// With returns a copy of parent that contains the given value which can be retrieved by calling From with the resulting context +// The function uses a generic key type to ensure that the stored value is type-safe and can be uniquely identified and retrieved without +// risk of key collisions +func With[T any](ctx context.Context, v T) context.Context { + return context.WithValue(ctx, key[T]{}, v) +} + +// From returns the value associated with the wanted type from the context +// It performs a type assertion to convert the value to the desired type T +// If the type assertion is successful, it returns the value and true +// If the type assertion fails, it returns the zero value of type T and false +func From[T any](ctx context.Context) (T, bool) { + v, ok := ctx.Value(key[T]{}).(T) + + return v, ok +} + +// MustFrom is similar to from, except that it panics if the type assertion fails / the value is not in the context +func MustFrom[T any](ctx context.Context) T { + return ctx.Value(key[T]{}).(T) +} + +// FromOr returns the value associated with the wanted type or the given default value if the type is not found +// This function is useful when you want to ensure that a value is always returned from the context, even if the +// context does not contain a value of the desired type. By providing a default value, you can avoid handling +// the case where the value is missing and ensure that your code has a fallback value to use +func FromOr[T any](ctx context.Context, def T) T { + v, ok := From[T](ctx) + if !ok { + return def + } + + return v +} + +// FromOrFunc returns the value associated with the wanted type or the result of the given function if the type is not found +// This function is useful when the default value is expensive to compute or when the default value depends on some runtime conditions +func FromOrFunc[T any](ctx context.Context, f func() T) T { + v, ok := From[T](ctx) + if !ok { + return f() + } + + return v +} diff --git a/contextx/contextx_test.go b/contextx/contextx_test.go new file mode 100644 index 0000000..ab38339 --- /dev/null +++ b/contextx/contextx_test.go @@ -0,0 +1,102 @@ +package contextx + +import ( + "context" + "reflect" + "testing" +) + +func TestNormalOperation(t *testing.T) { + ctx := context.Background() + ctx = With(ctx, 10) + + if MustFrom[int](ctx) != 10 { + t.FailNow() + } + + if _, ok := From[float64](ctx); ok { + t.FailNow() + } +} + +func TestIsolatedFromExplicitTypeReflection(t *testing.T) { + ctx := context.Background() + + ctx = With(ctx, 10) + + ctx = context.WithValue(ctx, reflect.TypeOf(20), 20) + + if MustFrom[int](ctx) != 10 { + t.FailNow() + } +} + +func TestPanicIfNoValue(t *testing.T) { + defer func() { + if recover() == nil { + t.FailNow() + } + }() + + MustFrom[int](context.Background()) +} + +type x interface { + a() +} + +type y struct{ v int } + +func (y) a() {} + +type z struct{ f func() } + +func (z z) a() { z.f() } + +func TestShouldWorkOnInterface(t *testing.T) { + var a x = y{10} + + ctx := context.Background() + ctx = With(ctx, a) + + b := MustFrom[x](ctx) + if b.(y).v != 10 { + t.FailNow() + } + + r := "" + a = z{func() { r = "hello" }} + + ctx = With(ctx, a) + + MustFrom[x](ctx).a() + + if r != "hello" { + t.FailNow() + } +} +func TestFromOr(t *testing.T) { + ctx := context.Background() + ctx = With(ctx, 10) + + if FromOr(ctx, 20) != 10 { + t.FailNow() + } + + if FromOr(context.Background(), 20) != 20 { + t.FailNow() + } +} + +func TestFromOrFunc(t *testing.T) { + ctx := context.Background() + ctx = With(ctx, 10) + + if FromOrFunc(ctx, func() int { return 20 }) != 10 { + t.FailNow() + } + + if FromOrFunc(context.Background(), func() int { return 20 }) != 20 { + t.FailNow() + } +} diff --git a/contextx/doc.go b/contextx/doc.go new file mode 100644 index 0000000..d2fe06b --- /dev/null +++ b/contextx/doc.go @@ -0,0 +1,8 @@ +// Package contextx is a helper package for managing context values +// Most **request-scoped data** is a singleton per request +// That is, it doesn't make sense for a request to carry around multiple loggers, users, traces +// you want to carry the _same one_ with you from function call to function call +// the way we've handled this historically is a separate context key per type you want to carry in the struct +// but with generics, instead of having to make a new zero-sized type for every struct +// we can just make a single generic type and use it for everything which is what this helper package is intended to do +package contextx