1package models
2
3import (
4 "slices"
5
6 "github.com/charmbracelet/bubbles/v2/help"
7 "github.com/charmbracelet/bubbles/v2/key"
8 tea "github.com/charmbracelet/bubbletea/v2"
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/llm/models"
11 "github.com/charmbracelet/crush/internal/tui/components/completions"
12 "github.com/charmbracelet/crush/internal/tui/components/core"
13 "github.com/charmbracelet/crush/internal/tui/components/core/list"
14 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
15 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
16 "github.com/charmbracelet/crush/internal/tui/styles"
17 "github.com/charmbracelet/crush/internal/tui/util"
18 "github.com/charmbracelet/lipgloss/v2"
19)
20
21const (
22 ModelsDialogID dialogs.DialogID = "models"
23
24 defaultWidth = 60
25)
26
27// ModelSelectedMsg is sent when a model is selected
28type ModelSelectedMsg struct {
29 Model models.Model
30}
31
32// CloseModelDialogMsg is sent when a model is selected
33type CloseModelDialogMsg struct{}
34
35// ModelDialog interface for the model selection dialog
36type ModelDialog interface {
37 dialogs.DialogModel
38}
39
40type modelDialogCmp struct {
41 width int
42 wWidth int // Width of the terminal window
43 wHeight int // Height of the terminal window
44
45 modelList list.ListModel
46 keyMap KeyMap
47 help help.Model
48}
49
50func NewModelDialogCmp() ModelDialog {
51 listKeyMap := list.DefaultKeyMap()
52 keyMap := DefaultKeyMap()
53
54 listKeyMap.Down.SetEnabled(false)
55 listKeyMap.Up.SetEnabled(false)
56 listKeyMap.HalfPageDown.SetEnabled(false)
57 listKeyMap.HalfPageUp.SetEnabled(false)
58 listKeyMap.Home.SetEnabled(false)
59 listKeyMap.End.SetEnabled(false)
60
61 listKeyMap.DownOneItem = keyMap.Next
62 listKeyMap.UpOneItem = keyMap.Previous
63
64 t := styles.CurrentTheme()
65 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
66 modelList := list.New(
67 list.WithFilterable(true),
68 list.WithKeyMap(listKeyMap),
69 list.WithInputStyle(inputStyle),
70 list.WithWrapNavigation(true),
71 )
72 help := help.New()
73 help.Styles = t.S().Help
74
75 return &modelDialogCmp{
76 modelList: modelList,
77 width: defaultWidth,
78 keyMap: DefaultKeyMap(),
79 help: help,
80 }
81}
82
83var ProviderPopularity = map[models.ModelProvider]int{
84 models.ProviderAnthropic: 1,
85 models.ProviderOpenAI: 2,
86 models.ProviderGemini: 3,
87 models.ProviderGROQ: 4,
88 models.ProviderOpenRouter: 5,
89 models.ProviderBedrock: 6,
90 models.ProviderAzure: 7,
91 models.ProviderVertexAI: 8,
92 models.ProviderXAI: 9,
93}
94
95var ProviderName = map[models.ModelProvider]string{
96 models.ProviderAnthropic: "Anthropic",
97 models.ProviderOpenAI: "OpenAI",
98 models.ProviderGemini: "Gemini",
99 models.ProviderGROQ: "Groq",
100 models.ProviderOpenRouter: "OpenRouter",
101 models.ProviderBedrock: "AWS Bedrock",
102 models.ProviderAzure: "Azure",
103 models.ProviderVertexAI: "VertexAI",
104 models.ProviderXAI: "xAI",
105}
106
107func (m *modelDialogCmp) Init() tea.Cmd {
108 cfg := config.Get()
109 enabledProviders := getEnabledProviders(cfg)
110
111 modelItems := []util.Model{}
112 for _, provider := range enabledProviders {
113 name, ok := ProviderName[provider]
114 if !ok {
115 name = string(provider) // Fallback to provider ID if name is not defined
116 }
117 modelItems = append(modelItems, commands.NewItemSection(name))
118 for _, model := range getModelsForProvider(provider) {
119 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model))
120 }
121 }
122 m.modelList.SetItems(modelItems)
123 return m.modelList.Init()
124}
125
126func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
127 switch msg := msg.(type) {
128 case tea.WindowSizeMsg:
129 m.wWidth = msg.Width
130 m.wHeight = msg.Height
131 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
132 case tea.KeyPressMsg:
133 switch {
134 case key.Matches(msg, m.keyMap.Select):
135 selectedItemInx := m.modelList.SelectedIndex()
136 if selectedItemInx == list.NoSelection {
137 return m, nil // No item selected, do nothing
138 }
139 items := m.modelList.Items()
140 selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model)
141
142 return m, tea.Sequence(
143 util.CmdHandler(dialogs.CloseDialogMsg{}),
144 util.CmdHandler(ModelSelectedMsg{Model: selectedItem}),
145 )
146 case key.Matches(msg, m.keyMap.Close):
147 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
148 default:
149 u, cmd := m.modelList.Update(msg)
150 m.modelList = u.(list.ListModel)
151 return m, cmd
152 }
153 }
154 return m, nil
155}
156
157func (m *modelDialogCmp) View() string {
158 t := styles.CurrentTheme()
159 listView := m.modelList.View()
160 content := lipgloss.JoinVertical(
161 lipgloss.Left,
162 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-4)),
163 listView,
164 "",
165 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
166 )
167 return m.style().Render(content)
168}
169
170func (m *modelDialogCmp) Cursor() *tea.Cursor {
171 if cursor, ok := m.modelList.(util.Cursor); ok {
172 cursor := cursor.Cursor()
173 if cursor != nil {
174 cursor = m.moveCursor(cursor)
175 return cursor
176 }
177 }
178 return nil
179}
180
181func (m *modelDialogCmp) style() lipgloss.Style {
182 t := styles.CurrentTheme()
183 return t.S().Base.
184 Width(m.width).
185 Border(lipgloss.RoundedBorder()).
186 BorderForeground(t.BorderFocus)
187}
188
189func (m *modelDialogCmp) listWidth() int {
190 return defaultWidth - 2 // 4 for padding
191}
192
193func (m *modelDialogCmp) listHeight() int {
194 listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
195 return min(listHeigh, m.wHeight/2)
196}
197
198func GetSelectedModel(cfg *config.Config) models.Model {
199 agentCfg := cfg.Agents[config.AgentCoder]
200 selectedModelID := agentCfg.Model
201 return models.SupportedModels[selectedModelID]
202}
203
204func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
205 var providers []models.ModelProvider
206 for providerID, provider := range cfg.Providers {
207 if !provider.Disabled {
208 providers = append(providers, providerID)
209 }
210 }
211
212 // Sort by provider popularity
213 slices.SortFunc(providers, func(a, b models.ModelProvider) int {
214 rA := ProviderPopularity[a]
215 rB := ProviderPopularity[b]
216
217 // models not included in popularity ranking default to last
218 if rA == 0 {
219 rA = 999
220 }
221 if rB == 0 {
222 rB = 999
223 }
224 return rA - rB
225 })
226 return providers
227}
228
229func getModelsForProvider(provider models.ModelProvider) []models.Model {
230 var providerModels []models.Model
231 for _, model := range models.SupportedModels {
232 if model.Provider == provider {
233 providerModels = append(providerModels, model)
234 }
235 }
236
237 // reverse alphabetical order (if llm naming was consistent latest would appear first)
238 slices.SortFunc(providerModels, func(a, b models.Model) int {
239 if a.Name > b.Name {
240 return -1
241 } else if a.Name < b.Name {
242 return 1
243 }
244 return 0
245 })
246
247 return providerModels
248}
249
250func (m *modelDialogCmp) Position() (int, int) {
251 row := m.wHeight/4 - 2 // just a bit above the center
252 col := m.wWidth / 2
253 col -= m.width / 2
254 return row, col
255}
256
257func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
258 row, col := m.Position()
259 offset := row + 3 // Border + title
260 cursor.Y += offset
261 cursor.X = cursor.X + col + 2
262 return cursor
263}
264
265func (m *modelDialogCmp) ID() dialogs.DialogID {
266 return ModelsDialogID
267}