models.go

  1package models
  2
  3import (
  4	"slices"
  5
  6	"github.com/charmbracelet/bubbles/v2/help"
  7	"github.com/charmbracelet/bubbles/v2/key"
  8	tea "github.com/charmbracelet/bubbletea/v2"
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/fur/provider"
 11	"github.com/charmbracelet/crush/internal/tui/components/completions"
 12	"github.com/charmbracelet/crush/internal/tui/components/core"
 13	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 14	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 15	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 16	"github.com/charmbracelet/crush/internal/tui/styles"
 17	"github.com/charmbracelet/crush/internal/tui/util"
 18	"github.com/charmbracelet/lipgloss/v2"
 19)
 20
 21const (
 22	ModelsDialogID dialogs.DialogID = "models"
 23
 24	defaultWidth = 60
 25)
 26
 27const (
 28	LargeModelType int = iota
 29	SmallModelType
 30)
 31
 32// ModelSelectedMsg is sent when a model is selected
 33type ModelSelectedMsg struct {
 34	Model     config.PreferredModel
 35	ModelType config.ModelType
 36}
 37
 38// CloseModelDialogMsg is sent when a model is selected
 39type CloseModelDialogMsg struct{}
 40
 41// ModelDialog interface for the model selection dialog
 42type ModelDialog interface {
 43	dialogs.DialogModel
 44}
 45
 46type ModelOption struct {
 47	Provider provider.Provider
 48	Model    provider.Model
 49}
 50
 51type modelDialogCmp struct {
 52	width   int
 53	wWidth  int
 54	wHeight int
 55
 56	modelList list.ListModel
 57	keyMap    KeyMap
 58	help      help.Model
 59	modelType int
 60}
 61
 62func NewModelDialogCmp() ModelDialog {
 63	listKeyMap := list.DefaultKeyMap()
 64	keyMap := DefaultKeyMap()
 65
 66	listKeyMap.Down.SetEnabled(false)
 67	listKeyMap.Up.SetEnabled(false)
 68	listKeyMap.HalfPageDown.SetEnabled(false)
 69	listKeyMap.HalfPageUp.SetEnabled(false)
 70	listKeyMap.Home.SetEnabled(false)
 71	listKeyMap.End.SetEnabled(false)
 72
 73	listKeyMap.DownOneItem = keyMap.Next
 74	listKeyMap.UpOneItem = keyMap.Previous
 75
 76	t := styles.CurrentTheme()
 77	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 78	modelList := list.New(
 79		list.WithFilterable(true),
 80		list.WithKeyMap(listKeyMap),
 81		list.WithInputStyle(inputStyle),
 82		list.WithWrapNavigation(true),
 83	)
 84	help := help.New()
 85	help.Styles = t.S().Help
 86
 87	return &modelDialogCmp{
 88		modelList: modelList,
 89		width:     defaultWidth,
 90		keyMap:    DefaultKeyMap(),
 91		help:      help,
 92		modelType: LargeModelType,
 93	}
 94}
 95
 96func (m *modelDialogCmp) Init() tea.Cmd {
 97	m.SetModelType(m.modelType)
 98	return m.modelList.Init()
 99}
