Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gqlgen.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ models:
model: github.com/stashapp/stash/internal/identify.FieldStrategy
ScraperSource:
model: github.com/stashapp/stash/pkg/scraper.Source
RoleEnum:
model: github.com/stashapp/stash/pkg/models.RoleEnum
IdentifySourceInput:
model: github.com/stashapp/stash/internal/identify.Source
IdentifyFieldOptionsInput:
Expand Down
231 changes: 145 additions & 86 deletions graphql/schema/schema.graphql

Large diffs are not rendered by default.

53 changes: 53 additions & 0 deletions graphql/schema/types/user.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
enum RoleEnum {
ADMIN
READ
MODIFY
}

directive @hasRole(role: RoleEnum!) on FIELD_DEFINITION
directive @isUserOwner on FIELD_DEFINITION

type User {
name: String!
"""
If the user has no roles, they are considered locked and cannot log in.
Should not be visible to other users
"""
roles: [RoleEnum!] @isUserOwner
"""
Should not be visible to other users
"""
api_key: String @isUserOwner
}

input UserCreateInput {
name: String!
"""
Password in plain text
"""
password: String!
roles: [RoleEnum!]!
}

input UserUpdateInput {
existingName: String!
name: String!
roles: [RoleEnum!]!
}

input UserDestroyInput {
name: String!
}

input UserChangePasswordInput {
"""
Password in plain text
"""
existingPassword: String!
newPassword: String!
}

input ChangeUserPasswordInput {
name: String!
newPassword: String!
}
38 changes: 31 additions & 7 deletions internal/api/authentication.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"errors"
"net"
"net/http"
Expand All @@ -11,6 +12,7 @@ import (
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/session"
)

Expand All @@ -29,13 +31,18 @@ func allowUnauthenticated(r *http.Request) bool {
return strings.HasPrefix(r.URL.Path, loginEndpoint) || r.URL.Path == logoutEndpoint || r.URL.Path == "/css" || strings.HasPrefix(r.URL.Path, "/assets")
}

