1package models
2
3import (
4 "slices"
5
6 "github.com/charmbracelet/bubbles/v2/help"
7 "github.com/charmbracelet/bubbles/v2/key"
8 tea "github.com/charmbracelet/bubbletea/v2"
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/llm/models"
11 "github.com/charmbracelet/crush/internal/tui/components/completions"
12 "github.com/charmbracelet/crush/internal/tui/components/core"
13 "github.com/charmbracelet/crush/internal/tui/components/core/list"
14 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
15 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
16 "github.com/charmbracelet/crush/internal/tui/styles"
17 "github.com/charmbracelet/crush/internal/tui/util"
18 "github.com/charmbracelet/lipgloss/v2"
19)
20
21const (
22 ModelsDialogID dialogs.DialogID = "models"
23
24 defaultWidth = 60
25)
26
27// ModelSelectedMsg is sent when a model is selected
28type ModelSelectedMsg struct {
29 Model models.Model
30}
31
32// CloseModelDialogMsg is sent when a model is selected
33type CloseModelDialogMsg struct{}
34
35// ModelDialog interface for the model selection dialog
36type ModelDialog interface {
37 dialogs.DialogModel
38}
39
40type modelDialogCmp struct {
41 width int
42 wWidth int // Width of the terminal window
43 wHeight int // Height of the terminal window
44
45 modelList list.ListModel
46 keyMap KeyMap
47 help help.Model
48}
49
50func NewModelDialogCmp() ModelDialog {
51 listKeyMap := list.DefaultKeyMap()
52 keyMap := DefaultKeyMap()
53
54 listKeyMap.Down.SetEnabled(false)
55 listKeyMap.Up.SetEnabled(false)
56 listKeyMap.HalfPageDown.SetEnabled(false)
57 listKeyMap.HalfPageUp.SetEnabled(false)
58 listKeyMap.Home.SetEnabled(false)
59 listKeyMap.End.SetEnabled(false)
60
61 listKeyMap.DownOneItem = keyMap.Next
62 listKeyMap.UpOneItem = keyMap.Previous
63
64 t := styles.CurrentTheme()
65 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
66 modelList := list.New(
67 list.WithFilterable(true),
68 list.WithKeyMap(listKeyMap),
69 list.WithInputStyle(inputStyle),
70 list.WithWrapNavigation(true),
71 )
72 help := help.New()
73 help.Styles = t.S().Help
74
75 return &modelDialogCmp{
76 modelList: modelList,
77 width: defaultWidth,
78 keyMap: DefaultKeyMap(),
79 help: help,
80 }
81}
82
83var ProviderPopularity = map[models.InferenceProvider]int{
84 models.ProviderAnthropic: 1,
85 models.ProviderOpenAI: 2,
86 models.ProviderGemini: 3,
87 models.ProviderGROQ: 4,
88 models.ProviderOpenRouter: 5,
89 models.ProviderBedrock: 6,
90 models.ProviderAzure: 7,
91 models.ProviderVertexAI: 8,
92 models.ProviderXAI: 9,
93}
94
95var ProviderName = map[models.InferenceProvider]string{
96 models.ProviderAnthropic: "Anthropic",
97 models.ProviderOpenAI: "OpenAI",
98 models.ProviderGemini: "Gemini",
99 models.ProviderGROQ: "Groq",
100 models.ProviderOpenRouter: "OpenRouter",
101 models.ProviderBedrock: "AWS Bedrock",
102 models.ProviderAzure: "Azure",
103 models.ProviderVertexAI: "VertexAI",
104 models.ProviderXAI: "xAI",
105}
106
107func (m *modelDialogCmp) Init() tea.Cmd {
108 cfg := config.Get()
109 enabledProviders := getEnabledProviders(cfg)
110
111 modelItems := []util.Model{}
112 for _, provider := range enabledProviders {
113 name, ok := ProviderName[provider]
114 if !ok {
115 name = string(provider) // Fallback to provider ID if name is not defined
116 }
117 modelItems = append(modelItems, commands.NewItemSection(name))
118 for _, model := range getModelsForProvider(provider) {
119 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model))
120 }
121 }
122 m.modelList.SetItems(modelItems)
123 return m.modelList.Init()
124}
125
126func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
127 switch msg := msg.(type) {
128 case tea.WindowSizeMsg:
129 m.wWidth = msg.Width
130 m.wHeight = msg.Height
131 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
132 case tea.KeyPressMsg:
133 switch {
134 case key.Matches(msg, m.keyMap.Select):
135 selectedItemInx := m.modelList.SelectedIndex()
136 if selectedItemInx == list.NoSelection {
137 return m, nil // No item selected, do nothing
138 }
139 items := m.modelList.Items()
140 selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model)
141
142 return m, tea.Sequence(
143 util.CmdHandler(dialogs.CloseDialogMsg{}),
144 util.CmdHandler(ModelSelectedMsg{Model: selectedItem}),
145 )
146 case key.Matches(msg, m.keyMap.Close):
147 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
148 default:
149 u, cmd := m.modelList.Update(msg)
150 m.modelList = u.(list.ListModel)
151 return m, cmd
152 }
153 }
154 return m, nil
155}
156
157func (m *modelDialogCmp) View() tea.View {
158 t := styles.CurrentTheme()
159 listView := m.modelList.View()
160 content := lipgloss.JoinVertical(
161 lipgloss.Left,
162 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-4)),
163 listView.String(),
164 "",
165 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
166 )
167 v := tea.NewView(m.style().Render(content))
168 if listView.Cursor() != nil {
169 c := m.moveCursor(listView.Cursor())
170 v.SetCursor(c)
171 }
172 return v
173}
174
175func (m *modelDialogCmp) style() lipgloss.Style {
176 t := styles.CurrentTheme()
177 return t.S().Base.
178 Width(m.width).
179 Border(lipgloss.RoundedBorder()).
180 BorderForeground(t.BorderFocus)
181}
182
183func (m *modelDialogCmp) listWidth() int {
184 return defaultWidth - 2 // 4 for padding
185}
186
187func (m *modelDialogCmp) listHeight() int {
188 listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
189 return min(listHeigh, m.wHeight/2)
190}
191
192func GetSelectedModel(cfg *config.Config) models.Model {
193 agentCfg := cfg.Agents[config.AgentCoder]
194 selectedModelID := agentCfg.Model
195 return models.SupportedModels[selectedModelID]
196}
197
198func getEnabledProviders(cfg *config.Config) []models.InferenceProvider {
199 var providers []models.InferenceProvider
200 for providerID, provider := range cfg.Providers {
201 if !provider.Disabled {
202 providers = append(providers, providerID)
203 }
204 }
205
206 // Sort by provider popularity
207 slices.SortFunc(providers, func(a, b models.InferenceProvider) int {
208 rA := ProviderPopularity[a]
209 rB := ProviderPopularity[b]
210
211 // models not included in popularity ranking default to last
212 if rA == 0 {
213 rA = 999
214 }
215 if rB == 0 {
216 rB = 999
217 }
218 return rA - rB
219 })
220 return providers
221}
222
223func getModelsForProvider(provider models.InferenceProvider) []models.Model {
224 var providerModels []models.Model
225 for _, model := range models.SupportedModels {
226 if model.Provider == provider {
227 providerModels = append(providerModels, model)
228 }
229 }
230
231 // reverse alphabetical order (if llm naming was consistent latest would appear first)
232 slices.SortFunc(providerModels, func(a, b models.Model) int {
233 if a.Name > b.Name {
234 return -1
235 } else if a.Name < b.Name {
236 return 1
237 }
238 return 0
239 })
240
241 return providerModels
242}
243
244func (m *modelDialogCmp) Position() (int, int) {
245 row := m.wHeight/4 - 2 // just a bit above the center
246 col := m.wWidth / 2
247 col -= m.width / 2
248 return row, col
249}
250
251func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
252 row, col := m.Position()
253 offset := row + 3 // Border + title
254 cursor.Y += offset
255 cursor.X = cursor.X + col + 2
256 return cursor
257}
258
259func (m *modelDialogCmp) ID() dialogs.DialogID {
260 return ModelsDialogID
261}