1package models
2
3import (
4 "slices"
5
6 "github.com/charmbracelet/bubbles/v2/help"
7 "github.com/charmbracelet/bubbles/v2/key"
8 tea "github.com/charmbracelet/bubbletea/v2"
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/llm/models"
11 "github.com/charmbracelet/crush/internal/tui/components/completions"
12 "github.com/charmbracelet/crush/internal/tui/components/core"
13 "github.com/charmbracelet/crush/internal/tui/components/core/list"
14 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
15 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
16 "github.com/charmbracelet/crush/internal/tui/styles"
17 "github.com/charmbracelet/crush/internal/tui/util"
18 "github.com/charmbracelet/lipgloss/v2"
19)
20
21const (
22 ModelsDialogID dialogs.DialogID = "models"
23
24 defaultWidth = 60
25)
26
27// ModelSelectedMsg is sent when a model is selected
28type ModelSelectedMsg struct {
29 Model models.Model
30}
31
32// CloseModelDialogMsg is sent when a model is selected
33type CloseModelDialogMsg struct{}
34
35// ModelDialog interface for the model selection dialog
36type ModelDialog interface {
37 dialogs.DialogModel
38}
39
40type modelDialogCmp struct {
41 width int
42 wWidth int // Width of the terminal window
43 wHeight int // Height of the terminal window
44
45 modelList list.ListModel
46 keyMap KeyMap
47 help help.Model
48}
49
50func NewModelDialogCmp() ModelDialog {
51 listKeyMap := list.DefaultKeyMap()
52 keyMap := DefaultKeyMap()
53
54 listKeyMap.Down.SetEnabled(false)
55 listKeyMap.Up.SetEnabled(false)
56 listKeyMap.NDown.SetEnabled(false)
57 listKeyMap.NUp.SetEnabled(false)
58 listKeyMap.HalfPageDown.SetEnabled(false)
59 listKeyMap.HalfPageUp.SetEnabled(false)
60 listKeyMap.Home.SetEnabled(false)
61 listKeyMap.End.SetEnabled(false)
62
63 listKeyMap.DownOneItem = keyMap.Next
64 listKeyMap.UpOneItem = keyMap.Previous
65
66 t := styles.CurrentTheme()
67 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
68 modelList := list.New(
69 list.WithFilterable(true),
70 list.WithKeyMap(listKeyMap),
71 list.WithInputStyle(inputStyle),
72 list.WithWrapNavigation(true),
73 )
74 help := help.New()
75 help.Styles = t.S().Help
76
77 return &modelDialogCmp{
78 modelList: modelList,
79 width: defaultWidth,
80 keyMap: DefaultKeyMap(),
81 help: help,
82 }
83}
84
85var ProviderPopularity = map[models.ModelProvider]int{
86 models.ProviderAnthropic: 1,
87 models.ProviderOpenAI: 2,
88 models.ProviderGemini: 3,
89 models.ProviderGROQ: 4,
90 models.ProviderOpenRouter: 5,
91 models.ProviderBedrock: 6,
92 models.ProviderAzure: 7,
93 models.ProviderVertexAI: 8,
94 models.ProviderXAI: 9,
95}
96
97var ProviderName = map[models.ModelProvider]string{
98 models.ProviderAnthropic: "Anthropic",
99 models.ProviderOpenAI: "OpenAI",
100 models.ProviderGemini: "Gemini",
101 models.ProviderGROQ: "Groq",
102 models.ProviderOpenRouter: "OpenRouter",
103 models.ProviderBedrock: "AWS Bedrock",
104 models.ProviderAzure: "Azure",
105 models.ProviderVertexAI: "VertexAI",
106 models.ProviderXAI: "xAI",
107}
108
109func (m *modelDialogCmp) Init() tea.Cmd {
110 cfg := config.Get()
111 enabledProviders := getEnabledProviders(cfg)
112
113 modelItems := []util.Model{}
114 for _, provider := range enabledProviders {
115 name, ok := ProviderName[provider]
116 if !ok {
117 name = string(provider) // Fallback to provider ID if name is not defined
118 }
119 modelItems = append(modelItems, commands.NewItemSection(name))
120 for _, model := range getModelsForProvider(provider) {
121 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model))
122 }
123 }
124 m.modelList.SetItems(modelItems)
125 return m.modelList.Init()
126}
127
128func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
129 switch msg := msg.(type) {
130 case tea.WindowSizeMsg:
131 m.wWidth = msg.Width
132 m.wHeight = msg.Height
133 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
134 case tea.KeyPressMsg:
135 switch {
136 case key.Matches(msg, m.keyMap.Select):
137 selectedItemInx := m.modelList.SelectedIndex()
138 if selectedItemInx == list.NoSelection {
139 return m, nil // No item selected, do nothing
140 }
141 items := m.modelList.Items()
142 selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model)
143
144 return m, tea.Sequence(
145 util.CmdHandler(dialogs.CloseDialogMsg{}),
146 util.CmdHandler(ModelSelectedMsg{Model: selectedItem}),
147 )
148 case key.Matches(msg, m.keyMap.Close):
149 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
150 default:
151 u, cmd := m.modelList.Update(msg)
152 m.modelList = u.(list.ListModel)
153 return m, cmd
154 }
155 }
156 return m, nil
157}
158
159func (m *modelDialogCmp) View() tea.View {
160 t := styles.CurrentTheme()
161 listView := m.modelList.View()
162 content := lipgloss.JoinVertical(
163 lipgloss.Left,
164 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-4)),
165 listView.String(),
166 "",
167 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
168 )
169 v := tea.NewView(m.style().Render(content))
170 if listView.Cursor() != nil {
171 c := m.moveCursor(listView.Cursor())
172 v.SetCursor(c)
173 }
174 return v
175}
176
177func (m *modelDialogCmp) style() lipgloss.Style {
178 t := styles.CurrentTheme()
179 return t.S().Base.
180 Width(m.width).
181 Border(lipgloss.RoundedBorder()).
182 BorderForeground(t.BorderFocus)
183}
184
185func (m *modelDialogCmp) listWidth() int {
186 return defaultWidth - 2 // 4 for padding
187}
188
189func (m *modelDialogCmp) listHeight() int {
190 listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
191 return min(listHeigh, m.wHeight/2)
192}
193
194func GetSelectedModel(cfg *config.Config) models.Model {
195 agentCfg := cfg.Agents[config.AgentCoder]
196 selectedModelID := agentCfg.Model
197 return models.SupportedModels[selectedModelID]
198}
199
200func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
201 var providers []models.ModelProvider
202 for providerID, provider := range cfg.Providers {
203 if !provider.Disabled {
204 providers = append(providers, providerID)
205 }
206 }
207
208 // Sort by provider popularity
209 slices.SortFunc(providers, func(a, b models.ModelProvider) int {
210 rA := ProviderPopularity[a]
211 rB := ProviderPopularity[b]
212
213 // models not included in popularity ranking default to last
214 if rA == 0 {
215 rA = 999
216 }
217 if rB == 0 {
218 rB = 999
219 }
220 return rA - rB
221 })
222 return providers
223}
224
225func getModelsForProvider(provider models.ModelProvider) []models.Model {
226 var providerModels []models.Model
227 for _, model := range models.SupportedModels {
228 if model.Provider == provider {
229 providerModels = append(providerModels, model)
230 }
231 }
232
233 // reverse alphabetical order (if llm naming was consistent latest would appear first)
234 slices.SortFunc(providerModels, func(a, b models.Model) int {
235 if a.Name > b.Name {
236 return -1
237 } else if a.Name < b.Name {
238 return 1
239 }
240 return 0
241 })
242
243 return providerModels
244}
245
246func (m *modelDialogCmp) Position() (int, int) {
247 row := m.wHeight/4 - 2 // just a bit above the center
248 col := m.wWidth / 2
249 col -= m.width / 2
250 return row, col
251}
252
253func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
254 row, col := m.Position()
255 offset := row + 3 // Border + title
256 cursor.Y += offset
257 cursor.X = cursor.X + col + 2
258 return cursor
259}
260
261func (m *modelDialogCmp) ID() dialogs.DialogID {
262 return ModelsDialogID
263}