models.go

  1package models
  2
  3import (
  4	"slices"
  5
  6	"github.com/charmbracelet/bubbles/v2/help"
  7	"github.com/charmbracelet/bubbles/v2/key"
  8	tea "github.com/charmbracelet/bubbletea/v2"
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/llm/models"
 11	"github.com/charmbracelet/crush/internal/tui/components/completions"
 12	"github.com/charmbracelet/crush/internal/tui/components/core"
 13	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 14	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 15	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 16	"github.com/charmbracelet/crush/internal/tui/styles"
 17	"github.com/charmbracelet/crush/internal/tui/util"
 18	"github.com/charmbracelet/lipgloss/v2"
 19)
 20
 21const (
 22	ModelsDialogID dialogs.DialogID = "models"
 23
 24	defaultWidth = 60
 25)
 26
 27// ModelSelectedMsg is sent when a model is selected
 28type ModelSelectedMsg struct {
 29	Model models.Model
 30}
 31
 32// CloseModelDialogMsg is sent when a model is selected
 33type CloseModelDialogMsg struct{}
 34
 35// ModelDialog interface for the model selection dialog
 36type ModelDialog interface {
 37	dialogs.DialogModel
 38}
 39
 40type modelDialogCmp struct {
 41	width   int
 42	wWidth  int // Width of the terminal window
 43	wHeight int // Height of the terminal window
 44
 45	modelList list.ListModel
 46	keyMap    KeyMap
 47	help      help.Model
 48}
 49
 50func NewModelDialogCmp() ModelDialog {
 51	listKeyMap := list.DefaultKeyMap()
 52	keyMap := DefaultKeyMap()
 53
 54	listKeyMap.Down.SetEnabled(false)
 55	listKeyMap.Up.SetEnabled(false)
 56	listKeyMap.HalfPageDown.SetEnabled(false)
 57	listKeyMap.HalfPageUp.SetEnabled(false)
 58	listKeyMap.Home.SetEnabled(false)
 59	listKeyMap.End.SetEnabled(false)
 60
 61	listKeyMap.DownOneItem = keyMap.Next
 62	listKeyMap.UpOneItem = keyMap.Previous
 63
 64	t := styles.CurrentTheme()
 65	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 66	modelList := list.New(
 67		list.WithFilterable(true),
 68		list.WithKeyMap(listKeyMap),
 69		list.WithInputStyle(inputStyle),
 70		list.WithWrapNavigation(true),
 71	)
 72	help := help.New()
 73	help.Styles = t.S().Help
 74
 75	return &modelDialogCmp{
 76		modelList: modelList,
 77		width:     defaultWidth,
 78		keyMap:    DefaultKeyMap(),
 79		help:      help,
 80	}
 81}
 82
 83var ProviderPopularity = map[models.InferenceProvider]int{
 84	models.ProviderAnthropic:  1,
 85	models.ProviderOpenAI:     2,
 86	models.ProviderGemini:     3,
 87	models.ProviderGROQ:       4,
 88	models.ProviderOpenRouter: 5,
 89	models.ProviderBedrock:    6,
 90	models.ProviderAzure:      7,
 91	models.ProviderVertexAI:   8,
 92	models.ProviderXAI:        9,
 93}
 94
 95var ProviderName = map[models.InferenceProvider]string{
 96	models.ProviderAnthropic:  "Anthropic",
 97	models.ProviderOpenAI:     "OpenAI",
 98	models.ProviderGemini:     "Gemini",
 99	models.ProviderGROQ:       "Groq",
100	models.ProviderOpenRouter: "OpenRouter",
101	models.ProviderBedrock:    "AWS Bedrock",
102	models.ProviderAzure:      "Azure",
103	models.ProviderVertexAI:   "VertexAI",
104	models.ProviderXAI:        "xAI",
105}
106
107func (m *modelDialogCmp) Init() tea.Cmd {
108	cfg := config.Get()
109	enabledProviders := getEnabledProviders(cfg)
110
111	modelItems := []util.Model{}
112	for _, provider := range enabledProviders {
113		name, ok := ProviderName[provider]
114		if !ok {
115			name = string(provider) // Fallback to provider ID if name is not defined
116		}
117		modelItems = append(modelItems, commands.NewItemSection(name))
118		for _, model := range getModelsForProvider(provider) {
119			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model))
120		}
121	}
122	m.modelList.SetItems(modelItems)
123	return m.modelList.Init()
124}
125
126func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
127	switch msg := msg.(type) {
128	case tea.WindowSizeMsg:
129		m.wWidth = msg.Width
130		m.wHeight = msg.Height
131		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
132	case tea.KeyPressMsg:
133		switch {
134		case key.Matches(msg, m.keyMap.Select):
135			selectedItemInx := m.modelList.SelectedIndex()
136			if selectedItemInx == list.NoSelection {
137				return m, nil // No item selected, do nothing
138			}
139			items := m.modelList.Items()
140			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model)
141
142			return m, tea.Sequence(
143				util.CmdHandler(dialogs.CloseDialogMsg{}),
144				util.CmdHandler(ModelSelectedMsg{Model: selectedItem}),
145			)
146		case key.Matches(msg, m.keyMap.Close):
147			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
148		default:
149			u, cmd := m.modelList.Update(msg)
150			m.modelList = u.(list.ListModel)
151			return m, cmd
152		}
153	}
154	return m, nil
155}
156
157func (m *modelDialogCmp) View() tea.View {
158	t := styles.CurrentTheme()
159	listView := m.modelList.View()
160	content := lipgloss.JoinVertical(
161		lipgloss.Left,
162		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-4)),
163		listView.String(),
164		"",
165		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
166	)
167	v := tea.NewView(m.style().Render(content))
168	if listView.Cursor() != nil {
169		c := m.moveCursor(listView.Cursor())
170		v.SetCursor(c)
171	}
172	return v
173}
174
175func (m *modelDialogCmp) style() lipgloss.Style {
176	t := styles.CurrentTheme()
177	return t.S().Base.
178		Width(m.width).
179		Border(lipgloss.RoundedBorder()).
180		BorderForeground(t.BorderFocus)
181}
182
183func (m *modelDialogCmp) listWidth() int {
184	return defaultWidth - 2 // 4 for padding
185}
186
187func (m *modelDialogCmp) listHeight() int {
188	listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
189	return min(listHeigh, m.wHeight/2)
190}
191
192func GetSelectedModel(cfg *config.Config) models.Model {
193	agentCfg := cfg.Agents[config.AgentCoder]
194	selectedModelID := agentCfg.Model
195	return models.SupportedModels[selectedModelID]
196}
197
198func getEnabledProviders(cfg *config.Config) []models.InferenceProvider {
199	var providers []models.InferenceProvider
200	for providerID, provider := range cfg.Providers {
201		if !provider.Disabled {
202			providers = append(providers, providerID)
203		}
204	}
205
206	// Sort by provider popularity
207	slices.SortFunc(providers, func(a, b models.InferenceProvider) int {
208		rA := ProviderPopularity[a]
209		rB := ProviderPopularity[b]
210
211		// models not included in popularity ranking default to last
212		if rA == 0 {
213			rA = 999
214		}
215		if rB == 0 {
216			rB = 999
217		}
218		return rA - rB
219	})
220	return providers
221}
222
223func getModelsForProvider(provider models.InferenceProvider) []models.Model {
224	var providerModels []models.Model
225	for _, model := range models.SupportedModels {
226		if model.Provider == provider {
227			providerModels = append(providerModels, model)
228		}
229	}
230
231	// reverse alphabetical order (if llm naming was consistent latest would appear first)
232	slices.SortFunc(providerModels, func(a, b models.Model) int {
233		if a.Name > b.Name {
234			return -1
235		} else if a.Name < b.Name {
236			return 1
237		}
238		return 0
239	})
240
241	return providerModels
242}
243
244func (m *modelDialogCmp) Position() (int, int) {
245	row := m.wHeight/4 - 2 // just a bit above the center
246	col := m.wWidth / 2
247	col -= m.width / 2
248	return row, col
249}
250
251func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
252	row, col := m.Position()
253	offset := row + 3 // Border + title
254	cursor.Y += offset
255	cursor.X = cursor.X + col + 2
256	return cursor
257}
258
259func (m *modelDialogCmp) ID() dialogs.DialogID {
260	return ModelsDialogID
261}