models.go

  1package models
  2
  3import (
  4	"slices"
  5
  6	"github.com/charmbracelet/bubbles/v2/help"
  7	"github.com/charmbracelet/bubbles/v2/key"
  8	tea "github.com/charmbracelet/bubbletea/v2"
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/fur/provider"
 11	"github.com/charmbracelet/crush/internal/tui/components/completions"
 12	"github.com/charmbracelet/crush/internal/tui/components/core"
 13	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 14	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 15	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 16	"github.com/charmbracelet/crush/internal/tui/styles"
 17	"github.com/charmbracelet/crush/internal/tui/util"
 18	"github.com/charmbracelet/lipgloss/v2"
 19)
 20
 21const (
 22	ModelsDialogID dialogs.DialogID = "models"
 23
 24	defaultWidth = 60
 25)
 26
 27const (
 28	LargeModelType int = iota
 29	SmallModelType
 30)
 31
 32// ModelSelectedMsg is sent when a model is selected
 33type ModelSelectedMsg struct {
 34	Model     config.SelectedModel
 35	ModelType config.SelectedModelType
 36}
 37
 38// CloseModelDialogMsg is sent when a model is selected
 39type CloseModelDialogMsg struct{}
 40
 41// ModelDialog interface for the model selection dialog
 42type ModelDialog interface {
 43	dialogs.DialogModel
 44}
 45
 46type ModelOption struct {
 47	Provider provider.Provider
 48	Model    provider.Model
 49}
 50
 51type modelDialogCmp struct {
 52	width   int
 53	wWidth  int
 54	wHeight int
 55
 56	modelList list.ListModel
 57	keyMap    KeyMap
 58	help      help.Model
 59	modelType int
 60}
 61
 62func NewModelDialogCmp() ModelDialog {
 63	listKeyMap := list.DefaultKeyMap()
 64	keyMap := DefaultKeyMap()
 65
 66	listKeyMap.Down.SetEnabled(false)
 67	listKeyMap.Up.SetEnabled(false)
 68	listKeyMap.HalfPageDown.SetEnabled(false)
 69	listKeyMap.HalfPageUp.SetEnabled(false)
 70	listKeyMap.Home.SetEnabled(false)
 71	listKeyMap.End.SetEnabled(false)
 72
 73	listKeyMap.DownOneItem = keyMap.Next
 74	listKeyMap.UpOneItem = keyMap.Previous
 75
 76	t := styles.CurrentTheme()
 77	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 78	modelList := list.New(
 79		list.WithFilterable(true),
 80		list.WithKeyMap(listKeyMap),
 81		list.WithInputStyle(inputStyle),
 82		list.WithWrapNavigation(true),
 83	)
 84	help := help.New()
 85	help.Styles = t.S().Help
 86
 87	return &modelDialogCmp{
 88		modelList: modelList,
 89		width:     defaultWidth,
 90		keyMap:    DefaultKeyMap(),
 91		help:      help,
 92		modelType: LargeModelType,
 93	}
 94}
 95
 96func (m *modelDialogCmp) Init() tea.Cmd {
 97	m.SetModelType(m.modelType)
 98	return m.modelList.Init()
 99}
