1package models
2
3import (
4 "slices"
5
6 "github.com/charmbracelet/bubbles/v2/help"
7 "github.com/charmbracelet/bubbles/v2/key"
8 tea "github.com/charmbracelet/bubbletea/v2"
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/fur/provider"
11 "github.com/charmbracelet/crush/internal/tui/components/completions"
12 "github.com/charmbracelet/crush/internal/tui/components/core"
13 "github.com/charmbracelet/crush/internal/tui/components/core/list"
14 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
15 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
16 "github.com/charmbracelet/crush/internal/tui/styles"
17 "github.com/charmbracelet/crush/internal/tui/util"
18 "github.com/charmbracelet/lipgloss/v2"
19)
20
21const (
22 ModelsDialogID dialogs.DialogID = "models"
23
24 defaultWidth = 60
25)
26
27const (
28 LargeModelType int = iota
29 SmallModelType
30)
31
32// ModelSelectedMsg is sent when a model is selected
33type ModelSelectedMsg struct {
34 Model config.SelectedModel
35 ModelType config.SelectedModelType
36}
37
38// CloseModelDialogMsg is sent when a model is selected
39type CloseModelDialogMsg struct{}
40
41// ModelDialog interface for the model selection dialog
42type ModelDialog interface {
43 dialogs.DialogModel
44}
45
46type ModelOption struct {
47 Provider provider.Provider
48 Model provider.Model
49}
50
51type modelDialogCmp struct {
52 width int
53 wWidth int
54 wHeight int
55
56 modelList list.ListModel
57 keyMap KeyMap
58 help help.Model
59 modelType int
60}
61
62func NewModelDialogCmp() ModelDialog {
63 listKeyMap := list.DefaultKeyMap()
64 keyMap := DefaultKeyMap()
65
66 listKeyMap.Down.SetEnabled(false)
67 listKeyMap.Up.SetEnabled(false)
68 listKeyMap.HalfPageDown.SetEnabled(false)
69 listKeyMap.HalfPageUp.SetEnabled(false)
70 listKeyMap.Home.SetEnabled(false)
71 listKeyMap.End.SetEnabled(false)
72
73 listKeyMap.DownOneItem = keyMap.Next
74 listKeyMap.UpOneItem = keyMap.Previous
75
76 t := styles.CurrentTheme()
77 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
78 modelList := list.New(
79 list.WithFilterable(true),
80 list.WithKeyMap(listKeyMap),
81 list.WithInputStyle(inputStyle),
82 list.WithWrapNavigation(true),
83 )
84 help := help.New()
85 help.Styles = t.S().Help
86
87 return &modelDialogCmp{
88 modelList: modelList,
89 width: defaultWidth,
90 keyMap: DefaultKeyMap(),
91 help: help,
92 modelType: LargeModelType,
93 }
94}
95
96func (m *modelDialogCmp) Init() tea.Cmd {
97 m.SetModelType(m.modelType)
98 return m.modelList.Init()
99}
100
101func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
102 switch msg := msg.(type) {
103 case tea.WindowSizeMsg:
104 m.wWidth = msg.Width
105 m.wHeight = msg.Height
106 m.SetModelType(m.modelType)
107 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
108 case tea.KeyPressMsg:
109 switch {
110 case key.Matches(msg, m.keyMap.Select):
111 selectedItemInx := m.modelList.SelectedIndex()
112 if selectedItemInx == list.NoSelection {
113 return m, nil
114 }
115 items := m.modelList.Items()
116 selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
117
118 var modelType config.SelectedModelType
119 if m.modelType == LargeModelType {
120 modelType = config.SelectedModelTypeLarge
121 } else {
122 modelType = config.SelectedModelTypeSmall
123 }
124
125 return m, tea.Sequence(
126 util.CmdHandler(dialogs.CloseDialogMsg{}),
127 util.CmdHandler(ModelSelectedMsg{
128 Model: config.SelectedModel{
129 Model: selectedItem.Model.ID,
130 Provider: string(selectedItem.Provider.ID),
131 },
132 ModelType: modelType,
133 }),
134 )
135 case key.Matches(msg, m.keyMap.Tab):
136 if m.modelType == LargeModelType {
137 return m, m.SetModelType(SmallModelType)
138 } else {
139 return m, m.SetModelType(LargeModelType)
140 }
141 case key.Matches(msg, m.keyMap.Close):
142 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
143 default:
144 u, cmd := m.modelList.Update(msg)
145 m.modelList = u.(list.ListModel)
146 return m, cmd
147 }
148 }
149 return m, nil
150}
151
152func (m *modelDialogCmp) View() tea.View {
153 t := styles.CurrentTheme()
154 listView := m.modelList.View()
155 radio := m.modelTypeRadio()
156 content := lipgloss.JoinVertical(
157 lipgloss.Left,
158 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
159 listView.String(),
160 "",
161 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
162 )
163 v := tea.NewView(m.style().Render(content))
164 if listView.Cursor() != nil {
165 c := m.moveCursor(listView.Cursor())
166 v.SetCursor(c)
167 }
168 return v
169}
170
171func (m *modelDialogCmp) style() lipgloss.Style {
172 t := styles.CurrentTheme()
173 return t.S().Base.
174 Width(m.width).
175 Border(lipgloss.RoundedBorder()).
176 BorderForeground(t.BorderFocus)
177}
178
179func (m *modelDialogCmp) listWidth() int {
180 return defaultWidth - 2 // 4 for padding
181}
182
183func (m *modelDialogCmp) listHeight() int {
184 listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
185 return min(listHeigh, m.wHeight/2)
186}
187
188func (m *modelDialogCmp) Position() (int, int) {
189 row := m.wHeight/4 - 2 // just a bit above the center
190 col := m.wWidth / 2
191 col -= m.width / 2
192 return row, col
193}
194
195func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
196 row, col := m.Position()
197 offset := row + 3 // Border + title
198 cursor.Y += offset
199 cursor.X = cursor.X + col + 2
200 return cursor
201}
202
203func (m *modelDialogCmp) ID() dialogs.DialogID {
204 return ModelsDialogID
205}
206
207func (m *modelDialogCmp) modelTypeRadio() string {
208 t := styles.CurrentTheme()
209 choices := []string{"Large Task", "Small Task"}
210 iconSelected := "◉"
211 iconUnselected := "○"
212 if m.modelType == LargeModelType {
213 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
214 }
215 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
216}
217
218func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
219 m.modelType = modelType
220
221 providers, err := config.Providers()
222 if err != nil {
223 return util.ReportError(err)
224 }
225
226 modelItems := []util.Model{}
227 selectIndex := 0
228
229 cfg := config.Get()
230 var currentModel config.SelectedModel
231 if m.modelType == LargeModelType {
232 currentModel = cfg.Models[config.SelectedModelTypeLarge]
233 } else {
234 currentModel = cfg.Models[config.SelectedModelTypeSmall]
235 }
236
237 // Create a map to track which providers we've already added
238 addedProviders := make(map[string]bool)
239
240 // First, add any configured providers that are not in the known providers list
241 // These should appear at the top of the list
242 knownProviders := provider.KnownProviders()
243 for providerID, providerConfig := range cfg.Providers {
244 if providerConfig.Disable {
245 continue
246 }
247
248 // Check if this provider is not in the known providers list
249 if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
250 // Convert config provider to provider.Provider format
251 configProvider := provider.Provider{
252 Name: string(providerID), // Use provider ID as name for unknown providers
253 ID: provider.InferenceProvider(providerID),
254 Models: make([]provider.Model, len(providerConfig.Models)),
255 }
256
257 // Convert models
258 for i, model := range providerConfig.Models {
259 configProvider.Models[i] = provider.Model{
260 ID: model.ID,
261 Name: model.Name,
262 CostPer1MIn: model.CostPer1MIn,
263 CostPer1MOut: model.CostPer1MOut,
264 CostPer1MInCached: model.CostPer1MInCached,
265 CostPer1MOutCached: model.CostPer1MOutCached,
266 ContextWindow: model.ContextWindow,
267 DefaultMaxTokens: model.DefaultMaxTokens,
268 CanReason: model.CanReason,
269 HasReasoningEffort: model.HasReasoningEffort,
270 DefaultReasoningEffort: model.DefaultReasoningEffort,
271 SupportsImages: model.SupportsImages,
272 }
273 }
274
275 // Add this unknown provider to the list
276 name := configProvider.Name
277 if name == "" {
278 name = string(configProvider.ID)
279 }
280 modelItems = append(modelItems, commands.NewItemSection(name))
281 for _, model := range configProvider.Models {
282 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
283 Provider: configProvider,
284 Model: model,
285 }))
286 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
287 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
288 }
289 }
290 addedProviders[providerID] = true
291 }
292 }
293
294 // Then add the known providers from the predefined list
295 for _, provider := range providers {
296 // Skip if we already added this provider as an unknown provider
297 if addedProviders[string(provider.ID)] {
298 continue
299 }
300
301 // Check if this provider is configured and not disabled
302 if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
303 continue
304 }
305
306 name := provider.Name
307 if name == "" {
308 name = string(provider.ID)
309 }
310 modelItems = append(modelItems, commands.NewItemSection(name))
311 for _, model := range provider.Models {
312 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
313 Provider: provider,
314 Model: model,
315 }))
316 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
317 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
318 }
319 }
320 }
321
322 return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
323}