models.go

  1package dialog
  2
  3import (
  4	"fmt"
  5	"slices"
  6	"strings"
  7
  8	"github.com/charmbracelet/bubbles/key"
  9	tea "github.com/charmbracelet/bubbletea"
 10	"github.com/charmbracelet/lipgloss"
 11	"github.com/opencode-ai/opencode/internal/config"
 12	"github.com/opencode-ai/opencode/internal/llm/models"
 13	"github.com/opencode-ai/opencode/internal/tui/layout"
 14	"github.com/opencode-ai/opencode/internal/tui/styles"
 15	"github.com/opencode-ai/opencode/internal/tui/theme"
 16	"github.com/opencode-ai/opencode/internal/tui/util"
 17)
 18
 19const (
 20	numVisibleModels = 10
 21	maxDialogWidth   = 40
 22)
 23
 24// ModelSelectedMsg is sent when a model is selected
 25type ModelSelectedMsg struct {
 26	Model models.Model
 27}
 28
 29// CloseModelDialogMsg is sent when a model is selected
 30type CloseModelDialogMsg struct{}
 31
 32// ModelDialog interface for the model selection dialog
 33type ModelDialog interface {
 34	tea.Model
 35	layout.Bindings
 36}
 37
 38type modelDialogCmp struct {
 39	models             []models.Model
 40	provider           models.ModelProvider
 41	availableProviders []models.ModelProvider
 42
 43	selectedIdx     int
 44	width           int
 45	height          int
 46	scrollOffset    int
 47	hScrollOffset   int
 48	hScrollPossible bool
 49}
 50
 51type modelKeyMap struct {
 52	Up     key.Binding
 53	Down   key.Binding
 54	Left   key.Binding
 55	Right  key.Binding
 56	Enter  key.Binding
 57	Escape key.Binding
 58	J      key.Binding
 59	K      key.Binding
 60	H      key.Binding
 61	L      key.Binding
 62}
 63
 64var modelKeys = modelKeyMap{
 65	Up: key.NewBinding(
 66		key.WithKeys("up"),
 67		key.WithHelp("↑", "previous model"),
 68	),
 69	Down: key.NewBinding(
 70		key.WithKeys("down"),
 71		key.WithHelp("↓", "next model"),
 72	),
 73	Left: key.NewBinding(
 74		key.WithKeys("left"),
 75		key.WithHelp("←", "scroll left"),
 76	),
 77	Right: key.NewBinding(
 78		key.WithKeys("right"),
 79		key.WithHelp("→", "scroll right"),
 80	),
 81	Enter: key.NewBinding(
 82		key.WithKeys("enter"),
 83		key.WithHelp("enter", "select model"),
 84	),
 85	Escape: key.NewBinding(
 86		key.WithKeys("esc"),
 87		key.WithHelp("esc", "close"),
 88	),
 89	J: key.NewBinding(
 90		key.WithKeys("j"),
 91		key.WithHelp("j", "next model"),
 92	),
 93	K: key.NewBinding(
 94		key.WithKeys("k"),
 95		key.WithHelp("k", "previous model"),
 96	),
 97	H: key.NewBinding(
 98		key.WithKeys("h"),
 99		key.WithHelp("h", "scroll left"),
100	),
101	L: key.NewBinding(
102		key.WithKeys("l"),
103		key.WithHelp("l", "scroll right"),
104	),
105}
106
107func (m *modelDialogCmp) Init() tea.Cmd {
108	m.setupModels()
109	return nil
110}
111
112func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
113	switch msg := msg.(type) {
114	case tea.KeyMsg:
115		switch {
116		case key.Matches(msg, modelKeys.Up) || key.Matches(msg, modelKeys.K):
117			m.moveSelectionUp()
118		case key.Matches(msg, modelKeys.Down) || key.Matches(msg, modelKeys.J):
119			m.moveSelectionDown()
120		case key.Matches(msg, modelKeys.Left) || key.Matches(msg, modelKeys.H):
121			if m.hScrollPossible {
122				m.switchProvider(-1)
123			}
124		case key.Matches(msg, modelKeys.Right) || key.Matches(msg, modelKeys.L):
125			if m.hScrollPossible {
126				m.switchProvider(1)
127			}
128		case key.Matches(msg, modelKeys.Enter):
129			util.ReportInfo(fmt.Sprintf("selected model: %s", m.models[m.selectedIdx].Name))
130			return m, util.CmdHandler(ModelSelectedMsg{Model: m.models[m.selectedIdx]})
131		case key.Matches(msg, modelKeys.Escape):
132			return m, util.CmdHandler(CloseModelDialogMsg{})
133		}
134	case tea.WindowSizeMsg:
135		m.width = msg.Width
136		m.height = msg.Height
137	}
138
139	return m, nil
140}
141
142// moveSelectionUp moves the selection up or wraps to bottom
143func (m *modelDialogCmp) moveSelectionUp() {
144	if m.selectedIdx > 0 {
145		m.selectedIdx--
146	} else {
147		m.selectedIdx = len(m.models) - 1
148		m.scrollOffset = max(0, len(m.models)-numVisibleModels)
149	}
150
151	// Keep selection visible
152	if m.selectedIdx < m.scrollOffset {
153		m.scrollOffset = m.selectedIdx
154	}
155}
156
157// moveSelectionDown moves the selection down or wraps to top
158func (m *modelDialogCmp) moveSelectionDown() {
159	if m.selectedIdx < len(m.models)-1 {
160		m.selectedIdx++
161	} else {
162		m.selectedIdx = 0
163		m.scrollOffset = 0
164	}
165
166	// Keep selection visible
167	if m.selectedIdx >= m.scrollOffset+numVisibleModels {
168		m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
169	}
170}
171
172func (m *modelDialogCmp) switchProvider(offset int) {
173	newOffset := m.hScrollOffset + offset
174
175	// Ensure we stay within bounds
176	if newOffset < 0 {
177		newOffset = len(m.availableProviders) - 1
178	}
179	if newOffset >= len(m.availableProviders) {
180		newOffset = 0
181	}
182
183	m.hScrollOffset = newOffset
184	m.provider = m.availableProviders[m.hScrollOffset]
185	m.setupModelsForProvider(m.provider)
186}
187
188func (m *modelDialogCmp) View() string {
189	t := theme.CurrentTheme()
190	baseStyle := styles.BaseStyle()
191
192	// Capitalize first letter of provider name
193	providerName := strings.ToUpper(string(m.provider)[:1]) + string(m.provider[1:])
194	title := baseStyle.
195		Foreground(t.Primary()).
196		Bold(true).
197		Width(maxDialogWidth).
198		Padding(0, 0, 1).
199		Render(fmt.Sprintf("Select %s Model", providerName))
200
201	// Render visible models
202	endIdx := min(m.scrollOffset+numVisibleModels, len(m.models))
203	modelItems := make([]string, 0, endIdx-m.scrollOffset)
204
205	for i := m.scrollOffset; i < endIdx; i++ {
206		itemStyle := baseStyle.Width(maxDialogWidth)
207		if i == m.selectedIdx {
208			itemStyle = itemStyle.Background(t.Primary()).
209				Foreground(t.Background()).Bold(true)
210		}
211		modelItems = append(modelItems, itemStyle.Render(m.models[i].Name))
212	}
213
214	scrollIndicator := m.getScrollIndicators(maxDialogWidth)
215
216	content := lipgloss.JoinVertical(
217		lipgloss.Left,
218		title,
219		baseStyle.Width(maxDialogWidth).Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)),
220		scrollIndicator,
221	)
222
223	return baseStyle.Padding(1, 2).
224		Border(lipgloss.RoundedBorder()).
225		BorderBackground(t.Background()).
226		BorderForeground(t.TextMuted()).
227		Width(lipgloss.Width(content) + 4).
228		Render(content)
229}
230
231func (m *modelDialogCmp) getScrollIndicators(maxWidth int) string {
232	var indicator string
233
234	if len(m.models) > numVisibleModels {
235		if m.scrollOffset > 0 {
236			indicator += "↑ "
237		}
238		if m.scrollOffset+numVisibleModels < len(m.models) {
239			indicator += "↓ "
240		}
241	}
242
243	if m.hScrollPossible {
244		if m.hScrollOffset > 0 {
245			indicator = "← " + indicator
246		}
247		if m.hScrollOffset < len(m.availableProviders)-1 {
248			indicator += "→"
249		}
250	}
251
252	if indicator == "" {
253		return ""
254	}
255
256	t := theme.CurrentTheme()
257	baseStyle := styles.BaseStyle()
258
259	return baseStyle.
260		Foreground(t.Primary()).
261		Width(maxWidth).
262		Align(lipgloss.Right).
263		Bold(true).
264		Render(indicator)
265}
266
267func (m *modelDialogCmp) BindingKeys() []key.Binding {
268	return layout.KeyMapToSlice(modelKeys)
269}
270
271func (m *modelDialogCmp) setupModels() {
272	cfg := config.Get()
273	modelInfo := GetSelectedModel(cfg)
274	m.availableProviders = getEnabledProviders(cfg)
275	m.hScrollPossible = len(m.availableProviders) > 1
276
277	m.provider = modelInfo.Provider
278	m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider)
279
280	m.setupModelsForProvider(m.provider)
281}
282
283func GetSelectedModel(cfg *config.Config) models.Model {
284
285	agentCfg := cfg.Agents[config.AgentCoder]
286	selectedModelId := agentCfg.Model
287	return models.SupportedModels[selectedModelId]
288}
289
290func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
291	var providers []models.ModelProvider
292	for providerId, provider := range cfg.Providers {
293		if !provider.Disabled {
294			providers = append(providers, providerId)
295		}
296	}
297
298	// Sort by provider popularity
299	slices.SortFunc(providers, func(a, b models.ModelProvider) int {
300		rA := models.ProviderPopularity[a]
301		rB := models.ProviderPopularity[b]
302
303		// models not included in popularity ranking default to last
304		if rA == 0 {
305			rA = 999
306		}
307		if rB == 0 {
308			rB = 999
309		}
310		return rA - rB
311	})
312	return providers
313}
314
315// findProviderIndex returns the index of the provider in the list, or -1 if not found
316func findProviderIndex(providers []models.ModelProvider, provider models.ModelProvider) int {
317	for i, p := range providers {
318		if p == provider {
319			return i
320		}
321	}
322	return -1
323}
324
325func (m *modelDialogCmp) setupModelsForProvider(provider models.ModelProvider) {
326	cfg := config.Get()
327	agentCfg := cfg.Agents[config.AgentCoder]
328	selectedModelId := agentCfg.Model
329
330	m.provider = provider
331	m.models = getModelsForProvider(provider)
332	m.selectedIdx = 0
333	m.scrollOffset = 0
334
335	// Try to select the current model if it belongs to this provider
336	if provider == models.SupportedModels[selectedModelId].Provider {
337		for i, model := range m.models {
338			if model.ID == selectedModelId {
339				m.selectedIdx = i
340				// Adjust scroll position to keep selected model visible
341				if m.selectedIdx >= numVisibleModels {
342					m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
343				}
344				break
345			}
346		}
347	}
348}
349
350func getModelsForProvider(provider models.ModelProvider) []models.Model {
351	var providerModels []models.Model
352	for _, model := range models.SupportedModels {
353		if model.Provider == provider {
354			providerModels = append(providerModels, model)
355		}
356	}
357
358	// reverse alphabetical order (if llm naming was consistent latest would appear first)
359	slices.SortFunc(providerModels, func(a, b models.Model) int {
360		if a.Name > b.Name {
361			return -1
362		} else if a.Name < b.Name {
363			return 1
364		}
365		return 0
366	})
367
368	return providerModels
369}
370
371func NewModelDialogCmp() ModelDialog {
372	return &modelDialogCmp{}
373}