diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..4adf30b --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,63 @@ +# This workflow will build a golang project +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go + +name: Go + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.24' + + - name: Build + run: go build -v ./... + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.24' + + - name: Test + run: go test -coverprofile=coverage.out ./... + + - name: Upload coverage report + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: coverage.out + fail_ci_if_error: true + + lint: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.24' + + - name: Install golangci-lint + run: | + go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.64.5 + + - name: Run golangci-lint + run: $(go env GOPATH)/bin/golangci-lint run \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..485dee6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..f637590 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,152 @@ +linters-settings: + gosec: + excludes: + - G306 + - G115 + depguard: + # new configuration + rules: + logger: + deny: + # logging is allowed only by logutils.Log, + # logrus is allowed to use only in logutils package. + - pkg: "github.com/sirupsen/logrus" + desc: logging is allowed only by logutils.Log + dupl: + threshold: 100 + funlen: + lines: -1 # the number of lines (code + empty lines) is not a right metric and leads to code without empty line or one-liner. + statements: 75 + goconst: + min-len: 2 + min-occurrences: 3 + gocritic: + enabled-tags: + - diagnostic + - experimental + - opinionated + - performance + - style + disabled-checks: + - dupImport # https://github.com/go-critic/go-critic/issues/845 + - ifElseChain + - octalLiteral + - whyNoLint + gocyclo: + min-complexity: 15 + gofmt: + rewrite-rules: + - pattern: 'interface{}' + replacement: 'any' + goimports: + local-prefixes: github.com/golangci/golangci-lint + + govet: + settings: + printf: + funcs: + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Infof + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Warnf + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Errorf + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Fatalf + enable: + - nilness + - shadow + errorlint: + asserts: false + lll: + line-length: 140 + nolintlint: + allow-unused: false # report any unused nolint directives + require-explanation: false # don't require an explanation for nolint directives + require-specific: false # don't require nolint directives to be specific about which linter is being skipped + revive: + rules: + - name: unexported-return + disabled: true + - name: unused-parameter + +linters: + disable-all: true + enable: + - bodyclose + - depguard + - dogsled + - dupl + - errcheck + - errorlint + - funlen + - gocheckcompilerdirectives + - gochecknoinits + - goconst + - gocritic + - gocyclo + - gofmt + - goimports + - goprintffuncname + - gosec + - gosimple + - govet + - ineffassign + - lll + - nakedret + - noctx + - nolintlint + - revive + - staticcheck + - stylecheck + - typecheck + - unconvert + - unparam + - unused + - whitespace + + # don't enable: + # - asciicheck + # - scopelint + # - gochecknoglobals + # - gocognit + # - godot + # - godox + # - goerr113 + # - interfacer + # - maligned + # - nestif + # - prealloc + # - testpackage + # - wsl + +issues: + # Excluding configuration per-path, per-linter, per-text and per-source + exclude-rules: + - path: pkg/golinters/errcheck.go + text: "SA1019: errCfg.Exclude is deprecated: use ExcludeFunctions instead" + - path: pkg/commands/run.go + text: "SA1019: lsc.Errcheck.Exclude is deprecated: use ExcludeFunctions instead" + - path: pkg/commands/run.go + text: "SA1019: e.cfg.Run.Deadline is deprecated: Deadline exists for historical compatibility and should not be used." + + - path: pkg/golinters/gofumpt.go + text: "SA1019: settings.LangVersion is deprecated: use the global `run.go` instead." + - path: pkg/golinters/staticcheck_common.go + text: "SA1019: settings.GoVersion is deprecated: use the global `run.go` instead." + - path: pkg/lint/lintersdb/manager.go + text: "SA1019: (.+).(GoVersion|LangVersion) is deprecated: use the global `run.go` instead." + - path: pkg/golinters/unused.go + text: "rangeValCopy: each iteration copies 160 bytes \\(consider pointers or indexing\\)" + - path: test/(fix|linters)_test.go + text: "string `gocritic.go` has 3 occurrences, make it a constant" + + # Due to a change inside go-critic v0.10.0, some reports have been removed, + # but as we run analysis with the previous version of golangci-lint this leads to a paradoxical situation. + # This exclusion will be removed when the next version of golangci-lint (v1.56.0) will be released. + - path: pkg/golinters/nolintlint/nolintlint.go + text: "hugeParam: (i|b) is heavy \\(\\d+ bytes\\); consider passing it by pointer" + exclude-dirs: + - test/testdata_etc # test files + - internal/cache # extracted from Go code + - internal/renameio # extracted from Go code + - internal/robustio # extracted from Go code + +run: + timeout: 5m diff --git a/README.md b/README.md index e69de29..47f0ad7 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,52 @@ +# go-inject + +## Description + +`go-inject` is a dependency injection library that support contextual scope. + +## Example + +```go +package myapp + +import ( + "context" + + "github.com/illuin-tech/goinject" +) + +type key int +const myScopeKey key = 0 + +const MyScope = "MyScope" + +// define function to declare your own scope in context +func WithMyScopeEnabled(ctx context.Context) context.Context { + return goinject.WithContextualScopeEnabled(ctx, myScopeKey) +} + +func ShutdownMyContextScoped(ctx context.Context) { + goinject.ShutdownContextualScope(ctx, myScopeKey) +} + +// define injection modules +var Module = goinject.Module("myModule", + goinject.RegisterScope(MyScope, goinject.NewContextualScope(myScopeKey)), + goinject.Provide(func() string { + return "Hello world from scope" + }, goinject.In(MyScope)), +) + +func main() { + ctx := context.Background() + + // enable scope + ctx = WithMyScopeEnabled(ctx) + defer ShutdownMyContextScoped(ctx) + + injector, _ := goinject.NewInjector(Module) + _ = injector.Invoke(ctx, func(hello string) { + println(hello) + }) +} +``` \ No newline at end of file diff --git a/binding.go b/binding.go new file mode 100644 index 0000000..5fee497 --- /dev/null +++ b/binding.go @@ -0,0 +1,36 @@ +package goinject + +import ( + "context" + "fmt" + "reflect" +) + +// binding defines a type mapped to a more concrete type +type binding struct { + typeof reflect.Type + provider reflect.Value + providedType reflect.Type + annotatedWith string + scope string + destroyMethod func(value reflect.Value) +} + +func (b *binding) create(ctx context.Context, injector *Injector) (reflect.Value, error) { + res, err := injector.callFunctionWithArgumentInstance(ctx, b.provider) + if err != nil { + return reflect.Value{}, + fmt.Errorf("failed to call provider function for type %q: %w", b.providedType.String(), err) + } + if b.provider.Type().NumOut() == 2 { + errValue := res[1].Interface() + if errValue != nil { + err, _ = errValue.(error) + } + } + if err != nil { + return res[0], fmt.Errorf("provider for type %q returned error: %w", b.providedType.String(), err) + } else { + return res[0], nil + } +} diff --git a/condition.go b/condition.go new file mode 100644 index 0000000..33a4d77 --- /dev/null +++ b/condition.go @@ -0,0 +1,29 @@ +package goinject + +import "os" + +type Conditional interface { + evaluate() bool +} + +type environmentVariableConditional struct { + name string + havingValue string + matchIfMissing bool +} + +func (c *environmentVariableConditional) evaluate() bool { + val, ok := os.LookupEnv(c.name) + if !ok { + return c.matchIfMissing + } + return val == c.havingValue +} + +func OnEnvironmentVariable(name, havingValue string, matchIfMissing bool) Conditional { + return &environmentVariableConditional{ + name: name, + havingValue: havingValue, + matchIfMissing: matchIfMissing, + } +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..afb2918 --- /dev/null +++ b/errors.go @@ -0,0 +1,68 @@ +package goinject + +import ( + "fmt" + "reflect" +) + +type invalidInputError struct { + message string +} + +var _ error = &invalidInputError{} + +func newInvalidInputError(msg string) *invalidInputError { + return &invalidInputError{msg} +} + +func (e *invalidInputError) Error() string { return e.message } + +type injectionError struct { + rType reflect.Type + annotation string + cause error +} + +var _ error = &injectionError{} + +func newInjectionError(typ reflect.Type, annotation string, cause error) *injectionError { + return &injectionError{typ, annotation, cause} +} + +func (e *injectionError) Error() string { + return fmt.Sprintf("Got error while resolving type %s (with annotation %q):\n%s", e.rType.String(), e.annotation, e.cause) +} + +func (e *injectionError) Unwrap() error { return e.cause } + +type contextScopedNotActiveError struct { +} + +var _ error = &contextScopedNotActiveError{} + +func newContextScopedNotActiveError() *contextScopedNotActiveError { + return &contextScopedNotActiveError{} +} + +func (e *contextScopedNotActiveError) Error() string { return "Scope is not active" } + +type injectorConfigurationError struct { + message string + cause error +} + +var _ error = &injectorConfigurationError{} + +func newInjectorConfigurationError(message string, cause error) *injectorConfigurationError { + return &injectorConfigurationError{message, cause} +} + +func (e *injectorConfigurationError) Error() string { + if e.cause == nil { + return e.message + } else { + return fmt.Sprintf("%s:\n%s", e.message, e.cause) + } +} + +func (e *injectorConfigurationError) Unwrap() error { return e.cause } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..364c9fb --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/illuin-tech/goinject + +go 1.24.0 + +require github.com/stretchr/testify v1.10.0 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..713a0b4 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/injector.go b/injector.go new file mode 100644 index 0000000..d4976ff --- /dev/null +++ b/injector.go @@ -0,0 +1,323 @@ +package goinject + +import ( + "context" + "fmt" + "reflect" + "strings" +) + +var errorReflectType = reflect.TypeFor[error]() +var invocationContextReflectType = reflect.TypeFor[InvocationContext]() + +// Injector defines bindings & scopes +type Injector struct { + bindings map[reflect.Type]map[string][]*binding // list of available bindings by type and annotations + scopes map[string]Scope // Scope by names + singletonScope *singletonScope +} + +// NewInjector builds up a new Injector out of a list of Modules with singleton scope +func NewInjector(options ...Option) (*Injector, error) { + mod := &configuration{ + bindings: make(map[*binding]bool), + scopes: make(map[string]Scope), + } + + for _, o := range options { + err := o.apply(mod) + if err != nil { + return nil, err + } + } + + singletonScope := newSingletonScope() + mod.scopes[Singleton] = singletonScope + mod.scopes[PerLookUp] = newPerLookUpScope() + + injector := &Injector{ + bindings: make(map[reflect.Type]map[string][]*binding), + scopes: make(map[string]Scope), + singletonScope: singletonScope, + } + + injectorType := reflect.TypeFor[*Injector]() + injectorBinding := &binding{ + typeof: injectorType, + provider: reflect.ValueOf(func() *Injector { return injector }), + providedType: injectorType, + scope: Singleton, + } + + injector.scopes = mod.scopes + for b := range mod.bindings { + _, ok := injector.bindings[b.typeof] + if !ok { + injector.bindings[b.typeof] = make(map[string][]*binding) + } + injector.bindings[b.typeof][b.annotatedWith] = append(injector.bindings[b.typeof][b.annotatedWith], b) + } + + injector.bindings[injectorType] = make(map[string][]*binding) + injector.bindings[injectorType][""] = []*binding{injectorBinding} + + err := injector.eagerlyCreateSingletons() + if err != nil { + return nil, err + } + return injector, nil +} + +// Shutdown clear underlying singleton scope +func (injector *Injector) Shutdown() { + injector.singletonScope.Shutdown() + injector.bindings = make(map[reflect.Type]map[string][]*binding) + injector.scopes = make(map[string]Scope) +} + +// Invoke will execute the parameter function (which must be a function that optionally can return an error). +// argument of function will be resolved by the injector using configured providers & scope. +func (injector *Injector) Invoke(ctx context.Context, function any) error { + if function == nil { + return newInvalidInputError("can't invoke on nil") + } + fvalue := reflect.ValueOf(function) + ftype := fvalue.Type() + if ftype.Kind() != reflect.Func { + return newInvalidInputError( + fmt.Sprintf("can't invoke non-function %v (type %v)", function, ftype)) + } + + if ftype.NumOut() > 1 || (ftype.NumOut() == 1 && !ftype.Out(0).AssignableTo(errorReflectType)) { + return newInvalidInputError("can't invoke on function whose return type is not error or no return type") + } + + res, err := injector.callFunctionWithArgumentInstance(ctx, fvalue) + if err != nil { + return fmt.Errorf("failed to call invokation function: %w", err) + } + if ftype.NumOut() == 1 { + invokationError := res[0].Interface().(error) + if invokationError != nil { + return fmt.Errorf("invokation returned error: %w", invokationError) + } + } + return nil +} + +func (injector *Injector) eagerlyCreateSingletons() error { + for _, bindingsByAnnotation := range injector.bindings { + for _, bindingList := range bindingsByAnnotation { + for _, b := range bindingList { + if b.scope == Singleton { + _, err := injector.getScopedInstanceFromBinding(nil, b) //nolint:staticcheck + if err != nil { + return fmt.Errorf("failed to get singleton instance: %w", err) + } + } + } + } + } + return nil +} + +func (injector *Injector) callFunctionWithArgumentInstance( + ctx context.Context, + fValue reflect.Value, +) ([]reflect.Value, error) { + fType := fValue.Type() + in := make([]reflect.Value, fType.NumIn()) + var err error + for i := 0; i < fType.NumIn(); i++ { + if in[i], err = injector.getFunctionArgumentInstance(ctx, fType.In(i)); err != nil { + return []reflect.Value{}, fmt.Errorf("failed to resolve function argument #%d: %w", i, err) + } + } + + res := fValue.Call(in) + return res, nil +} + +func (injector *Injector) getFunctionArgumentInstance(ctx context.Context, argType reflect.Type) (reflect.Value, error) { + if EmbedsParams(argType) { + return injector.createEmbeddedParams(ctx, argType) + } else { + return injector.getInstanceOfAnnotatedType(ctx, argType, "", false) + } +} + +func (injector *Injector) createEmbeddedParams(ctx context.Context, embeddedType reflect.Type) (reflect.Value, error) { + if embeddedType.Kind() == reflect.Ptr { + n := reflect.New(embeddedType.Elem()) + return n, injector.setParamFields(ctx, n.Elem()) + } else { // struct + n := reflect.New(embeddedType).Elem() + return n, injector.setParamFields(ctx, n) + } +} + +func (injector *Injector) setParamFields( + ctx context.Context, + paramValue reflect.Value, +) error { + embeddedType := paramValue.Type() + for fieldIndex := 0; fieldIndex < embeddedType.NumField(); fieldIndex++ { + field := paramValue.Field(fieldIndex) + if field.Type() == _paramType { + continue + } + if tag, ok := embeddedType.Field(fieldIndex).Tag.Lookup("inject"); ok { + if !field.CanSet() { + return newInjectionError(field.Type(), tag, fmt.Errorf("use inject tag on unsettable field")) + } + + var optional bool + for _, option := range strings.Split(tag, ",") { + if strings.TrimSpace(option) == "optional" { + optional = true + } + } + tag = strings.Split(tag, ",")[0] + + instance, err := injector.getInstanceOfAnnotatedType(ctx, field.Type(), tag, optional) + if err != nil { + return newInjectionError(field.Type(), tag, err) + } + if instance.IsValid() { + field.Set(instance) + } else if optional { + continue + } else { + return newInjectionError(field.Type(), tag, fmt.Errorf("cannot get valid instance from scope")) + } + } + } + return nil +} + +// getInstanceOfAnnotatedType resolves a type request within the injector +func (injector *Injector) getInstanceOfAnnotatedType( + ctx context.Context, + t reflect.Type, + annotation string, + optional bool, +) (reflect.Value, error) { + // if is slice, return as multi bindings + if t.Kind() == reflect.Slice { + bindings := injector.findBindingsForAnnotatedType(t.Elem(), annotation) + if len(bindings) > 0 { + n := reflect.MakeSlice(t, 0, len(bindings)) + for _, binding := range bindings { + r, err := injector.getScopedInstanceFromBinding(ctx, binding) + if err != nil { + return reflect.Value{}, err + } + n = reflect.Append(n, r) + } + return n, nil + } else if optional { + return reflect.MakeSlice(t, 0, 0), nil + } else { + return reflect.MakeSlice(t, 0, 0), newInjectionError(t.Elem(), annotation, + fmt.Errorf("did not found binding, expected at least one")) + } + } + + // check if there is a binding for this type & annotation + bindings := injector.findBindingsForAnnotatedType(t, annotation) + if len(bindings) > 1 { + return reflect.Value{}, + newInjectionError(t, annotation, fmt.Errorf("found multiple bindings expected one")) + } else if len(bindings) == 1 { + return injector.getScopedInstanceFromBinding(ctx, bindings[0]) + } else if injector.isProviderType(t) { + return injector.createProviderValue(t, annotation, optional), nil + } else if t == invocationContextReflectType { + return reflect.ValueOf(ctx), nil + } else if optional { + return reflect.Value{}, nil + } else { + return reflect.Value{}, + newInjectionError(t, annotation, fmt.Errorf("did not found binding, expected one")) + } +} + +func (injector *Injector) isProviderType(t reflect.Type) bool { + return t.Kind() == reflect.Func && + t.NumIn() == 1 && t.In(0) == invocationContextReflectType && + t.NumOut() == 2 && t.Out(1) == errorReflectType +} + +func (injector *Injector) createProviderValue( + t reflect.Type, + annotation string, + optional bool, +) reflect.Value { + bindingType := t.Out(0) + return reflect.MakeFunc(t, func(args []reflect.Value) (results []reflect.Value) { + ctx := args[0].Interface().(context.Context) + instance, err := injector.getInstanceOfAnnotatedType(ctx, bindingType, annotation, optional) + var instanceVal reflect.Value + if instance.IsValid() { + instanceVal = instance + } else { + instanceVal = reflect.Zero(bindingType) + } + var errVal reflect.Value + if err != nil { + errVal = reflect.ValueOf(err) + } else { + errVal = reflect.Zero(errorReflectType) + } + return []reflect.Value{ + instanceVal, + errVal, + } + }) +} + +func (injector *Injector) findBindingsForAnnotatedType( + t reflect.Type, + annotation string, +) []*binding { + if _, ok := injector.bindings[t]; ok && len(injector.bindings[t][annotation]) > 0 { + bindings := injector.bindings[t][annotation] + res := make([]*binding, len(bindings)) + copy(res, bindings) + return res + } + + return []*binding{} +} + +func (injector *Injector) getScopedInstanceFromBinding( + ctx context.Context, + binding *binding, +) (reflect.Value, error) { + scope, err := injector.getScopeFromBinding(binding) + if err != nil { + return reflect.Value{}, err + } + val, err := scope.ResolveBinding(ctx, binding, func() (Instance, error) { + val, creationError := binding.create(ctx, injector) + destroyMethod := binding.destroyMethod + if creationError == nil && destroyMethod != nil && !val.IsZero() { + scope.RegisterDestructionCallback( + ctx, + func() { destroyMethod(val) }, + ) + } + return Instance(val), creationError + }) + return reflect.Value(val), err +} + +func (injector *Injector) getScopeFromBinding( + binding *binding, +) (Scope, error) { + if scope, ok := injector.scopes[binding.scope]; ok { + return scope, nil + } + return nil, newInjectionError( + binding.typeof, binding.annotatedWith, fmt.Errorf("unknown scope %q for binding", binding.scope)) +} diff --git a/injector_test.go b/injector_test.go new file mode 100644 index 0000000..1f73eb5 --- /dev/null +++ b/injector_test.go @@ -0,0 +1,687 @@ +package goinject + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type Parent struct { +} + +type Child struct { + parent *Parent +} + +func TestShouldReturnFromProvider(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Parent { return &Parent{} }), + Provide(func(parent *Parent) *Child { return &Child{parent: parent} }), + ) + assert.Nil(t, err) + ctx := context.Background() + var parent *Parent + err = injector.Invoke(ctx, func(p *Parent) { + parent = p + }) + assert.Nil(t, err) + var child *Child + err = injector.Invoke(ctx, func(c *Child) { + child = c + }) + assert.Nil(t, err) + assert.Same(t, parent, child.parent) + }) +} + +func TestProvideShouldAcceptErrorReturnProviders(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() (*Parent, error) { return &Parent{}, nil }, In(PerLookUp)), + Provide(func(_ *Parent) (*Child, error) { return nil, fmt.Errorf("failed to create child") }, In(PerLookUp)), + ) + assert.Nil(t, err) + ctx := context.Background() + t.Run("And return type if no error", func(t *testing.T) { + err = injector.Invoke(ctx, func(parent *Parent) { + assert.NotNil(t, parent) + }) + assert.Nil(t, err) + }) + t.Run("And return error otherwise", func(t *testing.T) { + err = injector.Invoke(ctx, func(_ *Child) { + assert.Fail(t, "should not be reached") + }) + assert.ErrorContains(t, err, "failed to create child") + }) + }) +} + +func TestUseUnknownScopeShouldReturnError(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Parent { return &Parent{} }, In("unknown")), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(_ *Parent) { + assert.Fail(t, "should not be reached") + }) + assert.ErrorContains(t, err, "unknown scope \"unknown\" for binding") + }) +} + +type TestInvokeParamOptional struct { + Params + ParentA *Parent `inject:", optional"` + ParentB *Parent `inject:"B"` +} + +func TestInvokeWithOptional(t *testing.T) { + assert.NotPanics(t, func() { + t.Run("using param struct argument", func(t *testing.T) { + injector, err := NewInjector( + Provide(func() *Parent { + return &Parent{} + }, Named("B")), + ) + assert.Nil(t, err) + var parentA *Parent + var parentB *Parent + ctx := context.Background() + err = injector.Invoke(ctx, func(param TestInvokeParamOptional) { + parentA = param.ParentA + parentB = param.ParentB + }) + assert.Nil(t, err) + assert.Nil(t, parentA) + assert.NotNil(t, parentB) + }) + + t.Run("using param pointer argument", func(t *testing.T) { + injector, err := NewInjector( + Provide(func() *Parent { + return &Parent{} + }, Named("B")), + ) + assert.Nil(t, err) + var parentA *Parent + var parentB *Parent + ctx := context.Background() + err = injector.Invoke(ctx, func(param *TestInvokeParamOptional) { + parentA = param.ParentA + parentB = param.ParentB + }) + assert.Nil(t, err) + assert.Nil(t, parentA) + assert.NotNil(t, parentB) + }) + }) +} + +type Color struct { + name string +} + +type TestInvokeParamAnnotated struct { + Params + Color *Color `inject:"red"` +} + +func TestInvokeWithAnnotation(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Color { return &Color{name: "red"} }, Named("red")), + Provide(func() *Color { return &Color{name: "blue"} }, Named("blue")), + ) + assert.Nil(t, err) + var color *Color + ctx := context.Background() + err = injector.Invoke(ctx, func(param TestInvokeParamAnnotated) { + color = param.Color + }) + assert.NotNil(t, color) + assert.Equal(t, "red", color.name) + assert.Nil(t, err) + }) +} + +func TestInvokeShouldReturnErrorIfExpectedSingleBindingButMultipleFound(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Color { return &Color{name: "blue"} }), + Provide(func() *Color { return &Color{name: "red"} }), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(_ *Color) { + assert.Fail(t, "should not be reached") + }) + assert.NotNil(t, err) + // verify error tree contains an injection error + var expectedErrorType *injectionError + assert.ErrorAs(t, err, &expectedErrorType) + assert.Equal(t, + "failed to call invokation function: failed to resolve"+ + " function argument #0: Got error while resolving type *goinject.Color"+ + " (with annotation \"\"):\nfound multiple bindings expected one", + err.Error()) + }) +} + +type Red *Color +type Blue *Color + +func TestInvokeUsingTypeDefinition(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Color { return &Color{name: "blue"} }, As(Type[Blue]())), + Provide(func() *Color { return &Color{name: "red"} }, As(Type[Red]())), + ) + assert.Nil(t, err) + var color Red + ctx := context.Background() + err = injector.Invoke(ctx, func(c Red) { + color = c + }) + assert.NotNil(t, color) + assert.Equal(t, "red", color.name) + assert.Nil(t, err) + }) +} + +func TestInstallModuleShouldInstallBindingsOnce(t *testing.T) { + assert.NotPanics(t, func() { + subModule := Module("sub", Provide(func() *Parent { + return &Parent{} + }, Named("parent-in-sub"))) + parentModuleA := Module("parent-a", subModule) + parentModuleB := Module("parent-b", subModule) + injector, err := NewInjector( + parentModuleA, + parentModuleB, + ) + assert.Nil(t, err) + assert.NotNil(t, injector) + assert.Equal(t, 1, len(injector.bindings[reflect.TypeFor[*Parent]()])) + assert.Equal(t, 2, len(injector.bindings)) // we add a binding for *Injector + }) +} + +type Shape interface { + Name() string +} + +type Rectangle struct { +} + +func (r *Rectangle) Name() string { + return "rectangle" +} + +type Square struct { +} + +func (s *Square) Name() string { + return "square" +} + +func TestBindToInterface(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Rectangle { + return &Rectangle{} + }, As(Type[Shape]())), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(s Shape) { + assert.IsType(t, &Rectangle{}, s) + }) + assert.Nil(t, err) + }) +} + +func TestInjectorShouldBeProvided(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(i *Injector) { + assert.Same(t, i, injector) + }) + assert.Nil(t, err) + }) +} + +type WithRefCount struct { + refCount int +} + +func TestInjectorShutdownShouldShutdownSingletonScope(t *testing.T) { + assert.NotPanics(t, func() { + refCount := 0 + injector, err := NewInjector( + Provide(func() *WithRefCount { + res := &WithRefCount{refCount: refCount} + refCount++ + return res + }, WithDestroy(func(_ *WithRefCount) { + refCount-- + }), In(Singleton)), + ) + assert.Nil(t, err) + ctx := context.Background() + + // singleton should be created eagerly + assert.Equal(t, 1, refCount) + + err = injector.Invoke(ctx, func(c *WithRefCount) { + assert.Equal(t, 1, refCount) + assert.Equal(t, 0, c.refCount) + }) + + assert.Nil(t, err) + assert.Equal(t, 1, refCount) + injector.Shutdown() + assert.Equal(t, 0, refCount) + assert.Equal(t, 0, len(injector.bindings)) + }) +} + +func TestNewInjectorShouldReturnErrorIfEagerlyCreatedSingletonReturnError(t *testing.T) { + returnedErr := fmt.Errorf("provider error") + assert.NotPanics(t, func() { + _, err := NewInjector( + Provide(func() (*WithRefCount, error) { + return nil, returnedErr + }), + ) + assert.ErrorIs(t, err, returnedErr) + assert.Equal(t, "failed to get singleton instance: provider for type \"*goinject.WithRefCount\" "+ + "returned error: provider error", err.Error()) + }) +} + +type MultiBindOptionalInvokeParams struct { + Params + Shapes []Shape `inject:",optional"` +} + +func TestMultiBind(t *testing.T) { + t.Run("Using multiple interface implementation", func(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Rectangle { + return &Rectangle{} + }, As(Type[Shape]())), + Provide(func() *Square { + return &Square{} + }, As(Type[Shape]())), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(shapes []Shape) { + var names []string + for _, shape := range shapes { + names = append(names, shape.Name()) + } + assert.Contains(t, names, "square") + assert.Contains(t, names, "rectangle") + }) + assert.Nil(t, err) + }) + }) + + t.Run("Should not throw error if not found and optional", func(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(params MultiBindOptionalInvokeParams) { + assert.Empty(t, params.Shapes) + }) + assert.Nil(t, err) + }) + }) + + t.Run("Should throw error if not found and not optional", func(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(_ []Shape) { + assert.Fail(t, "should not be reached") + }) + assert.NotNil(t, err) + var expectedErrorType *injectionError + assert.ErrorAs(t, err, &expectedErrorType) + assert.Equal(t, "failed to call invokation function: failed to resolve function argument #0: "+ + "Got error while resolving type goinject.Shape (with annotation \"\"):\n"+ + "did not found binding, expected at least one", err.Error()) + }) + }) +} + +type WithProvider struct { + provider Provider[*WithRefCount] +} + +type WithProviderParam struct { + Params + Provider Provider[*WithRefCount] `inject:",optional"` +} + +func TestProvider(t *testing.T) { + assert.NotPanics(t, func() { + t.Run("Get from provider should re-ask scope (with per-lookup)", func(t *testing.T) { + refCount := 0 + injector, rootError := NewInjector( + Provide(func() *WithRefCount { + res := &WithRefCount{refCount: refCount} + refCount++ + return res + }, In(PerLookUp)), + Provide(func(p Provider[*WithRefCount]) *WithProvider { + return &WithProvider{ + provider: p, + } + }), + ) + assert.Nil(t, rootError) + ctx := context.Background() + + rootError = injector.Invoke(ctx, func(w *WithProvider) { + ref1, err := w.provider(ctx) + assert.Nil(t, err) + ref2, err := w.provider(ctx) + assert.Nil(t, err) + assert.NotEqual(t, ref2, ref1) + assert.Equal(t, 0, ref1.refCount) + assert.Equal(t, 1, ref2.refCount) + }) + assert.Nil(t, rootError) + }) + + t.Run("Get from provider should re-ask scope (with singleton)", func(t *testing.T) { + refCount := 0 + injector, rootError := NewInjector( + Provide(func() *WithRefCount { + res := &WithRefCount{refCount: refCount} + refCount++ + return res + }, In(Singleton)), + Provide(func(p Provider[*WithRefCount]) *WithProvider { + return &WithProvider{ + provider: p, + } + }), + ) + assert.Nil(t, rootError) + ctx := context.Background() + + rootError = injector.Invoke(ctx, func(w *WithProvider) { + ref1, err := w.provider(ctx) + assert.Nil(t, err) + ref2, err := w.provider(ctx) + assert.Nil(t, err) + assert.Same(t, ref2, ref1) + }) + assert.Nil(t, rootError) + }) + + t.Run("Provider with optional should return zero value if not present", func(t *testing.T) { + injector, rootError := NewInjector() + assert.Nil(t, rootError) + ctx := context.Background() + + rootError = injector.Invoke(ctx, func(w WithProviderParam) { + ref, err := w.Provider(ctx) + assert.Nil(t, err) + assert.Nil(t, ref) + }) + assert.Nil(t, rootError) + }) + + t.Run("Provider should return error", func(t *testing.T) { + injector, rootError := NewInjector( + Provide(func() (*WithRefCount, error) { + return nil, fmt.Errorf("test error") + }, In(PerLookUp)), + Provide(func(p Provider[*WithRefCount]) *WithProvider { + return &WithProvider{ + provider: p, + } + }, In(PerLookUp)), + ) + assert.Nil(t, rootError) + ctx := context.Background() + + rootError = injector.Invoke(ctx, func(w *WithProvider) { + ref, err := w.provider(ctx) + assert.Nil(t, ref) + assert.NotNil(t, err) + assert.Equal(t, "provider for type \"*goinject.WithRefCount\" returned error: test error", err.Error()) + }) + assert.Nil(t, rootError) + }) + }) +} + +func TestConditional(t *testing.T) { + t.Run("Test conditional env var should not register binding if no match", func(t *testing.T) { + t.Setenv("TEST", "CASE-KO") + injector, err := NewInjector( + When(OnEnvironmentVariable("TEST", "CASE-OK", false), + Provide(func() (*Parent, error) { return &Parent{}, nil }), + ), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(_ *Parent) { + assert.Fail(t, "inaccessible") + }) + assert.NotNil(t, err) + var expectedErrorType *injectionError + assert.ErrorAs(t, err, &expectedErrorType) + assert.Equal(t, + "failed to call invokation function: failed to resolve function argument #0: "+ + "Got error while resolving type *goinject.Parent (with annotation \"\"):\ndid not found binding, "+ + "expected one", + err.Error(), + ) + }) + + t.Run("Test conditional env var should register binding if match", func(t *testing.T) { + t.Setenv("TEST", "CASE-OK") + injector, err := NewInjector( + When(OnEnvironmentVariable("TEST", "CASE-OK", false), + Provide(func() (*Parent, error) { return &Parent{}, nil }), + ), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(parent *Parent) { + assert.NotNil(t, parent) + }) + assert.Nil(t, err) + }) + + t.Run("Test conditional env var should register binding if no match but match missing", func(t *testing.T) { + injector, err := NewInjector( + When(OnEnvironmentVariable("TEST", "CASE-OO", true), + Provide(func() (*Parent, error) { return &Parent{}, nil }), + ), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(parent *Parent) { + assert.NotNil(t, parent) + }) + assert.Nil(t, err) + }) + + t.Run("Test When should return binding configuration errors", func(t *testing.T) { + _, err := NewInjector( + When(OnEnvironmentVariable("TEST", "CASE-OK", true), + Provide(nil), + ), + ) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, "cannot accept nil provider", err.Error()) + }) +} + +func TestInvokeError(t *testing.T) { + t.Run("Invoke should not accept nil", func(t *testing.T) { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, nil) + assert.NotNil(t, err) + assert.IsType(t, err, &invalidInputError{}) + assert.Equal(t, "can't invoke on nil", err.Error()) + }) + + t.Run("Invoke should only accept function", func(t *testing.T) { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, true) + assert.NotNil(t, err) + assert.IsType(t, err, &invalidInputError{}) + assert.Equal(t, "can't invoke non-function true (type bool)", err.Error()) + }) + + t.Run("Invoke should only accept function returning error", func(t *testing.T) { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func() *Parent { return nil }) + assert.NotNil(t, err) + assert.IsType(t, err, &invalidInputError{}) + assert.Equal(t, "can't invoke on function whose return type is not error or no return type", err.Error()) + }) + + t.Run("Invoke should return error if function return error", func(t *testing.T) { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + invokationFnReturnedError := fmt.Errorf("returned error") + err = injector.Invoke(ctx, func() error { return invokationFnReturnedError }) + assert.NotNil(t, err) + assert.ErrorIs(t, err, invokationFnReturnedError) + }) +} + +func TestInjectorConfigurationError(t *testing.T) { + t.Run("Provide cannot accept nil", func(t *testing.T) { + _, err := NewInjector( + Provide(nil)) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, "cannot accept nil provider", err.Error()) + }) + + t.Run("Provider should use function as argument", func(t *testing.T) { + _, err := NewInjector( + Provide(true)) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, "provider argument should be a function", err.Error()) + }) + + t.Run("Provider function should return an instance", func(t *testing.T) { + _, err := NewInjector( + Provide(func() {})) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, "expected a function that return an instance and optionally an error", err.Error()) + }) + + t.Run("Provider function cannot return multiple types (except error)", func(t *testing.T) { + _, err := NewInjector( + Provide(func() (*Parent, *Child) { + return &Parent{}, &Child{} + })) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, "second return type of provider should be an error", err.Error()) + }) + + t.Run("Module should return nested errors", func(t *testing.T) { + _, err := NewInjector( + Module("test.Module", + Provide(nil)), + ) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, "error while installing module test.Module:\ncannot accept nil provider", err.Error()) + }) + + t.Run("As provider annotation should raise error if not assignable", func(t *testing.T) { + _, err := NewInjector( + Provide(func() *Parent { + return &Parent{} + }, As(Type[*Child]())), + ) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, + "got error while configuring provider for provided type *goinject.Parent:\ncannot assign "+ + "*goinject.Parent to *goinject.Child as specified in As argument", + err.Error(), + ) + }) + + t.Run("WithDestroy should raise an error if not a function", func(t *testing.T) { + _, err := NewInjector( + Provide(func() *Parent { + return &Parent{} + }, WithDestroy(true)), + ) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, + "got error while configuring provider for provided type *goinject.Parent:\nargument of WithDestroy"+ + " must be a function with one argument returning void", + err.Error(), + ) + }) + + t.Run("WithDestroy should raise an error if not a function of provided type", func(t *testing.T) { + _, err := NewInjector( + Provide(func() *Parent { + return &Parent{} + }, WithDestroy(func(_ *Child) {})), + ) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, + "got error while configuring provider for provided type *goinject.Parent:\nargument of WithDestroy"+ + " must be a function with one argument returning void", + err.Error(), + ) + }) + + t.Run("WithDestroy should raise an error if not a void function of provided type", func(t *testing.T) { + _, err := NewInjector( + Provide(func() *Parent { + return &Parent{} + }, WithDestroy(func(_ *Parent) error { + return nil + })), + ) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, + "got error while configuring provider for provided type *goinject.Parent:\nargument of WithDestroy "+ + "must be a function with one argument returning void", err.Error(), + ) + }) +} diff --git a/module.go b/module.go new file mode 100644 index 0000000..98fb697 --- /dev/null +++ b/module.go @@ -0,0 +1,235 @@ +package goinject + +import ( + "fmt" + "reflect" +) + +type configuration struct { + bindings map[*binding]bool + scopes map[string]Scope +} + +// Option enable to configure the given injector +type Option interface { + apply(*configuration) error +} + +type moduleOption struct { + name string + options []Option +} + +func (o *moduleOption) apply(mod *configuration) error { + for _, opt := range o.options { + err := opt.apply(mod) + if err != nil { + return newInjectorConfigurationError( + fmt.Sprintf("error while installing module %s", o.name), err) + } + } + return nil +} + +// Module group a list of Option in order to easily reuse them. +// the Module name is used in error when applying Option to easily find misconfigured options. +func Module(name string, opts ...Option) Option { + mo := &moduleOption{ + name: name, + options: opts, + } + return mo +} + +type provideOption struct { + constructor any + annotations []Annotation +} + +func (o *provideOption) apply(mod *configuration) error { + if o.constructor == nil { + return newInjectorConfigurationError("cannot accept nil provider", nil) + } + providerFncValue := reflect.ValueOf(o.constructor) + fncType := providerFncValue.Type() + if fncType.Kind() != reflect.Func { + return newInjectorConfigurationError("provider argument should be a function", nil) + } + if fncType.NumOut() > 2 || fncType.NumOut() == 0 { + return newInjectorConfigurationError("expected a function that return an instance and optionally an error", nil) + } + if fncType.NumOut() == 2 && !fncType.Out(1).AssignableTo(reflect.TypeOf(new(error)).Elem()) { + return newInjectorConfigurationError("second return type of provider should be an error", nil) + } + b := &binding{} + b.provider = providerFncValue + b.providedType = fncType.Out(0) + b.typeof = b.providedType + b.scope = Singleton + + for _, a := range o.annotations { + err := a.apply(b) + if err != nil { + return newInjectorConfigurationError( + fmt.Sprintf("got error while configuring provider for provided type %s", b.providedType), + err, + ) + } + } + + mod.bindings[b] = true + return nil +} + +// Provide define a binding from a function constructor that must return the provided instance (and optionally an error) +// arguments of the constructor parameter will be resolved by the injector itself. +// Provide enable to annotate the created binding using Annotation +func Provide(constructor any, annotations ...Annotation) Option { + return &provideOption{ + constructor: constructor, + annotations: annotations, + } +} + +type registerScopeOption struct { + name string + scope Scope +} + +func (o *registerScopeOption) apply(mod *configuration) error { + mod.scopes[o.name] = o.scope + return nil +} + +// RegisterScope register a new Scope with a name +func RegisterScope(name string, scope Scope) Option { + return ®isterScopeOption{ + name: name, + scope: scope, + } +} + +type whenOption struct { + condition Conditional + options []Option +} + +func (o *whenOption) apply(mod *configuration) error { + if o.condition.evaluate() { + for _, opt := range o.options { + if err := opt.apply(mod); err != nil { + return err + } + } + } + + return nil +} + +// When enable to group a list of Option that will be applied only if the given Conditional evaluate to true +func When(condition Conditional, options ...Option) Option { + return &whenOption{ + condition: condition, + options: options, + } +} + +// Annotation are used to configured bindings created by the Provide function +type Annotation interface { + apply(b *binding) error +} + +type asAnnotation struct { + target AsType +} + +func (a *asAnnotation) apply(b *binding) error { + targetType := a.target.getType() + if !b.providedType.AssignableTo(targetType) { + return newInjectorConfigurationError( + fmt.Sprintf("cannot assign %s to %s as specified in As argument", b.providedType, targetType), + nil, + ) + } + b.typeof = targetType + return nil +} + +// AsType is used in As function as an argument to register a provided type to another given assignable type +type AsType interface { + getType() reflect.Type +} + +type typeFor[T any] struct { +} + +func (t *typeFor[T]) getType() reflect.Type { + return reflect.TypeFor[T]() +} + +// Type return an AsType for a given type T +func Type[T any]() AsType { + return &typeFor[T]{} +} + +// As return an annotation that is used to override the binding registration type. +// Use it to bind a concrete type to an interface. +func As(target AsType) Annotation { + return &asAnnotation{target: target} +} + +type nameAnnotation struct { + name string +} + +func (a *nameAnnotation) apply(b *binding) error { + b.annotatedWith = a.name + return nil +} + +// Named return an annotation that is used to define the binding annotation name. +func Named(name string) Annotation { + return &nameAnnotation{name: name} +} + +type inAnnotation struct { + scope string +} + +// In return an annotation that is used to define the binding scope +func In(scope string) Annotation { + return &inAnnotation{scope: scope} +} + +func (a *inAnnotation) apply(b *binding) error { + b.scope = a.scope + return nil +} + +type withDestroyAnnotation struct { + destroyMethod any +} + +func (a *withDestroyAnnotation) apply(b *binding) error { + destroyMethodFnVal := reflect.ValueOf(a.destroyMethod) + if destroyMethodFnVal.Kind() != reflect.Func || + destroyMethodFnVal.Type().NumIn() != 1 || + destroyMethodFnVal.Type().In(0) != b.providedType || + destroyMethodFnVal.Type().NumOut() != 0 { + return newInjectorConfigurationError( + "argument of WithDestroy must be a function with one argument returning void", + nil, + ) + } + b.destroyMethod = func(val reflect.Value) { + destroyMethodFnVal.Call([]reflect.Value{val}) + } + return nil +} + +// WithDestroy return an annotation that declare a destroyMethod that will be used when closing a scope +func WithDestroy(destroyMethod any) Annotation { + return &withDestroyAnnotation{ + destroyMethod: destroyMethod, + } +} diff --git a/scope.go b/scope.go new file mode 100644 index 0000000..9c3d84a --- /dev/null +++ b/scope.go @@ -0,0 +1,199 @@ +package goinject + +import ( + "context" + "reflect" + "sync" +) + +// Instance is the return type for Scope ResolveBinding method. +// It is used to hidde the usage of reflect.Value in the public API +type Instance reflect.Value + +type instanceRegistry struct { + mu sync.Mutex // lock guarding instanceLock + instanceLock map[*binding]*sync.RWMutex // lock guarding instances + instances sync.Map + destroyMethodsLock sync.Mutex + destroyMethods []func() +} + +func (r *instanceRegistry) resolveBinding( + binding *binding, + instanceCreator func() (Instance, error), +) (Instance, error) { + r.mu.Lock() + + if l, ok := r.instanceLock[binding]; ok { + r.mu.Unlock() + l.RLock() + defer l.RUnlock() + + instance, _ := r.instances.Load(binding) + return instance.(Instance), nil + } + + r.instanceLock[binding] = new(sync.RWMutex) + l := r.instanceLock[binding] + l.Lock() + r.mu.Unlock() + + instance, err := instanceCreator() + r.instances.Store(binding, instance) + + defer l.Unlock() + + return instance, err +} + +func (r *instanceRegistry) registerDestructionCallback( + destroyCallback func(), +) { + r.destroyMethodsLock.Lock() + defer r.destroyMethodsLock.Unlock() + r.destroyMethods = append(r.destroyMethods, destroyCallback) +} + +func (r *instanceRegistry) shutdown() { + r.destroyMethodsLock.Lock() + defer r.destroyMethodsLock.Unlock() + + for i := len(r.destroyMethods) - 1; i >= 0; i-- { + r.destroyMethods[i]() + } + + r.destroyMethods = []func(){} +} + +func newInstanceRegistry() *instanceRegistry { + return &instanceRegistry{ + instanceLock: make(map[*binding]*sync.RWMutex), + destroyMethods: []func(){}, + } +} + +// Scope defines a scope's behaviour +type Scope interface { + // ResolveBinding resolve a dependency injection context for current scope + ResolveBinding( + ctx context.Context, + binding *binding, + instanceCreator func() (Instance, error), + ) (Instance, error) + + // RegisterDestructionCallback register a destruction callback. It is the responsibility of the Scope to call + // this callback when destroying the Scope + RegisterDestructionCallback( + ctx context.Context, + destroyCallback func(), + ) +} + +const PerLookUp = "inject.PerLookUp" + +// perLookUpScope is a Scope that return a new instance when requested +type perLookUpScope struct { +} + +var _ Scope = new(perLookUpScope) + +func newPerLookUpScope() Scope { + return &perLookUpScope{} +} + +func (s *perLookUpScope) ResolveBinding( + _ context.Context, + _ *binding, + instanceCreator func() (Instance, error), +) (Instance, error) { + return instanceCreator() +} + +func (s *perLookUpScope) RegisterDestructionCallback( + _ context.Context, + _ func(), +) { + // nothing to do, per lookup provided need to close destroy method themselves +} + +const Singleton = "inject.Singleton" + +// singletonScope is our Scope to handle Singletons +type singletonScope struct { + instanceRegistry *instanceRegistry +} + +var _ Scope = new(singletonScope) + +func newSingletonScope() *singletonScope { + return &singletonScope{ + instanceRegistry: newInstanceRegistry(), + } +} + +func (s *singletonScope) ResolveBinding( + _ context.Context, + binding *binding, + instanceCreator func() (Instance, error), +) (Instance, error) { + return s.instanceRegistry.resolveBinding(binding, instanceCreator) +} + +func (s *singletonScope) RegisterDestructionCallback( + _ context.Context, + destroyCallback func(), +) { + s.instanceRegistry.registerDestructionCallback(destroyCallback) +} + +func (s *singletonScope) Shutdown() { + s.instanceRegistry.shutdown() +} + +// contextualScope is an abstract scope to handle context attached scoped (request, session, ...) +type contextualScope struct { + key any +} + +var _ Scope = new(contextualScope) + +func (s *contextualScope) ResolveBinding( + ctx context.Context, + binding *binding, + instanceCreator func() (Instance, error), +) (Instance, error) { + if ctx == nil { + return Instance{}, newContextScopedNotActiveError() + } + scopeHolder, ok := ctx.Value(s.key).(*instanceRegistry) + if !ok { + return Instance{}, newContextScopedNotActiveError() + } + return scopeHolder.resolveBinding(binding, instanceCreator) +} + +func (s *contextualScope) RegisterDestructionCallback( + ctx context.Context, + destroyCallback func(), +) { + if scopeHolder, ok := ctx.Value(s.key).(*instanceRegistry); ok { + scopeHolder.registerDestructionCallback(destroyCallback) + } +} + +func NewContextualScope(key any) Scope { + return &contextualScope{ + key: key, + } +} + +func WithContextualScopeEnabled(ctx context.Context, key any) context.Context { + return context.WithValue(ctx, key, newInstanceRegistry()) +} + +func ShutdownContextualScope(ctx context.Context, key any) { + holder, ok := ctx.Value(key).(*instanceRegistry) + if ok { + holder.shutdown() + } +} diff --git a/scope_test.go b/scope_test.go new file mode 100644 index 0000000..b14673a --- /dev/null +++ b/scope_test.go @@ -0,0 +1,295 @@ +package goinject + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +type sessionScopeKey int + +const sessionScopeKeyVal sessionScopeKey = 0 + +type requestScopeKey int + +const requestScopeKeyVal requestScopeKey = 0 + +type Request struct { + ID int +} + +type Session struct { + ID int +} + +type ContextualScopesParams struct { + Params + Request *Request `inject:""` +} + +type ctxKey int + +const requestKey ctxKey = iota + +func TestContextualScopesUsingContextValue(t *testing.T) { + notAwareContextError := errors.New("not running in value aware context") + assert.NotPanics(t, func() { + injector, err := NewInjector( + Module("contextualScopeTest", + RegisterScope("request", NewContextualScope(requestScopeKeyVal)), + Provide(func(ctx InvocationContext) (*Request, error) { + if r, ok := ctx.Value(requestKey).(*Request); ok { + return r, nil + } else { + return nil, notAwareContextError + } + }, In("request")), + ), + ) + assert.Nil(t, err) + + t.Run("Provider should be able to provide from InvocationContext", func(t *testing.T) { + ctx := context.Background() + requestCtx := WithContextualScopeEnabled( + context.WithValue(ctx, requestKey, &Request{ID: 42}), + requestScopeKeyVal, + ) + defer ShutdownContextualScope(requestCtx, requestScopeKeyVal) + + invokeErr := injector.Invoke(requestCtx, func(r *Request) { + assert.Equal(t, 42, r.ID) + }) + assert.Nil(t, invokeErr) + }) + + t.Run("Provider should be able to provide from InvocationContext with error", func(t *testing.T) { + ctx := context.Background() + requestCtx := WithContextualScopeEnabled( + ctx, + requestScopeKeyVal, + ) + defer ShutdownContextualScope(requestCtx, requestScopeKeyVal) + + invokeErr := injector.Invoke(requestCtx, func(*Request) { + assert.Fail(t, "should not be called") + }) + assert.ErrorIs(t, invokeErr, notAwareContextError) + }) + }) +} + +func TestContextualScopes(t *testing.T) { + assert.NotPanics(t, func() { + count := 0 + injector, err := NewInjector( + Module("contextualScopeTest", + RegisterScope("request", NewContextualScope(requestScopeKeyVal)), + RegisterScope("session", NewContextualScope(sessionScopeKeyVal)), + Provide(func() *Request { + res := &Request{ID: count} + count++ + return res + }, In("request")), + Provide(func() *Session { + res := &Session{ID: count} + count++ + return res + }, In("session")), + ), + ) + assert.Nil(t, err) + + ctx := context.Background() + + t.Run("Contextual scope should return error if not active", func(t *testing.T) { + err = injector.Invoke(ctx, func(_ *Request) { + assert.Fail(t, "Should not be reached") + }) + assert.True(t, errors.Is(err, &contextScopedNotActiveError{})) + }) + + t.Run("Contextual scope should return error if not active (using Params)", func(t *testing.T) { + err = injector.Invoke(ctx, func(_ ContextualScopesParams) { + assert.Fail(t, "Should not be reached") + }) + assert.True(t, errors.Is(err, &contextScopedNotActiveError{})) + }) + + var sessionID int + var sessionID2 int + + t.Run("Test session with multiple request should keep same session scope but different request scope", + func(t *testing.T) { + sessionCtx := WithContextualScopeEnabled(ctx, sessionScopeKeyVal) + defer ShutdownContextualScope(sessionCtx, sessionScopeKeyVal) + + var request1ID int + var request2ID int + var sessionIDBis int + + t.Run("Test request 1", func(t *testing.T) { + requestCtx := WithContextualScopeEnabled(sessionCtx, requestScopeKeyVal) + defer ShutdownContextualScope(requestCtx, requestScopeKeyVal) + + err := injector.Invoke(requestCtx, func(session *Session, request *Request) { + sessionID = session.ID + request1ID = request.ID + }) + assert.Nil(t, err) + }) + + t.Run("Test request 2", func(t *testing.T) { + requestCtx := WithContextualScopeEnabled(sessionCtx, requestScopeKeyVal) + defer ShutdownContextualScope(requestCtx, requestScopeKeyVal) + + err := injector.Invoke(requestCtx, func(session *Session, request *Request) { + sessionIDBis = session.ID + request2ID = request.ID + }) + assert.Nil(t, err) + }) + + assert.NotZero(t, request1ID) + assert.NotZero(t, request2ID) + assert.NotEqual(t, request2ID, request1ID) + + assert.Equal(t, sessionID, sessionIDBis) + }) + + t.Run("Test session 2 (without request scope)", func(t *testing.T) { + sessionCtx := WithContextualScopeEnabled(ctx, sessionScopeKeyVal) + defer ShutdownContextualScope(sessionCtx, sessionScopeKeyVal) + + err := injector.Invoke(sessionCtx, func(session *Session) { + sessionID2 = session.ID + }) + assert.Nil(t, err) + }) + + assert.NotEqual(t, sessionID, sessionID2) + }) +} + +func TestContextualScopeDestroy(t *testing.T) { + assert.NotPanics(t, func() { + count := 0 + injector, err := NewInjector( + Module("contextualScopeTest", + RegisterScope("session", NewContextualScope(sessionScopeKeyVal)), + Provide(func() *Session { + res := &Session{ID: count} + count++ + return res + }, In("session"), WithDestroy(func(_ *Session) { + count-- + })), + ), + ) + assert.Nil(t, err) + ctx := context.Background() + + t.Run("Run session", func(t *testing.T) { + sessionCtx := WithContextualScopeEnabled(ctx, sessionScopeKeyVal) + defer ShutdownContextualScope(sessionCtx, sessionScopeKeyVal) + + err := injector.Invoke(sessionCtx, func(_ *Session) { + assert.Equal(t, 1, count) + }) + assert.Nil(t, err) + }) + + assert.Equal(t, 0, count) + }) +} + +type SingletonInjectee struct { + ID int +} + +func TestSingletonScope(t *testing.T) { + count := 0 + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *SingletonInjectee { + res := &SingletonInjectee{ID: count} + count++ + return res + }, In(Singleton)), + ) + assert.Nil(t, err) + + ctx := context.Background() + var fetch1 *SingletonInjectee + var fetch2 *SingletonInjectee + err = injector.Invoke(ctx, func(s *SingletonInjectee) { + fetch1 = s + }) + assert.Nil(t, err) + err = injector.Invoke(ctx, func(s *SingletonInjectee) { + fetch2 = s + }) + assert.Nil(t, err) + assert.NotNil(t, fetch1) + assert.NotNil(t, fetch2) + assert.Same(t, fetch1, fetch2) + }) +} + +type PerLookUpInjectee struct { + ID int +} + +func TestPerLookUpScope(t *testing.T) { + assert.NotPanics(t, func() { + t.Run("Should return new instance on each request", func(t *testing.T) { + count := 0 + injector, err := NewInjector( + Provide(func() *PerLookUpInjectee { + res := &PerLookUpInjectee{ID: count} + count++ + return res + }, In(PerLookUp)), + ) + assert.Nil(t, err) + + ctx := context.Background() + var fetch1 *PerLookUpInjectee + var fetch2 *PerLookUpInjectee + err = injector.Invoke(ctx, func(s *PerLookUpInjectee) { + fetch1 = s + }) + assert.Nil(t, err) + err = injector.Invoke(ctx, func(s *PerLookUpInjectee) { + fetch2 = s + }) + assert.Nil(t, err) + assert.NotNil(t, fetch1) + assert.NotNil(t, fetch2) + assert.NotEqual(t, fetch1, fetch2) + }) + + t.Run("Should ignore destroy instance methods", func(t *testing.T) { + count := 0 + injector, err := NewInjector( + Provide(func() *PerLookUpInjectee { + res := &PerLookUpInjectee{ID: count} + count++ + return res + }, In(PerLookUp), WithDestroy(func(_ *PerLookUpInjectee) { count-- })), + ) + assert.Nil(t, err) + + ctx := context.Background() + err = injector.Invoke(ctx, func(s *PerLookUpInjectee) { + assert.Equal(t, 0, s.ID) + assert.Equal(t, 1, count) + }) + assert.Nil(t, err) + assert.Equal(t, 1, count) + injector.Shutdown() + assert.Equal(t, 1, count) + }) + }) +} diff --git a/special.go b/special.go new file mode 100644 index 0000000..35d8d6c --- /dev/null +++ b/special.go @@ -0,0 +1,60 @@ +package goinject + +import ( + "context" + "reflect" +) + +// Params may be embedded in struct to request the injector to create it +// as special struct. When a constructor accepts such a struct, instead of the +// struct becoming a dependency for that constructor, all its fields become +// dependencies instead. +// +// Fields of the struct may optionally be tagged. +// The following tags are supported, +// +// annotation Requests a value with the same name and type from the +// container. See Named Values for more information. +// optional If set to true, indicates that the dependency is optional and +// the constructor gracefully handles its absence. +type Params struct{} + +var _paramType = reflect.TypeOf(Params{}) + +// EmbedsParams checks whether the given struct is an inject.Params struct. A struct qualifies +// as an inject.Params struct if it embeds inject.Params type. +// +// A struct MUST qualify as an inject.Params struct for its fields to be treated +// specially by the injector. +func EmbedsParams(o reflect.Type) bool { + return embedsType(o, _paramType) +} + +// Returns true if t embeds e +func embedsType(t, e reflect.Type) bool { + if t.Kind() == reflect.Ptr { + return embedsType(t.Elem(), e) + } + + if t.Kind() != reflect.Struct { + // for now, only struct are supported, it might be a good idea to support pointer too + return false + } + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Anonymous && f.Type == e { + return true + } + } + + return false +} + +type Provider[T any] func(ctx InvocationContext) (T, error) + +// InvocationContext wrap context.Context. +// Use this interface to retrieve the context pass to the Invoke method of the injector in providers +type InvocationContext interface { + context.Context +}