@@ -2,7 +2,7 @@ package dialog
import (
tea "charm.land/bubbletea/v2"
- "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/session"
)
@@ -24,8 +24,8 @@ type SessionSelectedMsg struct {
// ModelSelectedMsg is a message indicating a model has been selected.
type ModelSelectedMsg struct {
- Provider catwalk.Provider
- Model catwalk.Model
+ Model config.SelectedModel
+ ModelType config.SelectedModelType
}
// Messages for commands
@@ -37,6 +37,30 @@ func (mt ModelType) String() string {
}
}
+// Config returns the corresponding config model type.
+func (mt ModelType) Config() config.SelectedModelType {
+ switch mt {
+ case ModelTypeLarge:
+ return config.SelectedModelTypeLarge
+ case ModelTypeSmall:
+ return config.SelectedModelTypeSmall
+ default:
+ return ""
+ }
+}
+
+// Placeholder returns the input placeholder for the model type.
+func (mt ModelType) Placeholder() string {
+ switch mt {
+ case ModelTypeLarge:
+ return largeModelInputPlaceholder
+ case ModelTypeSmall:
+ return smallModelInputPlaceholder
+ default:
+ return ""
+ }
+}
+
const (
largeModelInputPlaceholder = "Choose a model for large, complex tasks"
smallModelInputPlaceholder = "Choose a model for small, simple tasks"
@@ -180,8 +204,8 @@ func (m *Models) Update(msg tea.Msg) tea.Msg {
}
return ModelSelectedMsg{
- Provider: modelItem.prov,
- Model: modelItem.model,
+ Model: modelItem.SelectedModel(),
+ ModelType: modelItem.SelectedModelType(),
}
case key.Matches(msg, m.keyMap.Tab):
if m.modelType == ModelTypeLarge {
@@ -276,14 +300,8 @@ func (m *Models) setProviderItems() error {
t := m.com.Styles
cfg := m.com.Config()
- selectedType := config.SelectedModelTypeLarge
- if m.modelType == ModelTypeLarge {
- selectedType = config.SelectedModelTypeLarge
- } else {
- selectedType = config.SelectedModelTypeSmall
- }
-
var selectedItemID string
+ selectedType := m.modelType.Config()
currentModel := cfg.Models[selectedType]
recentItems := cfg.RecentModels[selectedType]
@@ -325,7 +343,7 @@ func (m *Models) setProviderItems() error {
group := NewModelGroup(t, name, true)
for _, model := range p.Models {
- item := NewModelItem(t, provider, model, false)
+ item := NewModelItem(t, provider, model, m.modelType, false)
group.AppendItems(item)
itemsMap[item.ID()] = item
if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
@@ -379,7 +397,7 @@ func (m *Models) setProviderItems() error {
group := NewModelGroup(t, name, providerConfigured)
for _, model := range displayProvider.Models {
- item := NewModelItem(t, provider, model, false)
+ item := NewModelItem(t, provider, model, m.modelType, false)
group.AppendItems(item)
itemsMap[item.ID()] = item
if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
@@ -402,7 +420,7 @@ func (m *Models) setProviderItems() error {
}
// Show provider for recent items
- item = NewModelItem(t, item.prov, item.model, true)
+ item = NewModelItem(t, item.prov, item.model, m.modelType, true)
item.showProvider = true
validRecentItems = append(validRecentItems, recent)
@@ -429,11 +447,7 @@ func (m *Models) setProviderItems() error {
m.list.SetSelectedItem(selectedItemID)
// Update placeholder based on model type
- if m.modelType == ModelTypeLarge {
- m.input.Placeholder = largeModelInputPlaceholder
- } else {
- m.input.Placeholder = smallModelInputPlaceholder
- }
+ m.input.Placeholder = m.modelType.Placeholder()
return nil
}
@@ -3,6 +3,7 @@ package dialog
import (
"charm.land/lipgloss/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/ui/common"
"github.com/charmbracelet/crush/internal/ui/styles"
"github.com/charmbracelet/x/ansi"
@@ -49,8 +50,9 @@ func (m *ModelGroup) Render(width int) string {
// ModelItem represents a list item for a model type.
type ModelItem struct {
- prov catwalk.Provider
- model catwalk.Model
+ prov catwalk.Provider
+ model catwalk.Model
+ modelType ModelType
cache map[int]string
t *styles.Styles
@@ -59,13 +61,29 @@ type ModelItem struct {
showProvider bool
}
+// SelectedModel returns this model item as a [config.SelectedModel] instance.
+func (m *ModelItem) SelectedModel() config.SelectedModel {
+ return config.SelectedModel{
+ Model: m.model.ID,
+ Provider: string(m.prov.ID),
+ ReasoningEffort: m.model.DefaultReasoningEffort,
+ MaxTokens: m.model.DefaultMaxTokens,
+ }
+}
+
+// SelectedModelType returns the type of model represented by this item.
+func (m *ModelItem) SelectedModelType() config.SelectedModelType {
+ return m.modelType.Config()
+}
+
var _ ListItem = &ModelItem{}
// NewModelItem creates a new ModelItem.
-func NewModelItem(t *styles.Styles, prov catwalk.Provider, model catwalk.Model, showProvider bool) *ModelItem {
+func NewModelItem(t *styles.Styles, prov catwalk.Provider, model catwalk.Model, typ ModelType, showProvider bool) *ModelItem {
return &ModelItem{
prov: prov,
model: model,
+ modelType: typ,
t: t,
cache: make(map[int]string),
showProvider: showProvider,
@@ -574,7 +574,27 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
case dialog.QuitMsg:
cmds = append(cmds, tea.Quit)
case dialog.ModelSelectedMsg:
- // TODO: Handle model switching
+ if m.com.App.AgentCoordinator.IsBusy() {
+ cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait..."))
+ break
+ }
+
+ cfg := m.com.Config()
+ if cfg == nil {
+ cmds = append(cmds, uiutil.ReportError(errors.New("configuration not found")))
+ break
+ }
+
+ if err := cfg.UpdatePreferredModel(msg.ModelType, msg.Model); err != nil {
+ cmds = append(cmds, uiutil.ReportError(err))
+ }
+
+ // XXX: Should this be in a separate goroutine?
+ go m.com.App.UpdateAgentModel(context.TODO())
+
+ modelMsg := fmt.Sprintf("%s model changed to %s", msg.ModelType, msg.Model.Model)
+ cmds = append(cmds, uiutil.ReportInfo(modelMsg))
+ m.dialog.CloseDialog(dialog.ModelsID)
}
return tea.Batch(cmds...)