models.go

  1package models
  2
  3import (
  4	"github.com/charmbracelet/bubbles/v2/help"
  5	"github.com/charmbracelet/bubbles/v2/key"
  6	tea "github.com/charmbracelet/bubbletea/v2"
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/fur/provider"
  9	"github.com/charmbracelet/crush/internal/tui/components/completions"
 10	"github.com/charmbracelet/crush/internal/tui/components/core"
 11	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 12	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 13	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 14	"github.com/charmbracelet/crush/internal/tui/styles"
 15	"github.com/charmbracelet/crush/internal/tui/util"
 16	"github.com/charmbracelet/lipgloss/v2"
 17)
 18
 19const (
 20	ModelsDialogID dialogs.DialogID = "models"
 21
 22	defaultWidth = 60
 23)
 24
 25// ModelSelectedMsg is sent when a model is selected
 26type ModelSelectedMsg struct {
 27	Model config.PreferredModel
 28}
 29
 30// CloseModelDialogMsg is sent when a model is selected
 31type CloseModelDialogMsg struct{}
 32
 33// ModelDialog interface for the model selection dialog
 34type ModelDialog interface {
 35	dialogs.DialogModel
 36}
 37
 38type ModelOption struct {
 39	Provider provider.Provider
 40	Model    provider.Model
 41}
 42
 43type modelDialogCmp struct {
 44	width   int
 45	wWidth  int // Width of the terminal window
 46	wHeight int // Height of the terminal window
 47
 48	modelList list.ListModel
 49	keyMap    KeyMap
 50	help      help.Model
 51}
 52
 53func NewModelDialogCmp() ModelDialog {
 54	listKeyMap := list.DefaultKeyMap()
 55	keyMap := DefaultKeyMap()
 56
 57	listKeyMap.Down.SetEnabled(false)
 58	listKeyMap.Up.SetEnabled(false)
 59	listKeyMap.HalfPageDown.SetEnabled(false)
 60	listKeyMap.HalfPageUp.SetEnabled(false)
 61	listKeyMap.Home.SetEnabled(false)
 62	listKeyMap.End.SetEnabled(false)
 63
 64	listKeyMap.DownOneItem = keyMap.Next
 65	listKeyMap.UpOneItem = keyMap.Previous
 66
 67	t := styles.CurrentTheme()
 68	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 69	modelList := list.New(
 70		list.WithFilterable(true),
 71		list.WithKeyMap(listKeyMap),
 72		list.WithInputStyle(inputStyle),
 73		list.WithWrapNavigation(true),
 74	)
 75	help := help.New()
 76	help.Styles = t.S().Help
 77
 78	return &modelDialogCmp{
 79		modelList: modelList,
 80		width:     defaultWidth,
 81		keyMap:    DefaultKeyMap(),
 82		help:      help,
 83	}
 84}
 85
 86func (m *modelDialogCmp) Init() tea.Cmd {
 87	providers := config.Providers()
 88
 89	modelItems := []util.Model{}
 90	selectIndex := 0
 91	agentModel := config.GetAgentModel(config.AgentCoder)
 92	agentProvider := config.GetAgentProvider(config.AgentCoder)
 93	for _, provider := range providers {
 94		name := provider.Name
 95		if name == "" {
 96			name = string(provider.ID)
 97		}
 98		modelItems = append(modelItems, commands.NewItemSection(name))
 99		for _, model := range provider.Models {
100			if model.ID == agentModel.ID && provider.ID == agentProvider.ID {
101				selectIndex = len(modelItems) // Set the selected index to the current model
102			}
103			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
104				Provider: provider,
105				Model:    model,
106			}))
107		}
108	}
109
110	return tea.Sequence(m.modelList.Init(), m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
111}
112
113func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
114	switch msg := msg.(type) {
115	case tea.WindowSizeMsg:
116		m.wWidth = msg.Width
117		m.wHeight = msg.Height
118		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
119	case tea.KeyPressMsg:
120		switch {
121		case key.Matches(msg, m.keyMap.Select):
122			selectedItemInx := m.modelList.SelectedIndex()
123			if selectedItemInx == list.NoSelection {
124				return m, nil // No item selected, do nothing
125			}
126			items := m.modelList.Items()
127			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
128
129			return m, tea.Sequence(
130				util.CmdHandler(dialogs.CloseDialogMsg{}),
131				util.CmdHandler(ModelSelectedMsg{Model: config.PreferredModel{
132					ModelID:  selectedItem.Model.ID,
133					Provider: selectedItem.Provider.ID,
134				}}),
135			)
136		case key.Matches(msg, m.keyMap.Close):
137			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
138		default:
139			u, cmd := m.modelList.Update(msg)
140			m.modelList = u.(list.ListModel)
141			return m, cmd
142		}
143	}
144	return m, nil
145}
146
147func (m *modelDialogCmp) View() tea.View {
148	t := styles.CurrentTheme()
149	listView := m.modelList.View()
150	content := lipgloss.JoinVertical(
151		lipgloss.Left,
152		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-4)),
153		listView.String(),
154		"",
155		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
156	)
157	v := tea.NewView(m.style().Render(content))
158	if listView.Cursor() != nil {
159		c := m.moveCursor(listView.Cursor())
160		v.SetCursor(c)
161	}
162	return v
163}
164
165func (m *modelDialogCmp) style() lipgloss.Style {
166	t := styles.CurrentTheme()
167	return t.S().Base.
168		Width(m.width).
169		Border(lipgloss.RoundedBorder()).
170		BorderForeground(t.BorderFocus)
171}
172
173func (m *modelDialogCmp) listWidth() int {
174	return defaultWidth - 2 // 4 for padding
175}
176
177func (m *modelDialogCmp) listHeight() int {
178	listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
179	return min(listHeigh, m.wHeight/2)
180}
181
182func (m *modelDialogCmp) Position() (int, int) {
183	row := m.wHeight/4 - 2 // just a bit above the center
184	col := m.wWidth / 2
185	col -= m.width / 2
186	return row, col
187}
188
189func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
190	row, col := m.Position()
191	offset := row + 3 // Border + title
192	cursor.Y += offset
193	cursor.X = cursor.X + col + 2
194	return cursor
195}
196
197func (m *modelDialogCmp) ID() dialogs.DialogID {
198	return ModelsDialogID
199}