models.go

  1package models
  2
  3import (
  4	"slices"
  5
  6	"github.com/charmbracelet/bubbles/v2/help"
  7	"github.com/charmbracelet/bubbles/v2/key"
  8	tea "github.com/charmbracelet/bubbletea/v2"
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/llm/models"
 11	"github.com/charmbracelet/crush/internal/tui/components/completions"
 12	"github.com/charmbracelet/crush/internal/tui/components/core"
 13	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 14	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 15	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 16	"github.com/charmbracelet/crush/internal/tui/styles"
 17	"github.com/charmbracelet/crush/internal/tui/util"
 18	"github.com/charmbracelet/lipgloss/v2"
 19)
 20
 21const (
 22	ModelsDialogID dialogs.DialogID = "models"
 23
 24	defaultWidth = 60
 25)
 26
 27// ModelSelectedMsg is sent when a model is selected
 28type ModelSelectedMsg struct {
 29	Model models.Model
 30}
 31
 32// CloseModelDialogMsg is sent when a model is selected
 33type CloseModelDialogMsg struct{}
 34
 35// ModelDialog interface for the model selection dialog
 36type ModelDialog interface {
 37	dialogs.DialogModel
 38}
 39
 40type modelDialogCmp struct {
 41	width   int
 42	wWidth  int // Width of the terminal window
 43	wHeight int // Height of the terminal window
 44
 45	modelList list.ListModel
 46	keyMap    KeyMap
 47	help      help.Model
 48}
 49
 50func NewModelDialogCmp() ModelDialog {
 51	listKeyMap := list.DefaultKeyMap()
 52	keyMap := DefaultKeyMap()
 53
 54	listKeyMap.Down.SetEnabled(false)
 55	listKeyMap.Up.SetEnabled(false)
 56	listKeyMap.NDown.SetEnabled(false)
 57	listKeyMap.NUp.SetEnabled(false)
 58	listKeyMap.HalfPageDown.SetEnabled(false)
 59	listKeyMap.HalfPageUp.SetEnabled(false)
 60	listKeyMap.Home.SetEnabled(false)
 61	listKeyMap.End.SetEnabled(false)
 62
 63	listKeyMap.DownOneItem = keyMap.Next
 64	listKeyMap.UpOneItem = keyMap.Previous
 65
 66	t := styles.CurrentTheme()
 67	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 68	modelList := list.New(
 69		list.WithFilterable(true),
 70		list.WithKeyMap(listKeyMap),
 71		list.WithInputStyle(inputStyle),
 72		list.WithWrapNavigation(true),
 73	)
 74	help := help.New()
 75	help.Styles = t.S().Help
 76
 77	return &modelDialogCmp{
 78		modelList: modelList,
 79		width:     defaultWidth,
 80		keyMap:    DefaultKeyMap(),
 81		help:      help,
 82	}
 83}
 84
 85var ProviderPopularity = map[models.ModelProvider]int{
 86	models.ProviderAnthropic:  1,
 87	models.ProviderOpenAI:     2,
 88	models.ProviderGemini:     3,
 89	models.ProviderGROQ:       4,
 90	models.ProviderOpenRouter: 5,
 91	models.ProviderBedrock:    6,
 92	models.ProviderAzure:      7,
 93	models.ProviderVertexAI:   8,
 94	models.ProviderXAI:        9,
 95}
 96
 97var ProviderName = map[models.ModelProvider]string{
 98	models.ProviderAnthropic:  "Anthropic",
 99	models.ProviderOpenAI:     "OpenAI",
100	models.ProviderGemini:     "Gemini",
101	models.ProviderGROQ:       "Groq",
102	models.ProviderOpenRouter: "OpenRouter",
103	models.ProviderBedrock:    "AWS Bedrock",
104	models.ProviderAzure:      "Azure",
105	models.ProviderVertexAI:   "VertexAI",
106	models.ProviderXAI:        "xAI",
107}
108
109func (m *modelDialogCmp) Init() tea.Cmd {
110	cfg := config.Get()
111	enabledProviders := getEnabledProviders(cfg)
112
113	modelItems := []util.Model{}
114	for _, provider := range enabledProviders {
115		name, ok := ProviderName[provider]
116		if !ok {
117			name = string(provider) // Fallback to provider ID if name is not defined
118		}
119		modelItems = append(modelItems, commands.NewItemSection(name))
120		for _, model := range getModelsForProvider(provider) {
121			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model))
122		}
123	}
124	m.modelList.SetItems(modelItems)
125	return m.modelList.Init()
126}
127
128func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
129	switch msg := msg.(type) {
130	case tea.WindowSizeMsg:
131		m.wWidth = msg.Width
132		m.wHeight = msg.Height
133		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
134	case tea.KeyPressMsg:
135		switch {
136		case key.Matches(msg, m.keyMap.Select):
137			selectedItemInx := m.modelList.SelectedIndex()
138			if selectedItemInx == list.NoSelection {
139				return m, nil // No item selected, do nothing
140			}
141			items := m.modelList.Items()
142			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model)
143
144			return m, tea.Sequence(
145				util.CmdHandler(dialogs.CloseDialogMsg{}),
146				util.CmdHandler(ModelSelectedMsg{Model: selectedItem}),
147			)
148		case key.Matches(msg, m.keyMap.Close):
149			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
150		default:
151			u, cmd := m.modelList.Update(msg)
152			m.modelList = u.(list.ListModel)
153			return m, cmd
154		}
155	}
156	return m, nil
157}
158
159func (m *modelDialogCmp) View() tea.View {
160	t := styles.CurrentTheme()
161	listView := m.modelList.View()
162	content := lipgloss.JoinVertical(
163		lipgloss.Left,
164		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-4)),
165		listView.String(),
166		"",
167		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
168	)
169	v := tea.NewView(m.style().Render(content))
170	if listView.Cursor() != nil {
171		c := m.moveCursor(listView.Cursor())
172		v.SetCursor(c)
173	}
174	return v
175}
176
177func (m *modelDialogCmp) style() lipgloss.Style {
178	t := styles.CurrentTheme()
179	return t.S().Base.
180		Width(m.width).
181		Border(lipgloss.RoundedBorder()).
182		BorderForeground(t.BorderFocus)
183}
184
185func (m *modelDialogCmp) listWidth() int {
186	return defaultWidth - 2 // 4 for padding
187}
188
189func (m *modelDialogCmp) listHeight() int {
190	listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
191	return min(listHeigh, m.wHeight/2)
192}
193
194func GetSelectedModel(cfg *config.Config) models.Model {
195	agentCfg := cfg.Agents[config.AgentCoder]
196	selectedModelID := agentCfg.Model
197	return models.SupportedModels[selectedModelID]
198}
199
200func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
201	var providers []models.ModelProvider
202	for providerID, provider := range cfg.Providers {
203		if !provider.Disabled {
204			providers = append(providers, providerID)
205		}
206	}
207
208	// Sort by provider popularity
209	slices.SortFunc(providers, func(a, b models.ModelProvider) int {
210		rA := ProviderPopularity[a]
211		rB := ProviderPopularity[b]
212
213		// models not included in popularity ranking default to last
214		if rA == 0 {
215			rA = 999
216		}
217		if rB == 0 {
218			rB = 999
219		}
220		return rA - rB
221	})
222	return providers
223}
224
225func getModelsForProvider(provider models.ModelProvider) []models.Model {
226	var providerModels []models.Model
227	for _, model := range models.SupportedModels {
228		if model.Provider == provider {
229			providerModels = append(providerModels, model)
230		}
231	}
232
233	// reverse alphabetical order (if llm naming was consistent latest would appear first)
234	slices.SortFunc(providerModels, func(a, b models.Model) int {
235		if a.Name > b.Name {
236			return -1
237		} else if a.Name < b.Name {
238			return 1
239		}
240		return 0
241	})
242
243	return providerModels
244}
245
246func (m *modelDialogCmp) Position() (int, int) {
247	row := m.wHeight/4 - 2 // just a bit above the center
248	col := m.wWidth / 2
249	col -= m.width / 2
250	return row, col
251}
252
253func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
254	row, col := m.Position()
255	offset := row + 3 // Border + title
256	cursor.Y += offset
257	cursor.X = cursor.X + col + 2
258	return cursor
259}
260
261func (m *modelDialogCmp) ID() dialogs.DialogID {
262	return ModelsDialogID
263}