1package models
2
3import (
4 "slices"
5
6 "github.com/charmbracelet/bubbles/v2/help"
7 "github.com/charmbracelet/bubbles/v2/key"
8 tea "github.com/charmbracelet/bubbletea/v2"
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/fur/provider"
11 "github.com/charmbracelet/crush/internal/tui/components/completions"
12 "github.com/charmbracelet/crush/internal/tui/components/core"
13 "github.com/charmbracelet/crush/internal/tui/components/core/list"
14 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
15 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
16 "github.com/charmbracelet/crush/internal/tui/styles"
17 "github.com/charmbracelet/crush/internal/tui/util"
18 "github.com/charmbracelet/lipgloss/v2"
19)
20
21const (
22 ModelsDialogID dialogs.DialogID = "models"
23
24 defaultWidth = 60
25)
26
27const (
28 LargeModelType int = iota
29 SmallModelType
30)
31
32// ModelSelectedMsg is sent when a model is selected
33type ModelSelectedMsg struct {
34 Model config.SelectedModel
35 ModelType config.SelectedModelType
36}
37
38// CloseModelDialogMsg is sent when a model is selected
39type CloseModelDialogMsg struct{}
40
41// ModelDialog interface for the model selection dialog
42type ModelDialog interface {
43 dialogs.DialogModel
44}
45
46type ModelOption struct {
47 Provider provider.Provider
48 Model provider.Model
49}
50
51type modelDialogCmp struct {
52 width int
53 wWidth int
54 wHeight int
55
56 modelList list.ListModel
57 keyMap KeyMap
58 help help.Model
59 modelType int
60}
61
62func NewModelDialogCmp() ModelDialog {
63 listKeyMap := list.DefaultKeyMap()
64 keyMap := DefaultKeyMap()
65
66 listKeyMap.Down.SetEnabled(false)
67 listKeyMap.Up.SetEnabled(false)
68 listKeyMap.HalfPageDown.SetEnabled(false)
69 listKeyMap.HalfPageUp.SetEnabled(false)
70 listKeyMap.Home.SetEnabled(false)
71 listKeyMap.End.SetEnabled(false)
72
73 listKeyMap.DownOneItem = keyMap.Next
74 listKeyMap.UpOneItem = keyMap.Previous
75
76 t := styles.CurrentTheme()
77 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
78 modelList := list.New(
79 list.WithFilterable(true),
80 list.WithKeyMap(listKeyMap),
81 list.WithInputStyle(inputStyle),
82 list.WithWrapNavigation(true),
83 )
84 help := help.New()
85 help.Styles = t.S().Help
86
87 return &modelDialogCmp{
88 modelList: modelList,
89 width: defaultWidth,
90 keyMap: DefaultKeyMap(),
91 help: help,
92 modelType: LargeModelType,
93 }
94}
95
96func (m *modelDialogCmp) Init() tea.Cmd {
97 m.SetModelType(m.modelType)
98 return m.modelList.Init()
99}
100
101func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
102 switch msg := msg.(type) {
103 case tea.WindowSizeMsg:
104 m.wWidth = msg.Width
105 m.wHeight = msg.Height
106 m.SetModelType(m.modelType)
107 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
108 case tea.KeyPressMsg:
109 switch {
110 case key.Matches(msg, m.keyMap.Select):
111 selectedItemInx := m.modelList.SelectedIndex()
112 if selectedItemInx == list.NoSelection {
113 return m, nil
114 }
115 items := m.modelList.Items()
116 selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
117
118 var modelType config.SelectedModelType
119 if m.modelType == LargeModelType {
120 modelType = config.SelectedModelTypeLarge
121 } else {
122 modelType = config.SelectedModelTypeSmall
123 }
124
125 return m, tea.Sequence(
126 util.CmdHandler(dialogs.CloseDialogMsg{}),
127 util.CmdHandler(ModelSelectedMsg{
128 Model: config.SelectedModel{
129 Model: selectedItem.Model.ID,
130 Provider: string(selectedItem.Provider.ID),
131 },
132 ModelType: modelType,
133 }),
134 )
135 case key.Matches(msg, m.keyMap.Tab):
136 if m.modelType == LargeModelType {
137 return m, m.SetModelType(SmallModelType)
138 } else {
139 return m, m.SetModelType(LargeModelType)
140 }
141 case key.Matches(msg, m.keyMap.Close):
142 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
143 default:
144 u, cmd := m.modelList.Update(msg)
145 m.modelList = u.(list.ListModel)
146 return m, cmd
147 }
148 }
149 return m, nil
150}
151
152func (m *modelDialogCmp) View() string {
153 t := styles.CurrentTheme()
154 listView := m.modelList.View()
155 radio := m.modelTypeRadio()
156 content := lipgloss.JoinVertical(
157 lipgloss.Left,
158 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
159 listView,
160 "",
161 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
162 )
163 return m.style().Render(content)
164}
165
166func (m *modelDialogCmp) Cursor() *tea.Cursor {
167 if cursor, ok := m.modelList.(util.Cursor); ok {
168 cursor := cursor.Cursor()
169 if cursor != nil {
170 cursor = m.moveCursor(cursor)
171 return cursor
172 }
173 }
174 return nil
175}
176
177func (m *modelDialogCmp) style() lipgloss.Style {
178 t := styles.CurrentTheme()
179 return t.S().Base.
180 Width(m.width).
181 Border(lipgloss.RoundedBorder()).
182 BorderForeground(t.BorderFocus)
183}
184
185func (m *modelDialogCmp) listWidth() int {
186 return defaultWidth - 2 // 4 for padding
187}
188
189func (m *modelDialogCmp) listHeight() int {
190 listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
191 return min(listHeigh, m.wHeight/2)
192}
193
194func (m *modelDialogCmp) Position() (int, int) {
195 row := m.wHeight/4 - 2 // just a bit above the center
196 col := m.wWidth / 2
197 col -= m.width / 2
198 return row, col
199}
200
201func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
202 row, col := m.Position()
203 offset := row + 3 // Border + title
204 cursor.Y += offset
205 cursor.X = cursor.X + col + 2
206 return cursor
207}
208
209func (m *modelDialogCmp) ID() dialogs.DialogID {
210 return ModelsDialogID
211}
212
213func (m *modelDialogCmp) modelTypeRadio() string {
214 t := styles.CurrentTheme()
215 choices := []string{"Large Task", "Small Task"}
216 iconSelected := "◉"
217 iconUnselected := "○"
218 if m.modelType == LargeModelType {
219 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
220 }
221 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
222}
223
224func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
225 m.modelType = modelType
226
227 providers, err := config.Providers()
228 if err != nil {
229 return util.ReportError(err)
230 }
231
232 modelItems := []util.Model{}
233 selectIndex := 0
234
235 cfg := config.Get()
236 var currentModel config.SelectedModel
237 if m.modelType == LargeModelType {
238 currentModel = cfg.Models[config.SelectedModelTypeLarge]
239 } else {
240 currentModel = cfg.Models[config.SelectedModelTypeSmall]
241 }
242
243 // Create a map to track which providers we've already added
244 addedProviders := make(map[string]bool)
245
246 // First, add any configured providers that are not in the known providers list
247 // These should appear at the top of the list
248 knownProviders := provider.KnownProviders()
249 for providerID, providerConfig := range cfg.Providers {
250 if providerConfig.Disable {
251 continue
252 }
253
254 // Check if this provider is not in the known providers list
255 if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
256 // Convert config provider to provider.Provider format
257 configProvider := provider.Provider{
258 Name: string(providerID), // Use provider ID as name for unknown providers
259 ID: provider.InferenceProvider(providerID),
260 Models: make([]provider.Model, len(providerConfig.Models)),
261 }
262
263 // Convert models
264 for i, model := range providerConfig.Models {
265 configProvider.Models[i] = provider.Model{
266 ID: model.ID,
267 Name: model.Name,
268 CostPer1MIn: model.CostPer1MIn,
269 CostPer1MOut: model.CostPer1MOut,
270 CostPer1MInCached: model.CostPer1MInCached,
271 CostPer1MOutCached: model.CostPer1MOutCached,
272 ContextWindow: model.ContextWindow,
273 DefaultMaxTokens: model.DefaultMaxTokens,
274 CanReason: model.CanReason,
275 HasReasoningEffort: model.HasReasoningEffort,
276 DefaultReasoningEffort: model.DefaultReasoningEffort,
277 SupportsImages: model.SupportsImages,
278 }
279 }
280
281 // Add this unknown provider to the list
282 name := configProvider.Name
283 if name == "" {
284 name = string(configProvider.ID)
285 }
286 modelItems = append(modelItems, commands.NewItemSection(name))
287 for _, model := range configProvider.Models {
288 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
289 Provider: configProvider,
290 Model: model,
291 }))
292 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
293 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
294 }
295 }
296 addedProviders[providerID] = true
297 }
298 }
299
300 // Then add the known providers from the predefined list
301 for _, provider := range providers {
302 // Skip if we already added this provider as an unknown provider
303 if addedProviders[string(provider.ID)] {
304 continue
305 }
306
307 // Check if this provider is configured and not disabled
308 if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
309 continue
310 }
311
312 name := provider.Name
313 if name == "" {
314 name = string(provider.ID)
315 }
316 modelItems = append(modelItems, commands.NewItemSection(name))
317 for _, model := range provider.Models {
318 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
319 Provider: provider,
320 Model: model,
321 }))
322 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
323 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
324 }
325 }
326 }
327
328 return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
329}