From 90491b4b8ae5be3605cd92d77418edfd5aa4f3e1 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Mon, 5 Jan 2026 12:45:48 -0500 Subject: [PATCH] feat(ui): model dialog: implement model selection handling --- internal/config/config.go | 5 ++++ internal/ui/dialog/messages.go | 6 ++-- internal/ui/dialog/models.go | 48 ++++++++++++++++++++----------- internal/ui/dialog/models_item.go | 24 ++++++++++++++-- internal/ui/model/ui.go | 22 +++++++++++++- 5 files changed, 81 insertions(+), 24 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index e68ad8c27ca7e3c2313a3b18b48bcbedc3d677e9..22f5ea87be2676920e49b472565c4aaf7425c52c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -54,6 +54,11 @@ var defaultContextPaths = []string{ type SelectedModelType string +// String returns the string representation of the [SelectedModelType]. +func (s SelectedModelType) String() string { + return string(s) +} + const ( SelectedModelTypeLarge SelectedModelType = "large" SelectedModelTypeSmall SelectedModelType = "small" diff --git a/internal/ui/dialog/messages.go b/internal/ui/dialog/messages.go index 2d69e4c9b841fcc9d8776e8e6fb5cf04e3d1d0f0..8efc59240e83ea8137cdaf14a7c87f903b8683b5 100644 --- a/internal/ui/dialog/messages.go +++ b/internal/ui/dialog/messages.go @@ -2,7 +2,7 @@ package dialog import ( tea "charm.land/bubbletea/v2" - "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/session" ) @@ -24,8 +24,8 @@ type SessionSelectedMsg struct { // ModelSelectedMsg is a message indicating a model has been selected. type ModelSelectedMsg struct { - Provider catwalk.Provider - Model catwalk.Model + Model config.SelectedModel + ModelType config.SelectedModelType } // Messages for commands diff --git a/internal/ui/dialog/models.go b/internal/ui/dialog/models.go index 47b50cd758d68e84c1b0615b0b98b237ebbbcdb6..f0e0cb3b3e94f5749fe249b2758561f8ea615a9a 100644 --- a/internal/ui/dialog/models.go +++ b/internal/ui/dialog/models.go @@ -37,6 +37,30 @@ func (mt ModelType) String() string { } } +// Config returns the corresponding config model type. +func (mt ModelType) Config() config.SelectedModelType { + switch mt { + case ModelTypeLarge: + return config.SelectedModelTypeLarge + case ModelTypeSmall: + return config.SelectedModelTypeSmall + default: + return "" + } +} + +// Placeholder returns the input placeholder for the model type. +func (mt ModelType) Placeholder() string { + switch mt { + case ModelTypeLarge: + return largeModelInputPlaceholder + case ModelTypeSmall: + return smallModelInputPlaceholder + default: + return "" + } +} + const ( largeModelInputPlaceholder = "Choose a model for large, complex tasks" smallModelInputPlaceholder = "Choose a model for small, simple tasks" @@ -180,8 +204,8 @@ func (m *Models) Update(msg tea.Msg) tea.Msg { } return ModelSelectedMsg{ - Provider: modelItem.prov, - Model: modelItem.model, + Model: modelItem.SelectedModel(), + ModelType: modelItem.SelectedModelType(), } case key.Matches(msg, m.keyMap.Tab): if m.modelType == ModelTypeLarge { @@ -276,14 +300,8 @@ func (m *Models) setProviderItems() error { t := m.com.Styles cfg := m.com.Config() - selectedType := config.SelectedModelTypeLarge - if m.modelType == ModelTypeLarge { - selectedType = config.SelectedModelTypeLarge - } else { - selectedType = config.SelectedModelTypeSmall - } - var selectedItemID string + selectedType := m.modelType.Config() currentModel := cfg.Models[selectedType] recentItems := cfg.RecentModels[selectedType] @@ -325,7 +343,7 @@ func (m *Models) setProviderItems() error { group := NewModelGroup(t, name, true) for _, model := range p.Models { - item := NewModelItem(t, provider, model, false) + item := NewModelItem(t, provider, model, m.modelType, false) group.AppendItems(item) itemsMap[item.ID()] = item if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider { @@ -379,7 +397,7 @@ func (m *Models) setProviderItems() error { group := NewModelGroup(t, name, providerConfigured) for _, model := range displayProvider.Models { - item := NewModelItem(t, provider, model, false) + item := NewModelItem(t, provider, model, m.modelType, false) group.AppendItems(item) itemsMap[item.ID()] = item if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider { @@ -402,7 +420,7 @@ func (m *Models) setProviderItems() error { } // Show provider for recent items - item = NewModelItem(t, item.prov, item.model, true) + item = NewModelItem(t, item.prov, item.model, m.modelType, true) item.showProvider = true validRecentItems = append(validRecentItems, recent) @@ -429,11 +447,7 @@ func (m *Models) setProviderItems() error { m.list.SetSelectedItem(selectedItemID) // Update placeholder based on model type - if m.modelType == ModelTypeLarge { - m.input.Placeholder = largeModelInputPlaceholder - } else { - m.input.Placeholder = smallModelInputPlaceholder - } + m.input.Placeholder = m.modelType.Placeholder() return nil } diff --git a/internal/ui/dialog/models_item.go b/internal/ui/dialog/models_item.go index 46722f8592e417af5b90d6187d42a5cc11a89f7c..40a8a25c57cd7cf0ce6252ef3113ce2af2f8d2f4 100644 --- a/internal/ui/dialog/models_item.go +++ b/internal/ui/dialog/models_item.go @@ -3,6 +3,7 @@ package dialog import ( "charm.land/lipgloss/v2" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/ui/common" "github.com/charmbracelet/crush/internal/ui/styles" "github.com/charmbracelet/x/ansi" @@ -49,8 +50,9 @@ func (m *ModelGroup) Render(width int) string { // ModelItem represents a list item for a model type. type ModelItem struct { - prov catwalk.Provider - model catwalk.Model + prov catwalk.Provider + model catwalk.Model + modelType ModelType cache map[int]string t *styles.Styles @@ -59,13 +61,29 @@ type ModelItem struct { showProvider bool } +// SelectedModel returns this model item as a [config.SelectedModel] instance. +func (m *ModelItem) SelectedModel() config.SelectedModel { + return config.SelectedModel{ + Model: m.model.ID, + Provider: string(m.prov.ID), + ReasoningEffort: m.model.DefaultReasoningEffort, + MaxTokens: m.model.DefaultMaxTokens, + } +} + +// SelectedModelType returns the type of model represented by this item. +func (m *ModelItem) SelectedModelType() config.SelectedModelType { + return m.modelType.Config() +} + var _ ListItem = &ModelItem{} // NewModelItem creates a new ModelItem. -func NewModelItem(t *styles.Styles, prov catwalk.Provider, model catwalk.Model, showProvider bool) *ModelItem { +func NewModelItem(t *styles.Styles, prov catwalk.Provider, model catwalk.Model, typ ModelType, showProvider bool) *ModelItem { return &ModelItem{ prov: prov, model: model, + modelType: typ, t: t, cache: make(map[int]string), showProvider: showProvider, diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index a52339208f64a2d628803af97bd3aef92965af5a..0f5dfecad68c00659b4c13ddd2c632844fd6e9da 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -574,7 +574,27 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { case dialog.QuitMsg: cmds = append(cmds, tea.Quit) case dialog.ModelSelectedMsg: - // TODO: Handle model switching + if m.com.App.AgentCoordinator.IsBusy() { + cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait...")) + break + } + + cfg := m.com.Config() + if cfg == nil { + cmds = append(cmds, uiutil.ReportError(errors.New("configuration not found"))) + break + } + + if err := cfg.UpdatePreferredModel(msg.ModelType, msg.Model); err != nil { + cmds = append(cmds, uiutil.ReportError(err)) + } + + // XXX: Should this be in a separate goroutine? + go m.com.App.UpdateAgentModel(context.TODO()) + + modelMsg := fmt.Sprintf("%s model changed to %s", msg.ModelType, msg.Model.Model) + cmds = append(cmds, uiutil.ReportInfo(modelMsg)) + m.dialog.CloseDialog(dialog.ModelsID) } return tea.Batch(cmds...)