1package models
2
3import (
4 "slices"
5
6 "github.com/charmbracelet/bubbles/v2/help"
7 "github.com/charmbracelet/bubbles/v2/key"
8 tea "github.com/charmbracelet/bubbletea/v2"
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/fur/provider"
11 "github.com/charmbracelet/crush/internal/tui/components/completions"
12 "github.com/charmbracelet/crush/internal/tui/components/core"
13 "github.com/charmbracelet/crush/internal/tui/components/core/list"
14 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
15 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
16 "github.com/charmbracelet/crush/internal/tui/styles"
17 "github.com/charmbracelet/crush/internal/tui/util"
18 "github.com/charmbracelet/lipgloss/v2"
19)
20
21const (
22 ModelsDialogID dialogs.DialogID = "models"
23
24 defaultWidth = 60
25)
26
27const (
28 LargeModelType int = iota
29 SmallModelType
30)
31
32// ModelSelectedMsg is sent when a model is selected
33type ModelSelectedMsg struct {
34 Model config.PreferredModel
35 ModelType config.ModelType
36}
37
38// CloseModelDialogMsg is sent when a model is selected
39type CloseModelDialogMsg struct{}
40
41// ModelDialog interface for the model selection dialog
42type ModelDialog interface {
43 dialogs.DialogModel
44}
45
46type ModelOption struct {
47 Provider provider.Provider
48 Model provider.Model
49}
50
51type modelDialogCmp struct {
52 width int
53 wWidth int
54 wHeight int
55
56 modelList list.ListModel
57 keyMap KeyMap
58 help help.Model
59 modelType int
60}
61
62func NewModelDialogCmp() ModelDialog {
63 listKeyMap := list.DefaultKeyMap()
64 keyMap := DefaultKeyMap()
65
66 listKeyMap.Down.SetEnabled(false)
67 listKeyMap.Up.SetEnabled(false)
68 listKeyMap.HalfPageDown.SetEnabled(false)
69 listKeyMap.HalfPageUp.SetEnabled(false)
70 listKeyMap.Home.SetEnabled(false)
71 listKeyMap.End.SetEnabled(false)
72
73 listKeyMap.DownOneItem = keyMap.Next
74 listKeyMap.UpOneItem = keyMap.Previous
75
76 t := styles.CurrentTheme()
77 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
78 modelList := list.New(
79 list.WithFilterable(true),
80 list.WithKeyMap(listKeyMap),
81 list.WithInputStyle(inputStyle),
82 list.WithWrapNavigation(true),
83 )
84 help := help.New()
85 help.Styles = t.S().Help
86
87 return &modelDialogCmp{
88 modelList: modelList,
89 width: defaultWidth,
90 keyMap: DefaultKeyMap(),
91 help: help,
92 modelType: LargeModelType,
93 }
94}
95
96func (m *modelDialogCmp) Init() tea.Cmd {
97 m.SetModelType(m.modelType)
98 return m.modelList.Init()
99}
100
101func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
102 switch msg := msg.(type) {
103 case tea.WindowSizeMsg:
104 m.wWidth = msg.Width
105 m.wHeight = msg.Height
106 m.SetModelType(m.modelType)
107 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
108 case tea.KeyPressMsg:
109 switch {
110 case key.Matches(msg, m.keyMap.Select):
111 selectedItemInx := m.modelList.SelectedIndex()
112 if selectedItemInx == list.NoSelection {
113 return m, nil
114 }
115 items := m.modelList.Items()
116 selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
117
118 var modelType config.ModelType
119 if m.modelType == LargeModelType {
120 modelType = config.LargeModel
121 } else {
122 modelType = config.SmallModel
123 }
124
125 return m, tea.Sequence(
126 util.CmdHandler(dialogs.CloseDialogMsg{}),
127 util.CmdHandler(ModelSelectedMsg{
128 Model: config.PreferredModel{
129 ModelID: selectedItem.Model.ID,
130 Provider: selectedItem.Provider.ID,
131 },
132 ModelType: modelType,
133 }),
134 )
135 case key.Matches(msg, m.keyMap.Tab):
136 if m.modelType == LargeModelType {
137 return m, m.SetModelType(SmallModelType)
138 } else {
139 return m, m.SetModelType(LargeModelType)
140 }
141 case key.Matches(msg, m.keyMap.Close):
142 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
143 default:
144 u, cmd := m.modelList.Update(msg)
145 m.modelList = u.(list.ListModel)
146 return m, cmd
147 }
148 }
149 return m, nil
150}
151
152func (m *modelDialogCmp) View() tea.View {
153 t := styles.CurrentTheme()
154 listView := m.modelList.View()
155 radio := m.modelTypeRadio()
156 content := lipgloss.JoinVertical(
157 lipgloss.Left,
158 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
159 listView.String(),
160 "",
161 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
162 )
163 v := tea.NewView(m.style().Render(content))
164 if listView.Cursor() != nil {
165 c := m.moveCursor(listView.Cursor())
166 v.SetCursor(c)
167 }
168 return v
169}
170
171func (m *modelDialogCmp) style() lipgloss.Style {
172 t := styles.CurrentTheme()
173 return t.S().Base.
174 Width(m.width).
175 Border(lipgloss.RoundedBorder()).
176 BorderForeground(t.BorderFocus)
177}
178
179func (m *modelDialogCmp) listWidth() int {
180 return defaultWidth - 2 // 4 for padding
181}
182
183func (m *modelDialogCmp) listHeight() int {
184 listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
185 return min(listHeigh, m.wHeight/2)
186}
187
188func (m *modelDialogCmp) Position() (int, int) {
189 row := m.wHeight/4 - 2 // just a bit above the center
190 col := m.wWidth / 2
191 col -= m.width / 2
192 return row, col
193}
194
195func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
196 row, col := m.Position()
197 offset := row + 3 // Border + title
198 cursor.Y += offset
199 cursor.X = cursor.X + col + 2
200 return cursor
201}
202
203func (m *modelDialogCmp) ID() dialogs.DialogID {
204 return ModelsDialogID
205}
206
207func (m *modelDialogCmp) modelTypeRadio() string {
208 t := styles.CurrentTheme()
209 choices := []string{"Large Task", "Small Task"}
210 iconSelected := "◉"
211 iconUnselected := "○"
212 if m.modelType == LargeModelType {
213 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
214 }
215 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
216}
217
218func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
219 m.modelType = modelType
220
221 providers := config.Providers()
222 modelItems := []util.Model{}
223 selectIndex := 0
224
225 cfg := config.Get()
226 var currentModel config.PreferredModel
227 if m.modelType == LargeModelType {
228 currentModel = cfg.Models.Large
229 } else {
230 currentModel = cfg.Models.Small
231 }
232
233 // Create a map to track which providers we've already added
234 addedProviders := make(map[provider.InferenceProvider]bool)
235
236 // First, add any configured providers that are not in the known providers list
237 // These should appear at the top of the list
238 knownProviders := provider.KnownProviders()
239 for providerID, providerConfig := range cfg.Providers {
240 if providerConfig.Disabled {
241 continue
242 }
243
244 // Check if this provider is not in the known providers list
245 if !slices.Contains(knownProviders, providerID) {
246 // Convert config provider to provider.Provider format
247 configProvider := provider.Provider{
248 Name: string(providerID), // Use provider ID as name for unknown providers
249 ID: providerID,
250 Models: make([]provider.Model, len(providerConfig.Models)),
251 }
252
253 // Convert models
254 for i, model := range providerConfig.Models {
255 configProvider.Models[i] = provider.Model{
256 ID: model.ID,
257 Name: model.Name,
258 CostPer1MIn: model.CostPer1MIn,
259 CostPer1MOut: model.CostPer1MOut,
260 CostPer1MInCached: model.CostPer1MInCached,
261 CostPer1MOutCached: model.CostPer1MOutCached,
262 ContextWindow: model.ContextWindow,
263 DefaultMaxTokens: model.DefaultMaxTokens,
264 CanReason: model.CanReason,
265 HasReasoningEffort: model.HasReasoningEffort,
266 DefaultReasoningEffort: model.ReasoningEffort,
267 SupportsImages: model.SupportsImages,
268 }
269 }
270
271 // Add this unknown provider to the list
272 name := configProvider.Name
273 if name == "" {
274 name = string(configProvider.ID)
275 }
276 modelItems = append(modelItems, commands.NewItemSection(name))
277 for _, model := range configProvider.Models {
278 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
279 Provider: configProvider,
280 Model: model,
281 }))
282 if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider {
283 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
284 }
285 }
286 addedProviders[providerID] = true
287 }
288 }
289
290 // Then add the known providers from the predefined list
291 for _, provider := range providers {
292 // Skip if we already added this provider as an unknown provider
293 if addedProviders[provider.ID] {
294 continue
295 }
296
297 // Check if this provider is configured and not disabled
298 if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled {
299 continue
300 }
301
302 name := provider.Name
303 if name == "" {
304 name = string(provider.ID)
305 }
306 modelItems = append(modelItems, commands.NewItemSection(name))
307 for _, model := range provider.Models {
308 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
309 Provider: provider,
310 Model: model,
311 }))
312 if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider {
313 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
314 }
315 }
316 }
317
318 return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
319}