diff --git a/context.go b/context.go index 48a0345..9e00d02 100644 --- a/context.go +++ b/context.go @@ -1,6 +1,7 @@ package di import ( + stdcontext "context" "errors" "fmt" "reflect" @@ -10,6 +11,91 @@ type Context struct { path map[string]int holdersByType map[reflect.Type][]*holder holdersByName map[string]*holder + initialized bool + shutdown bool +} + +func (ctx *Context) Initialize() { + err := ctx.InitializeOrErr() + if err != nil { + panic(err) + } +} + +func (ctx *Context) InitializeOrErr() *Error { + if ctx.initialized { + return newLifecycleError("context already initialized") + } + if ctx.shutdown { + return newLifecycleError("context already shutdown") + } + deps := ctx.GetAllByType(new(Initializable)) + for _, dep := range deps { + initializable := dep.(Initializable) + err := func() (suberr error) { + defer func() { + if r := recover(); r != nil { + switch x := r.(type) { + case string: + suberr = errors.New(x) + case error: + suberr = x + default: + suberr = errors.New("shutdown panic") + } + } + }() + initializable.Initialize() + return nil + }() + if err != nil { + depType := reflect.TypeOf(dep) + return newInitializationError(&depType, err) + } + } + ctx.initialized = true + return nil +} + +func (ctx *Context) Shutdown(context stdcontext.Context) { + err := ctx.ShutdownOrErr(context) + if err != nil { + panic(err) + } +} + +func (ctx *Context) ShutdownOrErr(context stdcontext.Context) *Error { + if ctx.shutdown { + return newLifecycleError("context already shutdown") + } + rtype := reflect.TypeOf(new(Shutdownable)).Elem() + holders := ctx.holdersByType[rtype] + for _, holder := range holders { + if holder.created { + shutdownable := holder.instance.(Shutdownable) + err := func() (suberr error) { + defer func() { + if r := recover(); r != nil { + switch x := r.(type) { + case string: + suberr = errors.New(x) + case error: + suberr = x + default: + suberr = errors.New("shutdown panic") + } + } + }() + shutdownable.Shutdown(context) + return nil + }() + if err != nil { + return newShutdownError(&holder.providesType, err) + } + } + } + ctx.shutdown = true + return nil } func (ctx *Context) GetNamed(name string) any { @@ -21,6 +107,9 @@ func (ctx *Context) GetNamed(name string) any { } func (ctx *Context) GetNamedOrErr(name string) (any, *Error) { + if ctx.shutdown { + return nil, newLifecycleError("context already shutdown") + } holder := ctx.holdersByName[name] if holder == nil { return empty[any](), newMissingDependencyError(&name, nil) @@ -68,6 +157,9 @@ func (ctx *Context) GetAllByTypeOrErr(atype any) ([]any, *Error) { } func (ctx *Context) getByRType(rtype reflect.Type) (any, *Error) { + if ctx.shutdown { + return nil, newLifecycleError("context already shutdown") + } holders := ctx.holdersByType[rtype] if holders == nil { return empty[any](), newMissingDependencyError(nil, &rtype) @@ -91,6 +183,9 @@ func (ctx *Context) getByRType(rtype reflect.Type) (any, *Error) { } func (ctx *Context) getAllByRType(rtype reflect.Type) ([]any, *Error) { + if ctx.shutdown { + return nil, newLifecycleError("context already shutdown") + } holders := ctx.holdersByType[rtype] result := make([]any, 0) for _, holder := range holders { diff --git a/context_builder.go b/context_builder.go index eb7e78c..bb53a72 100644 --- a/context_builder.go +++ b/context_builder.go @@ -57,6 +57,18 @@ func (ctxb *ContextBuilder) addOrErr(ctor any, lazy bool) *Error { if err != nil { return err } + if hldr.providesType.Implements(initializableRType) { + err = ctxb.addHolderForType(hldr, initializableRType) + if err != nil { + return err + } + } + if hldr.providesType.Implements(shutdownableRType) { + err = ctxb.addHolderForType(hldr, shutdownableRType) + if err != nil { + return err + } + } err = ctxb.addHolderForType(hldr, hldr.providesType) if err != nil { return err @@ -187,7 +199,7 @@ func (ctxb *ContextBuilder) addHolderForType(hldr *holder, rtype reflect.Type) * } func (ctxb *ContextBuilder) addHolderForName(hldr *holder, name string) *Error { - if ctxb.holdersByName[name] != nil { + if ctxb.holdersByName[name] != nil && ctxb.holdersByName[name] != hldr { return newDuplicatedNameError(name) } ctxb.holdersByName[name] = hldr diff --git a/errors.go b/errors.go index 6c6c4ed..8f1bbe3 100644 --- a/errors.go +++ b/errors.go @@ -16,6 +16,9 @@ const ( ErrTypeInvalidType ErrTypeInvalidConstructor ErrTypeCyclicDependency + ErrTypeDependencyInitialization + ErrTypeDependencyShutdown + ErrTypeLifecycle ) type Error struct { @@ -50,6 +53,14 @@ func (e *Error) RootCause() error { return e.cause } +func newLifecycleError(cause string) *Error { + msg := fmt.Sprintf("context lifecycle error: %s", cause) + return &Error{ + errType: ErrTypeLifecycle, + message: msg, + } +} + func newDuplicatedRegistrationError() *Error { return &Error{ errType: ErrTypeDuplicatedRegistration, @@ -82,6 +93,24 @@ func newInvalidTypeError(objName *string, objType reflect.Type, expectedType ref } } +func newInitializationError(objType *reflect.Type, cause error) *Error { + msg := fmt.Sprintf("could not initialize dependency: %s, cause:\n%s", descriptor(nil, objType), cause) + return &Error{ + errType: ErrTypeDependencyInitialization, + message: msg, + cause: cause, + } +} + +func newShutdownError(objType *reflect.Type, cause error) *Error { + msg := fmt.Sprintf("could not shutdown dependency: %s, cause:\n%s", descriptor(nil, objType), cause) + return &Error{ + errType: ErrTypeDependencyShutdown, + message: msg, + cause: cause, + } +} + func newCyclicDependencyError(path []string) *Error { msg := "" for _, d := range path { diff --git a/lifecycle.go b/lifecycle.go new file mode 100644 index 0000000..debe8d7 --- /dev/null +++ b/lifecycle.go @@ -0,0 +1,21 @@ +package di + +import ( + stdcontext "context" + "reflect" +) + +type Shutdownable interface { + Shutdown(context stdcontext.Context) +} + +type Initializable interface { + Initialize() +} + +var ( + initializableType = new(Initializable) + initializableRType = reflect.TypeOf(initializableType).Elem() + shutdownableType = new(Shutdownable) + shutdownableRType = reflect.TypeOf(shutdownableType).Elem() +) diff --git a/test/lifecycle_test.go b/test/lifecycle_test.go new file mode 100644 index 0000000..4fdb599 --- /dev/null +++ b/test/lifecycle_test.go @@ -0,0 +1,140 @@ +package di_test + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + stdcontext "context" + + di "github.com/coditory/go-di" +) + +type CtxAwareFoo struct { + initialized int + shutdown int + errOnInitialize bool + errOnShutdown bool +} + +func (f *CtxAwareFoo) Initialize() { + if f.errOnInitialize { + panic(errSimulated) + } + f.initialized++ +} + +func (f *CtxAwareFoo) Shutdown(context stdcontext.Context) { + if f.errOnShutdown { + panic(errSimulated) + } + f.shutdown++ +} + +type LifecycleSuite struct { + suite.Suite +} + +func (suite *LifecycleSuite) TestDependencyInit() { + foo1 := CtxAwareFoo{} + foo2 := CtxAwareFoo{} + ctxb := di.NewContextBuilder() + ctxb.Add(&foo1) + ctxb.Add(&foo2) + ctx := ctxb.Build() + suite.Equal(foo1.initialized, 0) + suite.Equal(foo2.initialized, 0) + ctx.Initialize() + suite.Equal(foo1.initialized, 1) + suite.Equal(foo2.initialized, 1) +} + +func (suite *LifecycleSuite) TestDuplicatedInitError() { + foo1 := CtxAwareFoo{} + foo2 := CtxAwareFoo{} + ctxb := di.NewContextBuilder() + ctxb.Add(&foo1) + ctxb.Add(&foo2) + ctx := ctxb.Build() + ctx.Initialize() + err := ctx.InitializeOrErr() + suite.Equal("context lifecycle error: context already initialized", err.Error()) + suite.Equal(err.ErrType(), di.ErrTypeLifecycle) +} + +func (suite *LifecycleSuite) TestDependencyInitError() { + foo1 := CtxAwareFoo{} + foo2 := CtxAwareFoo{errOnInitialize: true} + ctxb := di.NewContextBuilder() + ctxb.Add(&foo1) + ctxb.Add(&foo2) + ctx := ctxb.Build() + err := ctx.InitializeOrErr() + suite.Equal("could not initialize dependency: *di_test.CtxAwareFoo, cause:\nsimulated", err.Error()) + suite.Equal(err.ErrType(), di.ErrTypeDependencyInitialization) +} + +func (suite *LifecycleSuite) TestDependencyShutdown() { + foo1 := CtxAwareFoo{} + foo2 := CtxAwareFoo{} + ctxb := di.NewContextBuilder() + ctxb.Add(&foo1) + ctxb.Add(&foo2) + ctx := ctxb.Build() + suite.Equal(foo1.shutdown, 0) + suite.Equal(foo2.shutdown, 0) + ctx.Shutdown(stdcontext.TODO()) + suite.Equal(foo1.shutdown, 1) + suite.Equal(foo2.shutdown, 1) +} + +func (suite *LifecycleSuite) TestDuplicatedShutdownError() { + foo1 := CtxAwareFoo{} + foo2 := CtxAwareFoo{} + ctxb := di.NewContextBuilder() + ctxb.Add(&foo1) + ctxb.Add(&foo2) + ctx := ctxb.Build() + ctx.Shutdown(stdcontext.TODO()) + err := ctx.ShutdownOrErr(stdcontext.TODO()) + suite.Equal("context lifecycle error: context already shutdown", err.Error()) + suite.Equal(err.ErrType(), di.ErrTypeLifecycle) +} + +func (suite *LifecycleSuite) TestDependencyShutdownError() { + foo1 := CtxAwareFoo{} + foo2 := CtxAwareFoo{errOnShutdown: true} + ctxb := di.NewContextBuilder() + ctxb.Add(&foo1) + ctxb.Add(&foo2) + ctx := ctxb.Build() + err := ctx.ShutdownOrErr(stdcontext.TODO()) + suite.Equal("could not shutdown dependency: *di_test.CtxAwareFoo, cause:\nsimulated", err.Error()) + suite.Equal(err.ErrType(), di.ErrTypeDependencyShutdown) +} + +func (suite *LifecycleSuite) TestInitAfterShutdownError() { + foo1 := CtxAwareFoo{} + ctxb := di.NewContextBuilder() + ctxb.Add(&foo1) + ctx := ctxb.Build() + ctx.Shutdown(stdcontext.TODO()) + err := ctx.InitializeOrErr() + suite.Equal("context lifecycle error: context already shutdown", err.Error()) + suite.Equal(err.ErrType(), di.ErrTypeLifecycle) +} + +func (suite *LifecycleSuite) TestGettingDependencyAfterShutdown() { + foo1 := CtxAwareFoo{} + ctxb := di.NewContextBuilder() + ctxb.Add(&foo1) + ctx := ctxb.Build() + ctx.Shutdown(stdcontext.TODO()) + _, err := ctx.GetByTypeOrErr(new(Foo)) + suite.Equal("context lifecycle error: context already shutdown", err.Error()) + suite.Equal(err.ErrType(), di.ErrTypeLifecycle) +} + +func TestLifecycleSuite(t *testing.T) { + suite.Run(t, new(LifecycleSuite)) +} diff --git a/test/named_dependency_test.go b/test/named_dependency_test.go index 90157e4..eebeec2 100644 --- a/test/named_dependency_test.go +++ b/test/named_dependency_test.go @@ -12,7 +12,7 @@ type NamedDependencySuite struct { suite.Suite } -func (suite *NamedDependencySuite) TestGetNamedDependencyByName() { +func (suite *LifecycleSuite) TestGetNamedDependencyByName() { foo1 := Foo{id: "foo1"} foo2 := Foo{id: "foo2"} ctxb := di.NewContextBuilder() @@ -26,7 +26,7 @@ func (suite *NamedDependencySuite) TestGetNamedDependencyByName() { suite.Equal(&foo2, result) } -func (suite *NamedDependencySuite) TestGetNamedDependencyByType() { +func (suite *LifecycleSuite) TestGetNamedDependencyByType() { foo1 := Foo{id: "foo1"} foo2 := Foo{id: "foo2"} ctxb := di.NewContextBuilder() @@ -40,14 +40,28 @@ func (suite *NamedDependencySuite) TestGetNamedDependencyByType() { suite.Equal([]*Foo{&foo1, &foo2, &foo}, all) } -func (suite *NamedDependencySuite) TestErrorOnDuplicatedName() { +func (suite *LifecycleSuite) TestRegisterNamedDependencyTwiceForDifferentTypes() { + foo1 := Foo{id: "foo1"} + ctxb := di.NewContextBuilder() + ctxb.AddNamedAs("foo", new(any), &foo1) + ctxb.AddNamed("foo", &foo1) + ctx := ctxb.Build() + result := di.Get[*Foo](ctx) + suite.Equal(&foo1, result) + resultAny := di.Get[any](ctx) + suite.Equal(&foo1, resultAny) + all := di.GetAll[*Foo](ctx) + suite.Equal([]*Foo{&foo1}, all) +} + +func (suite *LifecycleSuite) TestErrorOnDuplicatedName() { ctxb := di.NewContextBuilder() ctxb.AddNamed("foo", &Foo{id: "foo1"}) err := ctxb.AddNamedOrErr("foo", &Foo{id: "foo2"}) suite.Equal("duplicated dependency name: foo", err.Error()) } -func (suite *NamedDependencySuite) TestErrorOnInvalidType() { +func (suite *LifecycleSuite) TestErrorOnInvalidType() { ctxb := di.NewContextBuilder() ctxb.AddNamed("foo", &foo) ctx := ctxb.Build() @@ -57,5 +71,5 @@ func (suite *NamedDependencySuite) TestErrorOnInvalidType() { } func TestNamedDependencySuite(t *testing.T) { - suite.Run(t, new(NamedDependencySuite)) + suite.Run(t, new(LifecycleSuite)) }