1package models
2
3import (
4 "slices"
5
6 "github.com/charmbracelet/bubbles/v2/help"
7 "github.com/charmbracelet/bubbles/v2/key"
8 tea "github.com/charmbracelet/bubbletea/v2"
9 "github.com/charmbracelet/lipgloss/v2"
10 "github.com/opencode-ai/opencode/internal/config"
11 "github.com/opencode-ai/opencode/internal/llm/models"
12 "github.com/opencode-ai/opencode/internal/tui/components/completions"
13 "github.com/opencode-ai/opencode/internal/tui/components/core"
14 "github.com/opencode-ai/opencode/internal/tui/components/core/list"
15 "github.com/opencode-ai/opencode/internal/tui/components/dialogs"
16 "github.com/opencode-ai/opencode/internal/tui/components/dialogs/commands"
17 "github.com/opencode-ai/opencode/internal/tui/styles"
18 "github.com/opencode-ai/opencode/internal/tui/util"
19)
20
21const (
22 ID dialogs.DialogID = "models"
23
24 defaultWidth = 60
25)
26
27// ModelSelectedMsg is sent when a model is selected
28type ModelSelectedMsg struct {
29 Model models.Model
30}
31
32// CloseModelDialogMsg is sent when a model is selected
33type CloseModelDialogMsg struct{}
34
35// ModelDialog interface for the model selection dialog
36type ModelDialog interface {
37 dialogs.DialogModel
38}
39
40type modelDialogCmp struct {
41 width int
42 wWidth int // Width of the terminal window
43 wHeight int // Height of the terminal window
44
45 modelList list.ListModel
46 keyMap KeyMap
47 help help.Model
48}
49
50func NewModelDialogCmp() ModelDialog {
51 listKeyMap := list.DefaultKeyMap()
52 keyMap := DefaultKeyMap()
53
54 listKeyMap.Down.SetEnabled(false)
55 listKeyMap.Up.SetEnabled(false)
56 listKeyMap.NDown.SetEnabled(false)
57 listKeyMap.NUp.SetEnabled(false)
58 listKeyMap.HalfPageDown.SetEnabled(false)
59 listKeyMap.HalfPageUp.SetEnabled(false)
60 listKeyMap.Home.SetEnabled(false)
61 listKeyMap.End.SetEnabled(false)
62
63 listKeyMap.DownOneItem = keyMap.Next
64 listKeyMap.UpOneItem = keyMap.Previous
65
66 t := styles.CurrentTheme()
67 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
68 modelList := list.New(
69 list.WithFilterable(true),
70 list.WithKeyMap(listKeyMap),
71 list.WithInputStyle(inputStyle),
72 list.WithWrapNavigation(true),
73 )
74 help := help.New()
75 help.Styles = t.S().Help
76
77 return &modelDialogCmp{
78 modelList: modelList,
79 width: defaultWidth,
80 keyMap: DefaultKeyMap(),
81 help: help,
82 }
83}
84
85var ProviderPopularity = map[models.ModelProvider]int{
86 models.ProviderAnthropic: 1,
87 models.ProviderOpenAI: 2,
88 models.ProviderGemini: 3,
89 models.ProviderGROQ: 4,
90 models.ProviderOpenRouter: 5,
91 models.ProviderBedrock: 6,
92 models.ProviderAzure: 7,
93 models.ProviderVertexAI: 8,
94 models.ProviderXAI: 9,
95}
96
97var ProviderName = map[models.ModelProvider]string{
98 models.ProviderAnthropic: "Anthropic",
99 models.ProviderOpenAI: "OpenAI",
100 models.ProviderGemini: "Gemini",
101 models.ProviderGROQ: "Groq",
102 models.ProviderOpenRouter: "OpenRouter",
103 models.ProviderBedrock: "AWS Bedrock",
104 models.ProviderAzure: "Azure",
105 models.ProviderVertexAI: "VertexAI",
106 models.ProviderXAI: "xAI",
107}
108
109func (m *modelDialogCmp) Init() tea.Cmd {
110 cfg := config.Get()
111 enabledProviders := getEnabledProviders(cfg)
112
113 modelItems := []util.Model{}
114 for _, provider := range enabledProviders {
115 name, ok := ProviderName[provider]
116 if !ok {
117 name = string(provider) // Fallback to provider ID if name is not defined
118 }
119 modelItems = append(modelItems, commands.NewItemSection(name))
120 for _, model := range getModelsForProvider(provider) {
121 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model))
122 }
123 }
124 m.modelList.SetItems(modelItems)
125 return m.modelList.Init()
126}
127
128func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
129 switch msg := msg.(type) {
130 case tea.WindowSizeMsg:
131 m.wWidth = msg.Width
132 m.wHeight = msg.Height
133 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
134 case tea.KeyPressMsg:
135 switch {
136 case key.Matches(msg, m.keyMap.Select):
137 selectedItemInx := m.modelList.SelectedIndex()
138 if selectedItemInx == list.NoSelection {
139 return m, nil // No item selected, do nothing
140 }
141 items := m.modelList.Items()
142 selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model)
143
144 return m, tea.Sequence(
145 util.CmdHandler(dialogs.CloseDialogMsg{}),
146 util.CmdHandler(ModelSelectedMsg{Model: selectedItem}),
147 )
148 default:
149 u, cmd := m.modelList.Update(msg)
150 m.modelList = u.(list.ListModel)
151 return m, cmd
152 }
153 }
154 return m, nil
155}
156
157func (m *modelDialogCmp) View() tea.View {
158 t := styles.CurrentTheme()
159 listView := m.modelList.View()
160 content := lipgloss.JoinVertical(
161 lipgloss.Left,
162 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-4)),
163 listView.String(),
164 "",
165 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
166 )
167 v := tea.NewView(m.style().Render(content))
168 if listView.Cursor() != nil {
169 c := m.moveCursor(listView.Cursor())
170 v.SetCursor(c)
171 }
172 return v
173}
174
175func (m *modelDialogCmp) style() lipgloss.Style {
176 t := styles.CurrentTheme()
177 return t.S().Base.
178 Width(m.width).
179 Border(lipgloss.RoundedBorder()).
180 BorderForeground(t.BorderFocus)
181}
182
183func (m *modelDialogCmp) listWidth() int {
184 return defaultWidth - 2 // 4 for padding
185}
186
187func (m *modelDialogCmp) listHeight() int {
188 listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
189 return min(listHeigh, m.wHeight/2)
190}
191
192func GetSelectedModel(cfg *config.Config) models.Model {
193 agentCfg := cfg.Agents[config.AgentCoder]
194 selectedModelId := agentCfg.Model
195 return models.SupportedModels[selectedModelId]
196}
197
198func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
199 var providers []models.ModelProvider
200 for providerId, provider := range cfg.Providers {
201 if !provider.Disabled {
202 providers = append(providers, providerId)
203 }
204 }
205
206 // Sort by provider popularity
207 slices.SortFunc(providers, func(a, b models.ModelProvider) int {
208 rA := ProviderPopularity[a]
209 rB := ProviderPopularity[b]
210
211 // models not included in popularity ranking default to last
212 if rA == 0 {
213 rA = 999
214 }
215 if rB == 0 {
216 rB = 999
217 }
218 return rA - rB
219 })
220 return providers
221}
222
223func getModelsForProvider(provider models.ModelProvider) []models.Model {
224 var providerModels []models.Model
225 for _, model := range models.SupportedModels {
226 if model.Provider == provider {
227 providerModels = append(providerModels, model)
228 }
229 }
230
231 // reverse alphabetical order (if llm naming was consistent latest would appear first)
232 slices.SortFunc(providerModels, func(a, b models.Model) int {
233 if a.Name > b.Name {
234 return -1
235 } else if a.Name < b.Name {
236 return 1
237 }
238 return 0
239 })
240
241 return providerModels
242}
243
244func (m *modelDialogCmp) Position() (int, int) {
245 row := m.wHeight/4 - 2 // just a bit above the center
246 col := m.wWidth / 2
247 col -= m.width / 2
248 return row, col
249}
250
251func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
252 row, col := m.Position()
253 offset := row + 3 // Border + title
254 cursor.Y += offset
255 cursor.X = cursor.X + col + 2
256 return cursor
257}
258
259func (m *modelDialogCmp) ID() dialogs.DialogID {
260 return ID
261}