100
101func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
102	switch msg := msg.(type) {
103	case tea.WindowSizeMsg:
104		m.wWidth = msg.Width
105		m.wHeight = msg.Height
106		m.SetModelType(m.modelType)
107		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
108	case tea.KeyPressMsg:
109		switch {
110		case key.Matches(msg, m.keyMap.Select):
111			selectedItemInx := m.modelList.SelectedIndex()
112			if selectedItemInx == list.NoSelection {
113				return m, nil
114			}
115			items := m.modelList.Items()
116			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
117
118			var modelType config.ModelType
119			if m.modelType == LargeModelType {
120				modelType = config.LargeModel
121			} else {
122				modelType = config.SmallModel
123			}
124
125			return m, tea.Sequence(
126				util.CmdHandler(dialogs.CloseDialogMsg{}),
127				util.CmdHandler(ModelSelectedMsg{
128					Model: config.PreferredModel{
129						ModelID:  selectedItem.Model.ID,
130						Provider: selectedItem.Provider.ID,
131					},
132					ModelType: modelType,
133				}),
134			)
135		case key.Matches(msg, m.keyMap.Tab):
136			if m.modelType == LargeModelType {
137				return m, m.SetModelType(SmallModelType)
138			} else {
139				return m, m.SetModelType(LargeModelType)
140			}
141		case key.Matches(msg, m.keyMap.Close):
142			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
143		default:
144			u, cmd := m.modelList.Update(msg)
145			m.modelList = u.(list.ListModel)
146			return m, cmd
147		}
148	}
149	return m, nil
150}
151
152func (m *modelDialogCmp) View() string {
153	t := styles.CurrentTheme()
154	listView := m.modelList.View()
155	radio := m.modelTypeRadio()
156	content := lipgloss.JoinVertical(
157		lipgloss.Left,
158		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
159		listView,
160		"",
161		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
162	)
163	return m.style().Render(content)
164}
165
166func (m *modelDialogCmp) Cursor() *tea.Cursor {
167	if cursor, ok := m.modelList.(util.Cursor); ok {
168		cursor := cursor.Cursor()
169		if cursor != nil {
170			cursor = m.moveCursor(cursor)
171			return cursor
172		}
173	}
174	return nil
175}
176
177func (m *modelDialogCmp) style() lipgloss.Style {
178	t := styles.CurrentTheme()
179	return t.S().Base.
180		Width(m.width).
181		Border(lipgloss.RoundedBorder()).
182		BorderForeground(t.BorderFocus)
183}
184
185func (m *modelDialogCmp) listWidth() int {
186	return defaultWidth - 2 // 4 for padding
187}
188
189func (m *modelDialogCmp) listHeight() int {
190	listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
191	return min(listHeigh, m.wHeight/2)
192}
193
194func (m *modelDialogCmp) Position() (int, int) {
195	row := m.wHeight/4 - 2 // just a bit above the center
196	col := m.wWidth / 2
197	col -= m.width / 2
198	return row, col
199}
200
201func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
202	row, col := m.Position()
203	offset := row + 3 // Border + title
204	cursor.Y += offset
205	cursor.X = cursor.X + col + 2
206	return cursor
207}
208
209func (m *modelDialogCmp) ID() dialogs.DialogID {
210	return ModelsDialogID
211}
212
213func (m *modelDialogCmp) modelTypeRadio() string {
214	t := styles.CurrentTheme()
215	choices := []string{"Large Task", "Small Task"}
216	iconSelected := "◉"
217	iconUnselected := "○"
218	if m.modelType == LargeModelType {
219		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
220	}
221	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
222}
223
224func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
225	m.modelType = modelType
226
227	providers := config.Providers()
228	modelItems := []util.Model{}
229	selectIndex := 0
230
231	cfg := config.Get()
232	var currentModel config.PreferredModel
233	if m.modelType == LargeModelType {
234		currentModel = cfg.Models.Large
235	} else {
236		currentModel = cfg.Models.Small
237	}
238
239	// Create a map to track which providers we've already added
240	addedProviders := make(map[provider.InferenceProvider]bool)
241
242	// First, add any configured providers that are not in the known providers list
243	// These should appear at the top of the list
244	knownProviders := provider.KnownProviders()
245	for providerID, providerConfig := range cfg.Providers {
246		if providerConfig.Disabled {
247			continue
248		}
249
250		// Check if this provider is not in the known providers list
251		if !slices.Contains(knownProviders, providerID) {
252			// Convert config provider to provider.Provider format
253			configProvider := provider.Provider{
254				Name:   string(providerID), // Use provider ID as name for unknown providers
255				ID:     providerID,
256				Models: make([]provider.Model, len(providerConfig.Models)),
257			}
258
259			// Convert models
260			for i, model := range providerConfig.Models {
261				configProvider.Models[i] = provider.Model{
262					ID:                     model.ID,
263					Name:                   model.Name,
264					CostPer1MIn:            model.CostPer1MIn,
265					CostPer1MOut:           model.CostPer1MOut,
266					CostPer1MInCached:      model.CostPer1MInCached,
267					CostPer1MOutCached:     model.CostPer1MOutCached,
268					ContextWindow:          model.ContextWindow,
269					DefaultMaxTokens:       model.DefaultMaxTokens,
270					CanReason:              model.CanReason,
271					HasReasoningEffort:     model.HasReasoningEffort,
272					DefaultReasoningEffort: model.ReasoningEffort,
273					SupportsImages:         model.SupportsImages,
274				}
275			}
276
277			// Add this unknown provider to the list
278			name := configProvider.Name
279			if name == "" {
280				name = string(configProvider.ID)
281			}
282			modelItems = append(modelItems, commands.NewItemSection(name))
283			for _, model := range configProvider.Models {
284				modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
285					Provider: configProvider,
286					Model:    model,
287				}))
288				if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider {
289					selectIndex = len(modelItems) - 1 // Set the selected index to the current model
290				}
291			}
292			addedProviders[providerID] = true
293		}
294	}
295
296	// Then add the known providers from the predefined list
297	for _, provider := range providers {
298		// Skip if we already added this provider as an unknown provider
299		if addedProviders[provider.ID] {
300			continue
301		}
302
303		// Check if this provider is configured and not disabled
304		if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled {
305			continue
306		}
307
308		name := provider.Name
309		if name == "" {
310			name = string(provider.ID)
311		}
312		modelItems = append(modelItems, commands.NewItemSection(name))
313		for _, model := range provider.Models {
314			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
315				Provider: provider,
316				Model:    model,
317			}))
318			if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider {
319				selectIndex = len(modelItems) - 1 // Set the selected index to the current model
320			}
321		}
322	}
323
324	return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
325}