Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions cmd/cli/commands/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,12 @@ func newUpCommand() *cobra.Command {
// Build speculative config if any speculative flags are set
var speculativeConfig *inference.SpeculativeDecodingConfig
if draftModel != "" || numTokens > 0 || minAcceptanceRate > 0 {
normalizedDraftModel := dmrm.NormalizeModelName(draftModel)
speculativeConfig = &inference.SpeculativeDecodingConfig{
DraftModel: normalizedDraftModel,
DraftModel: draftModel,
NumTokens: numTokens,
MinAcceptanceRate: minAcceptanceRate,
}
sendInfo(fmt.Sprintf("Enabling speculative decoding with draft model: %s", normalizedDraftModel))
sendInfo(fmt.Sprintf("Enabling speculative decoding with draft model: %s", draftModel))
}

for _, model := range models {
Expand Down
6 changes: 3 additions & 3 deletions cmd/cli/commands/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (

"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/models"

"github.com/docker/model-runner/pkg/inference/scheduling"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -39,15 +39,15 @@ func newConfigureCmd() *cobra.Command {
argsBeforeDash)
}
}
opts.Model = models.NormalizeModelName(args[0])
opts.Model = args[0]
opts.RuntimeFlags = args[1:]
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
// Build the speculative config if any speculative flags are set
if draftModel != "" || numTokens > 0 || minAcceptanceRate > 0 {
opts.Speculative = &inference.SpeculativeDecodingConfig{
DraftModel: models.NormalizeModelName(draftModel),
DraftModel: draftModel,
NumTokens: numTokens,
MinAcceptanceRate: minAcceptanceRate,
}
Expand Down
5 changes: 2 additions & 3 deletions cmd/cli/commands/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/cmd/cli/commands/formatter"
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/pkg/inference/models"

"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -48,8 +48,7 @@ func newInspectCmd() *cobra.Command {
}

func inspectModel(args []string, openai bool, remote bool, desktopClient *desktop.Client) (string, error) {
// Normalize model name to add default org and tag if missing
modelName := models.NormalizeModelName(args[0])
modelName := args[0]
if openai {
model, err := desktopClient.InspectOpenAI(modelName)
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion cmd/cli/commands/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ func listModels(openai bool, desktopClient *desktop.Client, quiet bool, jsonForm
}

if modelFilter != "" {
// Normalize the filter to match stored model names
// Normalize the filter to match stored model names (backend normalizes when storing)
normalizedFilter := dmrm.NormalizeModelName(modelFilter)
var filteredModels []dmrm.Model
for _, m := range models {
hasMatchingTag := false
for _, tag := range m.Tags {
// Tags are stored in normalized format by the backend
if tag == normalizedFilter {
hasMatchingTag = true
break
Expand Down
7 changes: 3 additions & 4 deletions cmd/cli/commands/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/distribution/tarball"
"github.com/docker/model-runner/pkg/distribution/types"
"github.com/docker/model-runner/pkg/inference/models"

"github.com/google/go-containerregistry/pkg/name"
"github.com/spf13/cobra"

Expand Down Expand Up @@ -325,6 +325,7 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error {
}

cmd.PrintErrln("Model variant created successfully")
return nil // Return early to avoid the Build operation in lightweight case
} else {
// Process directory tar archives
if len(opts.dirTarPaths) > 0 {
Expand Down Expand Up @@ -412,9 +413,7 @@ func newModelRunnerTarget(client *desktop.Client, tag string) (*modelRunnerTarge
}
if tag != "" {
var err error
// Normalize the tag to add default namespace (ai/) and tag (:latest) if missing
normalizedTag := models.NormalizeModelName(tag)
target.tag, err = name.NewTag(normalizedTag)
target.tag, err = name.NewTag(tag)
if err != nil {
return nil, fmt.Errorf("invalid tag: %w", err)
}
Expand Down
4 changes: 1 addition & 3 deletions cmd/cli/commands/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/pkg/inference/models"

"github.com/mattn/go-isatty"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -42,8 +42,6 @@ func newPullCmd() *cobra.Command {
}

func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string, ignoreRuntimeMemoryCheck bool) error {
// Normalize model name to add default org and tag if missing
model = models.NormalizeModelName(model)
var progress func(string)
if isatty.IsTerminal(os.Stdout.Fd()) {
progress = TUIProgress
Expand Down
4 changes: 1 addition & 3 deletions cmd/cli/commands/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (

"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/pkg/inference/models"

"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -35,8 +35,6 @@ func newPushCmd() *cobra.Command {
}

func pushModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error {
// Normalize model name to add default org and tag if missing
model = models.NormalizeModelName(model)
response, progressShown, err := desktopClient.Push(model, TUIProgress)

// Add a newline before any output (success or error) if progress was shown.
Expand Down
9 changes: 2 additions & 7 deletions cmd/cli/commands/rm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"fmt"

"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/pkg/inference/models"

"github.com/spf13/cobra"
)

Expand All @@ -28,12 +28,7 @@ func newRemoveCmd() *cobra.Command {
if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), cmd); err != nil {
return fmt.Errorf("unable to initialize standalone model runner: %w", err)
}
// Normalize model names to add default org and tag if missing
normalizedArgs := make([]string, len(args))
for i, arg := range args {
normalizedArgs[i] = models.NormalizeModelName(arg)
}
response, err := desktopClient.Remove(normalizedArgs, force)
response, err := desktopClient.Remove(args, force)
if response != "" {
cmd.Print(response)
}
Expand Down
5 changes: 2 additions & 3 deletions cmd/cli/commands/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/cmd/cli/readline"
"github.com/docker/model-runner/pkg/inference/models"

"github.com/fatih/color"
"github.com/spf13/cobra"
"golang.org/x/term"
Expand Down Expand Up @@ -586,8 +586,7 @@ func newRunCmd() *cobra.Command {
}
},
RunE: func(cmd *cobra.Command, args []string) error {
// Normalize model name to add default org and tag if missing
model := models.NormalizeModelName(args[0])
model := args[0]
prompt := ""
argsLen := len(args)
if argsLen > 1 {
Expand Down
6 changes: 1 addition & 5 deletions cmd/cli/commands/tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/pkg/inference/models"

"github.com/google/go-containerregistry/pkg/name"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -37,10 +37,6 @@ func newTagCmd() *cobra.Command {
}

func tagModel(cmd *cobra.Command, desktopClient *desktop.Client, source, target string) error {
// Normalize source model name to add default org and tag if missing
source = models.NormalizeModelName(source)
// Normalize target model name to add default org and tag if missing
target = models.NormalizeModelName(target)
// Ensure tag is valid
tag, err := name.NewTag(target)
if err != nil {
Expand Down
9 changes: 2 additions & 7 deletions cmd/cli/commands/unload.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (

"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/pkg/inference/models"

"github.com/spf13/cobra"
)

Expand All @@ -18,12 +18,7 @@ func newUnloadCmd() *cobra.Command {
Use: "unload " + cmdArgs,
Short: "Unload running models",
RunE: func(cmd *cobra.Command, modelArgs []string) error {
// Normalize model names
normalizedModels := make([]string, len(modelArgs))
for i, model := range modelArgs {
normalizedModels[i] = models.NormalizeModelName(model)
}
unloadResp, err := desktopClient.Unload(desktop.UnloadRequest{All: all, Backend: backend, Models: normalizedModels})
unloadResp, err := desktopClient.Unload(desktop.UnloadRequest{All: all, Backend: backend, Models: modelArgs})
if err != nil {
return handleClientError(err, "Failed to unload models")
}
Expand Down
66 changes: 43 additions & 23 deletions cmd/cli/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ type Status struct {
Error error `json:"error"`
}

// normalizeHuggingFaceModelName converts Hugging Face model names to lowercase
func normalizeHuggingFaceModelName(model string) string {
if strings.HasPrefix(model, "hf.co/") {
return strings.ToLower(model)
}
return model
}

func (c *Client) Status() Status {
// TODO: Query "/".
resp, err := c.doRequest(http.MethodGet, inference.ModelsPrefix, nil)
Expand Down Expand Up @@ -98,7 +106,7 @@ func (c *Client) Status() Status {
}

func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) {
model = dmrm.NormalizeModelName(model)
model = normalizeHuggingFaceModelName(model)
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck})
if err != nil {
return "", false, fmt.Errorf("error marshaling request: %w", err)
Expand Down Expand Up @@ -166,7 +174,7 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
}

func (c *Client) Push(model string, progress func(string)) (string, bool, error) {
model = dmrm.NormalizeModelName(model)
model = normalizeHuggingFaceModelName(model)
pushPath := inference.ModelsPrefix + "/" + model + "/push"
resp, err := c.doRequest(
http.MethodPost,
Expand Down Expand Up @@ -247,7 +255,7 @@ func (c *Client) ListOpenAI() (dmrm.OpenAIModelList, error) {
}

func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) {
model = dmrm.NormalizeModelName(model)
model = normalizeHuggingFaceModelName(model)
if model != "" {
if !strings.Contains(strings.Trim(model, "/"), "/") {
// Do an extra API call to check if the model parameter isn't a model ID.
Expand All @@ -271,7 +279,7 @@ func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) {
}

func (c *Client) InspectOpenAI(model string) (dmrm.OpenAIModel, error) {
model = dmrm.NormalizeModelName(model)
model = normalizeHuggingFaceModelName(model)
modelsRoute := inference.InferencePrefix + "/v1/models"
if !strings.Contains(strings.Trim(model, "/"), "/") {
// Do an extra API call to check if the model parameter isn't a model ID.
Expand Down Expand Up @@ -335,6 +343,31 @@ func (c *Client) fullModelID(id string) (string, error) {
if m.ID[7:19] == id || strings.TrimPrefix(m.ID, "sha256:") == id || m.ID == id {
return m.ID, nil
}
// Check if the ID matches any of the model's tags using exact match first
for _, tag := range m.Tags {
if tag == id {
return m.ID, nil
}
}
// If not found with exact match, try partial name matching
for _, tag := range m.Tags {
// Extract the model name without tag part (e.g., from "ai/smollm2:latest" get "ai/smollm2")
tagWithoutVersion := tag
if idx := strings.LastIndex(tag, ":"); idx != -1 {
tagWithoutVersion = tag[:idx]
}

// Get just the name part without organization (e.g., from "ai/smollm2" get "smollm2")
namePart := tagWithoutVersion
if idx := strings.LastIndex(tagWithoutVersion, "/"); idx != -1 {
namePart = tagWithoutVersion[idx+1:]
}

// Check if the ID matches the name part
if namePart == id {
return m.ID, nil
}
Comment on lines +354 to +369
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use name.ParseReference(tag) to get same results:

Suggested change
// Extract the model name without tag part (e.g., from "ai/smollm2:latest" get "ai/smollm2")
tagWithoutVersion := tag
if idx := strings.LastIndex(tag, ":"); idx != -1 {
tagWithoutVersion = tag[:idx]
}
// Get just the name part without organization (e.g., from "ai/smollm2" get "smollm2")
namePart := tagWithoutVersion
if idx := strings.LastIndex(tagWithoutVersion, "/"); idx != -1 {
namePart = tagWithoutVersion[idx+1:]
}
// Check if the ID matches the name part
if namePart == id {
return m.ID, nil
}
// Try to parse both the stored tag and the input ID
tagRef, err1 := name.ParseReference(tag)
idRef, err2 := name.ParseReference(id)
if err1 != nil || err2 != nil {
// fallback to string comparison if parsing fails
if tag == id {
return m.ID, nil
}
continue
}
// Compare normalized names (without tag or digest)
if tagRef.Context().Name() == idRef.Context().Name() {
return m.ID, nil
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think this change also fixes the case were name and tag are specified but not the repository:

MODEL_RUNNER_HOST=http://localhost:13434 docker model inspect ai/smollm2:latest
{
    "id": "sha256:354bf30d0aa3af413d2aa5ae4f23c66d78980072d1e07a5b0d776e9606a2f0b9",
    "tags": [
        "ai/smollm2:latest"
    ],
    "created": 1742816981,
    "config": {
        "format": "gguf",
        "quantization": "IQ2_XXS/Q4_K_M",
        "parameters": "361.82 M",
        "architecture": "llama",
        "size": "256.35 MiB"
    }
}
MODEL_RUNNER_HOST=http://localhost:13434 docker model inspect smollm2:latest
Failed to get model smollm2:latest: invalid model name: smollm2:latest
```

}
}

return "", fmt.Errorf("model with ID %s not found", id)
Expand All @@ -347,7 +380,7 @@ func (c *Client) Chat(model, prompt string, imageURLs []string, outputFunc func(

// ChatWithContext performs a chat request with context support for cancellation and streams the response content with selective markdown rendering.
func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imageURLs []string, outputFunc func(string), shouldUseMarkdown bool) error {
model = dmrm.NormalizeModelName(model)
model = normalizeHuggingFaceModelName(model)
if !strings.Contains(strings.Trim(model, "/"), "/") {
// Do an extra API call to check if the model parameter isn't a model ID.
if expanded, err := c.fullModelID(model); err == nil {
Expand Down Expand Up @@ -523,7 +556,7 @@ func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imag
func (c *Client) Remove(modelArgs []string, force bool) (string, error) {
modelRemoved := ""
for _, model := range modelArgs {
model = dmrm.NormalizeModelName(model)
model = normalizeHuggingFaceModelName(model)
// Check if not a model ID passed as parameter.
if !strings.Contains(model, "/") {
if expanded, err := c.fullModelID(model); err == nil {
Expand Down Expand Up @@ -819,26 +852,13 @@ func (c *Client) handleQueryError(err error, path string) error {
return fmt.Errorf("error querying %s: %w", path, err)
}

// normalizeHuggingFaceModelName converts Hugging Face model names to lowercase
func normalizeHuggingFaceModelName(model string) string {
if strings.HasPrefix(model, "hf.co/") {
return strings.ToLower(model)
}

return model
}

func (c *Client) Tag(source, targetRepo, targetTag string) error {
source = normalizeHuggingFaceModelName(source)
// Check if the source is a model ID, and expand it if necessary
if !strings.Contains(strings.Trim(source, "/"), "/") {
// Do an extra API call to check if the model parameter might be a model ID
if expanded, err := c.fullModelID(source); err == nil {
source = expanded
}
}
// For tag operations, let the daemon handle name resolution to support
// partial name matching like "smollm2" -> "ai/smollm2:latest"
// Don't do client-side ID expansion which can cause issues with tagging

// Construct the URL with query parameters
// Construct the URL with query parameters using the normalized source
tagPath := fmt.Sprintf("%s/%s/tag?repo=%s&tag=%s",
inference.ModelsPrefix,
source,
Expand Down
Loading
Loading