Skip to content

Commit

Permalink
refactor(dslx): pass Maybe[T] to Func.Apply (#1382)
Browse files Browse the repository at this point in the history
  • Loading branch information
bassosimone authored Oct 25, 2023
1 parent 430f1c2 commit 227bea1
Show file tree
Hide file tree
Showing 21 changed files with 257 additions and 206 deletions.
8 changes: 4 additions & 4 deletions internal/dslx/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ type ResolvedAddresses struct {

// DNSLookupGetaddrinfo returns a function that resolves a domain name to
// IP addresses using libc's getaddrinfo function.
func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] {
return FuncAdapter[*DomainToResolve, *Maybe[*ResolvedAddresses]](func(ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] {
func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *ResolvedAddresses] {
return Operation[*DomainToResolve, *ResolvedAddresses](func(ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] {
// create trace
trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime(), input.Tags...)

Expand Down Expand Up @@ -115,8 +115,8 @@ func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *Maybe[*ResolvedAdd

// DNSLookupUDP returns a function that resolves a domain name to
// IP addresses using the given DNS-over-UDP resolver.
func DNSLookupUDP(rt Runtime, endpoint string) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] {
return FuncAdapter[*DomainToResolve, *Maybe[*ResolvedAddresses]](func(ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] {
func DNSLookupUDP(rt Runtime, endpoint string) Func[*DomainToResolve, *ResolvedAddresses] {
return Operation[*DomainToResolve, *ResolvedAddresses](func(ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] {
// create trace
trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime(), input.Tags...)

Expand Down
12 changes: 6 additions & 6 deletions internal/dslx/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func TestGetaddrinfo(t *testing.T) {
)
ctx, cancel := context.WithCancel(context.Background())
cancel() // immediately cancel the lookup
res := f.Apply(ctx, domain)
res := f.Apply(ctx, NewMaybeWithValue(domain))
if res.Observations == nil || len(res.Observations) <= 0 {
t.Fatal("unexpected empty observations")
}
Expand All @@ -88,7 +88,7 @@ func TestGetaddrinfo(t *testing.T) {
},
})),
)
res := f.Apply(context.Background(), domain)
res := f.Apply(context.Background(), NewMaybeWithValue(domain))
if res.Observations == nil || len(res.Observations) <= 0 {
t.Fatal("unexpected empty observations")
}
Expand All @@ -115,7 +115,7 @@ func TestGetaddrinfo(t *testing.T) {
},
})),
)
res := f.Apply(context.Background(), domain)
res := f.Apply(context.Background(), NewMaybeWithValue(domain))
if res.Observations == nil || len(res.Observations) <= 0 {
t.Fatal("unexpected empty observations")
}
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestLookupUDP(t *testing.T) {
f := DNSLookupUDP(NewMinimalRuntime(model.DiscardLogger, time.Now()), "1.1.1.1:53")
ctx, cancel := context.WithCancel(context.Background())
cancel()
res := f.Apply(ctx, domain)
res := f.Apply(ctx, NewMaybeWithValue(domain))
if res.Observations == nil || len(res.Observations) <= 0 {
t.Fatal("unexpected empty observations")
}
Expand Down Expand Up @@ -184,7 +184,7 @@ func TestLookupUDP(t *testing.T) {
})),
"1.1.1.1:53",
)
res := f.Apply(context.Background(), domain)
res := f.Apply(context.Background(), NewMaybeWithValue(domain))
if res.Observations == nil || len(res.Observations) <= 0 {
t.Fatal("unexpected empty observations")
}
Expand Down Expand Up @@ -219,7 +219,7 @@ func TestLookupUDP(t *testing.T) {
})),
"1.1.1.1:53",
)
res := f.Apply(context.Background(), domain)
res := f.Apply(context.Background(), NewMaybeWithValue(domain))
if res.Observations == nil || len(res.Observations) <= 0 {
t.Fatal("unexpected empty observations")
}
Expand Down
12 changes: 6 additions & 6 deletions internal/dslx/fxasync.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Parallelism int
func Map[A, B any](
ctx context.Context,
parallelism Parallelism,
fx Func[A, *Maybe[B]],
fx Func[A, B],
inputs <-chan A,
) <-chan *Maybe[B] {
// create channel for returning results
Expand All @@ -49,7 +49,7 @@ func Map[A, B any](
go func() {
defer wg.Done()
for a := range inputs {
r <- fx.Apply(ctx, a)
r <- fx.Apply(ctx, NewMaybeWithValue(a))
}
}()
}
Expand Down Expand Up @@ -81,7 +81,7 @@ func Parallel[A, B any](
ctx context.Context,
parallelism Parallelism,
input A,
fn ...Func[A, *Maybe[B]],
fn ...Func[A, B],
) []*Maybe[B] {
c := ParallelAsync(ctx, parallelism, input, StreamList(fn...))
return Collect(c)
Expand All @@ -94,7 +94,7 @@ func ParallelAsync[A, B any](
ctx context.Context,
parallelism Parallelism,
input A,
funcs <-chan Func[A, *Maybe[B]],
funcs <-chan Func[A, B],
) <-chan *Maybe[B] {
// create channel for returning results
r := make(chan *Maybe[B])
Expand All @@ -109,7 +109,7 @@ func ParallelAsync[A, B any](
go func() {
defer wg.Done()
for fx := range funcs {
r <- fx.Apply(ctx, input)
r <- fx.Apply(ctx, NewMaybeWithValue(input))
}
}()
}
Expand All @@ -126,7 +126,7 @@ func ParallelAsync[A, B any](
// ApplyAsync is equivalent to calling Apply but returns a channel.
func ApplyAsync[A, B any](
ctx context.Context,
fx Func[A, *Maybe[B]],
fx Func[A, B],
input A,
) <-chan *Maybe[B] {
return Map(ctx, Parallelism(1), fx, StreamList(input))
Expand Down
11 changes: 7 additions & 4 deletions internal/dslx/fxasync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@ import (
"context"
"sync"
"testing"

"github.com/ooni/probe-cli/v3/internal/runtimex"
)

func getFnWait(wg *sync.WaitGroup) Func[int, *Maybe[int]] {
func getFnWait(wg *sync.WaitGroup) Func[int, int] {
return &fnWait{wg}
}

type fnWait struct {
wg *sync.WaitGroup // set to n corresponding to the number of used goroutines
}

func (f *fnWait) Apply(ctx context.Context, i int) *Maybe[int] {
func (f *fnWait) Apply(ctx context.Context, i *Maybe[int]) *Maybe[int] {
runtimex.Assert(i.Error == nil, "did not expect to see an error here")
f.wg.Done()
f.wg.Wait() // continue when n goroutines have reached this point
return &Maybe[int]{State: i + 1}
return &Maybe[int]{State: i.State + 1}
}

/*
Expand Down Expand Up @@ -86,7 +89,7 @@ func TestParallel(t *testing.T) {
t.Run(name, func(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(tt.funcs)
funcs := []Func[int, *Maybe[int]]{}
funcs := []Func[int, int]{}
for i := 0; i < tt.funcs; i++ {
funcs = append(funcs, getFnWait(&wg))
}
Expand Down
69 changes: 41 additions & 28 deletions internal/dslx/fxcore.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,23 @@ import (

// Func is a function f: (context.Context, A) -> B.
type Func[A, B any] interface {
Apply(ctx context.Context, a A) B
Apply(ctx context.Context, a *Maybe[A]) *Maybe[B]
}

// FuncAdapter adapts a func to be a Func.
type FuncAdapter[A, B any] func(ctx context.Context, a A) B
// Operation adapts a golang function to behave like a Func.
type Operation[A, B any] func(ctx context.Context, a A) *Maybe[B]

// Apply implements Func.
func (fa FuncAdapter[A, B]) Apply(ctx context.Context, a A) B {
return fa(ctx, a)
func (op Operation[A, B]) Apply(ctx context.Context, a *Maybe[A]) *Maybe[B] {
if a.Error != nil {
return &Maybe[B]{
Error: a.Error,
Observations: a.Observations,
Operation: a.Operation,
State: *new(B), // zero value
}
}
return op(ctx, a.State)
}

// Maybe is the result of an operation implemented by this package
Expand All @@ -42,8 +50,18 @@ type Maybe[State any] struct {
State State
}

// NewMaybeWithValue constructs a Maybe containing the given value.
func NewMaybeWithValue[State any](value State) *Maybe[State] {
return &Maybe[State]{
Error: nil,
Observations: []*Observations{},
Operation: "",
State: value,
}
}

// Compose2 composes two operations such as [TCPConnect] and [TLSHandshake].
func Compose2[A, B, C any](f Func[A, *Maybe[B]], g Func[B, *Maybe[C]]) Func[A, *Maybe[C]] {
func Compose2[A, B, C any](f Func[A, B], g Func[B, C]) Func[A, C] {
return &compose2Func[A, B, C]{
f: f,
g: g,
Expand All @@ -52,14 +70,15 @@ func Compose2[A, B, C any](f Func[A, *Maybe[B]], g Func[B, *Maybe[C]]) Func[A, *

// compose2Func is the type returned by [Compose2].
type compose2Func[A, B, C any] struct {
f Func[A, *Maybe[B]]
g Func[B, *Maybe[C]]
f Func[A, B]
g Func[B, C]
}

// Apply implements Func
func (h *compose2Func[A, B, C]) Apply(ctx context.Context, a A) *Maybe[C] {
func (h *compose2Func[A, B, C]) Apply(ctx context.Context, a *Maybe[A]) *Maybe[C] {
mb := h.f.Apply(ctx, a)
runtimex.Assert(mb != nil, "h.f.Apply returned a nil pointer")

if mb.Error != nil {
return &Maybe[C]{
Error: mb.Error,
Expand All @@ -68,8 +87,10 @@ func (h *compose2Func[A, B, C]) Apply(ctx context.Context, a A) *Maybe[C] {
State: *new(C), // zero value
}
}
mc := h.g.Apply(ctx, mb.State)

mc := h.g.Apply(ctx, mb)
runtimex.Assert(mc != nil, "h.g.Apply returned a nil pointer")

op := mc.Operation
if op == "" { // propagate the previous operation name, if this operation has none
op = mb.Operation
Expand Down Expand Up @@ -99,24 +120,16 @@ func (c *Counter[T]) Value() int64 {
}

// Func returns a Func[T, *Maybe[T]] that updates the counter.
func (c *Counter[T]) Func() Func[T, *Maybe[T]] {
return &counterFunc[T]{c}
}

// counterFunc is the Func returned by CounterFunc.Func.
type counterFunc[T any] struct {
c *Counter[T]
}

// Apply implements Func.
func (c *counterFunc[T]) Apply(ctx context.Context, value T) *Maybe[T] {
c.c.n.Add(1)
return &Maybe[T]{
Error: nil,
Observations: nil,
Operation: "", // we cannot fail, so no need to store operation name
State: value,
}
func (c *Counter[T]) Func() Func[T, T] {
return Operation[T, T](func(ctx context.Context, value T) *Maybe[T] {
c.n.Add(1)
return &Maybe[T]{
Error: nil,
Observations: nil,
Operation: "", // we cannot fail, so no need to store operation name
State: value,
}
})
}

// FirstErrorExcludingBrokenIPv6Errors returns the first error and failed operation in a list of
Expand Down
51 changes: 43 additions & 8 deletions internal/dslx/fxcore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ import (
"context"
"errors"
"testing"
"time"

"github.com/ooni/probe-cli/v3/internal/mocks"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/runtimex"
)

func getFn(err error, name string) Func[int, *Maybe[int]] {
func getFn(err error, name string) Func[int, int] {
return &fn{err: err, name: name}
}

Expand All @@ -18,10 +21,11 @@ type fn struct {
name string
}

func (f *fn) Apply(ctx context.Context, i int) *Maybe[int] {
func (f *fn) Apply(ctx context.Context, i *Maybe[int]) *Maybe[int] {
runtimex.Assert(i.Error == nil, "did not expect to see an error here")
return &Maybe[int]{
Error: f.err,
State: i + 1,
State: i.State + 1,
Observations: []*Observations{
{
NetworkEvents: []*model.ArchivalNetworkEvent{{Tags: []string{"apply"}}},
Expand All @@ -31,6 +35,37 @@ func (f *fn) Apply(ctx context.Context, i int) *Maybe[int] {
}
}

func TestStageAdapter(t *testing.T) {
t.Run("make sure that we handle a previous stage failure", func(t *testing.T) {
unet := &mocks.UnderlyingNetwork{
// explicitly empty so we crash if we try using underlying network functionality
}
netx := &netxlite.Netx{Underlying: unet}

// create runtime
rt := NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(netx))

// create measurement pipeline where we run DNS lookups
pipeline := DNSLookupGetaddrinfo(rt)

// create input that contains an error
input := &Maybe[*DomainToResolve]{
Error: errors.New("mocked error"),
Observations: []*Observations{},
Operation: "",
State: nil,
}

// run the pipeline
output := pipeline.Apply(context.Background(), input)

// make sure the output contains the same error as the input
if !errors.Is(output.Error, input.Error) {
t.Fatal("unexpected error")
}
})
}

/*
Test cases:
- Compose 2 functions:
Expand All @@ -53,7 +88,7 @@ func TestCompose2(t *testing.T) {
f1 := getFn(tt.err, "maybe fail")
f2 := getFn(nil, "succeed")
composit := Compose2(f1, f2)
r := composit.Apply(context.Background(), tt.input)
r := composit.Apply(context.Background(), NewMaybeWithValue(tt.input))
if r.Error != tt.err {
t.Fatalf("unexpected error")
}
Expand All @@ -73,7 +108,7 @@ func TestGen(t *testing.T) {
incFunc := getFn(nil, "succeed")
composit := Compose14(incFunc, incFunc, incFunc, incFunc, incFunc, incFunc, incFunc, incFunc,
incFunc, incFunc, incFunc, incFunc, incFunc, incFunc)
r := composit.Apply(context.Background(), 0)
r := composit.Apply(context.Background(), NewMaybeWithValue(0))
if r.Error != nil {
t.Fatalf("unexpected error: %s", r.Error)
}
Expand All @@ -91,8 +126,8 @@ func TestObservations(t *testing.T) {
fn1 := getFn(nil, "succeed")
fn2 := getFn(nil, "succeed")
composit := Compose2(fn1, fn2)
r1 := composit.Apply(context.Background(), 3)
r2 := composit.Apply(context.Background(), 42)
r1 := composit.Apply(context.Background(), NewMaybeWithValue(3))
r2 := composit.Apply(context.Background(), NewMaybeWithValue(42))
if len(r1.Observations) != 2 || len(r2.Observations) != 2 {
t.Fatalf("unexpected number of observations")
}
Expand Down Expand Up @@ -123,7 +158,7 @@ func TestCounter(t *testing.T) {
fn := getFn(tt.err, "maybe fail")
cnt := NewCounter[int]()
composit := Compose2(fn, cnt.Func())
r := composit.Apply(context.Background(), 42)
r := composit.Apply(context.Background(), NewMaybeWithValue(42))
cntVal := cnt.Value()
if cntVal != tt.expect {
t.Fatalf("unexpected counter value")
Expand Down
Loading

0 comments on commit 227bea1

Please sign in to comment.