Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,19 @@ class BulkDeleteModelsResponse(BaseModel):
failed: List[dict] = Field(description="List of failed deletions with error messages")


class BulkReidentifyModelsRequest(BaseModel):
"""Request body for bulk model reidentification."""

keys: List[str] = Field(description="List of model keys to reidentify")


class BulkReidentifyModelsResponse(BaseModel):
"""Response body for bulk model reidentification."""

succeeded: List[str] = Field(description="List of successfully reidentified model keys")
failed: List[dict] = Field(description="List of failed reidentifications with error messages")


@model_manager_router.post(
"/i/bulk_delete",
operation_id="bulk_delete_models",
Expand Down Expand Up @@ -538,6 +551,65 @@ async def bulk_delete_models(
return BulkDeleteModelsResponse(deleted=deleted, failed=failed)


@model_manager_router.post(
"/i/bulk_reidentify",
operation_id="bulk_reidentify_models",
responses={
200: {"description": "Models reidentified (possibly with some failures)"},
},
status_code=200,
)
async def bulk_reidentify_models(
current_admin: AdminUserOrDefault,
request: BulkReidentifyModelsRequest = Body(description="List of model keys to reidentify"),
) -> BulkReidentifyModelsResponse:
"""
Reidentify multiple models by re-probing their weights files.

Returns a list of successfully reidentified keys and failed reidentifications with error messages.
"""
logger = ApiDependencies.invoker.services.logger
store = ApiDependencies.invoker.services.model_manager.store
models_path = ApiDependencies.invoker.services.configuration.models_path

succeeded = []
failed = []

for key in request.keys:
try:
config = store.get_model(key)
if pathlib.Path(config.path).is_relative_to(models_path):
model_path = pathlib.Path(config.path)
else:
model_path = models_path / config.path
mod = ModelOnDisk(model_path)
result = ModelConfigFactory.from_model_on_disk(mod)
if result.config is None:
raise InvalidModelException("Unable to identify model format")

# Retain user-editable fields from the original config
result.config.key = config.key
result.config.name = config.name
result.config.description = config.description
result.config.cover_image = config.cover_image
result.config.trigger_phrases = config.trigger_phrases
result.config.source = config.source
result.config.source_type = config.source_type

store.replace_model(config.key, result.config)
succeeded.append(key)
logger.info(f"Reidentified model: {key}")
except UnknownModelException as e:
logger.error(f"Failed to reidentify model {key}: {str(e)}")
failed.append({"key": key, "error": str(e)})
except Exception as e:
logger.error(f"Failed to reidentify model {key}: {str(e)}")
failed.append({"key": key, "error": str(e)})

logger.info(f"Bulk reidentify completed: {len(succeeded)} succeeded, {len(failed)} failed")
return BulkReidentifyModelsResponse(succeeded=succeeded, failed=failed)


@model_manager_router.delete(
"/i/{key}/image",
operation_id="delete_model_image",
Expand Down
9 changes: 9 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,15 @@
"reidentifySuccess": "Model reidentified successfully",
"reidentifyUnknown": "Unable to identify model",
"reidentifyError": "Error reidentifying model",
"reidentifyModels": "Reidentify Models",
"reidentifyModelsConfirm": "Are you sure you want to reidentify {{count}} model(s)? This will re-probe their weights files to determine the correct format and settings.",
"reidentifyWarning": "This will reset any custom settings you may have applied to these models.",
"modelsReidentified": "Successfully reidentified {{count}} model(s)",
"modelsReidentifyFailed": "Failed to reidentify models",
"someModelsFailedToReidentify": "{{count}} model(s) could not be reidentified",
"modelsReidentifiedPartial": "Partially completed",
"someModelsReidentified": "{{succeeded}} reidentified, {{failed}} failed",
"modelsReidentifyError": "Error reidentifying models",
"updatePath": "Update Path",
"updatePathTooltip": "Update the file path for this model if you have moved the model files to a new location.",
"updatePathDescription": "Enter the new path to the model file or directory. Use this if you have manually moved the model files on disk.",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import {
AlertDialog,
AlertDialogBody,
AlertDialogContent,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogOverlay,
Button,
Flex,
Text,
} from '@invoke-ai/ui-library';
import { memo, useRef } from 'react';
import { useTranslation } from 'react-i18next';

type BulkReidentifyModelsModalProps = {
isOpen: boolean;
onClose: () => void;
onConfirm: () => void;
modelCount: number;
isReidentifying?: boolean;
};

export const BulkReidentifyModelsModal = memo(
({ isOpen, onClose, onConfirm, modelCount, isReidentifying = false }: BulkReidentifyModelsModalProps) => {
const { t } = useTranslation();
const cancelRef = useRef<HTMLButtonElement>(null);

return (
<AlertDialog isOpen={isOpen} onClose={onClose} leastDestructiveRef={cancelRef} isCentered>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
{t('modelManager.reidentifyModels', {
count: modelCount,
defaultValue: 'Reidentify Models',
})}
</AlertDialogHeader>

<AlertDialogBody>
<Flex flexDir="column" gap={3}>
<Text>
{t('modelManager.reidentifyModelsConfirm', {
count: modelCount,
defaultValue: `Are you sure you want to reidentify ${modelCount} model(s)? This will re-probe their weights files to determine the correct format and settings.`,
})}
</Text>
<Text fontWeight="semibold" color="warning.400">
{t('modelManager.reidentifyWarning', {
defaultValue: 'This will reset any custom settings you may have applied to these models.',
})}
</Text>
</Flex>
</AlertDialogBody>

<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose} isDisabled={isReidentifying}>
{t('common.cancel')}
</Button>
<Button colorScheme="warning" onClick={onConfirm} ml={3} isLoading={isReidentifying}>
{t('modelManager.reidentify')}
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
}
);

BulkReidentifyModelsModal.displayName = 'BulkReidentifyModelsModal';
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,22 @@ import { serializeError } from 'serialize-error';
import {
modelConfigsAdapterSelectors,
useBulkDeleteModelsMutation,
useBulkReidentifyModelsMutation,
useGetMissingModelsQuery,
useGetModelConfigsQuery,
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';

import { BulkDeleteModelsModal } from './BulkDeleteModelsModal';
import { BulkReidentifyModelsModal } from './BulkReidentifyModelsModal';
import { FetchingModelsLoader } from './FetchingModelsLoader';
import { MissingModelsProvider } from './MissingModelsContext';
import { ModelListWrapper } from './ModelListWrapper';

const log = logger('models');

export const [useBulkDeleteModal] = buildUseDisclosure(false);
export const [useBulkReidentifyModal] = buildUseDisclosure(false);

const ModelList = () => {
const dispatch = useAppDispatch();
Expand All @@ -40,11 +43,14 @@ const ModelList = () => {
const { t } = useTranslation();
const toast = useToast();
const { isOpen, close } = useBulkDeleteModal();
const { isOpen: isReidentifyOpen, close: closeReidentify } = useBulkReidentifyModal();
const [isDeleting, setIsDeleting] = useState(false);
const [isReidentifying, setIsReidentifying] = useState(false);

const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery();
const { data: missingModelsData, isLoading: isLoadingMissing } = useGetMissingModelsQuery();
const [bulkDeleteModels] = useBulkDeleteModelsMutation();
const [bulkReidentifyModels] = useBulkReidentifyModelsMutation();

const data = filteredModelType === 'missing' ? missingModelsData : allModelsData;
const isLoading = filteredModelType === 'missing' ? isLoadingMissing : isLoadingAll;
Expand Down Expand Up @@ -148,6 +154,67 @@ const ModelList = () => {
}
}, [bulkDeleteModels, selectedModelKeys, dispatch, close, toast, t]);

const handleConfirmBulkReidentify = useCallback(async () => {
setIsReidentifying(true);
try {
const result = await bulkReidentifyModels({ keys: selectedModelKeys }).unwrap();

// Clear selection and close modal
dispatch(clearModelSelection());
dispatch(setSelectedModelKey(null));
closeReidentify();

if (result.failed.length === 0) {
toast({
id: 'BULK_REIDENTIFY_SUCCESS',
title: t('modelManager.modelsReidentified', {
count: result.succeeded.length,
defaultValue: `Successfully reidentified ${result.succeeded.length} model(s)`,
}),
status: 'success',
});
} else if (result.succeeded.length === 0) {
toast({
id: 'BULK_REIDENTIFY_FAILED',
title: t('modelManager.modelsReidentifyFailed', {
defaultValue: 'Failed to reidentify models',
}),
description: t('modelManager.someModelsFailedToReidentify', {
count: result.failed.length,
defaultValue: `${result.failed.length} model(s) could not be reidentified`,
}),
status: 'error',
});
} else {
toast({
id: 'BULK_REIDENTIFY_PARTIAL',
title: t('modelManager.modelsReidentifiedPartial', {
defaultValue: 'Partially completed',
}),
description: t('modelManager.someModelsReidentified', {
succeeded: result.succeeded.length,
failed: result.failed.length,
defaultValue: `${result.succeeded.length} reidentified, ${result.failed.length} failed`,
}),
status: 'warning',
});
}

log.info(`Bulk reidentify completed: ${result.succeeded.length} succeeded, ${result.failed.length} failed`);
} catch (err) {
log.error({ error: serializeError(err as Error) }, 'Bulk reidentify error');
toast({
id: 'BULK_REIDENTIFY_ERROR',
title: t('modelManager.modelsReidentifyError', {
defaultValue: 'Error reidentifying models',
}),
status: 'error',
});
} finally {
setIsReidentifying(false);
}
}, [bulkReidentifyModels, selectedModelKeys, dispatch, closeReidentify, toast, t]);

return (
<MissingModelsProvider>
<Flex flexDirection="column" w="full" h="full">
Expand All @@ -173,6 +240,13 @@ const ModelList = () => {
modelCount={selectedModelKeys.length}
isDeleting={isDeleting}
/>
<BulkReidentifyModelsModal
isOpen={isReidentifyOpen}
onClose={closeReidentify}
onConfirm={handleConfirmBulkReidentify}
modelCount={selectedModelKeys.length}
isReidentifying={isReidentifying}
/>
</MissingModelsProvider>
);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ import {
} from 'features/modelManagerV2/store/modelManagerV2Slice';
import { t } from 'i18next';
import { memo, useCallback, useMemo } from 'react';
import { PiCaretDownBold, PiTrashSimpleBold } from 'react-icons/pi';
import { PiCaretDownBold, PiSparkleFill, PiTrashSimpleBold } from 'react-icons/pi';
import {
modelConfigsAdapterSelectors,
useGetMissingModelsQuery,
useGetModelConfigsQuery,
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';

import { useBulkDeleteModal } from './ModelList';
import { useBulkDeleteModal, useBulkReidentifyModal } from './ModelList';

const ModelListBulkActionsSx: SystemStyleObject = {
alignItems: 'center',
Expand All @@ -40,11 +40,16 @@ export const ModelListBulkActions = memo(({ sx }: ModelListBulkActionsProps) =>
const { data: allModelsData } = useGetModelConfigsQuery();
const { data: missingModelsData } = useGetMissingModelsQuery();
const bulkDeleteModal = useBulkDeleteModal();
const bulkReidentifyModal = useBulkReidentifyModal();

const handleBulkDelete = useCallback(() => {
bulkDeleteModal.open();
}, [bulkDeleteModal]);

const handleBulkReidentify = useCallback(() => {
bulkReidentifyModal.open();
}, [bulkReidentifyModal]);

// Calculate displayed (filtered) model keys
const displayedModelKeys = useMemo(() => {
// Use missing models data when the filter is 'missing'
Expand Down Expand Up @@ -125,6 +130,12 @@ export const ModelListBulkActions = memo(({ sx }: ModelListBulkActionsProps) =>
{t('modelManager.actions')}
</MenuButton>
<MenuList>
<MenuItem icon={<PiSparkleFill />} onClick={handleBulkReidentify}>
{t('modelManager.reidentifyModels', {
count: selectionCount,
defaultValue: 'Reidentify Models',
})}
</MenuItem>
<MenuItem icon={<PiTrashSimpleBold />} onClick={handleBulkDelete} color="error.300">
{t('modelManager.deleteModels', { count: selectionCount })}
</MenuItem>
Expand Down
19 changes: 19 additions & 0 deletions invokeai/frontend/web/src/services/api/endpoints/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ type BulkDeleteModelsResponse = {
failed: string[];
};

type BulkReidentifyModelsArg = {
keys: string[];
};
type BulkReidentifyModelsResponse = {
succeeded: string[];
failed: string[];
};

type ConvertMainModelResponse =
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];

Expand Down Expand Up @@ -419,6 +427,16 @@ export const modelsApi = api.injectEndpoints({
}
},
}),
bulkReidentifyModels: build.mutation<BulkReidentifyModelsResponse, BulkReidentifyModelsArg>({
query: ({ keys }) => {
return {
url: buildModelsUrl('i/bulk_reidentify'),
method: 'POST',
body: { keys },
};
},
invalidatesTags: [{ type: 'ModelConfig', id: LIST_TAG }],
}),
getOrphanedModels: build.query<GetOrphanedModelsResponse, void>({
query: () => ({
url: buildModelsUrl('sync/orphaned'),
Expand Down Expand Up @@ -463,6 +481,7 @@ export const {
useResetHFTokenMutation,
useEmptyModelCacheMutation,
useReidentifyModelMutation,
useBulkReidentifyModelsMutation,
useGetOrphanedModelsQuery,
useDeleteOrphanedModelsMutation,
} = modelsApi;
Expand Down
Loading