feat(ui): model dialog: implement model selection handling

Ayman Bagabas created

Change summary

internal/config/config.go         |  5 +++
internal/ui/dialog/messages.go    |  6 ++--
internal/ui/dialog/models.go      | 48 +++++++++++++++++++++-----------
internal/ui/dialog/models_item.go | 24 ++++++++++++++--
internal/ui/model/ui.go           | 22 ++++++++++++++
5 files changed, 81 insertions(+), 24 deletions(-)

Detailed changes

internal/config/config.go 🔗

@@ -54,6 +54,11 @@ var defaultContextPaths = []string{
 
 type SelectedModelType string
 
+// String returns the string representation of the [SelectedModelType].
+func (s SelectedModelType) String() string {
+	return string(s)
+}
+
 const (
 	SelectedModelTypeLarge SelectedModelType = "large"
 	SelectedModelTypeSmall SelectedModelType = "small"

internal/ui/dialog/messages.go 🔗

@@ -2,7 +2,7 @@ package dialog
 
 import (
 	tea "charm.land/bubbletea/v2"
-	"github.com/charmbracelet/catwalk/pkg/catwalk"
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/session"
 )
 
@@ -24,8 +24,8 @@ type SessionSelectedMsg struct {
 
 // ModelSelectedMsg is a message indicating a model has been selected.
 type ModelSelectedMsg struct {
-	Provider catwalk.Provider
-	Model    catwalk.Model
+	Model     config.SelectedModel
+	ModelType config.SelectedModelType
 }
 
 // Messages for commands

internal/ui/dialog/models.go 🔗

@@ -37,6 +37,30 @@ func (mt ModelType) String() string {
 	}
 }
 
+// Config returns the corresponding config model type.
+func (mt ModelType) Config() config.SelectedModelType {
+	switch mt {
+	case ModelTypeLarge:
+		return config.SelectedModelTypeLarge
+	case ModelTypeSmall:
+		return config.SelectedModelTypeSmall
+	default:
+		return ""
+	}
+}
+
+// Placeholder returns the input placeholder for the model type.
+func (mt ModelType) Placeholder() string {
+	switch mt {
+	case ModelTypeLarge:
+		return largeModelInputPlaceholder
+	case ModelTypeSmall:
+		return smallModelInputPlaceholder
+	default:
+		return ""
+	}
+}
+
 const (
 	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
@@ -180,8 +204,8 @@ func (m *Models) Update(msg tea.Msg) tea.Msg {
 			}
 
 			return ModelSelectedMsg{
-				Provider: modelItem.prov,
-				Model:    modelItem.model,
+				Model:     modelItem.SelectedModel(),
+				ModelType: modelItem.SelectedModelType(),
 			}
 		case key.Matches(msg, m.keyMap.Tab):
 			if m.modelType == ModelTypeLarge {
@@ -276,14 +300,8 @@ func (m *Models) setProviderItems() error {
 	t := m.com.Styles
 	cfg := m.com.Config()
 
-	selectedType := config.SelectedModelTypeLarge
-	if m.modelType == ModelTypeLarge {
-		selectedType = config.SelectedModelTypeLarge
-	} else {
-		selectedType = config.SelectedModelTypeSmall
-	}
-
 	var selectedItemID string
+	selectedType := m.modelType.Config()
 	currentModel := cfg.Models[selectedType]
 	recentItems := cfg.RecentModels[selectedType]
 
@@ -325,7 +343,7 @@ func (m *Models) setProviderItems() error {
 
 			group := NewModelGroup(t, name, true)
 			for _, model := range p.Models {
-				item := NewModelItem(t, provider, model, false)
+				item := NewModelItem(t, provider, model, m.modelType, false)
 				group.AppendItems(item)
 				itemsMap[item.ID()] = item
 				if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
@@ -379,7 +397,7 @@ func (m *Models) setProviderItems() error {
 
 		group := NewModelGroup(t, name, providerConfigured)
 		for _, model := range displayProvider.Models {
-			item := NewModelItem(t, provider, model, false)
+			item := NewModelItem(t, provider, model, m.modelType, false)
 			group.AppendItems(item)
 			itemsMap[item.ID()] = item
 			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
@@ -402,7 +420,7 @@ func (m *Models) setProviderItems() error {
 			}
 
 			// Show provider for recent items
-			item = NewModelItem(t, item.prov, item.model, true)
+			item = NewModelItem(t, item.prov, item.model, m.modelType, true)
 			item.showProvider = true
 
 			validRecentItems = append(validRecentItems, recent)
@@ -429,11 +447,7 @@ func (m *Models) setProviderItems() error {
 	m.list.SetSelectedItem(selectedItemID)
 
 	// Update placeholder based on model type
-	if m.modelType == ModelTypeLarge {
-		m.input.Placeholder = largeModelInputPlaceholder
-	} else {
-		m.input.Placeholder = smallModelInputPlaceholder
-	}
+	m.input.Placeholder = m.modelType.Placeholder()
 
 	return nil
 }

internal/ui/dialog/models_item.go 🔗

@@ -3,6 +3,7 @@ package dialog
 import (
 	"charm.land/lipgloss/v2"
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/ui/common"
 	"github.com/charmbracelet/crush/internal/ui/styles"
 	"github.com/charmbracelet/x/ansi"
@@ -49,8 +50,9 @@ func (m *ModelGroup) Render(width int) string {
 
 // ModelItem represents a list item for a model type.
 type ModelItem struct {
-	prov  catwalk.Provider
-	model catwalk.Model
+	prov      catwalk.Provider
+	model     catwalk.Model
+	modelType ModelType
 
 	cache        map[int]string
 	t            *styles.Styles
@@ -59,13 +61,29 @@ type ModelItem struct {
 	showProvider bool
 }
 
+// SelectedModel returns this model item as a [config.SelectedModel] instance.
+func (m *ModelItem) SelectedModel() config.SelectedModel {
+	return config.SelectedModel{
+		Model:           m.model.ID,
+		Provider:        string(m.prov.ID),
+		ReasoningEffort: m.model.DefaultReasoningEffort,
+		MaxTokens:       m.model.DefaultMaxTokens,
+	}
+}
+
+// SelectedModelType returns the type of model represented by this item.
+func (m *ModelItem) SelectedModelType() config.SelectedModelType {
+	return m.modelType.Config()
+}
+
 var _ ListItem = &ModelItem{}
 
 // NewModelItem creates a new ModelItem.
-func NewModelItem(t *styles.Styles, prov catwalk.Provider, model catwalk.Model, showProvider bool) *ModelItem {
+func NewModelItem(t *styles.Styles, prov catwalk.Provider, model catwalk.Model, typ ModelType, showProvider bool) *ModelItem {
 	return &ModelItem{
 		prov:         prov,
 		model:        model,
+		modelType:    typ,
 		t:            t,
 		cache:        make(map[int]string),
 		showProvider: showProvider,

internal/ui/model/ui.go 🔗

@@ -574,7 +574,27 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
 		case dialog.QuitMsg:
 			cmds = append(cmds, tea.Quit)
 		case dialog.ModelSelectedMsg:
-			// TODO: Handle model switching
+			if m.com.App.AgentCoordinator.IsBusy() {
+				cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait..."))
+				break
+			}
+
+			cfg := m.com.Config()
+			if cfg == nil {
+				cmds = append(cmds, uiutil.ReportError(errors.New("configuration not found")))
+				break
+			}
+
+			if err := cfg.UpdatePreferredModel(msg.ModelType, msg.Model); err != nil {
+				cmds = append(cmds, uiutil.ReportError(err))
+			}
+
+			// XXX: Should this be in a separate goroutine?
+			go m.com.App.UpdateAgentModel(context.TODO())
+
+			modelMsg := fmt.Sprintf("%s model changed to %s", msg.ModelType, msg.Model.Model)
+			cmds = append(cmds, uiutil.ReportInfo(modelMsg))
+			m.dialog.CloseDialog(dialog.ModelsID)
 		}
 
 		return tea.Batch(cmds...)