models.go

  1package models
  2
  3import (
  4	"slices"
  5
  6	"github.com/charmbracelet/bubbles/v2/help"
  7	"github.com/charmbracelet/bubbles/v2/key"
  8	tea "github.com/charmbracelet/bubbletea/v2"
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/llm/models"
 11	"github.com/charmbracelet/crush/internal/tui/components/completions"
 12	"github.com/charmbracelet/crush/internal/tui/components/core"
 13	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 14	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 15	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 16	"github.com/charmbracelet/crush/internal/tui/styles"
 17	"github.com/charmbracelet/crush/internal/tui/util"
 18	"github.com/charmbracelet/lipgloss/v2"
 19)
 20
 21const (
 22	ModelsDialogID dialogs.DialogID = "models"
 23
 24	defaultWidth = 60
 25)
 26
 27// ModelSelectedMsg is sent when a model is selected
 28type ModelSelectedMsg struct {
 29	Model models.Model
 30}
 31
 32// CloseModelDialogMsg is sent when a model is selected
 33type CloseModelDialogMsg struct{}
 34
 35// ModelDialog interface for the model selection dialog
 36type ModelDialog interface {
 37	dialogs.DialogModel
 38}
 39
 40type modelDialogCmp struct {
 41	width   int
 42	wWidth  int // Width of the terminal window
 43	wHeight int // Height of the terminal window
 44
 45	modelList list.ListModel
 46	keyMap    KeyMap
 47	help      help.Model
 48}
 49
 50func NewModelDialogCmp() ModelDialog {
 51	listKeyMap := list.DefaultKeyMap()
 52	keyMap := DefaultKeyMap()
 53
 54	listKeyMap.Down.SetEnabled(false)
 55	listKeyMap.Up.SetEnabled(false)
 56	listKeyMap.HalfPageDown.SetEnabled(false)
 57	listKeyMap.HalfPageUp.SetEnabled(false)
 58	listKeyMap.Home.SetEnabled(false)
 59	listKeyMap.End.SetEnabled(false)
 60
 61	listKeyMap.DownOneItem = keyMap.Next
 62	listKeyMap.UpOneItem = keyMap.Previous
 63
 64	t := styles.CurrentTheme()
 65	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 66	modelList := list.New(
 67		list.WithFilterable(true),
 68		list.WithKeyMap(listKeyMap),
 69		list.WithInputStyle(inputStyle),
 70		list.WithWrapNavigation(true),
 71	)
 72	help := help.New()
 73	help.Styles = t.S().Help
 74
 75	return &modelDialogCmp{
 76		modelList: modelList,
 77		width:     defaultWidth,
 78		keyMap:    DefaultKeyMap(),
 79		help:      help,
 80	}
 81}
 82
 83var ProviderPopularity = map[models.ModelProvider]int{
 84	models.ProviderAnthropic:  1,
 85	models.ProviderOpenAI:     2,
 86	models.ProviderGemini:     3,
 87	models.ProviderGROQ:       4,
 88	models.ProviderOpenRouter: 5,
 89	models.ProviderBedrock:    6,
 90	models.ProviderAzure:      7,
 91	models.ProviderVertexAI:   8,
 92	models.ProviderXAI:        9,
 93}
 94
 95var ProviderName = map[models.ModelProvider]string{
 96	models.ProviderAnthropic:  "Anthropic",
 97	models.ProviderOpenAI:     "OpenAI",
 98	models.ProviderGemini:     "Gemini",
 99	models.ProviderGROQ:       "Groq",
100	models.ProviderOpenRouter: "OpenRouter",
101	models.ProviderBedrock:    "AWS Bedrock",
102	models.ProviderAzure:      "Azure",
103	models.ProviderVertexAI:   "VertexAI",
104	models.ProviderXAI:        "xAI",
105}
106
107func (m *modelDialogCmp) Init() tea.Cmd {
108	cfg := config.Get()
109	enabledProviders := getEnabledProviders(cfg)
110
111	modelItems := []util.Model{}
112	for _, provider := range enabledProviders {
113		name, ok := ProviderName[provider]
114		if !ok {
115			name = string(provider) // Fallback to provider ID if name is not defined
116		}
117		modelItems = append(modelItems, commands.NewItemSection(name))
118		for _, model := range getModelsForProvider(provider) {
119			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model))
120		}
121	}
122	m.modelList.SetItems(modelItems)
123	return m.modelList.Init()
124}
125
126func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
127	switch msg := msg.(type) {
128	case tea.WindowSizeMsg:
129		m.wWidth = msg.Width
130		m.wHeight = msg.Height
131		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
132	case tea.KeyPressMsg:
133		switch {
134		case key.Matches(msg, m.keyMap.Select):
135			selectedItemInx := m.modelList.SelectedIndex()
136			if selectedItemInx == list.NoSelection {
137				return m, nil // No item selected, do nothing
138			}
139			items := m.modelList.Items()
140			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model)
141
142			return m, tea.Sequence(
143				util.CmdHandler(dialogs.CloseDialogMsg{}),
144				util.CmdHandler(ModelSelectedMsg{Model: selectedItem}),
145			)
146		case key.Matches(msg, m.keyMap.Close):
147			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
148		default:
149			u, cmd := m.modelList.Update(msg)
150			m.modelList = u.(list.ListModel)
151			return m, cmd
152		}
153	}
154	return m, nil
155}
156
157func (m *modelDialogCmp) View() string {
158	t := styles.CurrentTheme()
159	listView := m.modelList.View()
160	content := lipgloss.JoinVertical(
161		lipgloss.Left,
162		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-4)),
163		listView,
164		"",
165		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
166	)
167	return m.style().Render(content)
168}
169
170func (m *modelDialogCmp) Cursor() *tea.Cursor {
171	if cursor, ok := m.modelList.(util.Cursor); ok {
172		cursor := cursor.Cursor()
173		if cursor != nil {
174			cursor = m.moveCursor(cursor)
175			return cursor
176		}
177	}
178	return nil
179}
180
181func (m *modelDialogCmp) style() lipgloss.Style {
182	t := styles.CurrentTheme()
183	return t.S().Base.
184		Width(m.width).
185		Border(lipgloss.RoundedBorder()).
186		BorderForeground(t.BorderFocus)
187}
188
189func (m *modelDialogCmp) listWidth() int {
190	return defaultWidth - 2 // 4 for padding
191}
192
193func (m *modelDialogCmp) listHeight() int {
194	listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
195	return min(listHeigh, m.wHeight/2)
196}
197
198func GetSelectedModel(cfg *config.Config) models.Model {
199	agentCfg := cfg.Agents[config.AgentCoder]
200	selectedModelID := agentCfg.Model
201	return models.SupportedModels[selectedModelID]
202}
203
204func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
205	var providers []models.ModelProvider
206	for providerID, provider := range cfg.Providers {
207		if !provider.Disabled {
208			providers = append(providers, providerID)
209		}
210	}
211
212	// Sort by provider popularity
213	slices.SortFunc(providers, func(a, b models.ModelProvider) int {
214		rA := ProviderPopularity[a]
215		rB := ProviderPopularity[b]
216
217		// models not included in popularity ranking default to last
218		if rA == 0 {
219			rA = 999
220		}
221		if rB == 0 {
222			rB = 999
223		}
224		return rA - rB
225	})
226	return providers
227}
228
229func getModelsForProvider(provider models.ModelProvider) []models.Model {
230	var providerModels []models.Model
231	for _, model := range models.SupportedModels {
232		if model.Provider == provider {
233			providerModels = append(providerModels, model)
234		}
235	}
236
237	// reverse alphabetical order (if llm naming was consistent latest would appear first)
238	slices.SortFunc(providerModels, func(a, b models.Model) int {
239		if a.Name > b.Name {
240			return -1
241		} else if a.Name < b.Name {
242			return 1
243		}
244		return 0
245	})
246
247	return providerModels
248}
249
250func (m *modelDialogCmp) Position() (int, int) {
251	row := m.wHeight/4 - 2 // just a bit above the center
252	col := m.wWidth / 2
253	col -= m.width / 2
254	return row, col
255}
256
257func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
258	row, col := m.Position()
259	offset := row + 3 // Border + title
260	cursor.Y += offset
261	cursor.X = cursor.X + col + 2
262	return cursor
263}
264
265func (m *modelDialogCmp) ID() dialogs.DialogID {
266	return ModelsDialogID
267}