models.go

  1package models
  2
  3import (
  4	"github.com/charmbracelet/bubbles/v2/help"
  5	"github.com/charmbracelet/bubbles/v2/key"
  6	tea "github.com/charmbracelet/bubbletea/v2"
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/fur/provider"
  9	"github.com/charmbracelet/crush/internal/tui/components/completions"
 10	"github.com/charmbracelet/crush/internal/tui/components/core"
 11	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 12	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 13	"github.com/charmbracelet/crush/internal/tui/styles"
 14	"github.com/charmbracelet/crush/internal/tui/util"
 15	"github.com/charmbracelet/lipgloss/v2"
 16)
 17
 18const (
 19	ModelsDialogID dialogs.DialogID = "models"
 20
 21	defaultWidth = 60
 22)
 23
 24const (
 25	LargeModelType int = iota
 26	SmallModelType
 27)
 28
 29// ModelSelectedMsg is sent when a model is selected
 30type ModelSelectedMsg struct {
 31	Model     config.SelectedModel
 32	ModelType config.SelectedModelType
 33}
 34
 35// CloseModelDialogMsg is sent when a model is selected
 36type CloseModelDialogMsg struct{}
 37
 38// ModelDialog interface for the model selection dialog
 39type ModelDialog interface {
 40	dialogs.DialogModel
 41}
 42
 43type ModelOption struct {
 44	Provider provider.Provider
 45	Model    provider.Model
 46}
 47
 48type modelDialogCmp struct {
 49	width   int
 50	wWidth  int
 51	wHeight int
 52
 53	modelList *ModelListComponent
 54	keyMap    KeyMap
 55	help      help.Model
 56}
 57
 58func NewModelDialogCmp() ModelDialog {
 59	listKeyMap := list.DefaultKeyMap()
 60	keyMap := DefaultKeyMap()
 61
 62	listKeyMap.Down.SetEnabled(false)
 63	listKeyMap.Up.SetEnabled(false)
 64	listKeyMap.HalfPageDown.SetEnabled(false)
 65	listKeyMap.HalfPageUp.SetEnabled(false)
 66	listKeyMap.Home.SetEnabled(false)
 67	listKeyMap.End.SetEnabled(false)
 68
 69	listKeyMap.DownOneItem = keyMap.Next
 70	listKeyMap.UpOneItem = keyMap.Previous
 71
 72	t := styles.CurrentTheme()
 73	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 74	modelList := NewModelListComponent(listKeyMap, inputStyle)
 75	help := help.New()
 76	help.Styles = t.S().Help
 77
 78	return &modelDialogCmp{
 79		modelList: modelList,
 80		width:     defaultWidth,
 81		keyMap:    DefaultKeyMap(),
 82		help:      help,
 83	}
 84}
 85
 86func (m *modelDialogCmp) Init() tea.Cmd {
 87	return m.modelList.Init()
 88}
 89
 90func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 91	switch msg := msg.(type) {
 92	case tea.WindowSizeMsg:
 93		m.wWidth = msg.Width
 94		m.wHeight = msg.Height
 95		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
 96	case tea.KeyPressMsg:
 97		switch {
 98		case key.Matches(msg, m.keyMap.Select):
 99			selectedItemInx := m.modelList.SelectedIndex()
100			if selectedItemInx == list.NoSelection {
101				return m, nil
102			}
103			items := m.modelList.Items()
104			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
105
106			var modelType config.SelectedModelType
107			if m.modelList.GetModelType() == LargeModelType {
108				modelType = config.LargeModel
109			} else {
110				modelType = config.SelectedModelTypeSmall
111			}
112
113			return m, tea.Sequence(
114				util.CmdHandler(dialogs.CloseDialogMsg{}),
115				util.CmdHandler(ModelSelectedMsg{
116					Model: config.SelectedModel{
117						Model:    selectedItem.Model.ID,
118						Provider: string(selectedItem.Provider.ID),
119					},
120					ModelType: modelType,
121				}),
122			)
123		case key.Matches(msg, m.keyMap.Tab):
124			if m.modelList.GetModelType() == LargeModelType {
125				return m, m.modelList.SetModelType(SmallModelType)
126			} else {
127				return m, m.modelList.SetModelType(LargeModelType)
128			}
129		case key.Matches(msg, m.keyMap.Close):
130			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
131		default:
132			u, cmd := m.modelList.Update(msg)
133			m.modelList = u
134			return m, cmd
135		}
136	}
137	return m, nil
138}
139
140func (m *modelDialogCmp) View() tea.View {
141	t := styles.CurrentTheme()
142	listView := m.modelList.View()
143	radio := m.modelTypeRadio()
144	content := lipgloss.JoinVertical(
145		lipgloss.Left,
146		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
147		listView.String(),
148		"",
149		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
150	)
151	v := tea.NewView(m.style().Render(content))
152	if listView.Cursor() != nil {
153		c := m.moveCursor(listView.Cursor())
154		v.SetCursor(c)
155	}
156	return v
157}
158
159func (m *modelDialogCmp) style() lipgloss.Style {
160	t := styles.CurrentTheme()
161	return t.S().Base.
162		Width(m.width).
163		Border(lipgloss.RoundedBorder()).
164		BorderForeground(t.BorderFocus)
165}
166
167func (m *modelDialogCmp) listWidth() int {
168	return defaultWidth - 2 // 4 for padding
169}
170
171func (m *modelDialogCmp) listHeight() int {
172	items := m.modelList.Items()
173	listHeigh := len(items) + 2 + 4
174	return min(listHeigh, m.wHeight/2)
175}
176
177func (m *modelDialogCmp) Position() (int, int) {
178	row := m.wHeight/4 - 2 // just a bit above the center
179	col := m.wWidth / 2
180	col -= m.width / 2
181	return row, col
182}
183
184func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
185	row, col := m.Position()
186	offset := row + 3 // Border + title
187	cursor.Y += offset
188	cursor.X = cursor.X + col + 2
189	return cursor
190}
191
192func (m *modelDialogCmp) ID() dialogs.DialogID {
193	return ModelsDialogID
194}
195
196func (m *modelDialogCmp) modelTypeRadio() string {
197	t := styles.CurrentTheme()
198	choices := []string{"Large Task", "Small Task"}
199	iconSelected := "◉"
200	iconUnselected := "○"
201	if m.modelList.GetModelType() == LargeModelType {
202		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
203	}
204	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
205}