diff --git a/cmd/state-svc/internal/resolver/resolver.go b/cmd/state-svc/internal/resolver/resolver.go index 4f218fefb8..1c4ebe9515 100644 --- a/cmd/state-svc/internal/resolver/resolver.go +++ b/cmd/state-svc/internal/resolver/resolver.go @@ -3,6 +3,7 @@ package resolver import ( "context" "encoding/json" + "errors" "os" "runtime/debug" "sort" @@ -22,8 +23,10 @@ import ( "github.com/ActiveState/cli/internal/constants" "github.com/ActiveState/cli/internal/errs" "github.com/ActiveState/cli/internal/graph" + "github.com/ActiveState/cli/internal/locale" "github.com/ActiveState/cli/internal/logging" configMediator "github.com/ActiveState/cli/internal/mediators/config" + "github.com/ActiveState/cli/internal/messages" "github.com/ActiveState/cli/internal/poller" "github.com/ActiveState/cli/internal/rtutils/ptr" "github.com/ActiveState/cli/internal/runbits/panics" @@ -33,9 +36,14 @@ import ( "github.com/patrickmn/go-cache" ) +type messageQueue struct { + messages []*graph.Message +} + type Resolver struct { cfg *config.Instance - messages *notifications.Notifications + notifications *notifications.Notifications + messages *messageQueue updatePoller *poller.Poller authPoller *poller.Poller projectIDCache *projectcache.ID @@ -50,11 +58,12 @@ type Resolver struct { // var _ genserver.ResolverRoot = &Resolver{} // Must implement ResolverRoot func New(cfg *config.Instance, an *sync.Client, auth *authentication.Auth) (*Resolver, error) { - msg, err := notifications.New(cfg, auth) + notif, err := notifications.New(cfg, auth) if err != nil { return nil, errs.Wrap(err, "Could not initialize messages") } + msg := &messageQueue{make([]*graph.Message, 0)} upchecker := updater.NewDefaultChecker(cfg, an) pollUpdate := poller.New(1*time.Hour, func() (interface{}, error) { defer func() { @@ -74,11 +83,26 @@ func New(cfg *config.Instance, an *sync.Client, auth *authentication.Auth) (*Res } pollAuth := poller.New(time.Duration(int64(time.Millisecond)*pollRate), func() (interface{}, error) { + logging.Debug("Polling for authenticated state") defer func() { panics.LogAndPanic(recover(), debug.Stack()) }() if auth.SyncRequired() { - return nil, auth.Sync() + logging.Debug("Sync required") + if err := auth.Sync(); err != nil { + logging.Debug("Syncing authenticated state: %s", err.Error()) + var invalidTokenErr *authentication.ErrInvalidToken + if errors.As(err, &invalidTokenErr) { + logging.Debug("Queuing invalid API token error") + msg.messages = append(msg.messages, &graph.Message{ + Topic: messages.TopicErrorAuthToken, + Message: locale.Tl("err_invalid_token_try_again", "Invalid API token. Please check your API token and try again."), + }) + } else { + logging.Warning("Could not sync authenticated state: %s", err.Error()) + } + } + return nil, nil } return nil, nil }) @@ -88,6 +112,7 @@ func New(cfg *config.Instance, an *sync.Client, auth *authentication.Auth) (*Res anForClient := sync.New(anaConsts.SrcStateTool, cfg, auth, nil) return &Resolver{ cfg, + notif, msg, pollUpdate, pollAuth, @@ -102,7 +127,7 @@ func New(cfg *config.Instance, an *sync.Client, auth *authentication.Auth) (*Res } func (r *Resolver) Close() error { - r.messages.Close() + r.notifications.Close() r.updatePoller.Close() r.authPoller.Close() r.anForClient.Close() @@ -250,7 +275,15 @@ func (r *Resolver) ReportRuntimeUsage(_ context.Context, pid int, exec, source s func (r *Resolver) CheckNotifications(ctx context.Context, command string, flags []string) ([]*graph.NotificationInfo, error) { defer func() { panics.LogAndPanic(recover(), debug.Stack()) }() logging.Debug("Check notifications resolver") - return r.messages.Check(command, flags) + return r.notifications.Check(command, flags) +} + +func (r *Resolver) CheckMessages(ctx context.Context) ([]*graph.Message, error) { + defer func() { panics.LogAndPanic(recover(), debug.Stack()) }() + logging.Debug("Check messages resolver") + messages := r.messages.messages + r.messages.messages = r.messages.messages[:0] // clear queue + return messages, nil } func (r *Resolver) ConfigChanged(ctx context.Context, key string) (*graph.ConfigChangedResponse, error) { diff --git a/cmd/state-svc/internal/server/generated/generated.go b/cmd/state-svc/internal/server/generated/generated.go index 9e2bbb7947..21e5330c93 100644 --- a/cmd/state-svc/internal/server/generated/generated.go +++ b/cmd/state-svc/internal/server/generated/generated.go @@ -79,6 +79,11 @@ type ComplexityRoot struct { User func(childComplexity int) int } + Message struct { + Message func(childComplexity int) int + Topic func(childComplexity int) int + } + Mutation struct { SetCache func(childComplexity int, key string, value string, expiry int) int } @@ -112,6 +117,7 @@ type ComplexityRoot struct { Query struct { AnalyticsEvent func(childComplexity int, category string, action string, source string, label *string, dimensionsJSON string) int AvailableUpdate func(childComplexity int, desiredChannel string, desiredVersion string) int + CheckMessages func(childComplexity int) int CheckNotifications func(childComplexity int, command string, flags []string) int ConfigChanged func(childComplexity int, key string) int FetchLogTail func(childComplexity int) int @@ -158,6 +164,7 @@ type QueryResolver interface { AnalyticsEvent(ctx context.Context, category string, action string, source string, label *string, dimensionsJSON string) (*graph.AnalyticsEventResponse, error) ReportRuntimeUsage(ctx context.Context, pid int, exec string, source string, dimensionsJSON string) (*graph.ReportRuntimeUsageResponse, error) CheckNotifications(ctx context.Context, command string, flags []string) ([]*graph.NotificationInfo, error) + CheckMessages(ctx context.Context) ([]*graph.Message, error) ConfigChanged(ctx context.Context, key string) (*graph.ConfigChangedResponse, error) FetchLogTail(ctx context.Context) (string, error) GetProcessesInUse(ctx context.Context, execDir string) ([]*graph.ProcessInfo, error) @@ -283,6 +290,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.JWT.User(childComplexity), true + case "Message.message": + if e.complexity.Message.Message == nil { + break + } + + return e.complexity.Message.Message(childComplexity), true + + case "Message.topic": + if e.complexity.Message.Topic == nil { + break + } + + return e.complexity.Message.Topic(childComplexity), true + case "Mutation.setCache": if e.complexity.Mutation.SetCache == nil { break @@ -417,6 +438,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.AvailableUpdate(childComplexity, args["desiredChannel"].(string), args["desiredVersion"].(string)), true + case "Query.checkMessages": + if e.complexity.Query.CheckMessages == nil { + break + } + + return e.complexity.Query.CheckMessages(childComplexity), true + case "Query.checkNotifications": if e.complexity.Query.CheckNotifications == nil { break @@ -762,6 +790,11 @@ type NotificationInfo { placement: NotificationPlacementType! } +type Message { + topic: String! + message: String! +} + type Organization { URLname: String! role: String! @@ -797,6 +830,7 @@ type Query { analyticsEvent(category: String!, action: String!, source: String!, label: String, dimensionsJson: String!): AnalyticsEventResponse reportRuntimeUsage(pid: Int!, exec: String!, source: String!, dimensionsJson: String!): ReportRuntimeUsageResponse checkNotifications(command: String!, flags: [String!]!): [NotificationInfo!]! + checkMessages: [Message!]! configChanged(key: String!): ConfigChangedResponse fetchLogTail: String! getProcessesInUse(execDir: String!): [ProcessInfo!]! @@ -1757,6 +1791,94 @@ func (ec *executionContext) fieldContext_JWT_user(_ context.Context, field graph return fc, nil } +func (ec *executionContext) _Message_topic(ctx context.Context, field graphql.CollectedField, obj *graph.Message) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Message_topic(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Topic, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Message_topic(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Message", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _Message_message(ctx context.Context, field graphql.CollectedField, obj *graph.Message) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Message_message(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Message, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Message_message(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Message", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Mutation_setCache(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Mutation_setCache(ctx, field) if err != nil { @@ -2769,6 +2891,56 @@ func (ec *executionContext) fieldContext_Query_checkNotifications(ctx context.Co return fc, nil } +func (ec *executionContext) _Query_checkMessages(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Query_checkMessages(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().CheckMessages(rctx) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.([]*graph.Message) + fc.Result = res + return ec.marshalNMessage2ᚕᚖgithubᚗcomᚋActiveStateᚋcliᚋinternalᚋgraphᚐMessageᚄ(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Query_checkMessages(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Query", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "topic": + return ec.fieldContext_Message_topic(ctx, field) + case "message": + return ec.fieldContext_Message_message(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type Message", field.Name) + }, + } + return fc, nil +} + func (ec *executionContext) _Query_configChanged(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Query_configChanged(ctx, field) if err != nil { @@ -5779,6 +5951,50 @@ func (ec *executionContext) _JWT(ctx context.Context, sel ast.SelectionSet, obj return out } +var messageImplementors = []string{"Message"} + +func (ec *executionContext) _Message(ctx context.Context, sel ast.SelectionSet, obj *graph.Message) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, messageImplementors) + + out := graphql.NewFieldSet(fields) + deferred := make(map[string]*graphql.FieldSet) + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("Message") + case "topic": + out.Values[i] = ec._Message_topic(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + case "message": + out.Values[i] = ec._Message_message(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch(ctx) + if out.Invalids > 0 { + return graphql.Null + } + + atomic.AddInt32(&ec.deferred, int32(len(deferred))) + + for label, dfs := range deferred { + ec.processDeferredGroup(graphql.DeferredGroup{ + Label: label, + Path: graphql.GetPath(ctx), + FieldSet: dfs, + Context: ctx, + }) + } + + return out +} + var mutationImplementors = []string{"Mutation"} func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) graphql.Marshaler { @@ -6169,6 +6385,28 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) }) + case "checkMessages": + field := field + + innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_checkMessages(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&fs.Invalids, 1) + } + return res + } + + rrm := func(ctx context.Context) graphql.Marshaler { + return ec.OperationContext.RootResolverMiddleware(ctx, + func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) }) case "configChanged": field := field @@ -6942,6 +7180,60 @@ func (ec *executionContext) marshalNInt2int(ctx context.Context, sel ast.Selecti return res } +func (ec *executionContext) marshalNMessage2ᚕᚖgithubᚗcomᚋActiveStateᚋcliᚋinternalᚋgraphᚐMessageᚄ(ctx context.Context, sel ast.SelectionSet, v []*graph.Message) graphql.Marshaler { + ret := make(graphql.Array, len(v)) + var wg sync.WaitGroup + isLen1 := len(v) == 1 + if !isLen1 { + wg.Add(len(v)) + } + for i := range v { + i := i + fc := &graphql.FieldContext{ + Index: &i, + Result: &v[i], + } + ctx := graphql.WithFieldContext(ctx, fc) + f := func(i int) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + if !isLen1 { + defer wg.Done() + } + ret[i] = ec.marshalNMessage2ᚖgithubᚗcomᚋActiveStateᚋcliᚋinternalᚋgraphᚐMessage(ctx, sel, v[i]) + } + if isLen1 { + f(i) + } else { + go f(i) + } + + } + wg.Wait() + + for _, e := range ret { + if e == graphql.Null { + return graphql.Null + } + } + + return ret +} + +func (ec *executionContext) marshalNMessage2ᚖgithubᚗcomᚋActiveStateᚋcliᚋinternalᚋgraphᚐMessage(ctx context.Context, sel ast.SelectionSet, v *graph.Message) graphql.Marshaler { + if v == nil { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "the requested element is null which the schema does not allow") + } + return graphql.Null + } + return ec._Message(ctx, sel, v) +} + func (ec *executionContext) marshalNNotificationInfo2ᚕᚖgithubᚗcomᚋActiveStateᚋcliᚋinternalᚋgraphᚐNotificationInfoᚄ(ctx context.Context, sel ast.SelectionSet, v []*graph.NotificationInfo) graphql.Marshaler { ret := make(graphql.Array, len(v)) var wg sync.WaitGroup diff --git a/cmd/state-svc/main.go b/cmd/state-svc/main.go index 0bdfb950ff..bc99a39a9a 100644 --- a/cmd/state-svc/main.go +++ b/cmd/state-svc/main.go @@ -206,9 +206,6 @@ func run(cfg *config.Instance) error { }, func(ccmd *captain.Command, args []string) error { logging.Debug("Running CmdForeground") - if err := auth.Sync(); err != nil { - logging.Warning("Could not sync authenticated state: %s", err.Error()) - } return runForeground(cfg, an, auth, foregroundArgText) }, ), diff --git a/cmd/state-svc/schema/schema.graphqls b/cmd/state-svc/schema/schema.graphqls index 606eb5cfd3..1046a9e687 100644 --- a/cmd/state-svc/schema/schema.graphqls +++ b/cmd/state-svc/schema/schema.graphqls @@ -62,6 +62,11 @@ type NotificationInfo { placement: NotificationPlacementType! } +type Message { + topic: String! + message: String! +} + type Organization { URLname: String! role: String! @@ -97,6 +102,7 @@ type Query { analyticsEvent(category: String!, action: String!, source: String!, label: String, dimensionsJson: String!): AnalyticsEventResponse reportRuntimeUsage(pid: Int!, exec: String!, source: String!, dimensionsJson: String!): ReportRuntimeUsageResponse checkNotifications(command: String!, flags: [String!]!): [NotificationInfo!]! + checkMessages: [Message!]! configChanged(key: String!): ConfigChangedResponse fetchLogTail: String! getProcessesInUse(execDir: String!): [ProcessInfo!]! diff --git a/cmd/state/internal/cmdtree/exechandlers/messages/messenger.go b/cmd/state/internal/cmdtree/exechandlers/messages/messenger.go new file mode 100644 index 0000000000..ab5848d59d --- /dev/null +++ b/cmd/state/internal/cmdtree/exechandlers/messages/messenger.go @@ -0,0 +1,68 @@ +package messages + +import ( + "context" + "strings" + + "github.com/ActiveState/cli/internal/captain" + "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/graph" + "github.com/ActiveState/cli/internal/locale" + "github.com/ActiveState/cli/internal/logging" + msgs "github.com/ActiveState/cli/internal/messages" + "github.com/ActiveState/cli/internal/output" + "github.com/ActiveState/cli/pkg/platform/model" +) + +type Messenger struct { + out output.Outputer + svcModel *model.SvcModel +} + +func New(out output.Outputer, svcModel *model.SvcModel) *Messenger { + return &Messenger{ + out: out, + svcModel: svcModel, + } +} + +func (m *Messenger) OnExecStart(_ *captain.Command, _ []string) error { + logging.Debug("Checking for messages") + if m.out.Type().IsStructured() { + return nil + } + + messages, err := m.svcModel.CheckMessages(context.Background()) + if err != nil { + return errs.Wrap(err, "Could not get messages") + } + logging.Debug("Found %d messages", len(messages)) + + for _, message := range messages { + m.out.Notice("") // Line break before + + segments := strings.Split(message.Topic, ".") + if len(segments) > 0 { + switch segments[0] { + case msgs.TopicError: + m.handleErrorMessages(message) + case msgs.TopicInfo: + logging.Info("State Service reported an info message: %s", message.Message) + m.out.Notice(message.Message) + default: + logging.Debug("State Service reported an unknown message: %s", message.Topic) + m.out.Notice(message.Message) // fallback to notice for unknown types + } + } + + m.out.Notice("") // Line break after + } + + return nil +} + +func (m *Messenger) handleErrorMessages(message *graph.Message) { + logging.Warning("State Service reported a %s error: %s", message.Topic, message.Message) + err := locale.NewError("err_svc_message", "[WARNING]Warning:[/RESET] {{.V0}}", message.Message) + m.out.Error(err) +} diff --git a/cmd/state/main.go b/cmd/state/main.go index 42d2a2f9dd..cd61702560 100644 --- a/cmd/state/main.go +++ b/cmd/state/main.go @@ -10,6 +10,7 @@ import ( "time" "github.com/ActiveState/cli/cmd/state/internal/cmdtree" + "github.com/ActiveState/cli/cmd/state/internal/cmdtree/exechandlers/messages" "github.com/ActiveState/cli/cmd/state/internal/cmdtree/exechandlers/notifier" anAsync "github.com/ActiveState/cli/internal/analytics/client/async" anaConst "github.com/ActiveState/cli/internal/analytics/constants" @@ -250,6 +251,9 @@ func run(args []string, cfg *config.Instance, out output.Outputer) (rerr error) cmds.OnExecStart(notifier.OnExecStart) cmds.OnExecStop(notifier.OnExecStop) + messenger := messages.New(out, svcmodel) + cmds.OnExecStart(messenger.OnExecStart) + // Auto update to latest state tool version if possible. if updated, err := autoUpdate(svcmodel, args, childCmd, cfg, an, out); err == nil && updated { return nil // command will be run by updated exe diff --git a/internal/graph/generated.go b/internal/graph/generated.go index f61461ffa3..7350c32b84 100644 --- a/internal/graph/generated.go +++ b/internal/graph/generated.go @@ -40,6 +40,11 @@ type Jwt struct { User *User `json:"user"` } +type Message struct { + Topic string `json:"topic"` + Message string `json:"message"` +} + type Mutation struct { } diff --git a/internal/locale/locales/en-us.yaml b/internal/locale/locales/en-us.yaml index 2ad2a09797..8dc354a931 100644 --- a/internal/locale/locales/en-us.yaml +++ b/internal/locale/locales/en-us.yaml @@ -932,6 +932,10 @@ err_read_projectfile: other: The activestate.yaml at {{.V0}} could not be read. err_auth_fail_totp: other: A two-factor authentication code is required. +err_invalid_token: + other: Invalid API token. +err_invalid_credentials: + other: Invalid credentials cve_title: other: Vulnerability Summary cve_description: diff --git a/internal/messages/topics.go b/internal/messages/topics.go new file mode 100644 index 0000000000..638e69797a --- /dev/null +++ b/internal/messages/topics.go @@ -0,0 +1,8 @@ +package messages + +const ( + TopicError = "error" + TopicInfo = "info" + TopicErrorAuth = "error.auth" + TopicErrorAuthToken = "error.auth.token" +) diff --git a/pkg/platform/api/svc/request/message.go b/pkg/platform/api/svc/request/message.go new file mode 100644 index 0000000000..373cea2a3b --- /dev/null +++ b/pkg/platform/api/svc/request/message.go @@ -0,0 +1,21 @@ +package request + +type MessageRequest struct { +} + +func NewMessageRequest() *MessageRequest { + return &MessageRequest{} +} + +func (m *MessageRequest) Query() string { + return `query { + checkMessages { + topic + message + } + }` +} + +func (m *MessageRequest) Vars() (map[string]interface{}, error) { + return map[string]interface{}{}, nil +} diff --git a/pkg/platform/authentication/auth.go b/pkg/platform/authentication/auth.go index 8e1a1f8db3..a8379cac2c 100644 --- a/pkg/platform/authentication/auth.go +++ b/pkg/platform/authentication/auth.go @@ -34,6 +34,8 @@ type ErrUnauthorized struct{ *locale.LocalizedError } type ErrTokenRequired struct{ *locale.LocalizedError } +type ErrInvalidToken struct{ *locale.LocalizedError } + var errNotYetGranted = locale.NewInputError("err_auth_device_noauth") // jwtLifetime is the lifetime of the JWT. This is defined by the API, but the API doesn't communicate this. @@ -45,6 +47,7 @@ type Auth struct { client *mono_client.Mono clientAuth *runtime.ClientAuthInfoWriter bearerToken string + envToken string user *mono_models.User cfg Configurable lastRenewal *time.Time @@ -93,6 +96,7 @@ func New(cfg Configurable) *Auth { auth := &Auth{ cfg: cfg, jwtLifetime: jwtLifetime, + envToken: os.Getenv(constants.APIKeyEnvVarName), } return auth @@ -249,6 +253,11 @@ func (s *Auth) AuthenticateWithModel(credentials *mono_models.Credentials) error return errs.AddTips(&ErrUnauthorized{locale.WrapExternalError(err, "err_unauthorized")}, tips...) case *apiAuth.PostLoginRetryWith: return errs.AddTips(&ErrTokenRequired{locale.WrapExternalError(err, "err_auth_fail_totp")}, tips...) + case *apiAuth.PostLoginBadRequest: + if credentials.Token != "" { + return errs.AddTips(&ErrInvalidToken{locale.WrapExternalError(err, "err_invalid_token")}, tips...) + } + return errs.AddTips(locale.WrapExternalError(err, "err_invalid_credentials"), tips...) default: if os.IsTimeout(err) { return locale.NewExternalError("err_api_auth_timeout", "Timed out waiting for authentication response. Please try again.") @@ -303,9 +312,20 @@ func (s *Auth) AuthenticateWithDevicePolling(deviceCode strfmt.UUID, interval ti // AuthenticateWithToken will try to authenticate using the given token func (s *Auth) AuthenticateWithToken(token string) error { logging.Debug("AuthenticateWithToken") - return s.AuthenticateWithModel(&mono_models.Credentials{ + err := s.AuthenticateWithModel(&mono_models.Credentials{ Token: token, }) + if err != nil { + var invalidTokenErr *ErrInvalidToken + if errors.As(err, &invalidTokenErr) && s.envToken != "" { + logging.Debug("Invalid token, clearing stored token") + s.envToken = "" + return errs.Wrap(err, "Invalid API token") + } + return errs.Wrap(err, "Failed to authenticate with token") + } + + return nil } // UpdateSession authenticates with the given access token obtained via a Platform @@ -464,9 +484,9 @@ func (s *Auth) NewAPIKey(name string) (string, error) { } func (s *Auth) AvailableAPIToken() (v string) { - if tkn := os.Getenv(constants.APIKeyEnvVarName); tkn != "" { + if s.envToken != "" { logging.Debug("Using API token passed via env var") - return tkn + return s.envToken } return s.cfg.GetString(ApiTokenConfigKey) } diff --git a/pkg/platform/model/svc.go b/pkg/platform/model/svc.go index d6e4af9ae3..718dd20572 100644 --- a/pkg/platform/model/svc.go +++ b/pkg/platform/model/svc.go @@ -137,6 +137,14 @@ func (m *SvcModel) CheckNotifications(ctx context.Context, command string, flags return resp, nil } +func (m *SvcModel) CheckMessages(ctx context.Context) ([]*graph.Message, error) { + resp := []*graph.Message{} + if err := m.request(ctx, request.NewMessageRequest(), &resp); err != nil { + return nil, errs.Wrap(err, "Error sending messages request") + } + return resp, nil +} + func (m *SvcModel) ConfigChanged(ctx context.Context, key string) error { defer profile.Measure("svc:ConfigChanged", time.Now()) diff --git a/test/integration/auth_int_test.go b/test/integration/auth_int_test.go index c34dab906a..321444e38c 100644 --- a/test/integration/auth_int_test.go +++ b/test/integration/auth_int_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/ActiveState/cli/internal/constants" "github.com/ActiveState/cli/internal/testhelpers/suite" "github.com/ActiveState/termtest" "github.com/google/uuid" @@ -112,6 +113,27 @@ func (suite *AuthIntegrationTestSuite) TestAuth_JsonOutput() { suite.authOutput("json") } +func (suite *AuthIntegrationTestSuite) TestAuth_InvalidToken() { + suite.OnlyRunForTags(tagsuite.Auth, tagsuite.Critical) + ts := e2e.New(suite.T(), false) + defer ts.Close() + + cp := ts.SpawnWithOpts(e2e.OptArgs("--version"), e2e.OptAppendEnv(constants.APIKeyEnvVarName+"=bad-token")) + // Message is displayed + cp.Expect("Warning: Invalid API token") + // The version information is still displayed + cp.Expect("ActiveState CLI") + cp.Expect("Version") + cp.ExpectExitCode(0) + + // Running the command again shows no error message as the token has been cleared + cp = ts.SpawnWithOpts(e2e.OptArgs("--version")) + cp.ExpectExitCode(0) + cp.Expect("ActiveState CLI") + cp.Expect("Version") + ts.IgnoreLogErrors() +} + func TestAuthIntegrationTestSuite(t *testing.T) { suite.Run(t, new(AuthIntegrationTestSuite)) }