func authenticateHandler() func(http.Handler) http.Handler {
type UserGetter interface {
GetUser(ctx context.Context, username string) (*models.User, error)
}

func authenticateHandler(g UserGetter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c := config.GetInstance()
s := c.UserStore

// error if external access tripwire activated
if accessErr := session.CheckExternalAccessTripwire(c); accessErr != nil {
if accessErr := session.CheckExternalAccessTripwire(s, c); accessErr != nil {
http.Error(w, tripwireActivatedErrMsg, http.StatusForbidden)
return
}
Expand All @@ -53,7 +60,9 @@ func authenticateHandler() func(http.Handler) http.Handler {
return
}

if err := session.CheckAllowPublicWithoutAuth(c, r); err != nil {
ctx := r.Context()

if err := session.CheckAllowPublicWithoutAuth(s, c, r); err != nil {
var accessErr session.ExternalAccessError
if errors.As(err, &accessErr) {
session.LogExternalAccessError(accessErr)
Expand All @@ -71,11 +80,23 @@ func authenticateHandler() func(http.Handler) http.Handler {
return
}

ctx := r.Context()
var u *models.User
if userID != "" {
u, err = g.GetUser(ctx, userID)
if err != nil {
// if we can't get the user object, we just return a forbidden error
logger.Errorf("Error getting user object: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if u == nil {
logger.Errorf("[User] cookie user %q not found", userID)
}
}

if c.HasCredentials() {
if hc := s.LoginRequired(ctx); hc {
// authentication is required
if userID == "" && !allowUnauthenticated(r) {
if u == nil && !allowUnauthenticated(r) {
// if graphql or a non-webpage was requested, we just return a forbidden error
ext := path.Ext(r.URL.Path)
if r.URL.Path == gqlEndpoint || (ext != "" && ext != ".html") {
Expand All @@ -102,7 +123,10 @@ func authenticateHandler() func(http.Handler) http.Handler {
}
}

ctx = session.SetCurrentUserID(ctx, userID)
if u != nil {
// set the user object in the context
ctx = session.SetCurrentUser(ctx, *u)
}

r = r.WithContext(ctx)

Expand Down
47 changes: 47 additions & 0 deletions internal/api/directives.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package api

import (
"context"

"github.com/99designs/gqlgen/graphql"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/session"
)

func HasRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role models.RoleEnum) (interface{}, error) {
currentUser := session.GetCurrentUser(ctx)

// if there is no current user, this is an anonymous request
// we should not end up here unless there are no credentials required
if currentUser == nil {
return next(ctx)
}

if currentUser != nil && !currentUser.Roles.HasRole(role) {
return nil, session.ErrUnauthorized
}

return next(ctx)
}

func IsUserOwnerDirective(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) {
currentUser := session.GetCurrentUser(ctx)

// if there is no current user, this is an anonymous request
// we should not end up here unless there are no credentials required
if currentUser == nil {
return next(ctx)
}

// get the user from the object
userObj, ok := obj.(*models.User)
if !ok {
return nil, session.ErrUnauthorized
}

if currentUser.Username != userObj.Username {
return nil, session.ErrUnauthorized
}

return next(ctx)
}
5 changes: 5 additions & 0 deletions internal/api/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Resolver struct {
imageService manager.ImageService
galleryService manager.GalleryService
groupService manager.GroupService
userService manager.UserService

hookExecutor hookExecutor
}
Expand Down Expand Up @@ -110,6 +111,9 @@ func (r *Resolver) Plugin() PluginResolver {
func (r *Resolver) ConfigResult() ConfigResultResolver {
return &configResultResolver{r}
}
func (r *Resolver) User() UserResolver {
return &userResolver{r}
}

type mutationResolver struct{ *Resolver }
type queryResolver struct{ *Resolver }
Expand All @@ -136,6 +140,7 @@ type folderResolver struct{ *Resolver }
type savedFilterResolver struct{ *Resolver }
type pluginResolver struct{ *Resolver }
type configResultResolver struct{ *Resolver }
type userResolver struct{ *Resolver }

func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
return r.repository.WithTxn(ctx, fn)
Expand Down
23 changes: 23 additions & 0 deletions internal/api/resolver_model_user.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package api

import (
"context"

"github.com/stashapp/stash/pkg/models"
)

func (r *userResolver) Name(ctx context.Context, obj *models.User) (string, error) {
return obj.Username, nil
}

func (r *userResolver) Roles(ctx context.Context, obj *models.User) ([]models.RoleEnum, error) {
ret := make([]models.RoleEnum, len(obj.Roles))
for i, role := range obj.Roles {
ret[i] = models.RoleEnum(role)
}
return ret, nil
}

func (r *userResolver) APIKey(ctx context.Context, obj *models.User) (*string, error) {
return nil, nil
}
62 changes: 62 additions & 0 deletions internal/api/resolver_mutation_user.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package api

import (
"context"

"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/session"
)

func (r *mutationResolver) UserCreate(ctx context.Context, input UserCreateInput) (*models.User, error) {
err := r.userService.CreateUser(ctx, models.User{
Username: input.Name,
Roles: models.Roles(input.Roles),
}, input.Password)
if err != nil {
return nil, err
}

return r.userService.GetUser(ctx, input.Name)
}

func (r *mutationResolver) UserUpdate(ctx context.Context, input UserUpdateInput) (*models.User, error) {
err := r.userService.UpdateUser(ctx, input.ExistingName, models.User{
Username: input.Name,
Roles: models.Roles(input.Roles),
})
if err != nil {
return nil, err
}

return r.userService.GetUser(ctx, input.Name)
}

func (r *mutationResolver) UserDestroy(ctx context.Context, input UserDestroyInput) (bool, error) {
err := r.userService.DeleteUser(ctx, input.Name)
if err != nil {
return false, err
}

return true, nil
}

func (r *mutationResolver) ChangePassword(ctx context.Context, input UserChangePasswordInput) (bool, error) {
// get current user
u := session.GetCurrentUser(ctx)

err := r.userService.ChangePassword(ctx, u.Username, input.ExistingPassword, input.NewPassword)
if err != nil {
return false, err
}

return true, nil
}

func (r *mutationResolver) ChangeUserPassword(ctx context.Context, input ChangeUserPasswordInput) (bool, error) {
err := r.userService.ChangeUserPassword(ctx, input.Name, input.NewPassword)
if err != nil {
return false, err
}

return true, nil
}
17 changes: 17 additions & 0 deletions internal/api/resolver_query_user.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package api

import (
"context"

"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/session"
)

func (r *queryResolver) Users(ctx context.Context) ([]*models.User, error) {
return r.userService.AllUsers(ctx)
}

func (r *queryResolver) Me(ctx context.Context) (*models.User, error) {
// get current user
return session.GetCurrentUser(ctx), nil
}
24 changes: 19 additions & 5 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,11 @@ func Initialize() (*Server, error) {
manager: mgr,
}

userStore := manager.GetInstance().UserService

r.Use(middleware.Heartbeat("/healthz"))
r.Use(cors.AllowAll().Handler)
r.Use(authenticateHandler())
r.Use(authenticateHandler(userStore))
visitedPluginHandler := mgr.SessionStore.VisitedPluginHandler()
r.Use(visitedPluginHandler)

Expand Down Expand Up @@ -162,16 +164,26 @@ func Initialize() (*Server, error) {
imageService := mgr.ImageService
galleryService := mgr.GalleryService
groupService := mgr.GroupService
userService := mgr.UserService
resolver := &Resolver{
repository: repo,
sceneService: sceneService,
imageService: imageService,
galleryService: galleryService,
groupService: groupService,
userService: userService,
hookExecutor: pluginCache,
}

gqlSrv := gqlHandler.New(NewExecutableSchema(Config{Resolvers: resolver}))
gqlCfg := Config{
Resolvers: resolver,
Directives: DirectiveRoot{
HasRole: HasRoleDirective,
IsUserOwner: IsUserOwnerDirective,
},
}

gqlSrv := gqlHandler.New(NewExecutableSchema(gqlCfg))
gqlSrv.SetRecoverFunc(recoverFunc)
gqlSrv.AddTransport(gqlTransport.Websocket{
Upgrader: websocket.Upgrader{
Expand Down Expand Up @@ -227,9 +239,11 @@ func Initialize() (*Server, error) {

staticLoginUI := statigz.FileServer(ui.LoginUIBox.(fs.ReadDirFS))

r.Get(loginEndpoint, handleLogin())
r.Post(loginEndpoint, handleLoginPost())
r.Get(logoutEndpoint, handleLogout())
sessionStore := mgr.SessionStore

r.Get(loginEndpoint, handleLogin(userService))
r.Post(loginEndpoint, handleLoginPost(sessionStore))
r.Get(logoutEndpoint, handleLogout(sessionStore))
r.Get(loginLocaleEndpoint, handleLoginLocale(cfg))
r.HandleFunc(loginEndpoint+"/*", func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = strings.TrimPrefix(r.URL.Path, loginEndpoint)
Expand Down
Loading