diff --git a/cmd/cli/commands/compose.go b/cmd/cli/commands/compose.go index d7f17db0..b192e2a2 100644 --- a/cmd/cli/commands/compose.go +++ b/cmd/cli/commands/compose.go @@ -75,13 +75,12 @@ func newUpCommand() *cobra.Command { // Build speculative config if any speculative flags are set var speculativeConfig *inference.SpeculativeDecodingConfig if draftModel != "" || numTokens > 0 || minAcceptanceRate > 0 { - normalizedDraftModel := dmrm.NormalizeModelName(draftModel) speculativeConfig = &inference.SpeculativeDecodingConfig{ - DraftModel: normalizedDraftModel, + DraftModel: draftModel, NumTokens: numTokens, MinAcceptanceRate: minAcceptanceRate, } - sendInfo(fmt.Sprintf("Enabling speculative decoding with draft model: %s", normalizedDraftModel)) + sendInfo(fmt.Sprintf("Enabling speculative decoding with draft model: %s", draftModel)) } for _, model := range models { diff --git a/cmd/cli/commands/configure.go b/cmd/cli/commands/configure.go index 4bb50872..a6fc461f 100644 --- a/cmd/cli/commands/configure.go +++ b/cmd/cli/commands/configure.go @@ -5,7 +5,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/pkg/inference" - "github.com/docker/model-runner/pkg/inference/models" + "github.com/docker/model-runner/pkg/inference/scheduling" "github.com/spf13/cobra" ) @@ -39,7 +39,7 @@ func newConfigureCmd() *cobra.Command { argsBeforeDash) } } - opts.Model = models.NormalizeModelName(args[0]) + opts.Model = args[0] opts.RuntimeFlags = args[1:] return nil }, @@ -47,7 +47,7 @@ func newConfigureCmd() *cobra.Command { // Build the speculative config if any speculative flags are set if draftModel != "" || numTokens > 0 || minAcceptanceRate > 0 { opts.Speculative = &inference.SpeculativeDecodingConfig{ - DraftModel: models.NormalizeModelName(draftModel), + DraftModel: draftModel, NumTokens: numTokens, MinAcceptanceRate: minAcceptanceRate, } diff --git a/cmd/cli/commands/inspect.go b/cmd/cli/commands/inspect.go index fbc1d8c4..0b3f34fa 100644 --- a/cmd/cli/commands/inspect.go +++ b/cmd/cli/commands/inspect.go @@ -6,7 +6,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/commands/formatter" "github.com/docker/model-runner/cmd/cli/desktop" - "github.com/docker/model-runner/pkg/inference/models" + "github.com/spf13/cobra" ) @@ -48,8 +48,7 @@ func newInspectCmd() *cobra.Command { } func inspectModel(args []string, openai bool, remote bool, desktopClient *desktop.Client) (string, error) { - // Normalize model name to add default org and tag if missing - modelName := models.NormalizeModelName(args[0]) + modelName := args[0] if openai { model, err := desktopClient.InspectOpenAI(modelName) if err != nil { diff --git a/cmd/cli/commands/list.go b/cmd/cli/commands/list.go index bed61fa6..67491ca7 100644 --- a/cmd/cli/commands/list.go +++ b/cmd/cli/commands/list.go @@ -72,12 +72,13 @@ func listModels(openai bool, desktopClient *desktop.Client, quiet bool, jsonForm } if modelFilter != "" { - // Normalize the filter to match stored model names + // Normalize the filter to match stored model names (backend normalizes when storing) normalizedFilter := dmrm.NormalizeModelName(modelFilter) var filteredModels []dmrm.Model for _, m := range models { hasMatchingTag := false for _, tag := range m.Tags { + // Tags are stored in normalized format by the backend if tag == normalizedFilter { hasMatchingTag = true break diff --git a/cmd/cli/commands/package.go b/cmd/cli/commands/package.go index d02504b6..8a3ea353 100644 --- a/cmd/cli/commands/package.go +++ b/cmd/cli/commands/package.go @@ -16,7 +16,7 @@ import ( "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/tarball" "github.com/docker/model-runner/pkg/distribution/types" - "github.com/docker/model-runner/pkg/inference/models" + "github.com/google/go-containerregistry/pkg/name" "github.com/spf13/cobra" @@ -325,6 +325,7 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error { } cmd.PrintErrln("Model variant created successfully") + return nil // Return early to avoid the Build operation in lightweight case } else { // Process directory tar archives if len(opts.dirTarPaths) > 0 { @@ -412,9 +413,7 @@ func newModelRunnerTarget(client *desktop.Client, tag string) (*modelRunnerTarge } if tag != "" { var err error - // Normalize the tag to add default namespace (ai/) and tag (:latest) if missing - normalizedTag := models.NormalizeModelName(tag) - target.tag, err = name.NewTag(normalizedTag) + target.tag, err = name.NewTag(tag) if err != nil { return nil, fmt.Errorf("invalid tag: %w", err) } diff --git a/cmd/cli/commands/pull.go b/cmd/cli/commands/pull.go index 23e435f8..960e5c91 100644 --- a/cmd/cli/commands/pull.go +++ b/cmd/cli/commands/pull.go @@ -6,7 +6,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" - "github.com/docker/model-runner/pkg/inference/models" + "github.com/mattn/go-isatty" "github.com/spf13/cobra" ) @@ -42,8 +42,6 @@ func newPullCmd() *cobra.Command { } func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string, ignoreRuntimeMemoryCheck bool) error { - // Normalize model name to add default org and tag if missing - model = models.NormalizeModelName(model) var progress func(string) if isatty.IsTerminal(os.Stdout.Fd()) { progress = TUIProgress diff --git a/cmd/cli/commands/push.go b/cmd/cli/commands/push.go index 0dd91a0a..ed9da538 100644 --- a/cmd/cli/commands/push.go +++ b/cmd/cli/commands/push.go @@ -5,7 +5,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" - "github.com/docker/model-runner/pkg/inference/models" + "github.com/spf13/cobra" ) @@ -35,8 +35,6 @@ func newPushCmd() *cobra.Command { } func pushModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error { - // Normalize model name to add default org and tag if missing - model = models.NormalizeModelName(model) response, progressShown, err := desktopClient.Push(model, TUIProgress) // Add a newline before any output (success or error) if progress was shown. diff --git a/cmd/cli/commands/rm.go b/cmd/cli/commands/rm.go index 4af2e3b9..8bc0d4ec 100644 --- a/cmd/cli/commands/rm.go +++ b/cmd/cli/commands/rm.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/docker/model-runner/cmd/cli/commands/completion" - "github.com/docker/model-runner/pkg/inference/models" + "github.com/spf13/cobra" ) @@ -28,12 +28,7 @@ func newRemoveCmd() *cobra.Command { if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), cmd); err != nil { return fmt.Errorf("unable to initialize standalone model runner: %w", err) } - // Normalize model names to add default org and tag if missing - normalizedArgs := make([]string, len(args)) - for i, arg := range args { - normalizedArgs[i] = models.NormalizeModelName(arg) - } - response, err := desktopClient.Remove(normalizedArgs, force) + response, err := desktopClient.Remove(args, force) if response != "" { cmd.Print(response) } diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 082725e0..40d93223 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -15,7 +15,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" "github.com/docker/model-runner/cmd/cli/readline" - "github.com/docker/model-runner/pkg/inference/models" + "github.com/fatih/color" "github.com/spf13/cobra" "golang.org/x/term" @@ -586,8 +586,7 @@ func newRunCmd() *cobra.Command { } }, RunE: func(cmd *cobra.Command, args []string) error { - // Normalize model name to add default org and tag if missing - model := models.NormalizeModelName(args[0]) + model := args[0] prompt := "" argsLen := len(args) if argsLen > 1 { diff --git a/cmd/cli/commands/tag.go b/cmd/cli/commands/tag.go index 265c51ad..eca69d57 100644 --- a/cmd/cli/commands/tag.go +++ b/cmd/cli/commands/tag.go @@ -6,7 +6,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" - "github.com/docker/model-runner/pkg/inference/models" + "github.com/google/go-containerregistry/pkg/name" "github.com/spf13/cobra" ) @@ -37,10 +37,6 @@ func newTagCmd() *cobra.Command { } func tagModel(cmd *cobra.Command, desktopClient *desktop.Client, source, target string) error { - // Normalize source model name to add default org and tag if missing - source = models.NormalizeModelName(source) - // Normalize target model name to add default org and tag if missing - target = models.NormalizeModelName(target) // Ensure tag is valid tag, err := name.NewTag(target) if err != nil { diff --git a/cmd/cli/commands/unload.go b/cmd/cli/commands/unload.go index c76b4391..d9d3e296 100644 --- a/cmd/cli/commands/unload.go +++ b/cmd/cli/commands/unload.go @@ -5,7 +5,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" - "github.com/docker/model-runner/pkg/inference/models" + "github.com/spf13/cobra" ) @@ -18,12 +18,7 @@ func newUnloadCmd() *cobra.Command { Use: "unload " + cmdArgs, Short: "Unload running models", RunE: func(cmd *cobra.Command, modelArgs []string) error { - // Normalize model names - normalizedModels := make([]string, len(modelArgs)) - for i, model := range modelArgs { - normalizedModels[i] = models.NormalizeModelName(model) - } - unloadResp, err := desktopClient.Unload(desktop.UnloadRequest{All: all, Backend: backend, Models: normalizedModels}) + unloadResp, err := desktopClient.Unload(desktop.UnloadRequest{All: all, Backend: backend, Models: modelArgs}) if err != nil { return handleClientError(err, "Failed to unload models") } diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 751e765c..0c2e08d9 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -56,6 +56,14 @@ type Status struct { Error error `json:"error"` } +// normalizeHuggingFaceModelName converts Hugging Face model names to lowercase +func normalizeHuggingFaceModelName(model string) string { + if strings.HasPrefix(model, "hf.co/") { + return strings.ToLower(model) + } + return model +} + func (c *Client) Status() Status { // TODO: Query "/". resp, err := c.doRequest(http.MethodGet, inference.ModelsPrefix, nil) @@ -98,7 +106,7 @@ func (c *Client) Status() Status { } func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) { - model = dmrm.NormalizeModelName(model) + model = normalizeHuggingFaceModelName(model) jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck}) if err != nil { return "", false, fmt.Errorf("error marshaling request: %w", err) @@ -166,7 +174,7 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func } func (c *Client) Push(model string, progress func(string)) (string, bool, error) { - model = dmrm.NormalizeModelName(model) + model = normalizeHuggingFaceModelName(model) pushPath := inference.ModelsPrefix + "/" + model + "/push" resp, err := c.doRequest( http.MethodPost, @@ -247,7 +255,7 @@ func (c *Client) ListOpenAI() (dmrm.OpenAIModelList, error) { } func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) { - model = dmrm.NormalizeModelName(model) + model = normalizeHuggingFaceModelName(model) if model != "" { if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. @@ -271,7 +279,7 @@ func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) { } func (c *Client) InspectOpenAI(model string) (dmrm.OpenAIModel, error) { - model = dmrm.NormalizeModelName(model) + model = normalizeHuggingFaceModelName(model) modelsRoute := inference.InferencePrefix + "/v1/models" if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. @@ -335,6 +343,31 @@ func (c *Client) fullModelID(id string) (string, error) { if m.ID[7:19] == id || strings.TrimPrefix(m.ID, "sha256:") == id || m.ID == id { return m.ID, nil } + // Check if the ID matches any of the model's tags using exact match first + for _, tag := range m.Tags { + if tag == id { + return m.ID, nil + } + } + // If not found with exact match, try partial name matching + for _, tag := range m.Tags { + // Extract the model name without tag part (e.g., from "ai/smollm2:latest" get "ai/smollm2") + tagWithoutVersion := tag + if idx := strings.LastIndex(tag, ":"); idx != -1 { + tagWithoutVersion = tag[:idx] + } + + // Get just the name part without organization (e.g., from "ai/smollm2" get "smollm2") + namePart := tagWithoutVersion + if idx := strings.LastIndex(tagWithoutVersion, "/"); idx != -1 { + namePart = tagWithoutVersion[idx+1:] + } + + // Check if the ID matches the name part + if namePart == id { + return m.ID, nil + } + } } return "", fmt.Errorf("model with ID %s not found", id) @@ -347,7 +380,7 @@ func (c *Client) Chat(model, prompt string, imageURLs []string, outputFunc func( // ChatWithContext performs a chat request with context support for cancellation and streams the response content with selective markdown rendering. func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imageURLs []string, outputFunc func(string), shouldUseMarkdown bool) error { - model = dmrm.NormalizeModelName(model) + model = normalizeHuggingFaceModelName(model) if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. if expanded, err := c.fullModelID(model); err == nil { @@ -523,7 +556,7 @@ func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imag func (c *Client) Remove(modelArgs []string, force bool) (string, error) { modelRemoved := "" for _, model := range modelArgs { - model = dmrm.NormalizeModelName(model) + model = normalizeHuggingFaceModelName(model) // Check if not a model ID passed as parameter. if !strings.Contains(model, "/") { if expanded, err := c.fullModelID(model); err == nil { @@ -819,26 +852,13 @@ func (c *Client) handleQueryError(err error, path string) error { return fmt.Errorf("error querying %s: %w", path, err) } -// normalizeHuggingFaceModelName converts Hugging Face model names to lowercase -func normalizeHuggingFaceModelName(model string) string { - if strings.HasPrefix(model, "hf.co/") { - return strings.ToLower(model) - } - - return model -} - func (c *Client) Tag(source, targetRepo, targetTag string) error { source = normalizeHuggingFaceModelName(source) - // Check if the source is a model ID, and expand it if necessary - if !strings.Contains(strings.Trim(source, "/"), "/") { - // Do an extra API call to check if the model parameter might be a model ID - if expanded, err := c.fullModelID(source); err == nil { - source = expanded - } - } + // For tag operations, let the daemon handle name resolution to support + // partial name matching like "smollm2" -> "ai/smollm2:latest" + // Don't do client-side ID expansion which can cause issues with tagging - // Construct the URL with query parameters + // Construct the URL with query parameters using the normalized source tagPath := fmt.Sprintf("%s/%s/tag?repo=%s&tag=%s", inference.ModelsPrefix, source, diff --git a/cmd/cli/desktop/desktop_test.go b/cmd/cli/desktop/desktop_test.go index 02e0032b..57dac089 100644 --- a/cmd/cli/desktop/desktop_test.go +++ b/cmd/cli/desktop/desktop_test.go @@ -20,7 +20,7 @@ func TestPullHuggingFaceModel(t *testing.T) { // Test case for pulling a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) mockContext := NewContextForMock(mockClient) @@ -46,7 +46,7 @@ func TestChatHuggingFaceModel(t *testing.T) { // Test case for chatting with a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" prompt := "Hello" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) @@ -63,7 +63,7 @@ func TestChatHuggingFaceModel(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")), }, nil) - err := client.Chat(modelName, prompt, nil, func(s string) {}, false) + err := client.Chat(modelName, prompt, []string{}, func(s string) {}, false) assert.NoError(t, err) } @@ -73,7 +73,7 @@ func TestInspectHuggingFaceModel(t *testing.T) { // Test case for inspecting a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) mockContext := NewContextForMock(mockClient) @@ -108,7 +108,6 @@ func TestNonHuggingFaceModel(t *testing.T) { // Test case for a non-Hugging Face model (should not be converted to lowercase) modelName := "docker.io/library/llama2" - expectedWithTag := "docker.io/library/llama2:latest" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) mockContext := NewContextForMock(mockClient) client := New(mockContext) @@ -117,7 +116,7 @@ func TestNonHuggingFaceModel(t *testing.T) { var reqBody models.ModelCreateRequest err := json.NewDecoder(req.Body).Decode(&reqBody) require.NoError(t, err) - assert.Equal(t, expectedWithTag, reqBody.From) + assert.Equal(t, modelName, reqBody.From) }).Return(&http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), @@ -133,7 +132,7 @@ func TestPushHuggingFaceModel(t *testing.T) { // Test case for pushing a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) mockContext := NewContextForMock(mockClient) @@ -156,7 +155,7 @@ func TestRemoveHuggingFaceModel(t *testing.T) { // Test case for removing a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) mockContext := NewContextForMock(mockClient) @@ -203,7 +202,7 @@ func TestInspectOpenAIHuggingFaceModel(t *testing.T) { // Test case for inspecting a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) mockContext := NewContextForMock(mockClient) @@ -225,3 +224,4 @@ func TestInspectOpenAIHuggingFaceModel(t *testing.T) { assert.NoError(t, err) assert.Equal(t, expectedLowercase, model.ID) } + diff --git a/pkg/inference/models/api.go b/pkg/inference/models/api.go index d0897700..50ccc5e7 100644 --- a/pkg/inference/models/api.go +++ b/pkg/inference/models/api.go @@ -19,6 +19,32 @@ type ModelCreateRequest struct { IgnoreRuntimeMemoryCheck bool `json:"ignore-runtime-memory-check,omitempty"` } +// ModelPackageRequest represents a model package request, which creates a new model +// from an existing one with modified properties (e.g., context size). +type ModelPackageRequest struct { + // From is the name of the source model to package from. + From string `json:"from"` + // Tag is the name to give the new packaged model. + Tag string `json:"tag"` + // ContextSize specifies the context size to set for the new model. + ContextSize uint64 `json:"context-size,omitempty"` +} + +// SimpleModel is a wrapper that allows creating a model with modified configuration +type SimpleModel struct { + types.Model + ConfigValue types.Config + DescriptorValue types.Descriptor +} + +func (s *SimpleModel) Config() (types.Config, error) { + return s.ConfigValue, nil +} + +func (s *SimpleModel) Descriptor() (types.Descriptor, error) { + return s.DescriptorValue, nil +} + // ToOpenAIList converts the model list to its OpenAI API representation. This function never // returns a nil slice (though it may return an empty slice). func ToOpenAIList(l []types.Model) (*OpenAIModelList, error) { diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index e2e6e38b..b7128805 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -13,6 +13,7 @@ import ( "sync" "github.com/docker/model-runner/pkg/diskusage" + "github.com/docker/model-runner/pkg/distribution/builder" "github.com/docker/model-runner/pkg/distribution/distribution" "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/types" @@ -173,6 +174,7 @@ func (m *Manager) routeHandlers() map[string]http.HandlerFunc { return map[string]http.HandlerFunc{ "POST " + inference.ModelsPrefix + "/create": m.handleCreateModel, "POST " + inference.ModelsPrefix + "/load": m.handleLoadModel, + "POST " + inference.ModelsPrefix + "/package": m.handlePackageModel, "GET " + inference.ModelsPrefix: m.handleGetModels, "GET " + inference.ModelsPrefix + "/{name...}": m.handleGetModel, "DELETE " + inference.ModelsPrefix + "/{name...}": m.handleDeleteModel, @@ -295,8 +297,7 @@ func (m *Manager) handleGetModels(w http.ResponseWriter, r *http.Request) { // handleGetModel handles GET /models/{name} requests. func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) { - // Normalize model name - modelName := NormalizeModelName(r.PathValue("name")) + modelRef := r.PathValue("name") // Parse remote query parameter remote := false @@ -317,9 +318,24 @@ func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) { var err error if remote { - apiModel, err = getRemoteModel(r.Context(), m, modelName) + // For remote lookups, always normalize the reference + normalizedRef := NormalizeModelName(modelRef) + apiModel, err = getRemoteModel(r.Context(), m, normalizedRef) } else { - apiModel, err = getLocalModel(m, modelName) + // For local lookups, first try without normalization (as ID), then with normalization + apiModel, err = getLocalModel(m, modelRef) + if err != nil && errors.Is(err, distribution.ErrModelNotFound) { + // If not found as-is, try with normalization + normalizedRef := NormalizeModelName(modelRef) + if normalizedRef != modelRef { // only try normalized if it's different + apiModel, err = getLocalModel(m, normalizedRef) + } + } + + // If still not found, try partial name matching (e.g., "smollm2" for "ai/smollm2:latest") + if err != nil && errors.Is(err, distribution.ErrModelNotFound) { + apiModel, err = findModelByPartialName(m, modelRef) + } } if err != nil { @@ -410,6 +426,40 @@ func getRemoteModel(ctx context.Context, m *Manager, name string) (*Model, error return apiModel, nil } +// findModelByPartialName looks for a model by matching the provided reference +// against model tags using partial name matching (e.g., "smollm2" matches "ai/smollm2:latest") +func findModelByPartialName(m *Manager, modelRef string) (*Model, error) { + // Get all models to search through their tags + models, err := m.distributionClient.ListModels() + if err != nil { + return nil, err + } + + // Look for a model whose tags match the reference + for _, model := range models { + for _, tag := range model.Tags() { + // Extract the model name without tag part (e.g., from "ai/smollm2:latest" get "ai/smollm2") + tagWithoutVersion := tag + if idx := strings.LastIndex(tag, ":"); idx != -1 { + tagWithoutVersion = tag[:idx] + } + + // Get just the name part without organization (e.g., from "ai/smollm2" get "smollm2") + namePart := tagWithoutVersion + if idx := strings.LastIndex(tagWithoutVersion, "/"); idx != -1 { + namePart = tagWithoutVersion[idx+1:] + } + + // Check if the reference matches the name part + if namePart == modelRef { + return ToModel(model) + } + } + } + + return nil, distribution.ErrModelNotFound +} + // handleDeleteModel handles DELETE /models/{name} requests. // query params: // - force: if true, delete the model even if it has multiple tags @@ -428,8 +478,7 @@ func (m *Manager) handleDeleteModel(w http.ResponseWriter, r *http.Request) { // the runner process exits (though this won't work for Windows, where we // might need some separate cleanup process). - // Normalize model name - modelName := NormalizeModelName(r.PathValue("name")) + modelRef := r.PathValue("name") var force bool if r.URL.Query().Has("force") { @@ -440,7 +489,16 @@ func (m *Manager) handleDeleteModel(w http.ResponseWriter, r *http.Request) { } } - resp, err := m.distributionClient.DeleteModel(modelName, force) + // First try to delete without normalization (as ID), then with normalization if not found + resp, err := m.distributionClient.DeleteModel(modelRef, force) + if err != nil && errors.Is(err, distribution.ErrModelNotFound) { + // If not found as-is, try with normalization + normalizedRef := NormalizeModelName(modelRef) + if normalizedRef != modelRef { // only try normalized if it's different + resp, err = m.distributionClient.DeleteModel(normalizedRef, force) + } + } + if err != nil { if errors.Is(err, distribution.ErrModelNotFound) { http.Error(w, err.Error(), http.StatusNotFound) @@ -497,11 +555,18 @@ func (m *Manager) handleOpenAIGetModel(w http.ResponseWriter, r *http.Request) { return } - // Normalize model name - modelName := NormalizeModelName(r.PathValue("name")) + modelRef := r.PathValue("name") + + // Query the model - first try without normalization (as ID), then with normalization + model, err := m.GetModel(modelRef) + if err != nil && errors.Is(err, distribution.ErrModelNotFound) { + // If not found as-is, try with normalization + normalizedRef := NormalizeModelName(modelRef) + if normalizedRef != modelRef { // only try normalized if it's different + model, err = m.GetModel(normalizedRef) + } + } - // Query the model. - model, err := m.GetModel(modelName) if err != nil { if errors.Is(err, distribution.ErrModelNotFound) { http.Error(w, err.Error(), http.StatusNotFound) @@ -530,13 +595,15 @@ func (m *Manager) handleOpenAIGetModel(w http.ResponseWriter, r *http.Request) { func (m *Manager) handleModelAction(w http.ResponseWriter, r *http.Request) { model, action := path.Split(r.PathValue("nameAndAction")) model = strings.TrimRight(model, "/") - // Normalize model name - model = NormalizeModelName(model) + + // For tag and push actions, we likely expect model references rather than IDs, + // so normalize the model name, but we'll handle both cases in the handlers + normalizedModel := NormalizeModelName(model) switch action { case "tag": - m.handleTagModel(w, r, model) + m.handleTagModel(w, r, normalizedModel) case "push": - m.handlePushModel(w, r, model) + m.handlePushModel(w, r, normalizedModel) default: http.Error(w, fmt.Sprintf("unknown action %q", action), http.StatusNotFound) } @@ -565,22 +632,106 @@ func (m *Manager) handleTagModel(w http.ResponseWriter, r *http.Request, model s // Construct the target string. target := fmt.Sprintf("%s:%s", repo, tag) - // Call the Tag method on the distribution client with source and modelName. - if err := m.distributionClient.Tag(model, target); err != nil { - m.log.Warnf("Failed to apply tag %q to model %q: %v", target, model, err) + // First try to tag using the provided model reference as-is + err := m.distributionClient.Tag(model, target) + if err != nil && errors.Is(err, distribution.ErrModelNotFound) { + // Check if the model parameter is a model ID (starts with sha256:) or is a partial name + var foundModelRef string + found := false + + // If it looks like an ID, try to find the model by ID + if strings.HasPrefix(model, "sha256:") || len(model) == 12 { // 12-char short ID + // Get all models and find the one matching this ID + models, listErr := m.distributionClient.ListModels() + if listErr != nil { + http.Error(w, fmt.Sprintf("error listing models: %v", listErr), http.StatusInternalServerError) + return + } + + for _, mModel := range models { + modelID, idErr := mModel.ID() + if idErr != nil { + m.log.Warnf("Failed to get model ID: %v", idErr) + continue + } + + // Check if the model ID matches (can be full or short ID) + if modelID == model || strings.HasPrefix(modelID, model) { + // Use the first tag of this model as the source reference + tags := mModel.Tags() + if len(tags) > 0 { + foundModelRef = tags[0] + found = true + break + } + } + } + } - if errors.Is(err, distribution.ErrModelNotFound) { + // If not found by ID, try partial name matching (similar to inspect) + if !found { + models, listErr := m.distributionClient.ListModels() + if listErr != nil { + http.Error(w, fmt.Sprintf("error listing models: %v", listErr), http.StatusInternalServerError) + return + } + + // Look for a model whose tags match the provided reference + for _, mModel := range models { + for _, tagStr := range mModel.Tags() { + // Extract the model name without tag part (e.g., from "ai/smollm2:latest" get "ai/smollm2") + tagWithoutVersion := tagStr + if idx := strings.LastIndex(tagStr, ":"); idx != -1 { + tagWithoutVersion = tagStr[:idx] + } + + // Get just the name part without organization (e.g., from "ai/smollm2" get "smollm2") + namePart := tagWithoutVersion + if idx := strings.LastIndex(tagWithoutVersion, "/"); idx != -1 { + namePart = tagWithoutVersion[idx+1:] + } + + // Check if the provided model matches the name part + if namePart == model { + // Found a match - use the tag string that matched as the source reference + foundModelRef = tagStr + found = true + break + } + } + if found { + break + } + } + } + + if !found { http.Error(w, err.Error(), http.StatusNotFound) return } + // Now tag using the found model reference (the matching tag) + if tagErr := m.distributionClient.Tag(foundModelRef, target); tagErr != nil { + m.log.Warnf("Failed to apply tag %q to resolved model %q: %v", target, foundModelRef, tagErr) + http.Error(w, tagErr.Error(), http.StatusInternalServerError) + return + } + } else if err != nil { + // If there's an error other than not found, return it http.Error(w, err.Error(), http.StatusInternalServerError) return } // Respond with success. + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) - w.Write([]byte(fmt.Sprintf("Model %q tagged successfully with %q", model, target))) + response := map[string]string{ + "message": fmt.Sprintf("Model tagged successfully with %q", target), + "target": target, + } + if err := json.NewEncoder(w).Encode(response); err != nil { + m.log.Warnln("Error while encoding tag response:", err) + } } // handlePushModel handles POST /models/{name}/push requests. @@ -612,6 +763,90 @@ func (m *Manager) handlePushModel(w http.ResponseWriter, r *http.Request, model } } +// handlePackageModel handles POST /models/package requests. +func (m *Manager) handlePackageModel(w http.ResponseWriter, r *http.Request) { + if m.distributionClient == nil { + http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable) + return + } + + // Decode the request + var request ModelPackageRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + // Validate required fields + if request.From == "" || request.Tag == "" { + http.Error(w, "both 'from' and 'tag' fields are required", http.StatusBadRequest) + return + } + + // Normalize the source model name + request.From = NormalizeModelName(request.From) + + // Create a builder from an existing model by getting the bundle first + // Since ModelArtifact interface is needed to work with the builder + bundle, err := m.distributionClient.GetBundle(request.From) + if err != nil { + if errors.Is(err, distribution.ErrModelNotFound) { + http.Error(w, fmt.Sprintf("source model not found: %s", request.From), http.StatusNotFound) + } else { + http.Error(w, fmt.Sprintf("error getting source model bundle %s: %v", request.From, err), http.StatusInternalServerError) + } + return + } + + // Create a builder from the existing model artifact (from the bundle) + modelArtifact, ok := bundle.(types.ModelArtifact) + if !ok { + http.Error(w, "source model does not implement ModelArtifact interface", http.StatusInternalServerError) + return + } + + // Create a builder from the existing model + bldr, err := builder.FromModel(modelArtifact) + if err != nil { + http.Error(w, fmt.Sprintf("error creating builder from model: %v", err), http.StatusInternalServerError) + return + } + + // Apply context size if specified + if request.ContextSize > 0 { + bldr = bldr.WithContextSize(request.ContextSize) + } + + // Get the built model artifact + builtModel := bldr.Model() + + // Check if we can use lightweight repackaging (config-only changes from existing model) + useLightweight := bldr.HasOnlyConfigChanges() + + if useLightweight { + // Use the lightweight method to avoid re-transferring layers + if err := m.distributionClient.WriteLightweightModel(builtModel, []string{request.Tag}); err != nil { + http.Error(w, fmt.Sprintf("error creating lightweight model: %v", err), http.StatusInternalServerError) + return + } + } else { + // If there are layer changes, we need a different approach (this shouldn't happen with context size only) + // For now, return an error if we can't use lightweight + http.Error(w, "only config-only changes are supported for repackaging", http.StatusBadRequest) + return + } + + // Return success response + w.Header().Set("Content-Type", "application/json") + response := map[string]string{ + "message": fmt.Sprintf("Successfully packaged model from %s with tag %s", request.From, request.Tag), + "model": request.Tag, + } + if err := json.NewEncoder(w).Encode(response); err != nil { + m.log.Warnln("Error while encoding package response:", err) + } +} + // handlePurge handles DELETE /models/purge requests. func (m *Manager) handlePurge(w http.ResponseWriter, _ *http.Request) { if m.distributionClient == nil { diff --git a/pkg/inference/models/manager_test.go b/pkg/inference/models/manager_test.go index a792c830..14f9f4da 100644 --- a/pkg/inference/models/manager_test.go +++ b/pkg/inference/models/manager_test.go @@ -214,7 +214,7 @@ func TestHandleGetModel(t *testing.T) { remote: false, modelName: "nonexistent:v1", expectedCode: http.StatusNotFound, - expectedError: "error while getting model", + expectedError: "model not found", }, { name: "get remote model - success", @@ -227,7 +227,7 @@ func TestHandleGetModel(t *testing.T) { remote: true, modelName: uri.Host + "/ai/nonexistent:v1", expectedCode: http.StatusNotFound, - expectedError: "not found", + expectedError: "failed to pull model", }, }