models.go

  1package dialog
  2
  3import (
  4	"fmt"
  5	"slices"
  6	"strings"
  7
  8	"github.com/charmbracelet/bubbles/v2/key"
  9	tea "github.com/charmbracelet/bubbletea/v2"
 10	"github.com/charmbracelet/lipgloss/v2"
 11	"github.com/opencode-ai/opencode/internal/config"
 12	"github.com/opencode-ai/opencode/internal/llm/models"
 13	"github.com/opencode-ai/opencode/internal/tui/layout"
 14	"github.com/opencode-ai/opencode/internal/tui/styles"
 15	"github.com/opencode-ai/opencode/internal/tui/theme"
 16	"github.com/opencode-ai/opencode/internal/tui/util"
 17)
 18
 19const (
 20	numVisibleModels = 10
 21	maxDialogWidth   = 40
 22)
 23
 24// ModelSelectedMsg is sent when a model is selected
 25type ModelSelectedMsg struct {
 26	Model models.Model
 27}
 28
 29// CloseModelDialogMsg is sent when a model is selected
 30type CloseModelDialogMsg struct{}
 31
 32// ModelDialog interface for the model selection dialog
 33type ModelDialog interface {
 34	util.Model
 35	layout.Bindings
 36}
 37
 38type modelDialogCmp struct {
 39	models             []models.Model
 40	provider           models.ModelProvider
 41	availableProviders []models.ModelProvider
 42
 43	selectedIdx     int
 44	width           int
 45	height          int
 46	scrollOffset    int
 47	hScrollOffset   int
 48	hScrollPossible bool
 49}
 50
 51type modelKeyMap struct {
 52	Up     key.Binding
 53	Down   key.Binding
 54	Left   key.Binding
 55	Right  key.Binding
 56	Enter  key.Binding
 57	Escape key.Binding
 58	J      key.Binding
 59	K      key.Binding
 60	H      key.Binding
 61	L      key.Binding
 62}
 63
 64var modelKeys = modelKeyMap{
 65	Up: key.NewBinding(
 66		key.WithKeys("up"),
 67		key.WithHelp("↑", "previous model"),
 68	),
 69	Down: key.NewBinding(
 70		key.WithKeys("down"),
 71		key.WithHelp("↓", "next model"),
 72	),
 73	Left: key.NewBinding(
 74		key.WithKeys("left"),
 75		key.WithHelp("←", "scroll left"),
 76	),
 77	Right: key.NewBinding(
 78		key.WithKeys("right"),
 79		key.WithHelp("→", "scroll right"),
 80	),
 81	Enter: key.NewBinding(
 82		key.WithKeys("enter"),
 83		key.WithHelp("enter", "select model"),
 84	),
 85	Escape: key.NewBinding(
 86		key.WithKeys("esc"),
 87		key.WithHelp("esc", "close"),
 88	),
 89	J: key.NewBinding(
 90		key.WithKeys("j"),
 91		key.WithHelp("j", "next model"),
 92	),
 93	K: key.NewBinding(
 94		key.WithKeys("k"),
 95		key.WithHelp("k", "previous model"),
 96	),
 97	H: key.NewBinding(
 98		key.WithKeys("h"),
 99		key.WithHelp("h", "scroll left"),
100	),
101	L: key.NewBinding(
102		key.WithKeys("l"),
103		key.WithHelp("l", "scroll right"),
104	),
105}
106
107func (m *modelDialogCmp) Init() tea.Cmd {
108	m.setupModels()
109	return nil
110}
111
112func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
113	switch msg := msg.(type) {
114	case tea.KeyMsg:
115		switch {
116		case key.Matches(msg, modelKeys.Up) || key.Matches(msg, modelKeys.K):
117			m.moveSelectionUp()
118		case key.Matches(msg, modelKeys.Down) || key.Matches(msg, modelKeys.J):
119			m.moveSelectionDown()
120		case key.Matches(msg, modelKeys.Left) || key.Matches(msg, modelKeys.H):
121			if m.hScrollPossible {
122				m.switchProvider(-1)
123			}
124		case key.Matches(msg, modelKeys.Right) || key.Matches(msg, modelKeys.L):
125			if m.hScrollPossible {
126				m.switchProvider(1)
127			}
128		case key.Matches(msg, modelKeys.Enter):
129			util.ReportInfo(fmt.Sprintf("selected model: %s", m.models[m.selectedIdx].Name))
130			return m, util.CmdHandler(ModelSelectedMsg{Model: m.models[m.selectedIdx]})
131		case key.Matches(msg, modelKeys.Escape):
132			return m, util.CmdHandler(CloseModelDialogMsg{})
133		}
134	case tea.WindowSizeMsg:
135		m.width = msg.Width
136		m.height = msg.Height
137	}
138
139	return m, nil
140}
141
142// moveSelectionUp moves the selection up or wraps to bottom
143func (m *modelDialogCmp) moveSelectionUp() {
144	if m.selectedIdx > 0 {
145		m.selectedIdx--
146	} else {
147		m.selectedIdx = len(m.models) - 1
148		m.scrollOffset = max(0, len(m.models)-numVisibleModels)
149	}
150
151	// Keep selection visible
152	if m.selectedIdx < m.scrollOffset {
153		m.scrollOffset = m.selectedIdx
154	}
155}
156
157// moveSelectionDown moves the selection down or wraps to top
158func (m *modelDialogCmp) moveSelectionDown() {
159	if m.selectedIdx < len(m.models)-1 {
160		m.selectedIdx++
161	} else {
162		m.selectedIdx = 0
163		m.scrollOffset = 0
164	}
165
166	// Keep selection visible
167	if m.selectedIdx >= m.scrollOffset+numVisibleModels {
168		m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
169	}
170}
171
172func (m *modelDialogCmp) switchProvider(offset int) {
173	newOffset := m.hScrollOffset + offset
174
175	// Ensure we stay within bounds
176	if newOffset < 0 {
177		newOffset = len(m.availableProviders) - 1
178	}
179	if newOffset >= len(m.availableProviders) {
180		newOffset = 0
181	}
182
183	m.hScrollOffset = newOffset
184	m.provider = m.availableProviders[m.hScrollOffset]
185	m.setupModelsForProvider(m.provider)
186}
187
188func (m *modelDialogCmp) View() string {
189	t := theme.CurrentTheme()
190	baseStyle := styles.BaseStyle()
191
192	// Capitalize first letter of provider name
193	providerName := strings.ToUpper(string(m.provider)[:1]) + string(m.provider[1:])
194	title := baseStyle.
195		Foreground(t.Primary()).
196		Bold(true).
197		Width(maxDialogWidth).
198		Padding(0, 0, 1).
199		Render(fmt.Sprintf("Select %s Model", providerName))
200
201	// Render visible models
202	endIdx := min(m.scrollOffset+numVisibleModels, len(m.models))
203	modelItems := make([]string, 0, endIdx-m.scrollOffset)
204
205	for i := m.scrollOffset; i < endIdx; i++ {
206		itemStyle := baseStyle.Width(maxDialogWidth)
207		if i == m.selectedIdx {
208			itemStyle = itemStyle.Background(t.Primary()).
209				Foreground(t.Background()).Bold(true)
210		}
211		modelItems = append(modelItems, itemStyle.Render(m.models[i].Name))
212	}
213
214	scrollIndicator := m.getScrollIndicators(maxDialogWidth)
215
216	content := lipgloss.JoinVertical(
217		lipgloss.Left,
218		title,
219		baseStyle.Width(maxDialogWidth).Render(lipgloss.JoinVertical(lipgloss.Left, modelItems...)),
220		scrollIndicator,
221	)
222
223	return baseStyle.Padding(1, 2).
224		Border(lipgloss.RoundedBorder()).
225		BorderBackground(t.Background()).
226		BorderForeground(t.TextMuted()).
227		Width(lipgloss.Width(content) + 4).
228		Render(content)
229}
230
231func (m *modelDialogCmp) getScrollIndicators(maxWidth int) string {
232	var indicator string
233
234	if len(m.models) > numVisibleModels {
235		if m.scrollOffset > 0 {
236			indicator += "↑ "
237		}
238		if m.scrollOffset+numVisibleModels < len(m.models) {
239			indicator += "↓ "
240		}
241	}
242
243	if m.hScrollPossible {
244		if m.hScrollOffset > 0 {
245			indicator = "← " + indicator
246		}
247		if m.hScrollOffset < len(m.availableProviders)-1 {
248			indicator += "→"
249		}
250	}
251
252	if indicator == "" {
253		return ""
254	}
255
256	t := theme.CurrentTheme()
257	baseStyle := styles.BaseStyle()
258
259	return baseStyle.
260		Foreground(t.Primary()).
261		Width(maxWidth).
262		Align(lipgloss.Right).
263		Bold(true).
264		Render(indicator)
265}
266
267func (m *modelDialogCmp) BindingKeys() []key.Binding {
268	return layout.KeyMapToSlice(modelKeys)
269}
270
271func (m *modelDialogCmp) setupModels() {
272	cfg := config.Get()
273	modelInfo := GetSelectedModel(cfg)
274	m.availableProviders = getEnabledProviders(cfg)
275	m.hScrollPossible = len(m.availableProviders) > 1
276
277	m.provider = modelInfo.Provider
278	m.hScrollOffset = findProviderIndex(m.availableProviders, m.provider)
279
280	m.setupModelsForProvider(m.provider)
281}
282
283func GetSelectedModel(cfg *config.Config) models.Model {
284	agentCfg := cfg.Agents[config.AgentCoder]
285	selectedModelId := agentCfg.Model
286	return models.SupportedModels[selectedModelId]
287}
288
289func getEnabledProviders(cfg *config.Config) []models.ModelProvider {
290	var providers []models.ModelProvider
291	for providerId, provider := range cfg.Providers {
292		if !provider.Disabled {
293			providers = append(providers, providerId)
294		}
295	}
296
297	// Sort by provider popularity
298	slices.SortFunc(providers, func(a, b models.ModelProvider) int {
299		rA := models.ProviderPopularity[a]
300		rB := models.ProviderPopularity[b]
301
302		// models not included in popularity ranking default to last
303		if rA == 0 {
304			rA = 999
305		}
306		if rB == 0 {
307			rB = 999
308		}
309		return rA - rB
310	})
311	return providers
312}
313
314// findProviderIndex returns the index of the provider in the list, or -1 if not found
315func findProviderIndex(providers []models.ModelProvider, provider models.ModelProvider) int {
316	for i, p := range providers {
317		if p == provider {
318			return i
319		}
320	}
321	return -1
322}
323
324func (m *modelDialogCmp) setupModelsForProvider(provider models.ModelProvider) {
325	cfg := config.Get()
326	agentCfg := cfg.Agents[config.AgentCoder]
327	selectedModelId := agentCfg.Model
328
329	m.provider = provider
330	m.models = getModelsForProvider(provider)
331	m.selectedIdx = 0
332	m.scrollOffset = 0
333
334	// Try to select the current model if it belongs to this provider
335	if provider == models.SupportedModels[selectedModelId].Provider {
336		for i, model := range m.models {
337			if model.ID == selectedModelId {
338				m.selectedIdx = i
339				// Adjust scroll position to keep selected model visible
340				if m.selectedIdx >= numVisibleModels {
341					m.scrollOffset = m.selectedIdx - (numVisibleModels - 1)
342				}
343				break
344			}
345		}
346	}
347}
348
349func getModelsForProvider(provider models.ModelProvider) []models.Model {
350	var providerModels []models.Model
351	for _, model := range models.SupportedModels {
352		if model.Provider == provider {
353			providerModels = append(providerModels, model)
354		}
355	}
356
357	// reverse alphabetical order (if llm naming was consistent latest would appear first)
358	slices.SortFunc(providerModels, func(a, b models.Model) int {
359		if a.Name > b.Name {
360			return -1
361		} else if a.Name < b.Name {
362			return 1
363		}
364		return 0
365	})
366
367	return providerModels
368}
369
370func NewModelDialogCmp() ModelDialog {
371	return &modelDialogCmp{}
372}