From 8311343c3f060789f7c26b13d9b670ce819c4079 Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Thu, 16 Oct 2025 16:40:41 +1100 Subject: [PATCH 1/5] Add custom fields tables schema migration --- pkg/sqlite/database.go | 2 +- pkg/sqlite/migrations/76_studio_custom_fields.up.sql | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 pkg/sqlite/migrations/76_studio_custom_fields.up.sql diff --git a/pkg/sqlite/database.go b/pkg/sqlite/database.go index 0ea3d71700..a87f6706fc 100644 --- a/pkg/sqlite/database.go +++ b/pkg/sqlite/database.go @@ -34,7 +34,7 @@ const ( cacheSizeEnv = "STASH_SQLITE_CACHE_SIZE" ) -var appSchemaVersion uint = 75 +var appSchemaVersion uint = 76 //go:embed migrations/*.sql var migrationsBox embed.FS diff --git a/pkg/sqlite/migrations/76_studio_custom_fields.up.sql b/pkg/sqlite/migrations/76_studio_custom_fields.up.sql new file mode 100644 index 0000000000..81a72d4d4e --- /dev/null +++ b/pkg/sqlite/migrations/76_studio_custom_fields.up.sql @@ -0,0 +1,9 @@ +CREATE TABLE `studio_custom_fields` ( + `studio_id` integer NOT NULL, + `field` varchar(64) NOT NULL, + `value` BLOB NOT NULL, + PRIMARY KEY (`studio_id`, `field`), + foreign key(`studio_id`) references `studios`(`id`) on delete CASCADE +); + +CREATE INDEX `index_studio_custom_fields_field_value` ON `studio_custom_fields` (`field`, `value`); From 60e9a07a7fb7750efea33c47015cbe7032ddf190 Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Thu, 16 Oct 2025 18:35:11 +1100 Subject: [PATCH 2/5] Studio custom field changes --- graphql/schema/types/studio.graphql | 6 +++ internal/api/loaders/dataloaders.go | 21 ++++++++- internal/api/resolver_model_studio.go | 13 ++++++ internal/api/resolver_mutation_studio.go | 8 +++- pkg/gallery/import.go | 2 +- pkg/group/import.go | 2 +- pkg/image/import.go | 2 +- pkg/models/mocks/StudioReaderWriter.go | 54 ++++++++++++++++++++++-- pkg/models/model_scraped_item.go | 4 +- pkg/models/model_studio.go | 21 +++++++++ pkg/models/repository_studio.go | 6 ++- pkg/models/studio.go | 7 +++ pkg/scene/import.go | 2 +- pkg/sqlite/studio.go | 30 ++++++++++--- pkg/sqlite/studio_filter.go | 7 +++ pkg/sqlite/tables.go | 1 + pkg/studio/import.go | 6 +-- pkg/studio/validate.go | 2 +- 18 files changed, 170 insertions(+), 24 deletions(-) diff --git a/graphql/schema/types/studio.graphql b/graphql/schema/types/studio.graphql index 4c5778c5b1..6da0ccc959 100644 --- a/graphql/schema/types/studio.graphql +++ b/graphql/schema/types/studio.graphql @@ -26,6 +26,8 @@ type Studio { groups: [Group!]! movies: [Movie!]! @deprecated(reason: "use groups instead") o_counter: Int + + custom_fields: Map! } input StudioCreateInput { @@ -43,6 +45,8 @@ input StudioCreateInput { aliases: [String!] tag_ids: [ID!] ignore_auto_tag: Boolean + + custom_fields: Map } input StudioUpdateInput { @@ -61,6 +65,8 @@ input StudioUpdateInput { aliases: [String!] tag_ids: [ID!] ignore_auto_tag: Boolean + + custom_fields: CustomFieldsInput } input BulkStudioUpdateInput { diff --git a/internal/api/loaders/dataloaders.go b/internal/api/loaders/dataloaders.go index 38f72b0a14..4676966c93 100644 --- a/internal/api/loaders/dataloaders.go +++ b/internal/api/loaders/dataloaders.go @@ -59,7 +59,9 @@ type Loaders struct { PerformerByID *PerformerLoader PerformerCustomFields *CustomFieldsLoader - StudioByID *StudioLoader + StudioByID *StudioLoader + StudioCustomFields *CustomFieldsLoader + TagByID *TagLoader GroupByID *GroupLoader FileByID *FileLoader @@ -99,6 +101,11 @@ func (m Middleware) Middleware(next http.Handler) http.Handler { maxBatch: maxBatch, fetch: m.fetchPerformerCustomFields(ctx), }, + StudioCustomFields: &CustomFieldsLoader{ + wait: wait, + maxBatch: maxBatch, + fetch: m.fetchStudioCustomFields(ctx), + }, StudioByID: &StudioLoader{ wait: wait, maxBatch: maxBatch, @@ -253,6 +260,18 @@ func (m Middleware) fetchStudios(ctx context.Context) func(keys []int) ([]*model } } +func (m Middleware) fetchStudioCustomFields(ctx context.Context) func(keys []int) ([]models.CustomFieldMap, []error) { + return func(keys []int) (ret []models.CustomFieldMap, errs []error) { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { + var err error + ret, err = m.Repository.Studio.GetCustomFieldsBulk(ctx, keys) + return err + }) + + return ret, toErrorSlice(err) + } +} + func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.Tag, []error) { return func(keys []int) (ret []*models.Tag, errs []error) { err := m.Repository.WithDB(ctx, func(ctx context.Context) error { diff --git a/internal/api/resolver_model_studio.go b/internal/api/resolver_model_studio.go index fabcf38bd2..b54455920c 100644 --- a/internal/api/resolver_model_studio.go +++ b/internal/api/resolver_model_studio.go @@ -207,6 +207,19 @@ func (r *studioResolver) Groups(ctx context.Context, obj *models.Studio) (ret [] return ret, nil } +func (r *studioResolver) CustomFields(ctx context.Context, obj *models.Studio) (map[string]interface{}, error) { + m, err := loaders.From(ctx).StudioCustomFields.Load(obj.ID) + if err != nil { + return nil, err + } + + if m == nil { + return make(map[string]interface{}), nil + } + + return m, nil +} + // deprecated func (r *studioResolver) Movies(ctx context.Context, obj *models.Studio) (ret []*models.Group, err error) { return r.Groups(ctx, obj) diff --git a/internal/api/resolver_mutation_studio.go b/internal/api/resolver_mutation_studio.go index 4b3316111a..3fc035f522 100644 --- a/internal/api/resolver_mutation_studio.go +++ b/internal/api/resolver_mutation_studio.go @@ -31,7 +31,7 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio } // Populate a new studio from the input - newStudio := models.NewStudio() + newStudio := models.NewCreateStudioInput() newStudio.Name = strings.TrimSpace(input.Name) newStudio.Rating = input.Rating100 @@ -61,6 +61,7 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio if err != nil { return nil, fmt.Errorf("converting tag ids: %w", err) } + newStudio.CustomFields = convertMapJSONNumbers(input.CustomFields) // Process the base 64 encoded image string var imageData []byte @@ -152,6 +153,11 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio } } + updatedStudio.CustomFields = input.CustomFields + // convert json.Numbers to int/float + updatedStudio.CustomFields.Full = convertMapJSONNumbers(updatedStudio.CustomFields.Full) + updatedStudio.CustomFields.Partial = convertMapJSONNumbers(updatedStudio.CustomFields.Partial) + // Process the base 64 encoded image string var imageData []byte imageIncluded := translator.hasField("image") diff --git a/pkg/gallery/import.go b/pkg/gallery/import.go index 0068b3f1cf..543d4cf486 100644 --- a/pkg/gallery/import.go +++ b/pkg/gallery/import.go @@ -126,7 +126,7 @@ func (i *Importer) populateStudio(ctx context.Context) error { } func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { - newStudio := models.NewStudio() + newStudio := models.NewCreateStudioInput() newStudio.Name = name err := i.StudioWriter.Create(ctx, &newStudio) diff --git a/pkg/group/import.go b/pkg/group/import.go index 3fc7db8f15..a73c3998e3 100644 --- a/pkg/group/import.go +++ b/pkg/group/import.go @@ -203,7 +203,7 @@ func (i *Importer) populateStudio(ctx context.Context) error { } func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { - newStudio := models.NewStudio() + newStudio := models.NewCreateStudioInput() newStudio.Name = name err := i.StudioWriter.Create(ctx, &newStudio) diff --git a/pkg/image/import.go b/pkg/image/import.go index bf92a6ae8c..77b6d74773 100644 --- a/pkg/image/import.go +++ b/pkg/image/import.go @@ -159,7 +159,7 @@ func (i *Importer) populateStudio(ctx context.Context) error { } func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { - newStudio := models.NewStudio() + newStudio := models.NewCreateStudioInput() newStudio.Name = name err := i.StudioWriter.Create(ctx, &newStudio) diff --git a/pkg/models/mocks/StudioReaderWriter.go b/pkg/models/mocks/StudioReaderWriter.go index 481565d6fc..f57a73aa1d 100644 --- a/pkg/models/mocks/StudioReaderWriter.go +++ b/pkg/models/mocks/StudioReaderWriter.go @@ -80,11 +80,11 @@ func (_m *StudioReaderWriter) CountByTagID(ctx context.Context, tagID int) (int, } // Create provides a mock function with given fields: ctx, newStudio -func (_m *StudioReaderWriter) Create(ctx context.Context, newStudio *models.Studio) error { +func (_m *StudioReaderWriter) Create(ctx context.Context, newStudio *models.CreateStudioInput) error { ret := _m.Called(ctx, newStudio) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *models.Studio) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *models.CreateStudioInput) error); ok { r0 = rf(ctx, newStudio) } else { r0 = ret.Error(0) @@ -291,6 +291,52 @@ func (_m *StudioReaderWriter) GetAliases(ctx context.Context, relatedID int) ([] return r0, r1 } +// GetCustomFields provides a mock function with given fields: ctx, id +func (_m *StudioReaderWriter) GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) { + ret := _m.Called(ctx, id) + + var r0 map[string]interface{} + if rf, ok := ret.Get(0).(func(context.Context, int) map[string]interface{}); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interface{}) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetCustomFieldsBulk provides a mock function with given fields: ctx, ids +func (_m *StudioReaderWriter) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]models.CustomFieldMap, error) { + ret := _m.Called(ctx, ids) + + var r0 []models.CustomFieldMap + if rf, ok := ret.Get(0).(func(context.Context, []int) []models.CustomFieldMap); ok { + r0 = rf(ctx, ids) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.CustomFieldMap) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok { + r1 = rf(ctx, ids) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetImage provides a mock function with given fields: ctx, studioID func (_m *StudioReaderWriter) GetImage(ctx context.Context, studioID int) ([]byte, error) { ret := _m.Called(ctx, studioID) @@ -479,11 +525,11 @@ func (_m *StudioReaderWriter) QueryForAutoTag(ctx context.Context, words []strin } // Update provides a mock function with given fields: ctx, updatedStudio -func (_m *StudioReaderWriter) Update(ctx context.Context, updatedStudio *models.Studio) error { +func (_m *StudioReaderWriter) Update(ctx context.Context, updatedStudio *models.UpdateStudioInput) error { ret := _m.Called(ctx, updatedStudio) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *models.Studio) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *models.UpdateStudioInput) error); ok { r0 = rf(ctx, updatedStudio) } else { r0 = ret.Error(0) diff --git a/pkg/models/model_scraped_item.go b/pkg/models/model_scraped_item.go index 4254a98769..bd6db10c8a 100644 --- a/pkg/models/model_scraped_item.go +++ b/pkg/models/model_scraped_item.go @@ -27,9 +27,9 @@ type ScrapedStudio struct { func (ScrapedStudio) IsScrapedContent() {} -func (s *ScrapedStudio) ToStudio(endpoint string, excluded map[string]bool) *Studio { +func (s *ScrapedStudio) ToStudio(endpoint string, excluded map[string]bool) *CreateStudioInput { // Populate a new studio from the input - ret := NewStudio() + ret := NewCreateStudioInput() ret.Name = strings.TrimSpace(s.Name) if s.RemoteSiteID != nil && endpoint != "" && *s.RemoteSiteID != "" { diff --git a/pkg/models/model_studio.go b/pkg/models/model_studio.go index 8c7a687af5..ee6fae2d25 100644 --- a/pkg/models/model_studio.go +++ b/pkg/models/model_studio.go @@ -23,6 +23,18 @@ type Studio struct { StashIDs RelatedStashIDs `json:"stash_ids"` } +type CreateStudioInput struct { + *Studio + + CustomFields map[string]interface{} `json:"custom_fields"` +} + +type UpdateStudioInput struct { + *Studio + + CustomFields CustomFieldsInput `json:"custom_fields"` +} + func NewStudio() Studio { currentTime := time.Now() return Studio{ @@ -31,6 +43,13 @@ func NewStudio() Studio { } } +func NewCreateStudioInput() CreateStudioInput { + s := NewStudio() + return CreateStudioInput{ + Studio: &s, + } +} + // StudioPartial represents part of a Studio object. It is used to update the database entry. type StudioPartial struct { ID int @@ -48,6 +67,8 @@ type StudioPartial struct { URLs *UpdateStrings TagIDs *UpdateIDs StashIDs *UpdateStashIDs + + CustomFields CustomFieldsInput } func NewStudioPartial() StudioPartial { diff --git a/pkg/models/repository_studio.go b/pkg/models/repository_studio.go index 99f98bffca..54fb6ed479 100644 --- a/pkg/models/repository_studio.go +++ b/pkg/models/repository_studio.go @@ -42,12 +42,12 @@ type StudioCounter interface { // StudioCreator provides methods to create studios. type StudioCreator interface { - Create(ctx context.Context, newStudio *Studio) error + Create(ctx context.Context, newStudio *CreateStudioInput) error } // StudioUpdater provides methods to update studios. type StudioUpdater interface { - Update(ctx context.Context, updatedStudio *Studio) error + Update(ctx context.Context, updatedStudio *UpdateStudioInput) error UpdatePartial(ctx context.Context, updatedStudio StudioPartial) (*Studio, error) UpdateImage(ctx context.Context, studioID int, image []byte) error } @@ -79,6 +79,8 @@ type StudioReader interface { TagIDLoader URLLoader + CustomFieldsReader + All(ctx context.Context) ([]*Studio, error) GetImage(ctx context.Context, studioID int) ([]byte, error) HasImage(ctx context.Context, studioID int) (bool, error) diff --git a/pkg/models/studio.go b/pkg/models/studio.go index 1711681297..aed4ad53f8 100644 --- a/pkg/models/studio.go +++ b/pkg/models/studio.go @@ -44,6 +44,9 @@ type StudioFilterType struct { CreatedAt *TimestampCriterionInput `json:"created_at"` // Filter by updated at UpdatedAt *TimestampCriterionInput `json:"updated_at"` + + // Filter by custom fields + CustomFields []CustomFieldCriterionInput `json:"custom_fields"` } type StudioCreateInput struct { @@ -60,6 +63,8 @@ type StudioCreateInput struct { Aliases []string `json:"aliases"` TagIds []string `json:"tag_ids"` IgnoreAutoTag *bool `json:"ignore_auto_tag"` + + CustomFields map[string]interface{} `json:"custom_fields"` } type StudioUpdateInput struct { @@ -77,4 +82,6 @@ type StudioUpdateInput struct { Aliases []string `json:"aliases"` TagIds []string `json:"tag_ids"` IgnoreAutoTag *bool `json:"ignore_auto_tag"` + + CustomFields CustomFieldsInput `json:"custom_fields"` } diff --git a/pkg/scene/import.go b/pkg/scene/import.go index efffd380da..b3f0f1ff12 100644 --- a/pkg/scene/import.go +++ b/pkg/scene/import.go @@ -213,7 +213,7 @@ func (i *Importer) populateStudio(ctx context.Context) error { } func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { - newStudio := models.NewStudio() + newStudio := models.NewCreateStudioInput() newStudio.Name = name err := i.StudioWriter.Create(ctx, &newStudio) diff --git a/pkg/sqlite/studio.go b/pkg/sqlite/studio.go index 1a05be6f36..74a85fc708 100644 --- a/pkg/sqlite/studio.go +++ b/pkg/sqlite/studio.go @@ -140,6 +140,7 @@ var ( type StudioStore struct { blobJoinQueryBuilder + customFieldsStore tagRelationshipStore tableMgr *table @@ -151,6 +152,10 @@ func NewStudioStore(blobStore *BlobStore) *StudioStore { blobStore: blobStore, joinTable: studioTable, }, + customFieldsStore: customFieldsStore{ + table: studiosCustomFieldsTable, + fk: studiosCustomFieldsTable.Col(studioIDColumn), + }, tagRelationshipStore: tagRelationshipStore{ idRelationshipStore: idRelationshipStore{ joinTable: studiosTagsTableMgr, @@ -169,11 +174,11 @@ func (qb *StudioStore) selectDataset() *goqu.SelectDataset { return dialect.From(qb.table()).Select(qb.table().All()) } -func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) error { +func (qb *StudioStore) Create(ctx context.Context, newObject *models.CreateStudioInput) error { var err error var r studioRow - r.fromStudio(*newObject) + r.fromStudio(*newObject.Studio) id, err := qb.tableMgr.insertID(ctx, r) if err != nil { @@ -207,12 +212,17 @@ func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) err } } + const partial = false + if err := qb.setCustomFields(ctx, id, newObject.CustomFields, partial); err != nil { + return err + } + updated, err := qb.find(ctx, id) if err != nil { return fmt.Errorf("finding after create: %w", err) } - *newObject = *updated + *newObject.Studio = *updated return nil } @@ -253,13 +263,17 @@ func (qb *StudioStore) UpdatePartial(ctx context.Context, input models.StudioPar } } - return qb.Find(ctx, input.ID) + if err := qb.SetCustomFields(ctx, input.ID, input.CustomFields); err != nil { + return nil, err + } + + return qb.find(ctx, input.ID) } // This is only used by the Import/Export functionality -func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio) error { +func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.UpdateStudioInput) error { var r studioRow - r.fromStudio(*updatedObject) + r.fromStudio(*updatedObject.Studio) if err := qb.tableMgr.updateByID(ctx, updatedObject.ID, r); err != nil { return err @@ -287,6 +301,10 @@ func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio) } } + if err := qb.SetCustomFields(ctx, updatedObject.ID, updatedObject.CustomFields); err != nil { + return err + } + return nil } diff --git a/pkg/sqlite/studio_filter.go b/pkg/sqlite/studio_filter.go index 6ff7fcced9..130e318d87 100644 --- a/pkg/sqlite/studio_filter.go +++ b/pkg/sqlite/studio_filter.go @@ -111,6 +111,13 @@ func (qb *studioFilterHandler) criterionHandler() criterionHandler { studioRepository.galleries.innerJoin(f, "", "studios.id") }, }, + + &customFieldsFilterHandler{ + table: studiosCustomFieldsTable.GetTable(), + fkCol: studioIDColumn, + c: studioFilter.CustomFields, + idCol: "studios.id", + }, } } diff --git a/pkg/sqlite/tables.go b/pkg/sqlite/tables.go index 7cddf25ccc..bfc5199fee 100644 --- a/pkg/sqlite/tables.go +++ b/pkg/sqlite/tables.go @@ -40,6 +40,7 @@ var ( studiosURLsJoinTable = goqu.T(studioURLsTable) studiosTagsJoinTable = goqu.T(studiosTagsTable) studiosStashIDsJoinTable = goqu.T("studio_stash_ids") + studiosCustomFieldsTable = goqu.T("studio_custom_fields") groupsURLsJoinTable = goqu.T(groupURLsTable) groupsTagsJoinTable = goqu.T(groupsTagsTable) diff --git a/pkg/studio/import.go b/pkg/studio/import.go index 405852e539..d5284ce022 100644 --- a/pkg/studio/import.go +++ b/pkg/studio/import.go @@ -153,7 +153,7 @@ func (i *Importer) populateParentStudio(ctx context.Context) error { } func (i *Importer) createParentStudio(ctx context.Context, name string) (int, error) { - newStudio := models.NewStudio() + newStudio := models.NewCreateStudioInput() newStudio.Name = name err := i.ReaderWriter.Create(ctx, &newStudio) @@ -194,7 +194,7 @@ func (i *Importer) FindExistingID(ctx context.Context) (*int, error) { } func (i *Importer) Create(ctx context.Context) (*int, error) { - err := i.ReaderWriter.Create(ctx, &i.studio) + err := i.ReaderWriter.Create(ctx, &models.CreateStudioInput{Studio: &i.studio}) if err != nil { return nil, fmt.Errorf("error creating studio: %v", err) } @@ -206,7 +206,7 @@ func (i *Importer) Create(ctx context.Context) (*int, error) { func (i *Importer) Update(ctx context.Context, id int) error { studio := i.studio studio.ID = id - err := i.ReaderWriter.Update(ctx, &studio) + err := i.ReaderWriter.Update(ctx, &models.UpdateStudioInput{Studio: &studio}) if err != nil { return fmt.Errorf("error updating existing studio: %v", err) } diff --git a/pkg/studio/validate.go b/pkg/studio/validate.go index 4e2f51c84c..dcc3138218 100644 --- a/pkg/studio/validate.go +++ b/pkg/studio/validate.go @@ -75,7 +75,7 @@ func ValidateAliases(ctx context.Context, id int, aliases []string, qb models.St return nil } -func ValidateCreate(ctx context.Context, studio models.Studio, qb models.StudioQueryer) error { +func ValidateCreate(ctx context.Context, studio models.CreateStudioInput, qb models.StudioQueryer) error { if err := validateName(ctx, 0, studio.Name, qb); err != nil { return err } From 2d068c10fc764352ca0f0cc1538489f247dc59e0 Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Thu, 16 Oct 2025 18:35:32 +1100 Subject: [PATCH 3/5] Update unit tests --- internal/autotag/integration_test.go | 7 +- internal/identify/scene_test.go | 2 +- internal/identify/studio_test.go | 6 +- pkg/gallery/import_test.go | 8 +- pkg/group/import_test.go | 8 +- pkg/image/import_test.go | 8 +- pkg/models/model_scraped_item_test.go | 2 +- pkg/scene/import_test.go | 8 +- pkg/sqlite/setup_test.go | 25 +- pkg/sqlite/studio_test.go | 559 +++++++++++++++++++++++++- pkg/studio/import_test.go | 18 +- 11 files changed, 605 insertions(+), 46 deletions(-) diff --git a/internal/autotag/integration_test.go b/internal/autotag/integration_test.go index fc83df848b..605082b98e 100644 --- a/internal/autotag/integration_test.go +++ b/internal/autotag/integration_test.go @@ -101,16 +101,15 @@ func createPerformer(ctx context.Context, pqb models.PerformerWriter) error { func createStudio(ctx context.Context, qb models.StudioWriter, name string) (*models.Studio, error) { // create the studio - studio := models.Studio{ - Name: name, - } + studio := models.NewCreateStudioInput() + studio.Name = name err := qb.Create(ctx, &studio) if err != nil { return nil, err } - return &studio, nil + return studio.Studio, nil } func createTag(ctx context.Context, qb models.TagWriter) error { diff --git a/internal/identify/scene_test.go b/internal/identify/scene_test.go index a76aef516d..0eec61c4e6 100644 --- a/internal/identify/scene_test.go +++ b/internal/identify/scene_test.go @@ -27,7 +27,7 @@ func Test_sceneRelationships_studio(t *testing.T) { db := mocks.NewDatabase() db.Studio.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) + s := args.Get(1).(*models.CreateStudioInput) s.ID = validStoredIDInt }).Return(nil) diff --git a/internal/identify/studio_test.go b/internal/identify/studio_test.go index 5424a6a93c..083675650a 100644 --- a/internal/identify/studio_test.go +++ b/internal/identify/studio_test.go @@ -21,13 +21,13 @@ func Test_createMissingStudio(t *testing.T) { db := mocks.NewDatabase() - db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool { + db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateStudioInput) bool { return p.Name == validName })).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) + s := args.Get(1).(*models.CreateStudioInput) s.ID = createdID }).Return(nil) - db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool { + db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateStudioInput) bool { return p.Name == invalidName })).Return(errors.New("error creating studio")) diff --git a/pkg/gallery/import_test.go b/pkg/gallery/import_test.go index b64f80d8f6..4248f51bcb 100644 --- a/pkg/gallery/import_test.go +++ b/pkg/gallery/import_test.go @@ -115,9 +115,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) - s.ID = existingStudioID + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) { + s := args.Get(1).(*models.CreateStudioInput) + s.Studio.ID = existingStudioID }).Return(nil) err := i.PreImport(testCtx) @@ -147,7 +147,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/group/import_test.go b/pkg/group/import_test.go index c4ca47442a..50b8b2dd10 100644 --- a/pkg/group/import_test.go +++ b/pkg/group/import_test.go @@ -121,9 +121,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) - s.ID = existingStudioID + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) { + s := args.Get(1).(*models.CreateStudioInput) + s.Studio.ID = existingStudioID }).Return(nil) err := i.PreImport(testCtx) @@ -156,7 +156,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/image/import_test.go b/pkg/image/import_test.go index 286e51fe34..98b3972b9a 100644 --- a/pkg/image/import_test.go +++ b/pkg/image/import_test.go @@ -77,9 +77,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) - s.ID = existingStudioID + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) { + s := args.Get(1).(*models.CreateStudioInput) + s.Studio.ID = existingStudioID }).Return(nil) err := i.PreImport(testCtx) @@ -109,7 +109,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/models/model_scraped_item_test.go b/pkg/models/model_scraped_item_test.go index b6b44025f2..5455436529 100644 --- a/pkg/models/model_scraped_item_test.go +++ b/pkg/models/model_scraped_item_test.go @@ -113,7 +113,7 @@ func Test_scrapedToStudioInput(t *testing.T) { got.StashIDs.List()[stid].UpdatedAt = time.Time{} } } - assert.Equal(t, tt.want, got) + assert.Equal(t, tt.want, got.Studio) }) } } diff --git a/pkg/scene/import_test.go b/pkg/scene/import_test.go index a6e3edcdfd..558b72ba26 100644 --- a/pkg/scene/import_test.go +++ b/pkg/scene/import_test.go @@ -241,9 +241,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) - s.ID = existingStudioID + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) { + s := args.Get(1).(*models.CreateStudioInput) + s.Studio.ID = existingStudioID }).Return(nil) err := i.PreImport(testCtx) @@ -273,7 +273,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index 704dde8a2b..f8c8544133 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -1738,7 +1738,19 @@ func getStudioNullStringValue(index int, field string) string { return ret.String } -func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int) (*models.Studio, error) { +func getStudioCustomFields(index int) map[string]interface{} { + if index%5 == 0 { + return nil + } + + return map[string]interface{}{ + "string": getStudioStringValue(index, "custom"), + "int": int64(index % 5), + "real": float64(index) / 10, + } +} + +func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int, customFields map[string]interface{}) (*models.Studio, error) { studio := models.Studio{ Name: name, } @@ -1747,7 +1759,7 @@ func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, par studio.ParentID = parentID } - err := createStudioFromModel(ctx, sqb, &studio) + err := createStudioFromModel(ctx, sqb, &studio, customFields) if err != nil { return nil, err } @@ -1755,8 +1767,11 @@ func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, par return &studio, nil } -func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio) error { - err := sqb.Create(ctx, studio) +func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio, customFields map[string]interface{}) error { + err := sqb.Create(ctx, &models.CreateStudioInput{ + Studio: studio, + CustomFields: customFields, + }) if err != nil { return fmt.Errorf("Error creating studio %v+: %s", studio, err.Error()) @@ -1818,7 +1833,7 @@ func createStudios(ctx context.Context, n int, o int) error { alias := getStudioStringValue(i, "Alias") studio.Aliases = models.NewRelatedStrings([]string{alias}) } - err := createStudioFromModel(ctx, sqb, &studio) + err := createStudioFromModel(ctx, sqb, &studio, getStudioCustomFields(i)) if err != nil { return err diff --git a/pkg/sqlite/studio_test.go b/pkg/sqlite/studio_test.go index 003877c779..954d0f5aa3 100644 --- a/pkg/sqlite/studio_test.go +++ b/pkg/sqlite/studio_test.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/stashapp/stash/pkg/models" "github.com/stretchr/testify/assert" @@ -47,6 +48,550 @@ func TestStudioFindByName(t *testing.T) { }) } +func loadStudioRelationships(ctx context.Context, expected models.Studio, actual *models.Studio) error { + if expected.Aliases.Loaded() { + if err := actual.LoadAliases(ctx, db.Studio); err != nil { + return err + } + } + if expected.TagIDs.Loaded() { + if err := actual.LoadTagIDs(ctx, db.Studio); err != nil { + return err + } + } + if expected.StashIDs.Loaded() { + if err := actual.LoadStashIDs(ctx, db.Studio); err != nil { + return err + } + } + + return nil +} + +func Test_StudioStore_Create(t *testing.T) { + var ( + name = "name" + details = "details" + url = "url" + rating = 3 + aliases = []string{"alias1", "alias2"} + ignoreAutoTag = true + favorite = true + endpoint1 = "endpoint1" + endpoint2 = "endpoint2" + stashID1 = "stashid1" + stashID2 = "stashid2" + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + ) + + tests := []struct { + name string + newObject models.CreateStudioInput + wantErr bool + }{ + { + "full", + models.CreateStudioInput{ + Studio: &models.Studio{ + Name: name, + URL: url, + Favorite: favorite, + Rating: &rating, + Details: details, + IgnoreAutoTag: ignoreAutoTag, + TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithStudio], tagIDs[tagIdx1WithDupName]}), + Aliases: models.NewRelatedStrings(aliases), + StashIDs: models.NewRelatedStashIDs([]models.StashID{ + { + StashID: stashID1, + Endpoint: endpoint1, + UpdatedAt: epochTime, + }, + { + StashID: stashID2, + Endpoint: endpoint2, + UpdatedAt: epochTime, + }, + }), + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + CustomFields: testCustomFields, + }, + false, + }, + { + "invalid tag id", + models.CreateStudioInput{ + Studio: &models.Studio{ + Name: name, + TagIDs: models.NewRelatedIDs([]int{invalidID}), + }, + }, + true, + }, + } + + qb := db.Studio + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + p := tt.newObject + if err := qb.Create(ctx, &p); (err != nil) != tt.wantErr { + t.Errorf("StudioStore.Create() error = %v, wantErr = %v", err, tt.wantErr) + } + + if tt.wantErr { + assert.Zero(p.ID) + return + } + + assert.NotZero(p.ID) + + copy := *tt.newObject.Studio + copy.ID = p.ID + + // load relationships + if err := loadStudioRelationships(ctx, copy, p.Studio); err != nil { + t.Errorf("loadStudioRelationships() error = %v", err) + return + } + + assert.Equal(copy, *p.Studio) + + // ensure can find the Studio + found, err := qb.Find(ctx, p.ID) + if err != nil { + t.Errorf("StudioStore.Find() error = %v", err) + } + + if !assert.NotNil(found) { + return + } + + // load relationships + if err := loadStudioRelationships(ctx, copy, found); err != nil { + t.Errorf("loadStudioRelationships() error = %v", err) + return + } + assert.Equal(copy, *found) + + // ensure custom fields are set + cf, err := qb.GetCustomFields(ctx, p.ID) + if err != nil { + t.Errorf("StudioStore.GetCustomFields() error = %v", err) + return + } + + assert.Equal(tt.newObject.CustomFields, cf) + + return + }) + } +} + +func Test_StudioStore_Update(t *testing.T) { + var ( + name = "name" + details = "details" + url = "url" + rating = 3 + aliases = []string{"aliasX", "aliasY"} + ignoreAutoTag = true + favorite = true + endpoint1 = "endpoint1" + endpoint2 = "endpoint2" + stashID1 = "stashid1" + stashID2 = "stashid2" + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + ) + + tests := []struct { + name string + updatedObject models.UpdateStudioInput + wantErr bool + }{ + { + "full", + models.UpdateStudioInput{ + Studio: &models.Studio{ + ID: studioIDs[studioIdxWithGallery], + Name: name, + URL: url, + Favorite: favorite, + Rating: &rating, + Details: details, + IgnoreAutoTag: ignoreAutoTag, + Aliases: models.NewRelatedStrings(aliases), + TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithStudio]}), + StashIDs: models.NewRelatedStashIDs([]models.StashID{ + { + StashID: stashID1, + Endpoint: endpoint1, + UpdatedAt: epochTime, + }, + { + StashID: stashID2, + Endpoint: endpoint2, + UpdatedAt: epochTime, + }, + }), + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + }, + false, + }, + { + "clear nullables", + models.UpdateStudioInput{ + Studio: &models.Studio{ + ID: studioIDs[studioIdxWithGallery], + Name: name, // name is mandatory + Aliases: models.NewRelatedStrings([]string{}), + TagIDs: models.NewRelatedIDs([]int{}), + StashIDs: models.NewRelatedStashIDs([]models.StashID{}), + }, + }, + false, + }, + { + "clear tag ids", + models.UpdateStudioInput{ + Studio: &models.Studio{ + ID: studioIDs[sceneIdxWithTag], + Name: name, // name is mandatory + TagIDs: models.NewRelatedIDs([]int{}), + }, + }, + false, + }, + { + "set custom fields", + models.UpdateStudioInput{ + Studio: &models.Studio{ + ID: studioIDs[studioIdxWithGallery], + Name: name, // name is mandatory + }, + CustomFields: models.CustomFieldsInput{ + Full: testCustomFields, + }, + }, + false, + }, + { + "clear custom fields", + models.UpdateStudioInput{ + Studio: &models.Studio{ + ID: studioIDs[studioIdxWithGallery], + Name: name, // name is mandatory + }, + CustomFields: models.CustomFieldsInput{ + Full: map[string]interface{}{}, + }, + }, + false, + }, + { + "invalid tag id", + models.UpdateStudioInput{ + Studio: &models.Studio{ + ID: studioIDs[sceneIdxWithGallery], + Name: name, // name is mandatory + TagIDs: models.NewRelatedIDs([]int{invalidID}), + }, + }, + true, + }, + } + + qb := db.Studio + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + copy := *tt.updatedObject.Studio + + if err := qb.Update(ctx, &tt.updatedObject); (err != nil) != tt.wantErr { + t.Errorf("StudioStore.Update() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + s, err := qb.Find(ctx, tt.updatedObject.ID) + if err != nil { + t.Errorf("StudioStore.Find() error = %v", err) + } + + // load relationships + if err := loadStudioRelationships(ctx, copy, s); err != nil { + t.Errorf("loadStudioRelationships() error = %v", err) + return + } + + assert.Equal(copy, *s) + + // ensure custom fields are correct + if tt.updatedObject.CustomFields.Full != nil { + cf, err := qb.GetCustomFields(ctx, tt.updatedObject.ID) + if err != nil { + t.Errorf("StudioStore.GetCustomFields() error = %v", err) + return + } + + assert.Equal(tt.updatedObject.CustomFields.Full, cf) + } + }) + } +} + +func clearStudioPartial() models.StudioPartial { + nullString := models.OptionalString{Set: true, Null: true} + nullInt := models.OptionalInt{Set: true, Null: true} + + // leave mandatory fields + return models.StudioPartial{ + URL: nullString, + Aliases: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet}, + Rating: nullInt, + Details: nullString, + TagIDs: &models.UpdateIDs{Mode: models.RelationshipUpdateModeSet}, + StashIDs: &models.UpdateStashIDs{Mode: models.RelationshipUpdateModeSet}, + } +} + +func Test_StudioStore_UpdatePartial(t *testing.T) { + var ( + name = "name" + details = "details" + url = "url" + aliases = []string{"aliasX", "aliasY"} + rating = 3 + ignoreAutoTag = true + favorite = true + endpoint1 = "endpoint1" + endpoint2 = "endpoint2" + stashID1 = "stashid1" + stashID2 = "stashid2" + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + ) + + tests := []struct { + name string + id int + partial models.StudioPartial + want models.Studio + wantErr bool + }{ + { + "full", + studioIDs[studioIdxWithDupName], + models.StudioPartial{ + Name: models.NewOptionalString(name), + URL: models.NewOptionalString(url), + Aliases: &models.UpdateStrings{ + Values: aliases, + Mode: models.RelationshipUpdateModeSet, + }, + Favorite: models.NewOptionalBool(favorite), + Rating: models.NewOptionalInt(rating), + Details: models.NewOptionalString(details), + IgnoreAutoTag: models.NewOptionalBool(ignoreAutoTag), + TagIDs: &models.UpdateIDs{ + IDs: []int{tagIDs[tagIdx1WithStudio], tagIDs[tagIdx1WithDupName]}, + Mode: models.RelationshipUpdateModeSet, + }, + StashIDs: &models.UpdateStashIDs{ + StashIDs: []models.StashID{ + { + StashID: stashID1, + Endpoint: endpoint1, + UpdatedAt: epochTime, + }, + { + StashID: stashID2, + Endpoint: endpoint2, + UpdatedAt: epochTime, + }, + }, + Mode: models.RelationshipUpdateModeSet, + }, + CreatedAt: models.NewOptionalTime(createdAt), + UpdatedAt: models.NewOptionalTime(updatedAt), + }, + models.Studio{ + ID: studioIDs[studioIdxWithDupName], + Name: name, + URL: url, + Aliases: models.NewRelatedStrings(aliases), + Favorite: favorite, + Rating: &rating, + Details: details, + IgnoreAutoTag: ignoreAutoTag, + TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithStudio]}), + StashIDs: models.NewRelatedStashIDs([]models.StashID{ + { + StashID: stashID1, + Endpoint: endpoint1, + UpdatedAt: epochTime, + }, + { + StashID: stashID2, + Endpoint: endpoint2, + UpdatedAt: epochTime, + }, + }), + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + false, + }, + { + "clear all", + studioIDs[studioIdxWithTwoTags], + clearStudioPartial(), + models.Studio{ + ID: studioIDs[studioIdxWithTwoTags], + Name: getStudioStringValue(studioIdxWithTwoTags, "Name"), + Favorite: getStudioBoolValue(studioIdxWithTwoTags), + Aliases: models.NewRelatedStrings([]string{}), + TagIDs: models.NewRelatedIDs([]int{}), + StashIDs: models.NewRelatedStashIDs([]models.StashID{}), + IgnoreAutoTag: getIgnoreAutoTag(studioIdxWithTwoTags), + }, + false, + }, + { + "invalid id", + invalidID, + models.StudioPartial{Name: models.NewOptionalString(name)}, + models.Studio{}, + true, + }, + } + for _, tt := range tests { + qb := db.Studio + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + tt.partial.ID = tt.id + + got, err := qb.UpdatePartial(ctx, tt.partial) + if (err != nil) != tt.wantErr { + t.Errorf("StudioStore.UpdatePartial() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return + } + + if err := loadStudioRelationships(ctx, tt.want, got); err != nil { + t.Errorf("loadStudioRelationships() error = %v", err) + return + } + + assert.Equal(tt.want, *got) + + s, err := qb.Find(ctx, tt.id) + if err != nil { + t.Errorf("StudioStore.Find() error = %v", err) + } + + // load relationships + if err := loadStudioRelationships(ctx, tt.want, s); err != nil { + t.Errorf("loadStudioRelationships() error = %v", err) + return + } + + assert.Equal(tt.want, *s) + }) + } +} + +func Test_StudioStore_UpdatePartialCustomFields(t *testing.T) { + tests := []struct { + name string + id int + partial models.StudioPartial + expected map[string]interface{} // nil to use the partial + }{ + { + "set custom fields", + studioIDs[studioIdxWithGallery], + models.StudioPartial{ + CustomFields: models.CustomFieldsInput{ + Full: testCustomFields, + }, + }, + nil, + }, + { + "clear custom fields", + studioIDs[studioIdxWithGallery], + models.StudioPartial{ + CustomFields: models.CustomFieldsInput{ + Full: map[string]interface{}{}, + }, + }, + nil, + }, + { + "partial custom fields", + studioIDs[studioIdxWithGallery], + models.StudioPartial{ + CustomFields: models.CustomFieldsInput{ + Partial: map[string]interface{}{ + "string": "bbb", + "new_field": "new", + }, + }, + }, + map[string]interface{}{ + "int": int64(2), + "real": 0.7, + "string": "bbb", + "new_field": "new", + }, + }, + } + for _, tt := range tests { + qb := db.Studio + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + tt.partial.ID = tt.id + + _, err := qb.UpdatePartial(ctx, tt.partial) + if err != nil { + t.Errorf("StudioStore.UpdatePartial() error = %v", err) + return + } + + // ensure custom fields are correct + cf, err := qb.GetCustomFields(ctx, tt.id) + if err != nil { + t.Errorf("StudioStore.GetCustomFields() error = %v", err) + return + } + if tt.expected == nil { + assert.Equal(tt.partial.CustomFields.Full, cf) + } else { + assert.Equal(tt.expected, cf) + } + }) + } +} + func TestStudioQueryNameOr(t *testing.T) { const studio1Idx = 1 const studio2Idx = 2 @@ -311,13 +856,13 @@ func TestStudioDestroyParent(t *testing.T) { // create parent and child studios if err := withTxn(func(ctx context.Context) error { - createdParent, err := createStudio(ctx, db.Studio, parentName, nil) + createdParent, err := createStudio(ctx, db.Studio, parentName, nil, nil) if err != nil { return fmt.Errorf("Error creating parent studio: %s", err.Error()) } parentID := createdParent.ID - createdChild, err := createStudio(ctx, db.Studio, childName, &parentID) + createdChild, err := createStudio(ctx, db.Studio, childName, &parentID, nil) if err != nil { return fmt.Errorf("Error creating child studio: %s", err.Error()) } @@ -373,13 +918,13 @@ func TestStudioUpdateClearParent(t *testing.T) { // create parent and child studios if err := withTxn(func(ctx context.Context) error { - createdParent, err := createStudio(ctx, db.Studio, parentName, nil) + createdParent, err := createStudio(ctx, db.Studio, parentName, nil, nil) if err != nil { return fmt.Errorf("Error creating parent studio: %s", err.Error()) } parentID := createdParent.ID - createdChild, err := createStudio(ctx, db.Studio, childName, &parentID) + createdChild, err := createStudio(ctx, db.Studio, childName, &parentID, nil) if err != nil { return fmt.Errorf("Error creating child studio: %s", err.Error()) } @@ -414,7 +959,7 @@ func TestStudioUpdateStudioImage(t *testing.T) { // create studio to test against const name = "TestStudioUpdateStudioImage" - created, err := createStudio(ctx, db.Studio, name, nil) + created, err := createStudio(ctx, db.Studio, name, nil, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } @@ -578,7 +1123,7 @@ func TestStudioStashIDs(t *testing.T) { // create studio to test against const name = "TestStudioStashIDs" - created, err := createStudio(ctx, db.Studio, name, nil) + created, err := createStudio(ctx, db.Studio, name, nil, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } @@ -990,7 +1535,7 @@ func TestStudioAlias(t *testing.T) { // create studio to test against const name = "TestStudioAlias" - created, err := createStudio(ctx, db.Studio, name, nil) + created, err := createStudio(ctx, db.Studio, name, nil, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } diff --git a/pkg/studio/import_test.go b/pkg/studio/import_test.go index 882b8ca568..6648ebe0d4 100644 --- a/pkg/studio/import_test.go +++ b/pkg/studio/import_test.go @@ -206,9 +206,9 @@ func TestImporterPreImportWithMissingParent(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Times(3) - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) - s.ID = existingStudioID + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) { + s := args.Get(1).(*models.CreateStudioInput) + s.Studio.ID = existingStudioID }).Return(nil) err := i.PreImport(testCtx) @@ -240,7 +240,7 @@ func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Once() - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) @@ -327,11 +327,11 @@ func TestCreate(t *testing.T) { } errCreate := errors.New("Create error") - db.Studio.On("Create", testCtx, &studio).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) + db.Studio.On("Create", testCtx, &models.CreateStudioInput{Studio: &studio}).Run(func(args mock.Arguments) { + s := args.Get(1).(*models.CreateStudioInput) s.ID = studioID }).Return(nil).Once() - db.Studio.On("Create", testCtx, &studioErr).Return(errCreate).Once() + db.Studio.On("Create", testCtx, &models.CreateStudioInput{Studio: &studioErr}).Return(errCreate).Once() id, err := i.Create(testCtx) assert.Equal(t, studioID, *id) @@ -366,7 +366,7 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input studio.ID = studioID - db.Studio.On("Update", testCtx, &studio).Return(nil).Once() + db.Studio.On("Update", testCtx, &models.UpdateStudioInput{Studio: &studio}).Return(nil).Once() err := i.Update(testCtx, studioID) assert.Nil(t, err) @@ -375,7 +375,7 @@ func TestUpdate(t *testing.T) { // need to set id separately studioErr.ID = errImageID - db.Studio.On("Update", testCtx, &studioErr).Return(errUpdate).Once() + db.Studio.On("Update", testCtx, &models.UpdateStudioInput{Studio: &studioErr}).Return(errUpdate).Once() err = i.Update(testCtx, errImageID) assert.NotNil(t, err) From 02858dbab3650e13f1eb984ca31718412143732a Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Mon, 20 Oct 2025 12:51:57 +1100 Subject: [PATCH 4/5] Add custom fields filtering on studio --- graphql/schema/types/filters.graphql | 2 ++ 1 file changed, 2 insertions(+) diff --git a/graphql/schema/types/filters.graphql b/graphql/schema/types/filters.graphql index 4eb91aa77b..dd1d633cc0 100644 --- a/graphql/schema/types/filters.graphql +++ b/graphql/schema/types/filters.graphql @@ -466,6 +466,8 @@ input StudioFilterType { created_at: TimestampCriterionInput "Filter by last update time" updated_at: TimestampCriterionInput + + custom_fields: [CustomFieldCriterionInput!] } input GalleryFilterType { From 02d3619352caaff373b6f5fd6dd662fcedf4d3e2 Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Tue, 9 Dec 2025 13:52:48 +1100 Subject: [PATCH 5/5] Update unit tests --- pkg/sqlite/studio_test.go | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pkg/sqlite/studio_test.go b/pkg/sqlite/studio_test.go index 954d0f5aa3..074c77d6f7 100644 --- a/pkg/sqlite/studio_test.go +++ b/pkg/sqlite/studio_test.go @@ -54,6 +54,11 @@ func loadStudioRelationships(ctx context.Context, expected models.Studio, actual return err } } + if expected.URLs.Loaded() { + if err := actual.LoadURLs(ctx, db.Studio); err != nil { + return err + } + } if expected.TagIDs.Loaded() { if err := actual.LoadTagIDs(ctx, db.Studio); err != nil { return err @@ -95,7 +100,7 @@ func Test_StudioStore_Create(t *testing.T) { models.CreateStudioInput{ Studio: &models.Studio{ Name: name, - URL: url, + URLs: models.NewRelatedStrings([]string{url}), Favorite: favorite, Rating: &rating, Details: details, @@ -221,7 +226,7 @@ func Test_StudioStore_Update(t *testing.T) { Studio: &models.Studio{ ID: studioIDs[studioIdxWithGallery], Name: name, - URL: url, + URLs: models.NewRelatedStrings([]string{url}), Favorite: favorite, Rating: &rating, Details: details, @@ -252,6 +257,7 @@ func Test_StudioStore_Update(t *testing.T) { Studio: &models.Studio{ ID: studioIDs[studioIdxWithGallery], Name: name, // name is mandatory + URLs: models.NewRelatedStrings([]string{}), Aliases: models.NewRelatedStrings([]string{}), TagIDs: models.NewRelatedIDs([]int{}), StashIDs: models.NewRelatedStashIDs([]models.StashID{}), @@ -357,7 +363,7 @@ func clearStudioPartial() models.StudioPartial { // leave mandatory fields return models.StudioPartial{ - URL: nullString, + URLs: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet}, Aliases: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet}, Rating: nullInt, Details: nullString, @@ -395,7 +401,10 @@ func Test_StudioStore_UpdatePartial(t *testing.T) { studioIDs[studioIdxWithDupName], models.StudioPartial{ Name: models.NewOptionalString(name), - URL: models.NewOptionalString(url), + URLs: &models.UpdateStrings{ + Values: []string{url}, + Mode: models.RelationshipUpdateModeSet, + }, Aliases: &models.UpdateStrings{ Values: aliases, Mode: models.RelationshipUpdateModeSet, @@ -429,7 +438,7 @@ func Test_StudioStore_UpdatePartial(t *testing.T) { models.Studio{ ID: studioIDs[studioIdxWithDupName], Name: name, - URL: url, + URLs: models.NewRelatedStrings([]string{url}), Aliases: models.NewRelatedStrings(aliases), Favorite: favorite, Rating: &rating, @@ -627,14 +636,6 @@ func TestStudioQueryNameOr(t *testing.T) { }) } -func loadStudioRelationships(ctx context.Context, t *testing.T, s *models.Studio) error { - if err := s.LoadURLs(ctx, db.Studio); err != nil { - return err - } - - return nil -} - func TestStudioQueryNameAndUrl(t *testing.T) { const studioIdx = 1 studioName := getStudioStringValue(studioIdx, "Name")