Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions cmd/cli/commands/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -743,11 +743,7 @@ func newRunCmd() *cobra.Command {

// Handle --detach flag: just load the model without interaction
if detach {
// Make a minimal request to load the model into memory
err := desktopClient.Chat(model, "", nil, func(content string) {
// Silently discard output in detach mode
}, false)
if err != nil {
if err := desktopClient.Preload(cmd.Context(), model); err != nil {
return handleClientError(err, "Failed to load model")
}
if debug {
Expand All @@ -764,6 +760,14 @@ func newRunCmd() *cobra.Command {
return nil
}

// For interactive mode, eagerly load the model in the background
// while the user types their first query
go func() {
if err := desktopClient.Preload(cmd.Context(), model); err != nil {
cmd.PrintErrf("background model preload failed: %v\n", err)
}
}()

// Initialize termenv with color caching before starting interactive session.
// This queries the terminal background color once and caches it, preventing
// OSC response sequences from appearing in stdin during the interactive loop.
Expand Down
39 changes: 39 additions & 0 deletions cmd/cli/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,45 @@ func (c *Client) Chat(model, prompt string, imageURLs []string, outputFunc func(
return c.ChatWithContext(context.Background(), model, prompt, imageURLs, outputFunc, shouldUseMarkdown)
}

// Preload loads a model into memory without running inference.
// The model stays loaded for the idle timeout period.
func (c *Client) Preload(ctx context.Context, model string) error {
reqBody := OpenAIChatRequest{
Model: model,
Messages: []OpenAIChatMessage{},
}

jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("error marshaling request: %w", err)
}

completionsPath := c.modelRunner.OpenAIPathPrefix() + "/chat/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.modelRunner.URL(completionsPath), bytes.NewReader(jsonData))
if err != nil {
return fmt.Errorf("error creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "docker-model-cli/"+Version)
req.Header.Set("X-Preload-Only", "true")

resp, err := c.modelRunner.Client().Do(req)
if err != nil {
return c.handleQueryError(err, completionsPath)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("preload failed with status %d and could not read response body: %w", resp.StatusCode, err)
}
return fmt.Errorf("preload failed: status=%d body=%s", resp.StatusCode, body)
}

return nil
}

// ChatWithMessagesContext performs a chat request with conversation history and returns the assistant's response.
// This allows maintaining conversation context across multiple exchanges.
func (c *Client) ChatWithMessagesContext(ctx context.Context, model string, conversationHistory []OpenAIChatMessage, prompt string, imageURLs []string, outputFunc func(string), shouldUseMarkdown bool) (string, error) {
Expand Down
45 changes: 44 additions & 1 deletion pkg/inference/scheduling/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"time"

"github.com/docker/model-runner/pkg/distribution/distribution"
"github.com/docker/model-runner/pkg/inference"
Expand All @@ -19,6 +21,10 @@ import (
"github.com/docker/model-runner/pkg/middleware"
)

type contextKey bool

const preloadOnlyKey contextKey = false

// HTTPHandler handles HTTP requests for the scheduler.
// It wraps the Scheduler to provide HTTP endpoint functionality without
// coupling the core scheduling logic to HTTP concerns.
Expand Down Expand Up @@ -223,6 +229,12 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque
}
defer h.scheduler.loader.release(runner)

// If this is a preload-only request, return here without running inference.
// Can be triggered via context (internal) or X-Preload-Only header (external).
if r.Context().Value(preloadOnlyKey) != nil || r.Header.Get("X-Preload-Only") == "true" {
return
}

// Record the request in the OpenAI recorder.
recordID := h.scheduler.openAIRecorder.RecordRequest(request.Model, r, body)
w = h.scheduler.openAIRecorder.NewResponseRecorder(w)
Expand Down Expand Up @@ -357,7 +369,7 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) {
return
}

_, err = h.scheduler.ConfigureRunner(r.Context(), backend, configureRequest, r.UserAgent())
backend, err = h.scheduler.ConfigureRunner(r.Context(), backend, configureRequest, r.UserAgent())
if err != nil {
if errors.Is(err, errRunnerAlreadyActive) {
http.Error(w, err.Error(), http.StatusConflict)
Expand All @@ -367,6 +379,37 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) {
return
}

// Preload the model in the background by calling handleOpenAIInference with preload-only context.
// This makes Compose preload the model as well as it calls `configure` by default.
go func() {
preloadBody, err := json.Marshal(OpenAIInferenceRequest{Model: configureRequest.Model})
if err != nil {
h.scheduler.log.Warnf("failed to marshal preload request body: %v", err)
return
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
preloadReq, err := http.NewRequestWithContext(
context.WithValue(ctx, preloadOnlyKey, true),
http.MethodPost,
inference.InferencePrefix+"/v1/chat/completions",
bytes.NewReader(preloadBody),
)
if err != nil {
h.scheduler.log.Warnf("failed to create preload request: %v", err)
return
}
preloadReq.Header.Set("User-Agent", r.UserAgent())
if backend != nil {
preloadReq.SetPathValue("backend", backend.Name())
}
recorder := httptest.NewRecorder()
h.handleOpenAIInference(recorder, preloadReq)
if recorder.Code != http.StatusOK {
h.scheduler.log.Warnf("background model preload failed with status %d: %s", recorder.Code, recorder.Body.String())
}
}()

w.WriteHeader(http.StatusAccepted)
}

Expand Down
Loading