models.go

  1package models
  2
  3import (
  4	"slices"
  5
  6	"github.com/charmbracelet/bubbles/v2/help"
  7	"github.com/charmbracelet/bubbles/v2/key"
  8	tea "github.com/charmbracelet/bubbletea/v2"
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/fur/provider"
 11	"github.com/charmbracelet/crush/internal/tui/components/completions"
 12	"github.com/charmbracelet/crush/internal/tui/components/core"
 13	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 14	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 15	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 16	"github.com/charmbracelet/crush/internal/tui/styles"
 17	"github.com/charmbracelet/crush/internal/tui/util"
 18	"github.com/charmbracelet/lipgloss/v2"
 19)
 20
 21const (
 22	ModelsDialogID dialogs.DialogID = "models"
 23
 24	defaultWidth = 60
 25)
 26
 27const (
 28	LargeModelType int = iota
 29	SmallModelType
 30)
 31
 32// ModelSelectedMsg is sent when a model is selected
 33type ModelSelectedMsg struct {
 34	Model     config.SelectedModel
 35	ModelType config.SelectedModelType
 36}
 37
 38// CloseModelDialogMsg is sent when a model is selected
 39type CloseModelDialogMsg struct{}
 40
 41// ModelDialog interface for the model selection dialog
 42type ModelDialog interface {
 43	dialogs.DialogModel
 44}
 45
 46type ModelOption struct {
 47	Provider provider.Provider
 48	Model    provider.Model
 49}
 50
 51type modelDialogCmp struct {
 52	width   int
 53	wWidth  int
 54	wHeight int
 55
 56	modelList list.ListModel
 57	keyMap    KeyMap
 58	help      help.Model
 59	modelType int
 60}
 61
 62func NewModelDialogCmp() ModelDialog {
 63	listKeyMap := list.DefaultKeyMap()
 64	keyMap := DefaultKeyMap()
 65
 66	listKeyMap.Down.SetEnabled(false)
 67	listKeyMap.Up.SetEnabled(false)
 68	listKeyMap.HalfPageDown.SetEnabled(false)
 69	listKeyMap.HalfPageUp.SetEnabled(false)
 70	listKeyMap.Home.SetEnabled(false)
 71	listKeyMap.End.SetEnabled(false)
 72
 73	listKeyMap.DownOneItem = keyMap.Next
 74	listKeyMap.UpOneItem = keyMap.Previous
 75
 76	t := styles.CurrentTheme()
 77	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 78	modelList := list.New(
 79		list.WithFilterable(true),
 80		list.WithKeyMap(listKeyMap),
 81		list.WithInputStyle(inputStyle),
 82		list.WithWrapNavigation(true),
 83	)
 84	help := help.New()
 85	help.Styles = t.S().Help
 86
 87	return &modelDialogCmp{
 88		modelList: modelList,
 89		width:     defaultWidth,
 90		keyMap:    DefaultKeyMap(),
 91		help:      help,
 92		modelType: LargeModelType,
 93	}
 94}
 95
 96func (m *modelDialogCmp) Init() tea.Cmd {
 97	m.SetModelType(m.modelType)
 98	return m.modelList.Init()
 99}
