1package models
2
3import (
4 "github.com/charmbracelet/bubbles/v2/help"
5 "github.com/charmbracelet/bubbles/v2/key"
6 tea "github.com/charmbracelet/bubbletea/v2"
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/charmbracelet/crush/internal/fur/provider"
9 "github.com/charmbracelet/crush/internal/tui/components/completions"
10 "github.com/charmbracelet/crush/internal/tui/components/core"
11 "github.com/charmbracelet/crush/internal/tui/components/core/list"
12 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
13 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
14 "github.com/charmbracelet/crush/internal/tui/styles"
15 "github.com/charmbracelet/crush/internal/tui/util"
16 "github.com/charmbracelet/lipgloss/v2"
17)
18
19const (
20 ModelsDialogID dialogs.DialogID = "models"
21
22 defaultWidth = 60
23)
24
25// ModelSelectedMsg is sent when a model is selected
26type ModelSelectedMsg struct {
27 Model config.PreferredModel
28}
29
30// CloseModelDialogMsg is sent when a model is selected
31type CloseModelDialogMsg struct{}
32
33// ModelDialog interface for the model selection dialog
34type ModelDialog interface {
35 dialogs.DialogModel
36}
37
38type ModelOption struct {
39 Provider provider.Provider
40 Model provider.Model
41}
42
43type modelDialogCmp struct {
44 width int
45 wWidth int // Width of the terminal window
46 wHeight int // Height of the terminal window
47
48 modelList list.ListModel
49 keyMap KeyMap
50 help help.Model
51}
52
53func NewModelDialogCmp() ModelDialog {
54 listKeyMap := list.DefaultKeyMap()
55 keyMap := DefaultKeyMap()
56
57 listKeyMap.Down.SetEnabled(false)
58 listKeyMap.Up.SetEnabled(false)
59 listKeyMap.HalfPageDown.SetEnabled(false)
60 listKeyMap.HalfPageUp.SetEnabled(false)
61 listKeyMap.Home.SetEnabled(false)
62 listKeyMap.End.SetEnabled(false)
63
64 listKeyMap.DownOneItem = keyMap.Next
65 listKeyMap.UpOneItem = keyMap.Previous
66
67 t := styles.CurrentTheme()
68 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
69 modelList := list.New(
70 list.WithFilterable(true),
71 list.WithKeyMap(listKeyMap),
72 list.WithInputStyle(inputStyle),
73 list.WithWrapNavigation(true),
74 )
75 help := help.New()
76 help.Styles = t.S().Help
77
78 return &modelDialogCmp{
79 modelList: modelList,
80 width: defaultWidth,
81 keyMap: DefaultKeyMap(),
82 help: help,
83 }
84}
85
86func (m *modelDialogCmp) Init() tea.Cmd {
87 providers := config.Providers()
88
89 modelItems := []util.Model{}
90 selectIndex := 0
91 agentModel := config.GetAgentModel(config.AgentCoder)
92 agentProvider := config.GetAgentProvider(config.AgentCoder)
93 for _, provider := range providers {
94 name := provider.Name
95 if name == "" {
96 name = string(provider.ID)
97 }
98 modelItems = append(modelItems, commands.NewItemSection(name))
99 for _, model := range provider.Models {
100 if model.ID == agentModel.ID && provider.ID == agentProvider.ID {
101 selectIndex = len(modelItems) // Set the selected index to the current model
102 }
103 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
104 Provider: provider,
105 Model: model,
106 }))
107 }
108 }
109
110 return tea.Sequence(m.modelList.Init(), m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
111}
112
113func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
114 switch msg := msg.(type) {
115 case tea.WindowSizeMsg:
116 m.wWidth = msg.Width
117 m.wHeight = msg.Height
118 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
119 case tea.KeyPressMsg:
120 switch {
121 case key.Matches(msg, m.keyMap.Select):
122 selectedItemInx := m.modelList.SelectedIndex()
123 if selectedItemInx == list.NoSelection {
124 return m, nil // No item selected, do nothing
125 }
126 items := m.modelList.Items()
127 selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
128
129 return m, tea.Sequence(
130 util.CmdHandler(dialogs.CloseDialogMsg{}),
131 util.CmdHandler(ModelSelectedMsg{Model: config.PreferredModel{
132 ModelID: selectedItem.Model.ID,
133 Provider: selectedItem.Provider.ID,
134 }}),
135 )
136 case key.Matches(msg, m.keyMap.Close):
137 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
138 default:
139 u, cmd := m.modelList.Update(msg)
140 m.modelList = u.(list.ListModel)
141 return m, cmd
142 }
143 }
144 return m, nil
145}
146
147func (m *modelDialogCmp) View() tea.View {
148 t := styles.CurrentTheme()
149 listView := m.modelList.View()
150 content := lipgloss.JoinVertical(
151 lipgloss.Left,
152 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-4)),
153 listView.String(),
154 "",
155 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
156 )
157 v := tea.NewView(m.style().Render(content))
158 if listView.Cursor() != nil {
159 c := m.moveCursor(listView.Cursor())
160 v.SetCursor(c)
161 }
162 return v
163}
164
165func (m *modelDialogCmp) style() lipgloss.Style {
166 t := styles.CurrentTheme()
167 return t.S().Base.
168 Width(m.width).
169 Border(lipgloss.RoundedBorder()).
170 BorderForeground(t.BorderFocus)
171}
172
173func (m *modelDialogCmp) listWidth() int {
174 return defaultWidth - 2 // 4 for padding
175}
176
177func (m *modelDialogCmp) listHeight() int {
178 listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
179 return min(listHeigh, m.wHeight/2)
180}
181
182func (m *modelDialogCmp) Position() (int, int) {
183 row := m.wHeight/4 - 2 // just a bit above the center
184 col := m.wWidth / 2
185 col -= m.width / 2
186 return row, col
187}
188
189func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
190 row, col := m.Position()
191 offset := row + 3 // Border + title
192 cursor.Y += offset
193 cursor.X = cursor.X + col + 2
194 return cursor
195}
196
197func (m *modelDialogCmp) ID() dialogs.DialogID {
198 return ModelsDialogID
199}