From c4a5a99b6f32f7982df83bb5187e3c5ded4cd853 Mon Sep 17 00:00:00 2001 From: Amolith Date: Sat, 3 Jan 2026 20:55:01 -0700 Subject: [PATCH] feat(acp): implement SetSessionModel - Wire SetSessionModel stub to parse provider:model IDs, validate, and update the agent's active model via config and coordinator - Add buildSessionModelState helper to collect available models from all configured providers - Include Models in NewSessionResponse and LoadSessionResponse so clients can display model selection UI Assisted-by: Claude Sonnet 4 via Crush --- internal/acp/agent.go | 87 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 84 insertions(+), 3 deletions(-) diff --git a/internal/acp/agent.go b/internal/acp/agent.go index b0b3c6e13530401e4b6ad1f4617bf802d22ce5f4..f3687f0efb3cd7e74193a3ace5ce44bddc47d108 100644 --- a/internal/acp/agent.go +++ b/internal/acp/agent.go @@ -2,9 +2,12 @@ package acp import ( "context" + "fmt" "log/slog" + "strings" "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/message" "github.com/coder/acp-go-sdk" @@ -81,6 +84,7 @@ func (a *Agent) NewSession(ctx context.Context, params acp.NewSessionRequest) (a return acp.NewSessionResponse{ SessionId: acp.SessionId(sess.ID), + Models: a.buildSessionModelState(), }, nil } @@ -112,7 +116,9 @@ func (a *Agent) LoadSession(ctx context.Context, params acp.LoadSessionRequest) } } - return acp.LoadSessionResponse{}, nil + return acp.LoadSessionResponse{ + Models: a.buildSessionModelState(), + }, nil } // SetSessionMode handles mode switching (stub - Crush doesn't have modes yet). @@ -121,9 +127,44 @@ func (a *Agent) SetSessionMode(ctx context.Context, params acp.SetSessionModeReq return acp.SetSessionModeResponse{}, nil } -// SetSessionModel handles model switching (stub - model selection not yet wired). +// SetSessionModel handles model switching by parsing the model ID and updating +// the agent's active model. func (a *Agent) SetSessionModel(ctx context.Context, params acp.SetSessionModelRequest) (acp.SetSessionModelResponse, error) { - slog.Debug("ACP SetSessionModel", "session_id", params.SessionId, "model_id", params.ModelId) + slog.Info("ACP SetSessionModel", "session_id", params.SessionId, "model_id", params.ModelId) + + // Parse model ID (format: "provider:model"). + parts := strings.SplitN(string(params.ModelId), ":", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return acp.SetSessionModelResponse{}, fmt.Errorf("invalid model ID format %q: expected provider:model", params.ModelId) + } + providerID, modelID := parts[0], parts[1] + + // Validate that the model exists. + cfg := config.Get() + if cfg.GetModel(providerID, modelID) == nil { + return acp.SetSessionModelResponse{}, fmt.Errorf("model %q not found for provider %q", modelID, providerID) + } + + // Check if the agent is busy. + if a.app.AgentCoordinator.IsBusy() { + return acp.SetSessionModelResponse{}, fmt.Errorf("agent is busy, cannot switch models") + } + + // Update the preferred model in config. + selectedModel := config.SelectedModel{ + Provider: providerID, + Model: modelID, + } + if err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel); err != nil { + return acp.SetSessionModelResponse{}, fmt.Errorf("failed to update preferred model: %w", err) + } + + // Apply the model change to the agent. + if err := a.app.UpdateAgentModel(ctx); err != nil { + return acp.SetSessionModelResponse{}, fmt.Errorf("failed to apply model change: %w", err) + } + + slog.Info("ACP SetSessionModel completed", "provider", providerID, "model", modelID) return acp.SetSessionModelResponse{}, nil } @@ -240,3 +281,43 @@ func (a *Agent) translateHistoryPart(role message.MessageRole, part message.Cont return nil } } + +// buildSessionModelState constructs the model state for session responses, +// listing all available models and the currently selected one. +func (a *Agent) buildSessionModelState() *acp.SessionModelState { + cfg := config.Get() + if cfg == nil { + return nil + } + + var availableModels []acp.ModelInfo + for providerID, providerConfig := range cfg.Providers.Seq2() { + if providerConfig.Disable { + continue + } + providerName := providerConfig.Name + if providerName == "" { + providerName = providerID + } + for _, model := range providerConfig.Models { + modelID := acp.ModelId(providerID + ":" + model.ID) + modelName := model.Name + if modelName == "" { + modelName = model.ID + } + availableModels = append(availableModels, acp.ModelInfo{ + ModelId: modelID, + Name: providerName + " / " + modelName, + }) + } + } + + // Get current model. + currentModel := cfg.Models[config.SelectedModelTypeLarge] + currentModelID := acp.ModelId(currentModel.Provider + ":" + currentModel.Model) + + return &acp.SessionModelState{ + AvailableModels: availableModels, + CurrentModelId: currentModelID, + } +}