100
101func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
102	switch msg := msg.(type) {
103	case tea.WindowSizeMsg:
104		m.wWidth = msg.Width
105		m.wHeight = msg.Height
106		m.SetModelType(m.modelType)
107		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
108	case tea.KeyPressMsg:
109		switch {
110		case key.Matches(msg, m.keyMap.Select):
111			selectedItemInx := m.modelList.SelectedIndex()
112			if selectedItemInx == list.NoSelection {
113				return m, nil
114			}
115			items := m.modelList.Items()
116			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
117
118			var modelType config.SelectedModelType
119			if m.modelType == LargeModelType {
120				modelType = config.SelectedModelTypeLarge
121			} else {
122				modelType = config.SelectedModelTypeSmall
123			}
124
125			return m, tea.Sequence(
126				util.CmdHandler(dialogs.CloseDialogMsg{}),
127				util.CmdHandler(ModelSelectedMsg{
128					Model: config.SelectedModel{
129						Model:    selectedItem.Model.ID,
130						Provider: string(selectedItem.Provider.ID),
131					},
132					ModelType: modelType,
133				}),
134			)
135		case key.Matches(msg, m.keyMap.Tab):
136			if m.modelType == LargeModelType {
137				return m, m.SetModelType(SmallModelType)
138			} else {
139				return m, m.SetModelType(LargeModelType)
140			}
141		case key.Matches(msg, m.keyMap.Close):
142			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
143		default:
144			u, cmd := m.modelList.Update(msg)
145			m.modelList = u.(list.ListModel)
146			return m, cmd
147		}
148	}
149	return m, nil
150}
151
152func (m *modelDialogCmp) View() string {
153	t := styles.CurrentTheme()
154	listView := m.modelList.View()
155	radio := m.modelTypeRadio()
156	content := lipgloss.JoinVertical(
157		lipgloss.Left,
158		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
159		listView,
160		"",
161		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
162	)
163	return m.style().Render(content)
164}
165
166func (m *modelDialogCmp) Cursor() *tea.Cursor {
167	if cursor, ok := m.modelList.(util.Cursor); ok {
168		cursor := cursor.Cursor()
169		if cursor != nil {
170			cursor = m.moveCursor(cursor)
171			return cursor
172		}
173	}
174	return nil
175}
176
177func (m *modelDialogCmp) style() lipgloss.Style {
178	t := styles.CurrentTheme()
179	return t.S().Base.
180		Width(m.width).
181		Border(lipgloss.RoundedBorder()).
182		BorderForeground(t.BorderFocus)
183}
184
185func (m *modelDialogCmp) listWidth() int {
186	return defaultWidth - 2 // 4 for padding
187}
188
189func (m *modelDialogCmp) listHeight() int {
190	listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
191	return min(listHeigh, m.wHeight/2)
192}
193
194func (m *modelDialogCmp) Position() (int, int) {
195	row := m.wHeight/4 - 2 // just a bit above the center
196	col := m.wWidth / 2
197	col -= m.width / 2
198	return row, col
199}
200
201func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
202	row, col := m.Position()
203	offset := row + 3 // Border + title
204	cursor.Y += offset
205	cursor.X = cursor.X + col + 2
206	return cursor
207}
208
209func (m *modelDialogCmp) ID() dialogs.DialogID {
210	return ModelsDialogID
211}
212
213func (m *modelDialogCmp) modelTypeRadio() string {
214	t := styles.CurrentTheme()
215	choices := []string{"Large Task", "Small Task"}
216	iconSelected := "◉"
217	iconUnselected := "○"
218	if m.modelType == LargeModelType {
219		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
220	}
221	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
222}
223
224func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
225	m.modelType = modelType
226
227	providers, err := config.Providers()
228	if err != nil {
229		return util.ReportError(err)
230	}
231
232	modelItems := []util.Model{}
233	selectIndex := 0
234
235	cfg := config.Get()
236	var currentModel config.SelectedModel
237	if m.modelType == LargeModelType {
238		currentModel = cfg.Models[config.SelectedModelTypeLarge]
239	} else {
240		currentModel = cfg.Models[config.SelectedModelTypeSmall]
241	}
242
243	// Create a map to track which providers we've already added
244	addedProviders := make(map[string]bool)
245
246	// First, add any configured providers that are not in the known providers list
247	// These should appear at the top of the list
248	knownProviders := provider.KnownProviders()
249	for providerID, providerConfig := range cfg.Providers {
250		if providerConfig.Disable {
251			continue
252		}
253
254		// Check if this provider is not in the known providers list
255		if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
256			// Convert config provider to provider.Provider format
257			configProvider := provider.Provider{
258				Name:   string(providerID), // Use provider ID as name for unknown providers
259				ID:     provider.InferenceProvider(providerID),
260				Models: make([]provider.Model, len(providerConfig.Models)),
261			}
262
263			// Convert models
264			for i, model := range providerConfig.Models {
265				configProvider.Models[i] = provider.Model{
266					ID:                     model.ID,
267					Name:                   model.Name,
268					CostPer1MIn:            model.CostPer1MIn,
269					CostPer1MOut:           model.CostPer1MOut,
270					CostPer1MInCached:      model.CostPer1MInCached,
271					CostPer1MOutCached:     model.CostPer1MOutCached,
272					ContextWindow:          model.ContextWindow,
273					DefaultMaxTokens:       model.DefaultMaxTokens,
274					CanReason:              model.CanReason,
275					HasReasoningEffort:     model.HasReasoningEffort,
276					DefaultReasoningEffort: model.DefaultReasoningEffort,
277					SupportsImages:         model.SupportsImages,
278				}
279			}
280
281			// Add this unknown provider to the list
282			name := configProvider.Name
283			if name == "" {
284				name = string(configProvider.ID)
285			}
286			modelItems = append(modelItems, commands.NewItemSection(name))
287			for _, model := range configProvider.Models {
288				modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
289					Provider: configProvider,
290					Model:    model,
291				}))
292				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
293					selectIndex = len(modelItems) - 1 // Set the selected index to the current model
294				}
295			}
296			addedProviders[providerID] = true
297		}
298	}
299
300	// Then add the known providers from the predefined list
301	for _, provider := range providers {
302		// Skip if we already added this provider as an unknown provider
303		if addedProviders[string(provider.ID)] {
304			continue
305		}
306
307		// Check if this provider is configured and not disabled
308		if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
309			continue
310		}
311
312		name := provider.Name
313		if name == "" {
314			name = string(provider.ID)
315		}
316		modelItems = append(modelItems, commands.NewItemSection(name))
317		for _, model := range provider.Models {
318			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
319				Provider: provider,
320				Model:    model,
321			}))
322			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
323				selectIndex = len(modelItems) - 1 // Set the selected index to the current model
324			}
325		}
326	}
327
328	return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
329}