100
101func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
102	switch msg := msg.(type) {
103	case tea.WindowSizeMsg:
104		m.wWidth = msg.Width
105		m.wHeight = msg.Height
106		m.SetModelType(m.modelType)
107		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
108	case tea.KeyPressMsg:
109		switch {
110		case key.Matches(msg, m.keyMap.Select):
111			selectedItemInx := m.modelList.SelectedIndex()
112			if selectedItemInx == list.NoSelection {
113				return m, nil
114			}
115			items := m.modelList.Items()
116			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
117
118			var modelType config.SelectedModelType
119			if m.modelType == LargeModelType {
120				modelType = config.SelectedModelTypeLarge
121			} else {
122				modelType = config.SelectedModelTypeSmall
123			}
124
125			return m, tea.Sequence(
126				util.CmdHandler(dialogs.CloseDialogMsg{}),
127				util.CmdHandler(ModelSelectedMsg{
128					Model: config.SelectedModel{
129						Model:    selectedItem.Model.ID,
130						Provider: string(selectedItem.Provider.ID),
131					},
132					ModelType: modelType,
133				}),
134			)
135		case key.Matches(msg, m.keyMap.Tab):
136			if m.modelType == LargeModelType {
137				return m, m.SetModelType(SmallModelType)
138			} else {
139				return m, m.SetModelType(LargeModelType)
140			}
141		case key.Matches(msg, m.keyMap.Close):
142			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
143		default:
144			u, cmd := m.modelList.Update(msg)
145			m.modelList = u.(list.ListModel)
146			return m, cmd
147		}
148	}
149	return m, nil
150}
151
152func (m *modelDialogCmp) View() tea.View {
153	t := styles.CurrentTheme()
154	listView := m.modelList.View()
155	radio := m.modelTypeRadio()
156	content := lipgloss.JoinVertical(
157		lipgloss.Left,
158		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
159		listView.String(),
160		"",
161		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
162	)
163	v := tea.NewView(m.style().Render(content))
164	if listView.Cursor() != nil {
165		c := m.moveCursor(listView.Cursor())
166		v.SetCursor(c)
167	}
168	return v
169}
170
171func (m *modelDialogCmp) style() lipgloss.Style {
172	t := styles.CurrentTheme()
173	return t.S().Base.
174		Width(m.width).
175		Border(lipgloss.RoundedBorder()).
176		BorderForeground(t.BorderFocus)
177}
178
179func (m *modelDialogCmp) listWidth() int {
180	return defaultWidth - 2 // 4 for padding
181}
182
183func (m *modelDialogCmp) listHeight() int {
184	listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
185	return min(listHeigh, m.wHeight/2)
186}
187
188func (m *modelDialogCmp) Position() (int, int) {
189	row := m.wHeight/4 - 2 // just a bit above the center
190	col := m.wWidth / 2
191	col -= m.width / 2
192	return row, col
193}
194
195func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
196	row, col := m.Position()
197	offset := row + 3 // Border + title
198	cursor.Y += offset
199	cursor.X = cursor.X + col + 2
200	return cursor
201}
202
203func (m *modelDialogCmp) ID() dialogs.DialogID {
204	return ModelsDialogID
205}
206
207func (m *modelDialogCmp) modelTypeRadio() string {
208	t := styles.CurrentTheme()
209	choices := []string{"Large Task", "Small Task"}
210	iconSelected := "◉"
211	iconUnselected := "○"
212	if m.modelType == LargeModelType {
213		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
214	}
215	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
216}
217
218func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
219	m.modelType = modelType
220
221	providers, err := config.Providers()
222	if err != nil {
223		return util.ReportError(err)
224	}
225
226	modelItems := []util.Model{}
227	selectIndex := 0
228
229	cfg := config.Get()
230	var currentModel config.SelectedModel
231	if m.modelType == LargeModelType {
232		currentModel = cfg.Models[config.SelectedModelTypeLarge]
233	} else {
234		currentModel = cfg.Models[config.SelectedModelTypeSmall]
235	}
236
237	// Create a map to track which providers we've already added
238	addedProviders := make(map[string]bool)
239
240	// First, add any configured providers that are not in the known providers list
241	// These should appear at the top of the list
242	knownProviders := provider.KnownProviders()
243	for providerID, providerConfig := range cfg.Providers {
244		if providerConfig.Disable {
245			continue
246		}
247
248		// Check if this provider is not in the known providers list
249		if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
250			// Convert config provider to provider.Provider format
251			configProvider := provider.Provider{
252				Name:   string(providerID), // Use provider ID as name for unknown providers
253				ID:     provider.InferenceProvider(providerID),
254				Models: make([]provider.Model, len(providerConfig.Models)),
255			}
256
257			// Convert models
258			for i, model := range providerConfig.Models {
259				configProvider.Models[i] = provider.Model{
260					ID:                     model.ID,
261					Name:                   model.Name,
262					CostPer1MIn:            model.CostPer1MIn,
263					CostPer1MOut:           model.CostPer1MOut,
264					CostPer1MInCached:      model.CostPer1MInCached,
265					CostPer1MOutCached:     model.CostPer1MOutCached,
266					ContextWindow:          model.ContextWindow,
267					DefaultMaxTokens:       model.DefaultMaxTokens,
268					CanReason:              model.CanReason,
269					HasReasoningEffort:     model.HasReasoningEffort,
270					DefaultReasoningEffort: model.DefaultReasoningEffort,
271					SupportsImages:         model.SupportsImages,
272				}
273			}
274
275			// Add this unknown provider to the list
276			name := configProvider.Name
277			if name == "" {
278				name = string(configProvider.ID)
279			}
280			modelItems = append(modelItems, commands.NewItemSection(name))
281			for _, model := range configProvider.Models {
282				modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
283					Provider: configProvider,
284					Model:    model,
285				}))
286				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
287					selectIndex = len(modelItems) - 1 // Set the selected index to the current model
288				}
289			}
290			addedProviders[providerID] = true
291		}
292	}
293
294	// Then add the known providers from the predefined list
295	for _, provider := range providers {
296		// Skip if we already added this provider as an unknown provider
297		if addedProviders[string(provider.ID)] {
298			continue
299		}
300
301		// Check if this provider is configured and not disabled
302		if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
303			continue
304		}
305
306		name := provider.Name
307		if name == "" {
308			name = string(provider.ID)
309		}
310		modelItems = append(modelItems, commands.NewItemSection(name))
311		for _, model := range provider.Models {
312			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
313				Provider: provider,
314				Model:    model,
315			}))
316			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
317				selectIndex = len(modelItems) - 1 // Set the selected index to the current model
318			}
319		}
320	}
321
322	return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
323}