models.go

  1package dialog
  2
  3import (
  4	"fmt"
  5	"slices"
  6	"strings"
  7
  8	"github.com/charmbracelet/bubbles/v2/key"
  9	tea "github.com/charmbracelet/bubbletea/v2"
 10	"github.com/charmbracelet/lipgloss/v2"
 11	"github.com/opencode-ai/opencode/internal/config"
 12	"github.com/opencode-ai/opencode/internal/llm/models"
 13	"github.com/opencode-ai/opencode/internal/tui/layout"
 14	"github.com/opencode-ai/opencode/internal/tui/styles"
 15	"github.com/opencode-ai/opencode/internal/tui/theme"
 16	"github.com/opencode-ai/opencode/internal/tui/util"
 17)
 18
 19const (
 20	numVisibleModels = 10
 21	maxDialogWidth   = 40
 22)
 23
 24// ModelSelectedMsg is sent when a model is selected
 25type ModelSelectedMsg struct {
 26	Model models.Model
 27}
 28
 29// CloseModelDialogMsg is sent when a model is selected
 30type CloseModelDialogMsg struct{}
 31
 32// ModelDialog interface for the model selection dialog
 33type ModelDialog interface {
 34	util.Model
 35	layout.Bindings
 36}
 37
 38type modelDialogCmp struct {
 39	models             []models.Model
 40	provider           models.ModelProvider
 41	availableProviders []models.ModelProvider
 42
 43	selectedIdx     int
 44	width           int
 45	height          int
 46	scrollOffset    int
 47	hScrollOffset   int
 48	hScrollPossible bool
 49}
 50
 51type modelKeyMap struct {
 52	Up     key.Binding
 53	Down   key.Binding
 54	Left   key.Binding
 55	Right  key.Binding
 56	Enter  key.Binding
 57	Escape key.Binding
 58	J      key.Binding
 59	K      key.Binding
 60	H      key.Binding
 61	L      key.Binding
 62}
 63
 64var modelKeys = modelKeyMap{
 65	Up: key.NewBinding(
 66		key.WithKeys("up"),
 67		key.WithHelp("↑", "previous model"),
 68	),
 69	Down: key.NewBinding(
 70		key.WithKeys("down"),
 71		key.WithHelp("↓", "next model"),
 72	),
 73	Left: key.NewBinding(
 74		key.WithKeys("left"),
 75		key.WithHelp("←", "scroll left"),
 76	),
 77	Right: key.NewBinding(
 78		key.WithKeys("right"),
 79		key.WithHelp("→", "scroll right"),
 80	),
 81	Enter: key.NewBinding(
 82		key.WithKeys("enter"),
 83		key.WithHelp("enter", "select model"),
 84	),
 85	Escape: key.NewBinding(
 86		key.WithKeys("esc"),
 87		key.WithHelp("esc", "close"),
 88	),
 89	J: key.NewBinding(
 90		key.WithKeys("j"),
 91		key.WithHelp("j", "next model"),
 92	),
 93	K: key.NewBinding(
 94		key.WithKeys("k"),
 95		key.WithHelp("k", "previous model"),
 96	),
 97	H: key.NewBinding(
 98		key.WithKeys("h"),
 99		key.WithHelp("h", "scroll left"),
100	),
101	L: key.NewBinding(
102		key.WithKeys("l"),
103		key.WithHelp("l", "scroll right"),
104	),
105}
106
107func (m *modelDialogCmp) Init() tea.Cmd {
108	m.setupModels()
109	return nil
110}
111
112func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
113	switch msg := msg.(type) {
114	case tea.KeyPressMsg:
115		switch {
116		case key.Matches(msg, modelKeys.Up) || key.Matches(msg, modelKeys.K):
117			m.moveSelectionUp()
118		case key.Matches(msg, modelKeys.Down) || key.Matches(msg, modelKeys.J):
119			m.moveSelectionDown()
120		case key.Matches(msg, modelKeys.Left) || key.Matches(msg, modelKeys.H):
121			if m.hScrollPossible {
122				m.switchProvider(-1)
123			}
124		case key.Matches(msg, modelKeys.Right) || key.Matches(msg, modelKeys.L):
125			if m.hScrollPossible {
126				m.switchProvider(1)
127			}
128		case key.Matches(msg, modelKeys.Enter):
129			util.ReportInfo(fmt.Sprintf("selected model: %s", m.models[m.selectedIdx].Name))
130			return m, util.CmdHandler(ModelSelectedMsg{Model: m.models[m.selectedIdx]})
131		case key.Matches(msg, modelKeys.Escape):
132			return m, util.CmdHandler(CloseModelDialogMsg{})
133		}
134	case tea.WindowSizeMsg:
135		m.width = msg.Width
136		m.height = msg.Height
137	}
138
139	return m, nil
140}
141
142// moveSelectionUp moves the selection up or wraps to bottom
143func (m *modelDialogCmp) moveSelectionUp() {
144	if m.selectedIdx > 0 {
145		m.selectedIdx--
146	} else {
147		m.selectedIdx = len(m.models) - 1
148		m.scrollOffset = max(0, len(m.models)-numVisibleModels)
149	}
150
151	// Keep selection visible
152	if m.selectedIdx < m.scrollOffset {
153		m.scrollOffset = m.selectedIdx
154	}
155}
156
157// moveSelectionDown moves the selection down or wraps to top
158func (m *modelDialogCmp) moveSelectionDown() {
159	if m.selectedIdx < len(m.models)-1 {
160		m.selectedIdx++
161	} else {
162		m.selectedIdx = 0
163		m.scrollOffset = 0
164	}
165
166	// Keep selection visible
167	if m.selectedIdx >= m.scrollOffset+numVisibleModels {
168		m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
169	}
170}
171
172func (m *modelDialogCmp) switchProvider(offset int) {
173	newOffset := m.hScrollOffset + offset
174
175	// Ensure we stay within bounds
176	if newOffset < 0 {
177		newOffset = len(m.availableProviders) - 1
178	}
179	if newOffset >= len(m.availableProviders) {
180		newOffset = 0
181	}
182
183	m.hScrollOffset = newOffset
184	m.provider = m.availableProviders[m.hScrollOffset]
185	m.setupModelsForProvider(m.provider)
186}
187
188func (m *modelDialogCmp) View() tea.View {
189	t := theme.CurrentTheme()
190	baseStyle := styles.BaseStyle()
191
192	// Capitalize first letter of provider name
193	providerName := strings.ToUpper(string(m.provider)[:1]) + string(m.provider[1:])
194	title := baseStyle.
195		Foreground(t.Primary()).
196		Bold(true).
197		Width(maxDialogWidth).
198		Padding(0, 0, 1).
199		Render(fmt.Sprintf("Select %s Model", providerName))
200
201	// Render visible models
202	endIdx := min(m.scrollOffset+numVisibleModels, len(m.models))
203	modelItems := make([]string, 0, endIdx-m.scrollOffset)
204
205	for i := m.scrollOffset; i < endIdx; i++ {
206		itemStyle := baseStyle.Width(maxDialogWidth)
207		if i == m.selectedIdx {
208			itemStyle = itemStyle.Background(t.Primary()).
209				Foreground(t.Background()).Bold(true)
210		}
211		modelItems = append(modelItems, itemStyle.Render(m.models[i].Name))
212	}
213
214	scrollIndicator := m.getScrollIndicators(maxDialogWidth)
215
216	content := lipgloss.JoinVertical(
217		lipgloss.Left,
218		title,
219		baseStyle.Width(maxDialogWidth).Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)),
220		scrollIndicator,
221	)
222
223	return tea.NewView(
224		baseStyle.Padding(1, 2).
225			Border(lipgloss.RoundedBorder()).
226			BorderBackground(t.Background()).
227			BorderForeground(t.TextMuted()).
228			Width(lipgloss.Width(content) + 4).
229			Render(content),
230	)
231}
232
233func (m *modelDialogCmp) getScrollIndicators(maxWidth int) string {
234	var indicator string
235
236	if len(m.models) > numVisibleModels {
237		if m.scrollOffset > 0 {
238			indicator += "↑ "
239		}
240		if m.scrollOffset+numVisibleModels < len(m.models) {
241			indicator += "↓ "
242		}
243	}
244
245	if m.hScrollPossible {
246		if m.hScrollOffset > 0 {
247			indicator = "← " + indicator
248		}
249		if m.hScrollOffset < len(m.availableProviders)-1 {
250			indicator += "→"
251		}
252	}
253
254	if indicator == "" {
255		return ""
256	}
257
258	t := theme.CurrentTheme()
259	baseStyle := styles.BaseStyle()
260
261	return baseStyle.
262		Foreground(t.Primary()).
263		Width(maxWidth).
264		Align(lipgloss.Right).
265		Bold(true).
266		Render(indicator)
267}
268
269func (m *modelDialogCmp) BindingKeys() []key.Binding {
270	return layout.KeyMapToSlice(modelKeys)
271}
272
273func (m *modelDialogCmp) setupModels() {
274	cfg := config.Get()
275	modelInfo := GetSelectedModel(cfg)
276	m.availableProviders = getEnabledProviders(cfg)
277	m.hScrollPossible = len(m.availableProviders) > 1
278
279	m.provider = modelInfo.Provider
280	m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider)
281
282	m.setupModelsForProvider(m.provider)
283}
284
285func GetSelectedModel(cfg *config.Config) models.Model {
286	agentCfg := cfg.Agents[config.AgentCoder]
287	selectedModelId := agentCfg.Model
288	return models.SupportedModels[selectedModelId]
289}
290
291func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
292	var providers []models.ModelProvider
293	for providerId, provider := range cfg.Providers {
294		if !provider.Disabled {
295			providers = append(providers, providerId)
296		}
297	}
298
299	// Sort by provider popularity
300	slices.SortFunc(providers, func(a, b models.ModelProvider) int {
301		rA := models.ProviderPopularity[a]
302		rB := models.ProviderPopularity[b]
303
304		// models not included in popularity ranking default to last
305		if rA == 0 {
306			rA = 999
307		}
308		if rB == 0 {
309			rB = 999
310		}
311		return rA - rB
312	})
313	return providers
314}
315
316// findProviderIndex returns the index of the provider in the list, or -1 if not found
317func findProviderIndex(providers []models.ModelProvider, provider models.ModelProvider) int {
318	for i, p := range providers {
319		if p == provider {
320			return i
321		}
322	}
323	return -1
324}
325
326func (m *modelDialogCmp) setupModelsForProvider(provider models.ModelProvider) {
327	cfg := config.Get()
328	agentCfg := cfg.Agents[config.AgentCoder]
329	selectedModelId := agentCfg.Model
330
331	m.provider = provider
332	m.models = getModelsForProvider(provider)
333	m.selectedIdx = 0
334	m.scrollOffset = 0
335
336	// Try to select the current model if it belongs to this provider
337	if provider == models.SupportedModels[selectedModelId].Provider {
338		for i, model := range m.models {
339			if model.ID == selectedModelId {
340				m.selectedIdx = i
341				// Adjust scroll position to keep selected model visible
342				if m.selectedIdx >= numVisibleModels {
343					m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
344				}
345				break
346			}
347		}
348	}
349}
350
351func getModelsForProvider(provider models.ModelProvider) []models.Model {
352	var providerModels []models.Model
353	for _, model := range models.SupportedModels {
354		if model.Provider == provider {
355			providerModels = append(providerModels, model)
356		}
357	}
358
359	// reverse alphabetical order (if llm naming was consistent latest would appear first)
360	slices.SortFunc(providerModels, func(a, b models.Model) int {
361		if a.Name > b.Name {
362			return -1
363		} else if a.Name < b.Name {
364			return 1
365		}
366		return 0
367	})
368
369	return providerModels
370}
371
372func NewModelDialogCmp() ModelDialog {
373	return &modelDialogCmp{}
374}