1package models
2
3import (
4 "github.com/charmbracelet/bubbles/v2/help"
5 "github.com/charmbracelet/bubbles/v2/key"
6 tea "github.com/charmbracelet/bubbletea/v2"
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/charmbracelet/crush/internal/fur/provider"
9 "github.com/charmbracelet/crush/internal/tui/components/completions"
10 "github.com/charmbracelet/crush/internal/tui/components/core"
11 "github.com/charmbracelet/crush/internal/tui/components/core/list"
12 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
13 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
14 "github.com/charmbracelet/crush/internal/tui/styles"
15 "github.com/charmbracelet/crush/internal/tui/util"
16 "github.com/charmbracelet/lipgloss/v2"
17)
18
19const (
20 ModelsDialogID dialogs.DialogID = "models"
21
22 defaultWidth = 60
23)
24
25const (
26 LargeModelType int = iota
27 SmallModelType
28)
29
30// ModelSelectedMsg is sent when a model is selected
31type ModelSelectedMsg struct {
32 Model config.PreferredModel
33 ModelType config.ModelType
34}
35
36// CloseModelDialogMsg is sent when a model is selected
37type CloseModelDialogMsg struct{}
38
39// ModelDialog interface for the model selection dialog
40type ModelDialog interface {
41 dialogs.DialogModel
42}
43
44type ModelOption struct {
45 Provider provider.Provider
46 Model provider.Model
47}
48
49type modelDialogCmp struct {
50 width int
51 wWidth int
52 wHeight int
53
54 modelList list.ListModel
55 keyMap KeyMap
56 help help.Model
57 modelType int
58}
59
60func NewModelDialogCmp() ModelDialog {
61 listKeyMap := list.DefaultKeyMap()
62 keyMap := DefaultKeyMap()
63
64 listKeyMap.Down.SetEnabled(false)
65 listKeyMap.Up.SetEnabled(false)
66 listKeyMap.HalfPageDown.SetEnabled(false)
67 listKeyMap.HalfPageUp.SetEnabled(false)
68 listKeyMap.Home.SetEnabled(false)
69 listKeyMap.End.SetEnabled(false)
70
71 listKeyMap.DownOneItem = keyMap.Next
72 listKeyMap.UpOneItem = keyMap.Previous
73
74 t := styles.CurrentTheme()
75 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
76 modelList := list.New(
77 list.WithFilterable(true),
78 list.WithKeyMap(listKeyMap),
79 list.WithInputStyle(inputStyle),
80 list.WithWrapNavigation(true),
81 )
82 help := help.New()
83 help.Styles = t.S().Help
84
85 return &modelDialogCmp{
86 modelList: modelList,
87 width: defaultWidth,
88 keyMap: DefaultKeyMap(),
89 help: help,
90 modelType: LargeModelType,
91 }
92}
93
94func (m *modelDialogCmp) Init() tea.Cmd {
95 m.SetModelType(m.modelType)
96 return m.modelList.Init()
97}
98
99func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
100 switch msg := msg.(type) {
101 case tea.WindowSizeMsg:
102 m.wWidth = msg.Width
103 m.wHeight = msg.Height
104 m.SetModelType(m.modelType)
105 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
106 case tea.KeyPressMsg:
107 switch {
108 case key.Matches(msg, m.keyMap.Select):
109 selectedItemInx := m.modelList.SelectedIndex()
110 if selectedItemInx == list.NoSelection {
111 return m, nil
112 }
113 items := m.modelList.Items()
114 selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
115
116 var modelType config.ModelType
117 if m.modelType == LargeModelType {
118 modelType = config.LargeModel
119 } else {
120 modelType = config.SmallModel
121 }
122
123 return m, tea.Sequence(
124 util.CmdHandler(dialogs.CloseDialogMsg{}),
125 util.CmdHandler(ModelSelectedMsg{
126 Model: config.PreferredModel{
127 ModelID: selectedItem.Model.ID,
128 Provider: selectedItem.Provider.ID,
129 },
130 ModelType: modelType,
131 }),
132 )
133 case key.Matches(msg, m.keyMap.Tab):
134 if m.modelType == LargeModelType {
135 return m, m.SetModelType(SmallModelType)
136 } else {
137 return m, m.SetModelType(LargeModelType)
138 }
139 case key.Matches(msg, m.keyMap.Close):
140 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
141 default:
142 u, cmd := m.modelList.Update(msg)
143 m.modelList = u.(list.ListModel)
144 return m, cmd
145 }
146 }
147 return m, nil
148}
149
150func (m *modelDialogCmp) View() tea.View {
151 t := styles.CurrentTheme()
152 listView := m.modelList.View()
153 radio := m.modelTypeRadio()
154 content := lipgloss.JoinVertical(
155 lipgloss.Left,
156 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
157 listView.String(),
158 "",
159 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
160 )
161 v := tea.NewView(m.style().Render(content))
162 if listView.Cursor() != nil {
163 c := m.moveCursor(listView.Cursor())
164 v.SetCursor(c)
165 }
166 return v
167}
168
169func (m *modelDialogCmp) style() lipgloss.Style {
170 t := styles.CurrentTheme()
171 return t.S().Base.
172 Width(m.width).
173 Border(lipgloss.RoundedBorder()).
174 BorderForeground(t.BorderFocus)
175}
176
177func (m *modelDialogCmp) listWidth() int {
178 return defaultWidth - 2 // 4 for padding
179}
180
181func (m *modelDialogCmp) listHeight() int {
182 listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
183 return min(listHeigh, m.wHeight/2)
184}
185
186func (m *modelDialogCmp) Position() (int, int) {
187 row := m.wHeight/4 - 2 // just a bit above the center
188 col := m.wWidth / 2
189 col -= m.width / 2
190 return row, col
191}
192
193func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
194 row, col := m.Position()
195 offset := row + 3 // Border + title
196 cursor.Y += offset
197 cursor.X = cursor.X + col + 2
198 return cursor
199}
200
201func (m *modelDialogCmp) ID() dialogs.DialogID {
202 return ModelsDialogID
203}
204
205func (m *modelDialogCmp) modelTypeRadio() string {
206 t := styles.CurrentTheme()
207 choices := []string{"Large Task", "Small Task"}
208 iconSelected := "◉"
209 iconUnselected := "○"
210 if m.modelType == LargeModelType {
211 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
212 }
213 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
214}
215
216func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
217 m.modelType = modelType
218
219 providers := config.Providers()
220 modelItems := []util.Model{}
221 selectIndex := 0
222
223 cfg := config.Get()
224 var currentModel config.PreferredModel
225 if m.modelType == LargeModelType {
226 currentModel = cfg.Models.Large
227 } else {
228 currentModel = cfg.Models.Small
229 }
230
231 for _, provider := range providers {
232 name := provider.Name
233 if name == "" {
234 name = string(provider.ID)
235 }
236 modelItems = append(modelItems, commands.NewItemSection(name))
237 for _, model := range provider.Models {
238 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
239 Provider: provider,
240 Model: model,
241 }))
242 if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider {
243 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
244 }
245 }
246 }
247
248 return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
249}